LLM Inference Walkthrough — Tensor Shapes and Core Formulas Across the Whole Pipeline

Lay out a Llama 3-style dense decoder-only model end to end from embedding to sampling, embedding the math positions of common variants (MHA / MQA / GQA / MLA, RoPE / ALiBi, RMSNorm / LayerNorm, SwiGLU / GeGLU, Flash Attention / Paged Attention) along the way — so that after reading you can draw this diagram from memory. One fact frames the whole article: this architecture has barely changed in twenty years; every “inference optimization” is a local surgery somewhere on the same skeleton.

Notation Conventions — Llama 3 8B as the Reference

The same notation is used throughout, with Llama 3 8B as the concrete example:

SymbolMeaningExample value (Llama 3 8B)
BBbatch size2
SSprompt length10
LLnumber of layers32
HHhidden dim4096
VVvocab size128256
nqn_qnumber of Q heads32
nkvn_{kv}number of KV heads (GQA)8
ddper-head dim =H/nq= H/n_q128
IIFFN intermediate dim14336
TTnumber of generated tokens100
ttcurrent decode step1..T1..T

All shape annotations follow PyTorch convention, written as [B, ..., H]; weight matrices follow “in × out”, written as WR[in,out]W \in \mathbb{R}^{[\text{in}, \text{out}]}.

Core Formulas Quick Reference — Embedding · Norm · Attn · FFN · LM Head

Embedding

xi=E[token_idi]RH\mathbf{x}_i = E[\text{token\_id}_i] \in \mathbb{R}^{H}

EE has shape [V,H][V, H] and is essentially a lookup table. Many implementations share EE with the LM Head’s WlmW_{\text{lm}} (tied embedding), saving memory and providing mild regularization; Llama models do not share by default.

Normalization: LayerNorm vs RMSNorm

Standard LayerNorm:

LN(x)=xμσ2+ϵγ+β,μ=1Hxi, σ2=1H(xiμ)2\text{LN}(\mathbf{x}) = \frac{\mathbf{x} - \mu}{\sqrt{\sigma^{2} + \epsilon}} \odot \boldsymbol{\gamma} + \boldsymbol{\beta}, \quad \mu = \tfrac{1}{H}\sum x_i,\ \sigma^{2} = \tfrac{1}{H}\sum (x_i - \mu)^{2}

RMSNorm (the mainstream choice for Llama / Mistral / Qwen):

RMSNorm(x)=x1Hi=1Hxi2+ϵγ\text{RMSNorm}(\mathbf{x}) = \frac{\mathbf{x}}{\sqrt{\frac{1}{H}\sum_{i=1}^{H} x_i^{2} + \epsilon}} \odot \boldsymbol{\gamma}

RMSNorm drops the mean and β\boldsymbol{\beta}, roughly halving both compute and parameters; empirically the quality cost is nearly zero. All mainstream inference engines organize this as pre-norm: norm lives inside the residual branch, and the residual main path bypasses it.

Q/K/V Projection + Positional Encoding

Q=XWQ,K=XWK,V=XWVQ = X W_Q, \quad K = X W_K, \quad V = X W_V

RoPE (Rotary Positional Embedding) applies a position-mm rotation matrix to every two-dim group of Q and K:

RoPE(qm,m)=(cosmθksinmθksinmθkcosmθk)(q2kq2k+1),θk=base2k/d\text{RoPE}(\mathbf{q}_m, m) = \begin{pmatrix} \cos m\theta_k & -\sin m\theta_k \\ \sin m\theta_k & \phantom{-}\cos m\theta_k \end{pmatrix} \begin{pmatrix} q_{2k} \\ q_{2k+1} \end{pmatrix}, \quad \theta_k = \text{base}^{-2k/d}

base\text{base} is usually 10000; long-context models (Llama 3.1 128K, Qwen2.5-1M) typically use YaRN / NTK-aware scaling to dynamically enlarge base\text{base} or apply band-wise scaling to θk\theta_k. RoPE acts only on Q and K, not on V — V is the weighted value and doesn’t need positional information.

ALiBi (used by BLOOM and MPT) takes a different route: it does not add position into embeddings; instead it adds a linear bias on the attention scores:

scoresij=qikjdmh(ij)\text{scores}_{ij} = \frac{\mathbf{q}_i^{\top}\mathbf{k}_j}{\sqrt{d}} - m_h \cdot (i - j)

mhm_h is a per-head slope constant. ALiBi extrapolates naturally, but its ceiling is below RoPE+YaRN, so post-2024 it’s essentially unused.

Scaled Dot-Product Attention

Attention(Q,K,V)=softmax ⁣(QKd+M)V\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^{\top}}{\sqrt{d}} + M\right) V

MM is a causal mask, with the upper triangle set to -\infty, ensuring position ii can only see positions i\le i. Dividing by d\sqrt{d} keeps softmax out of the near-zero-gradient region.

Multi-Head Variants: MHA / MQA / GQA / MLA

Variantnkvn_{kv}KV cache savingsRepresentative models
MHA=nq= n_qGPT-2/3, Llama 1/2 7B
GQAgrouped <nq< n_qnq/nkvn_q / n_{kv}Llama 3, Qwen2, Mistral
MQA=1= 1nqn_qPaLM, Falcon
MLAlow-rank compressednq\sim n_qDeepSeek V2/V3

GQA’s math: nkvn_{kv} groups of K and V, each group shared by nq/nkvn_q/n_{kv} Q heads; kernels typically implement this by broadcasting, not by actually replicating tensors.

MLA (Multi-Head Latent Attention) compresses K and V jointly into a low-rank latent vector cKV\mathbf{c}^{KV}:

ciKV=WDKVxi,kiC=WUKciKV,viC=WUVciKV\mathbf{c}^{KV}_i = W^{DKV} \mathbf{x}_i, \quad \mathbf{k}^{C}_i = W^{UK}\mathbf{c}^{KV}_i, \quad \mathbf{v}^{C}_i = W^{UV}\mathbf{c}^{KV}_i

Only cKV\mathbf{c}^{KV} (dimension dcd_c, typically 512) is cached rather than all heads’ K and V. The cost is more complex attention computation (RoPE must go through a separate kR\mathbf{k}^{R} branch), but per-token cache shrinks from GQA’s several KB to a few hundred bytes.

Output Projection + Residual

h=x+Attn(Q,K,V)WO\mathbf{h} = \mathbf{x} + \text{Attn}(Q', K', V') \cdot W_O

Note that the residual adds the attention sublayer’s input x\mathbf{x} (the value before pre-norm), not the post-norm value.

FFN Variants

Classic bilinear FFN (GPT-2):

FFN(x)=ϕ(xW1)W2\text{FFN}(\mathbf{x}) = \phi(\mathbf{x} W_1) W_2

ϕ\phi is typically GeLU. The GeLU approximation:

GeLU(x)0.5x(1+tanh ⁣[2π(x+0.044715x3)])\text{GeLU}(x) \approx 0.5 x \left(1 + \tanh\!\left[\sqrt{\tfrac{2}{\pi}}\left(x + 0.044715 x^{3}\right)\right]\right)

The GLU family (used by Llama, PaLM, Mistral):

GLU(x)=(ϕ(xWgate)(xWup))Wdown\text{GLU}(\mathbf{x}) = \big(\phi(\mathbf{x} W_{\text{gate}}) \odot (\mathbf{x} W_{\text{up}})\big) W_{\text{down}}

where SiLU is defined as SiLU(z)=zσ(z)\text{SiLU}(z) = z \cdot \sigma(z). SwiGLU has one more projection than the classic GeLU-FFN (three matrices vs two); to align parameter budget, implementations typically set I=234HI = \tfrac{2}{3} \cdot 4H (4H4H is the GPT-2 convention), and Llama 3 8B’s I=1433623440961.3I = 14336 \approx \tfrac{2}{3}\cdot 4\cdot 4096 \cdot 1.3.

Residual

xout=h+FFN(RMSNorm(h))\mathbf{x}_{\text{out}} = \mathbf{h} + \text{FFN}(\text{RMSNorm}(\mathbf{h}))

LM Head + Sampling

logits=xfinalWlm,p=softmax(logits/T)\text{logits} = \mathbf{x}_{\text{final}} W_{\text{lm}}, \quad \mathbf{p} = \text{softmax}(\text{logits}/T)

TT is temperature. Several logits transformations are typically applied before sampling:

Repetition penalties (repetition / frequency / presence penalty):

logitsv=logitsvα1[vhistory]βcount(v)\text{logits}'_v = \text{logits}_v - \alpha \cdot \mathbb{1}[v \in \text{history}] - \beta \cdot \text{count}(v)

Top-k: keep the largest kk logits, set the rest to -\infty.

Top-p (nucleus): sort by probability descending, keep the smallest set whose cumulative probability p\le p.

Min-p: keep tokens with pvpminpmaxp_v \ge p_{\min} \cdot p_{\max}; friendlier to low-entropy distributions.

Typical-p: truncate by deviation from the conditional entropy, keeping the set where logpvH(p)|-\log p_v - H(\mathbf{p})| is small.

All truncations act before/after the probability distribution itself, without changing the skeleton of the formula.

Prefill Stage Shape Transitions — S Tokens Walk Through Once

Input: input_ids [B, S] = [2, 10]. The figure below shows the forward pass of one Transformer layer, wrapped in 32 layers; the shape stays at [B, S, H] = [2, 10, 4096] throughout, with residual edges shown as orange dashed lines.

× 32 layersinput_ids · [2, 10]Embedding · lookup E[V, H]x = E[token_id]x · [2, 10, 4096] [B, S, H]RMSNormx / √(mean(x²) + ε) ⊙ γQ/K/V proj + RoPEQ = X·W_Q K = X·W_K V = X·W_VQ [2, 32, 10, 128] K, V [2, 8, 10, 128] GQAWrite KV Cachecache[l][:,:,0:10,:] = K, VAttention · multi-headsoftmax(Q·Kᵀ / √d + M) · V · W_Oscores [2, 32, 10, 10] · causal mask+h · [2, 10, 4096]h = x + Attn(RMSNorm(x)) · W_ORMSNormh / √(mean(h²) + ε) ⊙ γFFN · SwiGLU(SiLU(h·W_gate) ⊙ h·W_up) · W_downintermediate dim [2, 10, 14336]+x_out · [2, 10, 4096]x_out = h + FFN(RMSNorm(h))Final RMSNormx_final / √(mean(x_final²) + ε) ⊙ γtake last position only · [2, 4096]LM Headlogits = x_final · W_lmW_lm [H, V] · logits [2, 128256]Sampling · temperature / top-pp = softmax(logits / T) → top-k / top-p → samplenext_token · [2, 1] · first outputattn residualFFN residual
Prefill — shape transitions and core formulas of a single Transformer layer. Orange dashed lines are pre-norm residuals; ⊕ marks residual merges; the left bracket marks ”× 32 layers.”

After prefill, the KV Cache state: each layer holds the first 10 positions.

Decode Stage Shape Transitions — Step t Processes Only 1 Token

Prior state: cache_len=S+t1\text{cache\_len} = S + t - 1 positions filled.

Input: input_ids [B, 1] = [2, 1] (the 1 token generated at the previous step).

× 32 layersinput_ids · [2, 1] 1 token from previous stepEmbeddingx = E[token_id]x · [2, 1, 4096] [B, 1, H] · S = 1RMSNormx / √(mean(x²) + ε) ⊙ γQ/K/V proj · compute only 1 new tokenQ_new = x·W_Q K_new = x·W_K V_new = x·W_VQ_new [2, 32, 1, 128] K_new, V_new [2, 8, 1, 128]RoPE · pos = cache_lenQ_new, K_new ← rotate by m·θ_k, m = cache_lenWrite next position in KV Cachecache[l][:,:,cache_len,:] = K_new, V_new · cache_len += 1Read full history K, V from cacheK_full, V_full = cache[l][:,:,:cache_len,:]K_full, V_full · [2, 8, cache_len, 128]Attention · no causal masksoftmax(Q_new · K_fullᵀ / √d) · V_full · W_Oscores [2, 32, 1, cache_len]+h · [2, 1, 4096]h = x + Attn(RMSNorm(x)) · W_ORMSNormh / √(mean(h²) + ε) ⊙ γFFN · SwiGLU · processes only 1 token(SiLU(h·W_gate) ⊙ h·W_up) · W_downintermediate dim [2, 1, 14336]+x_out · [2, 1, 4096]x_out = h + FFN(RMSNorm(h))Final RMSNormx_final / √(mean(x_final²) + ε) ⊙ γLM Headlogits = x_final · W_lmlogits [2, 1, 128256] → [2, 128256]Sampling · temperature / top-pp = softmax(logits / T) → top-k / top-p → samplenext_token · [2, 1]attn residualFFN residualnext-step input
Decode — step t processes only 1 token, with core formulas; the KV Cache keeps growing; the blue dashed line shows the generated next_token feeding back as the next step’s input.

Prefill vs Decode Shape Comparison — GEMM vs GEMV · Compute vs Bandwidth

PositionPrefillDecode (per step)
input_ids[B,S][B, S][B,1][B, 1]
after embedding[B,S,H][B, S, H][B,1,H][B, 1, H]
Q[B,nq,S,d][B, n_q, S, d][B,nq,1,d][B, n_q, 1, d]
K_new / V_new[B,nkv,S,d][B, n_{kv}, S, d][B,nkv,1,d][B, n_{kv}, 1, d]
K_full / V_full (from cache)same as K_new[B,nkv,cache_len,d][B, n_{kv}, \text{cache\_len}, d]
attention scores[B,nq,S,S][B, n_q, S, S][B,nq,1,cache_len][B, n_q, 1, \text{cache\_len}]
attention output[B,nq,S,d][B, n_q, S, d][B,nq,1,d][B, n_q, 1, d]
FFN intermediate[B,S,I][B, S, I][B,1,I][B, 1, I]
logits[B,V][B, V] (last position)[B,V][B, V]
Operation typeGEMM (matrix × matrix)GEMV (matrix × vector)
Bottleneckcomputememory bandwidth

This table is the starting point for understanding every inference acceleration effort: prefill is like training’s forward, compute-bound; decode is a chain of GEMVs, memory-bound, with most time spent fetching weights into SMs. The optimization directions of the two are worlds apart.

KV Cache Shape and Growth — Few KB Per Token · Hundreds of MB at Long Context

One pair of caches per layer:

Kcache,VcacheRB×nkv×Smax×dK_{\text{cache}}, V_{\text{cache}} \in \mathbb{R}^{B \times n_{kv} \times S_{\max} \times d}

Per-token, per-layer cache size (fp16):

2×nkv×d×2 bytes=2×8×128×2=4 KB2 \times n_{kv} \times d \times 2\ \text{bytes} = 2 \times 8 \times 128 \times 2 = 4\ \text{KB}

Per-token across the full 32-layer model: 4 KB×32=128 KB/token4\ \text{KB} \times 32 = 128\ \text{KB} / \text{token}. A 4096-token request: 128 KB×4096512 MB128\ \text{KB} \times 4096 \approx 512\ \text{MB}.

Several engineering optimizations:

Per-Step Compute and Memory Cost — Llama 3 8B fp16 · H100 Knee ~330 FLOPs/byte

The shape diagrams above show shapes but not magnitudes. 90% of inference optimization discussion is about “how many FLOPs does this step cost, how many bytes move,” so let’s lay each step’s cost into tables directly.

Using Llama 3 8B, fp16, B=1B=1 as the baseline; for prefill take S=2048S=2048; for decode take cache_len=2048\text{cache\_len}=2048 (some step during generation around the 2048th token).

Reference hardware knee: H100 SXM fp16 theoretical compute ~989 TFLOPs, HBM bandwidth ~3 TB/s, roofline knee AI330 FLOPs/byte\text{AI}^{*} \approx 330\ \text{FLOPs/byte}. Above it is compute-bound, below is memory-bound.

Weight Distribution

ComponentShapefp16 sizefull model (× 32 layers)
Embedding EE[V,H][V, H]1.0 GB1.0 GB
WQW_Q[H,H][H, H]32 MB1.0 GB
WKW_K[H,nkvd][H, n_{kv}d]8 MB256 MB
WVW_V[H,nkvd][H, n_{kv}d]8 MB256 MB
WOW_O[H,H][H, H]32 MB1.0 GB
WgateW_{\text{gate}}[H,I][H, I]117 MB3.7 GB
WupW_{\text{up}}[H,I][H, I]117 MB3.7 GB
WdownW_{\text{down}}[I,H][I, H]117 MB3.7 GB
RMSNorm γ\boldsymbol{\gamma} (2 per layer)[H]×2[H]\times 216 KB500 KB
LM head WlmW_{\text{lm}}[H,V][H, V]1.0 GB1.0 GB
Total~432 MB / layer~16 GB

Full-model fp16 weights are ~16 GB; the “floor price” of every forward pass is to scan these 16 GB from HBM. At H100’s 3 TB/s, =16/30005.3 ms= 16/3000 \approx 5.3\ \text{ms}this is the physical lower bound of single-request decode.

Per-Layer, Per-Step Compute / Memory I/O

Compare the same layer’s substeps under prefill (N=SN=S) and decode (N=1N=1). “Weight HBM” is the weight bytes fetched from VRAM; “KV HBM” is the KV-cache bytes read/written. Intermediate activations are assumed fused into kernels and not counted separately.

StepPrefill FLOPs (S=2048)Decode FLOPs (S=1)Weight HBMKV HBM
RMSNorm5BSH5BSH ≈ 42 MF20 KFγ\boldsymbol{\gamma} 8 KB
QprojQ_{\text{proj}}2BSH22BSH^{2} ≈ 68.7 GF33.5 MFWQW_Q 32 MB
KprojK_{\text{proj}} (+write cache)2BSHnkvd2BSH \cdot n_{kv}d ≈ 17.2 GF8.4 MFWKW_K 8 MBW 4 MB / 2 KB
VprojV_{\text{proj}} (+write cache)17.2 GF8.4 MFWVW_V 8 MBW 4 MB / 2 KB
RoPE~50 MF25 KF
Attn QKQK^{\top}2BnqNLkd2B n_q N L_k d ≈ 34.4 GF16.8 MFR 4 MB (decode)
softmax~700 MF260 KF
Attn V\cdot V34.4 GF16.8 MFR 4 MB (decode)
WOW_O2BSH22BSH^{2} ≈ 68.7 GF33.5 MFWOW_O 32 MB
RMSNorm42 MF20 KFγ\boldsymbol{\gamma} 8 KB
WgateW_{\text{gate}}2BSHI2BSHI ≈ 241 GF117 MFWgateW_{\text{gate}} 117 MB
WupW_{\text{up}}241 GF117 MFWupW_{\text{up}} 117 MB
SiLU + gate~90 MF45 KF
WdownW_{\text{down}}2BSIH2BSIH ≈ 241 GF117 MFWdownW_{\text{down}} 117 MB
Per-layer total~960 GFLOPs~470 MFLOPs~432 MBW 8 MB (P) / R 8 MB (D)

A few direct conclusions:

One Full Forward Pass

Adding 32 layers + embedding + LM head:

StageFLOPsHBM I/OArithmetic IntensityBottleneck
Prefill S=2048, B=1~31 TFLOPs~14 GB (weights) + 256 MB (KV write)~2200 FLOPs/bytecompute
Decode step, cache_len=2048, B=1~15 GFLOPs~14 GB (weights) + 256 MB (KV read)~1.05 FLOPs/bytebandwidth
LM head (prefill, last position only)~1 GFLOP1 GB~1 FLOPs/bytebandwidth
LM head (decode)~1 GFLOP1 GB~1 FLOPs/bytebandwidth

Decode’s 1.05 FLOPs/byte is 2.5 orders of magnitude below H100’s knee of 330 — meaning ideal single-request decode compute utilization is only 1.05/3300.3%1.05/330 \approx 0.3\%. This is the mathematical basis for continuous batching: push BB to 32 so the same weight read is amortized across 32 requests; arithmetic intensity scales by 32×, decode throughput grows almost linearly until the attention portion or compute itself becomes the wall.

Mental-Math Rules

Two rules cover 90% of inference performance estimation:

  1. FLOPs ≈ 2PN2 P N: PP is the parameter count (~8B); NN is the total number of tokens this forward pass processes. Each parameter is used once per token (one MAC = 2 FLOPs). E.g., prefill S=2048S=2048: 2×8B×204833 TFLOPs2 \times 8\text{B} \times 2048 \approx 33\ \text{TFLOPs}, matching the itemized sum of 31 TFLOPs.
  2. Weight HBM I/O ≈ 2P2 P bytes (fp16): one forward pass scans the model once, about 16 GB.

Arithmetic intensity is essentially 2PN2P=N\frac{2 P N}{2 P} = N — the total number of tokens participating in this forward. Prefill has SBS \cdot B tokens; decode has only BB. This single number directly determines why prefill and decode have different bottlenecks.

Compute Complexity Overview — With vs Without KV Cache Differ by Three Orders of Magnitude

Prefill (process SS tokens at once):

FLOPsO(LSH2)linear layers+O(LS2H)attention\text{FLOPs} \sim \underbrace{O(L \cdot S \cdot H^{2})}_{\text{linear layers}} + \underbrace{O(L \cdot S^{2} \cdot H)}_{\text{attention}}

Decode per step (process 1 token, history cache_len\text{cache\_len}):

FLOPsO(LH2)linear layers, constant+O(Lcache_lenH)attention, linear in cache\text{FLOPs} \sim \underbrace{O(L \cdot H^{2})}_{\text{linear layers, constant}} + \underbrace{O(L \cdot \text{cache\_len} \cdot H)}_{\text{attention, linear in cache}}

Total complexity to generate TT tokens:

FLOPstotalO ⁣(LTH2+LT(S+T)H)\text{FLOPs}_{\text{total}} \sim O\!\left(L \cdot T \cdot H^{2} + L \cdot T \cdot (S + T) \cdot H\right)

Without KV cache: O(L(S+T)3)O(L \cdot (S+T)^{3}) — a massive difference.

Reading FLOPs and bandwidth together is even more illuminating — that’s what the previous section’s table shows: prefill challenges the compute ceiling; decode challenges the bandwidth ceiling; continuous batching’s point is to fuse NN requests’ decodes into one large GEMV, amortizing the weight-fetch cost across NN requests, with throughput rising linearly until compute or the attention portion becomes the bottleneck.

How Engineering Optimizations Plug into the Formulas — Flash Attn / Spec Decode / Continuous Batch

Flash Attention: mathematically equivalent to standard attention — the formulas don’t change a single character. Engineering-wise it fuses softmax and matmul into one kernel, updating softmax’s running statistics (max, sum) in a streaming manner over blocks, avoiding writing the S×SS \times S attention matrix back to HBM. Complexity unchanged; memory drops from O(S2)O(S^{2}) to O(S)O(S); speedup comes mainly from reduced HBM access. FA-2 shifted the partition granularity from heads to query blocks; FA-3 on H100/H200 adds warpgroup MMA + producer-consumer async pipelining.

Flash Decoding: at decode, Flash Attention’s Q has only one row, so kernel parallelism is too low. Flash Decoding splits the cache_len\text{cache\_len} dimension of K and V into chunks for parallelism and then does a final log-sum-exp reduction. The formula is the same softmax, just split into two passes.

Speculative Decoding: a small “draft” model generates kk tokens sequentially, then the large model verifies them with one prefill over the kk positions. Acceptance rule:

accept with prob min ⁣(1,ptarget(x)pdraft(x))\text{accept with prob } \min\!\left(1, \frac{p_{\text{target}}(x)}{p_{\text{draft}}(x)}\right)

The crux is fusing kk decode GEMVs into one kk-length GEMM, turning the large model’s memory-bound regime back into compute-bound. With expected kˉ\bar k accepted tokens per step, throughput scales by kˉ\bar k (minus draft overhead). Variants: Medusa (multi-head prediction), EAGLE (feature-level draft), Lookahead Decoding (no draft model).

Continuous Batching (vLLM, TGI): instead of padding prefill to align at boundaries, schedule at the per-step request level. Each step picks a batch of requests in the same phase (prefill or decode), releases finished ones. Mathematically each request is independent; only the ordering changes. Original paper: OSDI’22’s Orca.

Chunked Prefill: split long prompts’ prefill into chunks and mix them with decode requests in the same step, reducing decode latency jitter. No formula changes. The core scheduling primitive of SARATHI / DistServe.

One Sentence Spanning the Whole Process — From Token ID to Next Token

Input token IDs → look up embedding → through LL layers (Pre-RMSNorm → Attention with RoPE → residual → Pre-RMSNorm → SwiGLU FFN → residual) → Final RMSNorm → LM Head → logits → sampling. In prefill, SS tokens pass in parallel, producing the first token + full KV Cache; in decode, each step inputs 1 token; at attention it reads historical K and V from the cache, while all other operations are per-token independent.

Key Engineering Invariants — 6 Rules Worth Memorizing

  1. The main-line tensor shape is always [B,Scurrent,H][B, S_{\text{current}}, H]. Residual structure preserves the dimension; whenever HH appears different somewhere, either it’s spread into heads inside attention, or lifted to II inside FFN, and back to HH on exit.
  2. K and V, once computed, never change. Because they’re linear projections WK,WVW_K, W_V applied to the already-fixed input x\mathbf{x}, and the causal structure ensures later positions cannot reach back to modify earlier representations. This is the mathematical basis for KV Cache.
  3. Attention is the only cross-token operation; all others (norm, projection, FFN, activation) are per-token independent. So only K and V — the inputs to cross-token operations — need caching; everything else can be computed and discarded immediately.
  4. Decode’s scores shape is [B,nq,1,cache_len][B, n_q, 1, \text{cache\_len}]. The “1” is the Q side (the current new token), and the cache_len\text{cache\_len} dim is eliminated when weighted-summing with VV, returning to one row.
  5. Decode’s non-attention compute per step is constant; only attention grows linearly with cache length. So the true reason “generation gets slower over time” is that attention’s cache_len\text{cache\_len} keeps growing, plus KV Cache pushing the memory footprint against HBM bandwidth limits.
  6. Prefill uses GEMM; decode uses GEMV. This one-letter difference dictates that every inference engine has two kernel sets, two scheduling strategies. Internalize this and no inference-optimization paper will lose you.

Internalize these six and you’ll find that Flash Attention, PagedAttention, MLA, speculative decoding — they’re all local optimizations at some spot on this skeleton, while the skeleton itself has barely changed in twenty years.

References — Formulas · Papers · Engineering Blogs

Architecture and Core Operators

Positional Encoding

Multi-Head Variants (KV Sharing / Compression)

Flash Attention Series

Inference Engines and Serving Schedulers

Speculative Decoding Family

KV Cache Compression / Quantization

Representative Open-Source Model Technical Reports

Hardware / Roofline

Other Long-Form Articles / Tutorials