1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
| package com.xxx.utils
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcOptionsInWrite, JdbcRelationProvider, JdbcUtils}
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects}
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql._
import java.sql.Connection
object DataFrameWriterEnhance {
implicit class DataFrameWriterMysqlUpdateEnhance(writer: DataFrameWriter[Row]) {
def update(): Unit = {
val extraOptionsField = writer.getClass.getDeclaredField("org$apache$spark$sql$DataFrameWriter$$extraOptions")
val dfField = writer.getClass.getDeclaredField("df")
val sourceField = writer.getClass.getDeclaredField("source")
val partitioningColumnsField = writer.getClass.getDeclaredField("partitioningColumns")
extraOptionsField.setAccessible(true)
dfField.setAccessible(true)
sourceField.setAccessible(true)
partitioningColumnsField.setAccessible(true)
val extraOptions = extraOptionsField.get(writer).asInstanceOf[scala.collection.Map[String, String]]
val df = dfField.get(writer).asInstanceOf[DataFrame]
val partitioningColumns = partitioningColumnsField.get(writer).asInstanceOf[Option[Seq[String]]]
val logicalPlanField = df.getClass.getDeclaredField("logicalPlan")
logicalPlanField.setAccessible(true)
var logicalPlan = logicalPlanField.get(df).asInstanceOf[LogicalPlan]
val session = df.sparkSession
val dataSource = DataSource(
sparkSession = session,
className = s"${DataFrameWriterEnhance.getClass.getName}MysqlUpdateRelationProvider",
partitionColumns = partitioningColumns.getOrElse(Nil),
options = extraOptions.toMap)
logicalPlan = dataSource.planForWriting(SaveMode.Append, logicalPlan)
val qe = session.sessionState.executePlan(logicalPlan)
SQLExecution.withNewExecutionId(session, qe)(qe.toRdd)
}
}
class MysqlUpdateRelationProvider extends JdbcRelationProvider {
override def createRelation(sqlContext: SQLContext, mode: SaveMode, parameters: Map[String, String], df: DataFrame): BaseRelation = {
val options = new JdbcOptionsInWrite(parameters)
val isCaseSensitive = sqlContext.sparkSession.sessionState.conf.caseSensitiveAnalysis
val conn = JdbcUtils.createConnectionFactory(options)()
try {
val tableExists = JdbcUtils.tableExists(conn, options)
if (tableExists) {
mode match {
case SaveMode.Overwrite =>
if (options.isTruncate && JdbcUtils.isCascadingTruncateTable(options.url).contains(false)) {
// In this case, we should truncate table and then load.
JdbcUtils.truncateTable(conn, options)
val tableSchema = JdbcUtils.getSchemaOption(conn, options)
updateTable(df, tableSchema, isCaseSensitive, options, parameters.getOrElse("ignoreNull", "false").toBoolean)
} else {
// Otherwise, do not truncate the table, instead drop and recreate it
JdbcUtils.dropTable(conn, options.table, options)
JdbcUtils.createTable(conn, df, options)
updateTable(df, Some(df.schema), isCaseSensitive, options, parameters.getOrElse("ignoreNull", "false").toBoolean)
}
case SaveMode.Append =>
val tableSchema = JdbcUtils.getSchemaOption(conn, options)
updateTable(df, tableSchema, isCaseSensitive, options, parameters.getOrElse("ignoreNull", "false").toBoolean)
case SaveMode.ErrorIfExists =>
throw new Exception(
s"Table or view '${options.table}' already exists. " +
s"SaveMode: ErrorIfExists.")
case SaveMode.Ignore =>
// With `SaveMode.Ignore` mode, if table already exists, the save operation is expected
// to not save the contents of the DataFrame and to not change the existing data.
// Therefore, it is okay to do nothing here and then just return the relation below.
}
} else {
JdbcUtils.createTable(conn, df, options)
updateTable(df, Some(df.schema), isCaseSensitive, options, parameters.getOrElse("ignoreNull", "false").toBoolean)
}
} finally {
conn.close()
}
createRelation(sqlContext, parameters)
}
def updateTable(df: DataFrame,
tableSchema: Option[StructType],
isCaseSensitive: Boolean,
options: JdbcOptionsInWrite,
ignoreNull: Boolean = false): Unit = {
val url = options.url
val table = options.table
val dialect = JdbcDialects.get(url)
val rddSchema = df.schema
val getConnection: () => Connection = JdbcUtils.createConnectionFactory(options)
val batchSize = options.batchSize
val isolationLevel = options.isolationLevel
val updateStmt = getUpdateStatement(table, rddSchema, tableSchema, isCaseSensitive, dialect, ignoreNull)
println(updateStmt)
val repartitionedDF = options.numPartitions match {
case Some(n) if n <= 0 => throw new IllegalArgumentException(
s"Invalid value `$n` for parameter `${JDBCOptions.JDBC_NUM_PARTITIONS}` in table writing " +
"via JDBC. The minimum value is 1.")
case Some(n) if n < df.rdd.partitions.length => df.coalesce(n)
case _ => df
}
repartitionedDF.rdd.foreachPartition(iterator => JdbcUtils.savePartition(
getConnection, table, iterator, rddSchema, updateStmt, batchSize, dialect, isolationLevel,
options)
)
}
def getUpdateStatement(table: String,
rddSchema: StructType,
tableSchema: Option[StructType],
isCaseSensitive: Boolean,
dialect: JdbcDialect,
ignoreNull: Boolean): String = {
val columns = if (tableSchema.isEmpty) {
rddSchema.fields.map(x => dialect.quoteIdentifier(x.name)).mkString(",")
} else {
val columnNameEquality = if (isCaseSensitive) {
org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution
} else {
org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution
}
// The generated insert statement needs to follow rddSchema's column sequence and
// tableSchema's column names. When appending data into some case-sensitive DBMSs like
// PostgreSQL/Oracle, we need to respect the existing case-sensitive column names instead of
// RDD column names for user convenience.
val tableColumnNames = tableSchema.get.fieldNames
rddSchema.fields.map { col =>
val normalizedName = tableColumnNames.find(f => columnNameEquality(f, col.name)).getOrElse {
throw new Exception(s"""Column "${col.name}" not found in schema $tableSchema""")
}
dialect.quoteIdentifier(normalizedName)
}.mkString(",")
}
val placeholders = rddSchema.fields.map(_ => "?").mkString(",")
if (ignoreNull) {
s"""INSERT INTO $table ($columns) VALUES ($placeholders) AS data_new
|ON DUPLICATE KEY UPDATE
|${columns.split(",").map(col => s"$col=IF(data_new.$col is null,$table.$col,data_new.$col)").mkString(",")}
|""".stripMargin
} else {
s"""INSERT INTO $table ($columns) VALUES ($placeholders)
|ON DUPLICATE KEY UPDATE
|${columns.split(",").map(col => s"$col=VALUES($col)").mkString(",")}
|""".stripMargin
}
}
}
}
|