Spark 自定義外部數(shù)據(jù)源

翻譯自:Extending Spark Datasource API: write a custom spark datasource

Data Source API

Basic Interfaces

  • BaseRelation:展示從DataFrame中產(chǎn)生的底層數(shù)據(jù)源的關(guān)系或者表。定義如何產(chǎn)生schema信息?;蛘哒f是數(shù)據(jù)源的關(guān)系。
  • RelationProvider:獲取參數(shù)列表,返回一個BaseRelation對象。
  • TableScan:對數(shù)據(jù)的schame信息,進行完整掃描,返回一個沒有過濾的RDD。
  • DataSourceRegister:定義數(shù)據(jù)源的簡寫。

Providers

  • SchemaRelationProvider:用戶可以自定義schema信息。
  • CreatableRelationProvider:用戶可以定義從DataFrame中產(chǎn)生新的Relation。

Scans

  • PrunedScan:自定義方法,刪除不需要的列。
  • PrunedFilteredScan:自定義方法,刪除不需要的列,并且對列的值進行過濾。
  • CatalystScan:用于試驗與查詢計劃程序的更直接連接的界面。

Relations

  • InsertableRelation:插入數(shù)據(jù)。三個假設(shè):1.插入方法提供的數(shù)據(jù)與BaseRelation中定義的Schame信息匹配到;2.schema信息不變;3.插入方法中的數(shù)據(jù)都是可以為null的。
  • HadoopFsRelation:Hadoop 文件系統(tǒng)的數(shù)據(jù)源。

Output Interfaces

如果使用HadoopFsRelation,會使用到這一塊。

準備工作

數(shù)據(jù)格式

使用文本數(shù)據(jù)作為數(shù)據(jù)源,文件中的數(shù)據(jù)都是以都好分割,行之間以回車為分隔符,數(shù)據(jù)的格式為:

//編號,名稱,性別(1為男性,0為女性),工資,費用
10001,Alice,0,30000,12000

創(chuàng)建項目

在IDEA中創(chuàng)建一個maven項目,添加相應(yīng)的spark、scala依賴。

<properties>
    <scala.version>2.11.8</scala.version>
    <spark.version>2.2.0</spark.version>
</properties>

<repositories>
    <repository>
        <id>scala-tools.org</id>
        <name>Scala-Tools Maven2 Repository</name>
        <url>http://scala-tools.org/repo-releases</url>
    </repository>
</repositories>

<pluginRepositories>
    <pluginRepository>
        <id>scala-tools.org</id>
        <name>Scala-Tools Maven2 Repository</name>
        <url>http://scala-tools.org/repo-releases</url>
    </pluginRepository>
</pluginRepositories>

<dependencies>
    <dependency>
        <groupId>org.scala-lang</groupId>
        <artifactId>scala-library</artifactId>
        <version>${scala.version}</version>
    </dependency>
    <dependency>
        <groupId>junit</groupId>
        <artifactId>junit</artifactId>
        <version>4.4</version>
    </dependency>
    <dependency>
        <groupId>org.specs</groupId>
        <artifactId>specs</artifactId>
        <version>1.2.5</version>
        <scope>test</scope>
    </dependency>
    <dependency>
        <groupId>org.apache.spark</groupId>
        <artifactId>spark-sql_2.11</artifactId>
        <version>${spark.version}</version>
    </dependency>

</dependencies>

開始編寫自定義數(shù)據(jù)源

創(chuàng)建Schema信息

為了自定義Schema信息,必須要創(chuàng)建一個DefaultSource的類(源碼規(guī)定,如果不命名為DefaultSource,會報找不到DefaultSource類的錯誤)。
還需要繼承RelationProvider和SchemaRelationProvider。RelationProvider用來創(chuàng)建數(shù)據(jù)的關(guān)系,SchemaRelationProvider用來明確schema信息。
在編寫DefaultSource.scala文件時,如果文件存在的情況下,需要創(chuàng)建相應(yīng)的Relation來根據(jù)路徑讀取文件。

DefaultSource.scala文件代碼:

class DefaultSource
    extends RelationProvider
    with SchemaRelationProvider {
  override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = {
    createRelation(sqlContext,parameters,null)
  }

  override def createRelation(sqlContext: SQLContext, parameters: Map[String, String], schema: StructType): BaseRelation = {
    val path=parameters.get("path")
    path match {
      case Some(p) => new TextDataSourceRelation(sqlContext,p,schema)
      case _ => throw new IllegalArgumentException("Path is required for custom-datasource format!!")
    }
  }
}

在編寫Relation時,需要實現(xiàn)BaseRelation來重寫自定數(shù)據(jù)源的schema信息。如果是parquet/csv/json文件,可以直接獲取schema信息。
然后實現(xiàn)序列化接口,為了網(wǎng)絡(luò)傳輸。

TextDataSourceRelation.scala文件的代碼:

class TextDataSourceRelation(override val sqlContext : SQLContext, path : String, userSchema : StructType)
    extends BaseRelation
      with Serializable {
  override def schema: StructType = {
    if(userSchema!=null){
      userSchema
    }else{
      StructType(
        StructField("id",IntegerType,false) ::
        StructField("name",StringType,false) ::
        StructField("gender",StringType,false) ::
        StructField("salary",LongType,false) ::
        StructField("expenses",LongType,false) :: Nil)
    }
  }
}

根據(jù)上面編寫代碼,可以簡單測試一下是否可以拿到正確的schema信息。
在編寫測試方法時,使用sqlContext.read來讀取文件,使用format參數(shù)來指定自定義數(shù)據(jù)源的包路徑,使用printSchema()驗證是否可以拿到相應(yīng)的schema信息。

object TestApp extends App {
  println("Application Started...")
  val conf=new SparkConf().setAppName("spark-custom-datasource")
  val spark=SparkSession.builder().config(conf).master("local[2]").getOrCreate()
  val df=spark.sqlContext.read.format("com.edu.spark.text").load("/Users/Downloads/data")
  println("output schema...")
  df.printSchema()
  println("Application Ended...")
}

輸出的schema信息如下:

output schema...
root
 |-- id: integer (nullable = false)
 |-- name: string (nullable = false)
 |-- gender: string (nullable = false)
 |-- salary: long (nullable = false)
 |-- expenses: long (nullable = false)

通過輸出的schema,與自己定義的schema一致。

讀取數(shù)據(jù)

為了讀取數(shù)據(jù),TextDataSourceRelation需要實現(xiàn)TableScan,實現(xiàn)buildScan()方法。
這個方法會將數(shù)據(jù)以Row組成的RDD的形式返回數(shù)據(jù),每一個Row表示一行數(shù)據(jù)。
在讀取文件時,使用WholeTextFiles根據(jù)指定的路徑來讀取文件,返回的形式為(文件名,內(nèi)容)。
在讀取數(shù)據(jù)之后,然后按照逗號分割數(shù)據(jù),將性別這個字段根據(jù)數(shù)字轉(zhuǎn)換為相應(yīng)的字符串,然后根據(jù)在shema信息,轉(zhuǎn)換為相應(yīng)的類型。

轉(zhuǎn)換的代碼如下:

object Util {
  def castTo(value : String, dataType : DataType) ={
    dataType match {
      case _ : IntegerType => value.toInt
      case _ : LongType => value.toLong
      case _ : StringType => value
    }
  }
}

實現(xiàn)TableScan的代碼:

override def buildScan(): RDD[Row] = {
    println("TableScan: buildScan called...")
    val schemaFields = schema.fields
    // Reading the file's content
    val rdd = sqlContext.sparkContext.wholeTextFiles(path).map(f => f._2)
    
    val rows = rdd.map(fileContent => {
      val lines = fileContent.split("\n")
      val data = lines.map(line => line.split(",").map(word => word.trim).toSeq)
      val tmp = data.map(words => words.zipWithIndex.map{
        case (value, index) =>
          val colName = schemaFields(index).name
          Util.castTo(
            if (colName.equalsIgnoreCase("gender")) {
              if(value.toInt == 1) {
                "Male"
              } else {
                "Female"
              }
            } else {
              value
            }, schemaFields(index).dataType)
      })
      tmp.map(s => Row.fromSeq(s))
    })
    rows.flatMap(e => e)
}

測試是否可以讀取到數(shù)據(jù)的代碼:

object TestApp extends App {
  println("Application Started...")
  val conf=new SparkConf().setAppName("spark-custom-datasource")
  val spark=SparkSession.builder().config(conf).master("local[2]").getOrCreate()
  val df=spark.sqlContext.read.format("com.edu.spark.text").load("/Users/Downloads/data")
  df.show()
  println("Application Ended...")
}

拿到的數(shù)據(jù)為:

+-----+---------------+------+------+--------+
|   id|           name|gender|salary|expenses|
+-----+---------------+------+------+--------+
|10002|    Alice Heady|Female| 20000|    8000|
|10003|    Jenny Brown|Female| 30000|  120000|
|10004|     Bob Hayden|  Male| 40000|   16000|
|10005|    Cindy Heady|Female| 50000|   20000|
|10006|     Doug Brown|  Male| 60000|   24000|
|10007|Carolina Hayden|Female| 70000|  280000|
+-----+---------------+------+------+--------+

寫數(shù)據(jù)

本代碼有兩種編寫方法:自定義格式和Json。
繼承CreateTableRelationProvider,實現(xiàn)createRelation方法。

class DefaultSource
    extends RelationProvider
    with SchemaRelationProvider
    with CreatableRelationProvider {
  override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = {
    createRelation(sqlContext,parameters,null)
  }

  override def createRelation(sqlContext: SQLContext, parameters: Map[String, String], schema: StructType): BaseRelation = {
    val path=parameters.get("path")
    path match {
      case Some(p) => new TextDataSourceRelation(sqlContext,p,schema)
      case _ => throw new IllegalArgumentException("Path is required for custom-datasource format!!")
    }
  }

  override def createRelation(sqlContext: SQLContext, mode: SaveMode, parameters: Map[String, String], data: DataFrame): BaseRelation = {
    val path = parameters.getOrElse("path", "./output/") //can throw an exception/error, it's just for this tutorial
    val fsPath = new Path(path)
    val fs = fsPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration)

    mode match {
      case SaveMode.Append => sys.error("Append mode is not supported by " + this.getClass.getCanonicalName); sys.exit(1)
      case SaveMode.Overwrite => fs.delete(fsPath, true)
      case SaveMode.ErrorIfExists => sys.error("Given path: " + path + " already exists!!"); sys.exit(1)
      case SaveMode.Ignore => sys.exit()
    }

    val formatName = parameters.getOrElse("format", "customFormat")
    formatName match {
      case "customFormat" => saveAsCustomFormat(data, path, mode)
      case "json" => saveAsJson(data, path, mode)
      case _ => throw new IllegalArgumentException(formatName + " is not supported!!!")
    }
    createRelation(sqlContext, parameters, data.schema)
  }
  private def saveAsJson(data : DataFrame, path : String, mode: SaveMode): Unit = {
    /**
      * Here, I am using the dataframe's Api for storing it as json.
      * you can have your own apis and ways for saving!!
      */
    data.write.mode(mode).json(path)
  }

  private def saveAsCustomFormat(data : DataFrame, path : String, mode: SaveMode): Unit = {
    /**
      * Here, I am  going to save this as simple text file which has values separated by "|".
      * But you can have your own way to store without any restriction.
      */
    val customFormatRDD = data.rdd.map(row => {
      row.toSeq.map(value => value.toString).mkString("|")
    })
    customFormatRDD.saveAsTextFile(path)
  }
}

測試代碼:

object TestApp extends App {
  println("Application Started...")
  val conf=new SparkConf().setAppName("spark-custom-datasource")
  val spark=SparkSession.builder().config(conf).master("local[2]").getOrCreate()
  val df=spark.sqlContext.read.format("com.edu.spark.text").load("/Users/Downloads/data")
  //save the data
    df.write.options(Map("format" -> "customFormat")).mode(SaveMode.Overwrite).format("com.edu.spark.text").save("/Users//Downloads/out_custom/")
    df.write.options(Map("format" -> "json")).mode(SaveMode.Overwrite).format("com.edu.spark.text").save("/Users//Downloads/out_json/")
    df.write.mode(SaveMode.Overwrite).format("com.edu.spark.text").save("/Users//Downloads/out_none/")
  println("Application Ended...")
}

輸出的結(jié)果:

自定義格式:

10002|Alice Heady|Female|20000|8000
10003|Jenny Brown|Female|30000|120000
10004|Bob Hayden|Male|40000|16000
10005|Cindy Heady|Female|50000|20000
10006|Doug Brown|Male|60000|24000
10007|Carolina Hayden|Female|70000|280000

Json格式:

{"id":10002,"name":"Alice Heady","gender":"Female","salary":20000,"expenses":8000}
{"id":10003,"name":"Jenny Brown","gender":"Female","salary":30000,"expenses":120000}
{"id":10004,"name":"Bob Hayden","gender":"Male","salary":40000,"expenses":16000}
{"id":10005,"name":"Cindy Heady","gender":"Female","salary":50000,"expenses":20000}
{"id":10006,"name":"Doug Brown","gender":"Male","salary":60000,"expenses":24000}
{"id":10007,"name":"Carolina Hayden","gender":"Female","salary":70000,"expenses":280000}

修建列

繼承PrunedScan,實現(xiàn)buildScan方法,只展示需要的項。

override def buildScan(requiredColumns: Array[String]): RDD[Row] = {
    println("PrunedScan: buildScan called...")
    
    val schemaFields = schema.fields
    // Reading the file's content
    val rdd = sqlContext.sparkContext.wholeTextFiles(path).map(f => f._2)
    
    val rows = rdd.map(fileContent => {
      val lines = fileContent.split("\n")
      val data = lines.map(line => line.split(",").map(word => word.trim).toSeq)
      val tmp = data.map(words => words.zipWithIndex.map{
        case (value, index) =>
          val colName = schemaFields(index).name
          val castedValue = Util.castTo(if (colName.equalsIgnoreCase("gender")) {if(value.toInt == 1) "Male" else "Female"} else value,
            schemaFields(index).dataType)
          if (requiredColumns.contains(colName)) Some(castedValue) else None
      })
    
      tmp.map(s => Row.fromSeq(s.filter(_.isDefined).map(value => value.get)))
    })
    
    rows.flatMap(e => e)
}

測試代碼:

object TestApp extends App {
  println("Application Started...")
  val conf=new SparkConf().setAppName("spark-custom-datasource")
  val spark=SparkSession.builder().config(conf).master("local[2]").getOrCreate()
  val df=spark.sqlContext.read.format("com.edu.spark.text").load("/Users/Downloads/data")
  //select some specific columns
  df.createOrReplaceTempView("test")
  spark.sql("select id, name, salary from test").show()
  println("Application Ended...")
}

輸出的結(jié)果為:

+-----+---------------+------+
|10002|    Alice Heady| 20000|
|10003|    Jenny Brown| 30000|
|10004|     Bob Hayden| 40000|
|10005|    Cindy Heady| 50000|
|10006|     Doug Brown| 60000|
|10007|Carolina Hayden| 70000|
+-----+---------------+------+

過濾

繼承PrunedFilterScan,實現(xiàn)buildScan方法,只展示需要的項。

override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
    println("PrunedFilterScan: buildScan called...")
    
    println("Filters: ")
    filters.foreach(f => println(f.toString))
    
    var customFilters: Map[String, List[CustomFilter]] = Map[String, List[CustomFilter]]()
    filters.foreach( f => f match {
      case EqualTo(attr, value) =>
        println("EqualTo filter is used!!" + "Attribute: " + attr + " Value: " + value)
    
        /**
          * as we are implementing only one filter for now, you can think that this below line doesn't mak emuch sense
          * because any attribute can be equal to one value at a time. so what's the purpose of storing the same filter
          * again if there are.
          * but it will be useful when we have more than one filter on the same attribute. Take the below condition
          * for example:
          * attr > 5 && attr < 10
          * so for such cases, it's better to keep a list.
          * you can add some more filters in this code and try them. Here, we are implementing only equalTo filter
          * for understanding of this concept.
          */
        customFilters = customFilters ++ Map(attr -> {
          customFilters.getOrElse(attr, List[CustomFilter]()) :+ new CustomFilter(attr, value, "equalTo")
        })
      case GreaterThan(attr,value) =>
        println("GreaterThan Filter is used!!"+ "Attribute: " + attr + " Value: " + value)
        customFilters = customFilters ++ Map(attr -> {
          customFilters.getOrElse(attr, List[CustomFilter]()) :+ new CustomFilter(attr, value, "greaterThan")
        })
      case _ => println("filter: " + f.toString + " is not implemented by us!!")
    })
    
    val schemaFields = schema.fields
    // Reading the file's content
    val rdd = sqlContext.sparkContext.wholeTextFiles(path).map(f => f._2)
    
    val rows = rdd.map(file => {
      val lines = file.split("\n")
      val data = lines.map(line => line.split(",").map(word => word.trim).toSeq)
    
      val filteredData = data.map(s => if (customFilters.nonEmpty) {
        var includeInResultSet = true
        s.zipWithIndex.foreach {
          case (value, index) =>
            val attr = schemaFields(index).name
            val filtersList = customFilters.getOrElse(attr, List())
            if (filtersList.nonEmpty) {
              if (CustomFilter.applyFilters(filtersList, value, schema)) {
              } else {
                includeInResultSet = false
              }
            }
        }
        if (includeInResultSet) s else Seq()
      } else s)
    
      val tmp = filteredData.filter(_.nonEmpty).map(s => s.zipWithIndex.map {
        case (value, index) =>
          val colName = schemaFields(index).name
          val castedValue = Util.castTo(if (colName.equalsIgnoreCase("gender")) {
            if (value.toInt == 1) "Male" else "Female"
          } else value,
            schemaFields(index).dataType)
          if (requiredColumns.contains(colName)) Some(castedValue) else None
      })
    
      tmp.map(s => Row.fromSeq(s.filter(_.isDefined).map(value => value.get)))
    })
    
    rows.flatMap(e => e)
}

測試代碼:

object TestApp extends App {
  println("Application Started...")
  val conf=new SparkConf().setAppName("spark-custom-datasource")
  val spark=SparkSession.builder().config(conf).master("local[2]").getOrCreate()
  val df=spark.sqlContext.read.format("com.edu.spark.text").load("/Users/Downloads/data")
  //filter data
  df.createOrReplaceTempView("test")
  spark.sql("select id,name,gender from test where salary == 50000").show()

  println("Application Ended...")
}

輸出的結(jié)果為:

+-----+-----------+------+
|   id|       name|gender|
+-----+-----------+------+
|10005|Cindy Heady|Female|
+-----+-----------+------+

注冊自定義數(shù)據(jù)源

實現(xiàn)DataSourceRegister的shortName。

實現(xiàn)代碼如下:

override def shortName(): String = "udftext"

然后在resource目錄下,創(chuàng)建文件為META-INF/services/org.apache.spark.sql.sources.DataSourceRegister,文件內(nèi)容如下:

com.edu.spark.text.DefaultSource

測試代碼如下:

object TestApp extends App {
  println("Application Started...")
  val conf=new SparkConf().setAppName("spark-custom-datasource")
  val spark=SparkSession.builder().config(conf).master("local[2]").getOrCreate()
  val df=spark.sqlContext.read.format("udftext").load("/Users/Downloads/data")
  df.show()
  println("Application Ended...")
}

輸出的結(jié)果為:

+-----+---------------+------+------+--------+
|   id|           name|gender|salary|expenses|
+-----+---------------+------+------+--------+
|10002|    Alice Heady|Female| 20000|    8000|
|10003|    Jenny Brown|Female| 30000|  120000|
|10004|     Bob Hayden|  Male| 40000|   16000|
|10005|    Cindy Heady|Female| 50000|   20000|
|10006|     Doug Brown|  Male| 60000|   24000|
|10007|Carolina Hayden|Female| 70000|  280000|
+-----+---------------+------+------+--------+

編寫相應(yīng)的DataFrameReader來簡寫自定義的數(shù)據(jù)源,代碼如下:

object ReaderObject {
  implicit class UDFTextReader(val reader: DataFrameReader) extends AnyVal{
    def udftext(path:String) = reader.format("udftext").load(path)
  }
}

測試代碼(需要將隱士轉(zhuǎn)換導入相應(yīng)的DataFrameReader):

object TestApp extends App {
  println("Application Started...")
  val conf=new SparkConf().setAppName("spark-custom-datasource")
  val spark=SparkSession.builder().config(conf).master("local[2]").getOrCreate()
  val df=spark.sqlContext.read.udftext("/Users/Downloads/data")
  df.show()
  println("Application Ended...")
}

輸出結(jié)果與上面一致,不再贅述。

附加CustomFilter.scala代碼

case class CustomFilter(attr : String, value : Any, filter : String)
object CustomFilter {
  def applyFilters(filters : List[CustomFilter], value : String, schema : StructType): Boolean = {
    var includeInResultSet = true

    val schemaFields = schema.fields
    val index = schema.fieldIndex(filters.head.attr)
    val dataType = schemaFields(index).dataType
    val castedValue = Util.castTo(value, dataType)

    filters.foreach(f => {
      val givenValue = Util.castTo(f.value.toString, dataType)
      f.filter match {
        case "equalTo" => {
          includeInResultSet = castedValue == givenValue
          println("custom equalTo filter is used!!")
        }
        case "greaterThan" => {
          includeInResultSet = castedValue.equals(givenValue)
          println("custom greaterThan filter is used!!")
        }
        case _ => throw new UnsupportedOperationException("this filter is not supported!!")
      }
    })

    includeInResultSet
  }
}

參考文獻

  1. Extending Spark Datasource API: write a custom spark datasource(主要是翻譯這片文章)
  2. How to create a custom Spark SQL data source (using Parboiled2)
  3. spark-custom-datasource-example
  4. Implement REST DataSource using Spark DataSource API
  5. spark-custom-datasource
  6. spark-excel
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容