xgboost算法示範
import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier
import org.apache.spark.ml.feature.{StringIndexer, VectorAssembler}
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructType}
import scala.collection.mutable.ArrayBuffer
/**
* xgboost算法示範
*/
// this example works with Iris dataset (https://archive.ics.uci.edu/ml/datasets/iris)
object SparkTraining {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder()
.appName(this.getClass.getSimpleName)
.master("local[5]")
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
.getOrCreate()
val inputPath = "data/xgboost/sample.csv"
//定義資料的schema
val schema = new StructType(Array(
StructField("sepal length", DoubleType, true),
StructField("sepal width", DoubleType, true),
StructField("petal length", DoubleType, true),
StructField("petal width", DoubleType, true),
StructField("class", StringType, true)))
//讀取資料集
val rawInput = spark.read.schema(schema).csv(inputPath)
/**
* +------------+-----------+------------+-----------+-----------+
* |sepal length|sepal width|petal length|petal width|class |
* +------------+-----------+------------+-----------+-----------+
* |5.1 |3.5 |1.4 |0.2 |Iris-setosa|
* |4.9 |3.0 |1.4 |0.2 |Iris-setosa|
* |4.7 |3.2 |1.3 |0.2 |Iris-setosa|
* |4.6 |3.1 |1.5 |0.2 |Iris-setosa|
* |5.0 |3.6 |1.4 |0.2 |Iris-setosa|
* |5.4 |3.9 |1.7 |0.4 |Iris-setosa|
* |4.6 |3.4 |1.4 |0.3 |Iris-setosa|
* |5.0 |3.4 |1.5 |0.2 |Iris-setosa|
* |4.4 |2.9 |1.4 |0.2 |Iris-setosa|
* |4.9 |3.1 |1.5 |0.1 |Iris-setosa|
* +------------+-----------+------------+-----------+-----------+
*/
// rawInput.show(10,false)
// transform class to index to make xgboost happy
val stringIndexer = new StringIndexer()
.setInputCol("class")
.setOutputCol("classIndex")
.fit(rawInput)
val labelTransformed = stringIndexer.transform(rawInput).drop("class")
/**
* +------------+-----------+------------+-----------+----------+
* |sepal length|sepal width|petal length|petal width|classIndex|
* +------------+-----------+------------+-----------+----------+
* |5.1 |3.5 |1.4 |0.2 |0.0 |
* |4.9 |3.0 |1.4 |0.2 |0.0 |
* |4.7 |3.2 |1.3 |0.2 |0.0 |
* |4.6 |3.1 |1.5 |0.2 |0.0 |
* |5.0 |3.6 |1.4 |0.2 |1.0 |
* |5.4 |3.9 |1.7 |0.4 |1.0 |
* |4.6 |3.4 |1.4 |0.3 |2.0 |
* |5.0 |3.4 |1.5 |0.2 |2.0 |
* |4.4 |2.9 |1.4 |0.2 |2.0 |
* |4.9 |3.1 |1.5 |0.1 |2.0 |
* +------------+-----------+------------+-----------+----------+
*/
labelTransformed.show(10, false)
// 将所有特征列轉成向量
val vectorAssembler = new VectorAssembler()
//.setInputCols(Array("sepal length", "sepal width", "petal length", "petal width")).
.setInputCols(getColumnArray(labelTransformed))
.setOutputCol("features")
val xgbInput = vectorAssembler.transform(labelTransformed).select("features", "classIndex")
/**
* +-----------------+----------+
* |features |classIndex|
* +-----------------+----------+
* |[5.1,3.5,1.4,0.2]|0.0 |
* |[4.9,3.0,1.4,0.2]|0.0 |
* |[4.7,3.2,1.3,0.2]|0.0 |
* |[4.6,3.1,1.5,0.2]|0.0 |
* |[5.0,3.6,1.4,0.2]|1.0 |
* |[5.4,3.9,1.7,0.4]|1.0 |
* |[4.6,3.4,1.4,0.3]|2.0 |
* |[5.0,3.4,1.5,0.2]|2.0 |
* |[4.4,2.9,1.4,0.2]|2.0 |
* |[4.9,3.1,1.5,0.1]|2.0 |
* +-----------------+----------+
*/
xgbInput.show(10, false)
//訓練集,預測集
val Array(train, test) = xgbInput.randomSplit(Array(0.9, 0.1))
// 注意!!!這個num_workers 必須小于等于 local[5] 線程數,否則會出現程式卡死現象.
val xgbParam = Map("eta" -> 0.1f,
"max_depth" -> 2,
"objective" -> "multi:softprob",
"num_class" -> 3,
"num_round" -> 100,
"num_workers" -> 5)
// 建立xgboost函數,指定特征向量和标簽
val xgbClassifier = new XGBoostClassifier(xgbParam)
.setFeaturesCol("features")
.setLabelCol("classIndex")
//開始訓練
val xgbClassificationModel = xgbClassifier.fit(train)
//預測
val results = xgbClassificationModel.transform(test)
/**
* +--------------------+----------+--------------------+--------------------+----------+
* | features|classIndex| rawPrediction| probability|prediction|
* +--------------------+----------+--------------------+--------------------+----------+
* |[4.6,3.1,1.5,0.2,...| 0.0|[3.43588137626647...|[0.98977124691009...| 0.0|
* |[4.8,3.4,1.6,0.2,...| 0.0|[3.43588137626647...|[0.98977124691009...| 0.0|
* |[5.0,2.3,3.3,1.0,...| 1.0|[-1.9347994327545...|[0.00610134331509...| 1.0|
* |[5.0,3.2,1.2,0.2,...| 0.0|[3.43588137626647...|[0.98977124691009...| 0.0|
* |[5.5,2.4,3.8,1.1,...| 1.0|[-1.9347994327545...|[0.00610134331509...| 1.0|
* |[5.7,2.9,4.2,1.3,...| 1.0|[-1.9347994327545...|[0.00610134331509...| 1.0|
* |[5.8,2.6,4.0,1.2,...| 1.0|[-1.9347994327545...|[0.00556284701451...| 1.0|
* |[5.8,2.7,5.1,1.9,...| 2.0|[-1.9347994327545...|[0.00450986577197...| 2.0|
* |[6.0,3.4,4.5,1.6,...| 1.0|[-1.9347994327545...|[0.00870351772755...| 1.0|
* |[6.1,2.6,5.6,1.4,...| 2.0|[-1.9347994327545...|[0.00494972383603...| 2.0|
* |[6.1,2.8,4.7,1.2,...| 1.0|[-1.9347994327545...|[0.00601264182478...| 1.0|
* |[6.4,3.1,5.5,1.8,...| 2.0|[-1.9347994327545...|[0.00451971078291...| 2.0|
* |[6.4,3.2,5.3,2.3,...| 2.0|[-1.9347994327545...|[0.00451971078291...| 2.0|
* |[6.7,2.5,5.8,1.8,...| 2.0|[-1.9347994327545...|[0.00450955657288...| 2.0|
* |[6.7,3.0,5.0,1.7,...| 1.0|[-1.9347994327545...|[0.00869614630937...| 1.0|
* |[6.9,3.1,4.9,1.5,...| 1.0|[-1.9347994327545...|[0.00830076728016...| 1.0|
* |[7.2,3.0,5.8,1.6,...| 2.0|[-1.9347994327545...|[0.00496413139626...| 2.0|
* +--------------------+----------+--------------------+--------------------+----------+
*/
results.show()
spark.close()
}
/**
* 擷取所有列轉為Array數組
*
* @param df
* @return
*/
def getColumnArray(df: DataFrame): Array[String] = {
var columns: Array[String] = df.columns.clone()
//drop column : classIndex
columns = columns.dropRight(1)
val featuresColumns = new ArrayBuffer[String]()
for (column <- columns) {
featuresColumns += column
}
featuresColumns.toArray
}
}
maven依賴
<properties>
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
<scala.version>2.11.8</scala.version>
<hadoop.version>2.6.0</hadoop.version>
<!-- <spark.version>2.1.0</spark.version>-->
<spark.version>2.3.0</spark.version>
<xgboost4j.version>0.81</xgboost4j.version>
</properties>
<dependencies>
<!--xgboost4j-->
<dependency>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost4j</artifactId>
<version>${xgboost4j.version}</version>
</dependency>
<dependency>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost4j-spark</artifactId>
<version>${xgboost4j.version}</version>
</dependency>
</dependencies>
遇到的問題
資料集sample.csv
5.1,3.5,1.4,0.2,Iris-setosa
4.9,3.0,1.4,0.2,Iris-setosa
4.7,3.2,1.3,0.2,Iris-setosa
4.6,3.1,1.5,0.2,Iris-setosa
5.0,3.6,1.4,0.2,Iris-setosa
5.4,3.9,1.7,0.4,Iris-setosa
4.6,3.4,1.4,0.3,Iris-setosa
5.0,3.4,1.5,0.2,Iris-setosa
4.4,2.9,1.4,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
5.4,3.7,1.5,0.2,Iris-setosa
4.8,3.4,1.6,0.2,Iris-setosa
4.8,3.0,1.4,0.1,Iris-setosa
4.3,3.0,1.1,0.1,Iris-setosa
5.8,4.0,1.2,0.2,Iris-setosa
5.7,4.4,1.5,0.4,Iris-setosa
5.4,3.9,1.3,0.4,Iris-setosa
5.1,3.5,1.4,0.3,Iris-setosa
5.7,3.8,1.7,0.3,Iris-setosa
5.1,3.8,1.5,0.3,Iris-setosa
5.4,3.4,1.7,0.2,Iris-setosa
5.1,3.7,1.5,0.4,Iris-setosa
4.6,3.6,1.0,0.2,Iris-setosa
5.1,3.3,1.7,0.5,Iris-setosa
4.8,3.4,1.9,0.2,Iris-setosa
5.0,3.0,1.6,0.2,Iris-setosa
5.0,3.4,1.6,0.4,Iris-setosa
5.2,3.5,1.5,0.2,Iris-setosa
5.2,3.4,1.4,0.2,Iris-setosa
4.7,3.2,1.6,0.2,Iris-setosa
4.8,3.1,1.6,0.2,Iris-setosa
5.4,3.4,1.5,0.4,Iris-setosa
5.2,4.1,1.5,0.1,Iris-setosa
5.5,4.2,1.4,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
5.0,3.2,1.2,0.2,Iris-setosa
5.5,3.5,1.3,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
4.4,3.0,1.3,0.2,Iris-setosa
5.1,3.4,1.5,0.2,Iris-setosa
5.0,3.5,1.3,0.3,Iris-setosa
4.5,2.3,1.3,0.3,Iris-setosa
4.4,3.2,1.3,0.2,Iris-setosa
5.0,3.5,1.6,0.6,Iris-setosa
5.1,3.8,1.9,0.4,Iris-setosa
4.8,3.0,1.4,0.3,Iris-setosa
5.1,3.8,1.6,0.2,Iris-setosa
4.6,3.2,1.4,0.2,Iris-setosa
5.3,3.7,1.5,0.2,Iris-setosa
5.0,3.3,1.4,0.2,Iris-setosa
7.0,3.2,4.7,1.4,Iris-versicolor
6.4,3.2,4.5,1.5,Iris-versicolor
6.9,3.1,4.9,1.5,Iris-versicolor
5.5,2.3,4.0,1.3,Iris-versicolor
6.5,2.8,4.6,1.5,Iris-versicolor
5.7,2.8,4.5,1.3,Iris-versicolor
6.3,3.3,4.7,1.6,Iris-versicolor
4.9,2.4,3.3,1.0,Iris-versicolor
6.6,2.9,4.6,1.3,Iris-versicolor
5.2,2.7,3.9,1.4,Iris-versicolor
5.0,2.0,3.5,1.0,Iris-versicolor
5.9,3.0,4.2,1.5,Iris-versicolor
6.0,2.2,4.0,1.0,Iris-versicolor
6.1,2.9,4.7,1.4,Iris-versicolor
5.6,2.9,3.6,1.3,Iris-versicolor
6.7,3.1,4.4,1.4,Iris-versicolor
5.6,3.0,4.5,1.5,Iris-versicolor
5.8,2.7,4.1,1.0,Iris-versicolor
6.2,2.2,4.5,1.5,Iris-versicolor
5.6,2.5,3.9,1.1,Iris-versicolor
5.9,3.2,4.8,1.8,Iris-versicolor
6.1,2.8,4.0,1.3,Iris-versicolor
6.3,2.5,4.9,1.5,Iris-versicolor
6.1,2.8,4.7,1.2,Iris-versicolor
6.4,2.9,4.3,1.3,Iris-versicolor
6.6,3.0,4.4,1.4,Iris-versicolor
6.8,2.8,4.8,1.4,Iris-versicolor
6.7,3.0,5.0,1.7,Iris-versicolor
6.0,2.9,4.5,1.5,Iris-versicolor
5.7,2.6,3.5,1.0,Iris-versicolor
5.5,2.4,3.8,1.1,Iris-versicolor
5.5,2.4,3.7,1.0,Iris-versicolor
5.8,2.7,3.9,1.2,Iris-versicolor
6.0,2.7,5.1,1.6,Iris-versicolor
5.4,3.0,4.5,1.5,Iris-versicolor
6.0,3.4,4.5,1.6,Iris-versicolor
6.7,3.1,4.7,1.5,Iris-versicolor
6.3,2.3,4.4,1.3,Iris-versicolor
5.6,3.0,4.1,1.3,Iris-versicolor
5.5,2.5,4.0,1.3,Iris-versicolor
5.5,2.6,4.4,1.2,Iris-versicolor
6.1,3.0,4.6,1.4,Iris-versicolor
5.8,2.6,4.0,1.2,Iris-versicolor
5.0,2.3,3.3,1.0,Iris-versicolor
5.6,2.7,4.2,1.3,Iris-versicolor
5.7,3.0,4.2,1.2,Iris-versicolor
5.7,2.9,4.2,1.3,Iris-versicolor
6.2,2.9,4.3,1.3,Iris-versicolor
5.1,2.5,3.0,1.1,Iris-versicolor
5.7,2.8,4.1,1.3,Iris-versicolor
6.3,3.3,6.0,2.5,Iris-virginica
5.8,2.7,5.1,1.9,Iris-virginica
7.1,3.0,5.9,2.1,Iris-virginica
6.3,2.9,5.6,1.8,Iris-virginica
6.5,3.0,5.8,2.2,Iris-virginica
7.6,3.0,6.6,2.1,Iris-virginica
4.9,2.5,4.5,1.7,Iris-virginica
7.3,2.9,6.3,1.8,Iris-virginica
6.7,2.5,5.8,1.8,Iris-virginica
7.2,3.6,6.1,2.5,Iris-virginica
6.5,3.2,5.1,2.0,Iris-virginica
6.4,2.7,5.3,1.9,Iris-virginica
6.8,3.0,5.5,2.1,Iris-virginica
5.7,2.5,5.0,2.0,Iris-virginica
5.8,2.8,5.1,2.4,Iris-virginica
6.4,3.2,5.3,2.3,Iris-virginica
6.5,3.0,5.5,1.8,Iris-virginica
7.7,3.8,6.7,2.2,Iris-virginica
7.7,2.6,6.9,2.3,Iris-virginica
6.0,2.2,5.0,1.5,Iris-virginica
6.9,3.2,5.7,2.3,Iris-virginica
5.6,2.8,4.9,2.0,Iris-virginica
7.7,2.8,6.7,2.0,Iris-virginica
6.3,2.7,4.9,1.8,Iris-virginica
6.7,3.3,5.7,2.1,Iris-virginica
7.2,3.2,6.0,1.8,Iris-virginica
6.2,2.8,4.8,1.8,Iris-virginica
6.1,3.0,4.9,1.8,Iris-virginica
6.4,2.8,5.6,2.1,Iris-virginica
7.2,3.0,5.8,1.6,Iris-virginica
7.4,2.8,6.1,1.9,Iris-virginica
7.9,3.8,6.4,2.0,Iris-virginica
6.4,2.8,5.6,2.2,Iris-virginica
6.3,2.8,5.1,1.5,Iris-virginica
6.1,2.6,5.6,1.4,Iris-virginica
7.7,3.0,6.1,2.3,Iris-virginica
6.3,3.4,5.6,2.4,Iris-virginica
6.4,3.1,5.5,1.8,Iris-virginica
6.0,3.0,4.8,1.8,Iris-virginica
6.9,3.1,5.4,2.1,Iris-virginica
6.7,3.1,5.6,2.4,Iris-virginica
6.9,3.1,5.1,2.3,Iris-virginica
5.8,2.7,5.1,1.9,Iris-virginica
6.8,3.2,5.9,2.3,Iris-virginica
6.7,3.3,5.7,2.5,Iris-virginica
6.7,3.0,5.2,2.3,Iris-virginica
6.3,2.5,5.0,1.9,Iris-virginica
6.5,3.0,5.2,2.0,Iris-virginica
6.2,3.4,5.4,2.3,Iris-virginica
5.9,3.0,5.1,1.8,Iris-virginica