LLM 推理全过程维度变化与核心公式

把 Llama 3 风格的 dense decoder-only 模型从 embedding 到采样的整条路径一次写清楚,顺便把常见变体(MHA / MQA / GQA / MLA、RoPE / ALiBi、RMSNorm / LayerNorm、SwiGLU / GeGLU、Flash Attention / Paged Attention)的数学位置都嵌进去 — 读完之后自己能把这张图默写出来。整篇围绕一个事实展开:这套结构二十年没怎么变过,所有”推理优化”都是在同一张骨架的某个位置做局部手术。

参数符号约定 — Llama 3 8B 为基准

全篇沿用同一套符号,举例用 Llama 3 8B:

符号含义示例值(Llama 3 8B)
BBbatch size2
SSprompt 长度10
LL层数32
HHhidden dim4096
VV词表大小128256
nqn_qQ 头数32
nkvn_{kv}KV 头数(GQA)8
dd每头维度 =H/nq= H/n_q128
IIFFN 中间维度14336
TT生成 token 数100
ttdecode 当前步数1..T1..T

所有 shape 标注都按 PyTorch 习惯写成 [B, ..., H];权重矩阵按”输入维 × 输出维”约定写成 WR[in,out]W \in \mathbb{R}^{[\text{in}, \text{out}]}

核心公式速查 — Embedding · Norm · Attn · FFN · LM Head

Embedding

xi=E[token_idi]RH\mathbf{x}_i = E[\text{token\_id}_i] \in \mathbb{R}^{H}

EE 的 shape 是 [V,H][V, H],本质是一张查找表。许多实现把 EE 与 LM Head 的 WlmW_{\text{lm}} 共享(tied embedding),省显存也略微正则;Llama 系列默认不共享

归一化:LayerNorm vs RMSNorm

标准 LayerNorm:

LN(x)=xμσ2+ϵγ+β,μ=1Hxi, σ2=1H(xiμ)2\text{LN}(\mathbf{x}) = \frac{\mathbf{x} - \mu}{\sqrt{\sigma^{2} + \epsilon}} \odot \boldsymbol{\gamma} + \boldsymbol{\beta}, \quad \mu = \tfrac{1}{H}\sum x_i,\ \sigma^{2} = \tfrac{1}{H}\sum (x_i - \mu)^{2}

RMSNorm(Llama / Mistral / Qwen 主流选择):

RMSNorm(x)=x1Hi=1Hxi2+ϵγ\text{RMSNorm}(\mathbf{x}) = \frac{\mathbf{x}}{\sqrt{\frac{1}{H}\sum_{i=1}^{H} x_i^{2} + \epsilon}} \odot \boldsymbol{\gamma}

RMSNorm 省了均值、也省了 β\boldsymbol{\beta},计算量和参数都约减半;实证发现对质量几乎无损。所有主流推理引擎都按 pre-norm 组织:norm 在残差分支内部,残差主干不经过 norm。

Q/K/V 投影 + 位置编码

Q=XWQ,K=XWK,V=XWVQ = X W_Q, \quad K = X W_K, \quad V = X W_V

RoPE(Rotary Positional Embedding)把位置 mm 的旋转矩阵乘在每两维一组的 Q、K 上:

RoPE(qm,m)=(cosmθksinmθksinmθkcosmθk)(q2kq2k+1),θk=base2k/d\text{RoPE}(\mathbf{q}_m, m) = \begin{pmatrix} \cos m\theta_k & -\sin m\theta_k \\ \sin m\theta_k & \phantom{-}\cos m\theta_k \end{pmatrix} \begin{pmatrix} q_{2k} \\ q_{2k+1} \end{pmatrix}, \quad \theta_k = \text{base}^{-2k/d}

base\text{base} 通常取 10000;长上下文模型(Llama 3.1 128K、Qwen2.5-1M)往往用 YaRN / NTK-aware scalingbase\text{base} 动态拉大或对 θk\theta_k 做频段分层缩放。RoPE 只作用于 Q、K,不作用于 V — V 是被加权的值,不需要位置信息。

ALiBi(BLOOM、MPT 用)走的是另一条路:不在嵌入里加位置,而是在 attention 的分数上加一个线性偏置:

scoresij=qikjdmh(ij)\text{scores}_{ij} = \frac{\mathbf{q}_i^{\top}\mathbf{k}_j}{\sqrt{d}} - m_h \cdot (i - j)

mhm_h 是每头一个的斜率常数。ALiBi 的好处是天然外推,但上限不如 RoPE+YaRN 漂亮,所以 2024 年之后基本没人用了

Scaled Dot-Product Attention

Attention(Q,K,V)=softmax ⁣(QKd+M)V\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^{\top}}{\sqrt{d}} + M\right) V

MM 是 causal mask,上三角为 -\infty,保证位置 ii 只能看 i\le i 的位置。除 d\sqrt{d} 是为了抑制 softmax 进入梯度近乎为零的区域。

多头变体:MHA / MQA / GQA / MLA

变体nkvn_{kv}KV Cache 节省代表模型
MHA=nq= n_qGPT-2/3, Llama 1/2 7B
GQA分组 <nq< n_qnq/nkvn_q / n_{kv}Llama 3, Qwen2, Mistral
MQA=1= 1nqn_qPaLM, Falcon
MLA低秩压缩nq\sim n_qDeepSeek V2/V3

GQA 的数学形式是:nkvn_{kv} 组 K、V,每组被 nq/nkvn_q/n_{kv} 个 Q 头共享;在核函数里通常用广播实现,不真的复制张量

MLA(Multi-Head Latent Attention)则用低秩压缩把 K、V 共同压成潜向量 cKV\mathbf{c}^{KV}

ciKV=WDKVxi,kiC=WUKciKV,viC=WUVciKV\mathbf{c}^{KV}_i = W^{DKV} \mathbf{x}_i, \quad \mathbf{k}^{C}_i = W^{UK}\mathbf{c}^{KV}_i, \quad \mathbf{v}^{C}_i = W^{UV}\mathbf{c}^{KV}_i

只需缓存 cKV\mathbf{c}^{KV}(维度 dcd_c,通常 512)而不是所有头的 K、V。代价是 attention 计算变复杂(RoPE 要单独走 kR\mathbf{k}^{R} 分支),但单 token cache 从 GQA 的几 KB 进一步压到数百字节。

Output 投影 + 残差

h=x+Attn(Q,K,V)WO\mathbf{h} = \mathbf{x} + \text{Attn}(Q', K', V') \cdot W_O

注意残差加的是 attention 子层输入 x\mathbf{x}pre-norm 之前的值),不是 norm 之后的。

FFN 变体

经典双线性 FFN(GPT-2):

FFN(x)=ϕ(xW1)W2\text{FFN}(\mathbf{x}) = \phi(\mathbf{x} W_1) W_2

ϕ\phi 通常取 GeLU。GeLU 的近似形式:

GeLU(x)0.5x(1+tanh ⁣[2π(x+0.044715x3)])\text{GeLU}(x) \approx 0.5 x \left(1 + \tanh\!\left[\sqrt{\tfrac{2}{\pi}}\left(x + 0.044715 x^{3}\right)\right]\right)

GLU 家族(Llama、PaLM、Mistral 都在用):

GLU(x)=(ϕ(xWgate)(xWup))Wdown\text{GLU}(\mathbf{x}) = \big(\phi(\mathbf{x} W_{\text{gate}}) \odot (\mathbf{x} W_{\text{up}})\big) W_{\text{down}}

其中 SiLU 的定义是 SiLU(z)=zσ(z)\text{SiLU}(z) = z \cdot \sigma(z)。SwiGLU 比经典 GeLU-FFN 多一个投影(三矩阵 vs 两矩阵),为了参数预算对齐,实现里一般把 II 设成 234H\tfrac{2}{3} \cdot 4H4H4H 是 GPT-2 的惯例),Llama 3 8B 的 I=1433623440961.3I = 14336 \approx \tfrac{2}{3}\cdot 4\cdot 4096 \cdot 1.3

残差

xout=h+FFN(RMSNorm(h))\mathbf{x}_{\text{out}} = \mathbf{h} + \text{FFN}(\text{RMSNorm}(\mathbf{h}))

LM Head + 采样

logits=xfinalWlm,p=softmax(logits/T)\text{logits} = \mathbf{x}_{\text{final}} W_{\text{lm}}, \quad \mathbf{p} = \text{softmax}(\text{logits}/T)

TT 是 temperature。采样前通常叠几层 logits 变换:

重复惩罚(repetition / frequency / presence penalty):

logitsv=logitsvα1[vhistory]βcount(v)\text{logits}'_v = \text{logits}_v - \alpha \cdot \mathbb{1}[v \in \text{history}] - \beta \cdot \text{count}(v)

Top-k:只保留最大 kk 个 logits,其它置 -\infty

Top-p(nucleus):按概率降序累加,保留累积概率 p\le p 的集合。

Min-p:保留 pvpminpmaxp_v \ge p_{\min} \cdot p_{\max} 的集合,对低熵分布更友好。

Typical-p:基于与条件熵的偏差做截断,保留 logpvH(p)|-\log p_v - H(\mathbf{p})| 小的集合。

所有截断都作用在概率分布本身之前/之后,不改变公式的骨架

Prefill 阶段维度流转 — S 个 token 一次性走一遍

输入:input_ids [B, S] = [2, 10]

input_ids [2, 10]
Embedding · 查表 E[V, H]
x [2, 10, 4096][B, S, H]
进入 Layer 1 · 重复 32 层
RMSNorm
x_norm [2, 10, 4096]
Q/K/V 投影
Q_flat [2, 10, 4096]H = n_q · dK_flat [2, 10, 1024]V_flat [2, 10, 1024]n_kv · d · GQA 窄 4 倍
reshape + transpose 成多头
Q [2, 32, 10, 128][B, n_q, S, d]K [2, 8, 10, 128][B, n_kv, S, d]V [2, 8, 10, 128]
RoPE · 作用于 Q, K
Q, K shape 不变
写入 KV Cache
KV_Cache[layer][:, :, 0:10, :] = K, V
Attention · Q @ KT / √d · GQA 广播 K/V 到 n_q 头
scores [2, 32, 10, 10][B, n_q, S, S]
+ causal mask · softmax
attn_weights [2, 32, 10, 10]
@ V
attn_out [2, 32, 10, 128]
transpose + reshape
[2, 10, 4096]
× W_O
[2, 10, 4096]
+ 残差
h [2, 10, 4096]
RMSNorm
FFN · W_gate / W_up 升维 · SwiGLU 门控 · W_down 降维
中间张量 [2, 10, 14336]
+ 残差
x_out [2, 10, 4096]
退出第 32 层 · shape 始终 [B, S, H]
Final RMSNorm
[2, 10, 4096]
只取最后一个位置 · 其他位置不进 LM Head
[2, 4096]
× W_lm [H, V]
logits [2, 128256]
temperature / top-p / 采样
next_token [2, 1]第一个输出 token

Prefill 结束后,KV Cache 状态:每层已填入前 10 个位置。

Decode 阶段维度流转 — 第 t 步只走 1 个 token

前置状态:已有 cache_len=S+t1\text{cache\_len} = S + t - 1 个位置。

输入:input_ids [B, 1] = [2, 1](上一步生成的 1 个 token)。

input_ids [2, 1]
Embedding
x [2, 1, 4096][B, 1, H] · 注意 S = 1
进入 Layer l · l = 1..32
RMSNorm
[2, 1, 4096]
Q/K/V 投影 · 只算 1 个新 token
Q_new [2, 32, 1, 128]K_new [2, 8, 1, 128]V_new [2, 8, 1, 128]
RoPE · 位置 = cache_len
shape 不变
写入 KV Cache 的下一个位置
KV_Cache[:, :, cache_len, :] = K_new, V_newcache_len += 1
从 cache 读全部历史 K, V
K_full [2, 8, cache_len, 128]V_full [2, 8, cache_len, 128]
Attention · Q_new @ K_fullT / √d · GQA 广播
scores [2, 32, 1, cache_len]Q 侧只有 1 行
softmax · 无需 causal mask · 天然只看历史
attn_weights [2, 32, 1, cache_len]
@ V_full
attn_out [2, 32, 1, 128]cache_len 维被求和消掉
reshape + W_O
[2, 1, 4096]
+ 残差
[2, 1, 4096]
RMSNorm
FFN · 只处理 1 个 token
中间张量 [2, 1, 14336]
+ 残差
[2, 1, 4096]
退出第 32 层 · shape 始终 [B, 1, H]
Final RMSNorm
[2, 1, 4096]
× W_lm
logits [2, 1, 128256] → [2, 128256]
采样
next_token [2, 1]
循环 · 作为下一步的输入

Prefill vs Decode 维度对照 — GEMM vs GEMV · 算力 vs 带宽

位置PrefillDecode 每步
input_ids[B,S][B, S][B,1][B, 1]
embedding 后[B,S,H][B, S, H][B,1,H][B, 1, H]
Q[B,nq,S,d][B, n_q, S, d][B,nq,1,d][B, n_q, 1, d]
K_new / V_new[B,nkv,S,d][B, n_{kv}, S, d][B,nkv,1,d][B, n_{kv}, 1, d]
K_full / V_full(从 cache)同 K_new[B,nkv,cache_len,d][B, n_{kv}, \text{cache\_len}, d]
attention scores[B,nq,S,S][B, n_q, S, S][B,nq,1,cache_len][B, n_q, 1, \text{cache\_len}]
attention 输出[B,nq,S,d][B, n_q, S, d][B,nq,1,d][B, n_q, 1, d]
FFN 中间[B,S,I][B, S, I][B,1,I][B, 1, I]
logits[B,V][B, V](取最后位置)[B,V][B, V]
运算性质GEMM(矩 × 矩)GEMV(矩 × 量)
瓶颈算力内存带宽

这张表是理解所有推理加速工作的起点:prefill 像训练的 forward,compute-bound;decode 是一串 GEMV,memory-bound,绝大部分时间在往 SM 里搬权重。两段的优化方向天差地别

KV Cache 的形状与增长 — 单 token 几 KB · 长 context 几百 MB

每层一对 cache:

Kcache,VcacheRB×nkv×Smax×dK_{\text{cache}}, V_{\text{cache}} \in \mathbb{R}^{B \times n_{kv} \times S_{\max} \times d}

每个 token、每层的 cache 大小(fp16):

2×nkv×d×2 bytes=2×8×128×2=4 KB2 \times n_{kv} \times d \times 2\ \text{bytes} = 2 \times 8 \times 128 \times 2 = 4\ \text{KB}

每个 token、全模型(32 层):4 KB×32=128 KB/token4\ \text{KB} \times 32 = 128\ \text{KB} / \text{token}。一个 4096-token 请求:128 KB×4096512 MB128\ \text{KB} \times 4096 \approx 512\ \text{MB}

几种工程优化:

每步的计算量与内存开销 — Llama 3 8B fp16 · H100 拐点 ~330 FLOPs/byte

前面的维度图只说 shape,不说量级。推理优化 90% 的讨论都在算”这一步要花多少 FLOPs、搬多少字节”,所以这里直接把每步的开销拍成表。

以 Llama 3 8B、fp16、B=1B=1 为基准,prefill 取 S=2048S=2048,decode 取 cache_len=2048\text{cache\_len}=2048(即生成到第 2048 个 token 时的某一步)。

参考硬件拐点:H100 SXM fp16 理论算力 ~989 TFLOPs,HBM 带宽 ~3 TB/s,roofline 拐点 AI330 FLOPs/byte\text{AI}^{*} \approx 330\ \text{FLOPs/byte}。高于它是 compute-bound,低于它是 memory-bound

权重分布

组件Shapefp16 大小全模型(×32 层)
Embedding EE[V,H][V, H]1.0 GB1.0 GB
WQW_Q[H,H][H, H]32 MB1.0 GB
WKW_K[H,nkvd][H, n_{kv}d]8 MB256 MB
WVW_V[H,nkvd][H, n_{kv}d]8 MB256 MB
WOW_O[H,H][H, H]32 MB1.0 GB
WgateW_{\text{gate}}[H,I][H, I]117 MB3.7 GB
WupW_{\text{up}}[H,I][H, I]117 MB3.7 GB
WdownW_{\text{down}}[I,H][I, H]117 MB3.7 GB
RMSNorm γ\boldsymbol{\gamma}(每层 2 份)[H]×2[H]\times 216 KB500 KB
LM head WlmW_{\text{lm}}[H,V][H, V]1.0 GB1.0 GB
合计~432 MB / 层~16 GB

全模型 fp16 权重约 16 GB,每次 forward 的”底价”就是把这 16 GB 从 HBM 里扫一遍。H100 @ 3 TB/s 下 =16/30005.3 ms= 16/3000 \approx 5.3\ \text{ms}这就是单请求 decode 的物理下限

每层每步的计算 / 内存读写

对同一层在 prefill(N=SN=S)和 decode(N=1N=1)下的各个子步做对照。“Weight HBM” 是要从显存搬的权重字节,“KV HBM” 是要读/写的 KV Cache 字节。中间 activation 默认被 kernel 融合,不单独算。

步骤Prefill FLOPs(S=2048)Decode FLOPs(S=1)Weight HBMKV HBM
RMSNorm5BSH5BSH ≈ 42 MF20 KFγ\boldsymbol{\gamma} 8 KB
QprojQ_{\text{proj}}2BSH22BSH^{2} ≈ 68.7 GF33.5 MFWQW_Q 32 MB
KprojK_{\text{proj}}(+写 cache)2BSHnkvd2BSH \cdot n_{kv}d ≈ 17.2 GF8.4 MFWKW_K 8 MBW 4 MB / 2 KB
VprojV_{\text{proj}}(+写 cache)17.2 GF8.4 MFWVW_V 8 MBW 4 MB / 2 KB
RoPE~50 MF25 KF
Attn QKQK^{\top}2BnqNLkd2B n_q N L_k d ≈ 34.4 GF16.8 MFR 4 MB(decode)
softmax~700 MF260 KF
Attn V\cdot V34.4 GF16.8 MFR 4 MB(decode)
WOW_O2BSH22BSH^{2} ≈ 68.7 GF33.5 MFWOW_O 32 MB
RMSNorm42 MF20 KFγ\boldsymbol{\gamma} 8 KB
WgateW_{\text{gate}}2BSHI2BSHI ≈ 241 GF117 MFWgateW_{\text{gate}} 117 MB
WupW_{\text{up}}241 GF117 MFWupW_{\text{up}} 117 MB
SiLU + gate~90 MF45 KF
WdownW_{\text{down}}2BSIH2BSIH ≈ 241 GF117 MFWdownW_{\text{down}} 117 MB
每层合计~960 GFLOPs~470 MFLOPs~432 MBW 8 MB(P)/ R 8 MB(D)

几个直接结论:

全模型一次 forward

把 32 层 + embedding + LM head 加起来:

阶段FLOPsHBM I/OArithmetic Intensity瓶颈
Prefill S=2048, B=1~31 TFLOPs~14 GB(权重)+ 256 MB(KV 写)~2200 FLOPs/byte算力
Decode step, cache_len=2048, B=1~15 GFLOPs~14 GB(权重)+ 256 MB(KV 读)~1.05 FLOPs/byte带宽
LM head(prefill 只算最后一位)~1 GFLOP1 GB~1 FLOPs/byte带宽
LM head(decode)~1 GFLOP1 GB~1 FLOPs/byte带宽

Decode 的 1.05 FLOPs/byte 比 H100 拐点 330 低 两个半数量级 — 意味着理想情况下单请求 decode 的算力利用率只有 1.05/3300.3%1.05/330 \approx 0.3\%。这就是 continuous batching 的数学依据:把 BB 拉到 32,同一批权重 read 被 32 个请求共享,arithmetic intensity 直接 ×32,decode 吞吐几乎线性增长,直到 attention 部分或算力先撞墙。

心算口诀

两条规则覆盖 90% 的推理性能估算:

  1. FLOPs ≈ 2PN2 P NPP 是参数量(~8B),NN 是这次 forward 处理的 token 总数。每个参数被每个 token 各用一次 MAC,一次 MAC 算 2 FLOPs。例如 prefill S=2048S=2048: 2×8B×204833 TFLOPs2 \times 8\text{B} \times 2048 \approx 33\ \text{TFLOPs},与分项加总的 31 TFLOPs 吻合。
  2. 权重 HBM I/O ≈ 2P2 P bytes(fp16):一次 forward 就是把模型扫一遍,约 16 GB。

Arithmetic intensity 本质是 2PN2P=N\frac{2 P N}{2 P} = N — forward 里一共参与的 token 数。Prefill 有 SBS \cdot B 个 token,decode 只有 BB 个。这一个数字直接决定了 prefill / decode 瓶颈不同的根源。

计算复杂度总览 — 有 KV Cache vs 没 KV Cache 差三个数量级

Prefill(一次性处理 SS 个 token):

FLOPsO(LSH2)线性层+O(LS2H)attention\text{FLOPs} \sim \underbrace{O(L \cdot S \cdot H^{2})}_{\text{线性层}} + \underbrace{O(L \cdot S^{2} \cdot H)}_{\text{attention}}

Decode 每步(处理 1 个 token,历史 cache_len\text{cache\_len}):

FLOPsO(LH2)线性层,恒定+O(Lcache_lenH)attention,随 cache 线性增长\text{FLOPs} \sim \underbrace{O(L \cdot H^{2})}_{\text{线性层,恒定}} + \underbrace{O(L \cdot \text{cache\_len} \cdot H)}_{\text{attention,随 cache 线性增长}}

生成 TT 个 token 总复杂度:

FLOPstotalO ⁣(LTH2+LT(S+T)H)\text{FLOPs}_{\text{total}} \sim O\!\left(L \cdot T \cdot H^{2} + L \cdot T \cdot (S + T) \cdot H\right)

对比无 KV Cache:O(L(S+T)3)O(L \cdot (S+T)^{3}),差异巨大。

把 FLOPs 和带宽一起看就更直观 — 这正是上一节那张表展示的:prefill 挑算力上限,decode 挑带宽上限;continuous batching 的意义就是把 NN 个请求的 decode 拼成一个大 GEMV,搬权重的成本被 NN 个请求摊薄,吞吐线性上涨直到算力或 attention 部分成为瓶颈。

工程优化怎么嵌进公式 — Flash Attn / Spec Decode / Continuous Batch

Flash Attention:数学上完全等价于标准 attention,公式一字不变。工程上把 softmax 和 matmul 融成一个 kernel,按 block 流式更新 softmax 的运行统计量(max、sum),避免把 S×SS \times S 的 attention 矩阵写回 HBM。复杂度不变,显存占用从 O(S2)O(S^{2}) 降到 O(S)O(S),速度提升主要来自 HBM 访问减少。FA-2 把切分粒度从 head 改到 query block;FA-3 在 H100/H200 上叠加 warpgroup MMA + producer-consumer 异步流水。

Flash Decoding:Flash Attention 在 decode 时 Q 只有 1 行,kernel 并行度不够。Flash Decoding 把 K、V 的 cache_len\text{cache\_len} 维切成多块并行,再做一次 log-sum-exp 归约。公式还是同一个 softmax,只是拆成两段算

Speculative Decoding:用一个小的 draft 模型连续生成 kk 个 token,再用大模型一次 prefill 那 kk 个位置做验证。接受规则是

accept with prob min ⁣(1,ptarget(x)pdraft(x))\text{accept with prob } \min\!\left(1, \frac{p_{\text{target}}(x)}{p_{\text{draft}}(x)}\right)

关键是把 decode 的 kk 次 GEMV 合并成一次 kk 长度的 GEMM,把大模型压在 memory-bound 边界的时间重新变成 compute-bound。期望每步接受 kˉ\bar k 个 token,吞吐放大 kˉ\bar k 倍(减去 draft 开销)。变种:Medusa(多头预测)、EAGLE(特征级 draft)、Lookahead Decoding(无 draft 模型)。

Continuous Batching(vLLM、TGI):不在 prefill 边界 pad 到齐,而是请求级的 step 调度。每个 step 选一批正处于同一 phase(prefill 或 decode)的请求拼 batch,完成一个释放一个。数学上各请求互不影响,只是编排顺序变了。原始论文是 OSDI’22 的 Orca。

Chunked Prefill:把长 prompt 的 prefill 拆成多段,和 decode 请求混进同一个 step,减少 decode 请求的等待抖动。公式无变化。SARATHI / DistServe 等系统的核心调度原语。

一句话串起整个过程 — 从 token ID 到下一个 token

输入 token ID → 查 embedding → 过 LL 层(Pre-RMSNorm → Attention 带 RoPE → 残差 → Pre-RMSNorm → SwiGLU FFN → 残差) → Final RMSNorm → LM Head → logits → 采样。Prefill 时 SS 个 token 并行走一遍,产物是第一个 token + 完整 KV Cache;Decode 时每步只输入 1 个 token,attention 处从 cache 读历史 K、V,其他所有操作都是逐 token 独立的

关键工程不变量 — 值得背下来的 6 条

  1. 主线张量 shape 始终是 [B,Scurrent,H][B, S_{\text{current}}, H]。残差结构保证维度不变,一旦你在某一步看到 HH 变了,要么是在 attention 内部展成多头,要么是在 FFN 内部升到 II出子层就回到 HH
  2. K、V 一旦算出就不再变。因为它们只是线性投影 WK,WVW_K, W_V 作用在已经定了的输入 x\mathbf{x} 上,而 causal 结构保证更晚的位置不会反过来改更早位置的表示。这是 KV Cache 成立的数学基础。
  3. Attention 是唯一跨 token 的操作,其他所有操作(norm、投影、FFN、激活)都逐 token 独立。所以只需要缓存跨 token 操作需要的 K、V;其他都可以即时算完扔掉。
  4. Decode 的 scores shape 是 [B,nq,1,cache_len][B, n_q, 1, \text{cache\_len}]。“1” 是 Q 侧(当前新 token),cache_len\text{cache\_len} 维在和 VV 做加权求和时被消掉,输出又回到 1 行。
  5. Decode 每步非 attention 部分的计算量恒定,只有 attention 随 cache 长度线性增长。所以真正”越生成越慢”的本质,是 attention 的 cache_len\text{cache\_len} 越来越长,加上 KV Cache 把 memory footprint 顶到 HBM 带宽上限。
  6. Prefill 用 GEMM,decode 用 GEMV。这一字之差决定了所有推理引擎都要做两套 kernel、两套调度策略。理解这一条,后面看任何推理优化论文都不会迷路。

如果把这十条装进脑子里,再去读 Flash Attention、PagedAttention、MLA、投机解码这些论文,会发现它们都是在这张骨架的某个位置做局部优化,而骨架本身二十年没怎么变过。

参考资料 — 公式 · 论文 · 工程 blog

架构与核心算子

位置编码

多头变体(KV 共享 / 压缩)

Flash Attention 系列

推理引擎与服务调度

Speculative Decoding 家族

KV Cache 压缩 / 量化

代表性开源模型技术报告

硬件 / Roofline

其他长文 / 教程