下面介绍一些分布式通信原语

一对多

  • Broadcast(广播,总共有p张卡,把一张卡上的数据发送到其他p-1卡上)
  • Scatter(打散,把一张卡上的数据分成p份,将p-1份数据分别发送到其他p-1张卡上)

多对一

  • Reduce(将所有p张卡上的数据相加(或者相乘等满足交换律和结合律的操作))
  • Gather(收集,反向scatter)

多对多

  • All Reduce
    • Reduce + Broadcast
    • Reduce-Scatter + All-gather
  • All Gather=Gather + Broadcast
  • Reduce Scatter(数据x卡的矩阵上来看是对行求和+乘对角矩阵)
    • 是All Reduce但是将每个机器只有一部分Reduce后的数据,而不是所有机器有全部数据。因此从这个意义上来说是Scatter
    • Ring all reduce=Ring Reduce scatter + Ring all gather(这里的reduce scatter和all gather都是通过环的方式做
  • All to All(从数据x卡的矩阵上来看就是转置)

ring all-reduce的通信量

实现时分为两步,ring all reduce = ring Reduce-Scatter + ring All-gather
假设单张卡上总数据量为V,有p张卡。卡之间连成一个环,并且只能顺时针单向通信。

第一步ring reduce-scatter将每个设备上的数据V分成p份,每份数据大小为V/p,分别reduce到各个设备上。

过程如下图,图中p=4
每一次通信,每个设备都向相邻的设备发送V/p的数据,令带宽为,那么如果发送数据量够大不考虑时延的话,通信时间为
一共需要p-1次通信,总通信时间为

做完ring reduce-scatter后,每张卡上都有一部分reduce后的正确结果。需要分别将大小为V/p的正确结果发送到其他设备上。
第二步ring all-gather的过程如下

通信时间和ring reduce-scatter一样
所以ring all-reduce的时间为

理论上来说,要想完成reduce-scatter,每张卡需要向其他p-1张卡发送自己的大小为V/p的数据。(同时接收其他p-1张卡的V/p的数据)因此理论上最小的时间为 .

基本上只和V和相关。随着p不断增加,每个设备上需要reduce的数据分片V/p也会变小,每次传输的数据分片也会变小,在这种情况下,传输时延可能会成为主要的问题。上面的通信时间分析因为忽略了时延,因此也就不成立了

显存占用:每个设备上依然需要存大小为V的数据

代码实现

我去,真挺好。有些不应该分类在这里,但是先放在这里再说了

Pytorch - 分布式训练极简体验 - 颜挺帅的文章 - 知乎

Pytorch - 分布式通信原语(附源码) - 颜挺帅的文章 - 知乎

Pytorch - 手写allreduce分布式训练(附源码) - 颜挺帅的文章 - 知乎

Pytorch - 算子间并行极简实现(附源码) - 颜挺帅的文章 - 知乎

Pytorch - 多机多卡极简实现(附源码) - 颜挺帅的文章 - 知乎


微信公众平台