アプリケーションのグローバルな変数といったらgオブジェクトですよね。
flaskの仕組み中心になってきたね〜、Pytorchの存在感・・・
前回に引き続き、flaskアプリの改造を行なっていきます。
ニューラルネットのインスタンスを格納している「model」とラベルの内容が記載された「index_to_json.json」は、「classifier.py」で実行されるグローバルな変数として読み込んでいます。正直なところ、読み込みのみの変数なので、gオブジェクトにする必要があるのか疑問ですが、アプリケーションのグローバルな変数をgを使って表現する、というように変更していきたいと思います。(ログインユーザー名のようなリクエスト開始から終了までの間に書き換えるようなグローバル変数はgで管理しないと同時アクセスした別クライアントから書き換えられたり、いろいろと不都合が発生します。)
gオブジェクトの実践的な使い方だと思ってやってみたいと思います。
gに値を取得する
classifier.pyのプログラムの一部です。以下のbpの↓の部分に値の取得関数を追加しました。
今回は、「classifier.py」を修正していきます。
...
bp = Blueprint('classifier', __name__)
#追加:astは、strをdictに変換するためのモジュールです。
from flask import current_app, g
import ast
#①index_to_name.jsonの内容をg.labelに取得します。
def get_label():
if 'label' not in g:
with current_app.open_resource('index_to_name.json') as f:
#g.labelの型をstrからdictに変換して読み込みます。
g.label = ast.literal_eval(f.read().decode('utf8'))
return g.label
#②modelをg.modelへと格納します。
def get_model():
if 'model' not in g:
g.model = models.densenet121(pretrained=True)
g.model.eval()
return g.model
2つの関数を追加してデータを取得できるようにしました。
①get_label()
jsonファイルを読み込む処理を、行います。current_appというコンテキストのproxyを利用して、open_resourceメソッドを利用して、アプリからの相対的な階層でファイルオープンすることができています。ファイルを読み込むとstr型となるので、astモジュールでdict型に変換しておきます。後で利用するときはdict型でないといけないので、この中で変換しておきます。
returnで「g.label」を返してgコンテキストにグローバルな値として登録しておきます。
②get_model
modelをpytorchを利用して作成してreturnで返しています。そのままシンプルに関数化しただけですね。
このget_labelとget_modelは、アプリケーションコンテキストが有効な時に呼び出されないとエラーとなります。
また、以前グローバルに作成していた以下の部分のプログラムは削除しておきます。
#以下のプログラムは、削除しました。
#model = models.densenet121(pretrained=True)
#model.eval()
#img_class_map = None
#mapping_file_path = 'C:/Users/JIMU04/Python/Pytorch_flask/densenet/index_to_name.json'
#print("reading-file")
#if os.path.isfile(mapping_file_path):
# print("file_path")
# with open (mapping_file_path) as f:
# img_class_map = json.load(f)
...以下略
gオブジェクトにアクセスする
modelという変数と、img_class_mapの代わりに「g.model」と「g.label」としましたので、利用していた関数の変数を変更しておきます。
get_prediction関数の中のmodel変数を「g.model」に変更しました。
#g.modelでアクセスするように変更
def get_prediction(input_tensor):
#テストで型を確認しました。
#print("modeltype")
#print(type(g.model))
outputs = g.model.forward(input_tensor)
_, y_hat = outputs.max(1)
print(y_hat)
prediction = y_hat.item()
print(prediction)
return prediction
render_prediction関数のimg_class_mapをg.labelに変更しました。
#g.labelでアクセスするように変更
def render_prediction(prediction_idx):
stridx = str(prediction_idx)
class_name = 'Unknown'
if g.label is not None:
#temp = ast.literal_eval(g.label)
if stridx in g.label is not None:
class_name = g.label[stridx][1]
return prediction_idx, class_name
コンテキストが有効なところで利用する
最後のview関数のところで、「get_model」と「get_label」を呼び出してgオブジェクトに値を取得します。このview関数は、リクエストがきっかけで処理が開始されますので、コンテキストのproxyであるgなどを利用することができます。
@bp.route('/predict', methods=['POST'])
def predict():
if request.method == 'POST':
file = request.files['file']
if file is not None:
#modelとpathをgオブジェクトに取得
get_model()
get_label()
input_tensor = transform_image(file)
prediction_idx = get_prediction(input_tensor)
class_id, class_name = render_prediction(prediction_idx)
return jsonify({'class_id': class_id, 'class_name': class_name})
gの使い方がだいぶわかってきました。
考えて作ると使い方がわかるよね~