作者王益 OneFlow社区编译 翻译杨婷 最近,我在处理PyTorch分布式和TorchRec相关的工作,为此,我开始学习PyTorch2。0。在业余时间,我也在跟着Alpa作者学习JAX和XLA。如今回顾这些技术,我发现它们的关注点似乎都是如下两个问题: 包含自动求导和并行在内的函数转换,例如vmap,pmap和pjit等;异构计算,CPU负责控制流,GPUTPU负责张量计算和集合通信。 本文档中的所有例子都支持在Colab中运行: 1函数转换 函数转换意为将一个程序转变成另一个程序,最常见的例子是自动求导(autograd)。自动求导采用用户编写的前向过程并创建后向过程,对于用户来说,编写自动求导通常都太过复杂。函数转换的主要难点在于:在编写函数转换算法时以何种方式表示输入和输出过程。 Theano:显式地构建IR Theano是最早的深度学习工具之一,也就是如今为人们所熟知的Aesara项目。Theano有一个允许用户在内存中将IR构建为数据结构的API,因此Theano可实现自动求导,并将结果输出为Python函数。 importaesarafromaesaraimporttensorasataat。dscalar(a)Defineplaceholders,whichhavenovalues。bat。dscalar(b)cabcnowcontainstheIRofanexpression。TTdcaesara。grad(c,a)ConverttheIRincintoanotherone,dcfdcaesara。function(〔a,b〕,dc)ConverttheIRintoaPythonfunction,assertfdc(1。5,2。5)2。5sowecancallit。 TensorFlow1。x:用于运行IR的虚拟机 TensorFlow1。x明确保留了构建IR的想法。若在TensorFlow中运行上述示例,结果不会有什么差别;但倘若在TensorFlow1。x中来运行,最大的差别在于:我们不会将后向IR转换为Python函数,并使用Python解释器来运行。相反,我们会在TensorFlowruntime中来运行。 importtensorflow。compat。v1astfTensorFlow1。xAPIimportnumpyasnptf。disableeagerexecution()atf。placeholder(tf。float32,shape())btf。placeholder(tf。float32,shape())cabdctf。gradients(c,〔a〕,stopgradients〔a,b〕)withtf。compat。v1。Session()assess:TensorFlowhasaruntimetoexecutetheIR,xnp。single(2)so,noconvertingitintoPythoncode。ynp。single(3)print(sess。run(dc,feeddict{a:x,b:y})) PyTorch1。x:没有前向IR PyTorch不会像Theano或TensorFlow那样将前向传播转换为IR。反之,PyTorch使用Python解释器来运行前向传播。这样做的弊端在于会在运行期间生成表示后向传播的IR,我们称之为Eager模式(动态图模式)。 importtorchatorch。tensor(1。0,requiresgradTrue)Thesearenotplaceholders,butvalues。btorch。tensor(2。0)cabEvaluatescandderivestheIRofthebackwardinc。gradfn。c。backward()Executesc。gradfn。print(c。grad) TensorFlow2。x:梯度带 TensorFlow2。x增加了一个像PyTorchAPI的Eager模式API。此API追踪前向传播如何运行名为梯度带(GradientTape)的IR。TensorFlow2。x可以从这个跟踪中找出后向传播。 importtensorflowastfatf。Variable(1。0)LikePyTorch,thesearevalues,notplacehodlers。btf。Variable(2。0)withtf。GradientTape()astape:cabdcdatape。gradient(c,a)print(dcda) JAX JAX不会向用户公开诸如梯度带等方面的低级别细节。简单说来,JAX的思维方式为:将输入和输出都用Python函数来表示。 importjaxa2。0b3。0jax。grad(jax。lax。mul)(a,b)Computecabw。r。t。a。Theresultisb3。jax。jit(jax。grad(jax。lax。mul))(a,b)jax。experimental。pjit(jax。grad(jax。lax。mul),devicemesh(ntpus))(a,b) 对于想要自己编写的函数转换的高级用户,他们可以调用makejaxpr等低级API来访问IR,称为JAXPR。 jax。makejaxpr(jax。lax。mul)(2。0,3。0)ReturnstheIRrepresentingjax。lax。mul(2,3)jax。makejaxpr(jax。grad(jax。lax。mul))(2。0,3。0)ReturnstheIRofgrad(mul)(2,3) FuncTorch FuncTorch和JAX类似,都是基于PyTorch的函数转换。 importtorch,functorchatorch。tensor(〔2。0〕)btorch。tensor(〔3。0〕)functorch。grad(torch。dot)(a,b) JAX的makejaxpr类似于functorch的makefx。 deff(a,b):returntorch。dot(a,b)Havetowrapthebuiltinfunctiondotintof。必须将内置函数dot转换成f。print(functorch。makefx(f)(a,b)。code)print(functorch。makefx(functorch。grad(f))(a,b)。code) TensorFlow2。x、JAX和functorch都为前向传递构建了一个IR,但PyTorchEager模式没有。IR不仅可用于自动求导,还可用于其他类型的函数转换。在下列例子中,functorch。compile。aotfunction调用了回调函数printcompilefn两次,分别用于前向和后向传播。 fromfunctorch。compileimportaotfunctionimporttorch。fxasfxdefprintcompilefn(fxmodule,args):print(fxmodule)returnfxmoduleaotfnaotfunction(torch。dot,printcompilefn)aotfn(a,b) 2高阶导数 PyTorch importtorchfromtorchimportautogradxtorch。tensor(1。,requiresgradTrue)y2x38firstderivativeautograd。grad(y,x,creategraphTrue)print(firstderivative)secondderivativeautograd。grad(firstderivative,x)print(secondderivative) TensorFlow2。x importtensorflowastfxtf。Variable(1。0)withtf。GradientTape()asoutertape:withtf。GradientTape()astape:y2x38dydxtape。gradient(y,x)print(dydx)d2ydx2outertape。gradient(dydx,x)print(d2ydx2) JAX deff(a):return2a38print(jax。grad(f)(1。0))print(jax。grad(jax。grad(f))(1。0)) 3动态控制流 动态控制流(dynamiccontrolflows)有两个层级:在CPU上运行的粗粒度级别和在GPUTPU上运行的细粒度级别。本部分主要介绍在CPU上运行的粗粒度级别的动态控制流。下面我们将用(ifelse)条件语句作为例子检验深度学习工具。 TensorFlow1。x 在TensorFlow1。x中,我们需要将条件语句显式构建到IR中。此时条件语句是一个特殊的运算符tf。cond。 deff1():returntf。multiply(a,17)deff2():returntf。add(b,23)rtf。cond(tf。less(a,b),f1,f2)withtf。compat。v1。Session()assess:TensorFlowhasaruntimetoexecutetheIR,print(sess。run(r,feeddict{a:x,b:y})) TensorFlow2。x TensorFlow2。x支持使用tf。cond和tf。whileloop显式构建控制流。此外,实验项目googletangent中有AutoGraph功能,它可以将Python控制流转换为tf。cond或tf。whileloop。此功能利用了Python解释器支持的函数和函数源代码。例如下面的g函数调用了Python的标准库将源代码解析为AST,然后调用SSA表单来理解控制流。 defg(x,y):iftf。reduceany(xy):returntf。multiply(x,17)returntf。add(y,23)convertedgtf。autograph。tograph(g)importinspectprint(inspect。getsource(convertedg)) JAX 由于部分Python语法很复杂,所以通过解析源代码来理解控制流就显得很困难,这就导致AutoGraph经常出错。但如果这种方法很简单,那么Python开发者社区也不会在构建Python编译器时失败这么多次了。正是由于有这种挑战的存在,必须要明确地将控制流构建到IR中。为此,JAX提供了jax。lax。cond和jax。lax。forloop函数。 jax。lax。cond(ab,lambda:a17,lambda:b23) 考虑到这一点,你可能会觉得我们可以使用递归算法。但是下面用于计算阶乘的递归无法用JAX跟踪。 deffactorial(r,x):returnjax。lax。cond(x1。0,lambda:r,lambda:factorial(rx,x1))factorial(1。0,3。0) 可能你还想调用factorial来计算3!6。但这会让递归深度超过最大值,因为递归不仅依赖于条件,还依赖于函数定义和调用。 PyTorch PyTorch最初是Pythonnative。正如前文所说,由于多功能调度机制,grad和vamp的函数转换都是即时的。值得注意的是: 相比Theano和TensorFlow构建IR后的函数转换,即时函数转换效率更高。在进行grad和vmap时,JAX也是即时函数转换。然而像pamp和pjit等更复杂的函数转换需要对整个计算过程进行概述,在这个过程中IR是必不可少的。 由于IR在pmap和pjit中的必要性,PyTorch社区最近添加了torch。condpytorchpytorch83154 4分布式计算 根据执行代码或IR的不同方式,在使用Python解释器或runtime时,有两种分布式计算方法。 PythonNative Theano和PyTorch采用了Pythonnative分布式计算方式。这种分布式训练工作包含多个Python解释器进程。这导致出现了以下结果。 打包和运行(Packandrun)。由于这些Python进程在不同的host上运行,因此我们需要打包用户程序和依赖项,并将它们发送到这些host上去运行。一直以来TorchX负责了这个打包过程。它支持例如Docker和torch。package等各种打包格式,并且可以与各种集群管理器配合使用,如Kubernetes和SLURM。单程序多数据(SPMD)。由于将用户程序发送到各种host上要依赖于打包,与其他权重较轻的方式(如通过RPC发送代码)相比,这种方式不太灵活,因此,我们通常只发送一个程序。当所有这些进程运行同一程序时,这个作业就变成了单程序多数据(SPMD)作业。 PythonnativeSPMD 下面是一个简单的SPMDPyTorch程序,我们可以在相同或不同的host上使用进程运行这个程序。在这个过程中,我们只需要调用allgather。真正的分布式训练程序会调用更高级别的API,例如torch。nn。parallel。DistributedDataParallel和torchrec。DistributedModelParallel,然后再调用低级API,例如allgather和allreduce。 importosimporttorchfromtorchimportdistributedasdistdefmain():usegputorch。cuda。isavailable()localrankint(os。environ。get(LOCALRANK,0))localworldsizeint(os。environ。get(LOCALWORLDSIZE,0))devicetorch。device(fcuda:{localrank}ifusegpuelsecpu)dist。initdistributed(backendnccl)lsttorch。tensor(〔localrank100〕)。to(device)placeholderrltlst〔torch。zeroslike(lst)forinrange(localworldsize)〕dist。allgather(rltlst,lst,asyncopFalse)print(Afterbroadcasting:,rltlst) PythonnativeNonSPMD PyTorch不仅限于SPMD式的分布式训练。它还通过torch。distributed。pipeline。sync。Pipe和PiPPyproject提供流水并行,其中流水并行的各个阶段在不同的设备上运行不同的程序。这些阶段常通过torch。rpc包来沟通。 分布式运行时机制 分布式TensorFlow作业由运行TensorFlowruntime程序的进程组成,而不是由Python解释器组成。此分布式运行时作业执行TensorFlowgraph(IR),它是由执行用户程序的Python解释器生成。 用户程序可以使用低级API(如tf。device)去指定作业要运行什么操作、在哪台设备和主机上运行等等。因为API有runtime,所以可以做到这一点。 withtf。device(job:bartask:0device:gpu:2):opscreatedherehavethefullyspecifieddeviceabove 与PyTorch一样,TensorFlow也为分布式训练提供了高级APItf。distributed。strategy,Keras和DTensor。 strategytf。distribute。MirroredStrategy()iftf。config。listphysicaldevices(GPU)elsetf。distribute。getstrategy()withstrategy。scope():modeltf。keras。Sequential(〔tf。keras。layers。Dense(1,inputshape(1,))〕)model。compile(lossmse,optimizersgd) 分布式运行时极大地方便了训练服务的维护,因为我们不再将用户程序打包到集群上运行。相反,我们打包运行时程序,因为相比用户程序,运行时程序更加统一。 混合理念 JAX支持Pythonnative和分布式运行时。 JAX提供例如vmap、pmap和pjit的函数转换,这可以将Python函数转换为分布式程序。 (本文经授权后由OneFlow社区编译,译文转载请联系获得授权。原文:https:quip。comY8qtAyV4EXRg) 欢迎Star、试用OneFlow最新版本:https:github。comOneflowInconeflow