One-sentence summary: Multi-Head Attention splits Attention into several heads, each learning a different relationship pattern, then merges all views into a single representation.
11.1 Why Multiple Heads?
11.1.1 The Limit of a Single Head
Chapter 10 covered the complete Attention computation — but that was single-head Attention: one set of Q, K, V matrices, one relationship map.
Single-head Attention can learn one attention pattern at a time. Language needs more than one.
Take this sentence:
"The agent opened a pull request because the test suite was green."
Fully understanding it requires tracking several kinds of relationships simultaneously:
- Syntactic: "opened" takes "agent" as its subject
- Coreference: what does "it" refer to? (not in this sentence, but typical in longer text)
- Causal: "because" connects the PR opening to the green tests
- Positional: "opened" and "pull request" are neighbors
A single head cannot specialize in all of these at once.
11.1.2 The Solution: Multiple Heads in Parallel
Multi-Head Attention's core idea: run several Attention computations in parallel, each in a smaller subspace, so each head can specialize.
Head 1: might focus on syntactic structure (subject-verb-object)
Head 2: might focus on coreference (pronouns and nouns)
Head 3: might focus on local proximity (neighboring tokens)
Head 4: might focus on semantic similarity (related concepts)
...
Then all heads' outputs are merged into a single representation.
11.1.3 An Analogy
Think of code review from multiple teammates:
- One reviewer checks correctness
- One checks naming and style
- One checks test coverage
- One checks security implications
Each brings a different lens. You merge all their comments into your final understanding. Multi-Head Attention does exactly this — but the "lenses" are learned, not manually assigned.
11.2 Splitting Into Heads
11.2.1 The Dimension Split
The key operation in Multi-Head Attention is splitting the model dimension across heads.
Using K (Key) as the example, with:
d_model = 512num_heads = 4- Therefore
d_key = d_model / num_heads = 512 / 4 = 128
The split unfolds as:
Original K: [batch_size, ctx_length, d_model]
= [4, 16, 512]
↓
Split: [batch_size, ctx_length, num_heads, d_key]
= [4, 16, 4, 128]
↓
Transpose: [batch_size, num_heads, ctx_length, d_key]
= [4, 4, 16, 128]
11.2.2 Why Transpose?
The transpose brings num_heads to the second axis, giving the shape [batch, num_heads, seq_len, d_key]. This means:
- For each sequence in the batch
- We have
num_headsindependent Attention computations - Each one processes
seq_lenpositions - Each position uses a
d_key-dimensional vector
With this layout, every head can compute Attention independently, without interfering with the others.
11.2.3 The Same Split Applies to Q, K, and V
Q: [4, 16, 512] → [4, 4, 16, 128]
K: [4, 16, 512] → [4, 4, 16, 128]
V: [4, 16, 512] → [4, 4, 16, 128]
We now have 4 sets of (Q, K, V), ready for 4 independent Attention computations.
11.2.4 Two Equivalent Implementations
There are two ways to think about the split, and they are mathematically equivalent:
Conceptual view: each head has its own small Wq, Wk, Wv matrices. Head h computes Q_h = X @ Wq_h with a [d_model, d_key] matrix.
Practical view: one large Wq generates the full Q of shape [batch, seq, d_model], then we reshape and split along the last dimension into num_heads slices.
Real implementations use the practical view because a single large matrix multiplication is more GPU-efficient than many small ones. The GPU prefers large, contiguous operations over many small scattered ones.
11.3 Computing All Heads in Parallel
11.3.1 Each Head Is Independent
After the split, every head executes the same Attention formula independently:
For each head h = 1, 2, 3, 4:
scores_h = Q_h @ K_h^T [4, 16, 128] @ [4, 128, 16] = [4, 16, 16]
weights_h = softmax(scores_h / sqrt(d_key))
output_h = weights_h @ V_h [4, 16, 16] @ [4, 16, 128] = [4, 16, 128]
11.3.2 Dimension Tracking
Q @ K^T for one head:
Q: [4, 4, 16, 128]
batch heads seq d_key
K^T: [4, 4, 128, 16]
batch heads d_key seq
Q @ K^T: [4, 4, 16, 16]
batch heads seq seq
Softmax(Q @ K^T) @ V:
Attention Weights: [4, 4, 16, 16]
batch heads seq seq
V: [4, 4, 16, 128]
batch heads seq d_key
Output: [4, 4, 16, 128]
batch heads seq d_key
11.3.3 What the Parallelism Gets You
The total computation is the same as a single-head Attention with width 512. But running four heads of width 128 means each head operates in a smaller, more focused subspace. Each head can develop a clean specialization instead of trying to capture every relationship pattern in one large matrix.
11.4 Merging the Heads Back
11.4.1 Concatenation
After all heads compute their output, we concatenate them back into the full model dimension:
Head outputs: [4, 4, 16, 128]
batch heads seq d_key
↓
Transpose: [4, 16, 4, 128]
batch seq heads d_key
↓
Concatenate: [4, 16, 512]
batch seq d_model
The concatenation operation just merges the last two dimensions:
- 4 heads × 128 dimensions = 512 dimensions
11.4.2 The Output Projection Wo
Concatenation is mechanical. It puts the heads' outputs next to each other but does not let them interact. That is what Wo is for:
A @ Wo
[4, 16, 512] @ [512, 512] = [4, 16, 512]
Wo is a learned projection matrix. Its job:
- Mix information across heads — what each head learned can now influence the others
- Project the concatenated representation into a unified space
- Let the model decide how to weight each head's contribution
11.4.3 Why Wo Matters
Without Wo, Head 1's output and Head 3's output sit in different regions of the 512-dimensional vector, and nothing connects them. Wo provides one round of cross-head communication before passing the result to the next block.
11.5 Comparing the Outputs: Before and After Wo
11.5.1 A vs A @ Wo
Before Wo (A):
- Shape: [16, 512]
- Values: the raw concatenation of all heads' output vectors
After Wo (A @ Wo):
- Shape: [16, 512]
- Values: a mixed, projected representation
Same shape. Different content. The post-Wo output is what flows into the residual connection, then LayerNorm, then the FFN.
11.6 Full Multi-Head Attention Flow
11.6.1 End to End
Input X [batch, seq, d_model]
↓
Generate Q, K, V (via Wq, Wk, Wv)
↓
Split into heads [batch, num_heads, seq, d_key]
↓
Compute Attention independently per head
↓
Concatenate [batch, seq, d_model]
↓
Output projection (@ Wo)
↓
Output [batch, seq, d_model]
11.6.2 PyTorch Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d_key = d_model // num_heads
# Four learnable weight matrices
self.Wq = nn.Linear(d_model, d_model)
self.Wk = nn.Linear(d_model, d_model)
self.Wv = nn.Linear(d_model, d_model)
self.Wo = nn.Linear(d_model, d_model)
def forward(self, x, mask=None):
batch_size, seq_len, _ = x.shape
# 1. Generate Q, K, V
Q = self.Wq(x) # [batch, seq, d_model]
K = self.Wk(x)
V = self.Wv(x)
# 2. Split into heads
Q = Q.view(batch_size, seq_len, self.num_heads, self.d_key)
K = K.view(batch_size, seq_len, self.num_heads, self.d_key)
V = V.view(batch_size, seq_len, self.num_heads, self.d_key)
# Transpose: [batch, num_heads, seq, d_key]
Q = Q.transpose(1, 2)
K = K.transpose(1, 2)
V = V.transpose(1, 2)
# 3. Attention per head
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_key ** 0.5)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
attention_output = torch.matmul(attention_weights, V)
# 4. Merge heads
attention_output = attention_output.transpose(1, 2) # [batch, seq, heads, d_key]
attention_output = attention_output.contiguous().view(
batch_size, seq_len, self.d_model
)
# 5. Output projection
output = self.Wo(attention_output)
return output
11.7 Key Numbers
11.7.1 Parameter Count
Multi-Head Attention has four weight matrices:
| Matrix | Shape | Parameters |
|---|---|---|
| Wq | [d_model, d_model] | d_model² |
| Wk | [d_model, d_model] | d_model² |
| Wv | [d_model, d_model] | d_model² |
| Wo | [d_model, d_model] | d_model² |
Total: 4 × d_model²
For GPT-2 Small (d_model = 768): 4 × 768² ≈ 2.36 million parameters per Attention layer.
11.7.2 Common Configurations
| Model | d_model | num_heads | d_key |
|---|---|---|---|
| GPT-2 Small | 768 | 12 | 64 |
| GPT-2 Medium | 1024 | 16 | 64 |
| GPT-2 Large | 1280 | 20 | 64 |
| GPT-3 | 12288 | 96 | 128 |
| LLaMA-7B | 4096 | 32 | 128 |
Notice: d_key stays at 64 or 128 across a wide range of model sizes. Bigger models add more heads rather than making each head wider.
11.8 What Do the Heads Actually Learn?
11.8.1 Observed Patterns from Research
Researchers have identified recurring patterns in trained Attention heads:
| Head type | Pattern | Example |
|---|---|---|
| Positional | Attends to nearby fixed offsets | always look one position back |
| Syntactic | Tracks subject-verb-object | verb attends to its subject |
| Semantic | Groups related concepts | synonyms attend to each other |
| Coreference | Resolves pronoun references | "it" attends to the noun it replaces |
| Delimiter | Tracks sentence boundaries | attends to punctuation |
11.8.2 A Practical Example
For "The agent merged the pull request after review":
Head 1 (positional): "merged" mainly attends to "agent" (adjacent subject)
Head 2 (syntactic): "merged" mainly attends to "agent" (grammatical subject)
Head 3 (semantic): "pull request" and "review" attend to each other
Head 4 (coreference): not active here (no pronouns)
11.8.3 Head Redundancy
Not all heads are equally important. Research shows:
- Some heads can be pruned with minimal performance loss
- Some heads learn redundant patterns
- But keeping more heads generally improves robustness and reduces training sensitivity
The right number of heads is empirical. There is no closed-form answer.
11.9 Multi-Head vs Single-Head
11.9.1 Compute Comparison
For d_model = 512, num_heads = 8, d_key = 64:
Single head (with d_key = 512):
- Q @ K^T: [seq, 512] @ [512, seq] → O(seq² × 512)
Eight heads (with d_key = 64 each):
- Each head: [seq, 64] @ [64, seq] → O(seq² × 64)
- Total: 8 × O(seq² × 64) = O(seq² × 512)
Same total compute. Different capability.
11.9.2 Why Not More Heads?
More heads means smaller d_key:
d_key = d_model / num_heads
If d_key gets too small, each head has too few dimensions to represent a useful subspace. Empirically, d_key = 64 or d_key = 128 is the practical sweet spot.
11.10 Part 3 Wrap-Up
This chapter closes Part 3: Attention Mechanisms. Here is what we covered across the four chapters:
| Chapter | Topic | Core idea |
|---|---|---|
| Chapter 8 | Linear Transforms | Matrix multiplication as projection and similarity |
| Chapter 9 | Attention Geometry | Dot product as a similarity measure |
| Chapter 10 | Q, K, V | The three roles and the full computation |
| Chapter 11 | Multi-Head | Parallel views; concatenation and Wo |
The complete Multi-Head Attention formula:
Where:
Chapter Checklist
After this chapter, you should be able to:
- Explain why a single Attention head has limitations.
- Derive the relationship
d_key = d_model / num_heads. - Trace dimension changes through split, compute, and merge.
- Explain what Wo does after concatenation.
- Describe the kinds of patterns different heads can learn.
See You in the Next Chapter
Part 4 introduces the full Transformer block architecture — assembling all the components we have built so far.
The next chapter, Chapter 12, ties up the remaining conceptual thread: what does the Attention output actually represent, and what are the two things training is adjusting simultaneously?