【PyTorchチュートリアル】WHAT IS TORCH.NN REALLY?の1、MNISTデータをセットアップ

PyTorchのチュートリアルをやるとすごく理解が深まります。

PyTorchにはまってますね。

PyTorchのチュートリアルを順番に進めていっています。次のチュートリアルは、「WHAT IS TORCH.NN REALLY?」というチュートリアルを行なっていきます。これは、機械学習の手書き数字分類で有名な「MNIST」データセットを分類するニューラルネットワークを作成する、という題材でチュートリアルを行なっています。

このチュートリアルでは、PyTorchのtensorについての基礎的な知識と、ニューラルネットワークについての基礎的な知識が必要になります。

・tensorについてはこちらの記事もご参考ください。

PyTorchのチュートリアルpart1-1、tensor型の計算など

PyTorchのチュートリアルPart1-2、tensor型の操作など

・ニューラルネットワークについてはこちらの記事もご参考いただけます。

ニューラルネットワークの要素技術について勉強①

ニューラルネットワークの要素技術について勉強②

こんな人の役にたつかも

・機械学習プログラミングを勉強している人

・PyTorchのチュートリアルをしている人

目次

MNISTデータセットのセットアップ

このチュートリアルでは、MNISTデータセットを利用して、基本的なニューラルネットワークを訓練するという内容らしいです。最初はPyTorchのテンソルの機能のみで作成して行くとのことです。まずは、チュートリアルで利用するデータセットを準備していきます。

MNIST

MNISTは機械学習の有名なデータセットで、0~9の手書き数字(白黒画像)のデータです。分類問題として、正しく手書き数字を分類できるネットワークの訓練を目指します。

次のプログラム(Python3標準ライブラリのpathlibを利用)で、リクエストを利用して、MNISTデータセットをダウンロードします。

GoogleColaboで実行する場合、次のプログラムでファイルの「data/mnist/」に追加されます。

from pathlib import Path
import requests

DATA_PATH = Path("data")
PATH = DATA_PATH / "mnist"

PATH.mkdir(parents=True, exist_ok=True)

URL = "http://deeplearning.net/data/mnist/"
FILENAME = "mnist.pkl.gz"

if not (PATH / FILENAME).exists():
        content = requests.get(URL + FILENAME).content
        (PATH / FILENAME).open("wb").write(content)

ここでダウンロードしたデータセットはnumpy形式で、python固有のpickelで保存されています。そのため、次のプログラムでデータを取り出します。

gzipをopenしたのち、pickel形式のデータを読み込んでいます。

import pickle
import gzip

with gzip.open((PATH / FILENAME).as_posix(), "rb") as f:
        ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")

pickel、機械学習のデータはこの形式で保存されていることが多いのかな。よくみます。

データの確認と変換

この状態で、各画像は784個のデータの平坦化された行(1次元の配列)として保存されていますので、28×28の画像として変換する必要があります。

#確認用
print(x_train[0].shape)

1次元配列で784個のデータとなっています。

torch.Size([784])

画像として確認するために、一時的にreshapeメソッドで28×28の行列に変換してmatplotlibで表示してみます。

また、x_trainには50000個の手書き文字データが格納されています。

PyTorchでデータを利用するためにはnumpy配列から「torch.tensor」に変換しする必要があります。

import torch

x_train, y_train, x_valid, y_valid = map(
    torch.tensor, (x_train, y_train, x_valid, y_valid)
)
n, c = x_train.shape
x_train, x_train.shape, y_train.min(), y_train.max()
print(x_train, y_train)
print(x_train.shape)
print(y_train.min(), y_train.max())

やっとtorchが出てきました。

Pythonユーザーは当たり前かもしれませんが、map関数は、pythonの関数で、map(①関数,②処理する変数)と与えることで、②を引数として①に実行を行うような関数です。

今回は、map関数1行で4つのnumpy配列を一括でtensorに変換することができました。

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]]) tensor([5, 0, 4,  ..., 8, 4, 8])
torch.Size([50000, 784])
tensor(0) tensor(9)

今回は、MNISTデータセットを読み込み、データセットを準備しました。次回からニューラルネットワークの実装に入ります。

続きの記事はこちらです。

ぱんだクリップ
【PyTorchチュートリアル】WHAT IS TORCH.NN REALLY?の2、スクラッチでニューラルネットワーク | ぱんだク... スクラッチでニューラルネットを作成すると、データの流れの確認になります。 シンプルな例で仕組みを知ると応用できるのかもね。 PyTorchチュートリアル「WHAT IS TORCH.N...
よかったらシェアしてね!
  • URLをコピーしました!
  • URLをコピーしました!
目次