GraphX Label Propagation算法改進(jìn)

label propagation算法介紹

標(biāo)簽傳播算法(label propagation)的核心思想非常簡單:相似的數(shù)據(jù)應(yīng)該具有相同的label。LP算法包括兩大步驟:1)構(gòu)造關(guān)系網(wǎng);2)標(biāo)簽傳播。

  • 算法具體步驟如下:
    1、初始時,給每個節(jié)點一個唯一的標(biāo)簽;
    2、每個節(jié)點使用其鄰居節(jié)點的標(biāo)簽中最多的標(biāo)簽來更新自身的標(biāo)簽。
    3、反復(fù)執(zhí)行步驟2,直到每個節(jié)點的標(biāo)簽都不再發(fā)生變化為止。

  • 一次迭代過程中一個節(jié)點標(biāo)簽的更新可以分為同步和異步兩種。所謂同步更新,即節(jié)點z在第t次迭代的label依據(jù)于它的鄰居節(jié)點在第t-1次迭代時所得的label;異步更新,即節(jié)點z在第t次迭代的label依據(jù)于第t次迭代已經(jīng)更新過label的節(jié)點和第t次迭代未更新過label的節(jié)點在第t-1次迭代時的label。

graphX自帶LP算法的缺陷

1、邊權(quán)重信息不參與計算過程;
2、標(biāo)簽傳播結(jié)果存在震蕩的問題(震蕩問題是所有基于BSP模式的框架普遍存在的問題)

關(guān)于graphx及BSP可見我另一篇文章 http://www.itdecent.cn/p/7190123ad329

邊權(quán)重與無向圖支持的改造

  • 基于pregel接口,重新實現(xiàn)了一套傳播sendMessage和mergeMessage方法
def sendMessage(e: EdgeTriplet[VertexId, Int]): Iterator[(VertexId, Map[VertexId, Long])] = {
    Iterator((e.srcId, Map(e.dstAttr -> e.attr)), (e.dstId, Map(e.srcAttr -> e.attr)))
  }
def mergeMessage(count1: Map[VertexId, Long], count2: Map[VertexId, Long])
  : Map[VertexId, Long] = {
    (count1.keySet ++ count2.keySet).map { i =>
      val count1Val = count1.getOrElse(i, 0L)
      val count2Val = count2.getOrElse(i, 0L)
      i -> (count1Val + count2Val)
    }(collection.breakOut)
  }

標(biāo)簽傳播震蕩問題改造

1、初始化每個節(jié)點屬性信息,先給每個節(jié)點分配不重復(fù)標(biāo)簽。如,節(jié)點1對應(yīng)標(biāo)簽1,節(jié)點i對應(yīng)標(biāo)簽i;

2、N個節(jié)點,同步找到對應(yīng)節(jié)點鄰居,獲取此節(jié)點鄰居標(biāo)簽,找到出現(xiàn)權(quán)重最高的標(biāo)簽,若權(quán)重最高的標(biāo)簽不止一個,則選擇標(biāo)簽值較大的標(biāo)簽賦值給當(dāng)前節(jié)點;

3、若本輪標(biāo)簽重標(biāo)記后,節(jié)點標(biāo)簽不再變化(或者達(dá)到設(shè)定的最大迭代次數(shù)),則迭代停止,否則重復(fù)第2步。迭代結(jié)果即為RS0;

4、當(dāng)?shù)?步結(jié)束后,以其結(jié)果RS0作為節(jié)點初始化信息,重新初始化每個節(jié)點屬性信息,并從第2步開始,再分別迭代1輪、2輪、3輪,結(jié)果分別存為 RS1、RS2和RS3;

5、綜合RS0、RS1、RS2和RS3的結(jié)果,得到最終每個節(jié)點的標(biāo)簽結(jié)果。如,節(jié)點i在RS0、RS1、RS2和RS3中的標(biāo)簽信息分別為(a、b、c、d),選擇其中計數(shù)最多的標(biāo)簽作為節(jié)點i的最終結(jié)果,若計數(shù)最多的標(biāo)簽不止一個,則選擇標(biāo)簽值最大的標(biāo)簽作為節(jié)點i最終的標(biāo)簽。

6、至此,label propagation算法結(jié)束,每個節(jié)點獲得的標(biāo)簽即為其最終歸屬的cluster的id,聚類結(jié)束。

效果對比(demo數(shù)據(jù))

graphx自帶label propagation
  • demo數(shù)據(jù)展示(邊權(quán)重表示點之間的親密度)


    image.gif
  • 期望的聚類結(jié)果
image.gif
  • graphx自帶LPA聚類結(jié)果(共分成4個cluster,不同顏色標(biāo)注)
image.gif
  • 改進(jìn)算法的聚類結(jié)果


    image.gif

效果對比(通過wifi連接獲取的關(guān)系數(shù)據(jù))

1、外賣標(biāo)簽,數(shù)據(jù)集中該標(biāo)簽占比0.3965。數(shù)據(jù)集共23137人。訓(xùn)練集16195人,其中帶標(biāo)簽6451人;測試集6942人。其中帶標(biāo)簽2723人。

a、graphx自帶lp:召回率0.0823,精確率0.5450
b、pregel實現(xiàn)改進(jìn)版lp:召回率0.2281,精確率0.4909

屏幕快照 2019-08-22 下午8.21.31.png

2、學(xué)前教育,數(shù)據(jù)集中該標(biāo)簽占比0.0281。數(shù)據(jù)集共23137人。訓(xùn)練集16195人,其中帶標(biāo)簽462人;測試集6942人。其中帶標(biāo)簽188人。

a、graphx自帶lp:召回率0.0,精確率0.0
b、pregel實現(xiàn)改進(jìn)版lp:召回率0.0426,精確率0.0952

屏幕快照 2019-08-22 下午8.21.58.png

3、炒股,數(shù)據(jù)集中該標(biāo)簽占比0.2192。數(shù)據(jù)集共23137人。訓(xùn)練集16195人,其中帶標(biāo)簽3499人;測試集6942人。其中帶標(biāo)簽1572人。

a、graphx自帶lp:召回率0.0204,精確率0.3721
b、pregel實現(xiàn)改進(jìn)版lp:召回率0.1501,精確率0.3940

屏幕快照 2019-08-22 下午8.22.24.png

4、游戲付費意愿用戶,數(shù)據(jù)集中該標(biāo)簽占比0.1312。數(shù)據(jù)集共23137人。訓(xùn)練集16195人,其中帶標(biāo)簽2137人;測試集6942人。其中帶標(biāo)簽898人。

a、graphx自帶lp:召回率0.0267,精確率0.2857
b、pregel實現(xiàn)改進(jìn)版lp:召回率0.1292,精確率0.2736

屏幕快照 2019-08-22 下午8.22.47.png

5、35歲+標(biāo)簽,數(shù)據(jù)集中該標(biāo)簽占比0.3227。數(shù)據(jù)集共23137人。訓(xùn)練集16195人,其中帶標(biāo)簽5204人;測試集6942人。其中帶標(biāo)簽2262人。

a、graphx自帶lp:召回率0.0469,精確率0.4953
b、pregel實現(xiàn)改進(jìn)版lp:召回率0.2604,精確率0.5285


屏幕快照 2019-08-22 下午8.23.10.png

完整代碼如下(scala)

package Graph.LPA

import org.apache.spark.graphx._
import org.apache.spark._
import org.apache.spark.sql.Row
import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.LongType


object LPARevolution {
  def sendMessage(e: EdgeTriplet[VertexId, Int]): Iterator[(VertexId, Map[VertexId, Long])] = {
    Iterator((e.srcId, Map(e.dstAttr -> e.attr)), (e.dstId, Map(e.srcAttr -> e.attr)))
  }

  def mergeMessage(count1: Map[VertexId, Long], count2: Map[VertexId, Long])
  : Map[VertexId, Long] = {
    (count1.keySet ++ count2.keySet).map { i =>
      val count1Val = count1.getOrElse(i, 0L)
      val count2Val = count2.getOrElse(i, 0L)
      i -> (count1Val + count2Val)
    }(collection.breakOut)
  }

  // 更新點屬性
  def vertexProgram(vid: VertexId, attr: Long, message: Map[VertexId, Long]): VertexId = {
    if(message.isEmpty){
      attr
    }
    else{
      //            print(vid)
      //            println(" 接收到的消息:   ")
      //            println(message)
      //            println("最終選擇的是:")
      //            println(message.maxBy(_._2)._1)
      message.maxBy(_._2)._1  // 按照計數(shù)排序,然后取第一個
    }

  }


  def main(args: Array[String]): Unit = {
    val conf = new SparkConf()
    val sc = new SparkContext("yarn","lpa-revolution",conf)
    val hql = new HiveContext(sc)

    // 獲取邊數(shù)據(jù)gid,usertags,wifimac,ssid,geohash,day-int
    val edges = hql.sql("select cast(src as bigint), cast(dst as bigint)," +
      " cast(weight as int) from yangy.graph_edge_table_3day_zoom_weight_hz").rdd.
      map(row => Edge(row(0).asInstanceOf[Long], row(1).asInstanceOf[Long], row(2).asInstanceOf[Int]))

    // 獲取點數(shù)據(jù)id_2_label_table
    val users = hql.sql("select cast(id as bigint), user_tags from yangy.id_2_label_table_3day_zoom_weight_hz").
      rdd.map(row => (row(0).asInstanceOf[Long], row(1)))

//    val edges = sc.textFile("file:///home/yangy/data/xh_edge_20190530_8day_1_0.txt").
//      map{line =>
//        val fields = line.split(" ")
//        (Edge(fields(0).toLong, fields(1).toLong, fields(2).toInt))
//      }
//
//    val users = sc.textFile("file:///home/yangy/data/xh_vertex_with_label_1_0.txt").
//      map { line =>
//        val fields = line.split(" ")
//        (fields(0).toLong, fields(1).toLong)
//      }

    val graph = Graph(vertices = users, edges = edges)

    // 圖初始化
    val initGraph = graph.mapVertices { case (vid, attr) => vid }

    // 初始化msg
    val initialMessage = Map[VertexId, Long]()

    println("迭代結(jié)果:")

    // 分水嶺,開始解決社區(qū)震蕩&孤立點問題
    // ----------------------------------  迭代多輪  -------------------------------------
    val cluster1 = Pregel(initGraph, initialMessage, maxIterations = 100, activeDirection = EdgeDirection.Either)(
      vprog = vertexProgram,
      sendMsg = sendMessage,
      mergeMsg = mergeMessage)


    // =====================================================================
    // 優(yōu)雅代碼的核心部分,基于前面的結(jié)果初始化新的圖
    // 利用前面迭代結(jié)果重新初始化圖
    // 以此結(jié)果作為基礎(chǔ),后續(xù)在此基礎(chǔ)上繼續(xù)迭代
    val users_trans = cluster1.vertices
    val graph_trans = Graph(vertices = users_trans, edges = edges)
    val initGraph_trans = graph_trans.mapVertices { case (vid, attr) => attr}
    // ======================================================================


    // 在基礎(chǔ)數(shù)據(jù)上,額外迭代的輪數(shù)
    val cluster2 = Pregel(initGraph_trans, initialMessage, maxIterations = 1, activeDirection = EdgeDirection.Either)(
      vprog = vertexProgram,
      sendMsg = sendMessage,
      mergeMsg = mergeMessage)

    val cluster3 = Pregel(initGraph_trans, initialMessage, maxIterations = 2, activeDirection = EdgeDirection.Either)(
      vprog = vertexProgram,
      sendMsg = sendMessage,
      mergeMsg = mergeMessage)

    val cluster4 = Pregel(initGraph_trans, initialMessage, maxIterations = 3, activeDirection = EdgeDirection.Either)(
      vprog = vertexProgram,
      sendMsg = sendMessage,
      mergeMsg = mergeMessage)


    // 構(gòu)建label propagation結(jié)果dataframe
    val colNames = "id,group_id"
    val schema = StructType(colNames.split(",").map(column => StructField(column, LongType)))

    // 獲取每個id的分組信息,字段名是id, group_id
    val groupDf1 = hql.createDataFrame(cluster1.vertices.map(x=> Row(x._1, x._2)), schema)
    val groupDf2 = hql.createDataFrame(cluster2.vertices.map(x=> Row(x._1, x._2)), schema)
    val groupDf3 = hql.createDataFrame(cluster3.vertices.map(x=> Row(x._1, x._2)), schema)
    val groupDf4 = hql.createDataFrame(cluster4.vertices.map(x=> Row(x._1, x._2)), schema)


    val group_union_df = groupDf1.unionAll(groupDf2).unionAll(groupDf3).unionAll(groupDf4)

    // 選取合適的group_id,避免社區(qū)震蕩
    group_union_df.registerTempTable("group_union_table")

    // 獲取不震蕩的group歸屬信息
    val group_no_swing = hql.sql(
      """
        |select t2.id as id,
        |       t2.group_id as group_id
        |from
        |(
        |   select t1.id as id,
        |          t1.group_id as group_id,
        |          rank() over (partition by t1.id order by t1.cnt, t1.group_id desc) as rank
        |   from
        |   (
        |       select id,
        |              group_id,
        |              count(1) as cnt
        |       from group_union_table
        |       group by id,
        |                group_id
        |   ) t1
        |) t2
        |where t2.rank = 1
        |
      """.stripMargin)

    group_no_swing.write.mode("overwrite").
      saveAsTable("yangy.graphx_cluster_zoom_no_swing_hz_100_table")

    group_no_swing.show(20)

    sc.stop()
  }

}

** 原創(chuàng)內(nèi)容,若要轉(zhuǎn)載請聯(lián)系本人 **

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容