Sorry, your browser cannot access this site
This page requires browser support (enable) JavaScript
Learn more >

在图像识别任务中使用纯Transformer

摘要

在CV中self-attention通常会与CNN一起使用,本文提出,只需要使用单独的Transformer模型就能对该任务进行有效的处理,并且在庞大的数据集上训练后,迁移到小数据集上时,效果将达到SOTA

引言

对于Transformer模型而言,self-attention需要对输入的每一个元素两两之间进行交互,假设元素长度为N,那么这一过程则需要$O(n^2)$的复杂度。在NLP领域中,类似BERT的模型目前最大只能支持到N=512的输入。而CV中图像检测、识别的数据中,大的图片将达到800*800的像素,这就意味着我们需要找到一种合适的方法,将2d的图片投影到1d,并使其数据量相对减少。

一些思路是将CNN与self-attention混用,即将CNN的特征图作为Transformer的输入[1]

另一些直接放弃CNN的方法是将2d的图片矩阵才分为H*W两个1d的向量,即先对高做一次self-attention, 再对宽做一次self-attention,这样就只进行了两次顺序的self-attention[2 stand-alone attention],这一方法虽然高效,但使用了特殊的注意力机制,无法使用硬件加速,因此无法将模型做的太大。

本文的思路是将一张图片转化为一个一个patch组成的图片。即认为一个patch是16像素,那么一个224 * 224的图片,我们可以认为他是由 14 * 14个16 * 16个的小patch构成的。于是将16 * 16个patch作为transformer的输入。

本文训练时使用有监督训练。

与本工作类似的有ICLR 2020的论文[3],该片文章只使用了2*2的patch

本文表示如果不加约束条件,ViT与规格相同的ResNet相比,效果稍弱,原因在于ResNet包含了一些ViT所不具备的先验知识,这些先验知识包括:

  1. locality:局部性,即相邻的区域会有相邻的特征,CNN由滑动窗口组成,因此天然具备这一特性
  2. translation equivariance:平移不变性,即同样的物体无论移动到哪里,只要是同样的输入,遇到了同样的卷积核,输出将不变。

本文在ImageNet、ImageNet-ReaL、CIFAR-100、VTAB数据集上进行了实验来论证上述推论。

结论

本文除了在产生patch和计算位置编码时引入了图像的归纳偏置,除此之外就再也没有引入图像特有的归纳偏置,即不需要对vision由先验知识。

相关工作

NLP领域中,在但数据集上训练好,再用小数据集微调的模型,例如BERT使用的是分离子监督学习,即完形填空式——讲句子中某些词划去然后让模型预测;GPT使用language modeling,即已经有一个句子,预测下一个词是什么。上述训练方法都是自监督训练。

另一个与本文相似的工作是Image GPT(iGPT)使用无监督的学习来训练的生成模型,该模型同样使用Transformer,但该模型在ImageNet上的准确率只能达到72%

方法

模型上本文尽可能采用original Transformer

ViT模型

ViT(Vision Transformer)解析

如图所示的结构中,模型先讲图片分为不同的patch,然后经过一个线性投影层,这个过程中,为了保证patch之间的位置关系,加入了位置编码Position

Embedding。

然后输入到Tranformer Encoder,得到很多输出。

进行分类任务时,该模型借鉴了BERT的思想,BERT中提到使用Extra Learnable Embeding,CLS。

ViT同样使用这一方法,讲这一字符融入Position embedding中(图中用*表示),放在位置信息为0的位置。由于self-attention中每个位置都会与余下所有位置进行交互,因此该位置将会从其他位置中学习到有用的信息。因此,只需要用这一项的输出进行分类即可,如图MLP Head即用于分类的简单MLP。

ViT模型相对简洁,而它的特殊之处,是如何将一张图片变为一系列的token。

图像处理

例如:

我们使用维度为244 * 244 * 3的彩色图片作为输入,让每个Patch的维度大小为16 * 16

那么得到的token数$N = \frac{224^2}{16^2} = 196$

每个token的大小为16 * 16 * 3 = 768

模型中的线性投影层则是一个全连接层,此处使用E表示,E的维度是768 * 768

于是输入X经过线性投影层的变换后的计算过程为:

$X E = [196 * 768] * [768 * 768] = [196 * 768] $

注:此处为[]外的为矩阵

最终我们得到的便是一系列1d的token,而不是2d的图片,最后计算上加入的CLS大小1*768

最终得到的输入将是197 * 768规模

位置信息的长度将被设定为768,直接使用加法操作加到token中,因此这一步不会改变token的维度,仍然是197 * 768

输入Encoder

接下来我们将这些输入Transformer Encoder。

首先经过Layer Norm层后,输入将被分为q,k,v进行多头自注意力模块,此时的q,k,v都是197 * 768

但Base 版的ViT使用了12头的注意力机制,因此每个头输入将变为197 * (768/2) = 197 * 64

最后将12个头的输出进行拼接得到197 * 768

经过另一个Layer Norm后,MLP层会将维度进行放大(一般是放大四倍),变为197 * 3072,然后再缩小投射回去,变为197 * 768

位置编码

本文使用了一个标准可学习的1D位置编码,即BERT中的位置编码,本文作者尝试了别的编码形式例如2D-aware位置编码,但最后发现结果并无显著变化。

消融实验

分类特征

本文使用CLS做分类,本文使用的分类模块是一个以tanh作为非线性激活函数的MLP。

此处的CLS是从NLP借鉴来的,但CV领域使用CNN进行分类时,通常对最后得到的feature map进行globally average-pooling(GAP) ,得到一个1d向量然后进行分类。

本文最先做的实验是,使用GAP代替CLS。

即对于Transformer Encoder的所有输出进行一个GAP,然后以此来分类,而不是添加一个CLS。

实验结果是两种方法效果一致。

位置编码

本文主要讨论了三种位置编码:

  1. 1D:本文使用得到NLP领域中常用的位置编码,每个位置编码长度为D
  2. 2D:使用矩阵来表示位置,分为X-embedding和Y-embedding每个位置编码长度为D/2,每个输入将分别得到一个X-和一个Y-embedding,拼接后得到D维度的embedding。
  3. relative positional embedding:使用各个块的相对位置来表示位置

但消融实验结果表明使用哪种位置编码并不会影响模型精度和效率。本文认为,由于使用了patch作为基本单位,并不是使用像素作为输入,因此想要知道patch之间的位置关系要比像素之间的位置关系要容易,因此位置编码的类型并不会影响实验结果。

归纳偏置

Transformer要比CNN少很多图像上的归纳偏置。

关于混合CNN与ViT的混合网络:

本文使用ResNet得到的14 * 14的特征图替换ViT中的14 * 14个patch,来进行实验。

Fine-Tuning

使用预训练好的ViT模型在小规模的数据集上进行微调时,理论上只要硬件允许,可以使用更大的图片(> 224 * 224)进行微调,但是随着图片的增大,如果patch size不变,那将导致输入维度增大。硬件允许的条件下,增大的输入是可行的,但训练好的位置编码将无法使用,因为patch改变了。

本文发现,这种情况下,只需要对位置编码进行2D差值即可,本文使用torch自带的interpolate函数即可完成。

实验

该部分使用了与训练的ResNet和各种尺寸的ViT进行了对比。

该部分实验中特别提到预训练的ResNet、ViT、Hybrid(ResNet混合ViT)进行不同数据大小的对比,该实验先采用JFT 300M(都采用大数据集进行预训练屏蔽数据规模对模型的影响)进行与训练,然后在ImagNet-Real、Pets、flowers、cifar20和cifar100上进行实验然后取平均,右图是ImageNet-Real上单独的结果:

image-20230216140736013

如图可以发现,在小数据集上Hybrid模型效果要优于纯ViT和ResNet,在同等的计算复杂度下,ViT要优于ResNet,但在较低的计算复杂度下,Hybrid的效果表现较好。另外,随着计算复杂度的上升,ViT的效果逐渐超过Hybrid,关于为什么CNN抽离出的特征为什么没有帮助Transformer进行更好的学习,本文并未给出对这一现象的解释。

自监督训练

本文模仿了BERT中的mask方法,即mask掉句子中的一些单词,让模型预测出该词语,本文提出了一个patch-mask,即mask掉一些patch,然后让网络重建这些patch。但是该方法仍然比最好的有监督的预训练低了4%

feature work

本文希望引入CV中的对比学习来使ViT可以进行自监督学习。随后的MoCo v3,DINO就是使用Contrastive训练的ViT

评论