来源:互联网 更新时间:2026-06-02 07:46
很多人第一次跑本地大模型,会以为显存主要被模型参数吃掉。这当然没错——一个 7B 模型即使用 FP16,也要十几 GB 级别的参数显存。
但进入真实推理后,你会发现另一个东西涨得很快。prompt 越长,KV Cache 越大;batch 越大,KV Cache 越大;上下文窗口越长,KV Cache 越大;并发请求越多,KV Cache 越难管理。
模型参数是加载时就基本固定的,而 KV Cache 是生成过程中随着请求、长度和 batch 持续增长的。这也是为什么服务端推理框架会认真做 KV Cache 管理——vLLM 的 PagedAttention、Hugging Face 的 DynamicCache/StaticCache/QuantizedCache,本质上都在处理同一类问题:怎么让历史 K/V 既能被快速读取,又不要把显存撑爆。
GQA 正好站在这个问题的中心。
一句话概括:GQA 通过减少 KV Head 的数量,直接压缩了 KV Cache 的体积。
Decoder-only 大模型推理时,一般分成两个阶段:
| 阶段 | 输入 | 主要动作 |
|---|---|---|
| Prefill | 完整 prompt | 一次性计算 prompt 的每层 K/V,并写入 cache |
| Decode | 当前新 token | 只算新 token 的 Q/K/V,用新 Q 查询历史 K/V |
Hugging Face 的缓存文档也强调过:自回归生成是一个 token 一个 token 往后预测,KV Cache 会保存过去 token 在注意力层里的 K/V,后续 token 可以复用它们,避免重复计算。
上一篇文章里,我们用的 MHA 张量形状是:
q.shape == [batch, num_heads, seq_len, head_dim]
k.shape == [batch, num_heads, seq_len, head_dim]
v.shape == [batch, num_heads, seq_len, head_dim]
每一层要缓存的是历史 token 的 k 和 v:
past_k.shape == [batch, num_heads, past_len, head_dim]
past_v.shape == [batch, num_heads, past_len, head_dim]
注意,这里缓存的是每一层的 K/V。一个 32 层模型,就有 32 份这样的缓存。所以 KV Cache 的显存可以粗略估算为:
KV Cache bytes = batch_size * seq_len * num_layers * 2 * num_kv_heads * head_dim * bytes_per_element
这里的 2 表示 K 和 V 两份。公式里最容易被忽略的是 num_kv_heads。
在 MHA 里:num_kv_heads = num_query_heads。而在 GQA 里:num_kv_heads < num_query_heads。这就是 GQA 能省显存的入口。
假设有一个简化配置:
batch_size = 1
seq_len = 8192
num_layers = 32
num_query_heads = 32
head_dim = 128
dtype = fp16 # 2 bytes
如果是传统 MHA:num_kv_heads = 32,KV Cache 大约是:
1 * 8192 * 32 * 2 * 32 * 128 * 2 bytes = 4 GiB
如果换成 GQA,假设 num_kv_heads = 8,KV Cache 大约是:
1 * 8192 * 32 * 2 * 8 * 128 * 2 bytes = 1 GiB
同样的 Query Head 数量,同样的上下文长度,只是把 KV Head 从 32 降到 8,缓存就变成原来的四分之一。
如果是 MQA:num_kv_heads = 1,KV Cache 会进一步降到:
128 MiB
这只是一个教学估算,真实框架还要受到 allocator、block size、padding、并发调度、量化和 kernel 实现的影响。但作为面试和工程理解,这个公式足够抓住核心。
可以用一张表先记住:
| 结构 | Query Head | KV Head | 直觉 |
|---|---|---|---|
| MHA | 多个 | 和 Query 一样多 | 每个 Q head 独享一组 K/V |
| MQA | 多个 | 1 个 | 所有 Q head 共享同一组 K/V |
| GQA | 多个 | 介于 1 和 Query Head 之间 | 一组 Q head 共享一组 K/V |
假设:num_query_heads = 32,num_kv_heads = 8,则 group_size = num_query_heads // num_kv_heads = 4。
那么 GQA 的意思是:前 4 个 Q head(0123)共享一个 KV Head,接下来 4 个(4567)共享下一个,以此类推。它不像 MQA 那样把所有 Query Head 都压到同一个 KV Head 上,也不像 MHA 那样每个 Query Head 都保留独立 K/V。
GQA 原论文的动机也在这里:MQA 可以显著提升 decoder 推理速度,但可能带来质量下降;GQA 使用介于 1 和 Query Head 数之间的 KV Head 数量,在效果和推理效率之间做折中。
MHA 的投影通常是:
q_proj: hidden_dim -> num_q_heads * head_dim
k_proj: hidden_dim -> num_q_heads * head_dim
v_proj: hidden_dim -> num_q_heads * head_dim
GQA 的投影变成:
q_proj: hidden_dim -> num_q_heads * head_dim
k_proj: hidden_dim -> num_kv_heads * head_dim
v_proj: hidden_dim -> num_kv_heads * head_dim
也就是说,Q 还是很多头,K/V 变少了。
假设:batch = 2,seq_len = 5,num_q_heads = 32,num_kv_heads = 8,head_dim = 128,那么:
q.shape == [2, 32, 5, 128]
k.shape == [2, 8, 5, 128]
v.shape == [2, 8, 5, 128]
但 attention 计算时,q @ k.transpose(-2, -1) 要求 head 维度能对齐。一个教学版做法是把 K/V 按组展开:
k_expanded.shape == [2, 32, 5, 128]
v_expanded.shape == [2, 32, 5, 128]
PyTorch 的 scaled_dot_product_attention(enable_gqa=True) 文档里也展示了类似逻辑:启用 GQA 时,会按 Query Head 和 KV Head 的比例对 key/value 做 repeat_interlea ve。但要注意,真实高性能实现不一定真的物理复制 K/V。服务端推理更关心 cache 布局、访存和 kernel 的实现方式。
下面这份代码只保留核心逻辑,适合面试讲法:
import math
import torch
from torch import nn
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
# x: [B, H_kv, T, D]
if n_rep == 1:
return x
batch, num_kv_heads, seq_len, head_dim = x.shape
x = x[:, :, None, :, :]
x = x.expand(batch, num_kv_heads, n_rep, seq_len, head_dim)
return x.reshape(batch, num_kv_heads * n_rep, seq_len, head_dim)
class GroupedQueryAttention(nn.Module):
def __init__(self,
hidden_dim: int,
num_q_heads: int,
num_kv_heads: int,
dropout: float = 0.0,
):
super().__init__()
assert hidden_dim % num_q_heads == 0
assert num_q_heads % num_kv_heads == 0
self.hidden_dim = hidden_dim
self.num_q_heads = num_q_heads
self.num_kv_heads = num_kv_heads
self.head_dim = hidden_dim // num_q_heads
self.num_groups = num_q_heads // num_kv_heads
self.q_proj = nn.Linear(hidden_dim, num_q_heads * self.head_dim)
self.k_proj = nn.Linear(hidden_dim, num_kv_heads * self.head_dim)
self.v_proj = nn.Linear(hidden_dim, num_kv_heads * self.head_dim)
self.o_proj = nn.Linear(num_q_heads * self.head_dim, hidden_dim)
self.dropout = nn.Dropout(dropout)
def _split_heads(self, x: torch.Tensor, num_heads: int) -> torch.Tensor:
batch, seq_len, _ = x.shape
x = x.view(batch, seq_len, num_heads, self.head_dim)
return x.transpose(1, 2) # [B, H, T, D]
def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
batch, heads, seq_len, head_dim = x.shape
x = x.transpose(1, 2).contiguous()
return x.view(batch, seq_len, heads * head_dim)
def forward(self,
x: torch.Tensor,
attn_mask: torch.Tensor | None = None,
past_key_value: tuple[torch.Tensor, torch.Tensor] | None = None,
use_cache: bool = False,
):
q = self._split_heads(self.q_proj(x), self.num_q_heads)
k = self._split_heads(self.k_proj(x), self.num_kv_heads)
v = self._split_heads(self.v_proj(x), self.num_kv_heads)
if past_key_value is not None:
past_k, past_v = past_key_value
k = torch.cat([past_k, k], dim=2)
v = torch.cat([past_v, v], dim=2)
present_key_value = (k, v) if use_cache else None
k_for_attn = repeat_kv(k, self.num_groups)
v_for_attn = repeat_kv(v, self.num_groups)
scores = q @ k_for_attn.transpose(-2, -1)
scores = scores / math.sqrt(self.head_dim)
if attn_mask is not None:
scores = scores.masked_fill(attn_mask, float("-inf"))
weights = torch.softmax(scores, dim=-1)
weights = self.dropout(weights)
out = weights @ v_for_attn
out = self._merge_heads(out)
out = self.o_proj(out)
return out, weights, present_key_value
测试一下形状:
x = torch.randn(2, 5, 4096)
gqa = GroupedQueryAttention(
hidden_dim=4096,
num_q_heads=32,
num_kv_heads=8,
)
out, weights, cache = gqa(x, use_cache=True)
print(out.shape) # [2, 5, 4096]
print(weights.shape) # [2, 32, 5, 5]
print(cache[0].shape) # [2, 8, 5, 128]
print(cache[1].shape) # [2, 8, 5, 128]
关键点在最后两行。注意力权重仍然是 32 个 Query Head:weights.shape == [2, 32, 5, 5],但缓存里只有 8 个 KV Head:cache[0].shape == [2, 8, 5, 128],cache[1].shape == [2, 8, 5, 128]。这就是 GQA 在 KV Cache 上省显存的直接体现。
PyTorch 的 torch.nn.functional.scaled_dot_product_attention 已经有 enable_gqa 参数。
一个最小示例:
import torch
import torch.nn.functional as F
query = torch.randn(2, 32, 5, 128, device="cuda", dtype=torch.float16)
key = torch.randn(2, 8, 5, 128, device="cuda", dtype=torch.float16)
value = torch.randn(2, 8, 5, 128, device="cuda", dtype=torch.float16)
out = F.scaled_dot_product_attention(
query, key, value,
is_causal=True,
enable_gqa=True,
)
print(out.shape) # [2, 32, 5, 128]
官方文档里有两个约束很重要:
number_of_heads_query % number_of_heads_key_value == 0
number_of_heads_key == number_of_heads_value
也就是说:
enable_gqa 目前仍是实验特性,后端支持和张量类型有限制。还有一个容易踩坑的点:PyTorch 这个函数里的布尔 attn_mask 语义,和一些 MHA 接口的 padding mask 语义相反。scaled_dot_product_attention 里 True 表示参与 attention,迁移代码时要小心。
如果只做一次完整 forward,而且不使用 KV Cache,GQA 对峰值显存的影响没有 KV Cache 场景那么直观。
真正的收益集中在自回归 decode:
每一步都要读历史 K/V
历史越长,读得越多
并发越高,cache 越多
KV Head 越少,cache 越小
Hugging Face 的优化文档也提到,减少 KV 向量数量只有在使用 KV Cache 的自回归解码场景里才特别有意义,因为 decode 阶段会反复读取历史 K/V,内存带宽很容易成为瓶颈。
所以可以这样理解:
| 场景 | GQA 价值 |
|---|---|
| 训练全序列并行 | 不是主要优化目标 |
| Prefill | 可以减少写入 cache 的 K/V 体积 |
| Decode | 最关键,减少每步读取的历史 K/V |
| 长上下文服务 | 价值更明显 |
| 高并发服务 | 价值更明显 |
这也是为什么讲 GQA 时,不能只画 attention 公式,要把它放回推理服务的 KV Cache 场景里看。
GQA 解决的是:单个 token 的 K/V 体积更小。PagedAttention 解决的是:大量 token 的 K/V 如何更高效地组织和管理。二者不是同一层优化,但会一起影响推理效率。
vLLM 的 PagedAttention 文档里提到,key/value cache 会被拆成 block,每个 block 存固定数量 token 的 cache。这样做的目标是用更适合服务端调度的方式管理 KV Cache,而不是把每个请求都当成一大段连续显存。
可以把它们放到同一张图里:
GQA:减少每个 token 的 KV 体积
PagedAttention:管理很多 token 的 KV 存放方式
Quantized Cache:降低每个元素的字节数
Offloaded Cache:把部分 cache 放到 CPU
如果只看单次模型结构,GQA 像是 attention 结构变化。如果从推理系统看,GQA 是 KV Cache 成本控制的一环。
num_kv_heads,忘了改投影层输出维度GQA 里 Q/K/V 的 projection 输出维度不一样:
q_proj -> num_q_heads * head_dim
k_proj -> num_kv_heads * head_dim
v_proj -> num_kv_heads * head_dim
如果还把 K/V 投影到 num_q_heads * head_dim,cache 就没有省下来。
num_q_heads 不能整除 num_kv_headsGQA 要按组共享 K/V,所以通常要求:num_q_heads % num_kv_heads == 0,否则每组 Query Head 没法均匀映射到 KV Head。
教学代码为了看懂,会在 attention 前做 repeat_kv。但 cache 里应该保留较少的 KV Head:cache_k.shape == [B, H_kv, T, D]。如果把展开后的 K/V 存进去:cache_k.shape == [B, H_q, T, D],显存又回到 MHA 级别了。
KV Cache 不只是占显存。Decode 每一步都要读取历史 K/V,所以内存带宽也会成为瓶颈。GQA 的价值不只是少存,也包括少读。
GQA 是效果和效率的折中。GQA 原论文的结论是,GQA 相比 MQA 更能保留 MHA 的质量,同时接近 MQA 的速度收益。但具体效果仍然取决于模型、训练方式、上采样策略和任务。工程上不要把结构变化理解成“免费优化”——它通常是在模型设计或训练阶段就确定好的。
如果面试官问:“GQA 和 MHA 有什么区别?”
可以这样回答:GQA 的核心差异在于,Query Head 的数量多于 Key/Value Head 的数量,多个 Query Head 会共享同一组 K/V。而 MHA 里每个 Query Head 都有独立的 K/V。
如果继续问:“为什么能省显存?”
可以接:因为 KV Cache 的大小与 num_kv_heads 直接成正比。在相同的 Query Head 数量和序列长度下,GQA 只需要缓存更少的 K/V Head,所以显存占用更小。
如果问:“GQA、MQA 怎么区分?”
可以答:MQA 是所有 Query Head 共享一个 KV Head,极端省显存但可能损失效果;GQA 是折中方案,将 Query Head 分成若干组,每组共享一个 KV Head。
如果问:“代码里最容易错在哪里?”
可以答:最容易错的是投影层的输出维度改错,以及 cache 里意外存了展开后的 K/V。核心约束是 num_q_heads % num_kv_heads == 0。
| 问题 | 关键回答 |
|---|---|
| GQA 改了什么? | Query Head 多,KV Head 少 |
| 为什么能省显存? | KV Cache 大小和 num_kv_heads 成正比 |
| MHA 的 KV Head 数? | 通常等于 Query Head 数 |
| MQA 的 KV Head 数? | 1 个 |
| GQA 的 KV Head 数? | 介于 1 和 Query Head 数之间 |
| 代码核心约束? | num_q_heads % num_kv_heads == 0 |
| cache 里存什么? | 未展开的 K/V,形状是 [B, H_kv, T, D] |
| attention 前做什么? | 把 K/V 按组映射到 Query Head |
| 最适合讲的场景? | 长上下文、自回归 decode、高并发推理 |
| PyTorch 接口? | scaled_dot_product_attention(..., enable_gqa=True) |
GQA 可以用三句话记住:
所以,学 GQA 不要只记住一个缩写。真正要记住的是这条线:
MHA 张量形状 -> KV Cache 显存公式 -> KV Head 数量 -> Decode 访存压力 -> GQA
这条线讲清楚了,GQA、MQA、KV Cache、长上下文推理优化,就能串起来。
torch.nn.functional.scaled_dot_product_attention下饭影视APP下载安装指南
灵宝派对手游下载安装地址推荐
和平精英如何做到压枪稳-和平精英怎样才能压枪稳
下载浏览器app下载安装选择推荐
初中英语同步课文跟读APP推荐|免费下载高口碑跟读软件排行榜
4D采矿者官网在哪下载 最新官方下载安装地址
阅读app安卓版下载推荐
免费影视剧APP推荐
碎片人偶Vragmeet官网在哪下载 最新官方下载安装地址
儿子穿新中式现身大会堂 马斯克罕见用中文回应:他正在学习普通话
Elysium Above 履云录官网在哪下载 最新官方下载安装地址
好用的手环阅读app下载安装
名单曝光!库克、马斯克等将随团到访中国 黄仁勋不在其中
人声接近真人!OpenAI一口气更新三款超强语音AI
短视频软件推荐
短剧《情绪超市》剧情介绍
苹果macOS 27将优化界面设计并测试AI驱动的Safari标签页自动分组功能
免费看电影的软件推荐
售价约3200元!暴力熊推出预开盖版Ultra 7 270K Plus:支持直触芯片散热方案
《梦幻西游》出道人金价走势解析-云游道人影响解析
手机号码测吉凶
本站所有软件,都由网友上传,如有侵犯你的版权,请发邮件haolingcc@hotmail.com 联系删除。 版权所有 Copyright@2012-2013 haoling.cc