上一段實習的時候用spark手寫了一個tfidf,下面貼上代碼并和spark中的源碼進行比較。
輸入文本(demo):
文檔1:a b c d e f g
文檔2:a b c d e f
文檔3:a b c d e
文檔4:a b c d
文檔5:a b c
文檔6:a b
文檔7:a
輸出結果:

代碼分析
主要有以下幾個步驟:
- 讀取文件到JavaRDD<String>中
- mapToPair將每行文本映射為doc <標題 : 單詞[]>中,后者為分詞后的單詞數(shù)組
- mapValues獲取每個文檔的詞頻
- 將文檔數(shù)進行廣播,用于計算idf
- 類似于wordCount, 先將doc中的每個文本對應的去重單詞出現(xiàn)次數(shù)置為1,然后aggregateByKey統(tǒng)計每個單詞出現(xiàn)的文檔數(shù),用對應的求idf的公式,就可以求出idf了
- 將表示每個詞idf的RDD<map> collect到driver,再進行廣播,進行每個文檔的tfIdf計算
- 最后寫入輸出文件
和spark Mllib中tf-idf實現(xiàn)方法的對比
源碼中也是將tf計算和idf計算分隔開的,tf計算時也是用了HashMap但是使用了hash函數(shù)(hashcode取余numfeatures)將詞映射到了一個int作為Key.在計算idf時每個文檔使用了一個詞語大小的向量來保存每個詞是否出現(xiàn)過,累加這些向量就得到了整個數(shù)據(jù)集中每個詞語出現(xiàn)的文檔數(shù),即IDF,再利用公式計算,不過源碼中使用的是log即以e為底而不是以10為底。
源碼中也是用廣播的形式將TF和IDF聯(lián)系起來
public class GenerateTags {
public static void main(String[] args) throws IOException{
SparkConf conf = new SparkConf().setMaster("local").setAppName("test");
// SparkConf conf = new SparkConf().setAppName("video-tags");
JavaSparkContext sc = new JavaSparkContext(conf);
System.setProperty("hadoop.home.dir", "D:\\winutils");
JavaRDD<String> lines = sc.textFile("C:\\Users\\YANGXIN\\Desktop\\test.txt");
//得到每個文檔標題和對應的詞串
JavaPairRDD<String, String[]> docs = lines.mapToPair(new PairFunction<String, String, String[]>() {
@Override
public Tuple2<String, String[]> call(String s) throws Exception {
String[] doc = s.split(":");
String title = doc[0];
String[] words = doc[1].split(" ");
return new Tuple2<String, String[]>(title, words);
}
});
//得到每個文檔的詞頻
JavaPairRDD<String, Map<String, Double>> docTF = docs.mapValues(new Function<String[], Map<String, Double>>() {
@Override
public Map<String, Double> call(String[] strings) throws Exception {
Map<String, Double> map = new HashMap<String, Double>();
int sum = strings.length;
for(String str : strings){
double cnt = map.containsKey(str) ? map.get(str) : 1;
map.put(str, cnt);
}
for(String str : map.keySet()){
map.replace(str, map.get(str) / sum);
}
return map;
}
});
//文檔數(shù)
final Broadcast<Long> docCnt = sc.broadcast(docs.count());
//得到每個詞的idf值
JavaPairRDD<String, Integer> ones = docs.flatMapToPair(new PairFlatMapFunction<Tuple2<String, String[]>, String, Integer>() {
@Override
public Iterable<Tuple2<String, Integer>> call(Tuple2<String, String[]> stringTuple2) throws Exception {
List<Tuple2<String, Integer>> list = new ArrayList<Tuple2<String, Integer>>();
Set<String> set = new HashSet<String>();
for(String str : stringTuple2._2()){
set.add(str);
}
for(String str : set){
list.add(new Tuple2<>(str, 1));
}
return list;
}
});
//每個單詞在多少個文檔中出現(xiàn)了
JavaPairRDD<String, Integer> wordDocCnt= ones.aggregateByKey(0, new Function2<Integer, Integer, Integer>() {
@Override
public Integer call(Integer integer, Integer integer2) throws Exception { //同partition下的處理
return integer + integer2;
}
}, new Function2<Integer, Integer, Integer>() {
@Override
public Integer call(Integer integer, Integer integer2) throws Exception { //不同partition下的處理
return integer + integer2;
}
});
JavaPairRDD<String, Double> wordIdf = wordDocCnt.mapValues(new Function<Integer, Double>() {
@Override
public Double call(Integer integer) throws Exception {
return Math.log10((docCnt.getValue() + 1) * 1.0 / (integer + 1)); //計算逆文檔頻率
}
});
//廣播idf值,進行tf-idf計算
Map<String, Double> idfs = wordIdf.collectAsMap();
final Broadcast<Map<String, Double>> idfMap = sc.broadcast(idfs);
//計算每個文檔的tf-idf向量
JavaPairRDD<String, TreeMap<Double, String>> TfIdf = docTF.mapValues(new Function<Map<String, Double>, TreeMap<Double, String>>() {
@Override
public TreeMap<Double, String> call(Map<String, Double> stringDoubleMap) throws Exception {
TreeMap<Double, String> map = new TreeMap<Double, String>();
for(Map.Entry<String, Double> entry : stringDoubleMap.entrySet()){
String word = entry.getKey();
Double tf = entry.getValue();
Double idf = idfMap.getValue().get(word);
map.put(tf * idf, word);
}
return map;
}
});
TfIdf.saveAsTextFile("C:\\Users\\YANGXIN\\Desktop\\result");
}