一文要約: MHA は各ヘッドが独自の K と V を持つ(最も表現力が高く、メモリも最大)。MQA はすべてのヘッドが一つの K/V を共有する(最小メモリだが品質がやや落ちる)。GQA はグループ共有で中間をとる。現代の LLM の大半はこの GQA を採用しています。
23.1 KV キャッシュが生み出した問題
23.1.1 メモリの計算
第22章で確認したとおり、プロダクション推論において KV キャッシュは欠かせません。しかし Multi-Head Attention では各ヘッドが独自の K と V をキャッシュするため、積み重なると大変なことになります。
典型的な 7B パラメータモデルの場合 — 32 層、32 ヘッド、ヘッド次元 128、FP16:
リクエストごとの KV キャッシュ =
32 層 × 32 ヘッド × 2 (K と V) × seq_len × 128 × 2 バイト
seq_len = 1024 のとき:
32 × 32 × 2 × 1024 × 128 × 2 = 536 MB
これは 1 ユーザー、1 会話、1024 トークンで 536 MB です。100 人が同時に 4096 トークンを使うとスケールすると:
536 MB × 4 (4096/1024) × 100 ユーザー ≈ 200 GB
KV キャッシュだけで 200 GB。A100 が丸二枚つぶれる計算です。業界が「ヘッドごとに独立した K/V は本当に必要なのか」と問い始めた理由がわかります。
23.1.2 根本的なトレードオフ
MHA はもともと学習向けに設計されています。各ヘッドが異なるパターンを捉えるために独立した射影を学ぶ、それが強みです。しかし推論では、その独立した射影がコストになります。レイヤーごと、ヘッドごと、トークンごとに K/V を保存しなければなりません。
学習は表現力を求める。サービングは効率を求める。MQA と GQA はそのトレードオフから生まれたアーキテクチャです。
23.1.3 三つのメカニズム
| メカニズム | 正式名称 | 核心的なアイデア |
|---|---|---|
| MHA | Multi-Head Attention | 各ヘッドが独立した K、V を持つ |
| MQA | Multi-Query Attention | すべてのヘッドが一つの K、V を共有する |
| GQA | Grouped-Query Attention | グループ単位でヘッドが一つの K/V を共有する |
23.2 MHA: ベースライン
23.2.1 構造
n_heads ヘッドの標準 MHA では、各ヘッドがそれぞれ:
- 独自の 射影
- 独自の 射影
- 独自の 射影
を持ちます。KV キャッシュにはレイヤーごとに n_heads 個の K テンソルと n_heads 個の V テンソルが格納されます。
23.2.2 マルチヘッドが役立つ理由
異なるヘッドは本当に異なるものを学習します。「エージェントがレビュアーをタグ付けした、なぜなら PR が緊急だったから」という文を考えてみましょう:
- ヘッド 1 は構文上の主語-動詞を追う: エージェント → タグ付けした
- ヘッド 2 は代名詞の解消を追う: 「その PR」← どの PR?
- ヘッド 3 は因果推論を追う: タグ付けした → なぜなら → 緊急
- ヘッド 4 は直近性を追い、直近のトークンに強く注目する
独立した K/V 射影によって、各ヘッドがトークン履歴に対して独自の「視点」を構築できます。これが MHA の強みです。
23.2.3 MHA のコード形状
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.n_heads = n_heads
self.head_dim = d_model // n_heads
# すべてのヘッドを含む d_model 射影
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model) # n_heads 個の独立した K 射影
self.W_v = nn.Linear(d_model, d_model) # n_heads 個の独立した V 射影
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x, kv_cache=None):
B, T, C = x.shape
q = self.W_q(x).view(B, T, self.n_heads, self.head_dim)
k = self.W_k(x).view(B, T, self.n_heads, self.head_dim)
v = self.W_v(x).view(B, T, self.n_heads, self.head_dim)
# KV キャッシュは n_heads 組の K と V を格納する
23.2.4 プロダクションでの問題
32 ヘッドのモデルでは、各レイヤーの KV キャッシュに 64 テンソル(32 K + 32 V)が必要です。長いコンテキストや高い並行数では、これがリクエストを捌ける量の制約になります。
ツール呼び出しを組み合わせたエージェントシステムで 16k コンテキストを使う場合を具体的に計算すると:
セッションごとの KV キャッシュ (Llama-7B、16k ctx、FP16):
32 層 × 32 ヘッド × 2 (K+V) × 16384 × 128 × 2 バイト ≈ 8 GB
1 アクティブセッションで 8 GB です。モデルの重みを読み込んだ後に 40 GB の GPU が残っていれば、長いコンテキストセッションをせいぜい 4 つしか並行で捌けません。チームで使うことを考えると、KV キャッシュのメモリを削減したいという圧力がよくわかります。
23.3 MQA: すべてを集約する
23.3.1 核心的なアイデア
Multi-Query Attention(Shazeer, 2019)はシンプルですが大胆な選択をします: すべてのクエリヘッドが一つの K と一つの V を共有する。
- Q は引き続き
n_heads個の独立した射影を持つ - K は 1 個の射影
- V は 1 個の射影
KV キャッシュに格納されるのは、クエリヘッドの数にかかわらず レイヤーごとに 2 テンソルだけです。
23.3.2 メモリの節約
同じ 7B モデルで 1024 トークンの場合:
MHA KV キャッシュ = 32 層 × 32 ヘッド × 2 × 1024 × 128 × 2 = 536 MB
MQA KV キャッシュ = 32 層 × 1 ヘッド × 2 × 1024 × 128 × 2 = 16.75 MB
97% 削減。MHA で 5 ユーザーしか捌けなかった同じ GPU が、MQA では約 160 ユーザーに対応できます。
23.3.3 MQA のコード形状
class MultiQueryAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.W_q = nn.Linear(d_model, d_model) # n_heads 分フル
self.W_k = nn.Linear(d_model, self.head_dim) # 1 ヘッドのみ!
self.W_v = nn.Linear(d_model, self.head_dim) # 1 ヘッドのみ!
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x, kv_cache=None):
B, T, C = x.shape
q = self.W_q(x).view(B, T, self.n_heads, self.head_dim)
k = self.W_k(x).view(B, T, 1, self.head_dim)
v = self.W_v(x).view(B, T, 1, self.head_dim)
# k, v は [B,T,1,head_dim] から [B,T,n_heads,head_dim] にブロードキャストされる
23.3.4 コスト
すべてのクエリヘッドに同じ K/V 参照を強いることで、各ヘッドがトークン履歴の独立した視点を構築する力が制限されます。MQA は多くのタスクでうまく機能しますが、多様な長距離パターン捕捉が必要なタスクでは品質の低下が見られます。Google の PaLM は MQA を採用しましたが、フロンティアスケールでの品質低下はコミュニティに受け入れがたいものでした。
23.4 GQA: 現実的な中間点
23.4.1 核心的なアイデア
Grouped-Query Attention(Ainslie et al., 2023)はハイパーパラメータを一つ導入します: n_kv_heads、つまり K/V グループの数です。
クエリヘッドは n_kv_heads 個のグループに分けられます。同じグループ内のすべてのクエリヘッドが一つの K 射影と一つの V 射影を共有します。
形式的には:
n_heads— Q ヘッドの数n_kv_heads— KV グループの数n_rep = n_heads / n_kv_heads— グループあたりの Q ヘッド数
特殊なケース:
n_kv_heads = n_heads→ MHA(各ヘッドが独立)n_kv_heads = 1→ MQA(すべてのヘッドが共有)1 < n_kv_heads < n_heads→ GQA
23.4.2 GQA のコード形状
class GroupedQueryAttention(nn.Module):
def __init__(self, d_model, n_heads, n_kv_heads):
super().__init__()
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.n_rep = n_heads // n_kv_heads
self.head_dim = d_model // n_heads
self.W_q = nn.Linear(d_model, n_heads * self.head_dim)
self.W_k = nn.Linear(d_model, n_kv_heads * self.head_dim)
self.W_v = nn.Linear(d_model, n_kv_heads * self.head_dim)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x, kv_cache=None):
B, T, C = x.shape
q = self.W_q(x).view(B, T, self.n_heads, self.head_dim)
k = self.W_k(x).view(B, T, self.n_kv_heads, self.head_dim)
v = self.W_v(x).view(B, T, self.n_kv_heads, self.head_dim)
# K と V を Q のヘッド数に合わせて展開する
k = self.repeat_kv(k) # [B, T, n_heads, head_dim]
v = self.repeat_kv(v)
# ここから先は MHA と同じ Attention の計算
def repeat_kv(self, x):
"""各 KV グループを n_rep 回繰り返して Q ヘッド数に合わせる。"""
B, T, n_kv, head_dim = x.shape
if self.n_rep == 1:
return x
# [B, T, n_kv, head_dim] -> [B, T, n_kv, n_rep, head_dim]
x = x.unsqueeze(3).expand(B, T, n_kv, self.n_rep, head_dim)
# -> [B, T, n_heads, head_dim]
return x.reshape(B, T, self.n_heads, head_dim)
23.4.2b 学習と推論の違い
一つ理解しておくべき微妙な点があります。学習時は KV キャッシュを使いません(シーケンス全体をカジュアルマスキングで並列処理します)。そのため学習における GQA のメリットは、K と V の射影行列が小さくなるというパラメータ数の削減だけです。小さいですが、ゼロではありません。
推論時はメリットがずっと大きくなります。デコードはメモリ帯域幅に律速されます(第22章 22.6.2 節)。生成される各トークンは KV キャッシュ全体を HBM から読み込みます。KV キャッシュが小さければ、より多くが SRAM に収まり、トークンごとの HBM 読み込みが減り、スループットが上がります。GQA の 4× や 8× のメモリ削減は、ほぼそのままデコード速度の向上につながります。
23.4.3 repeat_kv の幾何学的イメージ
n_heads = 8、n_kv_heads = 2 のとき:
元の K/V の形状: [B, T, 2, head_dim]
KV グループ 0 KV グループ 1
repeat_kv 後: [B, T, 8, head_dim]
Q0 Q1 Q2 Q3 Q4 Q5 Q6 Q7
↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓
KV0 KV0 KV0 KV0 KV1 KV1 KV1 KV1
Q ヘッド 0〜3 は KV グループ 0 を共有し、Q ヘッド 4〜7 は KV グループ 1 を共有します。計算上これはテンソルの繰り返しであり、独立した射影ではありません。パラメータは追加されません。
23.5 三者比較
23.5.1 メモリの数値
7B モデル、1024 トークンシーケンス、FP16 の場合:
MHA(32 KV ヘッド):
32 × 32 × 2 × 1024 × 128 × 2 バイト = 536 MB
GQA(8 KV ヘッド):
32 × 8 × 2 × 1024 × 128 × 2 バイト = 134 MB
MQA(1 KV ヘッド):
32 × 1 × 2 × 1024 × 128 × 2 バイト = 16.75 MB
| メカニズム | KV ヘッド | KV キャッシュ | MHA 比 |
|---|---|---|---|
| MHA | 32 | 536 MB | 100% |
| GQA | 8 | 134 MB | 25% |
| MQA | 1 | 16.75 MB | 3.1% |
GQA は MHA のメモリコスト 25% で、MHA に近い品質を保ちます。MQA は 3.1% まで到達しますが、品質の代償がより大きくなります。
23.5.2 品質と効率のトレードオフ
GQA 論文のベンチマークから:
- GQA-G8(8 グループ)は品質が MHA に近い
- GQA-G8 の推論時間は MQA に近い
- 8 グループを超えて KV グループを増やしても品質改善はすぐに頭打ちになる
重要な実験的知見があります。学習済みの MHA モデルでは、異なるヘッドの K と V の表現が驚くほど似ていることが多いのです。多くのヘッドがほぼ冗長な射影を学びます。だからグループ内で K/V を共有してもそれほど損をしない — 失う多様性はもともとそれほど有用なシグナルを提供していなかったからです。
この知見にはアーキテクチャ上の示唆があります。既存の MHA チェックポイントを変換するのではなく最初から設計するなら、GQA で直接学習すれば、モデルは最初から KV 容量を効率的に使うことを学べます。MHA における冗長性は、学習中に K/V 表現をヘッド間で差別化するインセンティブがないことの副産物でもあります。
23.5.3 サービング並行数への影響
上記のメモリ数値は、同時に何人のユーザーに対応できるかを直接決定します。A100 80GB GPU で 7B モデルを FP16 でロードすると(14 GB)、KV キャッシュ用に約 66 GB 残ります:
| メカニズム | セッションあたり KV (4k ctx) | 最大並行セッション数 |
|---|---|---|
| MHA | 536 MB × 4 = 2.1 GB | 約 31 |
| GQA(8 ヘッド) | 134 MB × 4 = 536 MB | 約 123 |
| MQA | 16.75 MB × 4 = 67 MB | 約 984 |
GQA は同じハードウェア予算で MHA に対して並行ユーザー数をほぼ 4 倍にします。これがビジネスケースを一つの表で示したものです。
23.5.4 完全なトレードオフ表
| メカニズム | 品質 | 推論速度 | KV メモリ | 使いどころ |
|---|---|---|---|---|
| MHA | 最高 | 最遅 | 最大 | 研究、小規模モデル、学習オンリーの設定 |
| MQA | やや劣る | 最速 | 最小 | エッジ/モバイル、極端なスループット要件 |
| GQA | MHA に近い | MQA に近い | 中程度 | プロダクションのほぼすべて |
23.6 現代のモデルが採用するもの
23.6.1 業界は GQA に収束している
| モデル | パラメータ | Q ヘッド | KV ヘッド | グループサイズ |
|---|---|---|---|---|
| Llama-2 7B | 7B | 32 | 32 | 1 (MHA) |
| Llama-2 70B | 70B | 64 | 8 | 8 |
| Llama-3 8B | 8B | 32 | 8 | 4 |
| Llama-3 70B | 70B | 64 | 8 | 8 |
| Mistral 7B | 7B | 32 | 8 | 4 |
| Qwen-1.5 7B | 7B | 32 | 32 | 1 (MHA) |
| Qwen-1.5 32B | 32B | 40 | 8 | 5 |
| Qwen-2 7B | 7B | 28 | 4 | 7 |
いくつか観察できます。小さいモデルが MHA に留まることがあるのは、絶対的なメモリコストが管理可能で、表現力をフル活用したいからです。大きいモデルはほぼ例外なく GQA を選びます。KV ヘッド 8 というのが一般的なスイートスポットです。
23.6.2 なぜ 8 KV ヘッドなのか
研究によると、品質は 1 から 8 KV ヘッドにかけて急速に改善し、そこから頭打ちになります。一方、8 は一般的なテンソル並列構成(2、4、または 8 GPU)に均等に割り切れるため、KV ヘッドをデバイス間できれいに分散できます。経験的にも優れており、運用上も便利という組み合わせです。
23.6.3 マルチ GPU でのメリット
テンソル並列サービング(例: 4 GPU)の場合:
MHA(32 ヘッド):
GPU 0: Q ヘッド 0–7, K ヘッド 0–7, V ヘッド 0–7
GPU 1: Q ヘッド 8–15, K ヘッド 8–15, V ヘッド 8–15
...
GQA(32 Q ヘッド、8 KV ヘッド):
GPU 0: Q ヘッド 0–7, K ヘッド 0–1, V ヘッド 0–1
GPU 1: Q ヘッド 8–15, K ヘッド 2–3, V ヘッド 2–3
...
各 GPU の KV キャッシュが 4× 小さくなります。多数の並行リクエストを実行するときに重要です。
23.7 MHA チェックポイントを GQA に変換する
既に学習済みの MHA モデルがあれば、Google の GQA 論文が提案した uptraining を使えます:
- 重みを平均化する — マージする K/V ヘッドの各グループについて、それらの射影行列の平均を取る
- 継続学習する — 元の学習データの約 5% で短いファインチューニングを実行する
- 品質を回復させる — 平均化された重みが合理的な初期化になっているため、モデルは素早く適応する
def convert_mha_to_gqa(k_weights, n_heads, n_kv_heads, head_dim, d_model):
"""K(または V)射影行列のグループを平均化する。"""
group_size = n_heads // n_kv_heads
# k_weights の形状: [d_model, n_heads * head_dim]
k = k_weights.reshape(d_model, n_heads, head_dim)
# グループ化して平均: [d_model, n_kv_heads, group_size, head_dim]
k_grouped = k.reshape(d_model, n_kv_heads, group_size, head_dim)
k_gqa = k_grouped.mean(dim=2) # [d_model, n_kv_heads, head_dim]
return k_gqa.reshape(d_model, n_kv_heads * head_dim)
これが機能するのは、先述の経験的冗長性のためです。平均化された結果はすでに、各グループの共有 K/V がどうあるべきかの合理的な近似になっています。
GQA 論文では、元の学習データのわずか 5% で品質ギャップのほとんどを回復できると報告されています。これにより uptraining が現実的になります。高品質な MHA モデルを一度学習すれば、効率的なサービング用の GQA モデルに安価に変換できます。
23.8 Flash Attention と GQA の組み合わせ
Flash Attention 2 はカーネルに直接ネイティブ GQA サポートを追加しました。これは効率の観点で重要です。
GQA を認識しない Flash Attention 実装では、タイルループの前に repeat_kv で K と V を展開する必要があり、HBM により大きなテンソルが作られます。ネイティブ GQA 対応では、カーネルが各 Q ブロックを KV グループインデックス(group_idx = q_head_idx // n_rep)にマッピングし、展開されたテンソルを実体化せずに正しい K/V タイルをロードします。
その効果: Flash Attention の IO 効率と GQA の小さな KV フットプリントの両方を、繰り返し操作による余分なメモリオーバーヘッドなしに得られます。F.scaled_dot_product_attention を呼び出すか、vLLM や TensorRT-LLM のような GQA 対応実装を使うと、この最適化は通常自動的に適用されます。
23.10 よくある誤解
「GQA はヘッド数を増やした MQA に過ぎない」 — 厳密には違います。GQA はパラメータ化されたファミリーです。MHA と MQA は両極端であり、GQA はその間のスペクトル全体です。設計上の重要な選択は n_kv_heads です。
「KV ヘッドは少ないほど常によい」 — 品質と効率のトレードオフは現実に存在します。32 から 8 KV ヘッドに減らすと、品質損失を最小限にメモリを 4× 削減できます。8 から 1 に減らすとさらに 8× 削減できますが、より目立った品質劣化が伴います。最適な設定はサービング制約と品質要件に依存します。
「GQA は推論にしか影響しない」 — GQA はパラメータ数もわずかに削減します(小さな K と V 射影行列)。これにより学習が速くなり、モデルファイルサイズも小さくなる可能性があります。その影響は小さいですが、ゼロではありません。32 ヘッドから 8 KV ヘッドに変更した 7B モデルの場合: K と V の射影行列は [d_model, d_model] から [d_model, d_model/4] に縮小し、総パラメータの約 6% が節約されます。
「すべてのモデルが GQA を使うべき」 — メモリが豊富な環境にデプロイされる非常に小さいモデル(7B 未満)では、MHA の追加表現力がオーバーヘッドに見合う場合があります。特定の n_kv_heads にコミットする前には、必ず品質を測定してください。
「Flash Attention と GQA は競合する」 — 補完関係にあります。Flash Attention 2 はネイティブ GQA サポートを追加しました。タイル計算中に repeat_kv を内部で処理するため、IO 効率の高いカーネルと小さな KV フットプリントの両方を同時に得られます。
23.11 章のまとめ
MHA (Multi-Head Attention)
n_kv_heads = n_heads
各ヘッドが独立した Q、K、V を持つ
KV キャッシュ: レイヤーごとに 2 × n_heads テンソル
最高品質、最大メモリ
MQA (Multi-Query Attention)
n_kv_heads = 1
すべてのヘッドが一つの K、V を共有
KV キャッシュ: レイヤーごとに 2 テンソル
最小メモリ、スケールでの品質リスクあり
GQA (Grouped-Query Attention)
1 < n_kv_heads < n_heads
グループ単位でヘッドが K/V を共有
KV キャッシュ: レイヤーごとに 2 × n_kv_heads テンソル
MHA に近い品質、MQA に近い効率
選択ガイド
| 状況 | 推奨 | 理由 |
|---|---|---|
| 研究・学習重視 | MHA | 最大の表現力 |
| 大規模プロダクションサービング | GQA(8 KV ヘッド) | 最良の品質-効率バランス |
| エッジ/モバイル/極限の効率 | MQA | 最小メモリフットプリント |
| 不確かな場合 | GQA(8 KV ヘッド) | 安全で実証済みのデフォルト |
23.11.1 モデルの設定ファイルを読む
最新のモデルの Hugging Face config.json には、両方のフィールドが記載されています。Llama-3 8B の場合:
{
"num_attention_heads": 32,
"num_key_value_heads": 8
}
num_attention_heads が n_heads(Q ヘッド数)で、num_key_value_heads が n_kv_heads です。グループサイズは 32 / 8 = 4 — 各 KV ペアが 4 つのクエリヘッドで共有されます。num_key_value_heads が num_attention_heads と等しければ MHA、1 ならば MQA です。
チャプターチェックリスト
この章を終えた後、以下ができるようになっているはずです:
- 長いコンテキストや高い並行数で MHA の KV キャッシュがボトルネックになる理由を説明できる。
- MHA、MQA、GQA をそれぞれ一文で説明できる。
- モデルの次元数が与えられたとき、各メカニズムの KV キャッシュサイズを計算できる。
- 大きなメモリ節約にもかかわらず GQA の品質損失が小さい理由を説明できる。
- モデルの設定ファイルを読んで
n_headsとn_kv_headsを特定できる。 -
repeat_kvを実装できる。 - 各メカニズムのもとで、ある GPU が何人の並行ユーザーに対応できるかを推定できる。
参考文献
- Fast Transformer Decoding: One Write-Head is All You Need (Shazeer, 2019) — MQA 原著論文
- GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints (Ainslie et al., 2023)
- Llama 2: Open Foundation and Fine-Tuned Chat Models (Meta, 2023) — モデルスケールを超えた MHA→GQA の移行を示す
次の章へ
GQA によって保存する K/V データを減らすことができました。しかしキャッシュが小さくなっても、各トークンはウィンドウ内のすべての他のトークンに Attention を当てます — 完全 Attention の二次コストは残ったままです。
第24章では、その要件を完全に捨てたらどうなるかを探ります。Sparse Attention はシーケンスの選ばれた一部にしか Attention を当てないようにすることで、計算量を O(N) に近づけます。そして Infini Attention はさらに踏み込み、固定サイズの圧縮メモリを使って際限なく成長するコンテキストを扱います。次章でお会いしましょう。