【PyTorchチュートリアル】WHAT IS TORCH.NN REALLY?の5、Validation setの追加など

AIプログラミング

ニューラルネットでValidationデータの利用方法がよくわかりました。

一つづつ実装していくとどのようになっているかがわかるのが良いね。

PyTorchチュートリアル「WHAT IS TORCH.NN REALLY?」の「Add validation」から行なっていきます。

こんな人の役に立つかも

・機械学習プログラミングを勉強している人

・Pytorchのチュートリアルを行なっている人

スポンサーリンク

validationの追加

一番最初のpickelデータの読み込みのところで、「x_valid」「y_valid」というデータも読み込んでいましたが、利用していませんでした。検証データという名前でこのセクションから利用することになります。

(validationは、検証データと言って、訓練データに利用していないデータでモデルの損失を計算することで、過学習がどうかを見分けることができるようになります。)

前のセクションでは、訓練ループのためにデータを使うだけでした。実際には、モデルが過学習していないかを判断するために「Validation set」も作成する必要があります。

バッチと過学習の相関関係を防ぐために、データをシャッフルすることも重要です。一方で、検証セット(validation set)のシャッフルをしても、モデルから出力される損失は同じになるので、意味がありません。

検証セット(validation set)には、トレーニングセットの2倍のバッチサイズを使用します。これは、検証セットがバックプロパゲーションを必要としないため、メモリ消費が少ないためです(勾配計算を保存する必要はありません)。これを利用して、より大きなバッチサイズを使用し、損失をより迅速に計算します。

・訓練データ:データをシャッフル

・検証データ:データをシャッフルしない、バッチサイズ2倍、勾配計算なし

#訓練データ
train_ds = TensorDataset(x_train, y_train)
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)

#検証データ
valid_ds = TensorDataset(x_valid, y_valid)
valid_dl = DataLoader(valid_ds, batch_size=bs * 2)

各エポックの終わりに検証データでの損失を計算し、表示します。

トレーニングの前に常にmodel.train()を呼び出し、推論の前にmodel.eval()を呼び出すことに注意してください。これらはnn.BatchNorm2dやnn.Dropoutなどのレイヤーによって使用され、これらの異なるフェーズに対する適切な動作を保証します。

model, opt = get_model()

#現在2エポックで検証中
for epoch in range(epochs):
    #モデルを訓練モードに
    model.train()
    for xb, yb in train_dl:
        pred = model(xb)
        loss = loss_func(pred, yb)

        loss.backward()
        opt.step()
        opt.zero_grad()

    #モデルを検証モードに
    model.eval()
    #検証データで損失を求めます。
    with torch.no_grad():
        valid_loss = sum(loss_func(model(xb), yb) for xb, yb in valid_dl)

    #検証データでの損失の表示
    print(str(epoch + 1) + "エポック目")
    print("Loss:{:.4f}" .format(valid_loss.item() / len(valid_dl)))

表示の部分は見やすいように変更しております。

1エポック目
Loss:0.3098
2エポック目
Loss:0.3018

Create fit()とget_data()

損失計算という観点から、訓練データと検証データで2回同じことをしているので「loss_batch」という関数を作成してまとめます。

def loss_batch(model, loss_func, xb, yb, opt=None):
    loss = loss_func(model(xb), yb)

    if opt is not None:
        loss.backward()
        opt.step()
        opt.zero_grad()

    return loss.item(), len(xb)

「opt」という引数がポイントで、トレーニングデータの損失を計算するときは、optにオプティマイザを渡します。それによって、訓練データの場合には、そのまま逆伝播、重みの更新、初期化と処理が続きますが、検証データの場合はモデルの損失計算のみを行うような関数となっています。

ここで、訓練の関数を「fit」としてまとめます。検証データで損失を計算する機能も入りました。

import numpy as np

def fit(epochs, model, loss_func, opt, train_dl, valid_dl):
    for epoch in range(epochs):
        #訓練データ
        model.train()
        for xb, yb in train_dl:
            #ここでloss_batch
            loss_batch(model, loss_func, xb, yb, opt)

        #検証データ
        model.eval()
        with torch.no_grad():
            losses, nums = zip(
                #ここでloss_batch
                *[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl]
            )
        val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)

        print(str(epoch + 1) + "エポック目")
        print("Loss:{:.4f}" .format(val_loss))

検証データの損失計算部分はわかりにくいです。まずは、loss_batch周りがどのようなデータを吐き出しているかを確認してみます。ここでは、Python3の「リスト内包表記」という方法でloss_batchを回しています。リスト内包表記では、式 + forループという書き方をします。そこで、最終的に以下のプログラムで表示するデータが得られます。リストの中に、loss_batchの返り値2つが並んでいる形です。

リスト形式で保存されているので、値としてloss_batchの答えを1組づつを取り出すため、アスタリスクでアンパックを行います。そして最後にzip関数で1組ごとに取り出してlossesには損失値を、numsにはデータ数を入れていきます。

以下はデータの流れを追うためだけに個人的に追加したプログラムなので、わかる人は無視してください。

#リスト内包表記で出力されるデータ
print([loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl])

losses_test, nums_test = zip(
                *[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl]
            )

#lossesに格納されるデータ
print(losses_test)
#numsに格納されるデータ
print(nums_test)
[(0.33914339542388916, 128), (0.46684250235557556, 128), (0.4368799030780792, 128・・・略
(0.33914339542388916, 0.46684250235557556, 0.4368799030780792,・・・略
(128, 128, 128, 128, 128, 128, 128, 128・・・略

これで、データローダーを取得して、モデルを訓練させる流れを3行で記載できるようになりました。

train_dl, valid_dl = get_data(train_ds, valid_ds, bs)
model, opt = get_model()
fit(epochs, model, loss_func, opt, train_dl, valid_dl)
1エポック目
Loss:0.3033
2エポック目
Loss:0.2809

GitHubにプログラムを配置していますので、ご利用ください。

タイトルとURLをコピーしました