人工智能
神经网络简单介绍
遍地开花的 Attention,你真的懂吗?
注意力机制到底在做什么,Q/K/V怎么来的?一文读懂Attention注意力机制
FlashAttention图解(如何加速Attention)
FlashAttention2详解(性能比FlashAttention提升200%)
【Attention(4)】【QKV的自注意力机制】 主要思路(笔记)
-
+
首页
FlashAttention图解(如何加速Attention)
最新FlashDecoding++ [](https://zhuanlan.zhihu.com/p/665595287) ## FlashAttention V2和V3版本详解: [](https://zhuanlan.zhihu.com/p/645376942)[](https://zhuanlan.zhihu.com/p/661478232) ## Motivation 当输入序列(sequence length)较长时,Transformer的计算过程缓慢且耗费内存,这是因为self-attention的time和memory complexity会随着sequence length的增加成二次增长。 标准Attention的中间结果\\mathbf{S}, \\mathbf{P}(见下文)通常需要通过高带宽内存(HBM)进行存取,两者所需内存空间复杂度为O(N^2)。本文分析: - FlashAttention: 对HBM访问的次数为O(N^2d^2M^{-1}) - Attention: 对HBM访问的次数为\\Omega\\left(N d+N^2\\right) 往往N \\gg d(例如GPT2中N=1024,d=64),因此FlashAttention会快很多。下图展示了两者在GPT-2上的Forward+Backward的GFLOPs、HBM、Runtime对比(A100 GPU): ![](http://yg9538.kmgy.top/img/2023/11/24/2023-11-24_200736_1960850.9068105534675462.jpeg) GPU中存储单元主要有HBM和SRAM:HBM容量大但是访问速度慢,SRAM容量小却有着较高的访问速度。例如:A100 GPU有40-80GB的HBM,带宽为1.5-2.0TB/s;每108个流式多核处理器各有192KB的片上SRAM,带宽估计约为19TB/s。可以看出,片上的SRAM比HBM快一个数量级,但尺寸要小许多数量级。 **综上,FlashAttention目的不是节约FLOPs,而是减少对HBM的访问。重点是FlashAttention在训练和预测过程中的结果和标准Attention一样,对用户是无感的,而其他加速方法做不到这点。** 阅读本文需要了解的符号定义: - N: sequence length - d: head dimension - M: the size of SRAM - \\Omega:大于等于的数量级复杂度 - O:小于等于的数量级复杂度 - \\Theta:同数量级的复杂度 - o:小于的数量级复杂度 ## Method ### Attention 标准Attention输入为\\mathbf{Q}, \\mathbf{K}, \\mathbf{V} \\in \\mathbb{R}^{N \\times d},输出为\\mathbf{O} \\in \\mathbb{R}^{N \\times d},计算如下: \\mathbf{S}=\\mathbf{Q K}^{\\top} \\in \\mathbb{R}^{N \\times N}, \\quad \\mathbf{P}=\\operatorname{softmax}(\\mathbf{S}) \\in \\mathbb{R}^{N \\times N}, \\quad \\mathbf{O}=\\mathbf{P V} \\in \\mathbb{R}^{N \\times d}, 其中softmax操作是row-wise的,即每行都算一次softmax,一共计算N行。 ![](http://yg9538.kmgy.top/img/2023/11/24/2023-11-24_200736_2009510.1382590292831618.jpeg) 计算流程图如下: ![](http://yg9538.kmgy.top/img/2023/11/24/2023-11-24_200736_2035230.9269051570416373.jpeg) ### FlashAttention 建议先阅读这篇[知乎文章](https://zhuanlan.zhihu.com/p/582606847),重复内容不再赘述。 ![](http://yg9538.kmgy.top/img/2023/11/24/2023-11-24_200736_2502670.43021302752694446.jpeg) ![](http://yg9538.kmgy.top/img/2023/11/24/2023-11-24_200736_3285980.6050048501226372.jpeg) 以GPT2和A100为例:A100的SRAM大小为192KB=196608B;GPT2中N=1024, d=64,对应的\\mathbf{Q}, \\mathbf{K}, \\mathbf{V}的维度为N \\times d = 1024 \\times 64,中间结果\\mathbf{S}, \\mathbf{P}的维度为N \\times N = 1024 \\times 1024。本例中FlashAttention的参数为: - B\_c = \\lceil 196608 / 4 / 64 \\rceil = 768;B\_r = \\min(768, 64) = 64 - T\_c = \\lceil 1024 / 768 \\rceil = 2;T\_r = \\lceil 1024 / 64 \\rceil = 16 对应的计算过程: - 每次外循环(outer loop,j)载入的\\mathbf{K}\_j, \\mathbf{V}\_j的大小为B\_c \\times d = 768 \\times d,一共循环T\_c = 2次 ![](http://yg9538.kmgy.top/img/2023/11/24/2023-11-24_200736_2581540.5225601918884349.jpeg) - 每次内循环(inner loop,i)载入的\\mathbf{Q}\_i, \\mathbf{O}\_i的大小为B\_r \\times d = 64 \\times d,一共循环T\_r = 16次(总次数还需要乘以外循环) - \\mathbf{S}\_{ij} = \\mathbf{Q}\_i \\times \\mathbf{K}\_j^T,即为(下标表示维度):C\_{64 \\times 768} = A\_{64 \\times d} \\times B\_{d \\times 768}。 - \\tilde{\\mathbf{P}}\_{ij} = \\text{row\_softmax}(\\mathbf{S}\_{ij}),\\tilde{\\mathbf{P}}\_{ij}表示和标准attention计算的\\mathbf{P}\_{ij}有区别,因为\\tilde{m}\_{ij} = \\text{row\_max}(\\mathbf{S}\_{ij})得到的最大值可能不是\\mathbf{S}第i行的最大值。\\tilde{\\mathbf{P}}\_{ij}的大小和\\mathbf{S}\_{ij}一样,都为B\_r \\times B\_c = 64 \\times 768。 ![](http://yg9538.kmgy.top/img/2023/11/24/2023-11-24_200736_2777140.28447508316064196.png) - \\tilde{\\mathbf{P}}\_{ij}和\\mathbf{S}\_{ij}只是部分结果,如下图所示,外循环j是横向(特征维d)移动的,内循环i是纵向(序列维N)移动的。**换句话说,外循环在顺序计算特征,内循环在顺序计算序列。** ![](http://yg9538.kmgy.top/img/2023/11/24/2023-11-24_200736_2747650.4178013660514406.jpeg) - \\mathbf{O}\_i的大小为B\_r \\times d,第二维d是满的(和最终\\mathbf{O}一样),**这意味着每次外循环都要重新更新当前批次中的特征**,即虽然第一次外循环\\tilde{P}\_{00} \\times V\_0和第二次外循环\\tilde{P}\_{01} \\times V\_1都会得到\\mathbf{O}\_0,但是第二次\\mathbf{O}\_0的是基于第一次\\mathbf{O}\_0重新生成的。 ![](http://yg9538.kmgy.top/img/2023/11/24/2023-11-24_200736_3171850.3079710434619485.jpeg) - \\text{diag}(...) 作用是将vector生成为一个对角矩阵,从而实现相同长度的两个vector进行element-wise相乘。 **Theorem 1.** FlashAttention的FLOPs为 O(N^2d) ,除了input和output,额外需要的内存为 O(N) 。 - Theorem 1的证明过程如下。 影响FLOPs的主要是matrix multiplication。在一次循环中: - Algorithm 1第9行:计算\\mathbf{Q}\_i \\mathbf{K}\_j^{\\top} \\in \\mathbb{R}^{B\_r \\times B\_c}。由于\\mathbf{Q}\_i \\in \\mathbb{R}^{B\_r \\times d},\\mathbf{K}\_j \\in \\mathbb{R}^{B\_c \\times d},因此一次计算需要的FLOPs为O(B\_r B\_c d)。 - Algorithm 1第12行:计算\\tilde{\\mathbf{P}}\_{i j} \\mathbf{V}\_j \\in \\mathbb{R}^{B\_r \\times d}。由于\\tilde{\\mathbf{P}}\_{i j} \\in \\mathbb{R}^{B\_r \\times B\_c},\\mathbf{V}\_j \\in \\mathbb{R}^{B\_c \\times d},因此一次计算需要的FLOPs为O(B\_r B\_c d)。 上述计算循环的总次数为T\_c T\_r=\\left\\lceil\\frac{N}{B\_c}\\right\\rceil\\left\[\\frac{N}{B\_r}\\right\\rceil,因此总的FLOPs为: O\\left(\\frac{N^2}{B\_c B\_r} B\_r B\_c d\\right)=O\\left(N^2 d\\right) **Theorem 2.** 如果SRAM的size M 满足 d \\leq M \\leq N d 。标准Attention对HBM访问的次数为 \\Omega\\left(N d+N^2\\right) ,而FlashAttention对HBM访问的次数为 O(N^2d^2M^{-1}) 。 - Theorem 2的证明过程如下。 需要从HBM读取的数据有: - Algorithm 1第6行:每次循环读取的\\mathbf{K}\_j, \\mathbf{V}\_j的size复杂度都为\\Theta(M),总size为\\Theta(Nd)。 - Algorithm 1第8行:每次循环读取的\\mathbf{Q}\_i, \\mathbf{O}\_i的size复杂度都为\\Theta(Nd),总次数为T\_c=\[\\frac{N}{B\_c}\] = \\Theta(\\frac{Nd}{M})。 FlashAttention对HBM总访问次数的复杂度为: \\Theta(Nd+NdT\_c) = \\Theta(NdTc) = \\Theta(N^2d^2M^{-1})
yg9538
2023年11月24日 20:07
转发文档
收藏文档
上一篇
下一篇
手机扫码
复制链接
手机扫一扫转发分享
复制链接
Markdown文件
PDF文档
PDF文档(打印)
分享
链接
类型
密码
更新密码