本文将展示如何使用JAXFlax实现VisionTransformer(ViT),以及如何使用JAXFlax训练ViT。VisionTransformer 在实现VisionTransformer时,首先要记住这张图。 以下是论文描述的ViT执行过程。 从输入图像中提取补丁图像,并将其转换为平面向量。 投影到TransformerEncoder来处理的维度 预先添加一个可学习的嵌入(〔class〕标记),并添加一个位置嵌入。 由TransformerEncoder进行编码处理 使用〔class〕令牌作为输出,输入到MLP进行分类。细节实现 下面,我们将使用JAXFlax创建每个模块。 1、图像到展平的图像补丁 下面的代码从输入图像中提取图像补丁。这个过程通过卷积来实现,内核大小为patchsizepatchsize,stride为patchsizepatchsize,以避免重复。classPatches(nn。Module):patchsize:intembeddim:intdefsetup(self):self。convnn。Conv(featuresself。embeddim,kernelsize(self。patchsize,self。patchsize),strides(self。patchsize,self。patchsize),paddingVALID)defcall(self,images):patchesself。conv(images)b,h,w,cpatches。shapepatchesjnp。reshape(patches,(b,hw,c))returnpatches 2和3、对展平补丁块的线性投影添加〔CLS〕标记位置嵌入 TransformerEncoder对所有层使用相同的尺寸大小hiddendim。上面创建的补丁块向量被投影到hiddendim维度向量上。与BERT一样,有一个CLS令牌被添加到序列的开头,还增加了一个可学习的位置嵌入来保存位置信息。classPatchEncoder(nn。Module):hiddendim:intnn。compactdefcall(self,x):assertx。ndim3n,seqlen,x。shapeHiddendimxnn。Dense(self。hiddendim)(x)Addclstokenclsself。param(clstoken,nn。initializers。zeros,(1,1,self。hiddendim))clsjnp。tile(cls,(n,1,1))xjnp。concatenate(〔cls,x〕,axis1)Addpositionembeddingposembedself。param(positionembedding,nn。initializers。normal(stddev0。02),FromBERT(1,seqlen1,self。hiddendim))returnxposembed 4、Transformerencoder 如上图所示,编码器由多头自注意(MSA)和MLP交替层组成。Norm层(LN)在MSA和MLP块之前,残差连接在块之后。classTransformerEncoder(nn。Module):embeddim:inthiddendim:intnheads:intdropp:floatmlpdim:intdefsetup(self):self。mhaMultiHeadSelfAttention(self。hiddendim,self。nheads,self。dropp)self。mlpMLP(self。mlpdim,self。dropp)self。layernormnn。LayerNorm(epsilon1e6)defcall(self,inputs,trainTrue):AttentionBlockxself。layernorm(inputs)xself。mha(x,train)xinputsxMLPblockyself。layernorm(x)yself。mlp(y,train)returnxy MLP是一个两层网络。激活函数是GELU。本文将Dropout应用于Dense层之后。classMLP(nn。Module):mlpdim:intdropp:floatoutdim:Optional〔int〕Nonenn。compactdefcall(self,inputs,trainTrue):actualoutdiminputs。shape〔1〕ifself。outdimisNoneelseself。outdimxnn。Dense(featuresself。mlpdim)(inputs)xnn。gelu(x)xnn。Dropout(rateself。dropp,deterministicnottrain)(x)xnn。Dense(featuresactualoutdim)(x)xnn。Dropout(rateself。dropp,deterministicnottrain)(x)returnx 多头自注意(MSA) qkv的形式应为〔B,N,T,D〕,如SingleHead中计算权重和注意力后,应输出回原维度〔B,T,CND〕。classMultiHeadSelfAttention(nn。Module):hiddendim:intnheads:intdropp:floatdefsetup(self):self。qnetnn。Dense(self。hiddendim)self。knetnn。Dense(self。hiddendim)self。vnetnn。Dense(self。hiddendim)self。projnetnn。Dense(self。hiddendim)self。attdropnn。Dropout(self。dropp)self。projdropnn。Dropout(self。dropp)defcall(self,x,trainTrue):B,T,Cx。shapebatchsize,seqlength,hiddendimN,Dself。nheads,Cself。nheadsnumheads,headdimqself。qnet(x)。reshape(B,T,N,D)。transpose(0,2,1,3)(B,N,T,D)kself。knet(x)。reshape(B,T,N,D)。transpose(0,2,1,3)vself。vnet(x)。reshape(B,T,N,D)。transpose(0,2,1,3)weights(B,N,T,T)weightsjnp。matmul(q,jnp。swapaxes(k,2,1))math。sqrt(D)normalizedweightsnn。softmax(weights,axis1)attention(B,N,T,D)attentionjnp。matmul(normalizedweights,v)attentionself。attdrop(attention,deterministicnottrain)gatherheadsattentionattention。transpose(0,2,1,3)。reshape(B,T,ND)projectoutself。projdrop(self。projnet(attention),deterministicnottrain)returnout 5、使用CLS嵌入进行分类 最后MLP头(分类头)。classViT(nn。Module):patchsize:intembeddim:inthiddendim:intnheads:intdropp:floatnumlayers:intmlpdim:intnumclasses:intdefsetup(self):self。patchextracterPatches(self。patchsize,self。embeddim)self。patchencoderPatchEncoder(self。hiddendim)self。dropoutnn。Dropout(self。dropp)self。transformerencoderTransformerEncoder(self。embeddim,self。hiddendim,self。nheads,self。dropp,self。mlpdim)self。clsheadnn。Dense(featuresself。numclasses)defcall(self,x,trainTrue):xself。patchextracter(x)xself。patchencoder(x)xself。dropout(x,deterministicnottrain)foriinrange(self。numlayers):xself。transformerencoder(x,train)MLPheadxx〔:,0〕〔CLS〕tokenxself。clshead(x)returnx使用JAXFlax训练 现在已经创建了模型,下面就是使用JAXFlax来训练。 数据集 这里我们直接使用torchvision的CIFAR10。 首先是一些工具函数defimagetonumpy(img):imgnp。array(img,dtypenp。float32)img(img255。DATAMEANS)DATASTDreturnimgdefnumpycollate(batch):ifisinstance(batch〔0〕,np。ndarray):returnnp。stack(batch)elifisinstance(batch〔0〕,(tuple,list)):transposedzip(batch)return〔numpycollate(samples)forsamplesintransposed〕else:returnnp。array(batch) 然后是训练和测试的dataloadertesttransformimagetonumpytraintransformtransforms。Compose(〔transforms。RandomHorizontalFlip(),transforms。RandomResizedCrop((IMAGESIZE,IMAGESIZE),scaleCROPSCALES,ratioCROPRATIO),imagetonumpy〕)Validationsetshouldnotusetheaugmentation。traindatasetCIFAR10(data,trainTrue,transformtraintransform,downloadTrue)valdatasetCIFAR10(data,trainTrue,transformtesttransform,downloadTrue)trainset,torch。utils。data。randomsplit(traindataset,〔45000,5000〕,generatortorch。Generator()。manualseed(SEED)),valsettorch。utils。data。randomsplit(valdataset,〔45000,5000〕,generatortorch。Generator()。manualseed(SEED))testsetCIFAR10(data,trainFalse,transformtesttransform,downloadTrue)trainloadertorch。utils。data。DataLoader(trainset,batchsizeBATCHSIZE,shuffleTrue,droplastTrue,numworkers2,persistentworkersTrue,collatefnnumpycollate,)valloadertorch。utils。data。DataLoader(valset,batchsizeBATCHSIZE,shuffleFalse,droplastFalse,numworkers2,persistentworkersTrue,collatefnnumpycollate,)testloadertorch。utils。data。DataLoader(testset,batchsizeBATCHSIZE,shuffleFalse,droplastFalse,numworkers2,persistentworkersTrue,collatefnnumpycollate,) 初始化模型 初始化ViT模型definitializemodel(seed42,patchsize16,embeddim192,hiddendim192,nheads3,dropp0。1,numlayers12,mlpdim768,numclasses10):mainrngjax。random。PRNGKey(seed)xjnp。ones(shape(5,32,32,3))ViTmodelViT(patchsizepatchsize,embeddimembeddim,hiddendimhiddendim,nheadsnheads,droppdropp,numlayersnumlayers,mlpdimmlpdim,numclassesnumclasses)mainrng,initrng,droprngrandom。split(mainrng,3)paramsmodel。init({params:initrng,dropout:droprng},x,trainTrue)〔params〕returnmodel,params,mainrngvitmodel,vitparams,vitrnginitializemodel() 创建TrainState 在Flax中常见的模式是创建管理训练的状态的类,包括轮次、优化器状态和模型参数等等。还可以通过在applyfn中指定applyfn来减少学习循环中的函数参数列表,applyfn对应于模型的前向传播。defcreatetrainstate(model,params,learningrate):optimizeroptax。adam(learningrate)returntrainstate。TrainState。create(applyfnmodel。apply,txoptimizer,paramsparams)statecreatetrainstate(vitmodel,vitparams,3e4) 循环训练deftrainmodel(trainloader,valloader,state,rng,numepochs100):besteval0。0forepochidxintqdm(range(1,numepochs1)):state,rngtrainepoch(trainloader,epochidx,state,rng)ifepochidx10:evalaccevalmodel(valloader,state,rng)logger。addscalar(valacc,evalacc,globalstepepochidx)ifevalaccbesteval:bestevalevalaccsavemodel(state,stepepochidx)logger。flush()Evaluateaftertrainingtestaccevalmodel(testloader,state,rng)print(ftestacc:{testacc})deftrainepoch(trainloader,epochidx,state,rng):metricsdefaultdict(list)forbatchintqdm(trainloader,descTraining,leaveFalse):state,rng,loss,acctrainstep(state,rng,batch)metrics〔loss〕。append(loss)metrics〔acc〕。append(acc)forkeyinmetrics。keys():argvalnp。stack(jax。deviceget(metrics〔key〕))。mean()logger。addscalar(trainkey,argval,globalstepepochidx)print(f〔epoch{epochidx}〕{key}:{argval})returnstate,rng 验证defevalmodel(dataloader,state,rng):Testmodelonallimagesofadataloaderandreturnavglosscorrectclass,count0,0forbatchindataloader:rng,accevalstep(state,rng,batch)correctclassaccbatch〔0〕。shape〔0〕countbatch〔0〕。shape〔0〕evalacc(correctclasscount)。item()returnevalacc 训练步骤 在trainstep中定义损失函数,计算模型参数的梯度,并根据梯度更新参数;在valueandgradients方法中,计算状态的梯度。在applygradients中,更新TrainState。交叉熵损失是通过applyfn(与model。apply相同)计算logits来计算的,applyfn是在创建TrainState时指定的。jax。jitdeftrainstep(state,rng,batch):lossfnlambdaparams:calculateloss(params,state,rng,batch,trainTrue)Getloss,gradientsforloss,andotheroutputsoflossfunction(loss,(acc,rng)),gradsjax。valueandgrad(lossfn,hasauxTrue)(state。params)Updateparametersandbatchstatisticsstatestate。applygradients(gradsgrads)returnstate,rng,loss,acc 计算损失defcalculateloss(params,state,rng,batch,train):imgs,labelsbatchrng,droprngrandom。split(rng)logitsstate。applyfn({params:params},imgs,traintrain,rngs{dropout:droprng})lossoptax。softmaxcrossentropywithintegerlabels(logitslogits,labelslabels)。mean()acc(logits。argmax(axis1)labels)。mean()returnloss,(acc,rng)结果 训练结果如下所示。在Colabpro的标准GPU上,训练时间约为1。5小时。 testacc:0。7704000473022461 作者:satojkovic