【機械学習 no.3】Pythonで実装したK最近傍法の識別結果を可視化してみる

前回,irisデータをK最近傍法(KNN)で識別しましたが,今度はそれをプロットして可視化してみたいと思います.今回の目的はPythonでプロットする方法を学ぶことです.

と言っても,irisは4次元データなのでその中の2つの特徴量(petal kengthとpeta width)だけに絞り込んで2次元プロットしてみます.そして,KNNの識別境界も図示してみたいと思います.

ちなみに,識別境界と言ってますが,パーセプトロンのように明確な識別境界の式が定まるわけじゃないので,ある意味強引な図示です(^^;)

とりあえず実行プログラムを先に示します.前回のコードに少しコードを追加しました.

import numpy as np
import pandas as pd
from collections import Counter
from matplotlib import pyplot

def knn(train, test, cl, k=1):
    dist = np.sum((test-train)**2, axis=1)  
    k_arg = np.argsort(dist)[0:k]
    cnt = Counter(cl[k_arg])
    argmin_label,dummy = cnt.most_common()[0]
    return argmin_label

df = pd.read_csv('https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data', header=None)

x = df.iloc[:,0:4].values
label = df.iloc[:,4].values

correct = 0;
result = []
for n in range(150):
	ind = np.ones(150, dtype=bool)
	ind[n] = False
	
	train = x[ind]
	cl = label[ind]
	test = x[n]
	
	r = knn(train, test, cl, k=21)
	result.append(r)
	if r==label[n]:
		correct += 1

print('correct classification: {}%'.format(100*correct/150))
result = np.array(result)
print(pd.crosstab(result, label))

#ここから今回のコード
pyplot.clf()
for xx in np.arange(1,7,0.1):
	for yy in np.arange(0,2.5,0.05):
		r = knn(x[:,[2,3]], np.array([xx,yy]), label, k=21)
		if r=='Iris-setosa':
			pyplot.plot(xx,yy,'s', color='#ffcccc')
		elif r=='Iris-versicolor':
			pyplot.plot(xx,yy,'s', color='#ccccff')
		else:
			pyplot.plot(xx,yy,'s', color='#ccffcc')

pyplot.plot(x[0:50,2],x[0:50,3],'ro')
pyplot.plot(x[50:100,2],x[50:100,3],'b*')
pyplot.plot(x[100:150,2],x[100:150,3],'g+')
pyplot.xlabel('petal length')
pyplot.ylabel('petal width')
pyplot.legend()
pyplot.show()

結果は,こんな感じ.

37行目以降が今回追加した内容です.行数は多いですけどやってることはプロットしているだけです.また,グラフが表示されるまでかなり遅いです.それはご勘弁を(汗)

で,今回はグラフ表示の関数を追加しましたが,これは,matplotというライブラリをインポートすればいいらしいです.名前から察するに,MATLABのプロット関数を模したような気がしたけど,実際に使ってみると確かにモロMATLABでした(笑).

それで,plot関数はコードを見れば大体わかると思うので,特に説明の必要はないと思います.pyplot.plotの中でシングルクォートでro, b*, g+と書いてますが,これは「赤(r)の丸(o)でプロット」「青(b)のアスタリスク(*)でプロット」「緑(g)のプラス記号(+)でプロット」という意味です.

参考にしたのはこちらのページ↓


あと,

for xx in np.arange(1,7,0.1):
	for yy in np.arange(0,2.5,0.05):
		r = knn(x[:,[2,3]], np.array([xx,yy]), label, k=21)
		if r=='Iris-setosa':
			pyplot.plot(xx,yy,'s', color='#ffcccc')
		elif r=='Iris-versicolor':
			pyplot.plot(xx,yy,'s', color='#ccccff')
		else:
			pyplot.plot(xx,yy,'s', color='#ccffcc')

ここの箇所ですが,これはグラフの縦横の表示範囲だけxxとyyの値をそれぞれ0.1と0.05ずつ変化させていき,その場所の座標値を特徴ベクトルとしKNNの出力結果によって色を塗ってます.刻み幅をもう少し細かくすれば,識別境界がもっと綺麗になるんでしょうけど,かなり時間がかかるのでやめます(笑).’s’はスクエア(四角)でプロットするという意味です.

ちなみに,k=1に設定した時のプロットはこちら↓

versicolorとvirginicaの境界付近がデータにフィットし過ぎですね(赤い丸で囲ったところ).これはkの値を増やせば(通常は)解決します.