spark 实现 mysql upsert

实现 spark dataframe/dataset 根据mysql表唯一键实现有则更新,无则插入功能。

基于 spark2.4.3 scala2.11.8

工具类 DataFrameWriterEnhance

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
package com.xxx.utils

import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcOptionsInWrite, JdbcRelationProvider, JdbcUtils}
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects}
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql._

import java.sql.Connection

object DataFrameWriterEnhance {

  implicit class DataFrameWriterMysqlUpdateEnhance(writer: DataFrameWriter[Row]) {
    def update(): Unit = {
      val extraOptionsField = writer.getClass.getDeclaredField("org$apache$spark$sql$DataFrameWriter$$extraOptions")
      val dfField = writer.getClass.getDeclaredField("df")
      val sourceField = writer.getClass.getDeclaredField("source")
      val partitioningColumnsField = writer.getClass.getDeclaredField("partitioningColumns")
      extraOptionsField.setAccessible(true)
      dfField.setAccessible(true)
      sourceField.setAccessible(true)
      partitioningColumnsField.setAccessible(true)
      val extraOptions = extraOptionsField.get(writer).asInstanceOf[scala.collection.Map[String, String]]
      val df = dfField.get(writer).asInstanceOf[DataFrame]
      val partitioningColumns = partitioningColumnsField.get(writer).asInstanceOf[Option[Seq[String]]]
      val logicalPlanField = df.getClass.getDeclaredField("logicalPlan")
      logicalPlanField.setAccessible(true)
      var logicalPlan = logicalPlanField.get(df).asInstanceOf[LogicalPlan]
      val session = df.sparkSession
      val dataSource = DataSource(
        sparkSession = session,
        className = s"${DataFrameWriterEnhance.getClass.getName}MysqlUpdateRelationProvider",
        partitionColumns = partitioningColumns.getOrElse(Nil),
        options = extraOptions.toMap)
      logicalPlan = dataSource.planForWriting(SaveMode.Append, logicalPlan)
      val qe = session.sessionState.executePlan(logicalPlan)
      SQLExecution.withNewExecutionId(session, qe)(qe.toRdd)
    }
  }

  class MysqlUpdateRelationProvider extends JdbcRelationProvider {
    override def createRelation(sqlContext: SQLContext, mode: SaveMode, parameters: Map[String, String], df: DataFrame): BaseRelation = {
      val options = new JdbcOptionsInWrite(parameters)
      val isCaseSensitive = sqlContext.sparkSession.sessionState.conf.caseSensitiveAnalysis
      val conn = JdbcUtils.createConnectionFactory(options)()
      try {
        val tableExists = JdbcUtils.tableExists(conn, options)
        if (tableExists) {
          mode match {
            case SaveMode.Overwrite =>
              if (options.isTruncate && JdbcUtils.isCascadingTruncateTable(options.url).contains(false)) {
                // In this case, we should truncate table and then load.
                JdbcUtils.truncateTable(conn, options)
                val tableSchema = JdbcUtils.getSchemaOption(conn, options)
                updateTable(df, tableSchema, isCaseSensitive, options)
              } else {
                // Otherwise, do not truncate the table, instead drop and recreate it
                JdbcUtils.dropTable(conn, options.table, options)
                JdbcUtils.createTable(conn, df, options)
                updateTable(df, Some(df.schema), isCaseSensitive, options)
              }

            case SaveMode.Append =>
              val tableSchema = JdbcUtils.getSchemaOption(conn, options)
              updateTable(df, tableSchema, isCaseSensitive, options)

            case SaveMode.ErrorIfExists =>
              throw new Exception(
                s"Table or view '${options.table}' already exists. " +
                  s"SaveMode: ErrorIfExists.")

            case SaveMode.Ignore =>
            // With `SaveMode.Ignore` mode, if table already exists, the save operation is expected
            // to not save the contents of the DataFrame and to not change the existing data.
            // Therefore, it is okay to do nothing here and then just return the relation below.
          }
        } else {
          JdbcUtils.createTable(conn, df, options)
          updateTable(df, Some(df.schema), isCaseSensitive, options)
        }
      } finally {
        conn.close()
      }

      createRelation(sqlContext, parameters)
    }

    def updateTable(df: DataFrame,
                    tableSchema: Option[StructType],
                    isCaseSensitive: Boolean,
                    options: JdbcOptionsInWrite): Unit = {
      val url = options.url
      val table = options.table
      val dialect = JdbcDialects.get(url)
      val rddSchema = df.schema
      val getConnection: () => Connection = JdbcUtils.createConnectionFactory(options)
      val batchSize = options.batchSize
      val isolationLevel = options.isolationLevel

      val updateStmt = getUpdateStatement(table, rddSchema, tableSchema, isCaseSensitive, dialect)
      println(updateStmt)
      val repartitionedDF = options.numPartitions match {
        case Some(n) if n <= 0 => throw new IllegalArgumentException(
          s"Invalid value `$n` for parameter `${JDBCOptions.JDBC_NUM_PARTITIONS}` in table writing " +
            "via JDBC. The minimum value is 1.")
        case Some(n) if n < df.rdd.partitions.length => df.coalesce(n)
        case _ => df
      }
      repartitionedDF.rdd.foreachPartition(iterator => JdbcUtils.savePartition(
        getConnection, table, iterator, rddSchema, updateStmt, batchSize, dialect, isolationLevel,
        options)
      )
    }

    def getUpdateStatement(table: String,
                           rddSchema: StructType,
                           tableSchema: Option[StructType],
                           isCaseSensitive: Boolean,
                           dialect: JdbcDialect): String = {
      val columns = if (tableSchema.isEmpty) {
        rddSchema.fields.map(x => dialect.quoteIdentifier(x.name)).mkString(",")
      } else {
        val columnNameEquality = if (isCaseSensitive) {
          org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution
        } else {
          org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution
        }
        // The generated insert statement needs to follow rddSchema's column sequence and
        // tableSchema's column names. When appending data into some case-sensitive DBMSs like
        // PostgreSQL/Oracle, we need to respect the existing case-sensitive column names instead of
        // RDD column names for user convenience.
        val tableColumnNames = tableSchema.get.fieldNames
        rddSchema.fields.map { col =>
          val normalizedName = tableColumnNames.find(f => columnNameEquality(f, col.name)).getOrElse {
            throw new Exception(s"""Column "${col.name}" not found in schema $tableSchema""")
          }
          dialect.quoteIdentifier(normalizedName)
        }.mkString(",")
      }
      val placeholders = rddSchema.fields.map(_ => "?").mkString(",")
      s"""INSERT INTO $table ($columns) VALUES ($placeholders)
         |ON DUPLICATE KEY UPDATE
         |${columns.split(",").map(col => s"$col=VALUES($col)").mkString(",")}
         |""".stripMargin
    }
  }
}

工具类 MysqlUtils

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
package com.xxx.utils

import com.xxx.utils.DataFrameWriterEnhance.DataFrameWriterMysqlUpdateEnhance
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.{NullType, ShortType}
import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession}

object MysqlUtils {

  def upsert(rawDF: DataFrame, database: String, tableName: String)(implicit spark: SparkSession): Unit = {
    var df = rawDF
    for (elem <- df.schema.fields) {
      if (elem.dataType == NullType) {
        df = df.withColumn(elem.name, col(elem.name).cast(ShortType))
      }
    }

    df.write
      .format("jdbc")
      .mode(SaveMode.Append)
      .option("driver", "com.mysql.jdbc.Driver")
      .option("url", spark.conf.get(s"spark.job.mysql.${database}.url"))
      .option("user", spark.conf.get(s"spark.job.mysql.${database}.username"))
      .option("password", spark.conf.get(s"spark.job.mysql.${database}.password"))
      .option("dbtable", tableName)
      .option("useSSL", "false")
      .option("showSql", "false")
      .option("numPartitions", "1")
      .update()
  }


}

使用

spark启动脚本加入mysql配置

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
spark-submit \
--master yarn \
--deploy-mode cluster \
--executor-memory 3G \
--num-executors 5 \
--executor-cores 4 \
--driver-memory 3G \
--conf spark.job.mysql.test.url=${jdbc_url} \
--conf spark.job.mysql.test.username=${jdbc_username} \
--conf spark.job.mysql.test.password=${jdbc_password} \

使用范例

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
import utils.MysqlUtils

object TestMysqlUpsert {
  def main(args: Array[String]): Unit = {
    implicit val spark = SparkSession.builder().enableHiveSupport().getOrCreate()
    import spark.implicits._

    val database = "test"
    val arr = Array((1,11,"name1",11111),(2,22,"name2",22222))
    val df = spark.sparkContext.parallelize(arr)
      .toDF("key_one", "key_two", "val_one", "val_two")

    MysqlUtils.upsert(df, database, "test_unique_key")
    spark.close()

  }
}

test_unique_key表结构

1
2
3
4
5
6
7
CREATE TABLE `test_unique_key` (
  `key_one` int(11) NOT NULL DEFAULT '0',
  `key_two` int(11) NOT NULL DEFAULT '0',
  `val_one` varchar(50) DEFAULT NULL,
  `val_two` int(11) NOT NULL DEFAULT '0',
  UNIQUE KEY `uk` (`key_one`,`key_two`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COMMENT='test';

参考

csdn_Spark Upsert写入Mysql(scala增强) 无需依赖

Buy me a coffee~
hpkaiq 支付宝支付宝
hpkaiq 微信微信
0%