一、Distinct aggregation 算法
包含 distinct 關(guān)鍵字的 aggregation 由 4 個(gè)物理執(zhí)行步驟組成。我們使用以下 query 來(lái)介紹:
val dataset = Seq(
(1, "a"), (1, "a"), (1, "a"), (2, "b"), (2, "b"), (3, "c"), (3, "c")
).toDF("nr", "letter")
dataset.groupBy($"nr").agg(functions.countDistinct("letter")).explain(true)
① partial aggregation 步驟
第一步是創(chuàng)建一個(gè) partial aggregate,此 partial aggregate 的 grouping key 將不僅包括 query 中定義的 grouping key(nr),還包含 distinct 的列(letter),效果如 group by nr、letter,執(zhí)行計(jì)劃如下:
HashAggregate(keys=[nr#5, letter#6], functions=[], output=[nr#5, letter#6])
+- LocalTableScan [nr#5, letter#6]
② partial merge aggregation 步驟
這一步將通過(guò) shuffle 將具有相同 grouping key(此處為 nr、letter)的數(shù)據(jù)劃分為同一分區(qū):
+- HashAggregate(keys=[nr#5, letter#6], functions=[], output=[nr#5, letter#6])
+- Exchange hashpartitioning(nr#5, letter#6, 200)
+- HashAggregate(keys=[nr#5, letter#6], functions=[], output=[nr#5, letter#6])
+- LocalTableScan [nr#5, letter#6]
③ partial aggregation for distinct 步驟
第三步,Spark 最終開始執(zhí)行聚合,執(zhí)行的是 partial aggregate:
+- HashAggregate(keys=[nr#5], functions=[partial_count(distinct letter#6)], output=[nr#5, count#18L])
+- HashAggregate(keys=[nr#5, letter#6], functions=[], output=[nr#5, letter#6])
+- Exchange hashpartitioning(nr#5, letter#6, 200)
+- HashAggregate(keys=[nr#5, letter#6], functions=[], output=[nr#5, letter#6])
+- LocalTableScan [nr#5, letter#6]
④ final aggregation 步驟
第四步,partial aggregate(第三步)的結(jié)果將合并到最終結(jié)果中,并進(jìn)行返回。它涉及 shuffle:
HashAggregate(keys=[nr#5], functions=[count(distinct letter#6)], output=[nr#5, count(DISTINCT letter)#12L])
+- Exchange hashpartitioning(nr#5, 200)
+- HashAggregate(keys=[nr#5], functions=[partial_count(distinct letter#6)], output=[nr#5, count#18L])
+- HashAggregate(keys=[nr#5, letter#6], functions=[], output=[nr#5, letter#6])
+- Exchange hashpartitioning(nr#5, letter#6, 200)
+- HashAggregate(keys=[nr#5, letter#6], functions=[], output=[nr#5, letter#6])
+- LocalTableScan [nr#5, letter#6]
我們用下面的這張圖來(lái)總結(jié)上述幾個(gè)步驟:

二、無(wú) Distinct aggregation 算法
無(wú) Distinct aggregation 會(huì)簡(jiǎn)單一些,僅包含兩個(gè)步驟,我們通過(guò)下面的例子來(lái)說(shuō)明:
val dataset = Seq(
(1, "a"), (1, "a"), (1, "a"), (2, "b"), (2, "b"), (3, "c"), (3, "c")
).toDF("nr", "letter")
dataset.groupBy($"nr").count().explain(true)
①、partial aggregations 步驟
第一步即進(jìn)行局部聚合:
HashAggregate(keys=[nr#5], functions=[partial_count(1)], output=[nr#5, count#17L])
+- PlanLater LocalRelation [nr#5]
②、final aggregation 步驟
第二步,毫無(wú)疑問(wèn),對(duì)部分結(jié)果進(jìn)行了最終匯總:
HashAggregate(keys=[nr#5], functions=[count(1)], output=[nr#5, count#12L])
+- HashAggregate(keys=[nr#5], functions=[partial_count(1)], output=[nr#5, count#17L])
+- PlanLater LocalRelation [nr#5]
三、Hash-based 和 Sort-based aggregation
上述兩種模式都會(huì)調(diào)用到 createAggregate 方法,該方法為以下 3 種策略創(chuàng)建物理執(zhí)行計(jì)劃:
- hash-based
- object-hash-based
- sort-based
這 3 中策略有一些共性。一個(gè) Spark Sql aggregation 主要由兩部分組成:
- 一個(gè) agg buffer(聚合緩沖區(qū):包含 grouping keys 和 agg value)
- 一個(gè) agg state(聚合狀態(tài):僅 agg value)
每次調(diào)用 GROUP BY key 并對(duì)其使用一些聚合時(shí),框架都會(huì)創(chuàng)建一個(gè)聚合緩沖區(qū),保留給定的聚合(GROUP BY key)。指定 key(COUNT,SUM等)所涉及的聚合都在此聚合緩沖區(qū)存儲(chǔ)其部分(partial)或最終聚合結(jié)果,稱為聚合狀態(tài)。該狀態(tài)的存儲(chǔ)格式取決于聚合:
- 對(duì)于 AVG,它將是2個(gè)值,一個(gè)是出現(xiàn)次數(shù),另一個(gè)是值的總和
- 對(duì)于 MIN,它將是到目前為止所看到的最小值
依此類推
hash-based 策略使用可變的、原始的、固定 size 的類型來(lái)作為 agg state,包括:
- NullType
- BooleanType
- ByteType
- ShortType
- IntegerType
- LongType
- FloatType
- DoubleType
- DateType
- TimestampType
這里的可變能力非常重要,因?yàn)?Spark 會(huì)直接修改該值(如對(duì)于 count 來(lái)說(shuō),遇到新的 row,就會(huì)把 count 的值(agg state)加上 1)。
對(duì)于 agg state 的值是其他類型的情況,使用 object-hash-based 策略,該策略自 2.2.0 版本引入,目的是為了解決 hash-based 策略的局限性(必須使用可變的、原始的、固定 size 的類型來(lái)作為 agg state)。在 2.2.0 之前,針對(duì) HashAggregateExec 不支持的其他類型執(zhí)行的聚合都會(huì)轉(zhuǎn)換為 sort-based 的策略。大部分情況下,sort-based 的性能會(huì)比 hash-based 的差,因?yàn)樵诰酆锨皶?huì)進(jìn)行額外的排序。通過(guò)參數(shù) spark.sql.execution.useObjectHashAggregateExec 來(lái)控制是否使用 object-hash-based 聚合,默認(rèn)為 true。我們通過(guò)下面的例子來(lái)理解 sort-based 和 object-hash-based 的區(qū)別:
查詢
val dataset2 = Seq(
(1, "a"), (1, "aa"), (1, "a"), (2, "b"), (2, "b"), (3, "c"), (3, "c")
).toDF("nr", "letter")
dataset2.groupBy("nr").agg(functions.collect_list("letter").as("collected_letters")).explain(true)

如你所見,上圖兩個(gè)物理執(zhí)行計(jì)劃均只進(jìn)行一次 shuffle,但 sort-based 聚合相對(duì)于 object-hash-based 額外多了兩次排序,帶來(lái)性能開銷。
另一個(gè)值得關(guān)注的點(diǎn)是,hash-based 和 object-hash-based 運(yùn)行過(guò)程中如果內(nèi)存不夠用,會(huì)切換成 sort-based 聚合。對(duì)于 object-hash-based 聚合,通過(guò)參數(shù) spark.sql.objectHashAggregate.sortBased.fallbackThreshold 控內(nèi)存中(一種 hashMap)最多持有多少個(gè) agg buffer(一個(gè) grouping key 的組合一個(gè)),若超過(guò)該值,則切換為 sort-based agg,該配置默認(rèn)值為 128。如果切換為 sort-based agg,會(huì)打印如下日志:
ObjectAggregationIterator: Aggregation hash map reaches threshold capacity (128 entries), spilling and falling back to sort based aggregation. You may change the threshold by adjust option spark.sql.objectHashAggregate.sortBased.fallbackThreshold
對(duì)于 hash-based,該值為 Integer.MaxValue