【AIプログラミング】PyTorchで畳み込みニューラルネットワーク、チュートリアル Part4-1

AIプログラミング

PyTorchの60MinuteBlitzのチュートリアルも、最後のTraining a Classifierに入りました。

やっとここまできたね。60分どころじゃなかったね。

やっとPyTorchチュートリアル「A 60 MINUTE BLITZ」も最後のチュートリアルにきました。今まで構築したニューラルネットワークの知識で、「CIFAR10」という画像分類の課題をおこなっていくような内容となっています。今回はCIFAR10というデータセットをPyTorchでロードして確認する、というところまでを完了しました。

このチュートリアルをやっています。

こんな人の役に立つかも

・機械学習プログラミングの勉強をしている人

・PyTorchのチュートリアルをやっている人

・PyTorchで畳み込みニューラルネットワークを構築したい人

スポンサーリンク

データの扱いについて

画像、テキストファイル、オーディオファイル、ビデオなどをPythonプログラムで扱うときには、配列としてデータをロードします。これらのデータは、PyTorchのテンソル型に変換することもできます。

一般的に、次のようなライブラリはデータの扱いを簡単にしてくれます。

画像:「Pillow」や「OpenCV」

音声:「scipy」や「librosa」

テキスト:Pythonに組み込まれている関数やCython、「NLTK」や「SpaCy」

PyTorchでは、画像処理に対しては「torchvision」というパッケージがあります。これは、一般的なデータセット(MNISTやCIFAR10など)のローダーを備えていたり、「torchvision.datasets」や「torch.utils.data.DataLoader」といった機能でより便利になっているとのことです。

今回のチュートリアルでは、torchvisionからCIFAR10を呼び出すという内容になっています。

シーファーテン、胃に優しそうな名前。

それは、ガ●ターだね・・・

CIFAR10概要

10クラス分類問題のための画像データセットになります。

答えのクラスとして「airplane」「automobile」「bird」「cat」「deer」「dog」「frog」「horse」「ship」「truck」 の種類があります。

画像のデータは「3×32×32」となっています。最初の3は、チャンネル数で、RGBの3チャンネル分の画像データを持っていることになります。

テンソルとしては3階のテンソルで表現ができます。

データ数は、訓練データとして「60000枚」テストデータとして「10000枚」が用意されています。

チュートリアル全体の流れ

最後の画像分類チュートリアルは、次のような流れで行われます。

1.データ読み込みと前処理 torchvisionを利用してCIFAR10の訓練データとテストデータを読み込み、データを正規化します。

2.畳み込みニューラルネットワークを定義

3.損失関数を定義

4.訓練データでネットワークを訓練

5.テストデータで評価

本記事では、まず「1」の部分を行います。

CIFAR10データの読み込みとNormalize

※標準化、正規化など日本語だとややこしい感じがするので、Normalizeとしています。-1~1の範囲にデータをスケーリングするような処理です。以前、scikit-klearnで勉強した標準化、standardizationとは違う処理です。

CIFAR10のデータは、torchvisionを利用するとめちゃくちゃ楽にロードできる、らしいです。

import torch
import torchvision
import torchvision.transforms as transforms

torchvisionからロードするデータは、「0から1」の範囲の値のデータとなっています。テンソル型のデータとして読み込む際に、Normalize、「-1から1」の範囲にNormalize、も同時に行います。

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

transformという変換するためのライブラリからComposeを呼び出すことで、いくつかの画像処理をまとめて行うことができます。一つ目は、テンソルに変換「ToTensor」、二つ目が「Normalize」の処理です。 Normalizeは、引数に(平均,標準偏差)を与えるのですが、今回RGBの3チャンネルの画像のため、平均が(0.5,0.5,0.5)のようにチャンネル分指定してあります。また、標準偏差も同様に(0.5,0.5,0.5)としてあります。この辺りのNormalizetionについてはもう少し別で調べる必要がありそうですが、今回は0から1のデータを-1から1にスケーリングする用途ということで、このパラメータを定型化できそうなので、ここまでにしておきます。

PyTorchのドキュメント

まずはCIFAR10の訓練データセットをダウンロードして、Normalizeなど変換処理をおこなったものをtrainsetに格納です。 そして、ついでにバッチ単位にデータを分割しておきます。バッチ単位にまとめるのは、「DataLoader」の部分が行ってくれます。

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

上のプログラムを実行してデータをダウンロードすると、colaboの左の方の「ファイル」に、以下のようにCIFAR10のファイルが追加されます。

テストデータも同様に読み込んでおきます。

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)

答えクラスのラベルを次のようにして格納しておきます。

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

CIFAR10の画像を確認

ダウンロードしたCIFAR10データセットのいくつかを表示して確認します。 画像の表示には、matplotlibを利用します。

import matplotlib.pyplot as plt
import numpy as np

# functions to show an image
def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

# get some random training images
dataiter = iter(trainloader)
print(dataiter)
images, labels = dataiter.next()

# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))

まずは、imshow関数の定義です。

imgのデータはNormalizeされているデータを受け取る予定です。 受け取ったデータを、0から1の範囲のデータに戻しています。そして、numpyのデータとして(ndarray型)npimgに格納します。

matplotlibのimshowは、引数として最後にチャンネル数を取るため、numpyのtransposeにて「32×32×3」という順番の多次元配列に変換しています。 最後にshowで画像表示です。

次に、Pythonのiter()「イテレータ」でtrainloaderのバッチを一つ読み込んでいます。1つのバッチには4枚の画像がありますので、nextでたどると4枚の画像とラベルが取り出せるようになってるようです。

最後に、CIFAR10のデータの数を確認しておきます。

print("訓練データの数")
print(len(trainset))
print("テストデータの数")
print(len(testset))
訓練データの数
50000
テストデータの数
10000

とりあえず、torchvisionからCIFAR10を読み込むことができました。

タイトルとURLをコピーしました