【AIプログラミング】多クラスの時のConfusionMatrix、アヤメのデータ分類でConfusionMatrixを利用

1_プログラミング

多クラス分類のアヤメデータでConfusionMatrixを利用してみたよ。

多クラスでもConfusionMatrixは利用できたよね。

ConfusionMatrixについて勉強した時、基本的な2クラス分類の時で考えていましたが、多クラスの時のConfusionMatrixがどのようになってくるのかを見てみました。懐かしのアヤメデータで色々と試してみたいと思います。

ConfusionMatrixについては、こちらの記事もご参考ください。

こんな人の役に立つかも

・機械学習プログラミングの勉強をしている人

・多クラスのConfusionMatrixについて知りたい人

・scikit-learnのConfusionMatrixプログラムを勉強している人

スポンサーリンク

アヤメデータの多クラス分類をやってみる

アヤメデータのデータ構成などについては、こちらの記事もご参考ください。

早速、機械学習アルゴリズムについては、GradientBoost、利用してみました。また、単純にホールドアウトした訓練データで訓練をするだけでなく、交差検証にて「learning_rate」パラメータを模索してみました。

GradientBoostでアヤメ多クラス分類

久しぶりに交差検証を利用しました。交差検証は、訓練データをさらに分割することでテストデータに全く触れることなく、パラメータチューニングができるようになる方法です。こちらの記事もご参考ください。

import ~交差検証

まずは、importと、アヤメデータの読み込みを行い、訓練データとテストデータに分割します。そして、GradientBoostClassifierを作成して交差検証を行いました。

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report

#アヤメデータ
panda_box = load_iris()

X = panda_box.data
y = panda_box.target

#訓練データとテストデータに分割
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.25, stratify=y, random_state=0)

#交差検証
from sklearn.model_selection import cross_val_score

#勾配ブースティングマシン
sk_clf = GradientBoostingClassifier()

#交差検証を行う。
score = cross_val_score(sk_clf, X_train, y_train, cv=3)

#結果の表示
print("交差検証の結果")
print(score)
print("交差検証の平均")
print("{:.4f}".format(np.mean(score)))
交差検証の結果
[0.86842105 0.97297297 0.94594595]
交差検証の平均
0.9291

久しぶりに交差検証を利用したよ・・・使わないとね。

GradientBoostのパラメータ調整

訓練データを利用した交差検証の正解率は、92%程度となっています。ここで、一度GradientBoostingClassifierの「learning_rate」パラメータを調整してみます。

learning_rateを変更することで先ほど行った訓練データへの交差検証のスコアがどのように変化するのか、グラフ化してみたいと思います。今回は、「learning_rate」を0.01〜0.15まで変化させるように15回のループとしました。(事前に数回実行して、ある程度あたりをつけてこの範囲にしました。)

import matplotlib.pyplot as plt

train_dat_array = []

loop = 15

for i in range(loop):
    clf = GradientBoostingClassifier(learning_rate=((i+1)*0.01))
    score = cross_val_score(clf, X_train, y_train, cv=3)
    train_dat_array.append(np.mean(score))

#グラフ
X_axis = np.linspace(0.01,loop*0.01,loop)
plt.plot(X_axis, train_dat_array)
plt.xlabel("learning_rate")
plt.ylabel("score")

どうやら、learning_rateは、0.02をピークにスコアが一定になってしまうようです。ということで、0.02に決めました。

多クラスのCondusionMatrixなどを表示

sk_clf = GradientBoostingClassifier(learning_rate=0.02).fit(X_train, y_train)

sk_pred_train = sk_clf.predict(X_train)
sk_pred_test = sk_clf.predict(X_test)

#単純な正解率
print("テストデータ正解率")
print(accuracy_score(y_test ,sk_pred_test))
#ConfusionMatrixを作成
print("==ConfusionMatrix==")
print(confusion_matrix(y_test, sk_pred_test))

#ClassificationReport
print("==ClassificationReport==")
target_names = panda_box.target_names
print(classification_report(y_test, sk_pred_test, digits=4, target_names=target_names))

今回は、classification_report()のパラメータを追加して指定してみました。

・digits:小数点以下いくつまでを表示するかを指定

・target_names:リストで指定することで、答えの数値の代わりにラベル名で出力

結果をみると、precisionなどの数値が小数点4桁までに、また、以前は単純な数値だった答えのラベルが、アヤメの種類で表示されるようになりました。

テストデータ正解率
0.9736842105263158
==ConfusionMatrix==
[[13  0  0]
 [ 0 13  0]
 [ 0  1 11]]
==ClassificationReport==
              precision    recall  f1-score   support

      setosa     1.0000    1.0000    1.0000        13
  versicolor     0.9286    1.0000    0.9630        13
   virginica     1.0000    0.9167    0.9565        12

    accuracy                         0.9737        38
   macro avg     0.9762    0.9722    0.9732        38
weighted avg     0.9756    0.9737    0.9736        38

上のConfusionMatrixを表で整理してみると、

のようになりました。

これは、アルゴリズムがどこでどのように間違えたのかを明確にみることができる表になります。

今回は、予測アルゴリズムの答えで「1(virginica)」と答えたもので一つだけ実際の答えが「2(versicolor)」であったという間違いが1つあるのみでした。

GradientBoostすると、アヤメデータをほぼ完璧に分類している。すごいね。

PrecisionとRecallの値は?

Precisionは、予測したものがどれだけ正解していたかなので、それぞれの答えデータに対して一つづつ計算されます。

今回、1(virginica)と予測したものが一つ間違っていましたので、1(virginica)の時の計算をみてみます。

1(virginica)のprecisionは、「(予測値1の正解数13)÷(予測値1の全体の数)」で、0.928・・・となります。青色の枠の計算がprecisionです。

1(virginica)のrecallは、オレンジの枠で、実際の答えデータ1(virginica)に予測値の間違いがないので1となりました。

このように、「1(virginica)」と「それ以外」というような計算となるので、2クラス問題とは少し見方の様子が違ってくるようです。

タイトルとURLをコピーしました