> 昨天有位大哥問小弟一個(gè)Spark問題,他們想在不停Spark程序的情況下動(dòng)態(tài)更新UDF的邏輯,他一問我這個(gè)問題的時(shí)候,本豬心里一驚,Spark**還能這么玩?我出于程序員的本能回復(fù)他肯定不行,但今天再回過來頭想了一想,昨天腦子肯定進(jìn)水了,回復(fù)太膚淺了,既然Spark可以通過編程方式注冊(cè)UDF,當(dāng)然把那位大哥的代碼邏輯使用反射加載進(jìn)去再調(diào)用不就行了?這不就是JVM的優(yōu)勢(shì)么,怪自己的反射沒學(xué)到家,說搞就搞起。
## 分析過程
我會(huì)說這波分析過程很無聊,你還會(huì)看么?

跟著本豬看一個(gè)`Spark`注冊(cè)`UDF`的例子
```
spark.udf.register(name, (a1: String) => a1.toUpperCase)
```
點(diǎn)擊`register`的源碼進(jìn)去看
```
一個(gè)`A1`:參數(shù)類型,`RT`:返回類型
def register[RT: TypeTag, A1: TypeTag](name: String, func: Function1[A1, RT]): UserDefinedFunction = {
? ? val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
? ? val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: Nil).toOption
? ? def builder(e: Seq[Expression]) = if (e.length == 1) {
? ? ? ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true)
? ? } else {
? ? ? ...
? ? }
? ? ...
? }
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag](name: String, func: Function2[A1, A2, RT]): UserDefinedFunction
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](name: String, func: Function3[A1, A2, A3, RT]): UserDefinedFunction = {
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](name: String, func: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = {
...
```
**1.** func是上面方法的重點(diǎn),既然想要?jiǎng)討B(tài)`UDF`邏輯代碼,那我們把`Function1`這個(gè)函數(shù)實(shí)現(xiàn)不就可以了?再利用JVM反射的技術(shù)調(diào)用,完美。
**2.** 順便還看出了在scala-2.10.x版本中`case class`的元素是不能超過 $\color{#FF0000}{22}$ 個(gè)的。
---
上面的`UDF`注冊(cè)的原型其實(shí)是
```
val udf = new Function1[String,String] {
? ? ? override def apply(a1: String): String = {
? ? ? ? a1.toUpperCase
? ? ? }
}
spark.udf.register(name, udf)
```
到這里我有一個(gè) **膚淺** 并且 **大膽** 的想法,我把那位大哥的代碼放到apply方法里面調(diào)用不就行了?

```
val udf = new Function1[String,String] {
? ? ? override def apply(a1: String): String = {
? ? ? ? //method.invoke(instance) //使用反射加載代碼,把大哥動(dòng)態(tài)邏輯方法method拿出來調(diào)用。
? ? ? }
}
```
**1.** 但是還有一些問題要解決,我不能強(qiáng)制我的老大哥只能傳遞一個(gè)參數(shù)吧,那也太年輕不懂事了,至少讓他可以隨意傳 $\color{#FF0000}{22}$ 參數(shù)。
**2.** 唯一的解決方法,就是要控制`Function1`到`Function22`函數(shù)的動(dòng)態(tài)生成,找了半天沒發(fā)現(xiàn)`Function`的動(dòng)態(tài)生成,然后還發(fā)現(xiàn)Spark也是根據(jù)參數(shù)長度生成`FunctionN`的,真**刷新本豬的三觀呀。
**3.** 既然實(shí)現(xiàn)方式找到了,那就簡單了,只要通過反射就能 **上知天文,下知地理** 。
---
既然是`Spark`,肯定要用`Scala`去寫反射了。
```
case class ClassInfo(clazz: Class[_], instance: Any, defaultMethod: Method, methods: Map[String, Method], func:String) {
? def invoke[T](args: Object*): T = {
? ? defaultMethod.invoke(instance, args: _*).asInstanceOf[T]
? }
}
object ClassCreateUtils extends Logging{
? private val clazzs = new util.HashMap[String, ClassInfo]()
? private val classLoader = scala.reflect.runtime.universe.getClass.getClassLoader
? private val toolBox = universe.runtimeMirror(classLoader).mkToolBox()
? def apply(func: String): ClassInfo = this.synchronized {
? ? var clazz = clazzs.get(func)
? ? if (clazz == null) {
? ? ? val (className, classBody) = wrapClass(func)
? ? ? val zz = compile(prepareScala(className, classBody))
? ? ? val defaultMethod = zz.getDeclaredMethods.head
? ? ? val methods = zz.getDeclaredMethods
? ? ? clazz = ClassInfo(
? ? ? ? zz,
? ? ? ? zz.newInstance(),
? ? ? ? defaultMethod,
? ? ? ? methods = methods.map { m => (m.getName, m) }.toMap,
? ? ? ? func
? ? ? )
? ? ? clazzs.put(func, clazz)
? ? ? logInfo(s"dynamic load class => $clazz")
? ? }
? ? clazz
? }
? def compile(src: String): Class[_] = {
? ? val tree = toolBox.parse(src)
? ? toolBox.compile(tree).apply().asInstanceOf[Class[_]]
? }
? def prepareScala(className: String, classBody: String): String = {
? ? classBody + "\n" + s"scala.reflect.classTag[$className].runtimeClass"
? }
? def wrapClass(function: String): (String, String) = {
? ? val className = s"dynamic_class_${UUID.randomUUID().toString.replaceAll("-", "")}"
? ? val classBody =
? ? ? s"""
? ? ? ? |class $className{
? ? ? ? |? $function
? ? ? ? |}
? ? ? ? ? ? """.stripMargin
? ? (className, classBody)
? }
}
```
上面的代碼是小弟給大佬寫好的,不用大佬親自動(dòng)手了。

使用方法就灰常簡單了我的大佬們。
```
val infos = ClassCreateUtils(
? ? ? """
? ? ? ? |def apply(name:String)=name.toUpperCase
? ? ? """.stripMargin
)
println(infos.defaultMethod.invoke(infos.instance,"dounine 本豬會(huì)一點(diǎn)點(diǎn) spark"))
# 輸出結(jié)果不用猜也知道是
DOUNINE 本豬會(huì)一點(diǎn)點(diǎn) SPARK
# 也可以手動(dòng)指定方法
println(infos.methods("apply").invoke(infos.instance,"dounine 本豬會(huì)一點(diǎn)點(diǎn) spark"))
```
根據(jù)反射的方法信息生成`FunctionN`
```
object ScalaGenerateFuns {
? def apply(func: String): (AnyRef, Array[DataType], DataType) = {
? ? val (argumentTypes, returnType) = getFunctionReturnType(func)
? ? (generateFunction(func, argumentTypes.length), argumentTypes, returnType)
? }
? //獲取方法的參數(shù)類型及返回類型
? private def getFunctionReturnType(func: String): (Array[DataType], DataType) = {
? ? val classInfo = ClassCreateUtils(func)
? ? val method = classInfo.defaultMethod
? ? val dataType = JavaTypeInference.inferDataType(method.getReturnType)._1
? ? (method.getParameterTypes.map(JavaTypeInference.inferDataType).map(_._1), dataType)
? }
? //生成22個(gè)Function
? def generateFunction(func: String, argumentsNum: Int): AnyRef = {
? ? lazy val instance = ClassCreateUtils(func).instance
? ? lazy val method = ClassCreateUtils(func).methods("apply")
? ? argumentsNum match {
? ? ? case 0 => new (() => Any) with Serializable with Logging {
? ? ? ? override def apply(): Any = {
? ? ? ? ? try {
? ? ? ? ? ? method.invoke(instance)
? ? ? ? ? } catch {
? ? ? ? ? ? case e: Exception =>
? ? ? ? ? ? ? logError(e.getMessage)
? ? ? ? ? }
? ? ? ? }
? ? ? }
? ? ? case 1 => new (Object => Any) with Serializable with Logging {
? ? ? ? override def apply(v1: Object): Any = {
? ? ? ? ? try {
? ? ? ? ? ? method.invoke(instance, v1)
? ? ? ? ? } catch {
? ? ? ? ? ? case e: Exception =>
? ? ? ? ? ? ? e.printStackTrace()
? ? ? ? ? ? ? logError(e.getMessage)
? ? ? ? ? ? ? null
? ? ? ? ? }
? ? ? ? }
? ? ? }
? ? ? case 2 => new ((Object, Object) => Any) with Serializable with Logging {
? ? ? ? override def apply(v1: Object, v2: Object): Any = {
? ? ? ? ? try {
? ? ? ? ? ? method.invoke(instance, v1, v2)
? ? ? ? ? } catch {
? ? ? ? ? ? case e: Exception =>
? ? ? ? ? ? ? logError(e.getMessage)
? ? ? ? ? ? ? null
? ? ? ? ? }
? ? ? ? }
? ? ? }
? ? ? //... 麻煩大佬自己去寫剩下的20個(gè)了,這里裝不下了,不然瀏覽器會(huì)崩潰的,然后電腦會(huì)重啟的,為了大佬的電腦著想。
}
```
## 前戲我們都做完了,高潮的環(huán)節(jié)來了。

我們最后再照著`register`的實(shí)現(xiàn)方式,把我們動(dòng)態(tài)`Function`注冊(cè)給`Spark`
```
1. val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
2. val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: Nil).toOption
3. def builder(e: Seq[Expression]) = if (e.length == 1) {
? ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil),
? ? Some(name), nullable, udfDeterministic = true)
}
4. functionRegistry.createOrReplaceTempFunction(name, builder)
```
**1.** 這句代碼比較好理解,就是獲取`RT`返回值類型,就是我們的`returnType`
**2.** 就是參數(shù)類型,對(duì)應(yīng)的修改如下
```
val inputTypes = Try(argumentTypes.toList).toOption
```
**3.**? 剛開始看到這個(gè)時(shí)候,我是一臉???,后來看源碼才發(fā)現(xiàn)`builder`是一種自定類型,源碼如下
```
type FunctionBuilder = Seq[Expression] => Expression
```
改造方式如下
```
def builder(e: Seq[Expression]) = ScalaUDF(rf, returnType, e, inputTypes.getOrElse(Nil), Some(name))
```
**4.** 看到這句的時(shí)候我以為簡單了,直接使用`spark.sessionState.functionRegistry`發(fā)現(xiàn)編譯不過,看到`private[sql]`這個(gè)作用域的時(shí)候有點(diǎn)崩潰,本來是想用下面的方式注冊(cè)的。
```
val udf = UserDefinedFunction(rf, returnType, inputTypes).withName(name)
spark.udf.register(name, udf)
```
是小弟我想太多了,另辟捷徑,做了那么多工作難道就白費(fèi)了?

## 發(fā)現(xiàn)下面這句代碼,瞬間找到了家的方向。
```
functionRegistry.registerFunction(new FunctionIdentifier(name), builder)
```
---
## 人生巔峰
到此,大豬的分析與編碼已經(jīng)完成,下面是今天給大哥的解決方案。
方法實(shí)現(xiàn)可以通過查詢sql得到,或者接口都渴以。
```
? ? val spark = SparkSession
? ? ? .builder()
? ? ? .appName("test")
? ? ? .master("local[*]")
? ? ? .getOrCreate()
? ? val name = "hello"
? ? val (fun, argumentTypes, returnType) = ScalaSourceUDF(
? ? ? """
? ? ? ? |def apply(name:String)=name+" => hi"
? ? ? ? |""".stripMargin)
? ? val inputTypes = Try(argumentTypes.toList).toOption
? ? def builder(e: Seq[Expression]) = ScalaUDF(fun, returnType, e, inputTypes.getOrElse(Nil), Some(name))
? ? spark.sessionState.functionRegistry.registerFunction(new FunctionIdentifier(name), builder)
? ? val rdd = spark
? ? ? .sparkContext
? ? ? .parallelize(Array(("dounine", "20")))
? ? ? .map(x => Row.fromSeq(Array(x._1, x._2)))
? ? val types = StructType(
? ? ? Array(
? ? ? ? StructField("name", StringType),
? ? ? ? StructField("age", StringType)
? ? ? )
? ? )
? ? spark.createDataFrame(rdd, types).createTempView("log")
? ? spark.sql("select hello(name) from log").show(false)
```
真打臉,昨天還說不行的。

---
