XGBoost4J:可移植的分布式 XGBoost(支持 Spark、Flink 和 Dataflow)
引言
XGBoost 是一个为树增强(tree boosting)而设计并优化的库。梯度增强树模型(Gradient boosting trees model)最初由 Friedman 等人提出。通过支持多线程和引入正则化,XGBoost 提供了更高的计算能力和更准确的预测。在 Kaggle 上举办的机器学习挑战赛中,超过半数的获胜方案采用了 XGBoost(不完整列表)。XGBoost 为 C++、R、python、Julia 和 Java 用户提供了原生接口。它被用于数据探索和生产环境,以解决现实世界的机器学习问题。
分布式 XGBoost 在最近发表的论文中有所描述。简而言之,XGBoost 系统的运行速度比现有分布式机器学习替代方案快几个数量级,并且使用的资源也少得多。非常欢迎读者参考该论文了解更多详情。
尽管目前已取得巨大成功,但我们的最终目标之一是让 XGBoost 更广泛地应用于所有生产场景。基于 Java 虚拟机 (JVM) 的编程语言和数据处理/存储系统在大数据生态系统中扮演着重要角色。Hadoop、Spark 以及最近推出的Flink 是非常有用的通用大规模数据处理解决方案。
另一方面,机器学习和深度学习的兴起激发了许多优秀的机器学习库。许多这些机器学习库(例如 XGBoost/MxNet)需要新的计算抽象和原生支持(例如用于 GPU 计算的 C++)。它们通常也效率高得多。
通用数据处理框架与更专业的机器学习库/系统之间的实现基础存在差距,这阻碍了这两类系统之间的平滑连接,从而给最终用户带来了不必要的麻烦。用户的常见工作流程是利用 Spark/Flink 等系统进行数据预处理/清洗,通过文件系统将结果传递给 XGBoost/MxNet)等机器学习系统,然后进行后续的机器学习阶段。这种跨越两类系统的过程给用户带来了一定的不便,并给基础设施的运维人员带来了额外的开销。
我们希望兼得两全其美,因此我们可以将 Spark 和 Flink 等数据处理框架与最佳的分布式机器学习解决方案结合使用。为了解决这一问题,我们推出了全新打造的 XGBoost4J,即适用于 JVM 平台的 XGBoost。我们的目标是提供简洁的 Java/Scala API,并与最流行的基于 JVM 语言开发的数据处理系统集成。
机器学习中的 Unix 哲学
XGBoost 和 XGBoost4J 遵循 Unix 哲学。XGBoost 在一件事上做到极致——树增强,并且被设计为与其他系统协同工作。我们坚信机器学习解决方案不应受限于特定的语言或平台。
具体来说,用户将能够在 Spark 和 Flink 中使用分布式 XGBoost,未来还可能支持更多框架。我们以可移植的方式设计了 API,以便它可以轻松移植到云提供的其他数据流框架。XGBoost4J 与其他 XGBoost 库共享其核心,这意味着数据科学家可以使用 R/python 读取和可视化分布式训练的模型。这也意味着用户可以从单机版本开始探索,该版本已经可以处理数亿个样本。
系统概览
在下图中,我们描述了 XGBoost4J 的整体架构。XGBoost4J 提供了调用 XGBoost 库核心功能的 Java/Scala API。最重要的是,它不仅支持单机模型训练,还提供了一个抽象层,屏蔽了底层数据处理引擎的差异,并将训练扩展到分布式服务器。
通过调用 XGBoost4J API,用户可以将模型训练扩展到集群。XGBoost4J 调用 Spark/Flink 任务中运行的 XGBoost 工作进程实例,并在集群中运行它们。分布式模型训练任务与 XGBoost4J 运行时环境之间的通信通过 [Rabit] (https://github.com/dmlc/rabit) 进行。
借助 XGBoost4J 的抽象,用户可以构建一个统一的数据分析应用程序,涵盖从提取-转换-加载 (ETL)、数据探索、机器学习模型训练到最终数据产品服务的所有阶段。下图展示了一个基于 Apache Spark 构建的应用示例。该应用程序将 XGBoost 无缝嵌入到处理管道中,并通过 Spark 的分布式内存层与其他基于 Spark 的处理阶段交换数据。
单机训练演练
在本节中,我们将通过示例介绍 XGBoost4J 的 API。我们将使用 Scala 进行演示,但我们也为 Java 用户提供了完整的 API。
要开始模型训练和评估,我们需要准备训练集和测试集
val trainMax = new DMatrix("../../demo/data/agaricus.txt.train")
val testMax = new DMatrix("../../demo/data/agaricus.txt.test")
准备好数据后,我们可以训练模型
val params = new mutable.HashMap[String, Any]()
params += "eta" -> 1.0
params += "max_depth" -> 2
params += "objective" -> "binary:logistic"
val watches = new mutable.HashMap[String, DMatrix]
watches += "train" -> trainMax
watches += "test" -> testMax
val round = 2
// train a model
val booster = XGBoost.train(trainMax, params.toMap, round, watches.toMap)
然后我们评估模型
val predicts = booster.predict(testMax)
predict
可以输出预测结果,您可以定义自定义评估方法来得出您自己的指标(请参阅示例:Java 中的自定义评估指标, [Scala 中的自定义评估指标] (https://github.com/dmlc/xgboost/blob/master/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/CustomObjective.scala))。
使用分布式数据流框架进行分布式模型训练
本次 XGBoost4J 版本中最令人兴奋的部分是与分布式数据流框架的集成。最流行的数据处理框架都属于此类,例如 Apache Spark, [Apache Flink] (http://flink.apache.org/) 等。在本部分,我们将介绍如何构建包含数据预处理和分布式模型训练的统一数据分析应用程序,并结合 Spark 和 Flink 进行演示。(目前,我们仅为与 Spark 和 Flink 的集成提供了 Scala API)
与单机训练类似,我们需要准备训练和测试数据集。
Spark 示例
在 Spark 中,数据集表示为弹性分布式数据集 (RDD),我们可以利用 Spark 的分布式工具解析 libSVM 文件并将其封装为 RDD
val trainRDD = MLUtils.loadLibSVMFile(sc, inputTrainPath).repartition(args(1).toInt)
我们继续训练模型
val xgboostModel = XGBoost.train(trainRDD, paramMap, numRound, numWorkers)
下一步是评估模型,您可以在本地或以分布式方式进行预测
// testSet is an RDD containing testset data represented as
// org.apache.spark.mllib.regression.LabeledPoint
val testSet = MLUtils.loadLibSVMFile(sc, inputTestPath)
// local prediction
// import methods in DataUtils to convert Iterator[org.apache.spark.mllib.regression.LabeledPoint]
// to Iterator[ml.dmlc.xgboost4j.LabeledPoint] in automatic
import DataUtils._
xgboostModel.predict(new DMatrix(testSet.collect().iterator)
// distributed prediction
xgboostModel.predict(testSet)
Flink 示例
在 Flink 中,我们将训练数据表示为 Flink 的DataSet
val trainData = MLUtils.readLibSVM(env, "/path/to/data/agaricus.txt.train")
模型训练可以按如下方式进行
val xgboostModel = XGBoost.train(trainData, paramMap, round)
训练和预测。
// testData is a Dataset containing testset data represented as
// org.apache.flink.ml.math.Vector.LabeledVector
val testData = MLUtils.readLibSVM(env, "/path/to/data/agaricus.txt.test")
// local prediction
xgboostModel.predict(testData.collect().iterator)
// distributed prediction
xgboostModel.predict(testData.map{x => x.vector})
路线图
这是 XGBoost4J 包的第一个版本,我们正在积极推进,以便在下一个版本中带来更多吸引人的特性。您可以在XGBoost4J 路线图中查看我们的进展。
尽管我们正在尽最大努力保持 API 的最小变动,但仍然可能存在不兼容的更改。
进一步阅读
如果您对 XGBoost 了解更多感兴趣,可以在以下资源中找到丰富的信息
致谢
我们要特别感谢 Zixuan Huang,他是 XGBoost for Java 的早期开发者。