Spark自定義聚合函數(shù)時(shí),需要實(shí)現(xiàn)UserDefinedAggregateFunction中8個(gè)方法:
- inputSchema:輸入的數(shù)據(jù)類型
- bufferSchema:中間聚合處理時(shí),需要處理的數(shù)據(jù)類型
- dataType:函數(shù)的返回類型
- deterministic:是否是確定的
- initialize:為每個(gè)分組的數(shù)據(jù)初始化
- update:每個(gè)分組,有新的值進(jìn)來時(shí),如何進(jìn)行分組的聚合計(jì)算
- merge:由于Spark是分布式的,所以一個(gè)分組的數(shù)據(jù),可能會(huì)在不同的節(jié)點(diǎn)上進(jìn)行局部聚合,就是update,但是最后一個(gè)分組,在各節(jié)點(diǎn)上的聚合值,要進(jìn)行Merge,也就是合并
- evaluate:一個(gè)分組的聚合值,如何通過中間的聚合值,最后返回一個(gè)最終的聚合值
實(shí)例代碼:
package com.spark.sql
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
/**
* Created by Administrator on 2017/3/13.
* 用戶自定義聚合函數(shù)
*/
class StrCountUDAF extends UserDefinedAggregateFunction{
//輸入的數(shù)據(jù)類型
override def inputSchema: StructType = {
StructType(Array(
StructField("str",StringType,true)
))
}
//中間聚合處理時(shí),所處理的數(shù)據(jù)類型
override def bufferSchema: StructType = {
StructType(Array(
StructField("count",IntegerType,true)
))
}
//函數(shù)的返回類型
override def dataType: DataType = {
IntegerType
}
override def deterministic: Boolean = {
true
}
//為每個(gè)分組的數(shù)據(jù)初始化
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0)=0
}
//指的是,每個(gè)分組,有新的值進(jìn)來時(shí),如何進(jìn)行分組的聚合計(jì)算
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0)=buffer.getAs[Int](0)+1
}
//由于Spark是分布式的,所以一個(gè)分組的數(shù)據(jù),可能會(huì)在不同的節(jié)點(diǎn)上進(jìn)行局部聚合,就是update
//但是最后一個(gè)分組,在各節(jié)點(diǎn)上的聚合值,要進(jìn)行Merge,也就是合并
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0)=buffer1.getAs[Int](0) + buffer2.getAs[Int](0)
}
//一個(gè)分組的聚合值,如何通過中間的聚合值,最后返回一個(gè)最終的聚合值
override def evaluate(buffer: Row): Any = {
buffer.getAs[Int](0)
}
}
- 聚合函數(shù)的使用
package com.spark.sql
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
object UDAF extends App{
val conf = new SparkConf()
.setMaster("local")
.setAppName("DailyUVFunction")
val sc = new SparkContext(conf)
val sqlContext = new SQLContext(sc)
//導(dǎo)入隱式轉(zhuǎn)化
import sqlContext.implicits._
//構(gòu)造用戶的訪問數(shù)據(jù),并創(chuàng)建DataFrame
val names=Array("tom","yangql","mary","test","test")
val namesRDD = sc.parallelize(names)
//將RDD轉(zhuǎn)換為DataFram
val namesRowRDD=namesRDD.map(name=>Row(name))
val structType=StructType(Array(
StructField("name",StringType,true)
))
val namesDF=sqlContext.createDataFrame(namesRowRDD,structType)
//注冊(cè)表
namesDF.createOrReplaceTempView("names")
//定義和注冊(cè)自定義函數(shù)
sqlContext.udf.register("strCount",new StrCountUDAF)
//使用自定義函數(shù)
val df=sqlContext.sql("select name,strCount(name) from names group by name")
df.show()
}