一文要約: Flash Attention は、GPU のメインメモリに N×N のスコアマトリクスを書き出すことを避けるために再実装された、完全に正確な Attention です。少し余分な演算をする代わりに、大幅なメモリ転送削減を実現します。


21.1 なぜ標準 Attention は遅くなるのか

21.1.1 不思議な現象

NVIDIA A100 GPU で Transformer を学習しているとします。A100 の理論スループットは 312 TFLOPS(FP16)です。この数字を見れば、猛烈に速いはずです。ところが、シーケンス長が伸びると学習が急激に遅くなり、OOM エラーも頻発するようになります。

さらに奇妙なことに、GPU のメモリに余裕があるときでも、GPU 使用率が問題なさそうなときでも、Attention がボトルネックになります。

答えは、多くの人が見落とすところに隠れています。メモリ帯域幅です。

21.1.2 Attention のメモリ問題

標準 Attention は次のように計算されます:

Attention(Q,K,V)=softmax ⁣(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^T}{\sqrt{d_k}}\right)V

シーケンス長を NN、トークンあたりの次元数を dd とすると:

  • QQ の形状:[N,d][N, d]
  • KK の形状:[N,d][N, d]
  • QKTQK^T の形状:[N,N][N, N]

N=4096N = 4096 のとき、スコアマトリクスには 4096×4096=16,777,2164096 \times 4096 = 16{,}777{,}216 個の要素が入ります。FP16 では約 32 MB — ヘッド1つ、サンプル1つ分です。

現実的な学習ランに合わせてスケールアップすると(32ヘッド、バッチサイズ8、順伝播+逆伝播):

32×8×32 MB=8 GB32 \times 8 \times 32\ \text{MB} = 8\ \text{GB}

Attention のスコアマトリクスだけで 8 GB です。そして、学習の毎ステップでこのデータを GPU のメインメモリ経由で動かさなければなりません。

QK マトリクスと GPU メモリ階層の関係を示した図

21.1.3 GPU のメモリ階層

GPU のメモリは、単一のフラットなプールではありません。階層構造を持っています:

レベル名称容量帯域幅備考
オンチップSRAM(L1/L2/共有メモリ)〜20 MB〜19 TB/s非常に速い、非常に小さい
デバイスHBM(高帯域幅メモリ)〜40〜80 GB〜1.5〜3 TB/sGPU メインメモリ
ホストCPU DRAM〜1 TB〜12.8 GB/s大容量だが大幅に遅い

SRAM は HBM の 約20倍速です。

机の上(SRAM)と、部屋の向こうにある本棚(HBM)を想像してください。机の上にあるものはすぐに使えます。本棚の本を使うには、歩いて取りに行き、持ち帰らなければなりません。

標準 Attention はこのように動作します:

  1. Q,KQ, K を HBM から読み込む → QKTQK^T を計算する → HBM に書き戻す
  2. QKTQK^T を HBM から読み込む → Softmax を計算する → HBM に書き戻す
  3. Softmax の結果と VV を HBM から読み込む → 出力を計算する → HBM に書き戻す

HBM への往復ごとに時間が消費されます。これが本当のボトルネックです。


21.2 標準 Attention と Flash Attention:数字で比較する

21.2.1 標準 PyTorch の動作

# 標準 Attention
def standard_attention(Q, K, V):
    # ステップ1: QK^T を計算し、結果を HBM に保存する
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)

    # ステップ2: scores  HBM から読み込み、softmax を計算し、HBM に書き戻す
    attention_weights = torch.softmax(scores, dim=-1)

    # ステップ3: オプションの Dropout  さらに HBM の往復が発生する
    attention_weights = dropout(attention_weights)

    # ステップ4: weights  V  HBM から読み込み、出力を計算する
    output = torch.matmul(attention_weights, V)
    return output
PyTorch と FlashAttention のスループットベンチマーク比較

GPT-2 サイズの Attention での比較:

  • PyTorch 標準:〜15 ms。Matmul・Dropout・Softmax・Mask が個別カーネルとして分散実行
  • Flash Attention:〜3 ms。すべてを1つの fused kernel に統合

HBM の往復を排除するだけで、5倍の高速化を実現しています。

21.2.2 メモリ計算量

メモリの比較はさらに明確です:

実装Attention マトリクスのメモリ
標準 AttentionO(N2)O(N^2) — 完全なスコアマトリクスを HBM に保持
Flash AttentionO(N)O(N) — 入力と出力のみ、中間マトリクスなし

N=2048N = 2048 のとき、Flash Attention は中間メモリを約 2048分の1 に削減します。N=4096N = 4096 ではその比率がさらに倍になります。


21.3 タイリング:コアアイデア

21.3.1 直感的な理解

100×100 の巨大な掛け算の表を計算する必要があるとします。ナイーブなアプローチ:

  1. 表全体を計算し、大きな紙に書き出す。
  2. 各行を後処理する(Softmax)。
  3. 計算を続ける。

Flash Attention はこう考えます:

  1. 大きな表を 10×10 のタイルに切り分ける。
  2. スクラッチパッド(SRAM)上で、1つのタイルを丸ごと処理する。
  3. 表全体を書き出すことなく、最終結果を積み上げる。
SRAM タイリング:K・Q・V の小さなブロックを SRAM に読み込んで完全に処理する

各タイルに対して SRAM 内で行う処理の手順:

1. Q_block @ K_block^T
2. 因果マスクを適用する
3. オンライン Softmax(21.4節参照)
4. オプションの Dropout
5. V_block と掛け算する
6. 出力 O_i に積み上げる

HBM に戻す必要があるのは最終出力 OO だけです。巨大な中間 N×NN \times N マトリクスは、メモリ上に一切存在しません。

21.3.2 タイルのサイズはどのくらいか

A100 では、ストリーミングマルチプロセッサあたりの SRAM は約 192 KB です。同時に4つのものを収める必要があります:

  • QQ の1ブロック:Br×dB_r \times d
  • KK の1ブロック:Bc×dB_c \times d
  • VV の1ブロック:Bc×dB_c \times d
  • 出力の1ブロック:Br×dB_r \times d

ブロックサイズの式:

min(Br,Bc)=M4d\min(B_r, B_c) = \left\lceil\frac{M}{4d}\right\rceil

ここで MM は SRAM サイズ、dd はモデルの次元数です。M=192 KB=192×1024 bytesM = 192\text{ KB} = 192 \times 1024 \text{ bytes}d=512d = 512(モデルの幅)とすると:

192×10244×512=196,6082,048=96=96\left\lceil\frac{192 \times 1024}{4 \times 512}\right\rceil = \left\lceil\frac{196{,}608}{2{,}048}\right\rceil = \lceil 96 \rceil = 96

実際にはメモリアライメントの都合で 64 に切り下げるため、典型的な A100 実装では Br=Bc=64B_r = B_c = 64 になります。

A100 のストリーミングマルチプロセッサとブロックサイズの導出

21.3.3 コアループ

アルゴリズム: FlashAttention(簡略版)

入力:  Q, K, V  HBM 
出力:  O  HBM 

for j = 1 to T_c:          # K, V ブロックのアウターループ
    K_j, V_j  SRAM に読み込む

    for i = 1 to T_r:      # Q ブロックのインナーループ
        Q_i, O_i, l_i, m_i  HBM から SRAM に読み込む

        S_ij = Q_i @ K_j^T             # スコアブロック
        m_i を更新(実行中の最大値)
        l_i を更新(実行中の分母)
        O_i += rescaled_P_ij @ V_j     # 出力を積み上げる

        O_i, l_i, m_i  HBM に書き戻す

「rescaled」の部分を担当するのが、オンライン Softmax です。


21.4 オンライン Softmax:全体を見ずに Softmax を計算する

21.4.1 問題

標準 Softmax:

softmax(x)i=exij=1Kexj\text{softmax}(x)_i = \frac{e^{x_i}}{\sum_{j=1}^{K} e^{x_j}}
オンライン Softmax:ブロックをまたいで実行中の最大値と分母を追跡する

分母はすべての要素を合計します。ところがタイリング中は、1度に1ブロックしか見えません。どうやって正確に Softmax を計算すればよいでしょうか?

21.4.2 オンライン更新ルール

3つの実行中の値を管理します:

  • m(x)m(x):これまでに見た最大値
  • f(x)f(x):分子項のベクトル(スケール補正済みの指数)
  • l(x)l(x):分子項の合計(分母の累積値)

ブロック x(1)x^{(1)} を処理した後に新しいブロック x(2)x^{(2)} が来たとき:

1. 実行中の最大値を更新する:

m(x)=max ⁣(m(x(1)),m(x(2)))m(x) = \max\!\bigl(m(x^{(1)}),\, m(x^{(2)})\bigr)

2. 過去の分子をスケール補正する:

f(x)=[em(x(1))m(x)f(x(1)),    em(x(2))m(x)f(x(2))]f(x) = \Bigl[e^{m(x^{(1)}) - m(x)} \cdot f(x^{(1)}),\;\; e^{m(x^{(2)}) - m(x)} \cdot f(x^{(2)})\Bigr]

3. 分母を更新する:

l(x)=em(x(1))m(x)l(x(1))+em(x(2))m(x)l(x(2))l(x) = e^{m(x^{(1)}) - m(x)} \cdot l(x^{(1)}) + e^{m(x^{(2)}) - m(x)} \cdot l(x^{(2)})

4. 最終 Softmax:

softmax(x)=f(x)l(x)\text{softmax}(x) = \frac{f(x)}{l(x)}

21.4.3 なぜ最大値を追跡するのか

emoldmnewe^{m_\text{old} - m_\text{new}} の補正係数は、数値安定性のためです。

大きな xx に対して exe^x を計算するとオーバーフローします。標準的な対処法は、指数計算の前に最大値を引くことです:

softmax(x)i=eximax(x)jexjmax(x)\text{softmax}(x)_i = \frac{e^{x_i - \max(x)}}{\sum_{j} e^{x_j - \max(x)}}

タイリング計算では、各ブロックがそれぞれのローカル最大値を持ちます。新しいブロックが来てグローバル最大値が更新されたとき、emoldmnewe^{m_\text{old} - m_\text{new}} を使ってそれまでの累積値を遡って補正します。

21.4.4 計算例

完全な行:[3.01,  0.09,  2.48,  1.95][3.01,\; 0.09,\; 2.48,\; 1.95]。2つのブロックに分割します。

ブロック1[3.01,  0.09][3.01,\; 0.09]

  • m(1)=3.01m^{(1)} = 3.01
  • f(1)=[e0,  e2.92]=[1,  0.053]f^{(1)} = [e^{0},\; e^{-2.92}] = [1,\; 0.053]
  • l(1)=1.053l^{(1)} = 1.053

ブロック2[2.48,  1.95][2.48,\; 1.95]

  • m(2)=2.48m^{(2)} = 2.48
  • 新しいグローバル最大値:m=max(3.01,2.48)=3.01m = \max(3.01, 2.48) = 3.01(変わらず)
  • 補正係数:ブロック1は e3.013.01=1e^{3.01 - 3.01} = 1;ブロック2は e2.483.01=0.59e^{2.48 - 3.01} = 0.59
  • l=1×1.053+0.59×(1+e0.53)1.99l = 1 \times 1.053 + 0.59 \times (1 + e^{-0.53}) \approx 1.99

最初の要素の Softmax:

softmax(3.01)=11.9950.25%\text{softmax}(3.01) = \frac{1}{1.99} \approx 50.25\%

直接計算すると 50.28% になります。わずかな差は例題の丸めによるもので、実際の計算は厳密に正確です。

ブロックごとの累積処理を可視化した図

21.5 FlashAttention の完全なアルゴリズム

アルゴリズム: FLASHATTENTION

入力:  Q, K, V  R^{N×d}  HBM に;オンチップ SRAM のサイズ M

1.  B_c = ceil(M / 4d)、B_r = min(ceil(M / 4d), d) を設定する
2.  O = 0、l = 0、m = -  HBM に初期化する

3.  Q  T_r = ceil(N / B_r) 個のブロックに分割する
4.  K, V  T_c = ceil(N / B_c) 個のブロックに分割する

5.  for j = 1 to T_c:
6.      K_j, V_j  HBM から SRAM に読み込む
7.      for i = 1 to T_r:
8.          Q_i, O_i, l_i, m_i  HBM から SRAM に読み込む
9.          S_ij  = Q_i @ K_j^T
10.         m̃_ij  = rowmax(S_ij)
11.         P̃_ij  = exp(S_ij  m̃_ij)、l̃_ij = rowsum(P̃_ij)
12.         m_i_new = max(m_i, m̃_ij)
13.         l_i_new = exp(m_i  m_i_new) × l_i + exp(m̃_ij  m_i_new) × l̃_ij
14.         O_i     = diag(l_i_new)^{-1} × (diag(l_i) × exp(m_i  m_i_new) × O_i
                         + exp(m̃_ij  m_i_new) × P̃_ij × V_j)
15.         O_i, l_i_new, m_i_new  HBM に書き戻す
16. O を返す
SRAM と HBM のデータフローをアノテーション付きで示した FlashAttention アルゴリズム

21.5.1 IO 計算量

標準 Attention:

  • S=QKTS = QK^T を書き込む:O(N2)O(N^2)
  • 読み込み、Softmax、書き込み:O(N2)O(N^2)
  • Softmax + V を読み込み、出力を書き込む:O(N2+Nd)O(N^2 + Nd)
  • HBM トラフィック合計O(N2+Nd)O(N^2 + Nd)

Flash Attention:

  • 各 K/V ブロックをアウターループで1回読み込む:TcT_c 回の O(Bcd)O(B_c d) 読み込み
  • 各 Q/O ブロックをインナーループで読み書きする:Tr×TcT_r \times T_c 回の O(Brd)O(B_r d)
  • HBM トラフィック合計O ⁣(N2d2M)O\!\left(\frac{N^2 d^2}{M}\right)

MdM \gg d のとき、Flash Attention の IO 計算量は O(N2d/M)O(N^2 d / M) に近づきます。標準パスに対して M/dM/d 倍の改善です。


21.6 Flash Attention 1 と Flash Attention 2

FA1 と FA2 の並列化戦略とスループット比較

21.6.1 FA1 が達成したこと

Flash Attention 1(2022年)はこのアイデアを証明し、実際の高速化を実現しました。インナーループの並列化は、出力アキュムレータを共有するワーカー間の同期によって制約されていました。

21.6.2 FA2 が追加したこと

Flash Attention 2(2023年)は3つの重要な変更を加えました:

  1. より良いワーク分割 — ストリーミングマルチプロセッサ間の同期を減らし、ハードウェアをより均一に活用する
  2. MQA と GQA のネイティブサポート — 第23章で扱うヘッド共有パターンを直接処理する
  3. 非行列積演算の削減 — レジスタのスピルが減り、パイプラインがクリーンになる

A100 80GB SXM4 でのパフォーマンス:

設定PyTorchFA1FA2
シーケンス長 2k、head_dim 64〜50 TFLOPS〜120 TFLOPS〜175 TFLOPS
シーケンス長 4k、head_dim 64〜45 TFLOPS〜110 TFLOPS〜170 TFLOPS
シーケンス長 8k、head_dim 128〜40 TFLOPS〜100 TFLOPS〜165 TFLOPS

FA2 は、メモリバウンドな演算において A100 のピークスループットの 50〜70% に達します。これは非常に優秀な数字です。


21.7 実際の使い方

21.7.1 インストール

pip install flash-attn --no-build-isolation

21.7.2 直接 API を使う

import torch
from flash_attn import flash_attn_func

batch_size, seq_len, num_heads, head_dim = 2, 4096, 32, 128

# 形状: [batch, seq_len, num_heads, head_dim]
q = torch.randn(batch_size, seq_len, num_heads, head_dim,
                dtype=torch.float16, device='cuda')
k = torch.randn_like(q)
v = torch.randn_like(q)

output = flash_attn_func(q, k, v, causal=True)

21.7.3 PyTorch 2.0+ の組み込み機能

PyTorch 2.0 は scaled_dot_product_attention を追加しました。入力が条件を満たす場合、自動的に Flash Attention にディスパッチされます:

import torch.nn.functional as F

output = F.scaled_dot_product_attention(
    query, key, value,
    attn_mask=None,
    dropout_p=0.0,
    is_causal=True,
)
# PyTorch はハードウェアと入力の形状に応じて、
# Flash Attention、Memory Efficient Attention、
# または標準パスを自動選択します。

21.7.4 Hugging Face Transformers

from transformers import AutoModelForCausalLM
import torch

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    torch_dtype=torch.float16,
    attn_implementation="flash_attention_2",
)

21.7.5 知っておくべき制限事項

ハードウェア:Flash Attention には最新の NVIDIA GPU が必要です。Ampere(A100)以降が最良の結果を出します。古いカードや NVIDIA 以外のハードウェアは標準パスにフォールバックします。

逆伝播:Flash Attention は完全なスコアマトリクスを保存しないため、逆伝播でそれを再計算する必要があります。余分な演算はIO の節約に比べると安価で、エンドツーエンドの学習では依然として 2〜4倍のアドバンテージがあります。

非標準マスク:カスタムの Attention マスク(スパース、スライディングウィンドウ、任意のパターン)には特別な処理が必要な場合があります。FA2 はすでに因果マスク・パディングマスク・MQA/GQA をすぐに使える形でサポートしています。


21.8 章のまとめ

概念ポイント
ボトルネック演算スループットではなく、HBM 帯域幅
SRAM vs HBMSRAM 〜19 TB/s;HBM 〜1.5 TB/s;SRAM は約20倍速
タイリングQ/K/V の小さなブロックを SRAM で処理し、N×N マトリクスを一切書き出さない
オンライン Softmax実行中の最大値・分子・分母を追跡し、グローバル最大値が更新されたら過去のブロックを補正する
メモリ計算量標準:O(N2)O(N^2);Flash:O(N)O(N)
FA1 → FA2より良い並列化、MQA/GQA のネイティブサポート、FA1 比 1.5〜2倍
エンドツーエンドの高速化学習で 2〜4倍;Attention カーネル単体で 5倍

章末チェックリスト

この章を終えた後、以下のことができるようになっているはずです:

  • TFLOPS ではなく HBM 帯域幅が Attention のボトルネックである理由を説明できる。
  • タイリングが何をするものか、なぜ N×N マトリクスのマテリアライズを回避できるかを説明できる。
  • オンライン Softmax の更新ルールを順を追って説明できる。
  • 標準 Attention と Flash Attention のメモリ計算量を言える。
  • Flash Attention が近似ではなく厳密であることを説明できる。
  • FA1 と FA2 を比較できる。

参考文献

  • FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness (Dao et al., 2022) — arXiv 2205.14135
  • FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning (Dao, 2023) — arXiv 2307.08691
  • Memory Efficient Attention (xFormers)
  • PagedAttention (vLLM) — KV Cache 管理のための補完的アプローチ

次章へ

Flash Attention は個々の Attention 計算をより安価にします。しかし、自己回帰的な生成では、別の問題が残っています。モデルは毎ステップ、古いトークンの K と V を再計算し続けています。

第22章では KV Cache でその問題を解決します。これは Flash Attention の自然なパートナーであり、高速な推論のための2本柱のうちの1本です。

このページを引用する
Zhang, Wayland (2026). 第21章: Flash Attention - メモリを意識した Attention. In Transformer アーキテクチャ:直感から実装まで. https://waylandz.com/llm-transformer-book-ja/chapter-21-flash-attention
@incollection{zhang2026transformer_ja_chapter_21_flash_attention,
  author = {Zhang, Wayland},
  title = {第21章: Flash Attention - メモリを意識した Attention},
  booktitle = {Transformer アーキテクチャ:直感から実装まで},
  year = {2026},
  url = {https://waylandz.com/llm-transformer-book-ja/chapter-21-flash-attention}
}