SparkSQL 优化jdbc外部数据源的读写

目录

  1. jdbc 参数解读
  2. 源码
  3. jdbc 读并发度优化
  4. jdbc 写并发度优化

jdbc 参数解读

Spark SQL还包括一个可以使用JDBC从其他数据库读取数据的数据源。与使用JdbcRDD相比,应优先使用此功能。这是因为结果作为DataFrame返回,它们可以在Spark SQL中轻松处理或与其他数据源连接。JDBC数据源也更易于使用Java或Python,因为它不需要用户提供ClassTag。

可以使用Data Sources API将远程数据库中的表加载为DataFrame或Spark SQL临时视图。用户可以在数据源选项中指定JDBC连接属性。user和password通常作为用于登录数据源的连接属性。除连接属性外,Spark还支持以下不区分大小写的选项:

属性名称 解释
url 要连接的JDBC URL
dbtable 读取或写入的JDBC表
query 指定查询语句
driver 用于连接到该URL的JDBC驱动类名
partitionColumn, lowerBound, upperBound 如果指定了这些选项,则必须全部指定。另外, numPartitions必须指定
numPartitions 表读写中可用于并行处理的最大分区数。这也确定了并发JDBC连接的最大数量。如果要写入的分区数超过此限制,我们可以通过coalesce(numPartitions)在写入之前进行调用将其降低到此限制
queryTimeout 默认为0,查询超时时间
fetchsize JDBC的获取大小,它确定每次要获取多少行。这可以帮助提高JDBC驱动程序的性能
batchsize 默认为1000,JDBC批处理大小,这可以帮助提高JDBC驱动程序的性能。
isolationLevel 事务隔离级别,适用于当前连接。它可以是一个NONEREAD_COMMITTEDREAD_UNCOMMITTEDREPEATABLE_READ,或SERIALIZABLE,对应于由JDBC的连接对象定义,缺省值为标准事务隔离级别READ_UNCOMMITTED。此选项仅适用于写作。
sessionInitStatement 在向远程数据库打开每个数据库会话之后,在开始读取数据之前,此选项将执行自定义SQL语句,使用它来实现会话初始化代码。
truncate 这是与JDBC writer相关的选项。当SaveMode.Overwrite启用时,就会清空目标表的内容,而不是删除和重建其现有的表。默认为false
pushDownPredicate 用于启用或禁用谓词下推到JDBC数据源的选项。默认值为true,在这种情况下,Spark将尽可能将过滤器下推到JDBC数据源。

函数示例:

val jdbcDF = sparkSession.sqlContext.read.format("jdbc")
.option("url", url)
.option("driver", "com.mysql.jdbc.Driver")
.option("dbtable", "table")
.option("user", "user")
.option("partitionColumn", "id")
.option("lowerBound", 1)
.option("upperBound", 10000)
.option("fetchsize", 100) //经测试,fetchsize和batchsize的大小对读写性能并没有变化
.option("xxx", "xxx")
.load()

从函数可以看出,option模式其实是一种开放接口,spark会根据具体的参数,做出相应的行为。

源码

  • SparkSession
/**
* Returns a [[DataFrameReader]] that can be used to read non-streaming data in as a
* `DataFrame`.
* {{{
* sparkSession.read.parquet("/path/to/file.parquet")
* sparkSession.read.schema(schema).json("/path/to/file.json")
* }}}
*
* @since 2.0.0
*/
def read: DataFrameReader = new DataFrameReader(self)
  • DataFrameReader
 // ...省略代码...
/**
*所有的数据由RDD的一个分区处理,如果你这个表很大,很可能会出现OOM
*可以使用DataFrameDF.rdd.partitions.size方法查看
*/
def jdbc(url: String, table: String, properties: Properties): DataFrame = {
assertNoSpecifiedSchema("jdbc")
this.extraOptions ++= properties.asScala
this.extraOptions += (JDBCOptions.JDBC_URL -> url, JDBCOptions.JDBC_TABLE_NAME -> table)
format("jdbc").load()
}
/**
* @param url 数据库url
* @param table 表名
* @param columnName 分区字段名
* @param lowerBound `columnName`的最小值,用于分区步长
* @param upperBound `columnName`的最大值,用于分区步长.
* @param numPartitions 分区数量
* @param connectionProperties 其他参数
* @since 1.4.0
*/
def jdbc(
url: String,
table: String,
columnName: String,
lowerBound: Long,
upperBound: Long,
numPartitions: Int,
connectionProperties: Properties): DataFrame = {
this.extraOptions ++= Map(
JDBCOptions.JDBC_PARTITION_COLUMN -> columnName,
JDBCOptions.JDBC_LOWER_BOUND -> lowerBound.toString,
JDBCOptions.JDBC_UPPER_BOUND -> upperBound.toString,
JDBCOptions.JDBC_NUM_PARTITIONS -> numPartitions.toString)
jdbc(url, table, connectionProperties)
}

/**
* @param predicates 每个分区的where条件
* 比如:"id <= 1000", "score > 1000 and score <= 2000"
* 将会分成两个分区
* @since 1.4.0
*/
def jdbc(
url: String,
table: String,
predicates: Array[String],
connectionProperties: Properties): DataFrame = {
assertNoSpecifiedSchema("jdbc")
val params = extraOptions.toMap ++ connectionProperties.asScala.toMap
val options = new JDBCOptions(url, table, params)
val parts: Array[Partition] = predicates.zipWithIndex.map { case (part, i) =>
JDBCPartition(part, i) : Partition
}
val relation = JDBCRelation(parts, options)(sparkSession)
sparkSession.baseRelationToDataFrame(relation)
}

jdbc 读取并发度优化

很多人在spark中使用默认提供的jdbc方法时,在数据库数据较大时经常发现任务 hang 住,其实是单线程任务过重导致,这时候需要提高读取的并发度。

  1. 单partition(无并发)

    调用函数

    def jdbc(url: String, table: String, properties: Properties): DataFrame

    使用:

    val url = "jdbc:mysql://mysqlHost:3306/database"
    val tableName = "table"

    // 设置连接用户&密码
    val prop = new java.util.Properties
    prop.setProperty("user","username")
    prop.setProperty("password","pwd")

    // 取得该表数据
    val jdbcDF = sqlContext.read.jdbc(url,tableName,prop)

    // 一些操作
    ....

    查看并发度

    jdbcDF.rdd.partitions.size # 结果返回 1

    该操作的并发度为1,你所有的数据都会在一个partition中进行操作,意味着无论你给的资源有多少,只有一个task会执行任务,执行效率可想而之,并且在稍微大点的表中进行操作分分钟就会OOM。

    更直观的说法是,达到千万级别的表就不要使用该操作,count操作就要等一万年,no zuo no die ,don’t to try !

    WARN TaskSetManager: Lost task 0.0 in stage 6.0 (TID 56, spark047219):
    java.lang.OutOfMemoryError: GC overhead limit exceeded
    at com.mysql.jdbc.MysqlIO.reuseAndReadPacket(MysqlIO.java:3380)
  2. 根据 id (整型)字段分区

    调用函数

    def jdbc(
    url: String,
    table: String,
    columnName: String, # 根据该字段分区,需要为整形,比如id等
    lowerBound: Long, # 分区的下界
    upperBound: Long, # 分区的上界
    numPartitions: Int, # 分区的个数
    connectionProperties: Properties): DataFrame

    使用:

    val url = "jdbc:mysql://mysqlHost:3306/database"
    val tableName = "table"

    val columnName = "colName"
    val lowerBound = 1,
    val upperBound = 10000000,
    val numPartitions = 10,

    // 设置连接用户&密码
    val prop = new java.util.Properties
    prop.setProperty("user","username")
    prop.setProperty("password","pwd")

    // 取得该表数据
    val jdbcDF = sqlContext.read.jdbc(url,tableName,columnName,lowerBound,upperBound,numPartitions,prop)

    // 一些操作
    ....

    查看并发度

    jdbcDF.rdd.partitions.size # 结果返回 10

    该操作将字段 colName 中1-10000000条数据分到10个partition中,使用很方便,缺点也很明显,只能使用整形数据字段作为分区关键字。

    3000w数据的表 count 跨集群操作只要2s。

  3. 根据时间字段分区

    调用函数

    jdbc(
    url: String,
    table: String,
    predicates: Array[String],
    connectionProperties: Properties): DataFrame

    下面以使用最多的时间字段分区为例:

    val url = "jdbc:mysql://mysqlHost:3306/database"
    val tableName = "table"

    // 设置连接用户&密码
    val prop = new java.util.Properties
    prop.setProperty("user","username")
    prop.setProperty("password","pwd")

    /**
    * 将9月16-12月15三个月的数据取出,按时间分为6个partition
    * 为了减少事例代码,这里的时间都是写死的
    * modified_time 为时间字段
    */

    val predicates =
    Array(
    "2015-09-16" -> "2015-09-30",
    "2015-10-01" -> "2015-10-15",
    "2015-10-16" -> "2015-10-31",
    "2015-11-01" -> "2015-11-14",
    "2015-11-15" -> "2015-11-30",
    "2015-12-01" -> "2015-12-15"
    ).map {
    case (start, end) =>
    s"cast(modified_time as date) >= date '$start' " + s"AND cast(modified_time as date) <= date '$end'"
    }

    // 取得该表数据
    val jdbcDF = sqlContext.read.jdbc(url,tableName,predicates,prop)

    // 一些操作

    查看并发度

    jdbcDF.rdd.partitions.size # 结果返回 6

    该操作的每个分区数据都由该段时间的分区组成,这种方式适合各种场景,较为推荐。

  4. id 取模方式分区

    sqlContext.read.jdbc(url,tableName, "id%200", 1, 1000000,400,prop)

    根据numPartitions确定合理的模值,可以尽量做到数据的连续,且写法简单,但是由于在ID字段上使用了函数计算,所以索引将失效,此时需要配合其他包含索引的where条件加以辅助,才能使查询性能最大化。

  5. 自定义处理方式

    def getPredicates = {    
    //1.获取表total数据。
    //2.按numPartitions均分,获得offset,可以确保每个分片的数据一致
    //3.获取每个分片内的最大最小ID,组装成条件数组

    。。。实现细节省略
    }

    sqlContext.read.jdbc(url,table, getPredicates,connectionProperties)

    通过自由组装方式,可以达到精确控制,但是实现成本较高。

数据读取分区的原理

无论使用哪种JDBC API,spark拉取数据最终都是以select语句来执行的,所以在自定义分区条件或者指定的long型column时,都需要结合表的索引来综合考虑,才能以更高性能并发读取数据库数据。

API中的columnName其实只会作为where条件进行简单的拼接,所以数据库中支持的语法,都可以使用。tableName的原理也一样,仅会作为from 后的内容进行拼接,所以也可以写一个子句传入tableName中,但依然要在保证性能的前提下。

不仅仅是取模操作,数据库语法支持的任何函数,都可以在API中传入使用,关键在于性能是否达到预期。

JDBC的读取性能受很多条件影响,需要根据不同的数据库,表,索引,数据量,spark集群的executor情况等综合考虑,线上环境的操作,建议进行读写分离,即读备库,写主库。

注意: 高并发度可以大幅度提高读取以及处理数据的速度,但是如果设置过高(大量的partition同时读取)也可能会将数据源数据库弄挂。

jdbc 写并发度优化

  1. jdbc 方式

    object BatchInsertMySQL {
    case class Person(name: String, age: Int)
    def main(args: Array[String]): Unit = {

    // 创建sparkSession对象
    val conf = new SparkConf()
    .setAppName("BatchInsertMySQL")
    val spark: SparkSession = SparkSession.builder()
    .config(conf)
    .getOrCreate()
    import spark.implicits._
    // MySQL连接参数
    val url = JDBCUtils.url
    val user = JDBCUtils.user
    val pwd = JDBCUtils.password

    // 创建Properties对象,设置连接mysql的用户名和密码
    val properties: Properties = new Properties()

    properties.setProperty("user", user) // 用户名
    properties.setProperty("password", pwd) // 密码
    properties.setProperty("driver", "com.mysql.jdbc.Driver")
    properties.setProperty("numPartitions","10")

    // 读取mysql中的表数据
    val testDF: DataFrame = spark.read.jdbc(url, "test", properties)
    println("testDF的分区数: " + testDF.rdd.partitions.size)
    testDF.createOrReplaceTempView("test")
    testDF.sqlContext.cacheTable("test")
    testDF.printSchema()

    val result =
    s"""-- SQL代码
    """.stripMargin

    val resultBatch = spark.sql(result).as[Person]
    println("resultBatch的分区数: " + resultBatch.rdd.partitions.size)

    // 批量写入MySQL
    // 此处最好对处理的结果进行一次重分区
    // 由于数据量特别大,会造成每个分区数据特别多
    resultBatch.repartition(400).foreachPartition(record => {

    val list = new ListBuffer[Person]
    record.foreach(person => {
    val name = Person.name
    val age = Person.age
    list.append(Person(name,age))
    })
    upsertDateMatch(list) //执行批量插入数据
    })
    // 批量插入MySQL的方法
    def upsertPerson(list: ListBuffer[Person]): Unit = {

    var connect: Connection = null
    var pstmt: PreparedStatement = null

    try {
    connect = JDBCUtils.getConnection()
    // 禁用自动提交
    connect.setAutoCommit(false)

    val sql = "REPLACE INTO `person`(name, age)" +
    " VALUES(?, ?)"

    pstmt = connect.prepareStatement(sql)

    var batchIndex = 0
    for (person <- list) {
    pstmt.setString(1, person.name)
    pstmt.setString(2, person.age)
    // 加入批次
    pstmt.addBatch()
    batchIndex +=1
    // 控制提交的数量,
    // MySQL的批量写入尽量限制提交批次的数据量,否则会把MySQL写挂!!!
    if(batchIndex % 1000 == 0 && batchIndex !=0){
    pstmt.executeBatch()
    pstmt.clearBatch()
    }
    }
    // 提交批次
    pstmt.executeBatch()
    connect.commit()
    } catch {
    case e: Exception =>
    e.printStackTrace()
    } finally {
    JDBCUtils.closeConnection(connect, pstmt)
    }
    }

    spark.close()
    }
    }
  2. df 方式

    sqlDF.coalesce(10)	// 并行度
    .write
    .mode(SaveMode.Overwrite) //覆盖模式
    .format("jdbc")
    .option("url", url)
    .option("dbtable", s"$db.$target")
    .option("user", user)
    .option("password", password)
    .option("driver",driver)
    .option("batchsize","2000") // 每个批次写入的数据量
    .option("truncate",true) // SaveMode.Overwrite 不删除而是清空表
    .save()
Author: Tunan
Link: http://yerias.github.io/2020/11/05/spark/36/
Copyright Notice: All articles in this blog are licensed under CC BY-NC-SA 4.0 unless stating additionally.