从jdbc的角度解读外部数据源

接口

首先了解三个trait,分别是BaseRelation、TableScan/PrunedScan/PrunedFilteredScan、InsertableRelation、RelationProvider,他们的功能在源码中解读。

//代表了一个抽象的数据源。该数据源由一行行有着已知schema的数据组成(关系表)。
abstract class BaseRelation {
def sqlContext: SQLContext
def schema: StructType //schema *

def sizeInBytes: Long = sqlContext.conf.defaultSizeInBytes

def needConversion: Boolean = true

def unhandledFilters(filters: Array[Filter]): Array[Filter] = filters
}

//用于扫描整张表,将数据返回成RDD[Row]。
@InterfaceStability.Stable
trait TableScan {
def buildScan(): RDD[Row]
}

//用于裁剪整张表,将数据返回成RDD[Row]。
@InterfaceStability.Stable
trait PrunedScan {
def buildScan(requiredColumns: Array[String]): RDD[Row]
}

//用于裁剪并过滤整张表,将数据返回成RDD[Row]。
@InterfaceStability.Stable
trait PrunedFilteredScan {
def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row]
}

//插入数据的时候实现,设置overwrite是否为true
@InterfaceStability.Stable
trait InsertableRelation {
def insert(data: DataFrame, overwrite: Boolean): Unit
}

//为自定义的数据源类型生成一个新的Relation对象
trait RelationProvider {

//创建一个新的Relation
def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation
}

jdbc实现

JdbcRelationProvider (最初也是最终的地方)

class JdbcRelationProvider extends CreatableRelationProvider
with RelationProvider with DataSourceRegister {

override def shortName(): String = "jdbc" //简称

override def createRelation(
sqlContext: SQLContext,
parameters: Map[String, String]): BaseRelation = { //所有options参数以map形式传入
val jdbcOptions = new JDBCOptions(parameters) //把参数传入和系统参数匹配
val resolver = sqlContext.conf.resolver //忽略大小写
val timeZoneId = sqlContext.conf.sessionLocalTimeZone //拿到时区
val schema = JDBCRelation.getSchema(resolver, jdbcOptions) //传入参数,拿到schema
val parts = JDBCRelation.columnPartition(schema, resolver, timeZoneId, jdbcOptions) //拿到分区
JDBCRelation(schema, parts, jdbcOptions)(sqlContext.sparkSession) //拿到RDD[R]
}
}

getSchema

def getSchema(resolver: Resolver, jdbcOptions: JDBCOptions): StructType = {
val tableSchema = JDBCRDD.resolveTable(jdbcOptions) //传入参数,解析table,拿到Schame
jdbcOptions.customSchema match { //模式匹配
case Some(customSchema) => JdbcUtils.getCustomSchema(
tableSchema, customSchema, resolver) // 返回定制的Schema
case None => tableSchema //返回直接的Schema
}
}

resolveTable (阶段一: 拿Schema)

def resolveTable(options: JDBCOptions): StructType = { 	//传入参数,拿到Schame
val url = options.url //拿到url:jdbc:mysql://hadoop:3306/
val table = options.tableOrQuery //拿到table:access_dw.dws_ad_phone_type_dist
val dialect = JdbcDialects.get(url) //拿到方言:MySQLDialect
val conn: Connection = JdbcUtils.createConnectionFactory(options)() //创建连接
try {
val statement = conn.prepareStatement(dialect.getSchemaQuery(table)) //拿到sql:com.mysql.jdbc.JDBC42PreparedStatement@5bda157e: SELECT * FROM access_dw.dws_ad_phone_type_dist WHERE 1=0
try {
statement.setQueryTimeout(options.queryTimeout) //设置超时时间
val rs = statement.executeQuery() //执行查询,返回一个查询产生的数据的ResultSet对象
try {
JdbcUtils.getSchema(rs, dialect, alwaysNullable = true) //传入数据rs,拿到schema,接着下面的内容
} finally {
rs.close()
}
} finally {
statement.close()
}
} finally {
conn.close()
}
}

getSchema (阶段一: 拿Schema)

def getSchema(
resultSet: ResultSet, //查询表返回的rs(表结构)
dialect: JdbcDialect, //MySQL方言
alwaysNullable: Boolean = false): StructType = {
val rsmd = resultSet.getMetaData //拿到表的元数据
val ncols = rsmd.getColumnCount //拿到需要的字段的列的数量
val fields = new Array[StructField](ncols) //创建一个StructField类型的数组,拼接fields
var i = 0
while (i < ncols) { //循环出每个column
val columnName = rsmd.getColumnLabel(i + 1) //返回列名:phoneSystemType
val dataType = rsmd.getColumnType(i + 1) //返回数据类型:12
val typeName = rsmd.getColumnTypeName(i + 1) //返回数据类型的名称:VARCHAR
val fieldSize = rsmd.getPrecision(i + 1) //返回字段大小:64
val fieldScale = rsmd.getScale(i + 1) //返回scale:0
val isSigned = { //判断是否有符号
try {
rsmd.isSigned(i + 1) //是否有符号:false
} catch {
// Workaround for HIVE-14684:
case e: SQLException if
e.getMessage == "Method not supported" &&
rsmd.getClass.getName == "org.apache.hive.jdbc.HiveResultSetMetaData" => true
}
}
val nullable = if (alwaysNullable) { //判断是否可为空
true
} else {
rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls
}
val metadata = new MetadataBuilder().putLong("scale", fieldScale)
val columnType =
dialect.getCatalystType(dataType, typeName, fieldSize, metadata).getOrElse(
getCatalystType(dataType, fieldSize, fieldScale, isSigned)) // 传入参数拿到类型:StringType
fields(i) = StructField(columnName, columnType, nullable) //传入列名,数据类型,是否可为空,创建StructField,并加入到fields中
i = i + 1
}
new StructType(fields) //传入所有的StructField构建StructType,并返回,到这里拿到最终的Schema
}

JDBCRelation (阶段一: 拿Schema)

private[sql] case class JDBCRelation(
override val schema: StructType, //拿到Schema
parts: Array[Partition], //得到分区
jdbcOptions: JDBCOptions)(@transient val sparkSession: SparkSession)
extends BaseRelation //实现BaseRelation,必然拿到了Schema
with PrunedFilteredScan //实现裁剪并且过滤的扫描表
with InsertableRelation { //实现插入的模式

override def sqlContext: SQLContext = sparkSession.sqlContext

override val needConversion: Boolean = false

//检查JDBCRDD.compileFilter是否可以接受输入过滤器
override def unhandledFilters(filters: Array[Filter]): Array[Filter] = {
if (jdbcOptions.pushDownPredicate) {
filters.filter(JDBCRDD.compileFilter(_, JdbcDialects.get(jdbcOptions.url)).isEmpty)
} else {
filters
}
}

// 构建Scan
override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { //requiredColumns:需要的列,filters:过滤条件
// 依赖类型擦除:将RDD[InternalRow]传递回RDD[Row]
JDBCRDD.scanTable(
sparkSession.sparkContext, //上下文环境
schema, //Schema
requiredColumns, //需要的列
filters, //过滤条件
parts, //分区
jdbcOptions).asInstanceOf[RDD[Row]] //最终的结果转换成RDD[Row]类型
}

override def insert(data: DataFrame, overwrite: Boolean): Unit = {
data.write
.mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append)
.jdbc(jdbcOptions.url, jdbcOptions.tableOrQuery, jdbcOptions.asProperties)
}

override def toString: String = {
val partitioningInfo = if (parts.nonEmpty) s" [numPartitions=${parts.length}]" else ""
// 计划输出中不应包含凭据,表信息就足够了。
s"JDBCRelation(${jdbcOptions.tableOrQuery})" + partitioningInfo
}
}

scanTable (阶段二: 拿RDD[Row])

def scanTable(
sc: SparkContext,
schema: StructType,
requiredColumns: Array[String],
filters: Array[Filter],
parts: Array[Partition],
options: JDBCOptions): RDD[InternalRow] = {
val url = options.url //拿到客户端传入的rul
val dialect = JdbcDialects.get(url) //拿到方言
val quotedColumns = requiredColumns.map(colName => dialect.quoteIdentifier(colName)) //拿到需要的列
new JDBCRDD( //传入参数,返回RDD[InternalRow]
sc,
JdbcUtils.createConnectionFactory(options),
pruneSchema(schema, requiredColumns),
quotedColumns,
filters,
parts,
url,
options)
}
}

JDBCRDD(阶段二: 拿RDD[Row])

// 表示通过JDBC访问的数据库中的表的RDD。
private[jdbc] class JDBCRDD(
sc: SparkContext,
getConnection: () => Connection,
schema: StructType,
columns: Array[String],
filters: Array[Filter],
partitions: Array[Partition],
url: String,
options: JDBCOptions)
extends RDD[InternalRow](sc, Nil) {

// 索引与此RDD对应的分区列表。
override def getPartitions: Array[Partition] = partitions

// `columns` 作为一个字符串注入到SQL查询
private val columnList: String = {
val sb = new StringBuilder()
columns.foreach(x => sb.append(",").append(x))
if (sb.isEmpty) "1" else sb.substring(1)
}

// `filters`, 作为一个where语句注入到SQL查询
private val filterWhereClause: String =
filters
.flatMap(JDBCRDD.compileFilter(_, JdbcDialects.get(url)))
.map(p => s"($p)").mkString(" AND ")

// 如果当前分区有where语句,那么就拼接
private def getWhereClause(part: JDBCPartition): String = {
if (part.whereClause != null && filterWhereClause.length > 0) {
"WHERE " + s"($filterWhereClause)" + " AND " + s"(${part.whereClause})"
} else if (part.whereClause != null) {
"WHERE " + part.whereClause
} else if (filterWhereClause.length > 0) {
"WHERE " + filterWhereClause
} else {
""
}
}

// 对JDBC驱动程序运行SQL查询。
override def compute(thePart: Partition, context: TaskContext): Iterator[InternalRow] = {
var closed = false
var rs: ResultSet = null
var stmt: PreparedStatement = null
var conn: Connection = null

def close() {
if (closed) return
try {
if (null != rs) {
rs.close()
}
} catch {
case e: Exception => logWarning("Exception closing resultset", e)
}
try {
if (null != stmt) {
stmt.close()
}
} catch {
case e: Exception => logWarning("Exception closing statement", e)
}
try {
if (null != conn) {
if (!conn.isClosed && !conn.getAutoCommit) {
try {
conn.commit()
} catch {
case NonFatal(e) => logWarning("Exception committing transaction", e)
}
}
conn.close()
}
logInfo("closed connection")
} catch {
case e: Exception => logWarning("Exception closing connection", e)
}
closed = true
}

context.addTaskCompletionListener[Unit]{ context => close() }

val inputMetrics = context.taskMetrics().inputMetrics
val part = thePart.asInstanceOf[JDBCPartition]
conn = getConnection()
val dialect = JdbcDialects.get(url)
import scala.collection.JavaConverters._
dialect.beforeFetch(conn, options.asProperties.asScala.toMap)

// 这在通过JDBC读取表/查询之前执行一个通用的SQL语句(或PL/SQL块)。
// 使用此功能初始化数据库会话环境,例如用于优化和/或故障排除。
options.sessionInitStatement match {
case Some(sql) =>
val statement = conn.prepareStatement(sql)
logInfo(s"Executing sessionInitStatement: $sql")
try {
statement.setQueryTimeout(options.queryTimeout)
statement.execute() //最终执行的就是jdbc
} finally {
statement.close()
}
case None =>
}
// 返回RDD[InternalRow]
CompletionIterator[InternalRow, Iterator[InternalRow]](
new InterruptibleIterator(context, rowsIterator), close())
}
}

最终一套debug走下来,其实就是两步

  1. 第二步通过jdbc查元数据,拿到Schema
  2. 第二步通过jdbc查数据拿到RDD[Row]

最终的创建DataFrame由框架解决

Author: Tunan
Link: http://yerias.github.io/2019/10/16/spark/16/
Copyright Notice: All articles in this blog are licensed under CC BY-NC-SA 4.0 unless stating additionally.