【Pytorch×Flask】Pytorchのflaskチュートリアル、プログラムの内容を理解1

久しぶりのPytorch、復習しながら勉強を進めたいです。

復習しながら、プログラムの構成をみていきましょう~

Pytorchのチュートリアルでflaskで画像分類アプリを作成しています。公式ドキュメントはこちらのページです。

前回は、とりあえずPytorchとFlaskの画像分類アプリを動作させました。これで、大体どんな動作をするのかがわかりました。

ぱんだクリップ
【Pytorch×Flask】Pytorchのflaskチュートリアル、画像分類アプリを動作させる | ぱんだクリップ Pytorchのチュートリアルにflaskと連携するものがありました。 とても実践的なサンプルだね〜 Pytorchのチュートリアルを久しぶりに散策していたら、flaskを利用してPytorc...

久しぶりの機械学習でしたので、最初はPytorchのプログラムをみてもピンときませんでした^^;ちなみに、今回のプログラムは、DenseNetや転移学習など、過去に勉強したものがふんだんに利用されているだけでした・・・記憶力とははかないものです(´· ·`)

こんな人の役に立つかも

・画像分類アプリを作成したい人

・flaskとPytorchでWebアプリを作成したい人

・Pytorchを勉強している人

目次

Pytorchを中心にプログラムを理解

import

torchvisionは、pytorchに実装されている画像関係のモジュールです。PILはPythonの画像処理モジュールで、トリミングなど、シンプルな処理ができます。flaskはWebフレームワークですね。

import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request

モデルの定義

今回、分類器としてはtorchvisionにあらかじめ学習済みの「DenseNet121」を利用しています。

DenseNetの公式ドキュメントはこちらにあります。

PyTorch
Densenet import torch model = torch.hub.load('pytorch/vision:v0.10.0', 'densenet121', pretrained=True) # or any of these variants # model = torch.hub.load('pytorch/visio...

DenseNet121は、224×224サイズ3チャンネル画像を入力とする畳み込みニューラルネットワークモデルです。「pretrained=True」として呼び出すことで、あらかじめ学習された1000クラスの分類を行うことができます。

#densenet121モデルを読み込みます
model = models.densenet121(pretrained=True)
#検証モード
model.eval()

Pytorchでは、autogradという機能があり、tensorの計算を記録することで勾配を簡単に求めることができました。しかし、分類器として利用する際は、勾配を求める必要がないので、evalとすることで検証用の状態に変更できます。ニューラルネットのモデルには、「訓練する状態」と訓練した状態のネットワークを「使う状態」の2状態があって、evalで「使う状態」にできる、というようなイメージです。

ニューラルネットの内部で勾配がどのように利用されているかなどは、こちらのチュートリアルをやると理解が深まると思います。(ちょっと時間がかかりますが^^;)

ぱんだクリップ
【PyTorchチュートリアル】WHAT IS TORCH.NN REALLY?の1、MNISTデータをセットアップ | ぱんだクリップ PyTorchのチュートリアルをやるとすごく理解が深まります。 PyTorchにはまってますね。 PyTorchのチュートリアルを順番に進めていっています。次のチュートリアルは、「WHA...

入力画像の変換処理

torchvisionのtransformsを利用して、入力画像をニューラルネットの入力サイズに変換する関数です。

Composeのメソッドが一連の変換処理を実行してくれます。一連の変換処理は「input_transforms」というリスト形式で表現しています。一連の処理として「画像を255×255に大きさを拡大縮小」→「画像を中心から224×224を切り取り」→「Tensor型に変換」→「RGBチャンネルをそれぞれ平均,標準偏差の順番の引数でNormalize処理」という流れです。

def transform_image(infile):
    #入力画像への「一連の処理」を定義
    input_transforms = [transforms.Resize(255),#画像サイズの変更
        transforms.CenterCrop(224),#画像を中心から224×224に切り取り
        transforms.ToTensor(),#Tensor型に変換
        transforms.Normalize([0.485, 0.456, 0.406],
            [0.229, 0.224, 0.225])]#Normalize(平均,標準偏差)を行います。
    #一連の処理をmy_transformsとします。
    my_transforms = transforms.Compose(input_transforms)
    #入力画像をPILとして読み込みます。
    image = Image.open(infile)
    #timgとして画像を変換して格納しています。
    timg = my_transforms(image)
    #unsqueezedeで2次元画像を横並びのデータに変換
    timg.unsqueeze_(0)
    return timg

最後のunsqueezeでは、バッチ形式にデータを加工しています。バッチというのは、ニューラルネットワークへ与える画像のセットで、一つのバッチに何枚もの画像データが入っています。1枚だけ与える場合でも1枚だけ画像が入ったバッチにしなければいけません。

squeezeする前の画像は、次の図のように、RGBの画像が3チャンネル(②のくくりが3個)と、画像全体としてのくくりである①が一つですが、この状態ではニューラルネットは受け付けてくれません。

squeezeは指定の次元を挿入するということで、0番目、下の図で言う③のようなくくりを追加することで、画像枚数をくくるような次元を追加することになります。

また、squeeseの後の「_」によって、同じtimgを上書きして利用するように指定しています。Pytorchのテンソルは計算の過程が記録されていきますので、このように変数を上書きするときも明示的に「_」で明記しないといけないんですね。

予測をする関数

get_predictionはニューラルネットワークにデータを入れて順伝播させ、予測値を得る関数です。

シンプルに、モデルのforwardメソッドを実行しています。得られた出力から、最も値が大きいクラスを選択しています。得られた答えはテンソル型のため、数値にするためにitemメソッドで変換をしています。

def get_prediction(input_tensor):
    outputs = model.forward(input_tensor)
    _, y_hat = outputs.max(1)
    #itemメソッドは、テンソルから数値を取得するメソッド
    prediction = y_hat.item()
    return prediction

クラス名を取得する関数

render_prediction関数は、ニューラルネットから得られたクラス番号のラベルを探して返す関数です。DenseNet121は、1000個のクラス分類ができ、答えとしては0~999の数値が返ってきます。その数値をあらかじめjson形式で対応付けしたラベル一覧「index_to_name.json」を参照して答えの単語を返すというような処理を行います。

def render_prediction(prediction_idx):
    #入力としてニューラルネットの答えを受けます。そして文字列に変換します。
    stridx = str(prediction_idx)
    class_name = 'Unknown'
    #img_calss_mapをループして、一致するラベルを検索します。
    if img_class_map is not None:
        if stridx in img_class_map is not None:
            class_name = img_class_map[stridx][1]

    return prediction_idx, class_name

ちなみに、ループ内で利用されるimg_class_mapは、プログラム冒頭処理で次のように「index_to_name.json」から読み込まれているラベル名称一覧になります。

img_class_map = None
mapping_file_path = 'index_to_name.json'#index_to_name.jsonからクラス名を取得
if os.path.isfile(mapping_file_path):
    with open (mapping_file_path) as f:
        img_class_map = json.load(f)

これで、ニューラルネットが動作する関数は理解できました。

次回は、ニューラルネットのそれぞれの部品を組み合わせて動作のトリガーなどを定義してるflask周りのプログラム構成についてみていきたいと思います。

よかったらシェアしてね!
  • URLをコピーしました!
  • URLをコピーしました!
目次