Spark 實(shí)現(xiàn)MySQL update操作

背景

目前 spark 對(duì) MySQL 的操作只有 Append,Overwrite,ErrorIfExists,Ignore幾種表級(jí)別的模式,有時(shí)我們需要對(duì)表進(jìn)行行級(jí)別的操作,比如update。即我們需要構(gòu)造這樣的語(yǔ)句出來:insert into tb (id,name,age) values (?,?,?) on duplicate key update id=?,name =? ,age=?;

需求:我們的目的是既不影響以前寫的代碼,又不引入新的API,只需新加一個(gè)配置如:savemode=update這樣的形式來實(shí)現(xiàn)。

實(shí)踐

要滿足以上需求,肯定是要改源碼的,首先創(chuàng)建自己的saveMode,只是新加了一個(gè)Update而已:

public enum I4SaveMode {
    Append,
    Overwrite,
    ErrorIfExists,
    Ignore,
    Update
}

JDBC數(shù)據(jù)源的相關(guān)實(shí)現(xiàn)主要在JdbcRelationProvider里,我們需要關(guān)注的是createRelation方法,我們可以在此方法里,把SaveMode改成我們自己的mode,并把mode帶到saveTable方法里,所以改造后的方法如下(改了的地方都有注釋):

   override def createRelation(
                                   sqlContext: SQLContext,
                                   mode: SaveMode,
                                   parameters: Map[String, String],
                                   df: DataFrame): BaseRelation = {
        val options = new JDBCOptions(parameters)
        val isCaseSensitive = sqlContext.conf.caseSensitiveAnalysis
        // 替換成自己的saveMode
        var saveMode = mode match {
                case SaveMode.Overwrite => I4SaveMode.Overwrite
                case SaveMode.Append => I4SaveMode.Append
                case SaveMode.ErrorIfExists => I4SaveMode.ErrorIfExists
                case SaveMode.Ignore => I4SaveMode.Ignore
            }
        //重點(diǎn)在這里,檢查是否有saveMode=update的參數(shù),并設(shè)為對(duì)應(yīng)的模式
        val parameterLower = parameters.map(kv => (kv._1.toLowerCase,kv._2))
        if(parameterLower.keySet.contains("savemode")){
            saveMode = if(parameterLower.get("savemode").get.equals("update")) I4SaveMode.Update else saveMode
        }
        val conn = JdbcUtils.createConnectionFactory(options)()
        try {
            val tableExists = JdbcUtils.tableExists(conn, options)
            if (tableExists) {
                saveMode match {
                    case I4SaveMode.Overwrite =>
                        if (options.isTruncate && isCascadingTruncateTable(options.url) == Some(false)) {
                            // In this case, we should truncate table and then load.
                            truncateTable(conn, options.table)
                            val tableSchema = JdbcUtils.getSchemaOption(conn, options)
                            saveTable(df, tableSchema, isCaseSensitive, options, saveMode)
                        } else {
                        ......
    }

接下來就是saveTable方法:

def saveTable(
      df: DataFrame,
      tableSchema: Option[StructType],
      isCaseSensitive: Boolean,
      options: JDBCOptions,
      mode: I4SaveMode): Unit = { 
    ......
    val insertStmt = getInsertStatement(table, rddSchema, tableSchema, isCaseSensitive, dialect)
    .....
    repartitionedDF.foreachPartition(iterator => savePartition(
      getConnection, table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel)
    )
  }

這里通過getInsertStatement方法構(gòu)造sql語(yǔ)句,接著遍歷每個(gè)分區(qū)進(jìn)行對(duì)應(yīng)的save操作,我們先看是構(gòu)造語(yǔ)句是怎么改的(改了的地方都有注釋):

def getInsertStatement(
      table: String,
      rddSchema: StructType,
      tableSchema: Option[StructType],
      isCaseSensitive: Boolean,
      dialect: JdbcDialect,
      mode: I4SaveMode): String = {
    val columns = if (tableSchema.isEmpty) {
      rddSchema.fields.map(x => dialect.quoteIdentifier(x.name)).mkString(",")
    } else {
      val columnNameEquality = if (isCaseSensitive) {
        org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution
      } else {
        org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution
      } 
      val tableColumnNames = tableSchema.get.fieldNames
      rddSchema.fields.map { col =>
        val normalizedName = tableColumnNames.find(f => columnNameEquality(f, col.name)).getOrElse {
          throw new AnalysisException(s"""Column "${col.name}" not found in schema $tableSchema""")
        }
        dialect.quoteIdentifier(normalizedName)
      }.mkString(",")
    } 
    val placeholders = rddSchema.fields.map(_ => "?").mkString(",")
    // s"INSERT INTO $table ($columns) VALUES ($placeholders)"
   //若為update模式需要單獨(dú)構(gòu)造
    mode match {
            case I4SaveMode.Update ?
                val duplicateSetting = rddSchema.fields.map(x => dialect.quoteIdentifier(x.name)).map(name ? s"$name=?").mkString(",")
                s"INSERT INTO $table ($columns) VALUES ($placeholders) ON DUPLICATE KEY UPDATE $duplicateSetting"
            case _ ? s"INSERT INTO $table ($columns) VALUES ($placeholders)"
        }
  }

只需判斷是否是update模式來構(gòu)造對(duì)應(yīng)的 sql語(yǔ)句,接著主要是看 savePartition 方法,看看具體是怎么保存的:

 def savePartition(
      getConnection: () => Connection,
      table: String,
      iterator: Iterator[Row],
      rddSchema: StructType,
      insertStmt: String,
      batchSize: Int,
      dialect: JdbcDialect,
      isolationLevel: Int): Iterator[Byte] = {
    val conn = getConnection()
    var committed = false

    var finalIsolationLevel = Connection.TRANSACTION_NONE
    if (isolationLevel != Connection.TRANSACTION_NONE) {
      try {
        val metadata = conn.getMetaData
        if (metadata.supportsTransactions()) {
          // Update to at least use the default isolation, if any transaction level
          // has been chosen and transactions are supported
          val defaultIsolation = metadata.getDefaultTransactionIsolation
          finalIsolationLevel = defaultIsolation
          if (metadata.supportsTransactionIsolationLevel(isolationLevel))  {
            // Finally update to actually requested level if possible
            finalIsolationLevel = isolationLevel
          } else {
            logWarning(s"Requested isolation level $isolationLevel is not supported; " +
                s"falling back to default isolation level $defaultIsolation")
          }
        } else {
          logWarning(s"Requested isolation level $isolationLevel, but transactions are unsupported")
        }
      } catch {
        case NonFatal(e) => logWarning("Exception while detecting transaction support", e)
      }
    }
    val supportsTransactions = finalIsolationLevel != Connection.TRANSACTION_NONE

    try {
      if (supportsTransactions) {
        conn.setAutoCommit(false) // Everything in the same db transaction.
        conn.setTransactionIsolation(finalIsolationLevel)
      }
      val stmt = conn.prepareStatement(insertStmt)
      val setters = rddSchema.fields.map(f => makeSetter(conn, dialect, f.dataType))
      val nullTypes = rddSchema.fields.map(f => getJdbcType(f.dataType, dialect).jdbcNullType)
      val numFields = rddSchema.fields.length

      try {
        var rowCount = 0
        while (iterator.hasNext) {
          val row = iterator.next()
          var i = 0
          while (i < numFields) {
            if (row.isNullAt(i)) {
              stmt.setNull(i + 1, nullTypes(i))
            } else {
              setters(i).apply(stmt, row, i)
            }
            i = i + 1
          }
          stmt.addBatch()
          rowCount += 1
          if (rowCount % batchSize == 0) {
            stmt.executeBatch()
            rowCount = 0
          }
        }
        if (rowCount > 0) {
          stmt.executeBatch()
        }
      } finally {
        stmt.close()
      }
      if (supportsTransactions) {
        conn.commit()
      }
      committed = true
      Iterator.empty
    } catch {
      case e: SQLException =>
        val cause = e.getNextException
        if (cause != null && e.getCause != cause) {
          // If there is no cause already, set 'next exception' as cause. If cause is null,
          // it *may* be because no cause was set yet
          if (e.getCause == null) {
            try {
              e.initCause(cause)
            } catch {
              // Or it may be null because the cause *was* explicitly initialized, to *null*,
              // in which case this fails. There is no other way to detect it.
              // addSuppressed in this case as well.
              case _: IllegalStateException => e.addSuppressed(cause)
            }
          } else {
            e.addSuppressed(cause)
          }
        }
        throw e
    } finally {
      if (!committed) {
        // The stage must fail.  We got here through an exception path, so
        // let the exception through unless rollback() or close() want to
        // tell the user about another problem.
        if (supportsTransactions) {
          conn.rollback()
        }
        conn.close()
      } else {
        // The stage must succeed.  We cannot propagate any exception close() might throw.
        try {
          conn.close()
        } catch {
          case e: Exception => logWarning("Transaction succeeded, but closing failed", e)
        }
      }
    }
  }

大體思想就是在迭代該分區(qū)數(shù)據(jù)進(jìn)行插入之前就先根據(jù)數(shù)據(jù)的schema設(shè)置好了插入模板setters,迭代的時(shí)候只需將此模板應(yīng)用到每一行數(shù)據(jù)上就行了,避免了每一行都需要去判斷數(shù)據(jù)類型。
在非update的情況下:insert into tb (id,name,age) values (?,?,?)
在update情況下:insert into tb (id,name,age) values (?,?,?) on duplicate key update id=?,name =? ,age=?;
即占位符多了一倍,在update模式下進(jìn)行寫入的時(shí)候需要向PreparedStatement多喂一遍數(shù)據(jù)。原本的makeSetter方法如下:

private def makeSetter(
      conn: Connection,
      dialect: JdbcDialect,
      dataType: DataType): JDBCValueSetter = dataType match {
    case IntegerType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setInt(pos + 1, row.getInt(pos))
    case LongType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setLong(pos + 1, row.getLong(pos))
    ...
  }

我們只需要再加一個(gè)相對(duì)位置參數(shù)offset來控制,即改造成:

private def makeSetter(
       conn: Connection,
       dialect: JdbcDialect,
       dataType: DataType): JDBCValueSetter = dataType match {
     case IntegerType ?
        (stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ?
             stmt.setInt(pos + 1, row.getInt(pos - offset))
     case LongType ?
        (stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ?
             stmt.setLong(pos + 1, row.getLong(pos - offset))
    ...

在非update模式下offset就為0,在update模式下在沒有超過numFields時(shí)offset為0,超過numFileds時(shí)offset為numFields。改造后的savePartition方法為:

def savePartition(
                 getConnection: () => Connection,
                 table: String,
                 iterator: Iterator[Row],
                 rddSchema: StructType,
                 insertStmt: String,
                 batchSize: Int,
                 dialect: JdbcDialect,
                 isolationLevel: Int,
                 mode: I4SaveMode): Iterator[Byte] = {
    ...
    //判斷是否為update
    val isUpdateMode = mode == I4SaveMode.Update
    val stmt = conn.prepareStatement(insertStmt)
    val setters: Array[JDBCValueSetter] = rddSchema.fields.map(f => makeSetter(conn, dialect, f.dataType))
    val nullTypes = rddSchema.fields.map(f => getJdbcType(f.dataType, dialect).jdbcNullType)
    val length = rddSchema.fields.length
    // update模式下占位符是2倍
    val numFields = if (isUpdateMode) length * 2 else length
    val midField = numFields / 2
    try {
        var rowCount = 0
        while (iterator.hasNext) {
            val row = iterator.next()
            var i = 0
            while (i < numFields) {
                if (isUpdateMode) {
                    // update模式下未超過字段長(zhǎng)度,offset為0
                    i < midField match {
                        case true ?
                            if (row.isNullAt(i)) {
                                stmt.setNull(i + 1, nullTypes(i))
                            } else {
                                setters(i).apply(stmt, row, i, 0)
                            }
                        // update模式下超過字段長(zhǎng)度,offset為midField,即字段長(zhǎng)度
                        case false ?
                            if (row.isNullAt(i - midField)) {
                                stmt.setNull(i + 1, nullTypes(i - midField))
                            } else {
                                setters(i - midField).apply(stmt, row, i, midField)
                            }
                    }
                
                } else {
                    if (row.isNullAt(i)) {
                        stmt.setNull(i + 1, nullTypes(i))
                    } else {
                        setters(i).apply(stmt, row, i, 0)
                    }
                }
                i = i + 1
            }
          ...

改造好源碼后,需要重新編譯打包,替換掉線上對(duì)應(yīng)的jar即可。其實(shí)這里有個(gè)捷徑,自己創(chuàng)建相同的包名,改好源碼后打成jar包,把該jar里面的class文件替換掉線上jar里面對(duì)應(yīng)的那些class文件就可以了。

如何使用

若需要使用到update模式:

df.write.option("saveMode","update").jdbc(...)

參考

https://blog.csdn.net/cjuexuan/article/details/52333970

我的GitHub

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

  • Spring Cloud為開發(fā)人員提供了快速構(gòu)建分布式系統(tǒng)中一些常見模式的工具(例如配置管理,服務(wù)發(fā)現(xiàn),斷路器,智...
    卡卡羅2017閱讀 136,533評(píng)論 19 139
  • 1. Java基礎(chǔ)部分 基礎(chǔ)部分的順序:基本語(yǔ)法,類相關(guān)的語(yǔ)法,內(nèi)部類的語(yǔ)法,繼承相關(guān)的語(yǔ)法,異常的語(yǔ)法,線程的語(yǔ)...
    子非魚_t_閱讀 34,637評(píng)論 18 399
  • 一. Java基礎(chǔ)部分.................................................
    wy_sure閱讀 4,010評(píng)論 0 11
  • 1. 手機(jī)震了下。 一條新微信。 打開一看,有些意外。 是中學(xué)時(shí)的一個(gè)朋友,曾經(jīng)很熟絡(luò),只是大學(xué)后漸漸少了聯(lián)系,少...
    北葵向暖_Sun閱讀 554評(píng)論 0 0
  • “哈哈……你好認(rèn)真啊,明天都沒有你的班,你都來回復(fù)【收到】” 一個(gè)叫王泓的女生突然給我發(fā)來微信,我們是最近在多納多...
    司墨曹曉光閱讀 967評(píng)論 9 5

友情鏈接更多精彩內(nèi)容