【AIプログラミング】ロジスティック回帰で多クラス分類をしてみる

AIプログラミング

前回、ロジスティック回帰で2クラスの分類をしたので、今回は3クラス分類をしてみるよ。

だんだん小難しくなってきたね。

前回は、ロジスティック回帰の2クラス問題の境界線を確認しました。今回は、one versus restという用語を勉強して、3クラスのアヤメデータの境界線を引くプログラムをうごかしてみました。

前回の記事はこちらからどうぞ

こんな人の役に立つかも

・機械学習プログラミングの勉強をしている人

・ロジスティック回帰の3クラス分類の境界線を引くプログラムを作りたい人

スポンサーリンク

3クラスでの分類

ロジスティック回帰は、データの境界線を直線で求めるもので、基本的に、2クラスのデータの境界線を引く、というものでした。

そこで、多クラス(答えが3以上ある場合)に拡張するため、One versus restという考え方が使われます。

3クラス以上あるデータは、One versus restという考え方で分類できます。

プログラム

前回同様、importからアヤメデータの読み込みが以下のプログラムです。

#ロジスティック回帰
from sklearn.linear_model import LogisticRegression
from sklearn.datasets import load_iris
import numpy as np
import matplotlib.pyplot as plt

panda_box = load_iris()

X = panda_box.data
y = panda_box.target

X_2dim = panda_box.data[:,:2]

境界線を描く関数を作成します。

#関数を作成
def make_meshgrid(x, y, h=.02):
    x_min, x_max = x.min() - 1, x.max() + 1
    y_min, y_max = y.min() - 1, y.max() + 1
    xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
                         np.arange(y_min, y_max, h))
    return xx, yy

def plot_contours(ax, clf, xx, yy, **params):
    Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
    Z = Z.reshape(xx.shape)
    out = ax.contourf(xx, yy, Z, **params)
    return out

def graf_setting(ax):
    #ax.set_xticks(())
    #ax.set_yticks(())
    ax.set_xlabel('Sepal length')
    ax.set_ylabel('Sepal width')

3クラス(「setosa」「verginica」「versicolor」の3個の答え)のアヤメデータをロジスティック回帰で分類するプログラムです。

#3クラス
#===ロジスティック回帰の作成===
#訓練
clf = LogisticRegression().fit(X_2dim, y)

#空のグラフを作成
fig = plt.figure(figsize=(9, 4))
ax = fig.add_subplot(121)

#グリッドのデータを作成
X0, X1 = X_2dim[:, 0], X_2dim[:, 1]
xx, yy = make_meshgrid(X0, X1)

#グラフに境界線とデータをプロット
plot_contours(ax, clf, xx, yy, cmap=plt.cm.coolwarm, alpha=0.8)
ax.scatter(X0, X1, c=y, cmap=plt.cm.coolwarm, s=20, edgecolors='k')
graf_setting(ax)
3クラスの分類の結果

このように、ロジスティック回帰で3クラスの分類ができました。

OVR(One Versus Rest)について

本来、直線で境界線を引くということは、2クラスにしかデータを分離できないはずなのですが、3クラス以上のデータを分類する際にはOne versus restという考え方を利用することで、分類できるそうです。

ワン バーサス レスト・・・

one  versus restという考え方では、「一つ」対「残り」で比較していく方法です。

例えば、アヤメの3クラスのデータは、次のように3種類の直線による境界線を引くことができます。

①「setosa」VS「verginica と versicolor」

②「verginica」VS「setosa と versicolor」

③「versicolor」VS「setosa と verginica」

というデータの比較を行います。

①のsetosaの境界線は図のようになります。

setosa VS 他

境界線の数値を0として、setosaの方向に行くほど+のスコアが高くなり、反対方向ほど、-のスコアとなるような境界線が作成されます。

①~③のそれぞれの境界線を0として、境界線同士が重なる領域には、スコアの大きい数値を採用することで、それぞれの領域を決定しています。

プログラムで境界線を書いてみたい・・・できなかったので、簡単なイメージ図ですみません。

数学的にもわかるといいよね。

タイトルとURLをコピーしました