LLM Inference Walkthrough — Tensor Shapes and Core Formulas Across the Whole Pipeline
Lay out a Llama 3-style dense decoder-only model end to end from embedding to sampling, embedding the math positions of common variants (MHA / MQA / GQA / MLA, RoPE / ALiBi, RMSNorm / LayerNorm, SwiGLU / GeGLU, Flash Attention / Paged Attention) along the way — so that after reading you can draw this diagram from memory. One fact frames the whole article: this architecture has barely changed in twenty years; every “inference optimization” is a local surgery somewhere on the same skeleton.
Notation Conventions — Llama 3 8B as the Reference
The same notation is used throughout, with Llama 3 8B as the concrete example:
| Symbol | Meaning | Example value (Llama 3 8B) |
|---|---|---|
| batch size | 2 | |
| prompt length | 10 | |
| number of layers | 32 | |
| hidden dim | 4096 | |
| vocab size | 128256 | |
| number of Q heads | 32 | |
| number of KV heads (GQA) | 8 | |
| per-head dim | 128 | |
| FFN intermediate dim | 14336 | |
| number of generated tokens | 100 | |
| current decode step |
All shape annotations follow PyTorch convention, written as [B, ..., H]; weight matrices follow “in × out”, written as .
Core Formulas Quick Reference — Embedding · Norm · Attn · FFN · LM Head
Embedding
has shape and is essentially a lookup table. Many implementations share with the LM Head’s (tied embedding), saving memory and providing mild regularization; Llama models do not share by default.
Normalization: LayerNorm vs RMSNorm
Standard LayerNorm:
RMSNorm (the mainstream choice for Llama / Mistral / Qwen):
RMSNorm drops the mean and , roughly halving both compute and parameters; empirically the quality cost is nearly zero. All mainstream inference engines organize this as pre-norm: norm lives inside the residual branch, and the residual main path bypasses it.
Q/K/V Projection + Positional Encoding
RoPE (Rotary Positional Embedding) applies a position- rotation matrix to every two-dim group of Q and K:
is usually 10000; long-context models (Llama 3.1 128K, Qwen2.5-1M) typically use YaRN / NTK-aware scaling to dynamically enlarge or apply band-wise scaling to . RoPE acts only on Q and K, not on V — V is the weighted value and doesn’t need positional information.
ALiBi (used by BLOOM and MPT) takes a different route: it does not add position into embeddings; instead it adds a linear bias on the attention scores:
is a per-head slope constant. ALiBi extrapolates naturally, but its ceiling is below RoPE+YaRN, so post-2024 it’s essentially unused.
Scaled Dot-Product Attention
is a causal mask, with the upper triangle set to , ensuring position can only see positions . Dividing by keeps softmax out of the near-zero-gradient region.
Multi-Head Variants: MHA / MQA / GQA / MLA
| Variant | KV cache savings | Representative models | |
|---|---|---|---|
| MHA | 1× | GPT-2/3, Llama 1/2 7B | |
| GQA | grouped | Llama 3, Qwen2, Mistral | |
| MQA | PaLM, Falcon | ||
| MLA | low-rank compressed | DeepSeek V2/V3 |
GQA’s math: groups of K and V, each group shared by Q heads; kernels typically implement this by broadcasting, not by actually replicating tensors.
MLA (Multi-Head Latent Attention) compresses K and V jointly into a low-rank latent vector :
Only (dimension , typically 512) is cached rather than all heads’ K and V. The cost is more complex attention computation (RoPE must go through a separate branch), but per-token cache shrinks from GQA’s several KB to a few hundred bytes.
Output Projection + Residual
Note that the residual adds the attention sublayer’s input (the value before pre-norm), not the post-norm value.
FFN Variants
Classic bilinear FFN (GPT-2):
is typically GeLU. The GeLU approximation:
The GLU family (used by Llama, PaLM, Mistral):
- : GLU
- : GeGLU
- : SwiGLU
- : ReGLU
where SiLU is defined as . SwiGLU has one more projection than the classic GeLU-FFN (three matrices vs two); to align parameter budget, implementations typically set ( is the GPT-2 convention), and Llama 3 8B’s .
Residual
LM Head + Sampling
is temperature. Several logits transformations are typically applied before sampling:
Repetition penalties (repetition / frequency / presence penalty):
Top-k: keep the largest logits, set the rest to .
Top-p (nucleus): sort by probability descending, keep the smallest set whose cumulative probability .
Min-p: keep tokens with ; friendlier to low-entropy distributions.
Typical-p: truncate by deviation from the conditional entropy, keeping the set where is small.
All truncations act before/after the probability distribution itself, without changing the skeleton of the formula.
Prefill Stage Shape Transitions — S Tokens Walk Through Once
Input: input_ids [B, S] = [2, 10]. The figure below shows the forward pass of one Transformer layer, wrapped in 32 layers; the shape stays at [B, S, H] = [2, 10, 4096] throughout, with residual edges shown as orange dashed lines.
After prefill, the KV Cache state: each layer holds the first 10 positions.
Decode Stage Shape Transitions — Step t Processes Only 1 Token
Prior state: positions filled.
Input: input_ids [B, 1] = [2, 1] (the 1 token generated at the previous step).
Prefill vs Decode Shape Comparison — GEMM vs GEMV · Compute vs Bandwidth
| Position | Prefill | Decode (per step) |
|---|---|---|
| input_ids | ||
| after embedding | ||
| Q | ||
| K_new / V_new | ||
| K_full / V_full (from cache) | same as K_new | |
| attention scores | ||
| attention output | ||
| FFN intermediate | ||
| logits | (last position) | |
| Operation type | GEMM (matrix × matrix) | GEMV (matrix × vector) |
| Bottleneck | compute | memory bandwidth |
This table is the starting point for understanding every inference acceleration effort: prefill is like training’s forward, compute-bound; decode is a chain of GEMVs, memory-bound, with most time spent fetching weights into SMs. The optimization directions of the two are worlds apart.
KV Cache Shape and Growth — Few KB Per Token · Hundreds of MB at Long Context
One pair of caches per layer:
Per-token, per-layer cache size (fp16):
Per-token across the full 32-layer model: . A 4096-token request: .
Several engineering optimizations:
- Paged Attention (vLLM): split the cache into fixed-size blocks (typically 16 tokens), with a block table mapping virtual to physical addresses, eliminating fragmentation. The formulas don’t change; only the tensor layout and access pattern change.
- Sliding Window Attention (Mistral): keep only the most recent tokens of K and V. Cache cap drops from to , at the cost of information truncation, with long-range dependencies relayed through cross-layer stacking.
- INT8 / FP8 KV Cache: quantize fp16 cache down to int8 or even fp8, per-channel or per-token quantization, with controllable error and cache footprint cut by 1/2 to 1/4. Representative work: KIVI / KVQuant.
- KV compression / eviction (H2O, StreamingLLM, SnapKV): drop unimportant positions based on attention weights; used at very long context lengths.
- MLA: mentioned earlier — modifies cache shape at the model-structure level, not as a postprocess.
Per-Step Compute and Memory Cost — Llama 3 8B fp16 · H100 Knee ~330 FLOPs/byte
The shape diagrams above show shapes but not magnitudes. 90% of inference optimization discussion is about “how many FLOPs does this step cost, how many bytes move,” so let’s lay each step’s cost into tables directly.
Using Llama 3 8B, fp16, as the baseline; for prefill take ; for decode take (some step during generation around the 2048th token).
Reference hardware knee: H100 SXM fp16 theoretical compute ~989 TFLOPs, HBM bandwidth ~3 TB/s, roofline knee . Above it is compute-bound, below is memory-bound.
Weight Distribution
| Component | Shape | fp16 size | full model (× 32 layers) |
|---|---|---|---|
| Embedding | 1.0 GB | 1.0 GB | |
| 32 MB | 1.0 GB | ||
| 8 MB | 256 MB | ||
| 8 MB | 256 MB | ||
| 32 MB | 1.0 GB | ||
| 117 MB | 3.7 GB | ||
| 117 MB | 3.7 GB | ||
| 117 MB | 3.7 GB | ||
| RMSNorm (2 per layer) | 16 KB | 500 KB | |
| LM head | 1.0 GB | 1.0 GB | |
| Total | ~432 MB / layer | ~16 GB |
Full-model fp16 weights are ~16 GB; the “floor price” of every forward pass is to scan these 16 GB from HBM. At H100’s 3 TB/s, — this is the physical lower bound of single-request decode.
Per-Layer, Per-Step Compute / Memory I/O
Compare the same layer’s substeps under prefill () and decode (). “Weight HBM” is the weight bytes fetched from VRAM; “KV HBM” is the KV-cache bytes read/written. Intermediate activations are assumed fused into kernels and not counted separately.
| Step | Prefill FLOPs (S=2048) | Decode FLOPs (S=1) | Weight HBM | KV HBM |
|---|---|---|---|---|
| RMSNorm | ≈ 42 MF | 20 KF | 8 KB | — |
| ≈ 68.7 GF | 33.5 MF | 32 MB | — | |
| (+write cache) | ≈ 17.2 GF | 8.4 MF | 8 MB | W 4 MB / 2 KB |
| (+write cache) | 17.2 GF | 8.4 MF | 8 MB | W 4 MB / 2 KB |
| RoPE | ~50 MF | 25 KF | — | — |
| Attn | ≈ 34.4 GF | 16.8 MF | — | R 4 MB (decode) |
| softmax | ~700 MF | 260 KF | — | — |
| Attn | 34.4 GF | 16.8 MF | — | R 4 MB (decode) |
| ≈ 68.7 GF | 33.5 MF | 32 MB | — | |
| RMSNorm | 42 MF | 20 KF | 8 KB | — |
| ≈ 241 GF | 117 MF | 117 MB | — | |
| 241 GF | 117 MF | 117 MB | — | |
| SiLU + gate | ~90 MF | 45 KF | — | — |
| ≈ 241 GF | 117 MF | 117 MB | — | |
| Per-layer total | ~960 GFLOPs | ~470 MFLOPs | ~432 MB | W 8 MB (P) / R 8 MB (D) |
A few direct conclusions:
- FFN is the real protagonist. consume ~75% of FLOPs and ~80% of weight bandwidth. MoE, sparse activation, and FFN quantization all target this block.
- The 4 attention projections (Q/K/V/O) account for ~18%; the actual and only ~7% — in prefill, attention isn’t the bottleneck — the projections are.
- Decode’s KV reads are 8 MB per layer; at this is only ~2% of weight reads. But once context stretches to 64K or 128K, it grows tens of times, overtaking weight bandwidth as the new bottleneck (this is why Paged Attention, sliding window, and KV quantization exist).
One Full Forward Pass
Adding 32 layers + embedding + LM head:
| Stage | FLOPs | HBM I/O | Arithmetic Intensity | Bottleneck |
|---|---|---|---|---|
| Prefill S=2048, B=1 | ~31 TFLOPs | ~14 GB (weights) + 256 MB (KV write) | ~2200 FLOPs/byte | compute |
| Decode step, cache_len=2048, B=1 | ~15 GFLOPs | ~14 GB (weights) + 256 MB (KV read) | ~1.05 FLOPs/byte | bandwidth |
| LM head (prefill, last position only) | ~1 GFLOP | 1 GB | ~1 FLOPs/byte | bandwidth |
| LM head (decode) | ~1 GFLOP | 1 GB | ~1 FLOPs/byte | bandwidth |
Decode’s 1.05 FLOPs/byte is 2.5 orders of magnitude below H100’s knee of 330 — meaning ideal single-request decode compute utilization is only . This is the mathematical basis for continuous batching: push to 32 so the same weight read is amortized across 32 requests; arithmetic intensity scales by 32×, decode throughput grows almost linearly until the attention portion or compute itself becomes the wall.
Mental-Math Rules
Two rules cover 90% of inference performance estimation:
- FLOPs ≈ : is the parameter count (~8B); is the total number of tokens this forward pass processes. Each parameter is used once per token (one MAC = 2 FLOPs). E.g., prefill : , matching the itemized sum of 31 TFLOPs.
- Weight HBM I/O ≈ bytes (fp16): one forward pass scans the model once, about 16 GB.
Arithmetic intensity is essentially — the total number of tokens participating in this forward. Prefill has tokens; decode has only . This single number directly determines why prefill and decode have different bottlenecks.
Compute Complexity Overview — With vs Without KV Cache Differ by Three Orders of Magnitude
Prefill (process tokens at once):
Decode per step (process 1 token, history ):
Total complexity to generate tokens:
Without KV cache: — a massive difference.
Reading FLOPs and bandwidth together is even more illuminating — that’s what the previous section’s table shows: prefill challenges the compute ceiling; decode challenges the bandwidth ceiling; continuous batching’s point is to fuse requests’ decodes into one large GEMV, amortizing the weight-fetch cost across requests, with throughput rising linearly until compute or the attention portion becomes the bottleneck.
How Engineering Optimizations Plug into the Formulas — Flash Attn / Spec Decode / Continuous Batch
Flash Attention: mathematically equivalent to standard attention — the formulas don’t change a single character. Engineering-wise it fuses softmax and matmul into one kernel, updating softmax’s running statistics (max, sum) in a streaming manner over blocks, avoiding writing the attention matrix back to HBM. Complexity unchanged; memory drops from to ; speedup comes mainly from reduced HBM access. FA-2 shifted the partition granularity from heads to query blocks; FA-3 on H100/H200 adds warpgroup MMA + producer-consumer async pipelining.
Flash Decoding: at decode, Flash Attention’s Q has only one row, so kernel parallelism is too low. Flash Decoding splits the dimension of K and V into chunks for parallelism and then does a final log-sum-exp reduction. The formula is the same softmax, just split into two passes.
Speculative Decoding: a small “draft” model generates tokens sequentially, then the large model verifies them with one prefill over the positions. Acceptance rule:
The crux is fusing decode GEMVs into one -length GEMM, turning the large model’s memory-bound regime back into compute-bound. With expected accepted tokens per step, throughput scales by (minus draft overhead). Variants: Medusa (multi-head prediction), EAGLE (feature-level draft), Lookahead Decoding (no draft model).
Continuous Batching (vLLM, TGI): instead of padding prefill to align at boundaries, schedule at the per-step request level. Each step picks a batch of requests in the same phase (prefill or decode), releases finished ones. Mathematically each request is independent; only the ordering changes. Original paper: OSDI’22’s Orca.
Chunked Prefill: split long prompts’ prefill into chunks and mix them with decode requests in the same step, reducing decode latency jitter. No formula changes. The core scheduling primitive of SARATHI / DistServe.
One Sentence Spanning the Whole Process — From Token ID to Next Token
Input token IDs → look up embedding → through layers (Pre-RMSNorm → Attention with RoPE → residual → Pre-RMSNorm → SwiGLU FFN → residual) → Final RMSNorm → LM Head → logits → sampling. In prefill, tokens pass in parallel, producing the first token + full KV Cache; in decode, each step inputs 1 token; at attention it reads historical K and V from the cache, while all other operations are per-token independent.
Key Engineering Invariants — 6 Rules Worth Memorizing
- The main-line tensor shape is always . Residual structure preserves the dimension; whenever appears different somewhere, either it’s spread into heads inside attention, or lifted to inside FFN, and back to on exit.
- K and V, once computed, never change. Because they’re linear projections applied to the already-fixed input , and the causal structure ensures later positions cannot reach back to modify earlier representations. This is the mathematical basis for KV Cache.
- Attention is the only cross-token operation; all others (norm, projection, FFN, activation) are per-token independent. So only K and V — the inputs to cross-token operations — need caching; everything else can be computed and discarded immediately.
- Decode’s scores shape is . The “1” is the Q side (the current new token), and the dim is eliminated when weighted-summing with , returning to one row.
- Decode’s non-attention compute per step is constant; only attention grows linearly with cache length. So the true reason “generation gets slower over time” is that attention’s keeps growing, plus KV Cache pushing the memory footprint against HBM bandwidth limits.
- Prefill uses GEMM; decode uses GEMV. This one-letter difference dictates that every inference engine has two kernel sets, two scheduling strategies. Internalize this and no inference-optimization paper will lose you.
Internalize these six and you’ll find that Flash Attention, PagedAttention, MLA, speculative decoding — they’re all local optimizations at some spot on this skeleton, while the skeleton itself has barely changed in twenty years.
References — Formulas · Papers · Engineering Blogs
Architecture and Core Operators
- Vaswani et al., “Attention Is All You Need” (NeurIPS 2017) — the original Transformer paper. arxiv.org/abs/1706.03762
- Shazeer, “GLU Variants Improve Transformer” (2020) — source of SwiGLU / GeGLU / ReGLU. arxiv.org/abs/2002.05202
- Zhang & Sennrich, “Root Mean Square Layer Normalization” (NeurIPS 2019) — RMSNorm. arxiv.org/abs/1910.07467
- Hendrycks & Gimpel, “Gaussian Error Linear Units (GELUs)” (2016) — GeLU definition and approximation. arxiv.org/abs/1606.08415
Positional Encoding
- Su et al., “RoFormer: Enhanced Transformer with Rotary Position Embedding” (2021) — RoPE. arxiv.org/abs/2104.09864
- Press et al., “Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation” (ICLR 2022) — ALiBi. arxiv.org/abs/2108.12409
- Peng et al., “YaRN: Efficient Context Window Extension of Large Language Models” (2023) — long-context RoPE scaling. arxiv.org/abs/2309.00071
- bloc97 & emozilla, “NTK-Aware Scaled RoPE” — discussion of early NTK-aware open-source work. reddit / LocalLLaMA
Multi-Head Variants (KV Sharing / Compression)
- Shazeer, “Fast Transformer Decoding: One Write-Head is All You Need” (2019) — MQA. arxiv.org/abs/1911.02150
- Ainslie et al., “GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints” (EMNLP 2023) — GQA. arxiv.org/abs/2305.13245
- DeepSeek-AI, “DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model” (2024) — MLA introduced. arxiv.org/abs/2405.04434
- DeepSeek-AI, “DeepSeek-V3 Technical Report” (2024) — MLA + MoE engineering. arxiv.org/abs/2412.19437
Flash Attention Series
- Dao et al., “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness” (NeurIPS 2022) — FA-1. arxiv.org/abs/2205.14135
- Dao, “FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning” (2023) — FA-2. arxiv.org/abs/2307.08691
- Shah et al., “FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-Precision” (NeurIPS 2024) — FA-3 / Hopper. arxiv.org/abs/2407.08608
- Dao et al., “Flash-Decoding for long-context inference” (Stanford / Together blog, 2023) — Flash Decoding. crfm.stanford.edu
Inference Engines and Serving Schedulers
- Kwon et al., “Efficient Memory Management for Large Language Model Serving with PagedAttention” (SOSP 2023) — vLLM / PagedAttention. arxiv.org/abs/2309.06180
- Yu et al., “Orca: A Distributed Serving System for Transformer-Based Generative Models” (OSDI 2022) — original Continuous Batching paper. usenix.org/osdi22
- Agrawal et al., “SARATHI: Efficient LLM Inference by Piggybacking Decodes with Chunked Prefills” (2023) — chunked prefill. arxiv.org/abs/2308.16369
- Zhong et al., “DistServe: Disaggregating Prefill and Decoding for Goodput-optimized Large Language Model Serving” (OSDI 2024) — prefill/decode disaggregation. arxiv.org/abs/2401.09670
- vLLM main repo (PagedAttention engineering implementation). github.com/vllm-project/vllm
- Hugging Face Text Generation Inference (TGI). github.com/huggingface/text-generation-inference
- NVIDIA TensorRT-LLM documentation (FA / in-flight batching). nvidia.github.io/TensorRT-LLM
Speculative Decoding Family
- Leviathan, Kalman, Matias, “Fast Inference from Transformers via Speculative Decoding” (ICML 2023). arxiv.org/abs/2211.17192
- Chen et al., “Accelerating Large Language Model Decoding with Speculative Sampling” (DeepMind, 2023) — concurrent independent work. arxiv.org/abs/2302.01318
- Cai et al., “Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads” (2024). arxiv.org/abs/2401.10774
- Li et al., “EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty” (ICML 2024). arxiv.org/abs/2401.15077
- Fu et al., “Lookahead Decoding: Breaking the Sequential Dependency of LLM Inference” (2024). arxiv.org/abs/2402.02057
KV Cache Compression / Quantization
- Xiao et al., “Efficient Streaming Language Models with Attention Sinks” (ICLR 2024) — StreamingLLM. arxiv.org/abs/2309.17453
- Zhang et al., “H2O: Heavy-Hitter Oracle for Efficient Generative Inference of Large Language Models” (NeurIPS 2023). arxiv.org/abs/2306.14048
- Li et al., “SnapKV: LLM Knows What You are Looking for Before Generation” (2024). arxiv.org/abs/2404.14469
- Liu et al., “KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache” (ICML 2024). arxiv.org/abs/2402.02750
- Hooper et al., “KVQuant: Towards 10 Million Context Length LLM Inference with KV Cache Quantization” (NeurIPS 2024). arxiv.org/abs/2401.18079
Representative Open-Source Model Technical Reports
- Meta AI, “The Llama 3 Herd of Models” (2024) — Llama 3 family. arxiv.org/abs/2407.21783
- Jiang et al., “Mistral 7B” (2023) — Sliding Window Attention. arxiv.org/abs/2310.06825
- Qwen Team, “Qwen2.5 Technical Report” (2024). arxiv.org/abs/2412.15115
- Workshop et al., “BLOOM: A 176B-Parameter Open-Access Multilingual Language Model” (2022) — ALiBi application. arxiv.org/abs/2211.05100
- MosaicML, “MPT-7B” (2023) — ALiBi long context. databricks.com / MPT-7B
Hardware / Roofline
- NVIDIA, “H100 Tensor Core GPU Architecture Whitepaper” (2022). resources.nvidia.com
- Williams, Waterman, Patterson, “Roofline: An Insightful Visual Performance Model for Multicore Architectures” (CACM 2009) — original Roofline paper. dl.acm.org
- Hoffmann et al., “Training Compute-Optimal Large Language Models” (Chinchilla, 2022) — training-side version of the FLOPs rule of thumb. arxiv.org/abs/2203.15556
Other Long-Form Articles / Tutorials
- Lilian Weng, “The Transformer Family Version 2.0”. lilianweng.github.io
- Horace He, “Making Deep Learning Go Brrrr From First Principles” — three bottleneck classes: compute / memory / overhead. horace.io/brrr_intro
- Adam Casson, “Transformer Inference Arithmetic” — mental math for inference FLOPs / KV cache. kipp.ly
- Jay Mody, “LLM inference, in detail” — another take on tensor shape flow. jaykmody.com