LLM系列:KVCache及优化方法

1,643 阅读4分钟

KV Cache

如下以GPT2结构第i层推理过程为例,分析KV Cahche计算过程,其中WQiWKiWViWOiWupiWdowniW_Q^i,W_K^i,W_V^i,W_O^i,W_{up}^i,W_{down}^i表示第i层decoder权重矩阵。

全量prompt阶段

第i层输入xix^i([b,s,h]),self-attention块 xKixVixQix_K^i,x_V^i,x_Q^i([b,s,h])

KV Cahche计算:

xKi=xiWKi[b,s,h]x_K^i = x^i \cdot W_K^i([b,s,h]) xVi=xiWVi[b,s,h]x_V^i = x^i \cdot W_V^i([b,s,h])

第i层attention、mlp计算:

xQi=xiWQi[b,s,h]x_Q^i = x^i \cdot W_Q^i([b,s,h])

xouti=softmax(xQixKiTh)xViWOi+xi([b,s,h])x_{out}^i=softmax(\frac{x_Q^i{x_K^i}^T}{\sqrt{h}})\cdot x_V^i \cdot W_O^i + x^i ([b,s,h])

xi+1=fgelu(xoutiWupi)Wdowni+xouti([b,s,h])x^{i+1}=f_{gelu}(x_{out}^i \cdot W_{up}^i) \cdot W_{down}^i + x_{out}^i ([b,s,h])

增量token阶段

增量推理时,当前生成词在第i层表示ti([b,1,h])t^i ([b,1,h]),推理时执行:更新KV Cache和计算第i层输出。

更新KV Cahce:

xKiConcat(xKi,tiWKi)([b,s+1,h])x_K^i \leftarrow Concat(x_K^i, t^i \cdot W_K^i) ([b,s+1, h])
xViConcat(xVi,tiWVi)([b,s+1,h])x_V^i \leftarrow Concat(x_V^i, t^i \cdot W_V^i) ([b,s+1, h])

第i层计算过程:

tQi=tiWQi([b,1,h])t_Q^i = t^i \cdot W_Q^i ([b,1,h]) touti=softmax(tQixKiTh)xViWOi+ti([b,1,h])t_{out}^i = softmax(\frac{t_Q^i{x_K^i}^T}{\sqrt{h}})\cdot x_V^i \cdot W_O^i + t^i ([b,1,h])

ti+1=fgelu(toutiWupi)Wdowni+touti([b,1,h])t^{i+1}=f_{gelu}(t_{out}^i \cdot W_{up}^i) \cdot W_{down}^i + t_{out}^i ([b,1,h])

KV Cache缓存机制如图:

image.png

KV Cache显存分析

KV cache的峰值显存占用大小: b(s+n)hl22=4blh(s+n)b(s+n)∗h∗l∗2∗2=4blh(s+n),输入序列长度s,输出序列长度n,第一个2表示k/v cache,第二个2表示fp16占用2个字节,transformer模型的层数为l,隐藏层维度为h。

以GPT3(175B)为例分析KV Cache与模型参数大小,GPT3模型weight占用350GB(FP16),层数l=96,维度h=12888。

bss+nkv cache(GB)kv cache/weight
4409675.50.22
1640963020.86
64409612083.45

根据上述数据,随着batch增大和长度增大,KV Cahche开销快速增大,甚至超过模型参数本身。LLM的窗口长度不断增大,KV Cahche开销随之不断增大,优化KV Cahche非常必要。

优化KV Cache的必要性:
(1)不断增长的LLM窗口长度,与有限GPU显存资源之间矛盾;
(2)消费级显卡的显存较小,kv cache限制模型的batchsize;
(3)sora/sd3等文生视频、文生图模型,放弃u-net架构,转向支持diffusion transformer架构;kv cahce优化对这类AIGC模型同样能起到加速效果。

KV Cache优化方法

总结典型KV Cache优化手段如下。

1 共用KV Cache(MQA和GQA)

MQA(Multi Query Attention)多查询注意力是MHA多头注意力的变体。两者主要区别是MQA中不同头共享一组KV,每个头只保留查询参数Q。KV矩阵只有一份,大幅减少内存。 由于MQA改变注意力机制结构,模型需要从训练开始就支持MQA,或通过对已训练好的模型微调支持MQA,仅需约5%的原始数据量即可达到不过效果。Falcon、SantaCoder、StarCoder 等模型都采用了MQA机。

# Multi Head Attention
self.Wqkv = nn.Linear(     # Multi-Head Attention 的创建方法
    self.d_model,
    3 * self.d_model,     # Q、K和V 3 个矩阵, 所以是 3 * d_model
    device=device
)
query, key, value = qkv.chunk(3, dim=2)      # 每个 tensor 都是 (1, 512, 768)

# Multi Query Attention
self.Wqkv = nn.Linear(       # Multi-Query Attention 的创建方法
    d_model,
    d_model + 2 * self.head_dim,    # 只创建Q的头向量,所以是 1* d_model, 而K和V不再具备单独的头向量, 所以是 2 * self.head_dim
    device=device,
)
query, key, value = qkv.split(
    [self.d_model, self.head_dim, self.head_dim],    # query -> (1, 512, 768), key   -> (1, 512, 96), value -> (1, 512, 96)
    dim=2
)

GQA (Grouped Query Attention,分组查询注意力),介于MHQ和MQA之间的折中方案。按查询头Q分组,每个组共享一个K和V。表达能力与推理性能兼顾。

MHA、MQA与GQA原理:

image.png

MQA与GQA性能对比:

image.png

GQA既保留了多头注意力的一定表达能力,又通过减少内存访问压力来加速推理速度。

2 窗口优化

当推理文本长度T大于训练最大长度L时,需要滑动窗口:
(1)固定窗口长度(图b)
代表是Longformer,实现简单,空间复杂度只有O(TL),但精度下降比较大。
(2)KV重计算(图c)
每次计算都重新计算长度为的 KV cache,由于重计算的存在,其精度可以保证,但是性能损失比较大。
(3)箭型attention窗口,基本原理和(StreamingLLM)[arxiv.org/pdf/2309.17…

image.png

3 量化与稀疏

通过量化与稀疏压缩 KV cache的显存消耗。

  • 量化方法 主流推理框架都在逐步支持 KV cache 量化,如lmdeploy

  • 稀疏方法 典型稀疏方式:

image.png

(H2O)[browse.arxiv.org/pdf/2306.14…

结果显示,KV cache稀疏到只有原来20%时仍然可以保持很高精度。

image.png

4 存储与计算优化

典型方法是vLLM的PagedAttention。

FlashDecoding 是在FlashAttention基础上对inference的优化,主要分三步:

(1)长文本下将KV分成更小且方便并行的chunk
(2)对每个chunk的KV,Q和他们进行之前一样的FlashAttention获取这个chunk的结果
(3)对每个chunk的结果进行reduce gif图如下:

v2-13fcb10493400523013dcfe55cc9b846_b.gif

StreamingLLM

StreamingLLM:简洁高效的“无限长度”,基本思想来源于上述窗口思想。

参考:
browse.arxiv.org/pdf/2306.14…
arxiv.org/pdf/2309.17…
arxiv.org/pdf/2305.13…
developer.nvidia.com/zh-cn/blog/…
zhuanlan.zhihu.com/p/659770503