一文要約: Multi-Head Attention は Attention を複数のヘッドに分割し、それぞれが異なる関係パターンを学習したうえで、すべての視点を一つの表現に統合します。
11.1 なぜ複数のヘッドが必要なのか
11.1.1 単一ヘッドの限界
第10章では Attention の完全な計算を扱いましたが、それは 単一ヘッドの Attention でした。Q、K、V の組が一つだけ、関係マップも一つだけです。
単一ヘッドの Attention では、一度に一つの Attention パターンしか学習できません。しかし言語が必要とするのはそれ以上です。
次の文を見てみましょう:
"The agent opened a pull request because the test suite was green."
この文を完全に理解するには、いくつかの種類の関係を同時に追跡する必要があります:
- 構文的: "opened" は "agent" を主語に取る
- 共参照: "it" は何を指しているか? (この文には含まれませんが、長い文章では典型的)
- 因果的: "because" は PR の作成とテストの成功を結びつける
- 位置的: "opened" と "pull request" は隣接している
単一のヘッドでは、これらすべてに同時に特化することはできません。
11.1.2 解決策: 複数のヘッドを並列に
Multi-Head Attention の核心となるアイデアは、複数の Attention 計算を並列に実行し、それぞれをより小さな部分空間で行うことで、各ヘッドが特化できるようにすることです。
Head 1: 構文構造に注目するかもしれない (主語-動詞-目的語)
Head 2: 共参照に注目するかもしれない (代名詞と名詞)
Head 3: 局所的な近接性に注目するかもしれない (隣接トークン)
Head 4: 意味的類似性に注目するかもしれない (関連概念)
...
そして、すべてのヘッドの出力を一つの表現に統合します。
11.1.3 アナロジー
複数のチームメイトによるコードレビューを思い浮かべてみてください:
- ある人は正しさをチェックする
- ある人は命名とスタイルをチェックする
- ある人はテストカバレッジをチェックする
- ある人はセキュリティへの影響をチェックする
それぞれが異なる視点を持ち込みます。あなたはすべてのコメントを統合して最終的な理解にまとめます。Multi-Head Attention はまさにこれを行いますが、「視点」は手動で割り当てられるのではなく、学習されるという点が異なります。
11.2 ヘッドへの分割
11.2.1 次元の分割
Multi-Head Attention の鍵となる操作は、モデル次元をヘッド間で分割することです。
K (Key) を例として、以下の設定で考えます:
d_model = 512num_heads = 4- したがって
d_key = d_model / num_heads = 512 / 4 = 128
分割は次のように展開されます:
元の K: [batch_size, ctx_length, d_model]
= [4, 16, 512]
↓
分割: [batch_size, ctx_length, num_heads, d_key]
= [4, 16, 4, 128]
↓
転置: [batch_size, num_heads, ctx_length, d_key]
= [4, 4, 16, 128]
11.2.2 なぜ転置するのか
転置によって num_heads を第2軸に移動させ、形状を [batch, num_heads, seq_len, d_key] にします。これは次のことを意味します:
- バッチ内の各シーケンスについて
num_heads個の独立した Attention 計算を持つ- それぞれが
seq_len個の位置を処理する - 各位置は
d_key次元のベクトルを使用する
このレイアウトにより、すべてのヘッドが他に干渉することなく独立して Attention を計算できます。
11.2.3 同じ分割を Q、K、V に適用する
Q: [4, 16, 512] → [4, 4, 16, 128]
K: [4, 16, 512] → [4, 4, 16, 128]
V: [4, 16, 512] → [4, 4, 16, 128]
これで 4 組の (Q, K, V) が揃い、4 つの独立した Attention 計算の準備が整いました。
11.2.4 二つの等価な実装
分割の考え方には二通りあり、両者は数学的に等価です:
概念的な見方: 各ヘッドが独自の小さな Wq、Wk、Wv 行列を持つ。ヘッド h は [d_model, d_key] の行列を使って Q_h = X @ Wq_h を計算する。
実用的な見方: 一つの大きな Wq が [batch, seq, d_model] の形状の Q 全体を生成し、その後最後の次元に沿って num_heads 個のスライスにリシェイプして分割する。
実際の実装では実用的な見方を採用します。なぜなら、一つの大きな行列乗算のほうが、多数の小さな乗算よりも GPU 効率が良いからです。GPU は多数の小さく散らばった演算よりも、大きく連続した演算を好みます。
11.3 すべてのヘッドを並列に計算する
11.3.1 各ヘッドは独立している
分割後、すべてのヘッドが同じ Attention の式を独立に実行します:
各ヘッド h = 1, 2, 3, 4 について:
scores_h = Q_h @ K_h^T [4, 16, 128] @ [4, 128, 16] = [4, 16, 16]
weights_h = softmax(scores_h / sqrt(d_key))
output_h = weights_h @ V_h [4, 16, 16] @ [4, 16, 128] = [4, 16, 128]
11.3.2 次元の追跡
一つのヘッドにおける Q @ K^T:
Q: [4, 4, 16, 128]
batch heads seq d_key
K^T: [4, 4, 128, 16]
batch heads d_key seq
Q @ K^T: [4, 4, 16, 16]
batch heads seq seq
Softmax(Q @ K^T) @ V:
Attention Weights: [4, 4, 16, 16]
batch heads seq seq
V: [4, 4, 16, 128]
batch heads seq d_key
Output: [4, 4, 16, 128]
batch heads seq d_key
11.3.3 並列化が得るもの
総計算量は、幅 512 の単一ヘッド Attention と同じです。しかし、幅 128 のヘッドを 4 つ走らせるということは、各ヘッドがより小さく焦点の絞られた部分空間で動作することを意味します。各ヘッドは、すべての関係パターンを一つの大きな行列で捉えようとする代わりに、明確な専門性を発展させることができます。
11.4 ヘッドを統合して戻す
11.4.1 連結 (Concat)
すべてのヘッドが出力を計算した後、それらを連結して完全なモデル次元に戻します:
ヘッド出力: [4, 4, 16, 128]
batch heads seq d_key
↓
転置: [4, 16, 4, 128]
batch seq heads d_key
↓
連結: [4, 16, 512]
batch seq d_model
連結操作は単に最後の二つの次元をマージするだけです:
- 4 ヘッド × 128 次元 = 512 次元
11.4.2 出力射影 Wo
連結は機械的な操作です。各ヘッドの出力を隣り合わせに並べるだけで、それらを相互作用させることはありません。それを行うのが Wo です:
A @ Wo
[4, 16, 512] @ [512, 512] = [4, 16, 512]
Wo は学習される射影行列です。その役割は:
- ヘッド間で情報を混合する。各ヘッドが学んだことが他のヘッドに影響を与えられるようになる
- 連結された表現を統一された空間に射影する
- 各ヘッドの貢献度をどのように重み付けするかをモデルに決定させる
11.4.3 なぜ Wo が重要なのか
Wo がなければ、ヘッド 1 の出力とヘッド 3 の出力は 512 次元のベクトルの異なる領域に置かれ、両者を結びつけるものは何もありません。Wo は、結果を次のブロックに渡す前に、ヘッド間の通信を一度行います。
11.5 出力の比較: Wo の前と後
11.5.1 A と A @ Wo
Wo 適用前 (A):
- 形状: [16, 512]
- 値: すべてのヘッドの出力ベクトルを生のまま連結したもの
Wo 適用後 (A @ Wo):
- 形状: [16, 512]
- 値: 混合され射影された表現
形状は同じ。中身は異なる。Wo 適用後の出力が、残差接続、LayerNorm、FFN へと流れていきます。
11.6 Multi-Head Attention の全体フロー
11.6.1 エンドツーエンド
入力 X [batch, seq, d_model]
↓
Q, K, V を生成 (Wq, Wk, Wv 経由)
↓
ヘッドに分割 [batch, num_heads, seq, d_key]
↓
ヘッドごとに独立に Attention を計算
↓
連結 [batch, seq, d_model]
↓
出力射影 (@ Wo)
↓
出力 [batch, seq, d_model]
11.6.2 PyTorch 実装
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d_key = d_model // num_heads
# 4 つの学習可能な重み行列
self.Wq = nn.Linear(d_model, d_model)
self.Wk = nn.Linear(d_model, d_model)
self.Wv = nn.Linear(d_model, d_model)
self.Wo = nn.Linear(d_model, d_model)
def forward(self, x, mask=None):
batch_size, seq_len, _ = x.shape
# 1. Q, K, V を生成
Q = self.Wq(x) # [batch, seq, d_model]
K = self.Wk(x)
V = self.Wv(x)
# 2. ヘッドに分割
Q = Q.view(batch_size, seq_len, self.num_heads, self.d_key)
K = K.view(batch_size, seq_len, self.num_heads, self.d_key)
V = V.view(batch_size, seq_len, self.num_heads, self.d_key)
# 転置: [batch, num_heads, seq, d_key]
Q = Q.transpose(1, 2)
K = K.transpose(1, 2)
V = V.transpose(1, 2)
# 3. ヘッドごとの Attention
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_key ** 0.5)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
attention_output = torch.matmul(attention_weights, V)
# 4. ヘッドを統合
attention_output = attention_output.transpose(1, 2) # [batch, seq, heads, d_key]
attention_output = attention_output.contiguous().view(
batch_size, seq_len, self.d_model
)
# 5. 出力射影
output = self.Wo(attention_output)
return output
11.7 重要な数値
11.7.1 パラメータ数
Multi-Head Attention には 4 つの重み行列があります:
| 行列 | 形状 | パラメータ数 |
|---|---|---|
| Wq | [d_model, d_model] | d_model² |
| Wk | [d_model, d_model] | d_model² |
| Wv | [d_model, d_model] | d_model² |
| Wo | [d_model, d_model] | d_model² |
合計: 4 × d_model²
GPT-2 Small (d_model = 768) の場合: 4 × 768² ≈ Attention レイヤーあたり約 236 万パラメータ。
11.7.2 一般的な構成
| モデル | d_model | n_heads | d_key |
|---|---|---|---|
| GPT-2 Small | 768 | 12 | 64 |
| GPT-2 Medium | 1024 | 16 | 64 |
| GPT-2 Large | 1280 | 20 | 64 |
| GPT-3 | 12288 | 96 | 128 |
| LLaMA-7B | 4096 | 32 | 128 |
注目すべき点: d_key は幅広いモデルサイズにわたって 64 または 128 のままです。大規模なモデルは、各ヘッドを広くするのではなく、ヘッドの数を増やすのです。
11.8 ヘッドは実際に何を学習するのか
11.8.1 研究で観察されたパターン
研究者たちは、訓練済みの Attention ヘッドに繰り返し現れるパターンを特定してきました:
| ヘッドのタイプ | パターン | 例 |
|---|---|---|
| 位置的 | 近くの固定オフセットに注目する | 常に 1 つ前の位置を見る |
| 構文的 | 主語-動詞-目的語を追跡する | 動詞がその主語に注目する |
| 意味的 | 関連する概念をグループ化する | 同義語が互いに注目する |
| 共参照的 | 代名詞の参照を解決する | "it" がそれが置き換える名詞に注目する |
| 区切り | 文の境界を追跡する | 句読点に注目する |
11.8.2 実用的な例
"The agent merged the pull request after review" の場合:
Head 1 (位置的): "merged" は主に "agent" に注目 (隣接する主語)
Head 2 (構文的): "merged" は主に "agent" に注目 (文法的な主語)
Head 3 (意味的): "pull request" と "review" が互いに注目
Head 4 (共参照的): ここでは活性化しない (代名詞がない)
11.8.3 ヘッドの冗長性
すべてのヘッドが等しく重要というわけではありません。研究によると:
- 一部のヘッドは性能の低下を最小限に抑えながら剪定 (プルーニング) できる
- 一部のヘッドは冗長なパターンを学習している
- しかし、より多くのヘッドを保持することは一般に頑健性を高め、訓練の感度を下げる
ヘッドの適切な数は経験的に決まります。閉じた形の答えはありません。
11.9 Multi-Head と Single-Head の比較
11.9.1 計算量の比較
d_model = 512、n_heads = 8、d_key = 64 の場合:
単一ヘッド (d_key = 512 の場合):
- Q @ K^T: [seq, 512] @ [512, seq] → O(seq² × 512)
8 ヘッド (各 d_key = 64):
- 各ヘッド: [seq, 64] @ [64, seq] → O(seq² × 64)
- 合計: 8 × O(seq² × 64) = O(seq² × 512)
総計算量は同じ。能力は異なる。
11.9.2 なぜもっとヘッドを増やさないのか
ヘッドが増えると d_key が小さくなります:
d_key = d_model / n_heads
d_key が小さくなりすぎると、各ヘッドは有用な部分空間を表現するための次元が足りなくなります。経験的には、d_key = 64 または d_key = 128 が実用上のスイートスポットです。
11.10 第3部のまとめ
この章で第3部「Attention メカニズム」が締めくくられます。4 つの章で扱った内容を振り返ります:
| 章 | テーマ | 中核となるアイデア |
|---|---|---|
| 第8章 | 線形変換 | 射影と類似度としての行列乗算 |
| 第9章 | Attention の幾何 | 類似度尺度としての内積 |
| 第10章 | Q, K, V | 3 つの役割と完全な計算 |
| 第11章 | Multi-Head | 並列の視点。連結と Wo |
Multi-Head Attention の完全な式:
ここで:
章末チェックリスト
この章を読んだ後、あなたは次のことができるはずです:
- 単一の Attention ヘッドに限界がある理由を説明できる。
-
d_key = d_model / n_headsという関係を導出できる。 - 分割、計算、統合を通じた次元の変化を追跡できる。
- 連結後に Wo が何を行うかを説明できる。
- 異なるヘッドが学習しうるパターンの種類を説明できる。
次の章でお会いしましょう
第12章は、第3部の締めくくりとして残された概念の糸を結びます。Attention の出力は実際に何を表しているのか、そして訓練が同時に調整している二つのものとは何か、という問いに答えていきます。
そのあと第4部がすぐに続きます。これまで構築してきたすべての部品 (トークナイゼーション、位置エンコーディング、Attention、FFN) を組み合わせ、Transformer ブロックの完全なアーキテクチャへと組み上げていきます。
ここまでお疲れさまでした。複数のヘッドという考え方が、最初は抽象的に見えても、次元の流れを一度追ってしまえばすっと腑に落ちるはずです。次の章でまたお会いしましょう。