最近因為項目需求,需要時在weka上實(shí)現(xiàn)流形距離計算,因為weka沒有提供流形學(xué)習(xí)的包,而smile提供了,于是我根據(jù)smile的等距離度量(Isomap)來重寫了一個可在weka上使用的流形距離計算類。
歐式距離是最常用的距離度量,但是在數(shù)據(jù)集不具有全局線性結(jié)構(gòu)是,歐氏距離就不是一種合理的數(shù)據(jù)距離度量,一般使用拓?fù)淞餍谓Y(jié)構(gòu)來度量高維度的非線線性數(shù)據(jù)。這種方法通常用了對數(shù)據(jù)進(jìn)行降維,也被稱為流形學(xué)習(xí)。
定義1:
流形兩點(diǎn)間x1, x2的線段長度定義為 L(x1, x2) = exp(d(x1, x2) / σ) -1
定義2:
將數(shù)據(jù)點(diǎn)看作是無向有權(quán)圖G=(V, E),V是頂點(diǎn)集合,E是邊集P的集合,Pij表示圖上數(shù)據(jù)點(diǎn)Xi, Xj的所有路徑集合,則Xi,Xj的流形距離為 MD(xi, xj)=min∑L(pk, pk+1), 1≤k≤|p| - 1
算法流程:
for i = 1,2,3...m do
確定xi的k個最近鄰
將xi與k個最近鄰的距離設(shè)為定義的距離公式,與自己的距離設(shè)為0,與其他點(diǎn)距離設(shè)為-1
將這些數(shù)值添加進(jìn)入鄰接矩陣
end
根據(jù)鄰接矩陣構(gòu)建一個有權(quán)無向圖的對象
使用dijkstra最短距離求出圖上任意兩點(diǎn)的最短距離
ManifoldDistance.java
import weka.core.EuclideanDistance;
import weka.core.Instances;
import java.util.*;
/**
* Created by Administrator on 2017/3/15.
*/
public class ManifoldDistance {
private final Instances data;
private final int k;
private final double sigma;
private double[][] matrix;
private Graph graph = new Graph();
/**
* 流形學(xué)習(xí)的距離計算類的構(gòu)造方法
*
* @param data 要計算的instances類型的數(shù)據(jù)集
* @param k KNN需要指定的參數(shù)k
* @param sigma 距離公式需要的參數(shù)σ
*/
public ManifoldDistance(Instances data, int k, double sigma) {
this.data = data;
this.k = k;
this.sigma = sigma;
}
public Instances getData() {
return data;
}
public int getK() {
return k;
}
public double getSigma() {
return sigma;
}
public double[][] getMatrix() {
return matrix;
}
/**
* 構(gòu)造數(shù)據(jù)data的鄰接矩陣
*
* @return double[][]類型的鄰接矩陣
*/
private double[][] constructWeightMatrix() {
int num = this.data.numInstances();
double[][] weight_matrix = new double[num][num];
EuclideanDistance calculateDistance = new EuclideanDistance(this.data);
for(int i = 0; i < num; i++){
HashMap<Integer, Double> temp = new HashMap<>();
for(int j = 0; j < num; j++){
if(i != j) {
double dist = calculateDistance.distance(this.data.instance(i), this.data.instance(j));
temp.put(j, Math.exp(dist / this.sigma) - 1);
}else{
temp.put(j, 0.0);
}
}
ArrayList<Integer> index = nearestNeighbor(temp);
for(int n = 0; n < num; n++){
if(index.contains(n)){
weight_matrix[i][n] = temp.get(n);
weight_matrix[n][i] = temp.get(n);
}else if(i == n){
weight_matrix[i][i] = 0.0;
}else{
if(weight_matrix[i][n] == 0.0) {
weight_matrix[i][n] = -1.0;
}
}
}
}
return weight_matrix;
}
/**
* 計算K個最近鄰
*
* @param temp 當(dāng)前向量i與其他所有向量的距離
* @return k個最近鄰所在的位置索引
*/
private ArrayList<Integer> nearestNeighbor(HashMap<Integer, Double> temp){
ArrayList<Integer> index = new ArrayList<>();
ArrayList<Map.Entry<Integer, Double>> list = new ArrayList<>(temp.entrySet());
list.sort((o1, o2) -> o2.getValue().compareTo(o1.getValue()));
int count = 0;
for (Map.Entry<Integer, Double> aList : list) {
if(count >= this.k){
break;
}else {
index.add(aList.getKey());
count++;
}
}
return index;
}
/**
* 生成鄰接矩陣與對應(yīng)的無向有權(quán)圖
*/
public void build(){
this.matrix = constructWeightMatrix();
int num = this.matrix.length;
HashMap<String, List<Vertex>>edge = new HashMap<>();
for (int i = 0; i < num; i++){
edge.put(Integer.toString(i), new ArrayList<>());
}
for (int i = 0; i < num; i++){
for (int j = 0; j < num; j++){
if (this.matrix[i][j] > 0){
List<Vertex> iedge = edge.get(Integer.toString(i));
iedge.add(new Vertex(Integer.toString(j), this.matrix[i][j]));
edge.put(Integer.toString(i), iedge);
List<Vertex> jedge = edge.get(Integer.toString(j));
jedge.add(new Vertex(Integer.toString(i), this.matrix[i][j]));
edge.put(Integer.toString(j), jedge);
}
}
}
for(String i : edge.keySet()){
List<Vertex> toVertex = edge.get(i);
this.graph.addVertex(i, toVertex);
}
}
/**
* 獲取圖上兩個向量的dijkstra最短距離
*
* @param start 起始點(diǎn)
* @param end 結(jié)束點(diǎn)
* @return 最短距離的數(shù)值
*/
public double getDistance(String start, String end){
List<String> path = this.graph.getShortestPath(start, end);
path.add(start);
Collections.reverse(path);
double mDist = 0.0;
for (int i = 0; i < path.size() - 1; i++){
int m = Integer.parseInt(path.get(i));
int n = Integer.parseInt(path.get(i + 1));
mDist += this.matrix[m][n];
}
System.out.println("shortest path:" + path);
return mDist;
}
}
Graph.java
import java.util.*;
/**
* Created by Administrator on 2017/3/14.
*/
class Graph {
private final Map<String, List<Vertex>> vertices;
public Graph() {
this.vertices = new HashMap<>();
}
public void addVertex(String character, List<Vertex> vertex) {
this.vertices.put(character, vertex);
}
public List<String> getShortestPath(String start, String finish) {
final Map<String, Double> distances = new HashMap<>();
final Map<String, Vertex> previous = new HashMap<>();
PriorityQueue<Vertex> nodes = new PriorityQueue<>();
for(String vertex : vertices.keySet()) {
if (Objects.equals(vertex, start)) {
distances.put(vertex, 0.0);
nodes.add(new Vertex(vertex, 0.0));
} else {
distances.put(vertex, Double.MAX_VALUE);
nodes.add(new Vertex(vertex, Double.MAX_VALUE));
}
previous.put(vertex, null);
}
while (!nodes.isEmpty()) {
Vertex smallest = nodes.poll();
if (Objects.equals(smallest.getId(), finish)) {
final List<String> path = new ArrayList<>();
while (previous.get(smallest.getId()) != null) {
path.add(smallest.getId());
smallest = previous.get(smallest.getId());
}
return path;
}
if (distances.get(smallest.getId()) == Integer.MAX_VALUE) {
break;
}
for (Vertex neighbor : vertices.get(smallest.getId())) {
Double alt = distances.get(smallest.getId()) + neighbor.getDistance();
if (alt < distances.get(neighbor.getId())) {
distances.put(neighbor.getId(), alt);
previous.put(neighbor.getId(), smallest);
for(Vertex n : nodes) {
if (Objects.equals(n.getId(), neighbor.getId())) {
nodes.remove(n);
n.setDistance(alt);
nodes.add(n);
break;
}
}
}
}
}
return new ArrayList<>(distances.keySet());
}
}
Vertex.java
/**
* Created by Administrator on 2017/3/14.
*/
class Vertex implements Comparable<Vertex> {
private String id;
private Double distance;
public Vertex(String id, Double distance) {
super();
this.id = id;
this.distance = distance;
}
public String getId() {
return id;
}
public Double getDistance() {
return distance;
}
public void setId(String id) {
this.id = id;
}
public void setDistance(Double distance) {
this.distance = distance;
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result
+ ((distance == null) ? 0 : distance.hashCode());
result = prime * result + ((id == null) ? 0 : id.hashCode());
return result;
}
@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (getClass() != obj.getClass())
return false;
Vertex other = (Vertex) obj;
if (distance == null) {
if (other.distance != null)
return false;
} else if (!distance.equals(other.distance))
return false;
if (id == null) {
if (other.id != null)
return false;
} else if (!id.equals(other.id))
return false;
return true;
}
@Override
public String toString() {
return "Vertex [id=" + id + ", distance=" + distance + "]";
}
@Override
public int compareTo(Vertex o) {
if (this.distance < o.distance)
return -1;
else if (this.distance > o.distance)
return 1;
else
return this.getId().compareTo(o.getId());
}
}
Demo.java
import weka.core.Instances;
import java.io.FileReader;
import java.io.IOException;
/**
* Created by Administrator on 2017/3/15.
*/
public class Demo {
public static void main(String[] args) throws IOException {
Instances data = new Instances(new FileReader("Test/Manifold/cpu.arff"));
ManifoldDistance manifold = new ManifoldDistance(data, 20, 2);
manifold.build();
for (double[] aMtx : manifold.getMatrix()) {
for(double v : aMtx){
System.out.print(v + " ");
}
System.out.println();
}
System.out.println(manifold.getDistance("10", "71"));
System.out.println(manifold.getDistance("71", "10"));
}
}