實(shí)際應(yīng)用中經(jīng)常會(huì)遇到spark把DataFrame保存到mysql,同時(shí)遇重更新無(wú)重插入的場(chǎng)景,spark原生save只實(shí)現(xiàn)了insert,在遇到唯一性約束時(shí)就會(huì)拋出異常。為了解決這種問(wèn)題,我曾用過(guò)兩種方式,一種是采用foreachPartition,在每個(gè)partition里建立connection然后插入數(shù)據(jù),另一種方式是在mysql中建立臨時(shí)表和觸發(fā)器,spark將DataFrame的數(shù)據(jù)SaveMode.Append到臨時(shí)表,臨時(shí)表的觸發(fā)器對(duì)正式表進(jìn)行更新。兩種方法中,前者需要污染大量代碼,后者則把所有壓力集中到mysql中,而且因?yàn)閙ysql沒(méi)有postgresql的return null機(jī)制,需要定期清除臨時(shí)表,可能會(huì)引起事務(wù)卡死。
scala有很多類似于java和python的語(yǔ)法風(fēng)格,但隱式implicit是scala獨(dú)有的特性,尤其是隱式類implicit class,能起到類似于javascript的prototype的作用,能對(duì)各種類進(jìn)行增強(qiáng),在不修改源代碼重新編譯的情況下,給類增加方法。
通過(guò)源碼分析,DataFrameWriter的save方法是通過(guò)Datasource的planForWriting更新logicalPlan,在Datasource中根據(jù)className所對(duì)應(yīng)類(jdbc對(duì)應(yīng)的是JdbcRelationProvider類)的createRelation方法寫入數(shù)據(jù)庫(kù)。因此,需要改動(dòng)的地方并不多,只需要增加一個(gè)類似于JdbcRelationProvider的類,其實(shí)只要繼承并修改createRelation方法即可,并在Datasource更新logicalPlan的時(shí)候把className指定成這個(gè)類即可。
class MysqlUpdateRelationProvider extends JdbcRelationProvider {
override def createRelation(sqlContext: SQLContext, mode: SaveMode, parameters: Map[String, String], df: DataFrame): BaseRelation = {
val options = new JdbcOptionsInWrite(parameters)
val isCaseSensitive = sqlContext.sparkSession.sessionState.conf.caseSensitiveAnalysis
val conn = JdbcUtils.createConnectionFactory(options)()
try {
val tableExists = JdbcUtils.tableExists(conn, options)
if (tableExists) {
mode match {
case SaveMode.Overwrite =>
if (options.isTruncate && isCascadingTruncateTable(options.url) == Some(false)) {
// In this case, we should truncate table and then load.
truncateTable(conn, options)
val tableSchema = JdbcUtils.getSchemaOption(conn, options)
JdbcUtilsEnhance.updateTable(df, tableSchema, isCaseSensitive, options)
} else {
// Otherwise, do not truncate the table, instead drop and recreate it
dropTable(conn, options.table, options)
createTable(conn, df, options)
JdbcUtilsEnhance.updateTable(df, Some(df.schema), isCaseSensitive, options)
}
case SaveMode.Append =>
val tableSchema = JdbcUtils.getSchemaOption(conn, options)
JdbcUtilsEnhance.updateTable(df, tableSchema, isCaseSensitive, options)
case SaveMode.ErrorIfExists =>
throw new Exception(
s"Table or view '${options.table}' already exists. " +
s"SaveMode: ErrorIfExists.")
case SaveMode.Ignore =>
// With `SaveMode.Ignore` mode, if table already exists, the save operation is expected
// to not save the contents of the DataFrame and to not change the existing data.
// Therefore, it is okay to do nothing here and then just return the relation below.
}
} else {
createTable(conn, df, options)
JdbcUtilsEnhance.updateTable(df, Some(df.schema), isCaseSensitive, options)
}
} finally {
conn.close()
}
createRelation(sqlContext, parameters)
}
上述語(yǔ)句幾乎完全復(fù)制黏貼自父類,只是在JdbcUtilsEnhance.updateTable的地方,原來(lái)都是saveTable。
object JdbcUtilsEnhance {
def updateTable(df: DataFrame,
tableSchema: Option[StructType],
isCaseSensitive: Boolean,
options: JdbcOptionsInWrite): Unit = {
val url = options.url
val table = options.table
val dialect = JdbcDialects.get(url)
println(dialect)
val rddSchema = df.schema
val getConnection: () => Connection = createConnectionFactory(options)
val batchSize = options.batchSize
println(batchSize)
val isolationLevel = options.isolationLevel
val updateStmt = getUpdateStatement(table, rddSchema, tableSchema, isCaseSensitive, dialect)
println(updateStmt)
val repartitionedDF = options.numPartitions match {
case Some(n) if n <= 0 => throw new IllegalArgumentException(
s"Invalid value `$n` for parameter `${JDBCOptions.JDBC_NUM_PARTITIONS}` in table writing " +
"via JDBC. The minimum value is 1.")
case Some(n) if n < df.rdd.getNumPartitions => df.coalesce(n)
case _ => df
}
repartitionedDF.rdd.foreachPartition(iterator => savePartition(
getConnection, table, iterator, rddSchema, updateStmt, batchSize, dialect, isolationLevel,
options)
)
}
def getUpdateStatement(table: String,
rddSchema: StructType,
tableSchema: Option[StructType],
isCaseSensitive: Boolean,
dialect: JdbcDialect): String = {
val columns = if (tableSchema.isEmpty) {
rddSchema.fields.map(x => dialect.quoteIdentifier(x.name)).mkString(",")
} else {
val columnNameEquality = if (isCaseSensitive) {
org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution
} else {
org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution
}
// The generated insert statement needs to follow rddSchema's column sequence and
// tableSchema's column names. When appending data into some case-sensitive DBMSs like
// PostgreSQL/Oracle, we need to respect the existing case-sensitive column names instead of
// RDD column names for user convenience.
val tableColumnNames = tableSchema.get.fieldNames
rddSchema.fields.map { col =>
val normalizedName = tableColumnNames.find(f => columnNameEquality(f, col.name)).getOrElse {
throw new Exception(s"""Column "${col.name}" not found in schema $tableSchema""")
}
dialect.quoteIdentifier(normalizedName)
}.mkString(",")
}
val placeholders = rddSchema.fields.map(_ => "?").mkString(",")
s"""INSERT INTO $table ($columns) VALUES ($placeholders)
|ON DUPLICATE KEY UPDATE
|${columns.split(",").map(col=>s"$col=VALUES($col)").mkString(",")}
|""".stripMargin
}
}
增加的兩個(gè)update相關(guān)的方法也幾乎也全部復(fù)制黏貼自JdbcUtils。
然后是隱士類,由于DataFrameWriter的屬性幾乎都是private類型,所以需要用到反射。
object DataFrameWriterEnhance {
implicit class DataFrameWriterMysqlUpdateEnhance(writer: DataFrameWriter[Row]) {
def update(): Unit = {
val extraOptionsField = writer.getClass.getDeclaredField("extraOptions")
val dfField = writer.getClass.getDeclaredField("df")
val sourceField = writer.getClass.getDeclaredField("source")
val partitioningColumnsField = writer.getClass.getDeclaredField("partitioningColumns")
extraOptionsField.setAccessible(true)
dfField.setAccessible(true)
sourceField.setAccessible(true)
partitioningColumnsField.setAccessible(true)
val extraOptions = extraOptionsField.get(writer).asInstanceOf[scala.collection.mutable.HashMap[String, String]]
val df = dfField.get(writer).asInstanceOf[sql.DataFrame]
val partitioningColumns = partitioningColumnsField.get(writer).asInstanceOf[Option[Seq[String]]]
val logicalPlanField = df.getClass.getDeclaredField("logicalPlan")
logicalPlanField.setAccessible(true)
var logicalPlan = logicalPlanField.get(df).asInstanceOf[LogicalPlan]
val session = df.sparkSession
val dataSource = DataSource(
sparkSession = session,
className = "org.apache.spark.enhance.MysqlUpdateRelationProvider",
partitionColumns = partitioningColumns.getOrElse(Nil),
options = extraOptions.toMap)
logicalPlan = dataSource.planForWriting(SaveMode.Append, logicalPlan)
val qe = session.sessionState.executePlan(logicalPlan)
SQLExecution.withNewExecutionId(session, qe)(qe.toRdd)
}
}
}
這樣在應(yīng)用中就可以通過(guò)update方法,實(shí)現(xiàn)對(duì)mysql的upsert,下面假設(shè)x字段唯一
spark.sparkContext.parallelize(
Seq(("x1", "測(cè)試1"), ("x2", "測(cè)試2")), 2
).toDF("x", "y")
.write.format("jdbc").mode(SaveMode.Append)
.options(Map(
"url" -> config.database.url,
"dbtable" -> "foo",
"user" -> config.database.username,
"password" -> config.database.password,
"driver" -> config.database.driver
)).update()
這種方式對(duì)代碼污染小,泛用性大,對(duì)spark的catalyst沒(méi)有任何改動(dòng),但實(shí)現(xiàn)了需求,同時(shí)也可以拓展到任意的關(guān)系型數(shù)據(jù)庫(kù),甚至稍加改動(dòng)也可以支持redis等。