KNN - 基于tensorflow實(shí)現(xiàn)程序

本文之編寫程序涉及到API介紹,程序的完整實(shí)現(xiàn),具體算法原理請(qǐng)查看之前所寫的KNN算法介紹

一、基礎(chǔ)準(zhǔn)備

1、python 基礎(chǔ)

2、numpy 基礎(chǔ)

np.argmax
返回?cái)?shù)組的最大值

>>> a = np.arange(6).reshape(2,3)
>>> a.argmax()
5
>>> a.argmax(0)
array([1, 1, 1])
>>> a.argmax(1)
array([2, 2])

3、tensorflow 基礎(chǔ)

tf.arg_max
求數(shù)組最大值的下標(biāo),如axis=0,代表第一維度,8<2:0,1>3:1,2>4:1 =[0,1,1],如axis=1,代表第二維度,[8,1,2]:0,[2,3,4]:2 = [0 2]

data = tf.constant([[8,1,2],[2,3,4]])
sess = tf.Session()
print(sess.run(tf.arg_max(data,0)))
# >> [0 1 1]
print(sess.run(tf.arg_max(data,1)))
# >>[0 2]

tf.arg_min
與arg_min相反

二、完整程序

import tensorflow as tf
import numpy as np

def file2Mat(testFileName, parammterNumber):
    fr = open(testFileName)
    lines = fr.readlines()
    lineNums = len(lines)
    resultMat = np.zeros((lineNums, parammterNumber))
    classLabelVector = []
    for i in range(lineNums):
        line = lines[i].strip()
        itemMat = line.split('\t')
        resultMat[i, :] = itemMat[0:parammterNumber]
        classLabelVector.append(itemMat[-1])
    fr.close()
    return resultMat, classLabelVector;

# 為了防止某個(gè)屬性對(duì)結(jié)果產(chǎn)生很大的影響,所以有了這個(gè)優(yōu)化,比如:10000,4.5,6.8 10000就對(duì)結(jié)果基本起了決定作用
def autoNorm(dataSet):
    minVals = dataSet.min(0)
    maxVals = dataSet.max(0)
    ranges = maxVals - minVals
    normMat = np.zeros(np.shape(dataSet))
    size = normMat.shape[0]
    normMat = dataSet - np.tile(minVals, (size, 1))
    normMat = normMat / np.tile(ranges, (size, 1))
    return normMat, minVals, ranges

if __name__=='__main__':

    trainigSetFileName = 'data\\datingTrainingSet.txt'
    testFileName = 'data\\datingTestSet.txt'

    # 讀取訓(xùn)練數(shù)據(jù)
    trianingMat, classLabel = file2Mat(trainigSetFileName, 3)
    # 都數(shù)據(jù)進(jìn)行歸一化的處理
    autoNormTrianingMat, minVals, ranges = autoNorm(trianingMat)
    # 讀取測(cè)試數(shù)據(jù)
    testMat, testLabel = file2Mat(testFileName, 3)
    autoNormTestMat = []
    for i in range(len(testLabel)):
        autoNormTestMat.append((testMat[i] - minVals) / ranges)

    # 循環(huán)迭代計(jì)算每一個(gè)測(cè)試數(shù)據(jù)的預(yù)測(cè)值,并且和真正的值進(jìn)行對(duì)比,并計(jì)算精確度。該算法比較經(jīng)典的是不需要提前訓(xùn)練,直接在測(cè)試階段進(jìn)行識(shí)別。
    traindata_tensor=tf.placeholder('float',[None,3])
    testdata_tensor=tf.placeholder('float',[3])

    distance = tf.sqrt(tf.reduce_sum(tf.pow(tf.add(traindata_tensor, tf.negative(testdata_tensor)), 2), reduction_indices=1))
    pred = tf.arg_min(distance,0)
    test_num=1
    accuracy=0
    init=tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        for i in range(test_num):
            print(sess.run(distance,feed_dict={traindata_tensor:autoNormTrianingMat,testdata_tensor:autoNormTestMat[i]}))
            idx=sess.run(pred,feed_dict={traindata_tensor:autoNormTrianingMat,testdata_tensor:autoNormTestMat[i]})
            print(idx)

            print('test No.%d,the real label %d, the predict label %d'%(i,np.argmax(testLabel[i]),np.argmax(classLabel[idx])))
            if np.argmax(testLabel[i])==np.argmax(classLabel[idx]):
                accuracy+=1
        print("result:%f"%(1.0*accuracy/test_num))
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容