前言
需求:業(yè)務(wù)需求要求求出score的最大值(max),最小值(min),均值(mean),標(biāo)準(zhǔn)差(stddev),中位數(shù)。需求的前四個(gè)值Spark自帶函數(shù)可以解決,唯獨(dú)中位數(shù)沒有,所以需要自定義一個(gè)聚合函數(shù)。
實(shí)現(xiàn)方法以及代碼
自定義函數(shù)需要繼承UserDefinedAggregateFunction
class MiddleValueUDAF extends UserDefinedAggregateFunction{
// 輸入?yún)?shù)的數(shù)據(jù)類型
override def inputSchema: StructType = {
DataTypes.createStructType(util.Arrays
.asList((DataTypes.createStructField("score",DataTypes.StringType,true))))
}
/**
*
* 更新 可以認(rèn)為一個(gè)一個(gè)地將組內(nèi)的字段值傳遞進(jìn)來 實(shí)現(xiàn)拼接的邏輯
* buffer.getInt(0)獲取的是上一次聚合后的值
* 相當(dāng)于map端的combiner,combiner就是對每一個(gè)map task的處理結(jié)果進(jìn)行一次小聚合
* 大聚和發(fā)生在reduce端.
* 這里即是:在進(jìn)行聚合的時(shí)候,每當(dāng)有新的值進(jìn)來,對分組后的聚合如何進(jìn)行計(jì)算
*/
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer.update(0,Integer.valueOf(buffer.get(0).toString)+Integer.valueOf(input.get(0).toString))
buffer.update(0,buffer.get(0)+","+input.get(0).toString)
}
// buffer中的數(shù)據(jù)類型
override def bufferSchema: StructType = {
DataTypes.createStructType(util.Arrays
.asList((DataTypes.createStructField("summ",DataTypes.StringType,true))))
}
/**
* 合并其他部分結(jié)果
* 合并 update操作,可能是針對一個(gè)分組內(nèi)的部分?jǐn)?shù)據(jù),在某個(gè)節(jié)點(diǎn)上發(fā)生的 但是可能一個(gè)分組內(nèi)的數(shù)據(jù),會(huì)分布在多個(gè)節(jié)點(diǎn)上處理
* 此時(shí)就要用merge操作,將各個(gè)節(jié)點(diǎn)上分布式拼接好的串,合并起來
* buffer1.getInt(0) : 大聚合的時(shí)候 上一次聚合后的值
* buffer2.getInt(0) : 這次計(jì)算傳入進(jìn)來的update的結(jié)果
* 這里即是:最后在分布式節(jié)點(diǎn)完成后需要進(jìn)行全局級別的Merge操作
* 也可以是一個(gè)節(jié)點(diǎn)里面的多個(gè)executor合并
*/
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1.update(0,Integer.valueOf(buffer1.get(0).toString)+Integer.valueOf(buffer2.get(0).toString))
buffer1.update(0,buffer1.get(0)+","+buffer2.get(0).toString)
}
//初始化
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer.update(0,"")
}
// 確保一致性 一般用true,用以標(biāo)記針對給定的一組輸入,UDAF是否總是生成相同的結(jié)果
override def deterministic: Boolean = {
true
}
//計(jì)算邏輯
override def evaluate(buffer: Row): Any = {
val intArray = buffer.get(0).toString.replaceAll(",,",",").substring(1)
val list = intArray.split(",").map(_.toDouble).toList.sorted
val len = list.size
var mid = 0d
if (len % 2 == 0)
mid = (list(len / 2 - 1) + list(len / 2)) / 2
else
mid = list(len / 2)
mid
}
// 返回值的類型
override def dataType: DataType = {
DataTypes.DoubleType
}