【PyTorchチュートリアル】WHAT IS TORCH.NN REALLY?の4、プログラムのリファクタリング2

今回のリファクタリングでかなりスッキリとしたプログラムになります。

PyTorchの便利さがよくわかりますね。

PyTorchのチュートリアル「WHAT IS TORCH.NN REALLY?」の「Refactor using optim」から引き続きリファクタリングのチュートリアルを進めました。チュートリアルページはこちらです。

こんな人の役に立つかも

・機械学習プログラミングを勉強している人

・PyTorchのチュートリアルを勉強している人

目次

optimでリファクタリング

Pytorchには、さまざまな最適化アルゴリズムを含むパッケージtorch.optimもあります。各パラメーターを手動で更新する代わりに、オプティマイザーのstepメソッドを利用できます。

手動でコーディングした最適化手順をそのまま置き換えてみます。

#以前の手動でのパラメータ更新
#with torch.no_grad():
#    for p in model.parameters(): p -= p.grad * lr
#    model.zero_grad()

これを単純に次のプログラムに置き換えます。

#opt.step()
#opt.zero_grad()

実際に修正したプログラムをみてみましょう。

モデルとオプティマイザを一度に作成するための小さな関数も「get_model」として追加作成します。

from torch import optim

#この関数を呼び出すと、モデルとオプティマイザのインスタンス化ができます。
def get_model():
    model = Mnist_Logistic()
    return model, optim.SGD(model.parameters(), lr=lr)

#ここで呼び出します。
model, opt = get_model()
print(loss_func(model(xb), yb))

#fit()という関数にまとめていた訓練部分
#今回はとりあえず普通に記載しています。
for epoch in range(epochs):
    for i in range((n - 1) // bs + 1):
        start_i = i * bs
        end_i = start_i + bs
        xb = x_train[start_i:end_i]
        yb = y_train[start_i:end_i]
        pred = model(xb)
        loss = loss_func(pred, yb)

        loss.backward()
        #重みの更新がこのように簡潔になる
        opt.step()
        opt.zero_grad()

print(loss_func(model(xb), yb))
tensor(2.3214, grad_fn=<NllLossBackward>)
tensor(0.0821, grad_fn=<NllLossBackward>)

Datasetでリファクタリング

PyTorchには抽象的なデータクラスがあります。

データセットは、__len____getitem__をインデックス付けする方法を持つものであればなんでもOKとのことです。 (ここは、カスタムDatasetを作成するチュートリアルが詳しいので、また追加で勉強する必要があります。ここでは、Datasetクラスで__len____getitem__が呼び出せるようになっているという理解で良いかと思います。DatasetのサブクラスとしてカスタムDatasetを作成する方法が<https://pytorch.org/tutorials/beginner/data_loading_tutorial.html>で確認できます。

PyTorchのTensorDatasetは、データセットをラップするテンソルです。これは、インデックス付けの長さと方法を定義することにより、テンソルの最初の次元に沿って反復、インデックス付け、およびスライスする方法も提供します。これにより、トレーニング中に同じ行で独立変数と従属変数の両方にアクセスしやすくなります。

ここは、TensorDatasetを利用することで、データと答えの紐付けができて管理しやすくなる、という理解で良いかと思います。

まずは、「TensorDataset」をimportしましょう。

from torch.utils.data import TensorDataset

x_trainとy_trainの両方を1つのTensorDatasetに組み合わせることができます。これにより、繰り返し処理やスライスが容易になります。

(チュートリアルの一番最初にpickelで取得したデータです。x_trainには画像データ、y_trainには答えデータが格納されています。)

#TensorDatasetで画像と答えラベルをまとめる
train_ds = TensorDataset(x_train, y_train)

以前のプログラムでは、ミニバッチで画像データ「x」と答えデータ「y」は別のデータとして紐づけていました。

#以前のミニバッチへデータを読み込むプログラム
#xb = x_train[start_i:end_i]
#yb = y_train[start_i:end_i]

次の1行にまとめます。

#xbとxyの読み込みを1行に
#xb,yb = train_ds[i*bs : i*bs+bs]

以下のように、訓練ループに組み込みます。

model, opt = get_model()

#訓練ループの修正
for epoch in range(epochs):
    for i in range((n - 1) // bs + 1):
        #このように1行で書くことができるように
        xb, yb = train_ds[i * bs: i * bs + bs]
        pred = model(xb)
        loss = loss_func(pred, yb)

        loss.backward()
        opt.step()
        opt.zero_grad()

print(loss_func(model(xb), yb))
tensor(0.0833, grad_fn=<NllLossBackward>)

このDatasetでは、プログラムの記載する分量としては増えているのですが、この後のDataLoaderを利用することでバッチ単位でのデータ管理がとても簡単になります。

DataLoaderでリファクタリング

PyTorchのDataLoaderはバッチの管理を行います。DataLoaderはDatasetから作ることができます。マニュアルでデータを切り取るようなプログラム「train_ds[i*bs : i*bs+bs]」を利用しなくても、DataLoaderが自動的にミニバッチのデータを与えてくれるようになります。

from torch.utils.data import DataLoader

train_ds = TensorDataset(x_train, y_train)
train_dl = DataLoader(train_ds, batch_size=bs)

DataLoaderには、TensorDatasetとバッチサイズを与えるだけになります。

以前のプログラムでは以下のようになっていました。

#for i in range((n-1)//bs + 1):
#    xb,yb = train_ds[i*bs : i*bs+bs]
#    pred = model(xb)

forループのバッチサイズ毎に割り算したり、めんどくさいイメージがありました。DataLoaderを利用すると次のようになります。

#for xb,yb in train_dl:
#    pred = model(xb)

DataLoaderまでを利用した訓練ループは次のようになります。

model, opt = get_model()

for epoch in range(epochs):
    for xb, yb in train_dl:
        pred = model(xb)
        loss = loss_func(pred, yb)

        loss.backward()
        opt.step()
        opt.zero_grad()

print(loss_func(model(xb), yb))
tensor(0.0814, grad_fn=<NllLossBackward>)

Pytorchのnn.Module、nn.Parameter、Dataset、DataLoaderのおかげで、トレーニングループが劇的に小さくなり、理解しやすくなりました。

GitHubへデータを更新しましたので、ご参考までにご利用ください。

GitHub
machine-learning/WHAT_IS_TORCH_NN_REALLY.ipynb at master · perfectpanda-works/machine-learning ML code. Contribute to perfectpanda-works/machine-learning development by creating an account on GitHub.

続きの記事はこちらです。

ぱんだクリップ
【PyTorchチュートリアル】WHAT IS TORCH.NN REALLY?の5、Validation setの追加など | ぱんだクリップ ニューラルネットでValidationデータの利用方法がよくわかりました。 一つづつ実装していくとどのようになっているかがわかるのが良いね。 PyTorchチュートリアル「WHAT IS...
よかったらシェアしてね!
  • URLをコピーしました!
  • URLをコピーしました!
目次