optimモジュールで最適化をわざわざ記載しなくていい様になります。
色々な最適化手法も導入できる様になるよ。
PyTorchのチュートリアル「LEARNING PYTORCH WITH EXAMPLES」を少しづつ行なっていくというのも終盤にさしかかります。今回は、optimモジュールについてです。初めてPytorchでチュートリアルをする方は、チュートリアル1回目から行なった方が理解が深まると思います。
こんな人の役に立つかも
・機械学習プログラミングを勉強している人
・PyTorchで機械学習プログラミングを勉強している人
・PyTorchでニューラルネットワークを勉強している人
optimパッケージ
これまで、チュートリアルでは、学習可能なパラメーターを保持するテンソルを手動で変更することにより、モデルの重みを更新しました。テンソルをautogradを利用して直接計算したり、nnモジュールでの処理を行うとautogradの追跡が自動的に行われてしまいますので、今まで最適化(勾配降下法)を行うときはtorch.no_gradを利用して、重みの更新時の計算は明示的に追跡しない様にしていました。
これは、確率的勾配降下法などの単純な最適化アルゴリズムにとって大きな負担ではありませんが、AdaGrad、RMSProp、Adamなどのより高度な最適化手法を使用してニューラルネットワークをトレーニングすることがよくあります。
特に、Adamは使われているところをよく見ますね。
PyTorchのoptimパッケージで、最適化アルゴリズムの概念を抽象化し、一般的に使用される最適化アルゴリズムを利用することができます。
また、この様な最適化手法は、学術論文などで発表され、現在も研究、改善されている様な分野です。完全に理解して利用できると良いのですが、まずは実装ベースで考えるなら、optimモジュールに実装されている最適化の特徴を理解して利用するということが現実的です。
この例では、以前のようにnnパッケージを使用してモデルを定義しますが、optimパッケージによって提供されるAdamアルゴリズムを使用してモデルを最適化します。
チュートリアルプログラムの実装
前回のnnモジュールを利用した方法でニューラルネットワークの定義などを行なっています。
# -*- coding: utf-8 -*-
import torch
# N :バッチサイズ
# D_in :入力次元数
# H :隠れ層の次元数
# D_out:出力次元数
N, D_in, H, D_out = 64, 1000, 100, 10
# ランダムな入力データと出力データの作成
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)
# ニューラルネットワークと損失関数の定義
model = torch.nn.Sequential(
torch.nn.Linear(D_in, H),
torch.nn.ReLU(),
torch.nn.Linear(H, D_out),
)
#損失関数の定義
loss_fn = torch.nn.MSELoss(reduction='sum')
次のプログラムで、最適化手法として、Adamオプティマイザーを設定します。optimizerの設定では、引数として、モデルのパラメータを渡します。この様にして、更新する重みを渡しています。
learning_rate = 1e-4
#①最適化手法をAdamで設定して「optimizer」というインスタンスに
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for t in range(500):
# 順伝播
y_pred = model(x)
# 損失の計算
loss = loss_fn(y_pred, y)
if t % 100 == 99:
print(t, loss.item())
# ②「optimizer」のメソッドで勾配の初期化
optimizer.zero_grad()
# 逆伝播
loss.backward()
# ③「optimizer」のメソッドで重みの更新
optimizer.step()
今回の重要なポイントは、optimの利用方法です。実際にプログラム内では、「optimizer」としてインスタンス化されています。
①optimizerは、宣言時引数としてモデルのパラメータを渡しておきます。
②optimizerのzero_gradメソッドで勾配を初期化できます。
③optimizerのstepメソッドで、モデルのパラメータを更新します。
optimモジュールを導入してnnと連携することですごく抽象的に実装ができる様になりました。
こちらのGitHubのプログラムもご参考ください。GoogleColaboで実行ができます。
続きの記事はこちらです。