这里接着上次的解读jdbc数据源,现在我们自己实现一个text的外部数据源
创建DefaultSource类实现RelationProviderTrait,注意这里的类名必须是DefaultSource,源码中写死了
class DefaultSource extends RelationProvider{
override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = {
//拿到client传入的参数path
val path = parameters.get("path")
//判断path是否存在
path match {
case Some(p) =>new TextDataSourceRelation(sqlContext,p)
case _ => throw new IllegalArgumentException("path is required ...")
}
}
}自定义Relation,继承BashRelation和TableScan,拿到Schema和RDD[Row]
class TextDataSourceRelation(context:SQLContext,path:String) extends BaseRelation with TableScan with Logging{
override def sqlContext: SQLContext = context
//重写StructType接口的方式实现Schema
override def schema: StructType = StructType{
List(
StructField("id",StringType,true),
StructField("name",StringType,true),
StructField("sex",StringType,true),
StructField("sal",DoubleType,true),
StructField("comm",DoubleType,true)
)
}
//重写buildScan拿到RDD[Row]
override def buildScan(): RDD[Row] = {
//拿到文本数据
val textRDD: RDD[String] = sqlContext.sparkContext.textFile(path)
//拿到每个StructField
val schemaField: Array[StructField] = schema.fields
//对每行数据逗号切分,并且去掉空格,返回集合
textRDD.map(_.split(",").map(_.trim))
//对集合中的每个元素操作,通过zipWithIndex算子可以拿到元素的内容和对应的索引号
.map(row => row.zipWithIndex.map {
//模式匹配,拿到了value和index,然后对其做操作
case (value, index) => {
//通过schemaField和index拿到列名
val columnName = schemaField(index).name
//判断当前的列名是否是sex,并在工具类中做匹配,对value转换类型
Utils.caseTo(if (columnName.equalsIgnoreCase("sex")) {
//如果列名是sex,列下元素是1、2或者3,则返回对应的字符
if (value == "1") {
"男"
} else if (value == "2") {
"女"
} else {
"未知"
}
//如果列名不是sex,则直接返回元素
} else {
value
//传入dataType的类型,在工具类中做匹配,使value与schema的类型一致
}, schemaField(index).dataType)
}
//结果是个集合,转换成RDD[Row]
}).map(x => Row.fromSeq(x))
}
}自定义Utils类
object Utils {
def caseTo(value:String,dataType: DataType) ={
//模式匹配,转换value的类型
dataType match {
case _:DoubleType => value.toDouble
case _:LongType => value.toLong
case _:StringType => value
}
}
}测试
object Test {
def main(args: Array[String]): Unit = {
val spark = SparkSession
.builder()
.master("local[2]")
.appName(this.getClass.getSimpleName)
.getOrCreate()
val textDF: DataFrame = spark.read.format("com.tunan.spark.sql.extds.text").load("tunan-spark-sql/extds")
textDF.printSchema()
textDF.show()
}
}输出结果
root
|-- id: string (nullable = true)
|-- name: string (nullable = true)
|-- sex: string (nullable = true)
|-- sal: double (nullable = true)
|-- comm: double (nullable = true)
ERROR TextDataSourceRelation: 进入buildScan方法
+---+----+----+-------+------+
| id|name| sex| sal| comm|
+---+----+----+-------+------+
| 1|张三| 男|10000.0|1000.0|
| 2|李四| 男|12000.0|2000.0|
| 3|王五| 女|12500.0|1000.0|
| 4|赵六|未知|20000.0|2000.0|
| 5|图南| 男|21000.0|1000.0|
| 6|小七| 女|10000.0|1500.0|
+---+----+----+-------+------+