一文要約: KV Cache は過去のトークンから計算済みのキーとバリューを保存するため、モデルがそれらを再計算することはありません。ステップあたりの O(n²) の繰り返し作業を O(n) に変え、実測で5倍の推論高速化を実現します。
22.1 KV Cache が存在する理由
22.1.1 自己回帰的な無駄
第20章で自己回帰ループを確立しました。モデルは1トークンを生成し、それをシーケンスに追加し、フルの順伝播を再び実行します。この「フルの順伝播」が問題です。
あるエージェントが長いプルリクエストの説明に対して返答を書いているとします。プロンプトが400トークン、返答が300トークンだとしましょう。最適化なしでは、モデルは毎ステップ、増え続けるシーケンスに対して Attention を実行しなければなりません:
ステップ1: 401トークンに対して Attention を実行 → トークン1を出力
ステップ2: 402トークンに対して Attention を実行 → トークン2を出力
ステップ3: 403トークンに対して Attention を実行 → トークン3を出力
...
ステップ300: 700トークンに対して Attention を実行 → トークン300を出力
400個のプロンプトトークンが300回再処理されます。KV Cache が排除するのは、この無駄です。
22.1.2 Attention が実際に計算していること
復習しましょう:
これには2つの行列積が含まれます:
- — スコア:各ポジションが他のポジションにどれだけ注目するか
- — バリューの加重和
新しいトークンが来るたびに Q、K、V はすべて大きくなります。しかし、ここに重要な事実があります。古いトークンの K と V の行は変化しません。推論中、モデルの重みは固定されています。つまり、ステップ100におけるトークン1の K ベクトルと V ベクトルは、ステップ1のときと全く同じです。それらを再計算するのは純粋な無駄です。
22.1.3 具体的なウォークスルー
プロンプトを「The agent opened a pull request」(5トークン)とします。モデルは1語ずつ返答を生成します。
KV Cache なしでは、各ステップでシーケンス全体を再処理します:
ステップ1 — 最初の返答トークンを生成:
シーケンス: [The, agent, opened, a, pull, request] 長さ: 6
全6トークンの K と V を再計算する。
QK^T の形状: 6×6
ステップ2 — 2番目の返答トークンを生成:
シーケンス: [The, agent, opened, a, pull, request, token_1] 長さ: 7
全7トークンの K と V を再計算する — 「The」から「request」まで含めて。
QK^T の形状: 7×7
ステップ20 — 20番目の返答トークンを生成:
シーケンス: 長さ25
25トークンの K と V を再計算する — 元の5個のプロンプトトークンは20回再計算されている。
すべてのプロンプトトークンの K ベクトルと V ベクトルは、毎ステップゼロから計算されます。モデルの重みもトークンの埋め込みも変わらないため、毎回同じ結果になります。
22.1.4 どれくらいの無駄か
N トークンを生成するのに必要な Attention の行列積の回数を数えてみましょう:
| ステップ | シーケンス長 | のコスト(比例) |
|---|---|---|
| 1 | 1 | 1 |
| 2 | 2 | 4 |
| 3 | 3 | 9 |
| n | n | n² |
キャッシュなしで N ステップの合計:
KV Cache ありでは、各ステップでキャッシュ済みの K に対して1行の Q を計算するだけです。ステップあたりのコストは O(N)、合計は 。N = 1000 では、ステップあたりの Attention の作業量が約1000分の1になります。
22.2 何をキャッシュするのか、なぜか
22.2.1 4つのルール
- KV Cache は推論にのみ適用される。 学習中は対象シーケンス全体が既知なので、すべてのポジションを並列処理できます。キャッシュは不要です。
- KV Cache はデコーダブロックにのみ存在する。 エンコーダブロック(存在する場合)は入力を1度だけ並列処理します。自己回帰的ではありません。
- 2つの Attention 行列積を高速化する。 キャッシュ済みの K と V によって、行列の第2次元が増え続けるものから固定されたものに変わります。
- メモリを消費する。 無料の昼食はありません。キャッシュを保存するコストとして、シーケンス長に比例した GPU メモリが必要になります。
22.2.2 なぜ Q ではなく K と V だけキャッシュするのか
自己回帰的な生成では、最後のポジションの出力だけが重要です。その出力を計算するために、新しいトークンの Q はすべてのポジションの K と V に対して Attention を実行します。N+1番目のポジションの Q は毎ステップ新しく計算されます。キャッシュするものは何もありません。一方、過去のすべてのポジションの K と V は何度も再利用されます。キャッシュに値するのはそちらです。
こう考えると分かりやすいです。Query はモデルが今投げかけている新しい質問です。Keys と Values は参照している知識ベースです。知識ベースは机の上に置いておく。古い質問を引き出す必要はありません。
22.3 KV Cache なしと KV Cache あり
22.3.1 KV Cache なし
プロンプトを「The agent opened」(3トークン)とします。生成:
ステップ1 — 「The agent opened」の次のトークンを予測:
Q = [The, agent, opened] サイズ: 3×d
K = [The, agent, opened] サイズ: 3×d ← ゼロから計算
V = [The, agent, opened] サイズ: 3×d ← ゼロから計算
QK^T = 3×3 マトリクス
ステップ2 — 「The agent opened a」の次のトークンを予測:
Q = [The, agent, opened, a] サイズ: 4×d
K = [The, agent, opened, a] サイズ: 4×d ← The/agent/opened を再計算!
V = [The, agent, opened, a] サイズ: 4×d ← The/agent/opened を再計算!
QK^T = 4×4 マトリクス
K と V の最初の3行は両ステップで同一です。標準実装はそれを捨てて、毎回再計算します。
22.3.2 KV Cache あり
ステップ1 — フルのプロンプトを処理し、キャッシュを構築:
Q = [The, agent, opened] サイズ: 3×d
K = [The, agent, opened] サイズ: 3×d → キャッシュされる
V = [The, agent, opened] サイズ: 3×d → キャッシュされる
ステップ2 — 次のトークンを予測:
Q = [a] サイズ: 1×d ← 新しいトークンのみ
K = [The, agent, opened, a] サイズ: 4×d ← キャッシュを読み込み + 新しい行を追加
V = [The, agent, opened, a] サイズ: 4×d ← キャッシュを読み込み + 新しい行を追加
QK^T = 1×4 ベクトル ← クエリ行は1行だけ!
図中の色:グレー = キャッシュ済み(メモリから読み込み)、オレンジ = 新たに計算された K の行、グリーン = 新たに計算された V の行、ピンク = 因果マスクが適用されたポジション。
22.3.3 計算の節約
| ステップ n | キャッシュなし | キャッシュあり | 節約率 |
|---|---|---|---|
| 1 | 1×1 = 1 | 1×1 = 1 | 0% |
| 2 | 2×2 = 4 | 1×2 = 2 | 50% |
| 4 | 4×4 = 16 | 1×4 = 4 | 75% |
| 10 | 100 | 10 | 90% |
| 1000 | 1,000,000 | 1000 | 99.9% |
22.4 KV Cache のメモリコスト
22.4.1 計算式
KV Cache メモリ =
2 (K と V)
× batch_size
× context_length
× n_layers
× n_heads
× d_head
× bytes_per_element
bytes_per_element の値:
- FP32 → 4 バイト
- BF16 / FP16 → 2 バイト
- INT8 / FP8 → 1 バイト
22.4.2 計算例:Llama-7B
設定:n_layers = 32、n_heads = 32、d_head = 128、FP16、context_length = 4096、batch = 1。
KV Cache = 2 × 1 × 4096 × 32 × 32 × 128 × 2 バイト
= 2,147,483,648 バイト
≈ 2 GB
Llama-7B 自体は FP32 で約 14 GB です。4096トークンの会話1件に対して、それに加えて 2 GB の KV Cache が必要になります。
22.4.3 デプロイ時の制約
NVIDIA A10 GPU(24 GB メモリ)上で Llama-7B を提供するとします:
| 項目 | メモリ |
|---|---|
| モデルの重み(FP32) | 〜14 GB |
| KV Cache に使えるメモリ | 〜10 GB |
| 最大 KV Cache ウィンドウ | 〜2万トークン |
| コンテキスト4kで同時ユーザーの最大数 | 〜5 |
これが本番環境での現実の制約です。モデルのパラメータではなく、KV Cache のメモリが同時に提供できるユーザー数を左右することが多いです。この圧力こそが、第23章で扱う MQA と GQA を生み出した動機です。
22.5 マルチターン会話
22.5.1 積み上がる問題
チャットアプリは複数のターンにわたってコンテキストを蓄積します。スマートなキャッシュなしでは、各ターンでモデルが会話履歴全体を再処理することになります。
KV Cache があれば、会話履歴は1度だけ計算されます:
ターン1: [Q1][A1]
└── 計算してキャッシュ
ターン2: [Q1][A1] [Q2] [A2]
└── 再利用 ──┘ └── 新規
↑ この部分だけ新しい K, V が必要
ターン3: [Q1][A1][Q2][A2] [Q3] [A3]
└───── 再利用 ───┘ └── 新規
各ターンで追加されるのは、新しいトークンの K と V の行だけです。積み上がり方は2乗ではなく線形です。
22.5.2 メモリは二乗ではなく線形に増える
キャッシュ管理なしでは、マルチターン会話における Attention の総作業量は O(n²) です。ターン k のコンテキスト長は k に比例し、それに対する Attention のコストは O(k²) になります。KV Cache があれば、ターンあたり O(k) — 新しいトークンの K/V 行だけを計算します。
22.5.3 「タイプライター」効果
チャットインターフェースで見られる、トークンがストリーミングで1つずつ出てくるあの動作は、KV Cache の産物でもあります。各新しいトークンの生成コストが安いのは、K/V 計算の大部分がプリフィル段階で完了しているからです。KV Cache なしでは、シーケンスが長くなるにつれてトークンあたりのレイテンシが増加し続けます。
22.6 推論の2つのフェーズ
22.6.1 プリフィルとデコード
| フェーズ | プリフィル | デコード |
|---|---|---|
| 入力 | フルのプロンプト(一括) | 直前に生成されたトークン |
| 計算 | 全プロンプトトークンを並列処理 | 1トークンずつ順次処理 |
| KV Cache | 構築される | 1行ずつ拡張される |
| ボトルネック | 演算バウンド(多くのトークンを並列処理) | メモリ帯域幅バウンド |
22.6.2 なぜデコードはメモリバウンドなのか
プリフィル中、GPU は多くのトークンを同時に処理します。本当の意味でのバッチ行列積を行っています。デコード中は、1ステップで1トークンだけを処理します。GPU の演算リソースのほとんどがアイドル状態になり、制限要因は HBM からどれだけ速く KV Cache を読み出せるかになります。
だからこそデコードはメモリ帯域幅バウンドと呼ばれ、KV Cache サイズの削減(MQA、GQA、量子化によって)がデコードスループットに直接的な影響を与えます。
22.7 コード:高速化を測定する
import time
import numpy as np
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2").to(device)
for use_cache in (True, False):
times = []
for _ in range(10):
start = time.time()
model.generate(
**tokenizer("The agent opened a pull request.", return_tensors="pt").to(device),
use_cache=use_cache,
max_new_tokens=1000,
)
times.append(time.time() - start)
print(f"{'KV キャッシュあり' if use_cache else 'KV キャッシュなし'}: "
f"{np.mean(times):.1f} ± {np.std(times):.1f} 秒")
期待される出力:
KV キャッシュあり: 11.x ± x.x 秒
KV キャッシュなし: 56.x ± x.x 秒
GPT-2 で約 5倍の高速化 です。〜1000トークンのシーケンスに対して数学が予測する値と一致しています。
22.7.1 キャッシュのデータ構造
# レイヤーごとの KV Cache 構造
n_layers = 96
key_cache = [[] for _ in range(n_layers)]
value_cache = [[] for _ in range(n_layers)]
# プリフィル: プロンプトを処理してキャッシュを構築する
for token in input_sequence:
for layer in range(n_layers):
key_cache[layer].append(compute_key(layer, token))
value_cache[layer].append(compute_value(layer, token))
# デコード: 新しいトークンを1つずつ生成する
for new_token in generation_loop:
for layer in range(n_layers):
query = compute_query(layer, new_token)
# キャッシュ経由で全履歴に対して Attention を実行する
output = attention(query, keys=key_cache[layer], values=value_cache[layer])
# キャッシュを拡張する
key_cache[layer].append(compute_key(layer, new_token))
value_cache[layer].append(compute_value(layer, new_token))
2点補足します。各レイヤーの射影の重みが異なるため、各レイヤーは独立したキャッシュを持ちます。また、本番フレームワークではキャッシュは Python のリストではなく、事前にアロケートされたテンソルです。
22.8 よくある質問
KV Cache は精度に影響しますか? いいえ。再計算された場合とまったく同じ K と V の値を保存しています。出力はキャッシュなし版とビット単位で同一です。
学習でも KV Cache を使いますか? いいえ。学習では完全に既知のシーケンスを因果マスクを使って並列処理します。自己回帰的な構造は逐次的な生成ではなくマスクによって実現されます。KV Cache の概念が意味を持つのは、逐次的なデコードの状況だけです。
どのモデルがサポートしていますか? デコーダーのみのモデル(GPT ファミリー、LLaMA、Mistral、Gemma、Qwen)と、デコード中のエンコーダーデコーダーモデルすべてです。エンコーダーのみのモデル(BERT、RoBERTa)は自己回帰的に生成しないため、必要ありません。
デメリットはありますか? 3つあります。リクエストあたりのメモリが増える、メモリ制約のあるハードウェアでの同時実行数が制限される、そしてプリフィルのレイテンシはキャッシュでは削減できない(最初のステップは依然としてフルのプロンプトを処理します)。
コンテキスト長の制限は KV Cache でより重要になりますか? はい。KV Cache なしでは、長いコンテキストはより多くの再計算を意味するだけです。KV Cache ありでは、長いコンテキストはより多くのメモリを意味します — キャッシュは線形に増えます。Llama-7B での 128k トークンのコンテキストには:
2 × 1 × 131072 × 32 × 32 × 128 × 2 バイト ≈ 64 GB
A100 80GB のモデル全体よりも多くなります。これが MQA、GQA、量子化された KV 表現を推し進めた根本的な圧力です。
22.9 KV Cache の最適化の方向性
KV Cache はすべての本番デプロイの標準です。しかし、そこから派生した最適化の一群があります:
KV ヘッド数を減らす(キャッシュするものを少なくする):
- MQA — すべてのクエリヘッドが1組の K/V を共有する
- GQA — クエリヘッドのグループが1組の K/V を共有する
- 第23章で解説します。
KV の精度を下げる(要素あたりのフットプリントを小さくする):
- キャッシュ済みの K と V を INT8 または FP8 に量子化する
- 品質の低下はほぼなく、メモリを約半分に削減できる
古い KV エントリを削除または圧縮する:
- スライディングウィンドウ Attention — 直近のウィンドウだけをキャッシュする
- StreamingLLM — 「Attention シンク」トークン + 直近のウィンドウを保持する
- PagedAttention(vLLM)— OS の仮想メモリのように KV ページを管理する
22.10 章のまとめ
| 概念 | ポイント |
|---|---|
| K と V をキャッシュする理由 | 古いトークンの K と V は推論中に一定。再計算するとトータルで O(N²) の無駄 |
| Q をキャッシュしない理由 | Q は現在の(新しい)トークンにのみ必要 |
| 計算の節約 | ステップあたりのコストが O(N²) から O(N) に低下;実測で5倍の高速化 |
| メモリコストの計算式 | 2 × batch × ctx × layers × heads × d_head × bytes |
| コンテキスト4kでの Llama-7B | バッチ要素1件あたり約 2 GB の KV Cache |
| プリフィル vs デコード | プリフィル:演算バウンド。デコード:メモリ帯域幅バウンド |
| Hugging Face の切り替え | use_cache=True(デフォルト) |
章末チェックリスト
この章を終えた後、以下のことができるようになっているはずです:
- キャッシュなしの自己回帰的な生成が冗長な計算を生む理由を説明できる。
- Q ではなく K と V だけをキャッシュする理由を説明できる。
- 特定のモデル設定に対して KV Cache のメモリを計算できる。
- プリフィルとデコードの2つのフェーズとそれぞれのボトルネックを説明できる。
- ベンチマークコードを実行して高速化を解釈できる。
参考文献
- Efficient Transformers: A Survey (Tay et al., 2020) — arXiv 2009.06732
- vLLM: Easy, Fast, and Cheap LLM Serving (PagedAttention) — github.com/vllm-project/vllm
- FlashAttention-2 (Dao, 2023) — Flash Decoding を介して KV Cache を効率的に組み合わせる
次章へ
KV Cache のメモリは Attention ヘッドの数に比例してスケールします。32ヘッドということは、レイヤーごとに 32組の K と 32組の V を保存することを意味します。自然な疑問が生まれます。32組すべてが独立している必要は本当にあるのでしょうか。
第23章では、MHA(各ヘッドが独自の K/V を持つ)から MQA(全ヘッドが1組の K/V を共有)、GQA(ヘッドのグループが K/V を共有)への進化を解説します。メモリ不足に陥ることなく、現代のモデルがより長いコンテキストを提供できるようにしたアーキテクチャの変化です。