k近鄰分類器

??K近鄰分類器是機器識別中很常用的一種分類方法,以前在做單樣本人臉識別的時候常用的最近鄰分類方法就是其中k=1的特殊情況。以前都是用matlab寫的代碼,最近老師動員大家一起來學(xué)python,然后我們就都從K近鄰算法開始學(xué)習(xí)編程了。

K近鄰算法

??該方法的思路是:如果一個樣本在特征空間中的k個最相似(即特征空間中最鄰近)的樣本中的大多數(shù)屬于某一個類別,則該樣本也屬于這個類別。其實也很好理解,就是現(xiàn)在有n個訓(xùn)練樣本,分別對應(yīng)c個類,現(xiàn)在有一個未知的測試樣本,要找到這個測試樣本的類別。K近鄰的方法就是計算出每個訓(xùn)練樣本和測試樣本的距離,找到其中最近的K個樣本對應(yīng)的類別,并統(tǒng)計每個類出現(xiàn)的次數(shù),出現(xiàn)的最多的類即使該測試樣本對應(yīng)的類。

Python代碼

??這個代碼就是隨便寫寫了,我只用了歐氏距離。其實python有自帶的KNN的函數(shù),所以還是推薦大家直接調(diào)用函數(shù)吧。代碼若有錯誤望指正。

    def knn(trn,label,x,k):
    import numpy as np
    dist=np.zeros(numsamples)
    for i in range(numsamples):
        dist[i] = np.sqrt(np.sum(np.square(trn[i] - x)))
    print (dist)

    index=np.argsort(dist)
    print (index)

    numclass=np.max(label)+1
    times=np.zeros(numclass)
    for i in range(k):
        times[label[index[i]]]=times[label[index[i]]]+1
        print(times)

    ind=np.argmax(times)
    return (label[ind])

測試:

    trn1=np.random.randn(10,2)
    trn2=np.random.randn(10,2)
    trn=np.concatenate((trn1,trn2))
    x=np.random.randn(1,2)
    label=[0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1]
    k=3
    result=knn(trn,label,x,k)
    print (resut)

matlab代碼

??matlab我把這個分成了兩個部分,計算距離和分類,因為可選的參數(shù)有點多。
首先是計算距離的函數(shù)(discompute.m):

    function [D]=discompute(X_trn,X_tst,method)
    %%%%%%%%%%%%%%%%%%%%%
    %   這個函數(shù)是來計算訓(xùn)練樣本和測試樣本之間的距離
    %   X_trn是訓(xùn)練樣本,X_tst是測試樣本,都是cell矩陣,一個cell存放一張圖片
    %   method有四個值可供選擇:chi-square(卡方距離)、manhatton(曼哈頓距離)、cosine(余弦距離)、euclidean(歐氏距離)
    %   返回值D即距離,是cell矩陣,一個cell存放一個測試樣本和所有訓(xùn)練樣本的距離
    %   對于D中的每個cell,行對應(yīng)著不同的塊,列對應(yīng)著不同的訓(xùn)練樣本
    
    if nargin<3 method='chi-square';end
    D=cell(length(X_tst),1);
    switch lower(method)
        case 'chi-square'
            for j=1:length(X_tst)
                x1=X_tst{j};
                for i=1:length(X_trn)
                    x2=X_trn{i};
                    A=(x1-x2).^2./(x1+x2);
                    in=find(isnan(A)==1);
                    A(in)=0;
                    D{j}(:,i)=sum(A,2);     % Chi-square distance
                end
            end
        case 'manhattan'
            for j=1:length(X_tst)
                x1=X_tst{j};
                for i=1:length(X_trn)
                    x2=X_trn{i};
                    D{j}(:,i)=sum(abs(x1-x2),2); %Mahatton distance
                end
            end
        case 'cosine'
            for j=1:length(X_tst)
                x1=X_tst{j};
    %              [x1]=normlizedata(x1,'1-norm');
                for i=1:length(X_trn)
                    x2=X_trn{i};
    %                    [x2]=normlizedata(x2,'1-norm');
                    D{j}(:,i)=1-(diag(x1*x2')./(sqrt(diag(x1*x1')).*sqrt(diag(x2*x2')))); % Cosine distance
                end
            end
        case 'euclidean'
            for j=1:length(X_tst)
                x1=X_tst{j};
                for i=1:length(X_trn)
                    x2=X_trn{i};
                    D{j}(:,i)=(sum(((x1-x2).^2),2)); %Euclidean distance
                end
            end
    end

然后是分類(disclassify.m):

`function [out]=distclassify(D,Y,method,Layer)
    %%%%%%%%%%%%%%%%%%
    %   這個函數(shù)是用來進行分類的
    %   參數(shù)D是距離(參見discompute.m),Y是所有訓(xùn)練樣本的標(biāo)簽
    %   method有四個值可選:vote(投票)、min_dist(最小值)、max_dist(最大值)和sum_dist(總和)。
    %   out返回所有測試樣本的標(biāo)簽,是一個行向量
    
    if nargin<2 || nargin>4
        help Distclassify
    else
        numclass=max(Y);
        numtest=length(D);
        if nargin<4 Layer=floor(log(numclass)/log(2));end 
        if nargin<3 method='vote';end
        switch lower(method)
            case 'vote'
                A=zeros(numtest,numclass);
                for i=1:length(D)
                    [~,d]=min(D{i}');
                    for j=1:numclass
                        A(i,j)=sum(Y(j)==d);
                    end
                end
                [~,out]=max(A');
            case 'min_dist'
                A=zeros(numtest,numclass);
                for i=1:numtest
                    A(i,:)=min(D{i});
                end
                [~,out]=min(A');
            case 'max_dist'
                A=zeros(numtest,numclass);
                for i=1:numtest
                    A(i,:)=max(D{i});
                end
                [~,out]=min(A');
            case 'sum_dist'
                A=zeros(numtest,numclass);
                for i=1:length(D)
                    A(i,:)=sum(D{i});
                end
                [~,out]=min(A');
          
        end
    end
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容