【AIプログラミング】PyTorchで畳み込みニューラルネットワーク、チュートリアル Part4-2、CIFAR-10

PyTorchのチュートリアル「TRAINING A CLASSIFIER」の二回目です。今回は、訓練するところまで行きました。

少しづつ理解できてきたね。

PyTorchの「A 60 MINUTE BLITZ」チュートリアルの「TRAINING A CLASSIFIER」チュートリアルを進めました。個人的に不明な点や、その他、プログラムについてもメモしながら進めています。今回は、畳み込みニューラルネットワークの定義~訓練までを行っていきます。

前回の記事はこちらをご参考ください。

ぱんだクリップ
【AIプログラミング】PyTorchで畳み込みニューラルネットワーク、チュートリアル Part4-1、CIFAR-10 | ぱん... PyTorchの60MinuteBlitzのチュートリアルも、最後のTraining a Classifierに入りました。 やっとここまできたね。60分どころじゃなかったね。 やっとPyTorchチュートリアル...

こんな人の役に立つかも

・機械学習プログラミングを勉強している人

・PyTorchで畳み込みニューラルネットワークのプログラミングを勉強している人

・PyTorchのチュートリアルを行っている人

目次

ネットワーク定義~訓練

チュートリアルの全体像は次のような感じです。今回進めることができた項目は以下の通りとなります。

1.CIFAR10データの読み込みとNormalize

2.畳み込みニューラルネットワークを定義(今回)

3.損失関数を定義(今回)

4.訓練データでネットワークを訓練(今回)

5.テストデータで評価

畳み込みニューラルネットワークを定義

以前チュートリアルで定義したものよりシンプルになっています。

import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        #①畳み込み層1の入力画像を3チャンネルに
        self.conv1 = nn.Conv2d(3, 6, 5)
        #②プーリング層として定義する
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        #このように畳み込み層とプーリング層を表現することもできる
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        #③全結合ニューラルネットワークに入力するためにデータをベクトル化
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()

①のところでは、「Conv1」畳み込み層1の入力を以前は1チャンネルとしていたところを3チャンネルと変更します。今回は1画像に対してRGBの3チャンネルの情報が存在します。

②では、「pool」としてプーリング層を定義しています。 以前は、max_pool2dメソッドを利用してforwardメソッド内でプーリングを実行していましたが、プーリング層として定義に入れることもできるようです。今回は2×2のフィルタをもつプーリング層を定義しています。

また、③では、全結合ニューラルネットワークへ入れるデータは横並びデータ(ベクトル)になっていないといけないので、これをviewメソッドを利用して行っています。

畳み込みニューラルネットワークの層の構成自体は同じですが、前回のチュートリアルよりシンプルな書き方となっています。

以前のチュートリアルよりスマートに定義できている気が・・・

3.損失関数を定義

pytorchライブラリのoptimを利用します。

import torch.optim as optim

#①損失関数(誤差関数)
criterion = nn.CrossEntropyLoss()
#②最適化手法はMomentum SGD
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

①にて、損失関数を「クロスエントロピー誤差関数」に設定します。

クロスエントロピー誤差関数は、過去にも多クラス分類のときによく利用する誤差関数、程度でしか紹介はしていませんが、今回もそのノリで進めさせていただきます。

そして、②では、最適化手法を「確率的勾配降下法(SGD)」と設定します。

SGDには学習率というハイパーパラメータを与えないといけませんので、lr=0.001として学習率を設定しています。SGDには、MomentumSGDという進化したものがあります。過去の更新した勾配も少しだけ影響させて重みの更新をおこなう、ようなイメージのSGDです。(だと思います、まだ詳しくは調査します。)その過去の勾配の影響具合をmomentumという係数で指定ている模様です。momentumを0にすると普通のSGDになります。

訓練データでネットワークを訓練

一番メインの訓練の処理です。

for epoch in range(2):  #①チュートリアルでは2エポック訓練を行う。

    running_loss = 0.0
    #②enumerateでiにリスト番号、dataに内容が入る。
    for i, data in enumerate(trainloader, 0):
        #data変数は[inputs, labels]のリストデータ。
        inputs, labels = data

        #勾配のパラメータを初期化
        optimizer.zero_grad()

        #順伝播し、逆伝播して重みを更新
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:    #ミニバッチ2000ごとに状況の表示
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('Finished Training')

①の最初のforループは、エポックのループです。ニューラルネットワークは「1エポックで1回、全データを使用」します。

②の内側のforループにてバッチ単位で重みを更新していく処理になります。trainloaderにはバッチ分割されたデータがはいっているので、これを読みだしてループを行います。

enumerate()はリストの添え字とデータを同時に取ってきます。第二引数は始まるインデックスNoを指定しますが、今回は0なのでtrainloaderの0番目のデータから取得してきます。

running_loss変数には、2000ミニバッチ分の正解との誤差が累積されていきます。

累積された誤差は、2000回のミニバッチの訓練が完了した際に、2000で割ることで2000回のミニバッチ訓練を1セットとした平均的な誤差が出力されます。2000個のミニバッチ毎にこれを行うことで、毎回誤差を出力するのではなく、一定間隔で誤差が小さくなってきているのかを確認することができます。

.item()メソッドは、テンソル型から一つだけ値を取り出すときに利用するのでした。このような利用方法もあるのですね。もちろん、runnning_loss変数は、2000回のミニバッチ完了毎に0にリセットします。

とりあえず一通り畳み込みニューラルネットワークが理解できるようになってきました。

かなりの進歩。

交差エントロピー誤差や、MomentumSGD等については、今後、余裕があり次第別途深堀していきたいなと思う次第であります。

訓練モデルの保存

PyTorchでは、次のように訓練モデルを保存できます。

PATH = './cifar_net.pth'
torch.save(net.state_dict(), PATH)

Google colaboでは左の「ファイル」欄に追加されています。

続きの記事はこちらになります。

ぱんだクリップ
【AIプログラミング】PyTorchで畳み込みニューラルネットワーク、チュートリアル Part4-3 | ぱんだクリップ これでPyTorchの60分チュートリアルも最後になります。 60分と言いつつもかなり色々予備知識が必要だったね。 今回で、やっとPyTorchのチュートリアルである「A 60 MINUTE ...
よかったらシェアしてね!
  • URLをコピーしました!
  • URLをコピーしました!
目次