pytorch cpu和gpu版本怎么选(Pytorch – 弹性训练原理)
Pytorch在1.9.0引入了torchrun ,用其替代1.9.0以前版本的torch.distributed.launch 。torchrun在torch.distributed.launch 功能的基础上主要新增了两个功能:
Failover: 当worker训练失败时 ,会自动重新启动所有worker继续进行训练;
Elastic: 可以动态增加或或删除node节点;
弹性训练代码同DDP代码编写的思路基本一致 ,只要在DDP代码上增加以下两点即可:
checkpoint处理:由于再每次增加或删除node时 ,会将所有worker kill掉 ,然后再重新启动所有worker进行训练 。因此 ,在训练代码中要对训练的状态进行保存 ,以保证重启后能接着上次的状态继续训练 。
超参调解:由于node节点数的变化 ,会导致global batch size的变化 ,因此我们的learning rate一般也要做相应的调整 ,保证训练出的模质量不受影响 。
代码见第二节 最下面
当编写完弹性训练代码后 ,我们可以使用torchrun来启动弹性训练任务:
--nnodes=1:3 :表示当前训练任务接受最少1个node ,最多3个node参与分布式训练;
--nproc_per_node=4:表示每个node上节点有4个process
--max_restarts=3: worker group最大的重启次数;这里需要注意的是,node fail 、node scale down和node scale up都会导致restart;
--rdzv_id=1:一个unique的job id ,所有node均使用同一个job id;
--rdzv_backend: rendezvous的backend实现 ,默认支持c10d和etcd两种;rendezvous用于多个node之间的通信和协调;
--rdzv_endpoint:rendezvous的地址,应该为一个node的host ip和port;
torchrun \ --nnodes=1:3\ --nproc_per_node=4\ --max_restarts=3\ --rdzv_id=1\ --rdzv_backend=c10d\ --rdzv_endpoint="192.0.0.1:1234"\ train_elastic.py3 整体架构
弹性调度的架构如上图所示 ,其中最关键角色为elastic agent 。在每个Node上面都有一个elastic agent进程 ,其负责管理当前Node上面的所有workers 。
当我们调用torchrun 命令启动弹性训练任务后:
首先 ,elastic agent会触发rendezvous 流程; rendezvous的功能是在所有elastic agent间做协调和同步 ,该接口会一直阻塞直到至少min个elastic agent加入进来后返回;
然后 ,elastic agent会启动当前Node的所有workers
最后 ,elastic agent会监控当前Node上所有workers的运行状态 ,并根据workers的状态进行相应的处理(例如restart worker)
4 Elastic Agent
本小结 ,我们详细分析下Elastic Agent的实现 。Elastic Agent在Pytorch代码中由以下对象构成:
Elastic Agent是抽象基类
SimpleElasticAgent提供了更完整的Agent接口 ,并且实现了部分接口
LocalElasticAgent则是实现剩余的接口
Elastic Agent在代码中的调用逻辑如下:
torch.distributed.launcher.api:launch_agent() 弹性训练逻辑的入口;
首先 、会构建一个RendezvousParameters来描述Rendezvous调用时所需要的参数 ,例如min_nodes/max_nodes/endpoint等;
然后 、构建WorkerSpec描述当前Node上启动Wokers的信息 , 例如max_restart/entrypoint等;
再然后 ,构建LocalElasticAgent对象;
最后,调用LocalElasticAgent的run接口启动当前node的workers进行弹性训练;
Elastic run接口主要由两个部分逻辑组成:
若process group的状态为succeeded:调用_exit_barrier接口等待所有node上agent相应并退出
若process group的状态为unhealthy或failed: 如果重试次数小于_remaining_restart则restart所有worker进程 ,否则stop所有worker ,并退出;
若process group的状态为healthy: 则判断当前是否有node等待加入,如果有则restart_worker;(注:restart worker的实现逻辑是先stop 所有worker ,然后在调用_initialize_workers)
SimpleElasticAgent._initialize_workers:先调用_rendezvous等待至少min 个node加入 ,然后调用_start_workers接口在当前node上启动worker process
while loop monitor worker:while循环 ,监控上一步启动process的状态
5 Rendezvous
5.1 基本概念
Pytorch中Rendezvous的实现涉及到很多概念 ,我们这里先把这些概念一一介绍下 ,然后再介绍Rendezvous的实现这样会清晰很多 。
首先是_RendezvousState ,每个ElasticAgent上都会存储一份_RendezvousState ,并会在必要时进行彼此间的同步 ,_RendezvousState存储的内容如下:
round: The current round of the rendezvous.
complete: A boolean value indicating whether the current round of the rendezvous is complete.
deadline: The time at which the current round of the rendezvous will be considered complete if it is still waiting for nodes to join.
closed: A boolean value indicating whether the rendezvous is closed.
participants: A dictionary of the participants and their corresponding ranks.
wait_list:A set of nodes that are waiting to participate in the next round of the rendezvous.
last_heartbeats: A dictionary containing each nodes last heartbeat time.
那_RendezvousState是如何在所有ElasticAgent间进行同步的呢 ,Pytorch中又提出了Store的概念 ,在Pytorch中有TCPStore 、FileStore和HashStore三种类型 ,在弹性训练场景 ,默认使用TCPStore 。
TCPStore的典型用法如下:
其是一个典型的server-client架构,我们在process1上启动server ,在proess2上启动client ,通过TCPStore的set和get接口可以进行数据的设置和获取
在Rendezvous实现中即是通过TCPStore来对_RendezvousState进行设置和获取的 。
import torch.distributed as dist from datetime import timedelta # Run on process 1 (server) server_store = dist.TCPStore("127.0.0.1", 1234, 2, True, timedelta(seconds=30)) # Run on process 2 (client) client_store = dist.TCPStore("127.0.0.1", 1234, 2, False) # Use any of the store methods from either the client or server after initialization server_store.set("first_key", "first_value") client_store.get("first_key")Pytorch的Rendezvous实现中,通过C10dRendezvousBackend对TCPStore进行了封装 ,并提供了set_state和get_state接口 ,方便state的操作 。(注:Pytorch中还提供了EtcdRendezvousBackend ,该类型的RendezvousBackend通过Etcd来进行_RendezvousState的同步) 。
C10dRendezvousBackend的主要实现如下 ,可以很清晰的看到get_state和set_state的实现 ,均是对store接口的调用.
class C10dRendezvousBackend(RendezvousBackend): def get_state(self) -> Optional[Tuple[bytes, Token]]: """See base class.""" base64_state: bytes = self._call_store("get", self._key) return self._decode_state(base64_state) def set_state( self, state: bytes, token: Optional[Token] = None ) -> Optional[Tuple[bytes, Token, bool]]: """See base class.""" base64_state_str: str = b64encode(state).decode() if token: # Shortcut if we know for sure that the token is not valid. if not isinstance(token, bytes): result = self.get_state() if result is not None: tmp = *result, False # Python 3.6 does not support tuple unpacking in return # statements. return tmp return None token = token.decode() else: token = self._NULL_SENTINEL base64_state: bytes = self._call_store("compare_set", self._key, token, base64_state_str) state_token_pair = self._decode_state(base64_state) if state_token_pair is None: return None new_state, new_token = state_token_pair # C10d Stores compare_set method does not offer an easy way to find out # whether our write attempt was successful. As a brute-force solution we # perform a bitwise comparison of our local state and the remote state. return new_state, new_token, new_state == state def _call_store(self, store_op: str, *args, **kwargs) -> Any: try: return getattr(self._store, store_op)(*args, **kwargs) except (ValueError, RuntimeError, TimeoutError) as exc: raise RendezvousConnectionError( "The connection to the C10d store has failed. See inner exception for details." ) from exc在RendezvousBackend的基础上 ,Pytorch提出了一个更偏向业务层面的概念**_RendezvousStateHolder** ,其提供了_RendezvousState进行获取 、同步 、标记更新的接口 ,这些接口的实现均是调用RendezvousBackend的set_state和get_state完成的。
_RendezvousStateHolder的定义如下:
class _RendezvousStateHolder(ABC): """Holds the shared rendezvous state synced with other nodes.""" def state(self) -> _RendezvousState: """Gets the local state.""" def sync(self) -> Optional[bool]: """Reads or writes the latest state. Returns: A boolean value indicating whether the local state, in case marked as dirty, was successfully synced with other nodes. """ def mark_dirty(self) -> None: """Marks the local state as dirty."""Rendezvous的基础设置都准备好了 ,状态在 _RendezvousState中保存 ,状态的同步通过 _RendezvousStateHolder来完成 ,此时还差一项 ,就是Rendezvous state的是如何变更的 。这个变更通过 _RendezvousXXXOp和 _RendezvousOpExecutor共同来完成 。
Pytorch首先提供了_RendezvousExitOp/_RendezvousJoinOp/_RendezvousCloseOp/_RendezvousKeepAliveOp来对应ElasticAgent的退出 、加入 、Rendezvous关闭和心跳保保持四个操作。这些OP的实现逻辑是根据OP的类型和当前_RendezvousState的内容来决定来返回一个action,_RendezvousOpExecutor则执行对应的action 。
例如_RendezvousExitOp 对应ElasticAgent的退出操作
如果当前节点仍旧在participants列表中 ,则返回一个REMOVE_FROM_PARTICIPANTS ,_RendezvousOpExecutor在接收到这个action后会执行_remove_from_participants逻辑;
如果当前节点没有在participants列表中,返回FINISH ,这个状态_RendezvousOpExecutor不会做任何操作;
class _RendezvousExitOp: """Represents a rendezvous exit operation.""" def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action: if ctx.node in ctx.state.participants: if time.monotonic() > deadline: return _Action.ERROR_TIMEOUT return _Action.REMOVE_FROM_PARTICIPANTS return _Action.FINISH_DistributedRendezvousOpExecutor的核心接口如下:
run提供了执行Rendezvous op的总入口
其他接口则对应了Rendezvous op返回的action的实现 。这些action的实现本质上都是对_RendezvousState内容的修改 ,例如_mark_rendezvous_closed是将_RendezvousState的close字段设置为了True 。
class _DistributedRendezvousOpExecutor: def run(self, state_handler: Callable[[_RendezvousContext, float], _Action], deadline: float,) -> None: def _keep_alive(self) -> None: def _add_to_participants(self) def _add_to_wait_list(self) def _remove_from_participants(self) def _remove_from_wait_list(self) def _mark_rendezvous_complete(self) def _mark_rendezvous_closed(self): self._state.closed = True最后一个要介绍的概念是RendezvousHandler ,其是Rendezvous系统最上层的对外接口 ,ElasticAgent通过该接口来在所有节点间进行协调 。在Pytorch中提供了DynamicRendezvousHandler 、EtcdRendezvousHandler和StaticTCPRendezvous三种实现 ,这里我们仅关注DynamicRendezvousHandler 。
RendezvousHandler中最核心的接口是next_rendezvous ,ElasticAgent会调用该接口来等待至少min个node的加入 。他们实现我们后面再进行讲解 。
上面介绍的这些概念 ,可以通过如下的关系图来进行描述 。
5.2 实现逻辑
在熟系完Rendezvous的基本概念后 ,我们现在可以来看其实现逻辑了 。
首先 ,我们看DynamicRendezvousHandler.next_rendezvous的实现逻辑(注:ElasticAgent通过调用该接口实现的node间的协调) 。DynamicRendezvousHandler.next_rendezvous 一共由5个步骤组成:
DynamicRendezvousHandler._stop_heartbeats():停止先TCPStore的心跳操作 ,通过调用定时器_PeriodicTimer的cancel接口实现;
Execute Exit OP:执行退出逻辑 ,如果当前node已经在participants中了 ,则先把当前节点从_RendezvousState的participants列表中删除;
Execute Join OP: 下图仅描述了一个常规的场景,源码中还有一些特殊情况需要处理;
将自己加入到_RendezvousState的participants列表中;
向TCPStore发起心跳 ,等待至少min个node加入;
当_RendezvousState的participants的个数大于min时 ,mark rendezvous;
此时,Join OP执行完成 ,返回给_RendezvousOpExecutor 个Finish action;
DynamicRendezvousHandler._start_heartbeats(): 开启心跳 ,这个逻辑通过_PeriodicTimer定期执行_RendezvousKeepAliveOp实现;_RendezvousKeepAliveOp的操作则是对_RendezvousState的last_heartbeats进行更新来实现;
DynamicRendezvousHandler._get_world():从_RendezvousState中获取当前rank和work_size信息;
下面我们再看下Rendezvous的OP是如何执行的 。上文提到OP是通过_DistributedRendezvousOpExecutor.run()接口统一来完成的 。
主流程包裹在while循环中 ,直到OP的action为finish方可退出循环;
首先 ,会调用_BackendRendezvousStateHolder.sync()接口在所有node间进行_RendezvousState的同步;
若当前node有内容需要更新 ,则调用C10dRendezvousBackend.set_state()来更新;若没有 ,则调用C10dRendezvousBackend.get_state()来获取最新的state;
若获取了最新的state ,则对当前node上存储的state进行更新;
然后 ,调用当前需要执行的OP ,OP接口会返回一个ACTION ,_DistributedRendezvousOpExecutor则根据ACTION的内容执行keep_alive/add_to_participants/add_to_wait_list等操作;
6 Failover
Failover分为两种情况:
ElasticAgent Process正常 ,但是worker process 出错
ElasticAgent Process 异常退出
6.1 Worker Fail
对于worker fail的场景 ,worker process的异常状态会被ElasticAgent捕获,实现逻辑在SimpleElasticAgent的_invoke_run接口中。
该接口实现中会循环monitor 当前node上所有worker process的状态 ,如果process 异常 ,则会进行入UNHEALTHY/FAILED状态的处理流程 。
如果当前重试的次数小于_remain_restart,则会发起restart worker的流程
restart worker的实现逻辑也很清晰: whaosoft aiot http://143ai.com
先stop 点前node上所有worker
然后重新走_initialize_workers逻辑来进行Rendezvous和start worker
def _restart_workers(self, worker_group: WorkerGroup) -> None: """ Restarts (stops, rendezvous, starts) all local workers in the group. """ role = worker_group.spec.role log.info(f"[{role}] Stopping worker group") self._stop_workers(worker_group) worker_group.state = WorkerState.STOPPED self._initialize_workers(worker_group)6.2 ElasticAgent Fail
首先 ,我们看下当一个node Fail掉后 ,弹性训练是如何运行的 。这有两个node:node0和node1 ,开始node0和node1同时进行分布式训练 ,当训练到一定时间后 ,我们将node1 kill掉。
这是node1上的日志:
[763] epoch 14 (rank = 4, local_rank = 0) loss = 1.2388396263122559 [765] epoch 14 (rank = 6, local_rank = 2) loss = 1.4543075561523438 [766] epoch 14 (rank = 7, local_rank = 3) loss = 1.0290627479553223 [764] epoch 14 (rank = 5, local_rank = 1) loss = 1.1143463850021362 ^CTraceback (most recent call last): Traceback (most recent call last): File "/opt/conda/bin/torchrun", line 33, in <module> sys.exit(load_entry_point(torch==1.11.0, console_scripts, torchrun)()) File "/opt/conda/lib/python3.8/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 345, in wrapper return f(*args, **kwargs) File "/opt/conda/lib/python3.8/site-packages/torch/distributed/run.py", line 724, in main run(args) File "/opt/conda/lib/python3.8/site-packages/torch/distributed/run.py", line 715, in run elastic_launch( File "/opt/conda/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 131, in __call__ return launch_agent(self._config, self._entrypoint, list(args)) File "/opt/conda/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 236, in launch_agent result = agent.run() File "/opt/conda/lib/python3.8/site-packages/torch/distributed/elastic/metrics/api.py", line 125, in wrapper result = f(*args, **kwargs) File "/opt/conda/lib/python3.8/site-packages/torch/distributed/elastic/agent/server/api.py", line 709, in run result = self._invoke_run(role) File "/opt/conda/lib/python3.8/site-packages/torch/distributed/elastic/agent/server/api.py", line 850, in _invoke_run time.sleep(monitor_interval) File "/opt/conda/lib/python3.8/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 60, in _terminate_process_handler raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval) torch.distributed.elastic.multiprocessing.api.SignalException: Process 759 got signal: 2这是node0上的日志 ,我们可以得出以下结论:
当Elastic Agent退出时 ,会导致其他存活的Elastic Agent中的process 运行失败;这是因为剩余process无法在正常进行collective communication了;
存活的Elastic Agent会按照UNHEALTHY/FAILED的处理逻辑来重启本机的worker;若失败的Elastic Agent没有重启 ,则剩余的Elastic Agent重新构建worker group继续进行训练 ,若失败的Elastic Agent重新启动(例如kubernetes中job提供重启的机制) ,则会重新加入到整个训练任务中;
# 1) 此时node0和node1共同进行分布式训练 ... [11762] epoch 14 (rank = 2, local_rank = 2) loss = 1.1763713359832764 [702/1958] [11760] epoch 14 (rank = 0, local_rank = 0) loss = 1.324049949645996 # 2) 此时node1被kill掉 ,因此当执行collective communication时 ,会报出异常 [E ProcessGroupNCCL.cpp:406] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete d ata. To avoid this inconsistency, we are taking the entire process down. terminate called after throwing an instance of std::runtime_error what(): NCCL error: unhandled system error, NCCL version 21.0.3 ncclSystemError: System call (socket, malloc, munmap, etc) failed. # 3)stop 其他三个process WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 11761 closing signal SIGTERM WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 11762 closing signal SIGTERM WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 11763 closing signal SIGTERM ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: -6) local_rank: 0 (pid: 11760) of binary: /opt/conda/bin/python # 4)重新走_initialize_workers逻辑 [11828] Initializing process group with: {MASTER_ADDR: iZ2ze9q3ftqtxtqlkrk6tuZ, MASTER_PORT: 40539, WORLD_SIZE: 4, LOCAL_WORLD_SIZE: 4}[11825] Initializing process group with: {MASTER_ADDR: iZ2ze9q3ftqtxtqlkrk6tuZ, MASTER_PORT: 40539, WORLD_SIZE: 4, LOCAL_WORLD_SIZE: 4} [11826] Initializing process group with: {MASTER_ADDR: iZ2ze9q3ftqtxtqlkrk6tuZ, MASTER_PORT: 40539, WORLD_SIZE: 4, LOCAL_WORLD_SIZE: 4} [11827] Initializing process group with: {MASTER_ADDR: iZ2ze9q3ftqtxtqlkrk6tuZ, MASTER_PORT: 40539, WORLD_SIZE: 4, LOCAL_WORLD_SIZE: 4} [11827] (rank = 2, local_rank = 2) train worker starting... [11828] (rank = 3, local_rank = 3) train worker starting... [11825] (rank = 0, local_rank = 0) train worker starting... [11826] (rank = 1, local_rank = 1) train worker starting... # 5)node0 独自进行分布式训练 load checkpoint from checkpoint.ptload checkpoint from checkpoint.ptload checkpoint from checkpoint.ptload checkpoint from checkpoint.pt [11826] epoch 14 (rank = 1, local_rank = 1) loss = 0.839302122592926 [11828] epoch 14 (rank = 3, local_rank = 3) loss = 0.8971960544586182 [11825] epoch 14 (rank = 0, local_rank = 0) loss = 1.33822691440582287 Scale Up/Down
Scale Down的可以理解为上文中Elastic Agent退出,但是没有重启的场景 ,因此这里不再赘述 。
Scale UP这里要再介绍一下 ,Scale UP的流程仍旧可以用上图进行描述:
当有新的节点加入时,由于当前Elastic已经建立一个的Rendezvous ,其无法加入 ,所以当前Node会被加入到_RendezvousState的wait_list中
当ElasticAgent和对应的worker process都正常运行时 ,monitor会返回Healthy的状态;此时 ,ElasticAgent会检查_RendezvousState的waiting list的node个数 ,发现waiting list大于0 ,则出发restart worker来发起新一轮的Rendezvous以将新的加入 ,这样新的Node加入到了worker group中;
二 \ 代码----
著名物理学家 ,诺贝尔奖得主Richard Feynman办公室的黑板上写了:"What I cannot create, I do not understand." 。在程序员界也经常有"show me the code"的口号 。因此 ,我打算写一系列的分布式训练的文章 ,将以往抽象的分布式训练的概念以代码的形式展现出来 ,并保证每个代码可执行 、可验证 、可复现 ,并贡献出来源码让大家相互交流 。
经过调研发现pytorch对于分布式训练做好很好的抽象且接口完善,因此本系列文章将以pytorch为主要框架进行 ,文章中的例子很多都来自pytorch的文档 ,并在此基础上进行了调试和扩充 。
最后,由于分布式训练的理论介绍网络上已经很多了 ,理论部分的介绍不会是本系列文章的重点 ,我会将重点放在代码层面的介绍上面 。
Pytorch - 分布式训练极简体验:https://zhuanlan.zhihu.com/p/477073906
Pytorch - 分布式通信原语(附源码):https://zhuanlan.zhihu.com/p/478953028
Pytorch - 手写allreduce分布式训练(附源码):https://zhuanlan.zhihu.com/p/482557067
Pytorch - 算子间并行极简实现(附源码):https://zhuanlan.zhihu.com/p/483640235
Pytorch - 多机多卡极简实现(附源码):https://zhuanlan.zhihu.com/p/486130584
1. 介绍
Pytorch在1.9.0引入了torchrun ,用其替代1.9.0以前版本的torch.distributed.launch 。torchrun在torch.distributed.launch 功能的基础上主要新增了两个功能:
Failover: 当worker训练失败时 ,会自动重新启动所有worker继续进行训练;
Elastic: 可以动态增加或或删除node节点 ,本文将通过一个例子说明Elastic Training应该如何使用;
本例中会先在Node0上启动4 GPU的worker group ,等其训练一段时间后 ,会在Node1上再启动4 GPU的workers ,并与Node1上的workers构成一个新的worker group ,最终构成一个2机8卡的分布式训练 。
2. 模型构建
一个简单的全连接模型神经网络模型
class ToyModel(nn.Module): def __init__(self): super(ToyModel, self).__init__() self.net1 = nn.Linear(10, 10) self.relu = nn.ReLU() self.net2 = nn.Linear(10, 5) def forward(self, x): return self.net2(self.relu(self.net1(x)))3. checkpoint 处理
由于再每次增加或删除node时 ,会将所有worker kill掉 ,然后再重新启动所有worker进行训练 。因此 ,在训练代码中要对训练的状态进行保存,以保证重启后能接着上次的状态继续训练 。
需要保存的信息一般有如下内容:
model :模型的参数信息
optimizer :优化器的参数信心
epoch:当前执行到第几个epoch
save和load的代码如下所示
torch.save:利用python的pickle将python的object 进行序列化 ,并保存到本地文件;
torch.load : 将torch.save后的本地文件进行反序列化 ,并加载到内存中;
model.state_dict(): 存储了model 每个layer和其对应的param信息
optimizer.state_dict():存储了优化器的参数信信息
def save_checkpoint(epoch, model, optimizer, path): torch.save({ "epoch": epoch, "model_state_dict": model.state_dict(), "optimize_state_dict": optimizer.state_dict(), }, path) def load_checkpoint(path): checkpoint = torch.load(path) return checkpoint4. 训练代码
初始化逻辑如下:
1~3行: 输出当前worker的关键环境变量,用于后面的结果展示
5~8行:创建模型 、优化器和损失函数
10~12行:初始化参数信息
14~19行:如果存在checkpoint ,则加载checkpoint ,并赋值给model、optimizer和firt_epoch
local_rank = int(os.environ["LOCAL_RANK"]) rank = int(os.environ["RANK"]) print(f"[{os.getpid()}] (rank = {rank}, local_rank = {local_rank}) train worker starting...") model = ToyModel().cuda(local_rank) ddp_model = DDP(model, [local_rank]) loss_fn = nn.MSELoss() optimizer = optim.SGD(ddp_model.parameters(), lr=0.001) optimizer.zero_grad() max_epoch = 100 first_epoch = 0 ckp_path = "checkpoint.pt" if os.path.exists(ckp_path): print(f"load checkpoint from {ckp_path}") checkpoint = load_checkpoint(ckp_path) model.load_state_dict(checkpoint["model_state_dict"]) optimizer.load_state_dict(checkpoint["optimize_state_dict"]) first_epoch = checkpoint["epoch"]训练逻辑:
1行:epoch执行的次数为first_epoch到max_epoch ,以便能够在worker被重启后继续原有的epoch继续训练;
2行:为了展示动态添加node效果 ,这里添加sleep函数来降低训练的速度;
3~8行:模型训练流程;
9行:为了简单 ,文本每个epoch进行一次checkpoint保存;将当前的epoch ,model和optimizer保存到checkpoint中;
for i in range(first_epoch, max_epoch): time.sleep(1) # 为了展示动态添加node效果 ,这里添加sleep函数来降低训练的速度 outputs = ddp_model(torch.randn(20, 10).to(local_rank)) labels = torch.randn(20, 5).to(local_rank) loss = loss_fn(outputs, labels) loss.backward() print(f"[{os.getpid()}] epoch {i} (rank = {rank}, local_rank = {local_rank}) loss = {loss.item()}\n") optimizer.step() save_checkpoint(i, model, optimizer, ckp_path)5. 启动方式
由于我们使用torchrun来启动多机多卡任务 ,无需使用spawn接口来启动多个进程(torchrun会负责将我们的python script启动为一个process) ,因此直接调用上文编写的train函数 ,并在前后分别添加DistributedDataParallel的初始化和效果函数即可 。
下面代码描述了上文train接口的调用 。
def run(): env_dict = { key: os.environ[key] for key in ("MASTER_ADDR", "MASTER_PORT", "WORLD_SIZE", "LOCAL_WORLD_SIZE") } print(f"[{os.getpid()}] Initializing process group with: {env_dict}") dist.init_process_group(backend="nccl") train() dist.destroy_process_group() if __name__ == "__main__": run()本例中使用torchrun来执行多机多卡的分布式训练任务(注:torch.distributed.launch已经被pytorch淘汰了 ,尽量不要再使用)。启动脚本描述如下(注:node0和node1均通过该脚本进行启动)
--nnodes=1:3 :表示当前训练任务接受最少1个node ,最多3个node参与分布式训练;
--nproc_per_node=4:表示每个node上节点有4个process
--max_restarts=3: worker group最大的重启次数;这里需要注意的是,node fail 、node scale down和node scale up都会导致restart;
--rdzv_id=1:一个unique的job id ,所有node均使用同一个job id;
--rdzv_backend: rendezvous的backend实现 ,默认支持c10d和etcd两种;rendezvous用于多个node之间的通信和协调;
--rdzv_endpoint:rendezvous的地址,应该为一个node的host ip和port;
torchrun \ --nnodes=1:3\ --nproc_per_node=4\ --max_restarts=3\ --rdzv_id=1\ --rdzv_backend=c10d\ --rdzv_endpoint="192.0.0.1:1234"\ train_elastic.py6. 结果分析
代码:BetterDL - train_elastic.py:https://github.com/tingshua-yts/BetterDL/blob/master/test/pytorch/DDP/train_elastic.py
运行环境: 2台4卡 v100机器
image: pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtime gpu: v100先在node0上执行执行启动脚本
torchrun \ --nnodes=1:3\ --nproc_per_node=4\ --max_restarts=3\ --rdzv_id=1\ --rdzv_backend=c10d\ --rdzv_endpoint="192.0.0.1:1234"\ train_elastic.py得到如下结果
2~5行:当前启动的是单机4卡的训练任务 ,因此WORLD_SIZE为4 , LOCAL_WORKD_SIZE也为4
6~9行:共有4个rank参与了分布式训练 ,rank0~rank3
10~18行: rank0~rank3 均从epoch=0开始训练
r/workspace/DDP# sh run_elastic.sh [4031] Initializing process group with: {MASTER_ADDR: 192.0.0.1, MASTER_PORT: 44901, WORLD_SIZE: 4, LOCAL_WORLD_SIZE: 4} [4029] Initializing process group with: {MASTER_ADDR: 192.0.0.1, MASTER_PORT: 44901, WORLD_SIZE: 4, LOCAL_WORLD_SIZE: 4} [4030] Initializing process group with: {MASTER_ADDR: 192.0.0.1, MASTER_PORT: 44901, WORLD_SIZE: 4, LOCAL_WORLD_SIZE: 4} [4032] Initializing process group with: {MASTER_ADDR: 192.0.0.1, MASTER_PORT: 44901, WORLD_SIZE: 4, LOCAL_WORLD_SIZE: 4} [4029] (rank = 0, local_rank = 0) train worker starting... [4030] (rank = 1, local_rank = 1) train worker starting... [4032] (rank = 3, local_rank = 3) train worker starting... [4031] (rank = 2, local_rank = 2) train worker starting... [4101] epoch 0 (rank = 1, local_rank = 1) loss = 0.9288564920425415 [4103] epoch 0 (rank = 3, local_rank = 3) loss = 0.9711472988128662 [4102] epoch 0 (rank = 2, local_rank = 2) loss = 1.0727070569992065 [4100] epoch 0 (rank = 0, local_rank = 0) loss = 0.9402943253517151 [4100] epoch 1 (rank = 0, local_rank = 0) loss = 1.0327017307281494 [4101] epoch 1 (rank = 1, local_rank = 1) loss = 1.4485043287277222 [4103] epoch 1 (rank = 3, local_rank = 3) loss = 1.0959293842315674 [4102] epoch 1 (rank = 2, local_rank = 2) loss = 1.0669530630111694 ...在node1上执行与上面相同的脚本
torchrun \ --nnodes=1:3\ --nproc_per_node=4\ --max_restarts=3\ --rdzv_id=1\ --rdzv_backend=c10d\ --rdzv_endpoint="192.0.0.1:1234"\ train_elastic.pynode1上结果如下:
2~5行:由于添加node1 ,当前执行的是2机8卡的分布式训练任务 ,因此WORLD_SIZE=8 , LOCAL_WORLD_SIZE=4
6~9行:当前node1上workers的rank为rank4 ~rank7
13~20行: 由于node1是在node0上work训练到epoch35的时候加入的 ,因此其接着epoch 35开始训练
/workspace/DDP# sh run_elastic.sh [696] Initializing process group with: {MASTER_ADDR: 192.0.0.1, MASTER_PORT: 42913, WORLD_SIZE: 8, LOCAL_WORLD_SIZE: 4} [697] Initializing process group with: {MASTER_ADDR: 192.0.0.1, MASTER_PORT: 42913, WORLD_SIZE: 8, LOCAL_WORLD_SIZE: 4} [695] Initializing process group with: {MASTER_ADDR: 192.0.0.1, MASTER_PORT: 42913, WORLD_SIZE: 8, LOCAL_WORLD_SIZE: 4} [694] Initializing process group with: {MASTER_ADDR: 192.0.0.1, MASTER_PORT: 42913, WORLD_SIZE: 8, LOCAL_WORLD_SIZE: 4} [697] (rank = 7, local_rank = 3) train worker starting... [695] (rank = 5, local_rank = 1) train worker starting... [694] (rank = 4, local_rank = 0) train worker starting... [696] (rank = 6, local_rank = 2) train worker starting... load checkpoint from checkpoint.ptload checkpoint from checkpoint.pt load checkpoint from checkpoint.pt load checkpoint from checkpoint.pt [697] epoch 35 (rank = 7, local_rank = 3) loss = 1.1888569593429565 [694] epoch 35 (rank = 4, local_rank = 0) loss = 0.8916441202163696 [695] epoch 35 (rank = 5, local_rank = 1) loss = 1.5685604810714722 [696] epoch 35 (rank = 6, local_rank = 2) loss = 1.11683189868927 [696] epoch 36 (rank = 6, local_rank = 2) loss = 1.3724170923233032 [694] epoch 36 (rank = 4, local_rank = 0) loss = 1.061527967453003 [695] epoch 36 (rank = 5, local_rank = 1) loss = 0.96876460313797 [697] epoch 36 (rank = 7, local_rank = 3) loss = 0.8060566782951355 ...node0上结果如下:
6~9行: node0上的works在执行到epoch 35时 ,node1上执行了训练脚本 ,请求加入到训练任务中
10~13行:所有workers重新启动 ,由于添加了node1 ,当前执行的是2机8卡的分布式训练任务 ,因此WORLD_SIZE=8, LOCAL_WORLD_SIZE=4
14~17行:当前node1上works的rank为rank0~rank3
18~21行:加载checkpoint
22~30行:接着checkpoint中的model 、optimizer和epoch继续训练
... [4100] epoch 35 (rank = 0, local_rank = 0) loss = 1.0746158361434937 [4101] epoch 35 (rank = 1, local_rank = 1) loss = 1.1712706089019775 [4103] epoch 35 (rank = 3, local_rank = 3) loss = 1.1774182319641113 [4102] epoch 35 (rank = 2, local_rank = 2) loss = 1.0898035764694214 WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 4100 closing signal SIGTERM WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 4101 closing signal SIGTERM WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 4102 closing signal SIGTERM WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 4103 closing signal SIGTERM [4164] Initializing process group with: {MASTER_ADDR: 192.0.0.1, MASTER_PORT: 42913, WORLD_SIZE: 8, LOCAL_WORLD_SIZE: 4} [4165] Initializing process group with: {MASTER_ADDR: 192.0.0.1, MASTER_PORT: 42913, WORLD_SIZE: 8, LOCAL_WORLD_SIZE: 4} [4162] Initializing process group with: {MASTER_ADDR: 192.0.0.1, MASTER_PORT: 42913, WORLD_SIZE: 8, LOCAL_WORLD_SIZE: 4} [4163] Initializing process group with: {MASTER_ADDR: 192.0.0.1, MASTER_PORT: 42913, WORLD_SIZE: 8, LOCAL_WORLD_SIZE: 4} [4162] (rank = 0, local_rank = 0) train worker starting... [4163] (rank = 1, local_rank = 1) train worker starting... [4164] (rank = 2, local_rank = 2) train worker starting... [4165] (rank = 3, local_rank = 3) train worker starting... load checkpoint from checkpoint.pt load checkpoint from checkpoint.pt load checkpoint from checkpoint.pt load checkpoint from checkpoint.pt [4165] epoch 35 (rank = 3, local_rank = 3) loss = 1.3437936305999756 [4162] epoch 35 (rank = 0, local_rank = 0) loss = 1.5693414211273193 [4163] epoch 35 (rank = 1, local_rank = 1) loss = 1.199862003326416 [4164] epoch 35 (rank = 2, local_rank = 2) loss = 1.0465545654296875 [4163] epoch 36 (rank = 1, local_rank = 1) loss = 0.9741991758346558 [4162] epoch 36 (rank = 0, local_rank = 0) loss = 1.3609280586242676 [4164] epoch 36 (rank = 2, local_rank = 2) loss = 0.9585908055305481 [4165] epoch 36 (rank = 3, local_rank = 3) loss = 0.9169824123382568 ...创心域SEO版权声明:以上内容作者已申请原创保护,未经允许不得转载,侵权必究!授权事宜、对本内容有异议或投诉,敬请联系网站管理员,我们将尽快回复您,谢谢合作!