【Pytorch×Flask】Pytorchのflask画像分類アプリ、flaskプロジェクトを改造4、画像送信部分のエラー処理を追加

エラー処理、大事ですね。

アプリって、ここが一番苦労するよね・・・

Pytorchのflask画像分類アプリを作成して、改造してきました。現状のアプリでは、送信する画像ファイル名やファイルの拡張子が自由なので、色々なセキュリティの脆弱性が発生します。

ということで、今回は画像入力に対するエラー処理を実装していきたいと思います。

flaskのチュートリアルには次の入力画像に対するエラー処理のページがありましたので、参考にしています。

こんな人の役に立つかも

・flaskでWebアプリのテンプレートを実装したい人

・Pytorchのflask画像分類アプリチュートリアルを発展させたい人

・Pythonのflaskチュートリアルが物足りない人

目次

view関数の改良

「classifier.py」の内容を変更していきます。

ファイル拡張子の確認関数

allow_file関数を作成しました。これは、flaskチュートリアルをそのまま利用しています。

secure_filenameは、今回は利用していませんが、勉強のため入れてみています。

#今回は利用しませんが、secure_filenameをimportしておきます。
from werkzeug.utils import secure_filename

#predictの③で利用する入力ファイルの拡張子を制限する関数
ALLOWED_EXTENSIONS = {'jpg', 'jpeg'}#指定する拡張子は小文字である必要があります。
def allowed_file(filename):
    return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS

「ALLOEWD_EXTENSIONS」はset型データで、中にあらかじめ利用して良い拡張子の文字列を入れておきます。今回は「jpg」と「jpeg」にしました。

allowed_file関数の中ではretuen文で全てが完結しています。また、次の2つの条件がTrueとなるかをみています。

①「.」がファイル名に含まれるかどうか

②ファイル名の「.」より右側の文字列を小文字に変換して「ALLOWED_EXTENSIONS」に含まれている要素であるかどうか

「rsplit」は指定した文字より右側を取得するメソッドで、[ファイル名,拡張子]のようなリストを返してくれます。その右側の拡張子である[1]の要素を小文字化(lower)してその要素が「ALLOWED_EXTENSIONS」に含まれているものなのかという計算になっています。したがって、最初に設定した「ALLOWED_EXTENSIONS」の拡張子名は小文字である必要があります。

view関数でのエラー処理

predict関数を以下のように変更しました。

だいぶ原型がなくなったような・・・

@bp.route('/predict', methods=['GET','POST'])
def predict():
    if request.method == 'POST':
        #①POSTのリクエストデータにfile項目がない場合
        if 'file' not in request.files:
            flash('POST request error')
            return render_template('prediction.html')
        file = request.files['file']
        #②ユーザーがファイルを選択せずに送信ボタンを押した場合
        if file.filename == '':
            flash('ファイルが選択されていません。')
            return render_template('prediction.html')
        #③正常にファイルが存在していてファイル名の拡張子が許可されたものの場合
        if file and allowed_file(file.filename):
            #④ファイル名に階層情報などが含まれている場合に削除します。
            filename = secure_filename(file.filename)
            #今後の処理でファイル名自体の文字列を利用するときは安全な「filename」変数を利用します。
            get_model()
            get_label()
            input_tensor = transform_image(file)
            prediction_idx = get_prediction(input_tensor)
            class_id, class_name = render_prediction(prediction_idx)
            #return jsonify({'class_id': class_id, 'class_name': class_name})
            return render_template('result.html', class_id = class_id, class_name = class_name)
        else:
            flash('ファイルの拡張子は.jpgまたは.jpegのみです。')
    
    #GETリクエストで返すフォームHTML
    return render_template('prediction.html')

入力ファイルに対するエラー処理をチュートリアルを参考に3つ条件追加を行いました。

①つめは必ずしもテンプレートHTMLであるフォームからリクエストがあるわけではないので、リクエストデータの中にfileという名称のデータが存在しているかどうかを確認しておきます。ここではflash関数でエラーメッセージを処理することにしました。

②には、ユーザーがファイルを選択せず送信ボタンを押した場合のエラーを入れます。ここでもflash関数でエラーメッセージを処理しています。

③の最後の条件としては、fileにデータが存在している場合かつ、ファイルの拡張子が先ほど定義したallowed_file関数がTrueを返してきてくれているか、という条件です。そのため、else文ではflashに拡張子のエラーメッセージを入れておきました。(厳密にはfileにデータが存在指定ない場合も含まれますが・・・)

これらの3つの条件をくぐり抜けたデータが正常に処理されるようにしています。

また、今回は利用していないのですが、fileの名称を何かしらの処理で利用したい場合④のように、secure_filename関数を通した後の文字列を利用するのが安全になります。ファイル名が「../../../../」のような階層情報が含まれていたりする場合、この階層情報を示す文字列を削除してくれます。

テンプレートの改善

view関数でflashしているエラーメッセージを表示できるように以下のように変更しました。

<h1>Prediction</h1>
{% with messages = get_flashed_messages() %}
  {% if messages %}
    <ul class=flashes>
    {% for message in messages %}
      <li>{{ message }}</li>
    {% endfor %}
    </ul>
  {% endif %}
{% endwith %}
<form method="post" enctype="multipart/form-data">
  <input type="file" name="file">
  <button type="submit">送信</button>
</form>

flash関数はやはり便利ですね。

過去に勉強したことが生きているね

エラーの検証

kitten.jpgをGIMPでkitten.pngに変換して入れてみました。(ちなみに、制限をかける前のプログラムではpngでも正常に動作します。)

jpeg画像以外のものはエラーとなって返ってきています。

ファイルを選択せずに送信ボタンを押してみます。

ファイル名が空白のflashメッセージが表示されることが確認できました。

よかったらシェアしてね!
  • URLをコピーしました!
  • URLをコピーしました!
目次