天天看点

[源码解析] 深度学习分布式训练框架 horovod (21) --- 之如何恢复训练

本文以 PyTorch on Horovod 为切入点,分析一下 Horovod 弹性训练的恢复流程,具体涉及知识点有:ElasticSampler与PyTorch 原生DistributedSampler 的区别,Horovod 弹性训练如何恢复等。

目录

[源码解析] 深度学习分布式训练框架 horovod (21) --- 之如何恢复训练

0x00 摘要

0x01 总论

0x02 Sampler

2.1 PyTorch Distributed Optimizer

2.1.1 定义

2.1.2 问题点

2.2 ElasticSampler

2.2.1 定义

2.2.2 弹性方案

2.2.2.1 常规流程

2.2.2.2 异常处理

2.2.1 如何使用

2.2.1.1 主体代码

2.2.1.2 训练代码

0x03 保存和定期检查

3.1 定期保存

3.2 异常处理

3.3 Commit

0x04 State

4.1 恢复训练

4.2 TorchState

4.3 设置 handler

4.4 SamplerStateHandler

4.5 保存

4.6 HostsUpdatedInterrupt

4.7 HorovodInternalError

4.8 ElasticSampler.iter

0xFF 参考

本文以 PyTorch on Horovod 为切入点,分析一下 Horovod 弹性训练的恢复流程,具体涉及知识点有:

ElasticSampler与PyTorch 原生DistributedSampler 的区别,Horovod 弹性训练如何恢复等。

本系列其他文章链接如下:

[源码解析] 深度学习分布式训练框架 Horovod (1) --- 基础知识

[源码解析] 深度学习分布式训练框架 horovod (2) --- 从使用者角度切入

[源码解析] 深度学习分布式训练框架 horovod (3) --- Horovodrun背后做了什么

[源码解析] 深度学习分布式训练框架 horovod (4) --- 网络基础 & Driver

[源码解析] 深度学习分布式训练框架 horovod (5) --- 融合框架

[源码解析] 深度学习分布式训练框架 horovod (6) --- 后台线程架构

[源码解析] 深度学习分布式训练框架 horovod (7) --- DistributedOptimizer

[源码解析] 深度学习分布式训练框架 horovod (8) --- on spark

[源码解析] 深度学习分布式训练框架 horovod (9) --- 启动 on spark

[源码解析] 深度学习分布式训练框架 horovod (10) --- run on spark

[源码解析] 深度学习分布式训练框架 horovod (11) --- on spark --- GLOO 方案

[源码解析] 深度学习分布式训练框架 horovod (12) --- 弹性训练总体架构

[源码解析] 深度学习分布式训练框架 horovod (13) --- 弹性训练之 Driver

[源码解析] 深度学习分布式训练框架 horovod (14) --- 弹性训练发现节点 & State

[源码解析] 深度学习分布式训练框架 horovod (15) --- 广播 & 通知

[源码解析] 深度学习分布式训练框架 horovod (16) --- 弹性训练之Worker生命周期

[源码解析] 深度学习分布式训练框架 horovod (17) --- 弹性训练之容错

[源码解析] 深度学习分布式训练框架 horovod (18) --- kubeflow tf-operator

[源码解析] 深度学习分布式训练框架 horovod (19) --- kubeflow MPI-operator

[源码解析] 深度学习分布式训练框架 horovod (20) --- Elastic Training Operator

本文缘起于一个兄弟的留言:

请问在弹性训练中,如果节点数目发生变化,数据怎么重新划分呢?比如一个epoch还没有进行完,这时添加了新节点,新数据重新划分的话,当前内存中用旧数据训练的模型还有效吗?

我恰好在分析PyTorch分布式的时候也有类似疑问,所以就回头再看看Horovod是如何实现的。

我们之前对于 Horovod 的分析和示例大多以 TensorFlow 为例。大家对各种框架如何在Horovod之中适配的总体逻辑和思路应该有了一个大致的认识,所以我们本部分主要看看一些PyTorch 相关的特殊之处。

使用PyTorch做切入的另外一个原因是:在恢复训练这个流程上,PyTorch相关部分确实相对清晰明确。

在 horovod/torch/elastic/ 目录下,有两个文件 :state.py 和 sampler.py。既然是弹性相关,所以我们先来看看其特殊之处。

在 horovod/torch/elastic/sampler.py 之中,有一个 ElasticSampler 类,我们看看具体针对弹性做了哪些处理。

因为 ElasticSampler 类之中注明,它的实现非常类似<code>DistributedSampler</code>,也就是 PyTorch 原生的实现,所以我们要先看看 <code>DistributedSampler</code>。

<code>DistributedSampler</code>代码位于:torch/distributed/optim/optimizer.py。

总结一下DistributedSampler的分配方法是:每段连续的 <code>num_replicas</code> 个数据被拆成一个一个,分给 <code>num_replicas</code> 个进程,这样就达到了不重叠不交叉的目的,但也要注意的是:这样每个进程拿到的数据是不连续的。

<code>__iter__</code> 代码的一个技术细节是 本worker如何遍历?

<code>indices = indices[self.rank:self.total_size:self.num_replicas]</code>

这里,num_replicas 实际就是rank的总数,起始位置是self.rank,结束位置是总数据长度,按照num_replicas(就是world size)作为步长来递增,所以这里每个worker就会严格返回自己rank对应的那部分数据序号。

我们用一个例子来看看,比如:

得到:

具体代码如下:

DistributedSampler 如果直接用到 弹性训练,是有一定问题的,让我们分析一下,有几个问题:

如果用户已经训练了5轮,那么就意味着已经使用了前面5个批次的数据。假设此时加入了新的worker节点,那么就应该恢复训练。那么对于已经使用过的前面 5 个批次的数据,按说就不应该再次被用来训练了。

问题1: 恢复训练之后,应该怎么去除已经处理的数据index?

如果加入或者减少节点,如果告诉 Sampler,我们需要更改提取规则,最起码,num_replicas 需要被更新,以后按照新的 num_replicas 进行提取,比如原来5个节点,num_replicas = 5,现在6个节点,num_replicas 应该为 6。

问题2: 恢复训练之后,何时调用 <code>__iter__</code>以进行新的训练?

问题3: 恢复训练之后,何时修改 num_replicas?

我们看看 DistributedSampler 就会发现,其<code>__iter__</code>之中,没有任何保存状态的相关信息。即如果重新开始训练,依然会从全体数据中提取,而非从剩余数据中提取。也没有发现对后面两个问题的解决办法。

因此,很难利用 DistributedSampler进行弹性训练,所以 Horovod 就使用 ElasticSampler 来解决这个问题。

从注释中我们可以看到,ElasticSampler 自称与 DistributedSampler 非常类似。我们随后针对两个类代码比较可以看到,功能基本一致。

但是有两个新加入的变量值得注意,即:

定义如下:

具体弹性方案就围绕之前提到的两个变量来进行。

我们回忆其注释中提到的如何使用:

我们可以推导出来其内在逻辑:

进行本 epoch 训练。

当使用 <code>__iter__</code> 获取下一批次数据时候,<code>self.indices = self.remaining_indices[:]</code> 就会 只从未训练的数据里面提取。

每处理一个批次数据 之后,用户使用 <code>record_batch</code> 或者 <code>record_indices</code> 来把已经训练完的数据批次信息保存在 <code>processed_indices</code>。这样就记录了已经训练完的数据。

如果产生了问题,或者有节点变更,则:

会调用 reset 函数,reset 会把已经训练完的数据 <code>processed_indices</code> 从总数据中移除,剩下的 <code>self.remaining_indice</code>就是没有训练的数据。

恢复训练, 只从未训练的数据里面提取。

当完成这个epoch 之后,会调用 <code>set_epoch</code> 来重置 <code>processed_indices</code>,也会调用 reset 方法进行清零。

具体功能代码是:

在 horovod/torch/elastic/state.py 之中,当重新训练时候,会调用到 ElasticSampler 的 load_state_dict 方法。

而 load_state_dict 之中,会调用 reset,这样就把已经训练完的数据移除,得到的数据都是没有经过训练的。

所以重新训练时候,本epoch之内,不会用已经训练的数据再次重复训练。

我们后续会详细分析这个流程。

ElasticSampler 的使用如下,代码位于:examples/elastic/pytorch/pytorch_imagenet_resnet50_elastic.py。

本节我们主要介绍如何使用,就是正常使用/处理流程,后续会介绍异常处理,这里省略部分次要代码。

主体代码主要注意就是使用ElasticSampler分别配置了两个弹性采样器。

以下代码是具体训练代码。

某一个epoch具体逻辑(正常处理)如下:

如果是最初运行,则调用reset进行初始化,其中会依据 dataset 长度构建一个 index list。用这个index list 减去 processed_indices ,就得到了本次epoch应该处理的数据 index,赋值给 remaining_indices,就是剩下来应该处理的数据index;

在 <code>__iter__</code> 函数中,调用 <code>self.indices = self.remaining_indices[:]</code> ,这样 indices 就可以用来做迭代提取;

训练函数中,调用 iter(indices) 进行迭代提取,然后调用 record_indices 把本次使用过的index 更新到 processed_indices 之中。processed_indices 就记录了目前使用的所有index;

epoch 结束之后,调用 set_epoch 进行重置,即给 processed_indices 清零,调用 reset 重置 remaining_indices;

Hovorod 建议用户定周期性调用 state.commit() 来把状态(state)备份到内存。

定期备份非常有用。在某些worker发生意外错误时,定期备份可以避免因为状态被损坏而在重新训练时候无法恢复现场。比如,如果一个worker刚好在更新参数过程中突然出错,此时部分梯度更新完毕,部分梯度可能只更新到一半,这个状态是不可逆转而又无法继续。因此,当此状态发生时,会抛出一个 HorovodInternalError 异常,当 hvd.elastic.run 捕获到这个异常后,会利用最新一次commit中恢复所有状态。

因为commit状态代价高昂(比如如参数量太大会导致耗时过长),所以需要在"每个batch的处理时间"与"如果出错,训练需要从多久前的状态恢复"之间选取一个平衡点。比如,如果你每训练10个batches就commit一次,你就把复制时间降低了10倍。但是当发生错误时,你需要回滚到10个batches前的状态。

Elastic Horowod可以通过执行我们称之为“优雅地移除worker”操作来避免这些回滚。如果driver进程发现主机已可用或标记为删除,它将向所有workers推送一个通知。于是在下次调用state.commit()或更轻量级的state.check_host_updates()时,一个HostsUpdatedInterrupt异常将被抛出。此异常的处理方式与“HorovodInternalError”类似,只是参数状态不会还原到上次commit,而是从当前实时参数中恢复。

一般来说,如果你的硬件设施是可靠与稳定的,并且你的编排系统会在任务节点移除时提供足够的告警,你就可低频次调用 state.commit() 函数,同时只在每个batch结束时调用相对不耗时的 state.check_host_updates() 来检查节点变更情况。

具体示例代码如下:

我们可以看到,HorovodInternalError 和 HostsUpdatedInterrupt 这两个异常最大的区别:

HorovodInternalError 异常:当 hvd.elastic.run 捕获到这个异常后,会利用最新一次commit中恢复所有状态。

HostsUpdatedInterrupt 异常:处理方式与“HorovodInternalError”类似,只是参数状态不会还原到上次commit,而是从当前实时参数中恢复。

之所以要强调这个,因为后面就要介绍如何做到不同恢复。

在用户调用 State.commit 的时候,有两个动作:一个是保存状态。一个是调用 check_host_updates 检查更新。

这里 save 就会调用到 State 的 save 操作,结合本文,就是下面要介绍的 TorchState 的 save 操作。

另外,check_host_updates 会抛出HostsUpdatedInterrupt异常。HostsUpdatedInterrupt 异常里面,是否需要 sync,从下面 check_host_updates 代码可以看出来,就是如果节点数目有变化了,就需要sync。HostUpdateResult.removed 数值为1,这里其实可以改进,HostUpdateResult.removed 在目前这个情况之下,设定过细了。

我们接下来介绍异常处理逻辑,具体围绕着 State 来介绍。对于State,我们先回忆一下其在恢复训练时候的逻辑。

重新训练时候,会抛出两种异常:

如果是 ring allreduce 相关,就转为抛出异常 HorovodInternalError(e)。

如果当驱动进程通过节点发现脚本发现一个节点被标记为新增或者移除时,会抛出异常 HostsUpdatedInterrupt。

然后会进行如下处理:

逻辑如下:

因为这里涉及了大量的state操作,所以我们接下来要看看 TorchState:

首先,我们要看看 TorchState 如何使用。当调用时候,使用如下方法来生成一个TorchState:

其次,我们看看 TorchState 的定义,这里的 sync,restore,reset方法就在恢复训练中被调用。

在初始化函数 <code>__init__</code> 之中,会设置 handler,以我们的调用为例,就是 train_sampler,val_sampler这两个对应的sampler会配置对应的handler,即SamplerStateHandler。

TorchState 继承了 ObjectState,ObjectState 继承了 State,所以前面提到的 commit 代码中的 self.save(),就会调用到TorchState.save,而这里又会调用到 SamplerStateHandler.save。

基类代码中有:

上节中,我们可以看到,无论是reset,还是restore,都会调用到 _handlers 来进行处理,所以我们需要进一步分析。

首先就是如何设置handler。具体参见如下代码,主要是通过一个全局配置 _handler_registry 来指定哪个 handler 处理哪种类型实例,比如这里有 <code>(ElasticSampler, SamplerStateHandler)</code>,就代表着 SamplerStateHandler 是用来处理 ElasticSampler的 handler。

既然知道了 ElasticSampler 由 SamplerStaeHandler 处理,就来分析一下 SamplerStateHandler。

初始化之后,self.value 就是 sampler,针对我们之前的分析,就是ElasticSampler。

SamplerStateHandler 具体代码是,这里需要注意的是:初始化时候,会把ElasticSampler的状态保存起来,以后如果出错,会用此来恢复。

同时,save 也会被调用,用来恢复,我们马上就会分析。

SamplerStateHandler 的 基类是:

我们拓展一下save相关操作序列。

TorchState 继承了 ObjectState,ObjectState 继承了 State,所以:

前面提到的 commit 代码中的 self.save(),就会调用到TorchState.save。

而TorchState.save又会调用到 SamplerStateHandler.save。

SamplerStateHandler.save 会保存 ElasticSampler 的属性和数据,就是保存了 ElasticSampler 的 epoch 和 processed_indices。

这样,在定期 commit 的时候,就定期保存了模型的状态和 ElasticSampler 的状态,这些会在恢复训练中用到。具体下图所示:

只看静态定义,还是很难理解,需要分析动态流程。因为有两种异常,所以我们分开剖析。

回忆一下两个异常最大的区别:

如果当驱动进程通过节点发现脚本发现一个节点被标记为新增或者移除时,会抛出异常 HostsUpdatedInterrupt。此时不是关键异常,因此可以继续训练本epoch,只是从后续训练数据中,移除本epoch已经处理的数据。因此可以做到 参数状态不会还原到上次commit,而是从当前实时参数中恢复。

下面代码之中,我们只保留 HostsUpdatedInterrupt 相关代码。

发生异常之后,

1)HostsUpdatedInterrupt 表示本 epoch 需要继续训练,所以进行异常处理,其中只是会:

1.1) 记录本异常处理是否需要同步 :skip_sync = e.skip_sync。

2)这个步骤主要是重启 hvd,对worker数目进行更改。具体是调用 State 自身的 reset() 方法(代码位于<code>horovod/torch/elastic/__init__.py</code>),其中会:

2.1) 调用 shutdown() 来结束本次任务。

2.2) 调用 init(),从而调用_basics.init,最终重新建立 MPI 相关 context,所以 hvd.size() 就根据最新的worker数目进行了更改。后续 <code>ElasticSampler.__iter__</code> 之中会相应修改num_replicas。

3)这个步骤是把已经训练完的数据移除,得到的数据都是没有经过训练的。如果需要同步,则会调用 state.sync() ,其会调用 SamplerStateHandler.sync 方法,其内部会:

3.1) SamplerStateHandler会利用集合通信从所有worker中收集processed_indices,赋予给 world_processed_indices,这就是所有workers 已经处理过的数据 index。

3.2) 调用 ElasticSampler.state_dict方法,得到本地 ElasticSampler.epoch 和 ElasticSampler.processed_indices 的引用。然后将 world_processed_indices 赋值给 state_dict['processed_indices'],这样,本地 ElasticSampler.processed_indices 就是所有workers 已经处理过的数据 index。

3.3) <code>self.value.load_state_dict(broadcast_object(state_dict))</code> 有两步操作:

广播,这样在同步之后,所有worker都有同样的 state_dict['processed_indices'] 数据了。

load_state_dict 会再调用一次 ElasticSampler.reset,此次 reset 会更改 <code>num_replicas</code>,也会从总数据中去除<code>processed_indices</code>,得到新的 <code>remaining_indices</code>, 从而 后续 <code>__iter__</code> 之中,就会相应对提取index 的策略进行相应更改。

4)所以这样就把已经训练完的数据移除,所以得到的 remaining_indices 数据都是没有经过训练的。所以重新训练时候,本epoch之内,不会用已经训练的数据再次重复训练,而是从当前实时参数中恢复。

重新训练会调用 return func(state, *args, **kwargs) 进行训练,这里会处理 <code>ElasticSampler.__iter__</code> 。

具体逻辑如下:

手机如下:

[源码解析] 深度学习分布式训练框架 horovod (21) --- 之如何恢复训练

如果是 ring allreduce 相关,就转为抛出异常 HorovodInternalError(e)。HorovodInternalError 是关键异常,此时本 epoch 现有状态其实意义不大,应该利用最新一次commit中恢复所有状态。

下面代码之中,我们只保留 HorovodInternalError 相关代码。

HorovodInternalError 和 HostsUpdatedInterrupt 的代码路径几乎一样,只是多了一步 state.restore() 。

这里为啥也要走查看节点变化这个代码路径呢?因为Horovod是定期检查节点变化,所以可能产生HorovodInternalError时候,也有节点变化了,只是还没有发现而已,所以可以一并处理了。

具体逻辑为:

1)HorovodInternalError 表示本 epoch 需要恢复训练,所以先进行异常处理:

1.1)state.restore() 会调用 SamplerStateHandler.restore(这里是与HostsUpdatedInterrupt处理差异之处)。

进而调用 ElasticSampler.load_state_dict方法,会用在<code>SamplerStateHandler.__init__</code> 或者<code>SamplerStateHandler.save</code> 之中原始保存的数据来恢复 ElasticSampler。保存的数据就是 processed_indices 和 epoch。

ElasticSampler.load_state_dict方法 进而会调用 ElasticSampler.reset方法,使用 processed_indices 把已经训练完的数据移除,最新得到的 remaining_indices 数据都是没有经过训练的(针对上次保存的 processed_indices 来说)。

1.2) 记录本异常处理需要同步 : skip_sync = False。

2)这个步骤主要是重启 hvd。调用 State 自身的 reset() 方法(代码位于<code>horovod/torch/elastic/__init__.py</code>),其中会:

2.2) 调用 init(),从而调用_basics.init,最终重新建立 MPI 相关 context。

3)这个步骤是把已经训练完的数据移除,得到的数据都是没有经过训练的。因为这里需要同步,所以会调用 state.sync() ,其会调用 SamplerStateHandler.sync 方法,其内部会:

3.1) SamplerStateHandler会利用集合通信从所有worker中收集processed_indices,赋予给 world_processed_indices,这就是所有workers 已经处理过的数据 index。需要注意的是:因为是使用在<code>__init__</code> 或者 <code>save</code>之中原始保存的数据来恢复,所以其实这一步是恢复到上次commit状态。

3.3) 这里 <code>self.value.load_state_dict(broadcast_object(state_dict))</code> 有两步操作:

4)这样就是恢复到epoch 上次 commit 的状态进行训练。

具体逻辑如下图:

[源码解析] 深度学习分布式训练框架 horovod (21) --- 之如何恢复训练

到目前为止,我们还有一个问题没有仔细分析,就是何时调用 <code>ElasticSampler.__iter__</code>

我们仔细梳理一下:

以下是弹性训练总体逻辑:

弹性逻辑使用注解来封装了full_train,所以 func 就是 full_train。

我们看看 train 的主要代码:

所以我们可以理出来总体逻辑:

当出错恢复时候,train 会再次被调用,调用时候就会使用 enumerate(train_loader)调用到 <code>ElasticSampler.__iter__</code>。

num_replicas 在之前 reset 时候已经被设置,所以此时就是根据新的 world size 和 remaining_indices 重新确定提取数据的策略。

具体逻辑如下,其中

1)在 reset 之中设置了num_replicas。

2)在 <code>ElasticSampler.__iter__</code> 之中根据新的 world size 和 remaining_indices 重新确定提取数据的策略。

[源码解析] 深度学习分布式训练框架 horovod (21) --- 之如何恢复训练

至此,弹性训练如何恢复就分析完毕,以后可能结合 Pytorch 分布式 optimizer 来继续分析。

PyTorch 中文手册(2)-自动求导

pytorch中优化器optimizer.param_groups

PyTorch学习笔记6--案例2:PyTorch神经网络(MNIST CNN)

https://github.com/chenyuntc/pytorch-book