転移学習の醍醐味であるモデルの固定を行いました。
画像分類のモデルにいろいろ応用できそうだね。
PyTorchのチュートリアル「TRANSFER LEARNING FOR COMPUTER VISION TUTORIAL」のモデルを固定する部分から進めていきます。モデルを固定するという点でも、PyTorchnのテンソルのAutoGradの機能が活躍しています。
こんな人の役に立つかも
・機械学習プログラミングを勉強している人
・PyTorchのチュートリアルを勉強している人
・PyTorchで転移学習を勉強している人
重みを固定したConvNetで訓練する
ここまでチュートリアルで行ってきたモデルでは、ResNet18をベースにして、そこに全結合層を追加して利用するというもので、ResNet18の部分も訓練していたようです。
次にチュートリアルでは、モデルのパラメータを固定し、追加した層のみ訓練するような方法を見ていきます。
モデルのパラメータを固定する方法としては、テンソルのAutoGradをオフにすることで実現できるようです。テンソルのRequired_gradをFalseにしておくということです。
モデルのパラメータはテンソルとして格納されていますので、具体的にはパラメータを格納しているテンソルのRequired_gradフラグをFalseにすることでモデルを固定します。ResNet18のモデル部分のAutoGradをFalseにして、全結合層を追加することで、全結合層の部分のAutoGradはTrueの状態となり、勾配の更新も一部のみとなります。
AutoGradの詳細な動作については、こちらのページに詳しく記載されているようです。
#「model_conv」にresnet18を格納
model_conv = torchvision.models.resnet18(pretrained=True)
#①for文でパラメータを格納しているテンソルのAutoGradフラグをFalseに
for param in model_conv.parameters():
param.requires_grad = False
#新たに作成されるモジュール(nn.Linear)は初期状態でRequired_gradがTrue
num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, 2)
#モデルのデバイスを指定
model_conv = model_conv.to(device)
#損失関数をクロスエントロピーに設定
criterion = nn.CrossEntropyLoss()
#②最適化手法にMomentumSGDを設定していますが、与えるパラメータは追加した層のみとしています。
optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)
#学習率を減衰させるためのschedulerの設定を行います。
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)
①resnet18モデルの固定
for文でモデルのパラメータ(重み、バイアス)を格納しているテンソルのAutoGradをFalseにします。これによって既存モデルのresnet18のパラメータの勾配を求めず、resnet18は固定されることになります。
一方で、resnet18に追加する全結合層(nn.Linear)は追加した時点でRequired_gradがTrueとなっています。そのため、このまま訓練させることでこの層のパラメータは勾配が計算されることになります。
②最適化手法の設定
最適化手法に与えるパラメータとして、追加した全結合層のみのパラメータを与えていることも注意点です。
そして、「model_conv」を訓練させていきます。
model_conv = train_model(model_conv, criterion, optimizer_conv,
exp_lr_scheduler, num_epochs=25)
Epoch 0/24
----------
train Loss: 0.6817 Acc: 0.6352
val Loss: 0.2115 Acc: 0.9281
略
Epoch 24/24
----------
train Loss: 0.3147 Acc: 0.8566
val Loss: 0.2180 Acc: 0.9281
Training complete in 1m 34s
Best val Acc: 0.954248
チュートリアルのまとめ
resnet18は色々な画像を訓練させているようで、今回の「アリ」や「ミツバチ」の画像分類のように、訓練させる画像を変化させることで色々な画像分類アプリケーションに応用できそうです。
転移学習という項目を学び、最近画像分類関連のアプリが多く出てきている意味も少し納得できたような気がします。