FlashAttention图解(如何加速Attention)


最新FlashDecoding++

FlashAttention V2和V3版本详解:

Motivation

当输入序列(sequence length)较长时,Transformer的计算过程缓慢且耗费内存,这是因为self-attention的time和memory complexity会随着sequence length的增加成二次增长。

标准Attention的中间结果\mathbf{S}, \mathbf{P}(见下文)通常需要通过高带宽内存(HBM)进行存取,两者所需内存空间复杂度为O(N^2)。本文分析:

  • FlashAttention: 对HBM访问的次数为O(N^2d^2M^{-1})
  • Attention: 对HBM访问的次数为\Omega\left(N d+N^2\right)

往往N \gg d(例如GPT2中N=1024,d=64),因此FlashAttention会快很多。下图展示了两者在GPT-2上的Forward+Backward的GFLOPs、HBM、Runtime对比(A100 GPU):

invalid image (图片无法加载)

GPU中存储单元主要有HBM和SRAM:HBM容量大但是访问速度慢,SRAM容量小却有着较高的访问速度。例如:A100 GPU有40-80GB的HBM,带宽为1.5-2.0TB/s;每108个流式多核处理器各有192KB的片上SRAM,带宽估计约为19TB/s。可以看出,片上的SRAM比HBM快一个数量级,但尺寸要小许多数量级。

综上,FlashAttention目的不是节约FLOPs,而是减少对HBM的访问。重点是FlashAttention在训练和预测过程中的结果和标准Attention一样,对用户是无感的,而其他加速方法做不到这点。

阅读本文需要了解的符号定义:

  • N: sequence length
  • d: head dimension
  • M: the size of SRAM
  • \Omega:大于等于的数量级复杂度
  • O:小于等于的数量级复杂度
  • \Theta:同数量级的复杂度
  • o:小于的数量级复杂度

Method

Attention

标准Attention输入为\mathbf{Q}, \mathbf{K}, \mathbf{V} \in \mathbb{R}^{N \times d},输出为\mathbf{O} \in \mathbb{R}^{N \times d},计算如下:

\mathbf{S}=\mathbf{Q K}^{\top} \in \mathbb{R}^{N \times N}, \quad \mathbf{P}=\operatorname{softmax}(\mathbf{S}) \in \mathbb{R}^{N \times N}, \quad \mathbf{O}=\mathbf{P V} \in \mathbb{R}^{N \times d},

其中softmax操作是row-wise的,即每行都算一次softmax,一共计算N行。

invalid image (图片无法加载)

计算流程图如下:

invalid image (图片无法加载)

FlashAttention

建议先阅读这篇知乎文章,重复内容不再赘述。

invalid image (图片无法加载)

invalid image (图片无法加载)

以GPT2和A100为例:A100的SRAM大小为192KB=196608B;GPT2中N=1024, d=64,对应的\mathbf{Q}, \mathbf{K}, \mathbf{V}的维度为N \times d = 1024 \times 64,中间结果\mathbf{S}, \mathbf{P}的维度为N \times N = 1024 \times 1024。本例中FlashAttention的参数为:

  • B_c = \lceil 196608 / 4 / 64 \rceil = 768;B_r = \min(768, 64) = 64
  • T_c = \lceil 1024 / 768 \rceil = 2;T_r = \lceil 1024 / 64 \rceil = 16

对应的计算过程:

  • 每次外循环(outer loop,j)载入的\mathbf{K}_j, \mathbf{V}_j的大小为B_c \times d = 768 \times d,一共循环T_c = 2次

invalid image (图片无法加载)

  • 每次内循环(inner loop,i)载入的\mathbf{Q}_i, \mathbf{O}_i的大小为B_r \times d = 64 \times d,一共循环T_r = 16次(总次数还需要乘以外循环)

  • \mathbf{S}_{ij} = \mathbf{Q}_i \times \mathbf{K}_j^T,即为(下标表示维度):C_{64 \times 768} = A_{64 \times d} \times B_{d \times 768}。

  • \tilde{\mathbf{P}}_{ij} = \text{row_softmax}(\mathbf{S}_{ij}),\tilde{\mathbf{P}}_{ij}表示和标准attention计算的\mathbf{P}_{ij}有区别,因为\tilde{m}_{ij} = \text{row_max}(\mathbf{S}_{ij})得到的最大值可能不是\mathbf{S}第i行的最大值。\tilde{\mathbf{P}}_{ij}的大小和\mathbf{S}_{ij}一样,都为B_r \times B_c = 64 \times 768。

invalid image (图片无法加载)

  • \tilde{\mathbf{P}}_{ij}和\mathbf{S}_{ij}只是部分结果,如下图所示,外循环j是横向(特征维d)移动的,内循环i是纵向(序列维N)移动的。换句话说,外循环在顺序计算特征,内循环在顺序计算序列。

invalid image (图片无法加载)

  • \mathbf{O}_i的大小为B_r \times d,第二维d是满的(和最终\mathbf{O}一样),这意味着每次外循环都要重新更新当前批次中的特征,即虽然第一次外循环\tilde{P}_{00} \times V_0和第二次外循环\tilde{P}_{01} \times V_1都会得到\mathbf{O}_0,但是第二次\mathbf{O}_0的是基于第一次\mathbf{O}_0重新生成的。

invalid image (图片无法加载)

  • \text{diag}(...) 作用是将vector生成为一个对角矩阵,从而实现相同长度的两个vector进行element-wise相乘。

Theorem 1. FlashAttention的FLOPs为 O(N^2d) ,除了input和output,额外需要的内存为 O(N) 。

  • Theorem 1的证明过程如下。
    影响FLOPs的主要是matrix multiplication。在一次循环中:

  • Algorithm 1第9行:计算\mathbf{Q}_i \mathbf{K}_j^{\top} \in \mathbb{R}^{B_r \times B_c}。由于\mathbf{Q}_i \in \mathbb{R}^{B_r \times d},\mathbf{K}_j \in \mathbb{R}^{B_c \times d},因此一次计算需要的FLOPs为O(B_r B_c d)。

  • Algorithm 1第12行:计算\tilde{\mathbf{P}}_{i j} \mathbf{V}_j \in \mathbb{R}^{B_r \times d}。由于\tilde{\mathbf{P}}_{i j} \in \mathbb{R}^{B_r \times B_c},\mathbf{V}_j \in \mathbb{R}^{B_c \times d},因此一次计算需要的FLOPs为O(B_r B_c d)。

上述计算循环的总次数为T_c T_r=\left\lceil\frac{N}{B_c}\right\rceil\left[\frac{N}{B_r}\right\rceil,因此总的FLOPs为:
O\left(\frac{N^2}{B_c B_r} B_r B_c d\right)=O\left(N^2 d\right)

Theorem 2. 如果SRAM的size M 满足 d \leq M \leq N d 。标准Attention对HBM访问的次数为 \Omega\left(N d+N^2\right) ,而FlashAttention对HBM访问的次数为 O(N^2d^2M^{-1}) 。

  • Theorem 2的证明过程如下。
    需要从HBM读取的数据有:

  • Algorithm 1第6行:每次循环读取的\mathbf{K}_j, \mathbf{V}_j的size复杂度都为\Theta(M),总size为\Theta(Nd)。

  • Algorithm 1第8行:每次循环读取的\mathbf{Q}_i, \mathbf{O}_i的size复杂度都为\Theta(Nd),总次数为T_c=[\frac{N}{B_c}] = \Theta(\frac{Nd}{M})。

FlashAttention对HBM总访问次数的复杂度为:
\Theta(Nd+NdT_c) = \Theta(NdTc) = \Theta(N^2d^2M^{-1})


yg9538 2023年11月24日 20:07 收藏文档