機器學習之KNN

2021-09-30 20:08:11 字數 3913 閱讀 9178

約會問題為例,展示整個演算法工作流程

注:這裡使用的是python語言

import numpy as np

import operator

def read_data(file_path):

fr = open(file_path)

file_lines = fr.readlines()

data_num = len(file_lines)

dataset = np.zeros((data_num, 3))

class_vec =

index = 0

for line in file_lines:

line_split = (line.strip()).split('\t')

dataset[index, :] = line_split[:-1]

index += 1

return dataset, class_vec

def data_classify(dataset, class_vec):

type1 = [, , ]

type2 = [, , ]

type3 = [, , ]

labels = list(set(class_vec))

for i in range(len(dataset)):

if class_vec[i] == labels[0]:

elif class_vec[i] == labels[1]:

elif class_vec[i] == labels[2]:

return type1, type2, type3, labels

def view_data(dataset, class_vec):

from matplotlib import pyplot as plt

plt.rcparams['font.sans-serif'] = ['simhei']

plt.rcparams['axes.unicode_minus'] = false

type1, type2, type3, labels = data_classify(dataset, class_vec)

fig1 = plt.figure(figsize=(30, 10))

axe1 = plt.subplot(1, 3, 1)

axe1.scatter(type1[0], type1[1])

axe1.scatter(type2[0], type2[1])

axe1.scatter(type3[0], type3[1])

plt.legend(labels)

axe2 = plt.subplot(1, 3, 2)

axe2.scatter(type1[0], type1[2])

axe2.scatter(type2[0], type2[2])

axe2.scatter(type3[0], type3[2])

plt.legend(labels)

axe3 = plt.subplot(1, 3, 3)

axe3.scatter(type1[1], type1[2])

axe3.scatter(type2[1], type2[2])

axe3.scatter(type3[2], type3[2])

plt.legend(labels)

plt.show()

def split_dataset(dataset, class_vec):

ratio = 0.8

stop_index = int(dataset.shape[0] * ratio)

train_dataset = dataset[: stop_index, :]

test_dataset = dataset[stop_index:, :]

train_class = class_vec[: stop_index]

test_class = class_vec[stop_index :]

return train_dataset, test_dataset, train_class, test_class

def data_normalize(train_dataset):

maxv = train_dataset.max(0)

minv = train_dataset.min(0)

diff = maxv -minv

# print(diff, type(diff))

data_normal = (train_dataset - np.tile(minv, (train_dataset.shape[0],1))) / np.tile(diff, (train_dataset.shape[0],1))

# print(data_normal[:10,:])

return data_normal, maxv, minv

def knn(train_dataset, train_class, input_data, k):

difference = np.sum((train_dataset - input_data) ** 2, 1)

sorted_index = difference.argsort()

class_count = {}

for i in range(k):

voted_label = train_class[sorted_index[i]]

class_count[voted_label] = class_count.get(voted_label, 0) + 1

sorted_class = sorted(class_count.items(), key = operator.itemgetter(1), reverse=true)

return sorted_class[0][0]

def main():

file_path = "datingtestset.txt"

dataset, class_vec = read_data(file_path)

print('dataset = \n', dataset[:5, :])

print(type(class_vec), 'class_vec = \n', class_vec[:20])

view_data(dataset, class_vec)

print('畫圖完成!')

train_dataset, test_dataset, train_class, test_class = split_dataset(dataset, class_vec)

train_data_normal, maxv, minv = data_normalize(train_dataset)

error_num = 0

for i in range(len(test_dataset)):

data = test_dataset[i]

data_normal = (data - minv) / (maxv - minv)

label = knn(train_data_normal, train_class, data_normal, 3)

print('估計出的標籤為%s, 實際的標籤為:%s' % (label, test_class[i]))

if str(label) != str(test_class[i]):

error_num += 1

print('*' * 20, '上面一行錯誤', '*' * 20)

print('錯誤率:%.2f' % (float(error_num / len(test_dataset))))

機器學習之KNN

knn主要應用於文字分類 聚類分析 分析 降維等 中心思想是採用測量不同特徵值之間的距離方法進行分類 演算法非常簡單,不過這是乙個監督演算法,訓練資料需要經過人工標記。演算法中心思想是 計算候選樣本到所有訓練樣本之間的距離,選取k個最近距離資料中出現次數最多的分類作為新樣本的類別。from nump...

機器學習之KNN

knn分類演算法 k nearest neighbors classification 即k近鄰演算法 給定乙個訓練資料集,對新的輸入例項,在訓練資料集中找到與該例項最鄰近的k個例項,這k個例項的多數屬於某個類,就把該輸入例項分類到這個類中。核心思想 要確定測試樣本屬於哪一類,就尋找所有訓練樣本中與...

機器學習之KNN

以下部落格主要由兩部分構成。一是理論講解,而是 實現 因為工程上使用knn的頻率不是很高,所以 不是目的,一些 中的技巧就顯得很重要了 首先knn是什麼?k nearest neighbors knn 以下的均來自 貪心科技 不是打廣告,純粹是尊重智財權。因為便於投票分類 怎麼選擇合適的k,一般會用...