在Weka中實(shí)現(xiàn)流形學(xué)習(xí)Isomap中的距離計算

最近因為項目需求,需要時在weka上實(shí)現(xiàn)流形距離計算,因為weka沒有提供流形學(xué)習(xí)的包,而smile提供了,于是我根據(jù)smile的等距離度量(Isomap)來重寫了一個可在weka上使用的流形距離計算類。

歐式距離是最常用的距離度量,但是在數(shù)據(jù)集不具有全局線性結(jié)構(gòu)是,歐氏距離就不是一種合理的數(shù)據(jù)距離度量,一般使用拓?fù)淞餍谓Y(jié)構(gòu)來度量高維度的非線線性數(shù)據(jù)。這種方法通常用了對數(shù)據(jù)進(jìn)行降維,也被稱為流形學(xué)習(xí)。

定義1:
流形兩點(diǎn)間x1, x2的線段長度定義為 L(x1, x2) = exp(d(x1, x2) / σ) -1
定義2:
將數(shù)據(jù)點(diǎn)看作是無向有權(quán)圖G=(V, E),V是頂點(diǎn)集合,E是邊集P的集合,Pij表示圖上數(shù)據(jù)點(diǎn)Xi, Xj的所有路徑集合,則Xi,Xj的流形距離為 MD(xi, xj)=min∑L(pk, pk+1), 1≤k≤|p| - 1

算法流程:

for i = 1,2,3...m do
    確定xi的k個最近鄰
    將xi與k個最近鄰的距離設(shè)為定義的距離公式,與自己的距離設(shè)為0,與其他點(diǎn)距離設(shè)為-1
    將這些數(shù)值添加進(jìn)入鄰接矩陣
end

根據(jù)鄰接矩陣構(gòu)建一個有權(quán)無向圖的對象
使用dijkstra最短距離求出圖上任意兩點(diǎn)的最短距離

ManifoldDistance.java

import weka.core.EuclideanDistance;
import weka.core.Instances;

import java.util.*;

/**
 * Created by Administrator on 2017/3/15.
 */
public class ManifoldDistance {
    private final Instances data;
    private final int k;
    private final double sigma;
    private double[][] matrix;
    private Graph graph = new Graph();

    /**
     * 流形學(xué)習(xí)的距離計算類的構(gòu)造方法
     *
     * @param data  要計算的instances類型的數(shù)據(jù)集
     * @param k     KNN需要指定的參數(shù)k
     * @param sigma     距離公式需要的參數(shù)σ
     */
    public ManifoldDistance(Instances data, int k, double sigma) {
        this.data = data;
        this.k = k;
        this.sigma = sigma;
    }

    public Instances getData() {
        return data;
    }

    public int getK() {
        return k;
    }

    public double getSigma() {
        return sigma;
    }

    public double[][] getMatrix() {
        return matrix;
    }

    /**
     * 構(gòu)造數(shù)據(jù)data的鄰接矩陣
     *
     * @return      double[][]類型的鄰接矩陣
     */
    private double[][] constructWeightMatrix() {
        int num = this.data.numInstances();
        double[][] weight_matrix = new double[num][num];
        EuclideanDistance calculateDistance = new EuclideanDistance(this.data);

        for(int i = 0; i < num; i++){
            HashMap<Integer, Double> temp = new HashMap<>();
            for(int j = 0; j < num; j++){
                if(i != j) {
                    double dist = calculateDistance.distance(this.data.instance(i), this.data.instance(j));
                    temp.put(j, Math.exp(dist / this.sigma) - 1);
                }else{
                    temp.put(j, 0.0);
                }
            }

            ArrayList<Integer> index = nearestNeighbor(temp);
            for(int n = 0; n < num; n++){
                if(index.contains(n)){
                    weight_matrix[i][n] = temp.get(n);
                    weight_matrix[n][i] = temp.get(n);
                }else if(i == n){
                    weight_matrix[i][i] = 0.0;
                }else{
                    if(weight_matrix[i][n] == 0.0) {
                        weight_matrix[i][n] = -1.0;
                    }
                }
            }
        }
        return weight_matrix;
    }

    /**
     * 計算K個最近鄰
     *
     * @param temp  當(dāng)前向量i與其他所有向量的距離
     * @return      k個最近鄰所在的位置索引
     */
    private ArrayList<Integer> nearestNeighbor(HashMap<Integer, Double> temp){
        ArrayList<Integer> index = new ArrayList<>();
        ArrayList<Map.Entry<Integer, Double>> list = new ArrayList<>(temp.entrySet());
        list.sort((o1, o2) -> o2.getValue().compareTo(o1.getValue()));

        int count = 0;
        for (Map.Entry<Integer, Double> aList : list) {
            if(count >= this.k){
                break;
            }else {
                index.add(aList.getKey());
                count++;
            }
        }
        return index;
    }

    /**
     * 生成鄰接矩陣與對應(yīng)的無向有權(quán)圖
     */
    public void build(){
        this.matrix = constructWeightMatrix();

        int num = this.matrix.length;

        HashMap<String, List<Vertex>>edge = new HashMap<>();
        for (int i = 0; i < num; i++){
            edge.put(Integer.toString(i), new ArrayList<>());
        }

        for (int i = 0; i < num; i++){
            for (int j = 0; j < num; j++){
                if (this.matrix[i][j] > 0){
                    List<Vertex> iedge = edge.get(Integer.toString(i));
                    iedge.add(new Vertex(Integer.toString(j), this.matrix[i][j]));
                    edge.put(Integer.toString(i), iedge);

                    List<Vertex> jedge = edge.get(Integer.toString(j));
                    jedge.add(new Vertex(Integer.toString(i), this.matrix[i][j]));
                    edge.put(Integer.toString(j), jedge);
                }
            }
        }

        for(String i : edge.keySet()){
            List<Vertex> toVertex = edge.get(i);
            this.graph.addVertex(i, toVertex);
        }
    }

    /**
     * 獲取圖上兩個向量的dijkstra最短距離
     *
     * @param start     起始點(diǎn)
     * @param end   結(jié)束點(diǎn)
     * @return      最短距離的數(shù)值
     */
    public double getDistance(String start, String end){
        List<String> path = this.graph.getShortestPath(start, end);
        path.add(start);
        Collections.reverse(path);

        double mDist = 0.0;
        for (int i = 0; i < path.size() - 1; i++){
            int m = Integer.parseInt(path.get(i));
            int n = Integer.parseInt(path.get(i + 1));
            mDist += this.matrix[m][n];
        }

        System.out.println("shortest path:" + path);
        return mDist;
    }
}

Graph.java

import java.util.*;

/**
 * Created by Administrator on 2017/3/14.
 */

class Graph {

    private final Map<String, List<Vertex>> vertices;

    public Graph() {
        this.vertices = new HashMap<>();
    }

    public void addVertex(String character, List<Vertex> vertex) {
        this.vertices.put(character, vertex);
    }

    public List<String> getShortestPath(String start, String finish) {
        final Map<String, Double> distances = new HashMap<>();
        final Map<String, Vertex> previous = new HashMap<>();
        PriorityQueue<Vertex> nodes = new PriorityQueue<>();

        for(String vertex : vertices.keySet()) {
            if (Objects.equals(vertex, start)) {
                distances.put(vertex, 0.0);
                nodes.add(new Vertex(vertex, 0.0));
            } else {
                distances.put(vertex, Double.MAX_VALUE);
                nodes.add(new Vertex(vertex, Double.MAX_VALUE));
            }
            previous.put(vertex, null);
        }

        while (!nodes.isEmpty()) {
            Vertex smallest = nodes.poll();
            if (Objects.equals(smallest.getId(), finish)) {
                final List<String> path = new ArrayList<>();
                while (previous.get(smallest.getId()) != null) {
                    path.add(smallest.getId());
                    smallest = previous.get(smallest.getId());
                }
                return path;
            }

            if (distances.get(smallest.getId()) == Integer.MAX_VALUE) {
                break;
            }

            for (Vertex neighbor : vertices.get(smallest.getId())) {
                Double alt = distances.get(smallest.getId()) + neighbor.getDistance();
                if (alt < distances.get(neighbor.getId())) {
                    distances.put(neighbor.getId(), alt);
                    previous.put(neighbor.getId(), smallest);

                    for(Vertex n : nodes) {
                        if (Objects.equals(n.getId(), neighbor.getId())) {
                            nodes.remove(n);
                            n.setDistance(alt);
                            nodes.add(n);
                            break;
                        }
                    }
                }
            }
        }
        return new ArrayList<>(distances.keySet());
    }
}

Vertex.java

/**
 * Created by Administrator on 2017/3/14.
 */

class Vertex implements Comparable<Vertex> {

    private String id;
    private Double distance;

    public Vertex(String id, Double distance) {
        super();
        this.id = id;
        this.distance = distance;
    }

    public String getId() {
        return id;
    }

    public Double getDistance() {
        return distance;
    }

    public void setId(String id) {
        this.id = id;
    }

    public void setDistance(Double distance) {
        this.distance = distance;
    }

    @Override
    public int hashCode() {
        final int prime = 31;
        int result = 1;
        result = prime * result
                + ((distance == null) ? 0 : distance.hashCode());
        result = prime * result + ((id == null) ? 0 : id.hashCode());
        return result;
    }

    @Override
    public boolean equals(Object obj) {
        if (this == obj)
            return true;
        if (obj == null)
            return false;
        if (getClass() != obj.getClass())
            return false;
        Vertex other = (Vertex) obj;
        if (distance == null) {
            if (other.distance != null)
                return false;
        } else if (!distance.equals(other.distance))
            return false;
        if (id == null) {
            if (other.id != null)
                return false;
        } else if (!id.equals(other.id))
            return false;
        return true;
    }

    @Override
    public String toString() {
        return "Vertex [id=" + id + ", distance=" + distance + "]";
    }

    @Override
    public int compareTo(Vertex o) {
        if (this.distance < o.distance)
            return -1;
        else if (this.distance > o.distance)
            return 1;
        else
            return this.getId().compareTo(o.getId());
    }

}

Demo.java

import weka.core.Instances;

import java.io.FileReader;
import java.io.IOException;

/**
 * Created by Administrator on 2017/3/15.
 */
public class Demo {
    public static void main(String[] args) throws IOException {
        Instances data = new Instances(new FileReader("Test/Manifold/cpu.arff"));
        ManifoldDistance manifold = new ManifoldDistance(data, 20, 2);
        manifold.build();
        for (double[] aMtx : manifold.getMatrix()) {
            for(double v : aMtx){
                System.out.print(v + "   ");
            }
            System.out.println();
        }

        System.out.println(manifold.getDistance("10", "71"));
        System.out.println(manifold.getDistance("71", "10"));
    }
}
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容