本论文从分布式系统的角度开展针对当前一些机器学习平台的研究,综述了这些平台所使用的架构设计,对这些平台在通信和控制上的瓶颈、容错性和开发难度进行分析和对比,并对分布式机器学习平台的未来研究工作提出了一些建议。文中的工作由MuratDemirbas教授与他的研究生KuoZhang和SalemAlqahtani共同完成。
机器学习,特别是深度学习,已在语音识别、图像识别和自然语言处理以及近期在推荐及搜索引擎等领域上取得了革命性的成功。这些技术在无人驾驶、数字医疗系统、CRM、广告、物联网等领域具有很好的应用前景。当然,是资金引领和驱动了技术的加速推进,使得我们在近期看到了一些机器学习平台的推出。
考虑到训练中所涉及的数据集和模型的规模十分庞大,机器学习平台通常是分布式平台,部署了数十个乃至数百个并行运行的计算节点对模型做训练。据估计在不远的将来,数据中心的大多数任务都会是机器学习任务。
我来自于分布式系统研究领域,因此我们考虑从分布式系统的角度开展针对这些机器学习平台的研究,分析这些平台在通信和控制上的瓶颈。我们还考虑了这些平台的容错性和易编程性。
我们从设计方法上将机器学习平台划分为三个基本类别,分别是:基本数据流、参数-服务器模型和高级数据流。
下面我们将对每类方法做简要介绍,以ApacheSpark为例介绍基本数据流,以PMLS(Petuum)为例介绍参数服务器模型,而高级数据流则使用TensorFlow和MXNet为例。我们对比了上述各平台的性能并给出了一系列的评估结果。要了解详细的评估结果,可参考我们的论文。遗憾的是,作为一个小型研究团队,我们无法开展大规模的评估。
在本篇博文的最后,我给出了一些结论性要点,并对分布式机器学习平台的未来研究工作提出了一些建议。对这些分布式机器学习平台已有一定了解的读者,可以直接跳到本文结尾。
Spark
在Spark中,计算被建模为一种有向无环图(DAG),图中的每个顶点表示一个RDD,每条边表示了RDD上的一个操作。RDD由一系列被切分的对象(Partition)组成,这些被切分的对象在内存中存储并完成计算,也会在Shuffle过程中溢出(Overflow)到磁盘上
在DAG中,一条从顶点A到B的有向边E,表示了RDDB是在RDDA上执行操作E的结果。操作分为转换(Transformation)和动作(Action)两类。转换操作(例如map、filter和join)应用于某个RDD上,转换操作的输出是一个新的RDD。
Spark用户将计算建模为DAG,该DAG表示了在RDD上执行的转换和动作。DAG进而被编译为多个Stage。每个Stage执行为一系列并行运行的任务(Task),每个分区(Partition)对应于一个任务。这里,有限(Narrow)的依赖关系将有利于计算的高效执行,而宽泛(Wide)的依赖关系则会引入瓶颈,因为这样的依赖关系引入了通信密集的Shuffle操作,这打断了操作流。
Spark的分布式执行是通过将DAGStage划分到不同的计算节点实现的。上图清晰地展示了这种主机(master)-工作者(worker)架构。驱动器(Driver)包含有两个调度器(Scheduler)组件,即DAG调度器和任务调度器。调度器对工作者分配任务,并协调工作者。
Spark是为通用数据处理而设计的,并非专用于机器学习任务。要在Spark上运行机器学习任务,可以使用MLlibforSpark。如果采用基本设置的Spark,那么模型参数存储在驱动器节点上,在每次迭代后通过工作者和驱动器间的通信更新参数。如果是大规模部署机器学习任务,那么驱动器可能无法存储所有的模型参数,这时就需要使用RDD去容纳所有的参数。这将引入大量的额外开销,因为为了容纳更新的模型参数,需要在每次迭代中创建新的RDD。更新模型会涉及在机器和磁盘间的数据Shuffle,进而限制了Spark的扩展性。这正是基本数据流模型(即DAG)的短板所在。Spark并不能很好地支持机器学习中的迭代运算。
PMLS
PMLS是专门为机器学习任务而设计的。它引入了称为参数-服务器(Parameter-Server,PS)的抽象,这种抽象是为了支持迭代密集的训练过程。
PS(在图中以绿色方框所示)以分布式key-value数据表形式存在于内存中,它是可复制和分片的。每个节点(node)都是模型中某个分片的主节点(参数空间),并作为其它分片的二级节点或复制节点。这样PS在节点数量上的扩展性很好。
PS节点存储并更新模型参数,并响应来自于工作者的请求。工作者从自己的本地PS拷贝上请求最新的模型参数,并在分配给它们的数据集分区上执行计算。
PMLS也采用了SSP(StaleSynchronousParallelism)模型。相比于BSP(BulkSynchronousParellelism)模型,SSP放宽了每一次迭代结束时各个机器需做同步的要求。为实现同步,SSP允许工作者间存在一定程度上的不同步,工业机器人维修,并确保了最快的工作者不会领先最慢的工作者s轮迭代以上。由于处理过程处于误差所允许的范围内,这种非严格的一致性模型依然适用于机器学习。我曾经发表过一篇博文专门介绍这一机制。
TensorFlow
Google给出了一个基于分布式机器学习平台的参数服务器模型,称为DistBelief(此处是我对DistBelief论文的综述)。就我所知,大家对DistBelief的不满意之处主要在于,它在编写机器学习应用时需要混合一些底层代码。Google想使其任一雇员都可以在无需精通分布式执行的情况下编写机器学习代码。正是出于同一原因,Google对大数据处理编写了MapReduce框架。
TensorFlow是一种设计用于实现这一目标的平台。它采用了一种更高级的数据流处理范式,其中表示计算的图不再需要是DAG,图中可以包括环,并支持可变状态。我认为TensorFlow的设计在一定程度上受到了Naiad设计理念的影响。
TensorFlow将计算表示为一个由节点和边组成的有向图。节点表示计算操作或可变状态(例如Variable),边表示节点间通信的多维数组,这种多维数据称为Tensor。TensorFlow需要用户静态地声明逻辑计算图,并通过将图重写和划分到机器上实现分布式计算。需说明的是,MXNet,特别是DyNet,使用了一种动态定义的图。这简化了编程,并提高了编程的灵活性。
如上图所示,在TensorFlow中,分布式机器学习训练使用了参数-服务器方法。当在TensorFlow中使用PS抽象时,就使用了参数-服务器和数据并行。TensorFlow声称可以完成更复杂的任务,但是这需要用户编写代码以通向那些未探索的领域。
MXNet
MXNet是一个协同开源项目,源自于在2015年出现的CXXNet、Minverva和Purines等深度学习项目。类似于TensorFlow,MXNet也是一种数据流系统,支持具有可变状态的有环计算图,并支持使用参数-服务器模型的训练计算。同样,MXNet也对多个CPU/GPU上的数据并行提供了很好的支持,并可实现模型并行。MXNet支持同步的和异步的训练计算。下图显示了MXNet的主要组件。其中,运行时依赖引擎分析计算过程中的依赖关系,对不存在相互依赖关系的计算做并行处理。MXNet在运行时依赖引擎之上提供了一个中间层,用于计算图和内存的优化。
MXNet使用检查点机制支持基本的容错,提供了对模型的save和load操作。save操作将模型参数写入到检查点文件,load操作从检查点文件中读取模型参数。