人工智能
神经网络简单介绍
遍地开花的 Attention,你真的懂吗?
注意力机制到底在做什么,Q/K/V怎么来的?一文读懂Attention注意力机制
FlashAttention图解(如何加速Attention)
FlashAttention2详解(性能比FlashAttention提升200%)
【Attention(4)】【QKV的自注意力机制】 主要思路(笔记)
-
+
home
FlashAttention图解(如何加速Attention)
最新FlashDecoding++ [](https://zhuanlan.zhihu.com/p/665595287) ## FlashAttention V2和V3版本详解: [](https://zhuanlan.zhihu.com/p/645376942)[](https://zhuanlan.zhihu.com/p/661478232) ## Motivation 当输入序列(sequence length)较长时,Transformer的计算过程缓慢且耗费内存,这是因为self-attention的time和memory complexity会随着sequence length的增加成二次增长。 标准Attention的中间结果\\mathbf{S}, \\mathbf{P}(见下文)通常需要通过高带宽内存(HBM)进行存取,两者所需内存空间复杂度为O(N^2)。本文分析: - FlashAttention: 对HBM访问的次数为O(N^2d^2M^{-1}) - Attention: 对HBM访问的次数为\\Omega\\left(N d+N^2\\right) 往往N \\gg d(例如GPT2中N=1024,d=64),因此FlashAttention会快很多。下图展示了两者在GPT-2上的Forward+Backward的GFLOPs、HBM、Runtime对比(A100 GPU):  GPU中存储单元主要有HBM和SRAM:HBM容量大但是访问速度慢,SRAM容量小却有着较高的访问速度。例如:A100 GPU有40-80GB的HBM,带宽为1.5-2.0TB/s;每108个流式多核处理器各有192KB的片上SRAM,带宽估计约为19TB/s。可以看出,片上的SRAM比HBM快一个数量级,但尺寸要小许多数量级。 **综上,FlashAttention目的不是节约FLOPs,而是减少对HBM的访问。重点是FlashAttention在训练和预测过程中的结果和标准Attention一样,对用户是无感的,而其他加速方法做不到这点。** 阅读本文需要了解的符号定义: - N: sequence length - d: head dimension - M: the size of SRAM - \\Omega:大于等于的数量级复杂度 - O:小于等于的数量级复杂度 - \\Theta:同数量级的复杂度 - o:小于的数量级复杂度 ## Method ### Attention 标准Attention输入为\\mathbf{Q}, \\mathbf{K}, \\mathbf{V} \\in \\mathbb{R}^{N \\times d},输出为\\mathbf{O} \\in \\mathbb{R}^{N \\times d},计算如下: \\mathbf{S}=\\mathbf{Q K}^{\\top} \\in \\mathbb{R}^{N \\times N}, \\quad \\mathbf{P}=\\operatorname{softmax}(\\mathbf{S}) \\in \\mathbb{R}^{N \\times N}, \\quad \\mathbf{O}=\\mathbf{P V} \\in \\mathbb{R}^{N \\times d}, 其中softmax操作是row-wise的,即每行都算一次softmax,一共计算N行。  计算流程图如下:  ### FlashAttention 建议先阅读这篇[知乎文章](https://zhuanlan.zhihu.com/p/582606847),重复内容不再赘述。   以GPT2和A100为例:A100的SRAM大小为192KB=196608B;GPT2中N=1024, d=64,对应的\\mathbf{Q}, \\mathbf{K}, \\mathbf{V}的维度为N \\times d = 1024 \\times 64,中间结果\\mathbf{S}, \\mathbf{P}的维度为N \\times N = 1024 \\times 1024。本例中FlashAttention的参数为: - B\_c = \\lceil 196608 / 4 / 64 \\rceil = 768;B\_r = \\min(768, 64) = 64 - T\_c = \\lceil 1024 / 768 \\rceil = 2;T\_r = \\lceil 1024 / 64 \\rceil = 16 对应的计算过程: - 每次外循环(outer loop,j)载入的\\mathbf{K}\_j, \\mathbf{V}\_j的大小为B\_c \\times d = 768 \\times d,一共循环T\_c = 2次  - 每次内循环(inner loop,i)载入的\\mathbf{Q}\_i, \\mathbf{O}\_i的大小为B\_r \\times d = 64 \\times d,一共循环T\_r = 16次(总次数还需要乘以外循环) - \\mathbf{S}\_{ij} = \\mathbf{Q}\_i \\times \\mathbf{K}\_j^T,即为(下标表示维度):C\_{64 \\times 768} = A\_{64 \\times d} \\times B\_{d \\times 768}。 - \\tilde{\\mathbf{P}}\_{ij} = \\text{row\_softmax}(\\mathbf{S}\_{ij}),\\tilde{\\mathbf{P}}\_{ij}表示和标准attention计算的\\mathbf{P}\_{ij}有区别,因为\\tilde{m}\_{ij} = \\text{row\_max}(\\mathbf{S}\_{ij})得到的最大值可能不是\\mathbf{S}第i行的最大值。\\tilde{\\mathbf{P}}\_{ij}的大小和\\mathbf{S}\_{ij}一样,都为B\_r \\times B\_c = 64 \\times 768。  - \\tilde{\\mathbf{P}}\_{ij}和\\mathbf{S}\_{ij}只是部分结果,如下图所示,外循环j是横向(特征维d)移动的,内循环i是纵向(序列维N)移动的。**换句话说,外循环在顺序计算特征,内循环在顺序计算序列。**  - \\mathbf{O}\_i的大小为B\_r \\times d,第二维d是满的(和最终\\mathbf{O}一样),**这意味着每次外循环都要重新更新当前批次中的特征**,即虽然第一次外循环\\tilde{P}\_{00} \\times V\_0和第二次外循环\\tilde{P}\_{01} \\times V\_1都会得到\\mathbf{O}\_0,但是第二次\\mathbf{O}\_0的是基于第一次\\mathbf{O}\_0重新生成的。  - \\text{diag}(...) 作用是将vector生成为一个对角矩阵,从而实现相同长度的两个vector进行element-wise相乘。 **Theorem 1.** FlashAttention的FLOPs为 O(N^2d) ,除了input和output,额外需要的内存为 O(N) 。 - Theorem 1的证明过程如下。 影响FLOPs的主要是matrix multiplication。在一次循环中: - Algorithm 1第9行:计算\\mathbf{Q}\_i \\mathbf{K}\_j^{\\top} \\in \\mathbb{R}^{B\_r \\times B\_c}。由于\\mathbf{Q}\_i \\in \\mathbb{R}^{B\_r \\times d},\\mathbf{K}\_j \\in \\mathbb{R}^{B\_c \\times d},因此一次计算需要的FLOPs为O(B\_r B\_c d)。 - Algorithm 1第12行:计算\\tilde{\\mathbf{P}}\_{i j} \\mathbf{V}\_j \\in \\mathbb{R}^{B\_r \\times d}。由于\\tilde{\\mathbf{P}}\_{i j} \\in \\mathbb{R}^{B\_r \\times B\_c},\\mathbf{V}\_j \\in \\mathbb{R}^{B\_c \\times d},因此一次计算需要的FLOPs为O(B\_r B\_c d)。 上述计算循环的总次数为T\_c T\_r=\\left\\lceil\\frac{N}{B\_c}\\right\\rceil\\left\[\\frac{N}{B\_r}\\right\\rceil,因此总的FLOPs为: O\\left(\\frac{N^2}{B\_c B\_r} B\_r B\_c d\\right)=O\\left(N^2 d\\right) **Theorem 2.** 如果SRAM的size M 满足 d \\leq M \\leq N d 。标准Attention对HBM访问的次数为 \\Omega\\left(N d+N^2\\right) ,而FlashAttention对HBM访问的次数为 O(N^2d^2M^{-1}) 。 - Theorem 2的证明过程如下。 需要从HBM读取的数据有: - Algorithm 1第6行:每次循环读取的\\mathbf{K}\_j, \\mathbf{V}\_j的size复杂度都为\\Theta(M),总size为\\Theta(Nd)。 - Algorithm 1第8行:每次循环读取的\\mathbf{Q}\_i, \\mathbf{O}\_i的size复杂度都为\\Theta(Nd),总次数为T\_c=\[\\frac{N}{B\_c}\] = \\Theta(\\frac{Nd}{M})。 FlashAttention对HBM总访问次数的复杂度为: \\Theta(Nd+NdT\_c) = \\Theta(NdTc) = \\Theta(N^2d^2M^{-1})
yg9538
Nov. 24, 2023, 8:07 p.m.
Forward the document
Save to Collection
Last
Next
Scan the QR code with your phone
Copy link
手机扫一扫转发分享
Copy link
Markdown文件
Word document
PDF document
PDF document (print)
share
link
type
password
Update password
Validity period