基于spark實現(xiàn)TFIDF

上一段實習的時候用spark手寫了一個tfidf,下面貼上代碼并和spark中的源碼進行比較。
輸入文本(demo):

文檔1:a b c d e f g
文檔2:a b c d e f 
文檔3:a b c d e
文檔4:a b c d
文檔5:a b c 
文檔6:a b 
文檔7:a

輸出結果:

代碼分析
主要有以下幾個步驟:

  1. 讀取文件到JavaRDD<String>中
  2. mapToPair將每行文本映射為doc <標題 : 單詞[]>中,后者為分詞后的單詞數(shù)組
  3. mapValues獲取每個文檔的詞頻
  4. 將文檔數(shù)進行廣播,用于計算idf
  5. 類似于wordCount, 先將doc中的每個文本對應的去重單詞出現(xiàn)次數(shù)置為1,然后aggregateByKey統(tǒng)計每個單詞出現(xiàn)的文檔數(shù),用對應的求idf的公式,就可以求出idf了
  6. 將表示每個詞idf的RDD<map> collect到driver,再進行廣播,進行每個文檔的tfIdf計算
  7. 最后寫入輸出文件

和spark Mllib中tf-idf實現(xiàn)方法的對比
源碼中也是將tf計算和idf計算分隔開的,tf計算時也是用了HashMap但是使用了hash函數(shù)(hashcode取余numfeatures)將詞映射到了一個int作為Key.在計算idf時每個文檔使用了一個詞語大小的向量來保存每個詞是否出現(xiàn)過,累加這些向量就得到了整個數(shù)據(jù)集中每個詞語出現(xiàn)的文檔數(shù),即IDF,再利用公式計算,不過源碼中使用的是log即以e為底而不是以10為底。

源碼中也是用廣播的形式將TF和IDF聯(lián)系起來

public class GenerateTags {

    public static void main(String[] args) throws IOException{
        SparkConf conf = new SparkConf().setMaster("local").setAppName("test");
//        SparkConf conf = new SparkConf().setAppName("video-tags");
        JavaSparkContext sc = new JavaSparkContext(conf);
        System.setProperty("hadoop.home.dir", "D:\\winutils");
        JavaRDD<String> lines = sc.textFile("C:\\Users\\YANGXIN\\Desktop\\test.txt");

        //得到每個文檔標題和對應的詞串
        JavaPairRDD<String, String[]> docs = lines.mapToPair(new PairFunction<String, String, String[]>() {
            @Override
            public Tuple2<String, String[]> call(String s) throws Exception {
                String[] doc = s.split(":");
                String title = doc[0];
                String[] words = doc[1].split(" ");
                return new Tuple2<String, String[]>(title, words);
            }
        });

        //得到每個文檔的詞頻
        JavaPairRDD<String, Map<String, Double>> docTF = docs.mapValues(new Function<String[], Map<String, Double>>() {
            @Override
            public Map<String, Double> call(String[] strings) throws Exception {
                Map<String, Double> map = new HashMap<String, Double>();
                int sum = strings.length;
                for(String str : strings){
                    double cnt = map.containsKey(str) ? map.get(str) : 1;
                    map.put(str, cnt);
                }
                for(String str : map.keySet()){
                    map.replace(str, map.get(str) / sum);
                }
                return map;
            }
        });

        //文檔數(shù)
        final Broadcast<Long> docCnt = sc.broadcast(docs.count());

        //得到每個詞的idf值
        JavaPairRDD<String, Integer> ones = docs.flatMapToPair(new PairFlatMapFunction<Tuple2<String, String[]>, String, Integer>() {
            @Override
            public Iterable<Tuple2<String, Integer>> call(Tuple2<String, String[]> stringTuple2) throws Exception {
                List<Tuple2<String, Integer>> list = new ArrayList<Tuple2<String, Integer>>();
                Set<String> set = new HashSet<String>();
                for(String str : stringTuple2._2()){
                    set.add(str);
                }
                for(String str : set){
                    list.add(new Tuple2<>(str, 1));
                }
                return list;
            }
        });

        //每個單詞在多少個文檔中出現(xiàn)了
        JavaPairRDD<String, Integer> wordDocCnt= ones.aggregateByKey(0, new Function2<Integer, Integer, Integer>() {
            @Override
            public Integer call(Integer integer, Integer integer2) throws Exception { //同partition下的處理
                return integer + integer2;
            }
        }, new Function2<Integer, Integer, Integer>() {
            @Override
            public Integer call(Integer integer, Integer integer2) throws Exception { //不同partition下的處理
                return integer + integer2;
            }
        });

        JavaPairRDD<String, Double> wordIdf = wordDocCnt.mapValues(new Function<Integer, Double>() {
            @Override
            public Double call(Integer integer) throws Exception {
                return Math.log10((docCnt.getValue() + 1) * 1.0 / (integer + 1));  //計算逆文檔頻率
            }
        });

        //廣播idf值,進行tf-idf計算
        Map<String, Double> idfs = wordIdf.collectAsMap();
        final Broadcast<Map<String, Double>> idfMap = sc.broadcast(idfs);

        //計算每個文檔的tf-idf向量
        JavaPairRDD<String, TreeMap<Double, String>> TfIdf = docTF.mapValues(new Function<Map<String, Double>, TreeMap<Double, String>>() {
            @Override
            public TreeMap<Double, String> call(Map<String, Double> stringDoubleMap) throws Exception {
                TreeMap<Double, String> map = new TreeMap<Double, String>();
                for(Map.Entry<String, Double> entry : stringDoubleMap.entrySet()){
                    String word = entry.getKey();
                    Double tf = entry.getValue();
                    Double idf = idfMap.getValue().get(word);
                    map.put(tf * idf, word);
                }
                return map;
            }
        });

        TfIdf.saveAsTextFile("C:\\Users\\YANGXIN\\Desktop\\result");
    }

參考文獻:
https://github.com/endymecy/spark-ml-source-analysis/blob/master/%E7%89%B9%E5%BE%81%E6%8A%BD%E5%8F%96%E5%92%8C%E8%BD%AC%E6%8D%A2/TF-IDF.md

最后編輯于
?著作權歸作者所有,轉載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務。

相關閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容