背景
目前 spark 對(duì) MySQL 的操作只有 Append,Overwrite,ErrorIfExists,Ignore幾種表級(jí)別的模式,有時(shí)我們需要對(duì)表進(jìn)行行級(jí)別的操作,比如update。即我們需要構(gòu)造這樣的語(yǔ)句出來:insert into tb (id,name,age) values (?,?,?) on duplicate key update id=?,name =? ,age=?;
需求:我們的目的是既不影響以前寫的代碼,又不引入新的API,只需新加一個(gè)配置如:savemode=update這樣的形式來實(shí)現(xiàn)。
實(shí)踐
要滿足以上需求,肯定是要改源碼的,首先創(chuàng)建自己的saveMode,只是新加了一個(gè)Update而已:
public enum I4SaveMode {
Append,
Overwrite,
ErrorIfExists,
Ignore,
Update
}
JDBC數(shù)據(jù)源的相關(guān)實(shí)現(xiàn)主要在JdbcRelationProvider里,我們需要關(guān)注的是createRelation方法,我們可以在此方法里,把SaveMode改成我們自己的mode,并把mode帶到saveTable方法里,所以改造后的方法如下(改了的地方都有注釋):
override def createRelation(
sqlContext: SQLContext,
mode: SaveMode,
parameters: Map[String, String],
df: DataFrame): BaseRelation = {
val options = new JDBCOptions(parameters)
val isCaseSensitive = sqlContext.conf.caseSensitiveAnalysis
// 替換成自己的saveMode
var saveMode = mode match {
case SaveMode.Overwrite => I4SaveMode.Overwrite
case SaveMode.Append => I4SaveMode.Append
case SaveMode.ErrorIfExists => I4SaveMode.ErrorIfExists
case SaveMode.Ignore => I4SaveMode.Ignore
}
//重點(diǎn)在這里,檢查是否有saveMode=update的參數(shù),并設(shè)為對(duì)應(yīng)的模式
val parameterLower = parameters.map(kv => (kv._1.toLowerCase,kv._2))
if(parameterLower.keySet.contains("savemode")){
saveMode = if(parameterLower.get("savemode").get.equals("update")) I4SaveMode.Update else saveMode
}
val conn = JdbcUtils.createConnectionFactory(options)()
try {
val tableExists = JdbcUtils.tableExists(conn, options)
if (tableExists) {
saveMode match {
case I4SaveMode.Overwrite =>
if (options.isTruncate && isCascadingTruncateTable(options.url) == Some(false)) {
// In this case, we should truncate table and then load.
truncateTable(conn, options.table)
val tableSchema = JdbcUtils.getSchemaOption(conn, options)
saveTable(df, tableSchema, isCaseSensitive, options, saveMode)
} else {
......
}
接下來就是saveTable方法:
def saveTable(
df: DataFrame,
tableSchema: Option[StructType],
isCaseSensitive: Boolean,
options: JDBCOptions,
mode: I4SaveMode): Unit = {
......
val insertStmt = getInsertStatement(table, rddSchema, tableSchema, isCaseSensitive, dialect)
.....
repartitionedDF.foreachPartition(iterator => savePartition(
getConnection, table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel)
)
}
這里通過getInsertStatement方法構(gòu)造sql語(yǔ)句,接著遍歷每個(gè)分區(qū)進(jìn)行對(duì)應(yīng)的save操作,我們先看是構(gòu)造語(yǔ)句是怎么改的(改了的地方都有注釋):
def getInsertStatement(
table: String,
rddSchema: StructType,
tableSchema: Option[StructType],
isCaseSensitive: Boolean,
dialect: JdbcDialect,
mode: I4SaveMode): String = {
val columns = if (tableSchema.isEmpty) {
rddSchema.fields.map(x => dialect.quoteIdentifier(x.name)).mkString(",")
} else {
val columnNameEquality = if (isCaseSensitive) {
org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution
} else {
org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution
}
val tableColumnNames = tableSchema.get.fieldNames
rddSchema.fields.map { col =>
val normalizedName = tableColumnNames.find(f => columnNameEquality(f, col.name)).getOrElse {
throw new AnalysisException(s"""Column "${col.name}" not found in schema $tableSchema""")
}
dialect.quoteIdentifier(normalizedName)
}.mkString(",")
}
val placeholders = rddSchema.fields.map(_ => "?").mkString(",")
// s"INSERT INTO $table ($columns) VALUES ($placeholders)"
//若為update模式需要單獨(dú)構(gòu)造
mode match {
case I4SaveMode.Update ?
val duplicateSetting = rddSchema.fields.map(x => dialect.quoteIdentifier(x.name)).map(name ? s"$name=?").mkString(",")
s"INSERT INTO $table ($columns) VALUES ($placeholders) ON DUPLICATE KEY UPDATE $duplicateSetting"
case _ ? s"INSERT INTO $table ($columns) VALUES ($placeholders)"
}
}
只需判斷是否是update模式來構(gòu)造對(duì)應(yīng)的 sql語(yǔ)句,接著主要是看 savePartition 方法,看看具體是怎么保存的:
def savePartition(
getConnection: () => Connection,
table: String,
iterator: Iterator[Row],
rddSchema: StructType,
insertStmt: String,
batchSize: Int,
dialect: JdbcDialect,
isolationLevel: Int): Iterator[Byte] = {
val conn = getConnection()
var committed = false
var finalIsolationLevel = Connection.TRANSACTION_NONE
if (isolationLevel != Connection.TRANSACTION_NONE) {
try {
val metadata = conn.getMetaData
if (metadata.supportsTransactions()) {
// Update to at least use the default isolation, if any transaction level
// has been chosen and transactions are supported
val defaultIsolation = metadata.getDefaultTransactionIsolation
finalIsolationLevel = defaultIsolation
if (metadata.supportsTransactionIsolationLevel(isolationLevel)) {
// Finally update to actually requested level if possible
finalIsolationLevel = isolationLevel
} else {
logWarning(s"Requested isolation level $isolationLevel is not supported; " +
s"falling back to default isolation level $defaultIsolation")
}
} else {
logWarning(s"Requested isolation level $isolationLevel, but transactions are unsupported")
}
} catch {
case NonFatal(e) => logWarning("Exception while detecting transaction support", e)
}
}
val supportsTransactions = finalIsolationLevel != Connection.TRANSACTION_NONE
try {
if (supportsTransactions) {
conn.setAutoCommit(false) // Everything in the same db transaction.
conn.setTransactionIsolation(finalIsolationLevel)
}
val stmt = conn.prepareStatement(insertStmt)
val setters = rddSchema.fields.map(f => makeSetter(conn, dialect, f.dataType))
val nullTypes = rddSchema.fields.map(f => getJdbcType(f.dataType, dialect).jdbcNullType)
val numFields = rddSchema.fields.length
try {
var rowCount = 0
while (iterator.hasNext) {
val row = iterator.next()
var i = 0
while (i < numFields) {
if (row.isNullAt(i)) {
stmt.setNull(i + 1, nullTypes(i))
} else {
setters(i).apply(stmt, row, i)
}
i = i + 1
}
stmt.addBatch()
rowCount += 1
if (rowCount % batchSize == 0) {
stmt.executeBatch()
rowCount = 0
}
}
if (rowCount > 0) {
stmt.executeBatch()
}
} finally {
stmt.close()
}
if (supportsTransactions) {
conn.commit()
}
committed = true
Iterator.empty
} catch {
case e: SQLException =>
val cause = e.getNextException
if (cause != null && e.getCause != cause) {
// If there is no cause already, set 'next exception' as cause. If cause is null,
// it *may* be because no cause was set yet
if (e.getCause == null) {
try {
e.initCause(cause)
} catch {
// Or it may be null because the cause *was* explicitly initialized, to *null*,
// in which case this fails. There is no other way to detect it.
// addSuppressed in this case as well.
case _: IllegalStateException => e.addSuppressed(cause)
}
} else {
e.addSuppressed(cause)
}
}
throw e
} finally {
if (!committed) {
// The stage must fail. We got here through an exception path, so
// let the exception through unless rollback() or close() want to
// tell the user about another problem.
if (supportsTransactions) {
conn.rollback()
}
conn.close()
} else {
// The stage must succeed. We cannot propagate any exception close() might throw.
try {
conn.close()
} catch {
case e: Exception => logWarning("Transaction succeeded, but closing failed", e)
}
}
}
}
大體思想就是在迭代該分區(qū)數(shù)據(jù)進(jìn)行插入之前就先根據(jù)數(shù)據(jù)的schema設(shè)置好了插入模板setters,迭代的時(shí)候只需將此模板應(yīng)用到每一行數(shù)據(jù)上就行了,避免了每一行都需要去判斷數(shù)據(jù)類型。
在非update的情況下:insert into tb (id,name,age) values (?,?,?)
在update情況下:insert into tb (id,name,age) values (?,?,?) on duplicate key update id=?,name =? ,age=?;
即占位符多了一倍,在update模式下進(jìn)行寫入的時(shí)候需要向PreparedStatement多喂一遍數(shù)據(jù)。原本的makeSetter方法如下:
private def makeSetter(
conn: Connection,
dialect: JdbcDialect,
dataType: DataType): JDBCValueSetter = dataType match {
case IntegerType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setInt(pos + 1, row.getInt(pos))
case LongType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setLong(pos + 1, row.getLong(pos))
...
}
我們只需要再加一個(gè)相對(duì)位置參數(shù)offset來控制,即改造成:
private def makeSetter(
conn: Connection,
dialect: JdbcDialect,
dataType: DataType): JDBCValueSetter = dataType match {
case IntegerType ?
(stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ?
stmt.setInt(pos + 1, row.getInt(pos - offset))
case LongType ?
(stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ?
stmt.setLong(pos + 1, row.getLong(pos - offset))
...
在非update模式下offset就為0,在update模式下在沒有超過numFields時(shí)offset為0,超過numFileds時(shí)offset為numFields。改造后的savePartition方法為:
def savePartition(
getConnection: () => Connection,
table: String,
iterator: Iterator[Row],
rddSchema: StructType,
insertStmt: String,
batchSize: Int,
dialect: JdbcDialect,
isolationLevel: Int,
mode: I4SaveMode): Iterator[Byte] = {
...
//判斷是否為update
val isUpdateMode = mode == I4SaveMode.Update
val stmt = conn.prepareStatement(insertStmt)
val setters: Array[JDBCValueSetter] = rddSchema.fields.map(f => makeSetter(conn, dialect, f.dataType))
val nullTypes = rddSchema.fields.map(f => getJdbcType(f.dataType, dialect).jdbcNullType)
val length = rddSchema.fields.length
// update模式下占位符是2倍
val numFields = if (isUpdateMode) length * 2 else length
val midField = numFields / 2
try {
var rowCount = 0
while (iterator.hasNext) {
val row = iterator.next()
var i = 0
while (i < numFields) {
if (isUpdateMode) {
// update模式下未超過字段長(zhǎng)度,offset為0
i < midField match {
case true ?
if (row.isNullAt(i)) {
stmt.setNull(i + 1, nullTypes(i))
} else {
setters(i).apply(stmt, row, i, 0)
}
// update模式下超過字段長(zhǎng)度,offset為midField,即字段長(zhǎng)度
case false ?
if (row.isNullAt(i - midField)) {
stmt.setNull(i + 1, nullTypes(i - midField))
} else {
setters(i - midField).apply(stmt, row, i, midField)
}
}
} else {
if (row.isNullAt(i)) {
stmt.setNull(i + 1, nullTypes(i))
} else {
setters(i).apply(stmt, row, i, 0)
}
}
i = i + 1
}
...
改造好源碼后,需要重新編譯打包,替換掉線上對(duì)應(yīng)的jar即可。其實(shí)這里有個(gè)捷徑,自己創(chuàng)建相同的包名,改好源碼后打成jar包,把該jar里面的class文件替換掉線上jar里面對(duì)應(yīng)的那些class文件就可以了。
如何使用
若需要使用到update模式:
df.write.option("saveMode","update").jdbc(...)
參考
https://blog.csdn.net/cjuexuan/article/details/52333970
我的GitHub