KV Cache
如下以GPT2结构第i层推理过程为例,分析KV Cahche计算过程,其中表示第i层decoder权重矩阵。
全量prompt阶段
第i层输入([b,s,h]),self-attention块 ([b,s,h])
KV Cahche计算:
第i层attention、mlp计算:
增量token阶段
增量推理时,当前生成词在第i层表示,推理时执行:更新KV Cache和计算第i层输出。
更新KV Cahce:
第i层计算过程:
KV Cache缓存机制如图:
KV Cache显存分析
KV cache的峰值显存占用大小: ,输入序列长度s,输出序列长度n,第一个2表示k/v cache,第二个2表示fp16占用2个字节,transformer模型的层数为l,隐藏层维度为h。
以GPT3(175B)为例分析KV Cache与模型参数大小,GPT3模型weight占用350GB(FP16),层数l=96,维度h=12888。
| bs | s+n | kv cache(GB) | kv cache/weight |
|---|---|---|---|
| 4 | 4096 | 75.5 | 0.22 |
| 16 | 4096 | 302 | 0.86 |
| 64 | 4096 | 1208 | 3.45 |
根据上述数据,随着batch增大和长度增大,KV Cahche开销快速增大,甚至超过模型参数本身。LLM的窗口长度不断增大,KV Cahche开销随之不断增大,优化KV Cahche非常必要。
优化KV Cache的必要性:
(1)不断增长的LLM窗口长度,与有限GPU显存资源之间矛盾;
(2)消费级显卡的显存较小,kv cache限制模型的batchsize;
(3)sora/sd3等文生视频、文生图模型,放弃u-net架构,转向支持diffusion transformer架构;kv cahce优化对这类AIGC模型同样能起到加速效果。
KV Cache优化方法
总结典型KV Cache优化手段如下。
1 共用KV Cache(MQA和GQA)
MQA(Multi Query Attention)多查询注意力是MHA多头注意力的变体。两者主要区别是MQA中不同头共享一组KV,每个头只保留查询参数Q。KV矩阵只有一份,大幅减少内存。
由于MQA改变注意力机制结构,模型需要从训练开始就支持MQA,或通过对已训练好的模型微调支持MQA,仅需约5%的原始数据量即可达到不过效果。Falcon、SantaCoder、StarCoder 等模型都采用了MQA机。
# Multi Head Attention
self.Wqkv = nn.Linear( # Multi-Head Attention 的创建方法
self.d_model,
3 * self.d_model, # Q、K和V 3 个矩阵, 所以是 3 * d_model
device=device
)
query, key, value = qkv.chunk(3, dim=2) # 每个 tensor 都是 (1, 512, 768)
# Multi Query Attention
self.Wqkv = nn.Linear( # Multi-Query Attention 的创建方法
d_model,
d_model + 2 * self.head_dim, # 只创建Q的头向量,所以是 1* d_model, 而K和V不再具备单独的头向量, 所以是 2 * self.head_dim
device=device,
)
query, key, value = qkv.split(
[self.d_model, self.head_dim, self.head_dim], # query -> (1, 512, 768), key -> (1, 512, 96), value -> (1, 512, 96)
dim=2
)
GQA (Grouped Query Attention,分组查询注意力),介于MHQ和MQA之间的折中方案。按查询头Q分组,每个组共享一个K和V。表达能力与推理性能兼顾。
MHA、MQA与GQA原理:
MQA与GQA性能对比:
GQA既保留了多头注意力的一定表达能力,又通过减少内存访问压力来加速推理速度。
2 窗口优化
当推理文本长度T大于训练最大长度L时,需要滑动窗口:
(1)固定窗口长度(图b)
代表是Longformer,实现简单,空间复杂度只有O(TL),但精度下降比较大。
(2)KV重计算(图c)
每次计算都重新计算长度为的 KV cache,由于重计算的存在,其精度可以保证,但是性能损失比较大。
(3)箭型attention窗口,基本原理和(StreamingLLM)[arxiv.org/pdf/2309.17…
3 量化与稀疏
通过量化与稀疏压缩 KV cache的显存消耗。
-
量化方法 主流推理框架都在逐步支持 KV cache 量化,如lmdeploy
-
稀疏方法 典型稀疏方式:
(H2O)[browse.arxiv.org/pdf/2306.14…
结果显示,KV cache稀疏到只有原来20%时仍然可以保持很高精度。
4 存储与计算优化
典型方法是vLLM的PagedAttention。
FlashDecoding 是在FlashAttention基础上对inference的优化,主要分三步:
(1)长文本下将KV分成更小且方便并行的chunk
(2)对每个chunk的KV,Q和他们进行之前一样的FlashAttention获取这个chunk的结果
(3)对每个chunk的结果进行reduce gif图如下:
StreamingLLM
StreamingLLM:简洁高效的“无限长度”,基本思想来源于上述窗口思想。
参考:
browse.arxiv.org/pdf/2306.14…
arxiv.org/pdf/2309.17…
arxiv.org/pdf/2305.13…
developer.nvidia.com/zh-cn/blog/…
zhuanlan.zhihu.com/p/659770503