【PyTorchチュートリアル】TRANSFER LEARNING FOR COMPUTER VISION TUTORIAL-1、イントロ

転移学習に興味がありましたので、転移学習に関するチュートリアルを始めようと思います。

他のモデルを利用して訓練を簡単にするやつですね。

転移学習について勉強しようと思います。PyTorchには、「TRANSFER LEARNING FOR COMPUTER VISION TUTORIAL」という画像分類の畳み込みニューラルネットワークの転移学習のチュートリアルがありましたので、このチュートリアルをやっていこうと思います。転移学習についてはまだよくわかりませんので、どのようなものかを調べつつ、チュートリアルを進めていきたいと思います。

チュートリアルのサイトはこちらです。

こんな人の役に立つかも

・機械学習プログラミングを勉強している人

・PyTorchのチュートリアルをしている人

・転移学習の勉強をしている人

目次

転移学習について

転移学習は、「すでに訓練させたモデルを別のモデルに利用する」ことです。

転移学習(TRANSFER LEARNING)

画像識別などの畳み込みニューラルネットワークを訓練させるときに、とても多くのデータが必要になったり、訓練の時間やハードウェアのリソースが必要になったりします。全ての人がこのような作業を簡単にできるというものではなく、ある程度似たような問題に対してはすでに過去に作成したモデルを利用できないか、というような考え方から転移学習が利用されています。

例えば、自動運転などは、現実世界でデータを収集することは非常に難しいので、シミュレーションなどでデータを収集するということが考えられます。

また、猫の画像分類機をすでに作成している場合、それを犬の識別モデルを作成するときに転移学習を利用すれば、0から訓練のプロセスを行うことなく、より少ないデータ量で目的を達成することができそうです。

すでにあるモデルを転移できるという背景には、ニューラルネットワークが訓練によって共通の特徴を捉えているという性質があるようで、このような転移学習ができるらしいです。

転移学習の構成

転移学習の訓練では、すでに訓練させたモデルに対しては、手を加えず、出力層側に新しく層を追加して、追加した層を訓練していくようです。

すでに訓練したモデルの一部を訓練すると、「ファインチューニング」という違う手法になるようなので注意が必要です。

訓練済みモデルの特徴を生かして新しい問題に調整していくイメージですね。

チュートリアルの準備

チュートリアルでは、「アリ」と「ミツバチ」の画像を分類する畳み込みニューラルネットワークを訓練するようです。

画像はチュートリアルの「download here」のところからダウンロードできます。

このデータは、

・アリの訓練データ:124枚

・ミツバチの訓練データ:121枚

・アリの検証データ:70枚

・ミツバチの検証データ:83枚

から構成されています。

何万というデータと比較するととても少ないデータです。

今回のチュートリアルは、Google Colaboで実行していきます。

それでは、まずはimportからです。

# License: BSD
# Author: Sasank Chilamkurthy

from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy

plt.ion()   # interactive mode

あらかじめ先にダウンロードした「アリとミツバチのデータ」である「hymenoptera_data」フォルダをGoogleDriveにアップロードしておきます。今回私は、「My Drive」直下に配置しました。

次のプログラムでGoogleColaboがDriveにアクセスすることを許可しておきます。

from google.colab import drive 
drive.mount('/content/drive')

訓練データと検証データを読み込み、データローダーでミニバッチを作成します。

訓練データは、データ拡張といって、「RandomResizedCrop」と「RandomHorizontalFlip」の処理に通してからデータのNormalizeを行なっています。

# 訓練データはデータ拡張してNormalizeを行う。
# 検証データにはNormalizeのみをかける。
data_transforms = {
    'train': transforms.Compose([
        #ランダムなサイズ、                         
        transforms.RandomResizedCrop(224),
        #50:50の確率で水平に反転
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

data_dir = 'drive/My Drive/hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                                             shuffle=True, num_workers=4)
              for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

最後に、matplotlibで画像が正常に読み込まれているかを確認しましょう。

def imshow(inp, title=None):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated


# Get a batch of training data
inputs, classes = next(iter(dataloaders['train']))

# Make a grid from batch
out = torchvision.utils.make_grid(inputs)

imshow(out, title=[class_names[x] for x in classes])

これで、GoogleColaboでチュートリアルを実行できる準備が整いました。次から、チュートリアルを進めていきます。

続きの記事はこちらです。

ぱんだクリップ
【PyTorchチュートリアル】TRANSFER LEARNING FOR COMPUTER VISION TUTORIAL-2、転移学習 | ぱんだクリップ 今回のチュートリアルは説明が少ないですね。 いろいろ調べながら進めましょう。 今回は、「TRANSFER LEARNING FOR COMPUTER VISION TUTORIAL」チュートリアルのモデルを訓...
よかったらシェアしてね!
  • URLをコピーしました!
  • URLをコピーしました!
目次