分享
Deploy029 GQA MHA MLA
输入“/”快速插入内容
Deploy029 GQA MHA MLA
用户5190
用户5190
用户3543
用户3543
2024年11月18日修改
💡
生成阶段,通过缓存KV Cache可以减少重复计算,但是内存有限,当缓存token长度过长,会导致内存爆掉,因而通过减少KV Cache的方式,减少内存占用。
1.
KV cache
以下引自:
transformer之KV Cache
•
原理是什么
最本质的原理是避免重复计算,将需要重复计算的结果进行缓存,需要缓存的值为历史token对应的KV值,所以叫KV Cache。
•
为什么只需要KV
生成阶段
,输入新的token需要先计算其Q值,然后计算其与历史token K的注意力值,最后与历史token V值进行加权即得到结果,所以只需要缓存历史token的KV值。
•
为什么会存在重复计算
首先,生成式模型每生成一个新token都需要调用整个模型进行一次推理,历史token计算得到的中间激活值在Decoder架构的模型中每次推理时都是一样的,所以可以进行缓存。
这是因为Decoder架构中,当前token只用之前的token计算得到注意力值,通过Causal Mask实现,换句话说,在推理的时候前面已经生成的字符不需要与后面的字符产生attention,从而使得前面已经计算的K和V可以缓存起来。
总的来说,KV Cache是以空间换时间的做法,通过使用快速的缓存存取,减少了重复计算。(注意,只有decoder结构的模型可用,因为有mask attention的存在,使得前面的token可以不用关注后面的token)。
以生成阶段的QK计算过程来看,不同的计算方式计算量的差异,来分析K/V cache对计算量的影响:
,
,
•
方法一
计算量=1xdxd+1xdxd+1xnxd
该方法计算量最小,但是需要缓存历史x信息,如果每次使用index方式索引embedding矩阵替代KV,倒是可以节省缓存,但这种方法效果可能会受损
33%
•
方法二
计算量=dxdxd+1xdxd+1xnxd
该方法计算量最大,但是可以通过
离线计算
方式,这样可极大减小计算量,但是不适用于Rope相对位置编码,只
适用于绝对位置编码
,因为相对编码矩阵无法融入权重矩阵中
33%
•
方法三
计算量=1xdxd+dxnxd+1xnxd
该方法计算量适中,如果配合KV cache缓存
计算量为1xdxd+1xnxd,计算量最小,并且
适用于相对位置编码
33%
但是,用了KV Cache之后也不是立刻万事大吉。
我们简单算一下,对于输入长度为 s ,层数为 L ,hidden size为 d 的模型,需要缓存的参数量为2×L×s×d
如果使用的是半精度浮点数,那么总共所需的空间就是2×2×L×s×d。
以Llama2 7B为例,有 L=32 , d=4096 ,那么每个token所需的缓存空间就是524,288 bytes,约520K,当 s=1024 时,则需要536,870,912 bytes,超过500M的空间。
下面介绍attention优化算法来减少KV缓存空间。
2.
Multi-Query Attention & Grouped-Query Attention
以下引自:
理解Attention:从起源到MHA,MQA和GQA
•
MQA
◦
直接让所有Attention Head共享同一个K、V
◦
MQA直接将KV Cache减少到了原来的1/n
◦
效果方面,目前看来大部分任务的损失都比较有限,且MQA的支持者相信这部分损失可以通过进一步训练来弥补回。
50%
•
GQA
◦
由于使用一套共享的 K、V 效果不好,将同一个group内的Q共享同一套 K、V ,不同group的Q所用的 K、V 不同
◦
GQA使用者:Meta开源的
LLAMA2-70B
,以及
LLAMA3
全系列,
DeepSeek-V1
、
Yi
、
ChatGLM2
/3
50%
3.
Multi-head Latent Attention
以下引自:
大模型KV Cache节省神器MLA学习笔记(包含推理时的矩阵吸收分析)
DeepSeek2 paper
缓存与效果的极限拉扯:从MHA、MQA、GQA到MLA