もう少し本格的なflaskアプリとして成り立たせたいです。
flask周りをちゃんと整えたいね!
せっかくflaskを勉強したので、flaskを本格的に利用したアプリに挑戦していきたいと思いました。Pytorchのflask画像分類アプリは、すごく基本的なflaskの使い方しかしていないので、flaskアプリで勉強した「アプリケーションファクトリ」や「blueprint」機能を利用したプロジェクト構成に改造してみたいと思いました。
※現状は未完状態で、jsonラベルの変換機能など、改造にいくつかの課題を残しました・・・–;
ということで、flaskプロジェクトの構成を改造していきたいと思います。
こんな人の役に立つかも
・flaskの勉強をしている人
・flaskのプロジェクト構成を勉強している人
・flaskアプリの開発規模を拡大したい人
アプリをモジュールの構成にする
まずは、flaskをモジュールの構成とするために、プログラムを入れるためのフォルダを作成しました。フォルダ名(アプリ名)を「densenet」としました。
「__init__.py」のアプリケーションファクトリ
「__init__.py」としてチュートリアルで勉強したアプリケーションファクトリのcreate_appmのプログラムを配置しました。
import os
from flask import Flask
def create_app(test_config=None):
# create and configure the app
app = Flask(__name__, instance_relative_config=True)
app.config.from_mapping(
SECRET_KEY='dev',
DATABASE=os.path.join(app.instance_path, 'flaskr.sqlite'),
)
if test_config is None:
# load the instance config, if it exists, when not testing
app.config.from_pyfile('config.py', silent=True)
else:
# load the test config if passed in
app.config.from_mapping(test_config)
# ensure the instance folder exists
try:
os.makedirs(app.instance_path)
except OSError:
pass
# a simple page that says hello
@app.route('/hello')
def hello():
return 'Hello, World!'
return app
フォルダと「__init__.py」はこのようになりました。
Blueprintのプログラム
Pytorch関連の機械学習プログラムを「classifier.py」というプログラムにまとめたいと思います。その際、flaskのblueprint機能でアプリケーションファクトリに機能を追加していく形を取っていきます。まずは次の基本的なblueprintのプログラムをclassifier.pyに記載します。
チュートリアルはこちらのページです。
import functools
from flask import (
Blueprint, flash, g, redirect, render_template, request, session, url_for
)
from werkzeug.security import check_password_hash, generate_password_hash
#blueprintをclassifierという名前で作成します。
bp = Blueprint('classifier', __name__)
このblueprintnをアプリケーションファクトリに入れます。「__init__.py」にblueprintを追加します。
#return app の前に以下のblueprint追加のプログラムを記載しました。
from . import classifier
app.register_blueprint(classifier.bp)
return app
この時点でのフォルダ構成は次のような感じです。
機械学習処理を追加
次のプログラムを「classifier.py」に追記しました。
以前の処理との変更点としては、
#以下でpytorch関連のプログラム処理を記載
import io
import json
import os
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request
#app = Flask(__name__)#ここはアプリケーションファクトリでアプリインスタンスを作成するため削除します。
model = models.densenet121(pretrained=True)
model.eval()
img_class_map = None
#カレントディレクトリに「index_to_name.json」を配置しておきます。
mapping_file_path = 'index_to_name.json'
if os.path.isfile(mapping_file_path):
with open (mapping_file_path) as f:
img_class_map = json.load(f)
def transform_image(infile):
input_transforms = [transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])]
my_transforms = transforms.Compose(input_transforms)
image = Image.open(infile)
timg = my_transforms(image)
timg.unsqueeze_(0)
return timg
def get_prediction(input_tensor):
outputs = model.forward(input_tensor)
_, y_hat = outputs.max(1)
prediction = y_hat.item()
return prediction
def render_prediction(prediction_idx):
stridx = str(prediction_idx)
class_name = 'Unknown'
if img_class_map is not None:
if stridx in img_class_map is not None:
class_name = img_class_map[stridx][1]
return prediction_idx, class_name
#blueprintのルーティング:app→bpに変更
@bp.route('/', methods=['GET'])
def root():
return jsonify({'msg' : 'Try POSTing to the /predict endpoint with an RGB image attachment'})
@bp.route('/predict', methods=['POST'])
def predict():
if request.method == 'POST':
file = request.files['file']
if file is not None:
input_tensor = transform_image(file)
prediction_idx = get_prediction(input_tensor)
class_id, class_name = render_prediction(prediction_idx)
return jsonify({'class_id': class_id, 'class_name': class_name})
appインスタンスは、「__init__.py」で作成しているので、コメントアウトしました。
「index_to_name.json」は階層を移動させておきました。
ルーティングは、blueprintを利用するので、appからbpへと変更をしております。
動作検証
アプリを動作させてみます。venv仮想環境に入り、flaskをデバッグ用のWebサーバーに立ち上げます。
※環境によっては、Cythonがインストールされていない、みたいなメッセージが出ましたので、pip install Cythonで仮想環境にpipする必要があるかもしれません。私はwindowsで試したら出ました。
mac
$. venv/bin/activate
(venv)$ export FLASK_APP=densenet
(venv)$ export FLASK_ENV=development
(venv)$ flask run
windows
>venv¥Scripts¥activate
(venv)> set FLASK_APP=densenet
(venv)> set FLASK_ENV=development
(venv)> flask run
新しくターミナル、またはコマンドプロンプトをひらき、アプリの階層に移動して、curlコマンドを投げてみました。
※今回、私は前回動作させたプロジェクトをコピーして「Pytorch_flask2」というフォルダで作業を行いましたので、ここに移動してcurlを投げています。
※windowのコマンドでも同様に動作します。
$ cd /Python/Pytorch_flask2
$ curl -X POST -H "multipart/form-data" http://localhost:5000/predict -F "file=@kitten.jpg"
次のような結果が帰ってきました。pytorchの分類器の結果はうまく取得できていますが、jsonファイルのラベル変換がうまくいっていないようです。
class_nameがうまく取得できていませんね・・・
調査したところ「index_to_name.json」ファイルがうまく読みこめていませんでした。「index_to_name.json」をフルパスで指定したら、うまく動作しました。
mapping_file_path = '任意の階層のフルパス…/index_to_name.json'
考察と改良点
よく考えたら、「model」と「index_to_name.json」は、コンテキストから取得したほうが良いような気がしてます。今回の「model」と「index_to_json」はリクエストによって書き換えが発生するプログラムではないので問題は発生しないと考えていますが、プログラム的にはグローバルな変数として代入されていますし、「g」proxyオブジェクトに、「model」と「index_to_name.json」を格納する設計に変更することでチュートリアルでやったデータベースコネクションのようなスマートな設計になることでしょう!
index_to_name.jsonへのパスもフルパスで指定するというパワープレイをしているので、current_appオブジェクトでコンテキストの中からファイルオープンをする、というやり方が良さそうです。
・アプリケーションコンテキストの「g」を使って「model」と「index_to_name.json」を利用できるようにしてみたいです。
また、ちょっとハードルが上がるかもしれないのですが、templateのHTMLからアップロードするような形も理想的です。
・curlコマンドではなく、テンプレートに画像をアップロードして処理してみたいです。
とりあえず、基本的な構成としてはモジュール化ができたような気がしています。
少しづつ改善していこう