一文で言うと: QKV Attention は何かを予測しているのではなく、各トークンの埋め込みベクトルを文脈の中で意味を持つように継続的に調整しているのです。
12.1 短いけれど実り多い章
本章で Multi-Head Attention の話題に区切りをつけます。これまでの章で、出力テンソル A を一歩ずつ導出してきました。それがどのように計算されるかは、すでに明確にイメージできているはずです。
多くの解説で省略されがちなのは、A が実際には何を 意味する のか、そしてそれを生み出すときにモデルが何をしているのか、という点です。
取り上げるのは三つです。
- Concatenate: 複数のヘッドを一つのテンソルに統合する
- 線形変換 Wo: 最後の行列積
- 学習の本質: QKV が実際に何を調整しているのか
ここがクリアになると、Layer Normalization と残差接続の理解がぐっと楽になります。
12.2 A の形: 4 次元テンソル
12.2.1 Attention の可視化
図は Attention の中核プロセスを示しています。
- Q (Query): 現在処理中のトークンに対するクエリベクトル
- K (Key): 文脈に含まれるすべてのトークンのキー行列
- V (Value): 文脈に含まれるすべてのトークンの値行列
Q は K の各行と内積をとって Attention の重みを生み出し、その重みで V の各行をブレンドします。このブレンド結果が、当該トークンに対する出力です。つまり、Q が K に問い合わせ、関連する位置を見つけ、V から取り出して集約する わけです。
12.2.2 形状の分解
Multi-Head Attention の出力 A は 4 次元テンソルです。
A: [4, 4, 16, 128]
│ │ │ └── ヘッドごとの次元 (d_head = 512 / 4 = 128)
│ │ └────── シーケンス長 (seq_len = 16)
│ └────────── ヘッド数 (num_heads = 4)
└───────────── バッチサイズ (batch_size = 4)
分解してみましょう。
- 最初の 4: バッチに含まれる 4 本のシーケンス
- 二つ目の 4: 512 次元のモデル幅を 4 つのヘッドに分割したもの
- 16: シーケンスごとに 16 個のトークン
- 128: 各ヘッドの部分空間の次元
実際の計算は [16, 128] のスライス単位で行われます。シーケンスごと、ヘッドごとに 1 枚ずつです。4 次元という形状は単なるパッケージングに過ぎません。
┌─────────────────────────────────────┐
│ Sequence 1 │
│ ┌──────┬──────┬──────┬──────┐ │
│ │ Head1│ Head2│ Head3│ Head4│ │
│ │16×128│16×128│16×128│16×128│ │
│ └──────┴──────┴──────┴──────┘ │
├─────────────────────────────────────┤
│ Sequence 2 (同じ構造) │
├─────────────────────────────────────┤
│ Sequence 3 (同じ構造) │
├─────────────────────────────────────┤
│ Sequence 4 (同じ構造) │
└─────────────────────────────────────┘
12.2.3 実モデルのサイズ
GPT-2 (1.17 億パラメータ) の場合:
d_model = 768、num_heads = 12、d_head = 64
LLaMA-7B の場合:
d_model = 4096、num_heads = 32、d_head = 128
同じ 4 次元構造ですが、実用ではずっと大きいということです。
12.3 Concatenate: ヘッドを再び 1 つに
12.3.1 統合の操作
Concatenate(よく "concat" と呼ばれます)は、複数のヘッドを 1 つのテンソルに組み立て直すステップです。第 11 章で行った分割の逆操作にあたります。
concat 前: [4, 4, 16, 128] → 4 つのヘッド、各 128 次元
concat 後: [4, 16, 512] → 1 つにまとまった 512 次元テンソル
最後の 2 つの次元 [4, 128] を [512] に戻します。それだけ ── reshape 操作にすぎません。
12.3.2 そもそもなぜ分割して統合するのか
私が最初に学んだときに抱いた疑問は、「なぜベクトルを分割して、それぞれで Attention を計算してから貼り合わせるのか? その遠回りで何が得られるのか?」というものでした。
答えは 多視点の表現 です。
各ヘッドはモデル全体次元のうち別々の部分空間で動作します。パラメータは共有しません。学習が進むにつれて、ヘッドはそれぞれ専門性を獲得していく傾向があります。
- ヘッド 1 は構文に敏感なパターンを学ぶかもしれません
- ヘッド 2 は意味的類似性を学ぶかもしれません
- ヘッド 3 は位置の近さを学ぶかもしれません
- ヘッド 4 は話題の連続性を学ぶかもしれません
分割することがこの専門化を促し、統合によってその専門化された視点が一つの表現にまとめられます。
覚えておきたいトレードオフ: ヘッドが多いほど表現能力は豊かになりますが、パラメータも計算量も増えます。経験則的にちょうどよいのは
d_head = 64またはd_head = 128のあたりです。最適なヘッド数を与える理論式は存在せず、実験的に調整されています。
12.4 Wo: 最後の線形変換
12.4.1 Wo とは
連結のあと、最後にもう一度行列積を行います。
Wo の形: [512, 512]
演算: A @ Wo → 最終出力
Wo (出力用の重み行列) は Wq, Wk, Wv と構造的にまったく同じです。
- 形:
[d_model, d_model]=[512, 512] - 初期化: ランダム
- 種別: 学習可能なパラメータ
12.4.2 重み共有のルール
ここは私自身が学んだときに混乱した部分なので、はっきり書いておきます。
1 つの Transformer ブロック内: すべてのヘッドが 1 つの Wq、1 つの Wk、1 つの Wv を共有し、Wo は 1 つだけ存在します。各ヘッドは独立したモジュールではありません。射影行列を共有し(第 11 章の reshape のトリックを使って)、Wo がその出力を再結合するのです。
Transformer ブロック間: 各ブロックは独立した Wq, Wk, Wv, Wo の組を持ちます。12 ブロックのモデルなら、これらの行列の組が 12 セットあり、それぞれ少しずつ異なるものを学習します。
Block 1: Wq₁, Wk₁, Wv₁, Wo₁ ← 1 つ目のセット
Block 2: Wq₂, Wk₂, Wv₂, Wo₂ ← 2 つ目のセット
...
Block 12: Wq₁₂, Wk₁₂, Wv₁₂, Wo₁₂ ← 12 番目のセット
各ブロックは自分の重みで自分の Attention を計算し、表現を一段階ずつ深めていきます。
12.4.3 PyTorch では
PyTorch の nn.MultiheadAttention を使う場合、重み行列は内部で扱われます。
self.attn = nn.MultiheadAttention(embed_dim=512, num_heads=4)
# 内部では:
# self.attn.in_proj_weight → Wq, Wk, Wv をまとめてパック
# self.attn.out_proj.weight → これが Wo
Hugging Face の transformers ライブラリはさらにこれを包み込みますが、同じ 4 つの行列がそこにあります。
12.5 Q × K が実際に計算しているもの
12.5.1 スコア行列
Q と K の積が何を生み出すか、改めて見ていきましょう。最初のバッチ、最初のヘッドを取ると、16×16 の正方行列が得られます。
Token1 Token2 Token3 ... Token16
Token1 [ 0.20 0.10 0.05 ... 0.01 ]
Token2 [ 0.15 0.30 0.10 ... 0.02 ]
Token3 [ 0.08 0.12 0.25 ... 0.03 ]
...
Token16 [ 0.01 0.02 0.01 ... 0.40 ]
12.5.2 幾何学的な直観
この行列の読み方:
- 各行: 1 つのトークンの Attention の視点
- 各列: 1 つのトークンが他から見える度合い
- 各セル: 行のトークンから列のトークンへの Attention の重み
Softmax を通すと、各行の合計は 1 になります。シーケンス全体に対する確率分布になるわけです。
つまり Q × K は 「すべてのトークンについて、他のすべてのトークンに何 % の注意を割くべきか」 を計算しているのです。
行列の形にしているおかげで、これが効率よく行えます。ループを回さずに、すべての対の関係を一度に計算できるのです。
12.5.3 具体例
「The agent merged the pull request after review.」をモデルが処理する場面を想像してください。
"merged" から見ると:
- "merged" → "agent": おそらく 30% (動作の主語)
- "merged" → "pull request": おそらく 25% (動作の目的語)
- "merged" → "review": おそらく 20% (動作の文脈)
- "merged" → 残りのトークン: 残り 25%
これらの割合は Q × K から得られます。これがモデルに「どこを見ればよいか」を伝えるのです。
12.6 V の役割: 内容に Attention を適用する
12.6.1 V は内容の運び手
Q × K のスコア行列は「地図」のようなものです。どこを見るかは示してくれますが、それ自身は内容を持っていません。内容のほうは V です。
今回の構成では:
- シーケンスに 16 トークン
- 各トークンが 128 次元の V ベクトル(このヘッドの部分空間内)を持つ
12.6.2 掛け算の意味
スコア行列に V を掛けます。
(Q × K) × V → 形は依然として 16×128、変化なし
これがコアの操作です。Attention の割合を使って、各トークンのベクトルを更新する のです。
各トークンの出力ベクトルは、すべての V ベクトルの重み付き和になります。重みが Attention のスコアです。高い Attention を受け取ったトークンほど、その V の内容を出力に多く寄与させます。
最初のトークン埋め込みは意味を持たないランダム初期化からスタートしました。この操作を経て、そして数千回の学習ステップを経て、これらの値は意味を持つようになります ── 各トークンが 周囲のシーケンスという文脈の中で 何を表すのかを符号化していくのです。
12.7 学習: 同時に調整される 2 つの対象
12.7.1 一歩ずつ進む
学習中、各順伝播は小さな調整を行います。次の順伝播でさらに調整を行い、それが何万ステップも(大規模モデルではさらに多く)続いていきます。
12.7.2 トークン埋め込み更新の具体例
学習コーパスに「agent」という語を含むシーケンスがたくさんある状況を想像してください。
最初の学習ステップ: "agent" のトークン埋め込みはランダムで、意味のある値はありません。
最初の順伝播と逆伝播のあと、その埋め込みは「文脈の中で "agent" の次に何が来るかを予測する助けになる方向」へわずかに動きます。
2 ステップ目: 次に "agent" が現れたとき、ステップ 1 で 更新された 埋め込みを使います。再びわずかな調整が入ります。
N ステップ目: "agent" の埋め込みは豊かな情報を符号化するようになります。「これは agent という単語である」だけでなく、「行動を取る主体であり、エージェント的な文脈に現れ、'opened' や 'merged' のような動詞が続くことが多い」までも表現するようになります。
初期の "agent" 埋め込み: [ランダムな数値]
ステップ 1 後: [わずかに調整された値]
ステップ 2 後: [もう少し意味のある値]
...
ステップ N 後: [意味的に豊かな値]
これこそが、なぜ Embedding (埋め込み) と呼ぶのかの理由です。単語が意味のあるベクトル空間に埋め込まれていくのです。
12.7.3 更新される 2 種類のパラメータ
QKV Attention は、2 つの異なるパラメータの集合を同時に洗練します。
第 1: トークン埋め込み
- トークン ID をベクトルへ写像するルックアップテーブル
- 各トークンのベクトルが文脈中の意味を捉えるように更新される
- すべてのレイヤーで共有(トークン ID あたり 1 つの埋め込み)
第 2: 重み行列
- Wq, Wk, Wv, Wo ── Attention 内部の線形変換
- Attention 機構が有用な関係を見つけるように更新される
- ブロックごとに別個(12 ブロックのモデルなら 12 セット)
この 2 つは互いを引き上げ合います。トークン埋め込みが良くなれば Attention のスコアも良くなる。重み行列が良くなればトークン埋め込みの更新も良くなる。両者は一緒に収束していくのです。
12.8 全体像: Multi-Head Attention の働き
12.8.1 アーキテクチャ内での位置付け
各 Transformer ブロックの中の Multi-Head Attention モジュールは、次のことを行います。
- トークン埋め込みを更新する: このモジュールを通るたびに、シーケンス内で周囲に何があるかに応じて埋め込みベクトルを調整する
- 重み行列を更新する: Wq, Wk, Wv, Wo はすべて学習されるパラメータで、誤差逆伝播で改善されていく
12.8.2 パラメータ数
d_model = 512 の 12 ブロックモデルの場合:
- 各ブロック: 4 つの重み行列 × 512² = 4 × 262,144 = 1,048,576 パラメータ
- 12 ブロック: 12 × 1,048,576 ≈ 1,260 万パラメータ (Attention のみ)
- これに加えて: トークン埋め込みテーブル、FFN レイヤー、Layer Norm、出力射影
十分な学習ステップを経ると、これらのパラメータは整合的な文章生成を可能にする値に落ち着きます。
12.8.3 モデル規模の参考
| モデル | レイヤー数 | d_model | 総パラメータ |
|---|---|---|---|
| GPT-2 Small | 12 | 768 | 1.17 億 |
| GPT-2 Medium | 24 | 1024 | 3.45 億 |
| GPT-2 Large | 36 | 1280 | 7.74 億 |
| LLaMA-7B | 32 | 4096 | 70 億 |
| LLaMA-70B | 80 | 8192 | 700 億 |
章末チェックリスト
本章を終えたあと、次のことができるようになっているはずです。
- 4 次元の形
[batch, heads, seq_len, d_head]と各次元の意味を説明できる。 - Concatenate が何をしているか、なぜ分割してから統合するのが有用なのかを説明できる。
- Wo が何をするか、ブロック内とブロック間で重みがどう共有されるかを説明できる。
- Q × K が計算するもの (Attention の割合) を説明できる。
- (Q × K) × V が何をしているか (Attention の重みを使ってトークンベクトルを更新する) を説明できる。
- 学習が同時に調整する 2 つの対象 (トークン埋め込みと重み行列) を説明できる。
次章でお会いしましょう
本章はこれで十分です。Attention テンソルの各出力次元が ── どのヘッド、どのトークン、どの部分空間を表しているか ── 図を見なくても説明できるなら、第 13 章の準備は整っています。
第 13 章では残差接続と Dropout を扱います。深い Transformer を安定して学習させるためのエンジニアリング上のコツです。Attention が何を生み出すかが分かった今、残差接続というパターンがなぜ理にかなっているのか、すぐに腹落ちするはずです。それでは、また次の章でお会いしましょう。