背景
组织:谷歌研究,加州大学伯克利分校
论文地址:
https://www.aminer.cn/pub/5e5e189993d709897ce1ddbc
包括的会议:2020年国际学习资源中心
纸张代码:
https://github . com/Google/trax/tree/master/trax/models/重整器
摘要
基于Transformer的各种巨型模型往往可以在各种自然语言处理任务中取得最佳效果,但这些模型的训练成本往往过高,尤其是对于长序列文本。因此,本文提出了两种基于Transformer的模型改进技术,称为重整器。首先用局部敏感哈希代替原来的点乘注意,使得空之间的复杂度从O降低到O,其中L表示文本序列的长度。其次,将标准残差替换为反向残差层,使得激活值只需要存储一次,而不是N次,其中N代表网络层的数量。最终结果表明,重整器的性能与Transformer相当,在长序列上具有更高的存储效率和更快的速度。
介绍
训练Transformer模型真的需要大量资源并且效率很低吗?以现有最大Transformer层为例,Transformer层中的参数量为0.5B,需要2GB内存。对于由64Ktokens组成的序列,如果嵌入层的大小为1024,批处理大小为8,则激活值需要64K * 1K * 8=0.5B浮点数来存储,这又需要2GB内存。如果每层的内存占用只是上面提到的,那么使用Transformer在单个加速器上处理64K长度的序列是很容易的。此外,只需要17GB的内存来训练BERT的整个语料库。然而,现实并非如此。为什么在真实环境中,这些模型甚至不能在单台机器上进行微调?
这是因为上面只考虑了单层参数的内存占用和输入激活值的内存消耗,而忽略了Transformer内存占用的主要问题:
-需要存储激活值进行反向传播,因此n层模型的内存占用是单层的n倍;
-因为中间全连接层的深度d_通常远大于注意力激活层的深度d_需要占用大量内存;
-长度为l的序列的关注时间和空之间的复杂度为o,那么对于64K令牌的序列,内存将被耗尽。
因此,本文提出了重整器模型来解决上述问题,具体采用了以下方案:
-可逆层,在整个模型中仅使用单个副本可以消除层数因子n。
-前馈层单独激活,分块处理,消除了d_ factor的影响,减少了前馈层的内存占用。
-使用基于局部敏感哈希的近似关注度计算,关注层的O因子变为O,使得在长序列上进行处理成为可能。
重整器模型对以下三个任务进行实验:合成任务、文本任务和图像生成任务。实验结果表明,该重整器相当于变压器,但速度更快,存储效率更高。
本地敏感散列注意
点击并加倍关注:
使用点乘的标准Transformer的注意事项,查询和键的维度为d_k,值的维度为d _ v,查询首先乘以键,然后除以根d_k,再输入softmax得到值的权重,最后将权重乘以值得到最终结果。在实际操作过程中,以矩阵方式进行批量操作,查询构成矩阵Q,键构成矩阵K,值构成矩阵v,以上过程总结如下:
龙注意了:
上述注意操作并行执行h次,然后输出维度为d_v的输出结果。然后拼接这些结果,做一个投影运算得到最终结果。所谓多头关注。
高效的记忆注意力:
我们先来计算一下上述注意机制消耗的内存。假设q,k和v的维数为。QK^T的维度是。当长度=64k时,即使batch_size=1,64k*64k矩阵按32位浮点数存储也需要16GB内存。鉴于此,在长序列上使用Transformer是不切实际的。然而,应当注意,QK^T矩阵不需要存储在存储器中,并且可以为每个查询单独计算注意力。使用反向传播时,再次重新计算梯度。虽然用这种方法计算注意力效率不高,但占用的内存与长度成正比。这种方法在本文中被用作充分关注的基线。
q、k和v来自哪里?
上面讨论了q、K、V,但一般我们只得到大小为A的激活值,即令牌嵌入形成的句子向量。然后,为了从a中获取q、k和v,Transformer使用三个不同的线性层将a投影为q、k和v。对于使用本地敏感哈希注意的模型,我们希望查询和键相同。它只需要A投射到Q,A投射到K的线性变换参数相同,而A投射到V的参数不同。这样,它就变成了一个共享的QK变压器。实验表明,即使增加归一化项d_k,共享QK也不会影响变压器的性能。
分散注意力:
在LSH注意中,假设q、k和v的大小为0,并且仍然使用之前介绍的多头注意机制。那么QK^T的面积就是。由于softmax的计算结果主要取决于取值最大的部分,所以对于每次查询,我们只需要关注k中最接近查询的点。当k的长度为64k时,本文对于每个查询只考虑最新的32或64个键。这样会更有效率,那么如何找到最近的钥匙呢?
本地敏感哈希:
本地敏感哈希可用于查找高纬度空的最近邻居。每个向量x由哈希函数h映射,如果近向量高概率得到相同的哈希,而远向量没有,那么这样的哈希称为位置敏感哈希。在这个例子中,我们实际上只要求邻居的向量具有相同的高概率哈希值,哈希桶也具有相同的大概率大小。
具体来说,使用图1所示的随机投影方法:
上图中的角度LSH是一种常用的LSH算法,它将点投射到一个单位球体上,该球体被分成预定义的区域,每个区域都有一个特定的代码。然后,一系列随机旋转的点定义了这些点所属的桶。下面用一个简单的2D例子来说明这一点,https://miro.medium.com/max/1052/1 * bj8d4k 05 gz8 or-aqmhyyyva . gif。
LSH注意:
综合考虑上述LSH策略和散列注意,首先在位置I重写单个查询的常规注意:
图片来源:https://towards tasciety . com/图解-重整器-393575ac6ba0
有两个点,投影到一个单位圆上,以不同的角度随机旋转三次。可以观察到,它们不太可能共享同一个散列桶。在下面的示例中,可以看到在3次随机旋转后,两个非常接近的点将位于同一个散列桶中:
https://miro . medium . com/max/1052/1 * aarg6a 26 kqbilekt 43 fxlw . gif
角LSH最近邻搜索的简化动画:两点非常接近的情况。
图片来源:https://towards tasciety . com/图解-重整器-393575ac6ba0
其中P_i表示位置I处查询所需的关注集,z表示softmax等分区函数中的归一化项。为了写得清楚,这里省略了比例项root d_k。
对于批处理操作,当不在P_i中的元素被屏蔽时,一般注意事项定义如下:
也就是说,对于不可顾的位置,m为正无穷大,那么q_i* k_j减去正无穷大再进行exp运算,结果为0。这样,不需要为每个位置I设置单独的P_i..
在LSH注意中,查询中位置I可以参与的限制集P_i被限制为一个散列桶。图2显示了完全关注和散列关注之间的比较。
图a:在常规注意机制中,黑点代表softmax中的主导地位。注意编码器的注意,否则q_3不能注意k_6。另外,注意力充分的注意力矩阵一般都是稀疏的,但这种稀疏性在计算中并没有用到,因此可以用来降低时间空之间的复杂度。
图B:计算查询和关键字所属的散列桶。然后按照桶排序,同一个桶按照原来的位置排序得到图b,可以看到在同一个桶中,可能有多个查询,但是键很少。例如,图中有三个蓝色的桶,它们都附着在同一个键上。由于相似的物品很可能落入同一个桶中,所以只有在每个桶中注意才能近似完全注意。
图c:为了缓解桶中q和k的不平衡,本文通过制作$ k _ = frac { { left u q u right u| } $,使h=h,即使使用了共享-QK注意。然后,根据桶序列号对查询进行排序,在每个桶中,它们仍然根据原始位置大小进行排序。得到图C..对比图b和图c,可以看到竖轴的k变成了q,此时可以保证对角线都是专心的,桶中q和k的个数是一样的。在排序的关注矩阵中,同一个桶的值将聚集在对角线附近。注意,图中的对角点是空的中心,因为虽然正常情况下,Q会尝试自身位置的值,在share-QK的实现中,如果它尝试自身,它的值会极大,而其他的值会极小。通过softmax后,其他都是0,所以本身就是1。因此,为了避免这种情况,Q不会关注自身位置的价值,除非只有自身能够关注。
图d:即使Q=K,还是会出现一个问题:桶多桶少。例如,在一个极端的情况下,有两个桶,一个桶占据所有的键,另一个桶是空,所以LSH注意没有影响。因此,在图c的基础上,增加了组块的操作。对输入进行排序后,每个桶的平均大小为$m=frac}}$。这里假设桶中的数字增加到平均值的两倍的概率足够低。对于存储桶中的每个查询,您可以尝试将自己和密钥与上一个存储桶中的哈希值相同。
总之,LSH注意做了以下两件事:
首先,找到Q和K矩阵的LSH哈希。
其次,在同一个哈希桶中计算k和q向量的标准关注度。
更具体地说,它可以分为以下五个步骤:
首先,让输入序列查询=键
第二,做LSH桶,即进行哈希计算,得到每个查询和关键字所属的桶。
第三,根据桶号对查询进行排序,在同一个桶内,根据查询的原始位置进行排序。
第四,对于排序后的新序列,执行组块分割
第五,对于每个查询,只关注自己和自己之前的块,关注这些候选集中同一个桶的键。
多轮LSH注意:
LSH是近似的,也就是说,它不能保证相似的输入可以在同一个桶里。为了缓解这一问题,采用了多轮LSH关注。即多次重复上述过程,使相似物品尽可能高概率落入同一桶,尽量避免相似物品落入不同桶。详情见附件a。
可逆层
如上所述,注意力的复杂性可以降低到与序列长度成线性比例。但是参数的复杂度还是很高的,那么如何进一步降低呢?在这里我们开始尝试解决前面介绍中提到的第二个和第三个问题,即大量编码器和解码器层的深度,以及全连接层FFN。
可逆剩余网络
RevNet的思想是每一层的激活都可以从下一层的激活中导出,所以不需要在内存中存储激活。在原始剩余层中,公式y=x+F输出激活。其中f是一个剩余函数。在RevNet中,输入x分为x_1和x_2两部分,然后通过不同的残差函数得到输出y_1和y _ 2:f和g:
根据以下结构,从输出中获得输入:
可逆变压器
那么如何将RevNet引入Transformer呢?关注层和FFN层通过ResNet连接,减少了内存消耗。具体来说,让f函数成为关注层,g函数成为FFN层。应当注意,层归一化被包括在剩余块中。
通过这种方式,使用可逆变压器消除了在每一层存储激活值的需要,从而避免了n _ 1项..可逆层可以代替标准剩余层,只存储一次激活,而不是训练过程中的N次。
组块
以上消除了n_l项的影响,深层网络仍然占用大量内存。FFN的中隐层纬度通常很大,比如d_=4k以上。因为FFN的计算与序列中的位置无关,所以可以将计算分成C块来减少内存的使用。虽然这个操作可以并行处理,但是一次只计算一个块,并且用时间来交换内存空。
此外,可逆操作和反向传播操作也分块处理。除了FFN,对于大词汇量的模型,输出端的对数概率被分成块,序列的每个部分的损失被计算一次。
实验结果
在图像生成任务imagenet64和enwik8-64K上进行了实验,评估了可逆层、共享查询密钥和LSH注意对内存、准确性和速度的影响。
可逆层和共享查询密钥的影响;
图3的左边部分验证了共享查询键的影响。从困惑曲线可以看出,共享QK注意并不逊色于常规注意。在enwik8数据集上收敛更快。换句话说,使用共享的QK注意力不会牺牲准确性。
图3右侧部分验证了可逆层的影响。在实验中,可逆层和常规Transformer的参数相同,学习曲线看起来几乎相同。这些结果表明,可逆变压器可以节省内存,而不牺牲准确性。
LSH关注的影响:
如图4所示,可以看到随着哈希数的增加,准确率也提高了。
更大的重整器模型:
图5显示了在envik8和imagenet64上具有不同层的重整器的性能。下图为大重整器随楼层数变化的指标结果,20层仍无压力。下图显示了不同序列长度下普通注意力和LSH注意力的速度比较。当序列很长时,LSH有显著的优势。
摘要
重整器将Transformer的建模能力与可以在长序列上高效执行的架构相结合,因此即使在处理大型模型时,它也可以使用更少的内存。这将有助于大规模和大规模的参数化Transformer模型变得更加广泛可用。此外,处理长序列的能力为重整器在许多发电任务中的应用开辟了道路。除了生成非常长的连贯文本,重整器还可以将Transformer模型的能力应用到其他领域,如时间序列预测、音乐、图像等。
研究生院:华中科技大学
研究兴趣:机器阅读理解、文本生成等。
在微信官方账号对话框回复“2020科技趋势”,获取《2020科技趋势报告》完整PDF!
微信官方账号回复了对话框中的“AI女神”,获得了完整版《人工智能世界最具影响力女学者报告》!
在微信官方账号对话框回复“AI10”,获得“谈人工智能未来十年”主题演讲PPT!
在微信官方账号对话框回复“GNN”,获得主题演讲《图神经网络与认知推理》PPT!
在微信官方账号对话框回复“学术搜索”,获得完整版“人工智能学术搜索”报告!
微信官方账号对话框回复“AI指数”,获得《2019人工智能指数报告》完整PDF!
在微信官方账号对话框回复“3D视觉”,获取3D视觉技术白皮书完整PDF!