One-sentence summary: MHA gives every head its own K and V (most expressive, most memory); MQA gives all heads one shared K/V (least memory, some quality loss); GQA splits the difference with group-sharing, and that is what most modern LLMs ship.
23.1 The Problem KV Cache Created
23.1.1 The Memory Math
Chapter 22 established that KV Cache is non-negotiable for production inference. But every head in Multi-Head Attention caches its own K and V, and that accumulates fast.
For a typical 7B-parameter model — 32 layers, 32 heads, head dimension 128, FP16:
KV Cache per request =
32 layers × 32 heads × 2 (K and V) × seq_len × 128 × 2 bytes
At seq_len = 1024:
32 × 32 × 2 × 1024 × 128 × 2 = 536 MB
That is 536 MB for one user, one conversation, 1024 tokens. Scale to 100 concurrent users at 4096 tokens each:
536 MB × 4 (4096/1024) × 100 users ≈ 200 GB
200 GB of KV Cache. That is more than two fully loaded A100s just for the cache. This is why the industry started asking whether all those independent K/V heads are actually necessary.
23.1.2 The Root Tension
MHA was designed for training: every head learns a different projection, capturing different patterns. That is a feature. During inference, those independent projections become a burden — we must store one K/V set per head per layer per token.
Training wants expressiveness. Serving wants efficiency. MQA and GQA are the architectures born from that tension.
23.1.3 Three Mechanisms
| Mechanism | Full name | Core idea |
|---|---|---|
| MHA | Multi-Head Attention | Every head has independent K, V |
| MQA | Multi-Query Attention | All heads share one K, one V |
| GQA | Grouped-Query Attention | Groups of heads share one K/V each |
23.2 MHA: The Baseline
23.2.1 Structure
In standard MHA with n_heads heads, every head has:
- its own projection
- its own projection
- its own projection
KV Cache stores n_heads K tensors and n_heads V tensors per layer.
23.2.2 Why Multiple Heads Help
Different heads genuinely learn different things. In a sentence like "The agent tagged the reviewer because the PR was urgent":
- Head 1 might track syntactic subject-verb: agent → tagged
- Head 2 might track pronoun resolution: "the PR" ← which PR?
- Head 3 might track causal reasoning: tagged → because → urgent
- Head 4 might track recency, attending heavily to recent tokens
Independent K/V projections let each head build its own "perspective" on the token history. That is MHA's strength.
23.2.3 MHA Code Shape
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.n_heads = n_heads
self.head_dim = d_model // n_heads
# Full d_model projections → contains all heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model) # n_heads independent K projections
self.W_v = nn.Linear(d_model, d_model) # n_heads independent V projections
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x, kv_cache=None):
B, T, C = x.shape
q = self.W_q(x).view(B, T, self.n_heads, self.head_dim)
k = self.W_k(x).view(B, T, self.n_heads, self.head_dim)
v = self.W_v(x).view(B, T, self.n_heads, self.head_dim)
# KV Cache stores n_heads sets of K and V
23.2.4 The Problem in Production
For a 32-head model, each layer needs 64 tensors in the KV Cache (32 K + 32 V). At long context or high concurrency, this becomes the binding constraint on how many requests you can serve.
An agent system doing tool-augmented reasoning at 16k context makes the problem concrete:
KV Cache per session (Llama-7B, 16k ctx, FP16):
32 layers × 32 heads × 2 (K+V) × 16384 × 128 × 2 bytes ≈ 8 GB
8 GB per active session. On a GPU with 40 GB available (after loading the model weights), you can serve perhaps 4 concurrent long-context sessions. Scale that to a team, and you see the pressure to reduce KV Cache memory.
23.3 MQA: Collapse Everything
23.3.1 The Core Idea
Multi-Query Attention (Shazeer, 2019) makes a simple but aggressive choice: all query heads share one single K and one single V.
- Q still has
n_headsindependent projections - K has 1 projection
- V has 1 projection
KV Cache now stores 2 tensors per layer regardless of how many query heads exist.
23.3.2 Memory Savings
For the same 7B model at 1024 tokens:
MHA KV Cache = 32 layers × 32 heads × 2 × 1024 × 128 × 2 = 536 MB
MQA KV Cache = 32 layers × 1 head × 2 × 1024 × 128 × 2 = 16.75 MB
97% reduction. The same GPU that serves 5 concurrent users with MHA can now serve ~160 users with MQA.
23.3.3 MQA Code Shape
class MultiQueryAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.W_q = nn.Linear(d_model, d_model) # full n_heads
self.W_k = nn.Linear(d_model, self.head_dim) # only 1 head!
self.W_v = nn.Linear(d_model, self.head_dim) # only 1 head!
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x, kv_cache=None):
B, T, C = x.shape
q = self.W_q(x).view(B, T, self.n_heads, self.head_dim)
k = self.W_k(x).view(B, T, 1, self.head_dim)
v = self.W_v(x).view(B, T, 1, self.head_dim)
# k, v broadcast from shape [B,T,1,head_dim] to [B,T,n_heads,head_dim]
23.3.4 The Cost
Forcing all query heads to consult the same K/V reference material limits each head's ability to build an independent view of the token history. MQA works well for many tasks but shows measurable quality degradation on tasks requiring diverse long-range pattern capture. Google's PaLM adopted MQA; the broader community found the quality loss hard to accept at frontier scale.
23.4 GQA: The Practical Middle Ground
23.4.1 Core Idea
Grouped-Query Attention (Ainslie et al., 2023) introduces one hyperparameter: n_kv_heads, the number of K/V groups.
Query heads are divided into n_kv_heads groups. All query heads within a group share one K projection and one V projection.
Formally:
n_heads— number of Q headsn_kv_heads— number of KV groupsn_rep = n_heads / n_kv_heads— Q heads per group
Special cases:
n_kv_heads = n_heads→ MHA (every head independent)n_kv_heads = 1→ MQA (all heads share)1 < n_kv_heads < n_heads→ GQA
23.4.2 GQA Code Shape
class GroupedQueryAttention(nn.Module):
def __init__(self, d_model, n_heads, n_kv_heads):
super().__init__()
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.n_rep = n_heads // n_kv_heads
self.head_dim = d_model // n_heads
self.W_q = nn.Linear(d_model, n_heads * self.head_dim)
self.W_k = nn.Linear(d_model, n_kv_heads * self.head_dim)
self.W_v = nn.Linear(d_model, n_kv_heads * self.head_dim)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x, kv_cache=None):
B, T, C = x.shape
q = self.W_q(x).view(B, T, self.n_heads, self.head_dim)
k = self.W_k(x).view(B, T, self.n_kv_heads, self.head_dim)
v = self.W_v(x).view(B, T, self.n_kv_heads, self.head_dim)
# Expand K and V to match Q head count
k = self.repeat_kv(k) # [B, T, n_heads, head_dim]
v = self.repeat_kv(v)
# Attention proceeds identically to MHA from here
def repeat_kv(self, x):
"""Repeat each KV group n_rep times to match Q head count."""
B, T, n_kv, head_dim = x.shape
if self.n_rep == 1:
return x
# [B, T, n_kv, head_dim] -> [B, T, n_kv, n_rep, head_dim]
x = x.unsqueeze(3).expand(B, T, n_kv, self.n_rep, head_dim)
# -> [B, T, n_heads, head_dim]
return x.reshape(B, T, self.n_heads, head_dim)
23.4.2b Training vs Inference
One nuance worth understanding: during training, KV Cache is not used (the whole sequence is processed in parallel with causal masking). So GQA's benefit during training is just the reduced parameter count from smaller K and V projection matrices — small but real.
During inference, the benefit is much larger. Decode is memory-bandwidth-bound (Chapter 22, Section 22.6.2): each generated token reads the entire KV Cache from HBM. Smaller KV Cache means more of it fits in SRAM, fewer HBM reads per token, and higher throughput. GQA's 4× or 8× memory reduction translates almost directly into faster decode.
23.4.3 Geometric View of repeat_kv
For n_heads = 8, n_kv_heads = 2:
Original K/V shape: [B, T, 2, head_dim]
KV group 0 KV group 1
After repeat_kv: [B, T, 8, head_dim]
Q0 Q1 Q2 Q3 Q4 Q5 Q6 Q7
↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓
KV0 KV0 KV0 KV0 KV1 KV1 KV1 KV1
Q heads 0–3 share KV group 0; Q heads 4–7 share KV group 1. Computationally this is a tensor repeat, not a separate projection — no parameters are added.
23.5 Three-Way Comparison
23.5.1 Memory Numbers
For a 7B model, 1024-token sequence, FP16:
MHA (32 KV heads):
32 × 32 × 2 × 1024 × 128 × 2 bytes = 536 MB
GQA (8 KV heads):
32 × 8 × 2 × 1024 × 128 × 2 bytes = 134 MB
MQA (1 KV head):
32 × 1 × 2 × 1024 × 128 × 2 bytes = 16.75 MB
| Mechanism | KV heads | KV Cache | vs MHA |
|---|---|---|---|
| MHA | 32 | 536 MB | 100% |
| GQA | 8 | 134 MB | 25% |
| MQA | 1 | 16.75 MB | 3.1% |
GQA hits 25% of MHA's memory cost while retaining close to MHA-level quality. MQA gets to 3.1% but pays a steeper quality price.
23.5.2 Quality vs Efficiency
From the GQA paper's benchmarks:
- GQA-G8 (8 groups) sits close to MHA in quality
- GQA-G8 inference time is close to MQA
- Quality improvements from adding more KV groups plateau quickly beyond 8
An important empirical finding: in trained MHA models, different heads' K and V representations are often surprisingly similar. Many heads learn near-redundant projections. That is why sharing K/V within a group loses relatively little — the diversity you give up was not providing much signal to begin with.
This finding has an architectural implication: if you are designing a model from scratch rather than converting an existing MHA checkpoint, you can train directly with GQA and the model learns to use its KV capacity efficiently from the start. The redundancy in MHA is partly an artifact of having no incentive to differentiate K/V representations across heads during training.
23.5.3 Serving Concurrency Impact
The memory numbers above directly determine how many users you can serve simultaneously. On an A100 80GB GPU with a 7B model loaded at FP16 (14 GB), roughly 66 GB remains for KV Cache:
| Mechanism | Per-session KV at 4k ctx | Max concurrent sessions |
|---|---|---|
| MHA | 536 MB × 4 = 2.1 GB | ~31 |
| GQA (8 heads) | 134 MB × 4 = 536 MB | ~123 |
| MQA | 16.75 MB × 4 = 67 MB | ~984 |
GQA roughly quadruples your concurrent user capacity versus MHA at the same hardware budget. That is the business case in one table.
23.5.4 The Full Tradeoff Table
| Mechanism | Quality | Inference speed | KV memory | When to use |
|---|---|---|---|---|
| MHA | Highest | Slowest | Largest | Research, small models, training-only settings |
| MQA | Some loss | Fastest | Smallest | Edge/mobile, extreme throughput requirements |
| GQA | Near-MHA | Near-MQA | Medium | Almost everything in production |
23.6 What Modern Models Ship
23.6.1 The Industry Has Converged on GQA
| Model | Params | Q heads | KV heads | Group size |
|---|---|---|---|---|
| Llama-2 7B | 7B | 32 | 32 | 1 (MHA) |
| Llama-2 70B | 70B | 64 | 8 | 8 |
| Llama-3 8B | 8B | 32 | 8 | 4 |
| Llama-3 70B | 70B | 64 | 8 | 8 |
| Mistral 7B | 7B | 32 | 8 | 4 |
| Qwen-1.5 7B | 7B | 32 | 32 | 1 (MHA) |
| Qwen-1.5 32B | 32B | 40 | 8 | 5 |
| Qwen-2 7B | 7B | 28 | 4 | 7 |
A few observations: smaller models sometimes stick to MHA because the absolute memory cost is manageable and the full expressiveness matters; larger models almost universally go GQA; 8 KV heads is a common sweet spot.
23.6.2 Why 8 KV Heads?
Research shows quality improves rapidly from 1 to 8 KV heads and then flattens. Meanwhile, 8 divides evenly into common tensor-parallel configurations (2, 4, or 8 GPUs), so the KV heads can be distributed cleanly across devices. It is both empirically good and operationally convenient.
23.6.3 Multi-GPU Benefit
In tensor-parallel serving (e.g., 4 GPUs):
MHA with 32 heads:
GPU 0: Q heads 0–7, K heads 0–7, V heads 0–7
GPU 1: Q heads 8–15, K heads 8–15, V heads 8–15
...
GQA with 32 Q heads, 8 KV heads:
GPU 0: Q heads 0–7, K heads 0–1, V heads 0–1
GPU 1: Q heads 8–15, K heads 2–3, V heads 2–3
...
Each GPU's KV Cache is 4× smaller. This matters when running many concurrent requests.
23.7 Converting MHA Checkpoints to GQA
If you already have a trained MHA model, Google's GQA paper proposed uptraining:
- Average weights — for each group of K/V heads to be merged, take the mean of their projection matrices
- Continue training — run a short fine-tuning pass on about 5% of the original training data
- Recover quality — the model adapts quickly because the averaged weights are already a reasonable initialization
def convert_mha_to_gqa(k_weights, n_heads, n_kv_heads, head_dim, d_model):
"""Average groups of K (or V) projection matrices."""
group_size = n_heads // n_kv_heads
# k_weights shape: [d_model, n_heads * head_dim]
k = k_weights.reshape(d_model, n_heads, head_dim)
# Group and average: [d_model, n_kv_heads, group_size, head_dim]
k_grouped = k.reshape(d_model, n_kv_heads, group_size, head_dim)
k_gqa = k_grouped.mean(dim=2) # [d_model, n_kv_heads, head_dim]
return k_gqa.reshape(d_model, n_kv_heads * head_dim)
This works because of the empirical redundancy mentioned earlier: the averaged result is already a reasonable approximation of what each group's shared K/V should look like.
The GQA paper reports recovering most of the quality gap with just 5% of the original training data. This makes uptraining practical: you invest in training a high-quality MHA model once, then cheaply convert it to a GQA model for efficient serving.
23.8 Flash Attention and GQA Together
Flash Attention 2 added native GQA support directly in the kernel. This matters for efficiency.
Without GQA awareness, a Flash Attention implementation would need to expand K and V via repeat_kv before the tiled loop — creating a larger tensor in HBM. With native GQA, the kernel maps each Q block to its KV group index (group_idx = q_head_idx // n_rep) and loads the right K/V tile without materializing the expanded tensor.
The effect: you get both Flash Attention's IO efficiency and GQA's smaller KV footprint, with no extra memory overhead from the repeat operation. When you call F.scaled_dot_product_attention or use a GQA-aware implementation like vLLM or TensorRT-LLM, this optimization is typically applied automatically.
23.10 Common Misconceptions
"GQA is just MQA with more heads." Not exactly. GQA is a parameterized family. MHA and MQA are the two extremes; GQA is the whole spectrum between them. The key design choice is n_kv_heads.
"Fewer KV heads is always better." The quality-efficiency tradeoff is real. Going from 32 to 8 KV heads cuts memory 4× with minimal quality loss; going from 8 to 1 cuts memory another 8× but with more visible degradation. The optimal setting depends on your serving constraints and quality bar.
"GQA only affects inference." GQA also reduces parameter count slightly (smaller K and V projection matrices), which can speed up training and reduce model file size. The effect is small but not zero. For a 7B model with 32 heads going to 8 KV heads: the K and V projection matrices shrink from [d_model, d_model] to [d_model, d_model/4], saving about 6% of total parameters.
"Every model should use GQA." For very small models (sub-7B) deployed in memory-abundant settings, MHA's extra expressiveness may be worth the overhead. Always measure quality before committing to a specific n_kv_heads.
"Flash Attention and GQA conflict." They are complementary. Flash Attention 2 added native GQA support: it handles the repeat_kv internally during the tiled computation, so you get both the IO-efficient kernel and the smaller KV footprint simultaneously.
23.11 Chapter Summary
MHA (Multi-Head Attention)
n_kv_heads = n_heads
Each head has independent Q, K, V
KV Cache: 2 × n_heads tensors per layer
Best quality, highest memory
MQA (Multi-Query Attention)
n_kv_heads = 1
All heads share one K, one V
KV Cache: 2 tensors per layer
Lowest memory, quality risk at scale
GQA (Grouped-Query Attention)
1 < n_kv_heads < n_heads
Groups of heads share K/V
KV Cache: 2 × n_kv_heads tensors per layer
Near-MHA quality, near-MQA efficiency
Selection Guide
| Situation | Recommendation | Reason |
|---|---|---|
| Research / training focus | MHA | Maximum expressiveness |
| Large-scale production serving | GQA (8 KV heads) | Best quality-efficiency balance |
| Edge / mobile / extreme efficiency | MQA | Minimum memory footprint |
| Uncertain | GQA with 8 KV heads | Safe, empirically validated default |
23.11.1 Reading a Model Config
The Hugging Face config.json for any modern model will list both fields. For Llama-3 8B:
{
"num_attention_heads": 32,
"num_key_value_heads": 8
}
num_attention_heads is n_heads (Q heads). num_key_value_heads is n_kv_heads. The group size is 32 / 8 = 4 — each KV pair is shared by 4 query heads. If num_key_value_heads equals num_attention_heads, the model uses MHA. If it equals 1, it uses MQA.
Chapter Checklist
After this chapter, you should be able to:
- Explain why MHA's KV Cache becomes a bottleneck at long context or high concurrency.
- Describe MHA, MQA, and GQA in one sentence each.
- Calculate KV Cache size for each mechanism given model dimensions.
- Explain why GQA quality loss is small despite large memory savings.
- Read a model config and identify its
n_headsandn_kv_heads. - Implement
repeat_kv. - Estimate how many concurrent users a given GPU can serve under each mechanism.
Further Reading
- Fast Transformer Decoding: One Write-Head is All You Need (Shazeer, 2019) — MQA original paper
- GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints (Ainslie et al., 2023)
- Llama 2: Open Foundation and Fine-Tuned Chat Models (Meta, 2023) — shows MHA→GQA transition across model scales
See You in the Next Chapter
GQA reduces the amount of K/V data we store. But even with a smaller cache, every token still attends to every other token within the window — the quadratic cost of full Attention remains.
Chapter 24 explores what happens when you drop that requirement entirely. Sparse Attention lets each token attend to only a chosen subset of the sequence, pushing complexity toward O(N). And Infini Attention goes further, using a fixed-size compressed memory to handle context that grows without bound.