【機械学習 no.2】Pythonで実装した最近傍法をK最近傍法に改良

シェアする

いきなりヘビの写真で失礼します.私はヘビが嫌いでアオダイショウでも触れないくらいですが,Pythonが話題なので仕方ないです.ま,クモよりはヘビの方がマシですが(苦笑)

それはともかく,前回の最近傍法のプログラムはこちらの記事で紹介してます.

今回は前回の最近傍法のプログラムをK最近傍法に改良したプログラムです.いきなりですが,改良したプログラムを以下に示します.

import numpy as np
import pandas as pd
from collections import Counter

def knn(train, test, cl, k=1):
    dist = np.sum((test-train)**2, axis=1)  
    k_arg = np.argsort(dist)[0:k]
    cnt = Counter(cl[k_arg])
    argmin_label,dummy = cnt.most_common()[0]
    return argmin_label

df = pd.read_csv('https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data', header=None)

x = df.iloc[:,0:4].values
label = df.iloc[:,4].values

correct = 0;
result = []
for n in range(150):
	ind = np.ones(150, dtype=bool)
	ind[n] = False
	
	train = x[ind]
	cl = label[ind]
	test = x[n]
	
	r = knn(train, test, cl, k=21)
	result.append(r)
	if r==label[n]:
		correct += 1

print('correct classification: {}%'.format(100*correct/150))
result = np.array(result)
print(pd.crosstab(result, label))

プログラムの中で今回初めて使用した関数を説明します.まず,collections.Counterモジュールですが,これはK最近傍法の多数決をする際に便利です.参考にしたのは以下のページ.

このcollectionsは,要素が重複しているリストの中で,各要素の出現頻度数を調べるものです.たとえば,以下の実行例で説明します.

>>> data = ['dog','dog','cat','lion','wolf','cat','dog']
>>> cnt = Counter(data)
>>> cnt.most_common()
[('dog', 3), ('cat', 2), ('lion', 1), ('wolf', 1)]
>>> label, num = cnt.most_common()[0]
>>> label
'dog'
>>> num
3
>>> 

dog, dog, cat, lion, wolf, cat, dogという文字列の要素が入ったリストを考えます.この中で最も要素数の多いのはdog,次がcatとなります.most_common()という関数を使うと,要素数の多い順に個数と共に表示されます(4行目).最も数の多かった文字列とその個数を取得しているのが5行目.Pythonだとこんな書き方も出来るんですね.すごいです.

これ,多数決をする場合とても便利.K最近傍法だけでなくアンサンブル学習でも役立つと思います.

また,今回のプログラムは前回に加えて,クロス集計もすることにしました.参考にしたのはこちらのページ

pandasの中にcrosstabという関数があるんですね.これ,Rのtable関数と同じ.これもまたすごく便利です.

さて,このk=21に設定した時のプログラムの実行結果はこちらです

$ python knn.py
correct classification: 98.0%
col_0            Iris-setosa  Iris-versicolor  Iris-virginica
row_0                                                        
Iris-setosa               50                0               0
Iris-versicolor            0               48               1
Iris-virginica             0                2              49
$ 

表の見方ですが,setosaは50個あるデータすべてを正しく識別してますが,versicolorデータは2個だけvirginicaと誤識別,virginicaのデータは1個だけversicolorと誤識別してるという意味です.

データ全体の識別率は,(50+48+49)/150 = 0.98となります.

ふむ,やっとPythonに慣れてきたかな.