【PyTorchチュートリアル】WHAT IS TORCH.NN REALLY?の3、プログラムのリファクタリング1

1_プログラミング

前回のスクラッチコードをPyTorchのモジュールで置き換えていきます。ボリュームがあったので、今回は半分だけやりました。

リファクタリングしてプログラムが短くわかりやすくなるんだね。

今回は、プログラムのリファクタリングをしていくことになります。リファクタリング項目が少しボリュームありましたので、今回は「nn.Functional」「nn.Module」「nn.Linear」のところをやっていきます。

こんな人の役に立つかも

・機械学習プログラミングを勉強している人

・PyTorchのチュートリアルをしている人

スポンサーリンク

torch.nn.functionalを使う

いちばん簡単なステップとしては、手動で定義していた損失関数を「torch.nn.functional」で置き換えることです。「torch.nn.functional」は一般的に「F」という名前空間でimportされます。

損失関数として、「負の尤度関数(negative log likelihood)」または「尤度関数(log softmax)」を利用している場合、PyTorchでは「F.cross_entropy」関数でカバーしています。これを利用して、損失関数をモジュールを利用して簡潔に記述することができます。

Functionalは、「as F」とするのが一般的のようですね。

import torch.nn.functional as F

loss_func = F.cross_entropy

def model(xb):
    return xb @ weights + bias

スクラッチプログラムと同じ損失と同じ精度が算出されることを確認します。

print(loss_func(model(xb), yb), accuracy(model(xb), yb))
tensor(0.0804, grad_fn=<NllLossBackward>) tensor(1.)

nn.Moduleでリファクタリング

次に、訓練ループを整理するために、「nn.Module」と「nn.Parameter」を利用します。

nn.Moduleのサブクラスとしてクラスを定義することで利用します。サブクラスでは、「重み、バイアス」といったネットワークの訓練パラメータを保持し、順伝播をメソッドとして実装していきたいです。

nn.Moduleには、あらかじめ便利なメソッド

・Parameter()

・zero_grad()

が存在します。このnn.Moduleを継承することで、一括で重み、バイアスにアクセスできるようになります。

nn.Moduleはめちゃくちゃ便利ですね。

#nnをimport
from torch import nn

#nn.Moduleを継承したMnistLogisticクラス
class Mnist_Logistic(nn.Module):
    def __init__(self):
        super().__init__()
        #コンストラクタで重み、バイアスの初期化
        self.weights = nn.Parameter(torch.randn(784, 10) / math.sqrt(784))
        self.bias = nn.Parameter(torch.zeros(10))

    #forwardメソッドの定義
    def forward(self, xb):
        return xb @ self.weights + self.bias

クラスのコンストラクタでXavier initialisationとバイアスのゼロ初期化を行うようになりました。

また、forwardをクラスのメソッドとして定義しました。nn.Moduleのクラスを利用することで、テンソルの追跡も自動的に行われるので、手動で行っていた時のようにRequired_gradを気にする必要もなくなりました。

次のように、定義したモデルをインスタンス化します。

model = Mnist_Logistic()

以前の記載方法と同じ記載で損失を計算することができます。

nn.Moduleのオブジェクトは関数のように使用することができます。関数のように呼び出すことで、自動的にforwardメソッドを実行してくれるようになっています。

print(loss_func(model(xb), yb))
tensor(0.0821, grad_fn=<NllLossBackward>)

以前の訓練方法では、それぞれのパラメータ(重み、バイアス)を変数名で指定して更新を行い、それぞれのパラメータの勾配を明示的に初期化を行わなければいけませんでした。

#いままでの手動での勾配降下法の実装
#with torch.no_grad():
#    weights -= weights.grad * lr
#    bias -= bias.grad * lr
#    weights.grad.zero_()
#    bias.grad.zero_()

nn.Moduleクラスを導入することで、「model.parameters()」でモデル全体のパラメータにアクセスができ、「model.zero_grad() 」でモデル全体のパラメータの勾配を一括で初期化できるようになりました。これは、もっと複雑なネットワークを定義するときにとても役に立ちます。重みの更新忘れなどの人為的ミスも防ぐことになります。

#↓こうなります
#with torch.no_grad():
#    for p in model.parameters(): p -= p.grad * lr
#    model.zero_grad()

このparameters()とzero_grad()を使った勾配降下法で訓練関数fit()を作成してみます。

def fit():
    for epoch in range(epochs):
        for i in range((n - 1) // bs + 1):
            start_i = i * bs
            end_i = start_i + bs
            xb = x_train[start_i:end_i]
            yb = y_train[start_i:end_i]
            pred = model(xb)
            loss = loss_func(pred, yb)

            loss.backward()
            with torch.no_grad():
                for p in model.parameters():
                    p -= p.grad * lr
                model.zero_grad()

fit()

fit()で関数を呼び出して訓練です。そして、最後に損失を求めます。

print(loss_func(model(xb), yb))
tensor(0.0817, grad_fn=<NllLossBackward>)

nn.Linearでリファクタリング

次に、モデルの定義の中で、「self.weights」や「self.bias」と行った変数を定義して初期化していました。そして、「xb @ self.weights + self.bias」と順伝播の式を手動で定義していました。

中間層が増加すると、「weights」や「bias」と行った変数も中間層分増加することになり、順伝播の式もより複雑になっていきます。

今回のような線形な層の場合、nn.Linearを利用して、簡潔に記載することができます。線形の層以外の畳み込み層など一般的に利用されるものもpytorchに存在し、とても早く実装することができるようになります。

class Mnist_Logistic(nn.Module):
    def __init__(self):
        super().__init__()
        #入力が784次元、出力が10次元の線形の層を定義
        self.lin = nn.Linear(784, 10)

    def forward(self, xb):
        return self.lin(xb)

このクラスは、以前と同様のプログラムでインスタンス化して、損失を計算することができます。

model = Mnist_Logistic()
print(loss_func(model(xb), yb))
tensor(2.2910, grad_fn=<NllLossBackward>)

先ほどリファクタリングした訓練用の関数であるfitも同様に利用することができます。

fit()

print(loss_func(model(xb), yb))
tensor(0.0821, grad_fn=<NllLossBackward>)

かなりシンプルなプログラムで作成できるようになってきました。

GitHubにチュートリアルプログラムを配置します。随時更新していきます。

続きの記事はこちらです。

タイトルとURLをコピーしました