接口 首先了解三个trait,分别是BaseRelation、TableScan/PrunedScan/PrunedFilteredScan、InsertableRelation、RelationProvider,他们的功能在源码中解读。
abstract class BaseRelation { def sqlContext : SQLContext def schema : StructType def sizeInBytes : Long = sqlContext.conf.defaultSizeInBytes def needConversion : Boolean = true def unhandledFilters (filters: Array [Filter ]): Array [Filter ] = filters } @InterfaceStability .Stable trait TableScan { def buildScan (): RDD [Row ] } @InterfaceStability .Stable trait PrunedScan { def buildScan (requiredColumns: Array [String ]): RDD [Row ] } @InterfaceStability .Stable trait PrunedFilteredScan { def buildScan (requiredColumns: Array [String ], filters: Array [Filter ]): RDD [Row ] } @InterfaceStability .Stable trait InsertableRelation { def insert (data: DataFrame , overwrite: Boolean ): Unit } trait RelationProvider { def createRelation (sqlContext: SQLContext , parameters: Map [String , String ]): BaseRelation }
jdbc实现 JdbcRelationProvider (最初也是最终的地方)
class JdbcRelationProvider extends CreatableRelationProvider with RelationProvider with DataSourceRegister { override def shortName (): String = "jdbc" override def createRelation ( sqlContext: SQLContext , parameters: Map [String , String ]): BaseRelation = { val jdbcOptions = new JDBCOptions (parameters) val resolver = sqlContext.conf.resolver val timeZoneId = sqlContext.conf.sessionLocalTimeZone val schema = JDBCRelation .getSchema(resolver, jdbcOptions) val parts = JDBCRelation .columnPartition(schema, resolver, timeZoneId, jdbcOptions) JDBCRelation (schema, parts, jdbcOptions)(sqlContext.sparkSession) } }
getSchema
def getSchema (resolver: Resolver , jdbcOptions: JDBCOptions ): StructType = { val tableSchema = JDBCRDD .resolveTable(jdbcOptions) jdbcOptions.customSchema match { case Some (customSchema) => JdbcUtils .getCustomSchema( tableSchema, customSchema, resolver) case None => tableSchema } }
resolveTable (阶段一: 拿Schema)
def resolveTable (options: JDBCOptions ): StructType = { val url = options.url val table = options.tableOrQuery val dialect = JdbcDialects .get(url) val conn: Connection = JdbcUtils .createConnectionFactory(options)() try { val statement = conn.prepareStatement(dialect.getSchemaQuery(table)) try { statement.setQueryTimeout(options.queryTimeout) val rs = statement.executeQuery() try { JdbcUtils .getSchema(rs, dialect, alwaysNullable = true ) } finally { rs.close() } } finally { statement.close() } } finally { conn.close() } }
getSchema (阶段一: 拿Schema)
def getSchema ( resultSet: ResultSet , dialect: JdbcDialect , alwaysNullable: Boolean = false ): StructType = { val rsmd = resultSet.getMetaData val ncols = rsmd.getColumnCount val fields = new Array [StructField ](ncols) var i = 0 while (i < ncols) { val columnName = rsmd.getColumnLabel(i + 1 ) val dataType = rsmd.getColumnType(i + 1 ) val typeName = rsmd.getColumnTypeName(i + 1 ) val fieldSize = rsmd.getPrecision(i + 1 ) val fieldScale = rsmd.getScale(i + 1 ) val isSigned = { try { rsmd.isSigned(i + 1 ) } catch { case e: SQLException if e.getMessage == "Method not supported" && rsmd.getClass.getName == "org.apache.hive.jdbc.HiveResultSetMetaData" => true } } val nullable = if (alwaysNullable) { true } else { rsmd.isNullable(i + 1 ) != ResultSetMetaData .columnNoNulls } val metadata = new MetadataBuilder ().putLong("scale" , fieldScale) val columnType = dialect.getCatalystType(dataType, typeName, fieldSize, metadata).getOrElse( getCatalystType(dataType, fieldSize, fieldScale, isSigned)) fields(i) = StructField (columnName, columnType, nullable) i = i + 1 } new StructType (fields) }
JDBCRelation (阶段一: 拿Schema)
private [sql] case class JDBCRelation ( override val schema: StructType , //拿到Schema parts: Array [Partition ], //得到分区 jdbcOptions: JDBCOptions )(@transient val sparkSession: SparkSession ) extends BaseRelation //实现BaseRelation,必然拿到了Schema with PrunedFilteredScan //实现裁剪并且过滤的扫描表 with InsertableRelation { override def sqlContext : SQLContext = sparkSession.sqlContext override val needConversion: Boolean = false override def unhandledFilters (filters: Array [Filter ]): Array [Filter ] = { if (jdbcOptions.pushDownPredicate) { filters.filter(JDBCRDD .compileFilter(_, JdbcDialects .get(jdbcOptions.url)).isEmpty) } else { filters } } override def buildScan (requiredColumns: Array [String ], filters: Array [Filter ]): RDD [Row ] = { JDBCRDD .scanTable( sparkSession.sparkContext, schema, requiredColumns, filters, parts, jdbcOptions).asInstanceOf[RDD [Row ]] } override def insert (data: DataFrame , overwrite: Boolean ): Unit = { data.write .mode(if (overwrite) SaveMode .Overwrite else SaveMode .Append ) .jdbc(jdbcOptions.url, jdbcOptions.tableOrQuery, jdbcOptions.asProperties) } override def toString : String = { val partitioningInfo = if (parts.nonEmpty) s" [numPartitions=${parts.length} ]" else "" s"JDBCRelation(${jdbcOptions.tableOrQuery} )" + partitioningInfo } }
scanTable (阶段二: 拿RDD[Row])
def scanTable ( sc: SparkContext , schema: StructType , requiredColumns: Array [String ], filters: Array [Filter ], parts: Array [Partition ], options: JDBCOptions ): RDD [InternalRow ] = { val url = options.url val dialect = JdbcDialects .get(url) val quotedColumns = requiredColumns.map(colName => dialect.quoteIdentifier(colName)) new JDBCRDD ( sc, JdbcUtils .createConnectionFactory(options), pruneSchema(schema, requiredColumns), quotedColumns, filters, parts, url, options) } }
JDBCRDD(阶段二: 拿RDD[Row])
private [jdbc] class JDBCRDD ( sc: SparkContext , getConnection: ( ) => Connection , schema : StructType , columns: Array [String ], filters: Array [Filter ], partitions: Array [Partition ], url: String , options: JDBCOptions ) extends RDD [InternalRow ](sc, Nil ) { override def getPartitions : Array [Partition ] = partitions private val columnList: String = { val sb = new StringBuilder () columns.foreach(x => sb.append("," ).append(x)) if (sb.isEmpty) "1" else sb.substring(1 ) } private val filterWhereClause: String = filters .flatMap(JDBCRDD .compileFilter(_, JdbcDialects .get(url))) .map(p => s"($p )" ).mkString(" AND " ) private def getWhereClause (part: JDBCPartition ): String = { if (part.whereClause != null && filterWhereClause.length > 0 ) { "WHERE " + s"($filterWhereClause )" + " AND " + s"(${part.whereClause} )" } else if (part.whereClause != null ) { "WHERE " + part.whereClause } else if (filterWhereClause.length > 0 ) { "WHERE " + filterWhereClause } else { "" } } override def compute (thePart: Partition , context: TaskContext ): Iterator [InternalRow ] = { var closed = false var rs: ResultSet = null var stmt: PreparedStatement = null var conn: Connection = null def close () { if (closed) return try { if (null != rs) { rs.close() } } catch { case e: Exception => logWarning("Exception closing resultset" , e) } try { if (null != stmt) { stmt.close() } } catch { case e: Exception => logWarning("Exception closing statement" , e) } try { if (null != conn) { if (!conn.isClosed && !conn.getAutoCommit) { try { conn.commit() } catch { case NonFatal (e) => logWarning("Exception committing transaction" , e) } } conn.close() } logInfo("closed connection" ) } catch { case e: Exception => logWarning("Exception closing connection" , e) } closed = true } context.addTaskCompletionListener[Unit ]{ context => close() } val inputMetrics = context.taskMetrics().inputMetrics val part = thePart.asInstanceOf[JDBCPartition ] conn = getConnection() val dialect = JdbcDialects .get(url) import scala.collection.JavaConverters ._ dialect.beforeFetch(conn, options.asProperties.asScala.toMap) options.sessionInitStatement match { case Some (sql) => val statement = conn.prepareStatement(sql) logInfo(s"Executing sessionInitStatement: $sql " ) try { statement.setQueryTimeout(options.queryTimeout) statement.execute() } finally { statement.close() } case None => } CompletionIterator [InternalRow , Iterator [InternalRow ]]( new InterruptibleIterator (context, rowsIterator), close()) } }
最终一套debug走下来,其实就是两步
第二步通过jdbc查元数据,拿到Schema
第二步通过jdbc查数据拿到RDD[Row]
最终的创建DataFrame由框架解决