基于 SparkGraphx 實現(xiàn)適用于位置信息的DBScan聚類

基于 SparkGraphx 實現(xiàn)的 DBScan聚類

關(guān)于DBScan算法的詳細介紹請參見維基百科

https://en.wikipedia.org/wiki/DBSCAN

Graphx 實現(xiàn)Dbscan 圖解


圖解

1.上圖中藍色的點代表我們需要聚類的樣本點,假設(shè)我們將DBScan的兩個參數(shù):距離 (Eps)設(shè)為1,最小集群點數(shù)(minPts)設(shè)為 4,則根據(jù)聚類規(guī)則,上圖的A、B部分則會分別被聚為一類,C、D部分則會被視為離群點。

2.而Graphx的作用就是將兩個距離滿足條件的點連成邊,然后再將這些邊連成一個個的連通圖,最后再計算各個圖內(nèi)的點數(shù)是否滿足設(shè)定的最小集群點數(shù)。根據(jù)聚類規(guī)則我們就可以完成聚類,抽象出來就如上圖所示。

3.代碼實現(xiàn)過程如下
本文所使用的是經(jīng)緯度數(shù)據(jù),因此在使用距離計算的時候,用的是經(jīng)緯度距離的計算方法(球面距離),在實現(xiàn)過程中也使用了Geohash算法(相關(guān)介紹有很多,這篇帖子就很好)進行了相關(guān)優(yōu)化。

  /**
      * 參數(shù)校驗
      */
    if (args.length != 4) {
      println(
        """
          |參數(shù):
          |dbinput   輸入路徑
          |eps       鄰域半徑
          |minpts    最小密集點數(shù)
          |dboutput  輸出路徑
        """.stripMargin)
      System.exit(3)
    }
    val Array(dbinput, eps, minpts, output) = args

    val spark = SparkSession.builder()
      .appName(s"${this.getClass.getSimpleName}")
      .master("local[*]")
      .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
      .config("spark.shuffle.consolidateFiles", "true")
      .config("spark.io.compression.codec", "snappy")
      .getOrCreate()

    import spark.implicits._

  //  加載數(shù)據(jù)
  val dbdata = spark.read.option("inferSchema", true).csv(dbinput)


  // 計算經(jīng)緯度距離
    def distanceBetweenPoints(lon1: Double, lon2: Double, lat1: Double, lat2: Double): Double = {
      require(lon1 >= -180 && lon1 <= 180)
      require(lon2 >= -180 && lon2 <= 180)
      require(lat1 >= -90 && lat1 <= 90)
      require(lat2 >= -90 && lat2 <= 90)
      val R = 6371009d // average radius of the earth in metres
      val dLat = toRadians(lat2 - lat1)
      val dLng = toRadians(lon2 - lon1)
      val latA = toRadians(lat1)
      val latB = toRadians(lat2)
      // The actual haversine formula. a and c are well known value names in the formula.
      val a = sin(dLat / 2) * sin(dLat / 2) +
        sin(dLng / 2) * sin(dLng / 2) * cos(latA) * cos(latB)
      val c = 2 * atan2(sqrt(a), sqrt(1 - a))
      // 默認返回千米
      (R * c) / 1000D
    }

    // 經(jīng)緯度距離 sparksql udf 
    val lonLatDistance = udf((lon1: Double, lon2: Double, lat1: Double, lat2: Double) => {
      distanceBetweenPoints(lon1, lon2, lat1, lat2)
    })

此部分是結(jié)合GeoHash算法做的一點優(yōu)化,主要是根據(jù)dbscan的距離參數(shù)預(yù)先對數(shù)據(jù)進行分組,筆者水平有限,只想到了這個數(shù)據(jù)分區(qū)的方法。


GeoHash Code 精度對照
   // 根據(jù)geohash算法對經(jīng)緯度數(shù)據(jù)做分區(qū)
    val scope = udf((lon: Double, lat: Double) => {
      // geohash
      val geohash = GeoHash.encodeHash(lat, lon,
        // 計算geohash的最優(yōu)分區(qū)位數(shù)
        MLUtils.geoLength(eps.toDouble))
      val neighbours: Array[String] = GeoHash.neighbours(geohash).toArray().map(_.toString)
      Seq(geohash) ++ neighbours
    })
//  將原始的經(jīng)緯度數(shù)據(jù)按照相同的分組進行 join聚合
  val localbase = dbdata
      .toDF("lon", "lat")
      .where($"lon".isNotNull and $"lat".isNotNull)
      .withColumn("id", hash($"lon", $"lat"))

    val ll = localbase
      .withColumn("scopes", scope($"lon", $"lat"))
      .withColumn("scope", explode($"scopes"))
      .drop("scopes").cache()

    val ll2 = ll.toDF("lon2", "lat2", "id2", "scope")

    val data = ll.join(ll2, "scope").where($"id" =!= $"id2")
      .withColumn("distance", lonLatDistance($"lon", $"lon2", $"lat", $"lat2"))
 

    //構(gòu)建邊Edge[Int]
    val lv: RDD[(VertexId, VertexId)] = data
      .filter($"distance" <= eps.toDouble)  // 篩選出滿足距離條件的點
      .select($"id", $"id2").rdd
      .map(row => {
        val id = row.getAs[Int]("id").toLong
        val id2 = row.getAs[Int]("id2").toLong
        (id, id2)
      })

    val le = lv.map { ids => Edge(ids._1, ids._2, 0) } // 根據(jù)點構(gòu)建邊

    // 構(gòu)建圖
    val graph = Graph(lv, le)
    val gcc = graph.connectedComponents().vertices
    val joined = gcc.join(lv)
      .map(tp => {
        (tp._2._1, Seq(tp._2._2))
      }).reduceByKey(_ ++ _)   // 聚合每個聯(lián)通圖的點
      .map(tp => {
        (tp._2.distinct, tp._2.distinct.length)
      }).filter(_._2 >= minpts.toInt)   // 篩選出滿足最小聚類點數(shù)的連通圖

    val clust = joined.toDF("clu", "ct")
      .withColumn("cluid", hash($"clu"))
      .withColumn("id", explode($"clu"))   

    val dbres: DataFrame = localbase.join(clust, Seq("id"), "left")
      .na.fill(0).drop("clu", "ct")  // 離群點的聚類id以0標識

    // 保存聚類結(jié)果
    dbres.repartition(1).write.option("header", true)
      .mode("overwrite")
      .csv(output) 

在本案例中,eps設(shè)為30km,minPts設(shè)為 5,聚類結(jié)果的可視化如下 ,紅圈的就是兩個簇類,其余的都是離群點


聚類結(jié)果可視化

本案例的數(shù)據(jù)鏈接 https://pan.baidu.com/s/1EaA7oGAmiJ2m4oXPLppsdg

用此方法實現(xiàn)的DBScan聚類在大數(shù)據(jù)集上運行效率較低,還有很多可以優(yōu)化的地方,也有很多可以擴展的地方,如有不當(dāng)之處,歡迎指正

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容