티스토리 뷰

반응형

1. KNN 이란?

지도 학습의 한 종류로, 레이블이 있는 데이터를 사용하여 분류 작업을 하는 알고리즘이다. 

데이터로부터 거리가 가까운 k개의 다른 데이터의 레이블을 참조하여 분류한다.

유일하게 필요한 것은 다음과 같다

거리를 재는 방법
서로 가까운 점들은 유사하다는 가정

장점

알고리즘이 간단하여 구현하기 쉽다.

단점

차원(벡터)의 크기가 크면 계산량이 많아진다. 
현상의 원인을 파악하는 데는 큰 도움이 되지 않는다.

2. 알고리즘 순서

1) 특정한 점과 다른 점들과의 거리 측정 
2) 가까운 k 개의 점 선별 
3) 선별된 점들의 레이블을 보고 다수결(majority vote)로 새로운 데이터 포인트의 레이블을 정한다

2_1. 특정한 점과 다른 점들과의 거리 측정

두 점 사이의 거리를 측정하는 방법은 간단하다. 두 점이 각각 (x1, y1), (x2, y2) 로 되어 있으면 두 점 사이의 거리는 다음과 같다.

d=(x1x2)2+(y1y2)2

d=(x1x2)2+(y1y2)2

numpy 를 사용하여 구현해보도록 하겠다.

In [174]:
import numpy as np

def distance(p1, p2):
    return np.sqrt(np.sum(np.power((p2-p1),2)))

p1 = np.array([1,3])    
p2 = np.array([2,6])
    
print(distance(p1, p2)) # (1,3) 과 (2,6) 간의 거리 계산
3.16227766017

(1,3) 과 (2,6) 간의 거리는 3.16227766017 이다

2_2. 가까운 k 개의 점 선별

iris data 를 가지고 진행할 예정으로, 제일 마지막 virginica 데이터와 나머지 데이터간의 distance 를 구한다 
iris data 에 대한 자세한 설명은 생략하도록 한다. 
iris.data 는 아래와 같이 구분되어 있다.
0~49 : setosa 
50~99 : versicolor 
100~149 : virginica

In [175]:
from sklearn.datasets import load_iris

data = load_iris()

distance_result = np.zeros(150)

for i in  range(len(data.data)):
    distance_result[i] = distance(data.data[len(data.data)-1],data.data[i] )
    
print(distance_result)
[ 4.14004831  4.15331193  4.29883705  4.14969878  4.17372735  3.81837662
  4.21781934  4.0607881   4.30232495  4.10609303  4.03236903  4.02243707
  4.21781934  4.63141447  4.33358974  4.11339276  4.17851648  4.1024383
  3.80657326  4.0607881   3.81182371  4.00624512  4.62817459  3.73898382
  3.76430604  3.95221457  3.89615195  4.03236903  4.11096096  4.03608721
  4.00374824  3.91535439  4.18927201  4.22492603  4.10609303  4.3150898
  4.17252921  4.10609303  4.38748219  4.03980198  4.21307489  4.37492857
  4.39203825  3.84057287  3.67151195  4.14125585  4.01123422  4.24028301
  4.04598566  4.14125585  1.25299641  0.86023253  1.06770783  1.4525839
  0.86023253  0.83066239  0.67082039  2.28910463  1.          1.47648231
  2.23830293  0.9486833   1.58113883  0.60827625  1.61245155  1.14017543
  0.73484692  1.3190906   1.08627805  1.50665192  0.36055513  1.24096736
  0.73484692  0.77459667  1.07238053  1.06770783  1.04880885  0.81240384
  0.68556546  1.84390889  1.64316767  1.76635217  1.37840488  0.37416574
  0.83666003  0.75498344  0.9486833   1.17898261  1.15758369  1.36747943
  1.08166538  0.67082039  1.3190906   2.27596134  1.11355287  1.1
  1.05356538  0.99498744  2.40624188  1.15325626  1.24498996  0.33166248
  1.47309199  0.64807407  1.00498756  2.28691933  1.27279221  1.84661853
  1.17473401  1.88148877  0.66332496  0.6244998   1.02956301  0.58309519
  0.64031242  0.76157731  0.72111026  2.56904652  2.62488095  0.8660254
  1.28452326  0.45825757  2.42487113  0.53851648  1.08627805  1.59373775
  0.46904158  0.28284271  0.79372539  1.48996644  1.81659021  2.52388589
  0.83666003  0.53851648  0.78102497  2.11896201  0.96436508  0.64807407
  0.31622777  1.09087121  1.12249722  1.12249722  0.33166248  1.3190906
  1.25698051  0.9486833   0.65574385  0.64031242  0.76811457  0.        ]

실제 자신과의 거리는 0 이 나왔고, virginica 는 100~149 (마지막)에 분포되어 있다. 
따라서 distance 가 대체적으로 뒤로 갈수록 작아짐을 볼 수 있다.

np.argsort 함수는 sorting 하여 index 를 반환한다

In [176]:
print(np.argsort(distance_result)) ## sorting 시 index 값 반환. 첫번째 index 인 149 는 자기 자신이다.
[149 127 138 142 101  70  83 121 126 123 133 113  63 111 147 114 103 137
 146 110  56  91  78 116  72  66  85 115 148  73 134 128  77  55  84 132
  51  54 119  61  86 145 136  97  58 104 112  76  96  75  52  74  90  68
 124 139  95  94 140 141  65  99  88 108  87  71 100  50 144 106 120 143
  92  67  89  82  53 102  59 129  69  62 125  64  80  81 130  79 107 109
 135  60  93 105  57  98 122 131 117 118  44  23  24  18  20   5  43  26
  31  25  30  21  46  11  10  27  29  39  48  19   7  17   9  34  37  28
  15   0  49  45   3   1  36   4  16  32  40   6  12  33  47   2   8  35
  14  41  38  42  22  13]

p 로부터 거리가 가까운 k 개의 points 의 target 값을 반환한다. 
target 값 ==> 0:setosa , 1:versicolor , 2:virginica

In [177]:
def find_nearest_neighbors(p, points, k): 
    
    distance_result = np.zeros(len(points))
    
    for i in  range(len(points)):
        distance_result[i] = distance(points[i],p )
        
    sorted_index = np.argsort(distance_result)[:k]  
    
    result = np.zeros(k)
    
    for i in range(k):
        result[i] = data.target[sorted_index[i]]
        
    return result # target 값 ==> 0:setosa    , 1:versicolor    , 2:virginica  
In [178]:
target_result = find_nearest_neighbors(data.data[149], data.data[:149], 5)
print(target_result)
[ 2.  2.  2.  2.  1.]

제일 마지막 virginica 데이터와 나머지 데이터간의 distance 중 가까운 점 5개의 target 값은 위와 같다.
2(virginica) 가 4개, versicolor 가 1개임을 알 수 있다.

2_3. 다수결(majority vote)로 새로운 데이터 포인트의 레이블을 정한다.

In [179]:
from collections import Counter

def majority_vote(target_result): 
    
    vote_counts = Counter(target_result)
    
    #print(type(vote_counts.most_common(1)))
    #print(vote_counts.most_common(1))
    
    #vote_counts 의 가장 빈번한 값중 0번째 index 의 target 값을 반환
    #가장 빈번한 값이 중복된 경우의 처리는 하지 않았다.
    
    # most_common([n])
    #Return a list of the n most common elements and their counts from the most common to the least. 
    
    return vote_counts.most_common(1)[0][0] 
In [180]:
majority_vote(target_result)
Out[180]:
2.0

주변 k 개의 점 중 2 가 가장 많다. 의 의미로 볼 수 있겠다.

2_4. 위의 함수들을 참조하여 KNN 함수 구현

In [181]:
def knn_predict(p, points , k=5):
    
    species = {0: "setosa",  1:"versicolor", 2:"virginica"}
    
    target_result = find_nearest_neighbors(p, points, k)
    
    species_result = majority_vote(target_result)
    
    return species[species_result]
In [182]:
print("최종결과 : " + knn_predict(data.data[149], data.data[:149], 5))
최종결과 : virginica

나머지 종들의 data 도 넣어 확인해본다.

In [183]:
print("최종결과 : " + knn_predict(data.data[0], data.data, 5)) 
# 여기서는 distance 측정에 자신의 data 가 포함되게 넣어주었다. 귀찮아서 ㅠㅠ
최종결과 : setosa
In [184]:
print("최종결과 : " + knn_predict(data.data[80], data.data, 5))
# 역시 distance 측정에 자신의 data 가 포함되게 넣어주었다. 귀찮아서 ㅠㅠ
최종결과 : versicolor

역시 tutorial 데이터라 그런지, 결과가 잘 나온다. 
어설프지만 KNN 의 동작을 이해하는데 도움이 되고자 구현해보았다. 
다음에는 파이썬에서 제공하는 라이브러리를 이용하여 같은 결과가 나오는지 확인해보고자 한다.


반응형

'머신러닝' 카테고리의 다른 글

의사결정나무(decisiontree)_2  (0) 2018.05.14
의사결정나무(decisiontree)_1  (0) 2018.05.12
iris data 를 이용한 KNN 구현해보기_2  (0) 2018.05.08
[tensorflow] MNIST Data 설명  (0) 2017.11.22
소프트맥스(SoftMax)  (0) 2017.11.22