一言まとめ:
model.pyとは、これまでの章で扱ってきたコンポーネント — Embedding、位置エンコーディング、Multi-Head Attention、FFN、LayerNorm — を PyTorch で配線するだけのものです。各クラスはひとつの数式に直接対応します。
完全なコードリポジトリ: github.com/waylandzhang/Transformer-from-scratch
18.1 コードを書く前に: 全体像
18.1.1 何を実装するのか
Model (完全なモデル)
├── Token Embedding
├── Positional Encoding
├── N × TransformerBlock
│ ├── LayerNorm
│ ├── Multi-Head Attention
│ ├── LayerNorm
│ └── Feed Forward Network
├── 最後の LayerNorm
└── Output Linear (語彙への射影)
18.1.2 ファイル構成
すべてを単一の model.py に収めます。
# model.py の全体構成
import math
import torch
import torch.nn as nn
from torch.nn import functional as F
class FeedForwardNetwork(nn.Module): # FFN
...
class Attention(nn.Module): # 単一の Attention ヘッド
...
class MultiHeadAttention(nn.Module): # Multi-Head Attention
...
class TransformerBlock(nn.Module): # Transformer ブロック
...
class Model(nn.Module): # 完全なモデル
...
私は高水準ライブラリに手を伸ばす前に、一度はこのファイルをゼロから書いてみることをおすすめします。そのあとならライブラリは「魔法の幕」ではなく、純粋な生産性ツールに見えてくるはずです。
18.2 Feed Forward Network
18.2.1 FFN 構造の復習
第15章の通り、FFN は2層の全結合ネットワークです。
入力 [batch, seq, d_model]
|
Linear1: d_model -> d_model × 4 (4倍に拡張)
|
ReLU 活性化
|
Linear2: d_model × 4 -> d_model (元の幅に戻す)
|
Dropout
|
出力 [batch, seq, d_model]
4倍への拡張は、各位置で表現容量を増やしてから再び圧縮することを意味します。FFN こそが、モデルの「知識」が格納される場所です。
18.2.2 コード
# Feed Forward Network の定義
class FeedForwardNetwork(nn.Module):
def __init__(self, d_model, dropout):
super().__init__()
self.d_model = d_model
self.dropout = dropout
self.ffn = nn.Sequential(
nn.Linear(self.d_model, self.d_model * 4), # 4倍に拡張
nn.ReLU(), # 活性化
nn.Linear(self.d_model * 4, self.d_model), # 元の幅に戻す
nn.Dropout(self.dropout) # 正則化
)
def forward(self, x):
return self.ffn(x)
18.2.3 コードの読み解き
| コード | 役割 | 形状の変化 |
|---|---|---|
nn.Linear(d_model, d_model * 4) | 第1線形層 | [B,T,512] → [B,T,2048] |
nn.ReLU() | 非線形性 | 変化なし |
nn.Linear(d_model * 4, d_model) | 第2線形層 | [B,T,2048] → [B,T,512] |
nn.Dropout(dropout) | 正則化のためのランダム dropout | 変化なし |
18.3 Attention (単一ヘッド)
18.3.1 Attention 数式の復習
コードでは以下を実装する必要があります。
- 線形射影で Q, K, V を生成する
- Attention スコア Q @ K^T を計算する
- √d_k でスケーリングする
- Causal Mask を適用する (未来位置への参照を防ぐ)
- Softmax 正規化
- V を掛けて出力を得る
18.3.2 コード
# 単一ヘッドの Scaled Dot Product Attention
class Attention(nn.Module):
def __init__(self, d_model, head_size, context_length, dropout):
super().__init__()
self.d_model = d_model
self.head_size = head_size
self.context_length = context_length
self.dropout = dropout
# Q, K, V の線形射影
self.Wq = nn.Linear(self.d_model, self.head_size, bias=False)
self.Wk = nn.Linear(self.d_model, self.head_size, bias=False)
self.Wv = nn.Linear(self.d_model, self.head_size, bias=False)
# Causal Mask: 下三角行列
self.register_buffer('mask', torch.tril(torch.ones(self.context_length, self.context_length)))
self.dropout = nn.Dropout(self.dropout)
def forward(self, x):
B, T, C = x.shape # Batch, Time (seq_len), Channels (d_model)
# 1. Q, K, V を生成
q = self.Wq(x) # [B, T, head_size]
k = self.Wk(x) # [B, T, head_size]
v = self.Wv(x) # [B, T, head_size]
# 2. Attention スコア Q @ K^T を計算してスケーリング
weights = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_size)
# weights: [B, T, T]
# 3. Causal Mask を適用 (未来位置を -inf に)
weights = weights.masked_fill(self.mask[:T, :T] == 0, float('-inf'))
# 4. Softmax 正規化
weights = F.softmax(weights, dim=-1)
# 5. Dropout
weights = self.dropout(weights)
# 6. V を掛ける
output = weights @ v # [B, T, head_size]
return output
18.3.3 重要なコードのポイント
Causal Mask の仕組み:
self.register_buffer('mask', torch.tril(torch.ones(context_length, context_length)))
torch.tril は下三角行列を生成します。
[[1, 0, 0, 0],
[1, 1, 0, 0],
[1, 1, 1, 0],
[1, 1, 1, 1]]
位置 i は位置 0 から i までしか参照できません。これこそが Causal の意味するところです — 時間における因果関係が保たれます。
なぜ register_buffer を使うのか?
このマスクは学習可能なパラメータではなく、更新もされません。しかしモデルがどのデバイス (CPU でも GPU でも) にあろうと一緒に移動する必要があります。register_buffer はまさにそのためのツールです。パラメータではないが永続的なテンソルとして扱えます。
18.4 Multi-Head Attention
18.4.1 マルチヘッドの考え方
Multi-Head Attention とは、複数の単一ヘッド Attention を並列に動かし、最後に出力を連結したもの です。
各ヘッドは異なるパターンに注目できます。あるヘッドは主語と動詞の一致を追跡し、別のヘッドは代名詞と先行詞の関係を追跡するかもしれません。それらを連結すれば、単一のヘッドだけでは得られない豊かな表現が得られます。
18.4.2 コード
# Multi-Head Attention の定義
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads, head_size, context_length, dropout):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.head_size = head_size
self.context_length = context_length
self.dropout = dropout
# 複数の Attention ヘッドを作成
self.heads = nn.ModuleList([
Attention(self.d_model, self.head_size, self.context_length, self.dropout)
for _ in range(self.num_heads)
])
# 出力射影 Wo
self.projection_layer = nn.Linear(self.d_model, self.d_model)
self.dropout = nn.Dropout(self.dropout)
def forward(self, x):
# 全ヘッドを並列に実行
head_outputs = [head(x) for head in self.heads]
# 全ヘッドの出力を連結
head_outputs = torch.cat(head_outputs, dim=-1) # [B, T, num_heads * head_size] = [B, T, d_model]
# 出力射影を適用
out = self.dropout(self.projection_layer(head_outputs))
return out
18.4.3 形状の追跡
d_model=512, num_heads=8, head_size=64 を仮定します。
入力 x: [B, T, 512]
|
各ヘッドの出力: [B, T, 64] # 8 ヘッド
|
連結: [B, T, 512] # 64 × 8 = 512
|
Wo 射影: [B, T, 512]
|
出力: [B, T, 512]
重要な関係: head_size = d_model // num_heads
18.5 論文版 Multi-Head Attention
18.5.1 ふたつの実装の比較
これまでの実装は 物理的に分離 されています。各ヘッドが独自の Wq, Wk, Wv 行列を持ちます。
「Attention Is All You Need」論文の元実装は 論理的に分離 されています。ひとつの大きな線形層を通したあとで、複数のヘッドに reshape します。
18.5.2 論文版のコード
# 論文スタイルの Multi-Head Attention (論理的分離)
class MultiHeadAttention_Paper(nn.Module):
def __init__(self, d_model, num_heads, head_size, context_length, dropout):
super().__init__()
self.context_length = context_length
self.d_model = d_model
self.num_heads = num_heads
self.head_size = head_size
# ひとつの大きな線形層、出力次元はそのまま d_model
self.Wq = nn.Linear(d_model, d_model)
self.Wk = nn.Linear(d_model, d_model)
self.Wv = nn.Linear(d_model, d_model)
self.Wo = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
self.register_buffer('mask', torch.tril(torch.ones(self.context_length, self.context_length)))
def split_heads(self, x):
"""論理的に複数ヘッドへ分割"""
batch_size = x.shape[0]
context_length = x.shape[1]
# [B, T, d_model] -> [B, T, num_heads, head_size] -> [B, num_heads, T, head_size]
x = x.reshape(batch_size, context_length, self.num_heads, self.head_size)
x = x.permute(0, 2, 1, 3)
return x
def forward(self, x):
B, T, C = x.shape
# 射影してから分割
q = self.split_heads(self.Wq(x)) # [B, num_heads, T, head_size]
k = self.split_heads(self.Wk(x))
v = self.split_heads(self.Wv(x))
# Attention を計算
weights = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_size)
weights = weights.masked_fill(self.mask[:T, :T] == 0, float('-inf'))
weights = F.softmax(weights, dim=-1)
weights = self.dropout(weights)
output = weights @ v # [B, num_heads, T, head_size]
# ヘッドを統合: [B, num_heads, T, head_size] -> [B, T, d_model]
output = output.transpose(1, 2).reshape(-1, T, C)
# 出力射影
output = self.Wo(output)
return output
18.5.3 比較
| 物理的分離 | 論理的分離 (論文) | |
|---|---|---|
| Wq/Wk/Wv の数 | 各 num_heads 個 | 各 1 個 |
| パラメータ数 | 同じ | 同じ |
| 計算効率 | やや劣る (ループ) | より高い (GPU 並列化) |
| コードの読みやすさ | 学習用としては明快 | やや複雑 |
パラメータ数が同じになる理由:
- 物理的分離:
num_heads × (d_model × head_size) = d_model × d_model - 論理的分離:
d_model × d_model
実用上、論文版のほうが GPU レベルの並列性を引き出せるので高速です。学習目的では、物理的に分離したほうが理解しやすいでしょう。
18.6 Transformer Block
18.6.1 ブロック構造
各ブロックには以下が含まれます。
- LayerNorm → Multi-Head Attention → 残差
- LayerNorm → FFN → 残差
これは GPT-2 が採用している Pre-Norm 構造です。
18.6.2 コード
# Transformer Block の定義
class TransformerBlock(nn.Module):
def __init__(self, d_model, num_heads, head_size, context_length, dropout):
super().__init__()
self.ln1 = nn.LayerNorm(d_model)
self.ln2 = nn.LayerNorm(d_model)
self.mha = MultiHeadAttention(d_model, num_heads, head_size, context_length, dropout)
self.ffn = FeedForwardNetwork(d_model, dropout)
def forward(self, x):
# Attention + 残差
x = x + self.mha(self.ln1(x))
# FFN + 残差
x = x + self.ffn(self.ln2(x))
return x
18.6.3 Pre-Norm と Post-Norm
Pre-Norm (本書で採用):
x = x + self.mha(self.ln1(x)) # Attention の前に正規化
Post-Norm (元の Transformer):
x = self.ln1(x + self.mha(x)) # Attention の後に正規化
Pre-Norm のほうが学習が安定します。GPT-2、LLaMA、そして現代のすべてのモデルがこちらを採用しているのはそのためです。
18.7 完全な Model クラス
18.7.1 モデル構造
# 完全なモデル定義
class Model(nn.Module):
def __init__(self, h_params):
super().__init__()
# ハイパーパラメータ辞書から設定を読み込む
self.context_length = h_params['context_length']
self.d_model = h_params['d_model']
self.num_blocks = h_params['num_blocks']
self.num_heads = h_params['num_heads']
self.head_size = self.d_model // self.num_heads
self.dropout = h_params['dropout']
self.device = h_params['device']
self.max_token_value = h_params['max_token_value']
# Token Embedding
self.token_embedding_lookup_table = nn.Embedding(self.max_token_value, self.d_model)
# Transformer ブロック + 最終 LayerNorm
self.transformer_blocks = nn.Sequential(*(
[TransformerBlock(self.d_model, self.num_heads, self.head_size,
self.context_length, self.dropout)
for _ in range(self.num_blocks)] +
[nn.LayerNorm(self.d_model)]
))
# 出力射影層
self.model_out_linear_layer = nn.Linear(self.d_model, self.max_token_value)
18.7.2 順伝播
def forward(self, idx, targets=None):
B, T = idx.shape
# 1. 位置エンコーディング (sinusoidal)
position_encoding_lookup_table = torch.zeros(self.context_length, self.d_model, device=self.device)
position = torch.arange(0, self.context_length, dtype=torch.float, device=self.device).unsqueeze(1)
div_term = torch.exp(torch.arange(0, self.d_model, 2, dtype=torch.float, device=self.device) * (-math.log(10000.0) / self.d_model))
position_encoding_lookup_table[:, 0::2] = torch.sin(position * div_term)
position_encoding_lookup_table[:, 1::2] = torch.cos(position * div_term)
position_embedding = position_encoding_lookup_table[:T, :].to(self.device)
# 2. Token Embedding + 位置エンコーディング
x = self.token_embedding_lookup_table(idx) + position_embedding
# 3. すべての Transformer ブロックを通過
x = self.transformer_blocks(x)
# 4. 語彙へ射影
logits = self.model_out_linear_layer(x)
# 5. targets が与えられていれば (学習モード) 損失を計算
if targets is not None:
B, T, C = logits.shape
logits_reshaped = logits.view(B * T, C)
targets_reshaped = targets.view(B * T)
loss = F.cross_entropy(input=logits_reshaped, target=targets_reshaped)
else:
loss = None
return logits, loss
18.7.3 重要なコードのポイント
位置エンコーディングの式:
div_term = torch.exp(torch.arange(0, self.d_model, 2).float() * (-math.log(10000.0) / self.d_model))
position_encoding_lookup_table[:, 0::2] = torch.sin(position * div_term)
position_encoding_lookup_table[:, 1::2] = torch.cos(position * div_term)
これは第5章で扱った sinusoidal 位置エンコーディングです。
- 偶数次元には sin
- 奇数次元には cos
- 次元インデックスが大きくなるほど周波数が小さくなる
損失関数:
loss = F.cross_entropy(input=logits_reshaped, target=targets_reshaped)
クロスエントロピー損失は、モデルの予測分布と真の one-hot ターゲットとの KL ダイバージェンスを測ります。語彙数が 50,000 で、モデルがランダム初期化されている場合、損失は ln(50000) ≈ 10.8 あたりから始まります。十分に学習されたモデルではこれが 3 を切ります。
18.8 生成関数
18.8.1 自己回帰生成
def generate(self, idx, max_new_tokens=100, temperature=1.0, top_k=None):
"""
自己回帰によるテキスト生成。
Args:
idx: 初期トークン ID [B, T]
max_new_tokens: 生成する新しいトークンの最大数
temperature: 出力のランダム性を制御
top_k: 確率上位 k 個のトークンからのみサンプリング
"""
for _ in range(max_new_tokens):
# 1. コンテキスト長の上限まで切り取る
idx_crop = idx[:, -self.context_length:]
# 2. 順伝播
logits, loss = self.forward(idx_crop)
# 3. 最終位置の logits のみを取り出し、temperature を適用
logits = logits[:, -1, :] / temperature
# 4. オプション: top-k 候補だけを残す
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
# 5. Softmax で確率に変換
probs = F.softmax(input=logits, dim=-1)
# 6. 次のトークンをサンプリング
idx_next = torch.multinomial(input=probs, num_samples=1)
# 7. シーケンスに追加
idx = torch.cat((idx, idx_next), dim=1)
return idx
18.8.2 Temperature
第6章で扱った temperature です。
logits = logits[:, -1, :] / temperature
- T < 1: 確率がより集中する — より決定論的
- T = 1: 元の分布
- T > 1: 確率がより一様になる — よりランダム
事実ベースの補完には T ≈ 0.3、創造的な生成には T ≈ 0.8 から 1.0 を使います。
18.8.3 Top-K サンプリング
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
確率上位 k 個のトークンだけを残し、それ以外を -inf に設定します。これにより、統計的にあり得ず、しばしば意味の通らない低確率トークンの生成を防げます。
18.9 完成版 model.py
"""
テキスト生成のための Transformer Decoder-only ベースモデル
"""
import math
import torch
import torch.nn as nn
from torch.nn import functional as F
class FeedForwardNetwork(nn.Module):
def __init__(self, d_model, dropout):
super().__init__()
self.ffn = nn.Sequential(
nn.Linear(d_model, d_model * 4),
nn.ReLU(),
nn.Linear(d_model * 4, d_model),
nn.Dropout(dropout)
)
def forward(self, x):
return self.ffn(x)
class Attention(nn.Module):
def __init__(self, d_model, head_size, context_length, dropout):
super().__init__()
self.head_size = head_size
self.Wq = nn.Linear(d_model, head_size, bias=False)
self.Wk = nn.Linear(d_model, head_size, bias=False)
self.Wv = nn.Linear(d_model, head_size, bias=False)
self.register_buffer('mask', torch.tril(torch.ones(context_length, context_length)))
self.dropout = nn.Dropout(dropout)
def forward(self, x):
B, T, C = x.shape
q = self.Wq(x)
k = self.Wk(x)
v = self.Wv(x)
weights = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_size)
weights = weights.masked_fill(self.mask[:T, :T] == 0, float('-inf'))
weights = F.softmax(weights, dim=-1)
weights = self.dropout(weights)
return weights @ v
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads, head_size, context_length, dropout):
super().__init__()
self.heads = nn.ModuleList([
Attention(d_model, head_size, context_length, dropout)
for _ in range(num_heads)
])
self.projection_layer = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
head_outputs = torch.cat([head(x) for head in self.heads], dim=-1)
return self.dropout(self.projection_layer(head_outputs))
class TransformerBlock(nn.Module):
def __init__(self, d_model, num_heads, head_size, context_length, dropout):
super().__init__()
self.ln1 = nn.LayerNorm(d_model)
self.ln2 = nn.LayerNorm(d_model)
self.mha = MultiHeadAttention(d_model, num_heads, head_size, context_length, dropout)
self.ffn = FeedForwardNetwork(d_model, dropout)
def forward(self, x):
x = x + self.mha(self.ln1(x))
x = x + self.ffn(self.ln2(x))
return x
class Model(nn.Module):
def __init__(self, h_params):
super().__init__()
self.context_length = h_params['context_length']
self.d_model = h_params['d_model']
self.num_blocks = h_params['num_blocks']
self.num_heads = h_params['num_heads']
self.head_size = self.d_model // self.num_heads
self.dropout = h_params['dropout']
self.device = h_params['device']
self.max_token_value = h_params['max_token_value']
self.token_embedding_lookup_table = nn.Embedding(self.max_token_value, self.d_model)
self.transformer_blocks = nn.Sequential(*(
[TransformerBlock(self.d_model, self.num_heads, self.head_size,
self.context_length, self.dropout)
for _ in range(self.num_blocks)] +
[nn.LayerNorm(self.d_model)]
))
self.model_out_linear_layer = nn.Linear(self.d_model, self.max_token_value)
def forward(self, idx, targets=None):
B, T = idx.shape
# Positional Encoding
position_encoding = torch.zeros(self.context_length, self.d_model, device=self.device)
position = torch.arange(0, self.context_length, dtype=torch.float, device=self.device).unsqueeze(1)
div_term = torch.exp(torch.arange(0, self.d_model, 2, dtype=torch.float, device=self.device) * (-math.log(10000.0) / self.d_model))
position_encoding[:, 0::2] = torch.sin(position * div_term)
position_encoding[:, 1::2] = torch.cos(position * div_term)
x = self.token_embedding_lookup_table(idx) + position_encoding[:T, :].to(self.device)
x = self.transformer_blocks(x)
logits = self.model_out_linear_layer(x)
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
else:
loss = None
return logits, loss
def generate(self, idx, max_new_tokens=100, temperature=1.0, top_k=None):
for _ in range(max_new_tokens):
idx_crop = idx[:, -self.context_length:]
logits, _ = self.forward(idx_crop)
logits = logits[:, -1, :] / temperature
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, idx_next), dim=1)
return idx
18.10 章のまとめ
18.10.1 コードと概念の対応
| クラス | 概念 | 章 |
|---|---|---|
FeedForwardNetwork | フィードフォワードネットワーク | 第7章 |
Attention | 単一ヘッド Attention | 第9〜12章 |
MultiHeadAttention | Multi-Head Attention | 第11章 |
TransformerBlock | Transformer ブロック | 第13章 |
Model | 完全なモデル | 第15章 |
18.10.2 パラメータ数の概算
d_model=512, num_heads=8, num_blocks=6, vocab_size=50,000 を仮定します。
| コンポーネント | 式 | パラメータ数 |
|---|---|---|
| Token Embedding | vocab × d_model | 約 25.6M |
| Attention (×6) | 4 × d_model² × 6 | 約 6.3M |
| FFN (×6) | 2 × d_model × 4×d_model × 6 | 約 12.6M |
| Output Linear | d_model × vocab | 約 25.6M |
合計: およそ 70M パラメータ
18.10.3 核心的な気づき
model.pyとは、これまでの章のコンポーネントを PyTorch でつなげただけのものです。各クラスはひとつの概念に対応します — FFN、Attention、MultiHeadAttention、TransformerBlock、Model。概念を理解していればコードは自明に見えます。逆ではありません。
章のチェックリスト
この章を終えたら、次のことができるはずです。
-
FeedForwardNetworkを独力で実装できる。 -
Attention(Causal Mask 含む) を独力で実装できる。 -
MultiHeadAttentionを独力で実装できる。 - 物理的分離型と論理的分離型の MHA 実装の違いを説明できる。
-
Model.forward()全体のデータフローを説明できる。
完全なコード
完全な実装は GitHub にあります。
model.py、train.py、inference.py、そしてステップバイステップの Jupyter ノートブックが含まれています。
次の章でお会いしましょう
モデルは出来上がりました。しかし今のモデルは何も知りません — すべてのパラメータがランダムに初期化されたままだからです。プロンプトを与えても、出てくるのはノイズです。
第19章では学習ループを書きます。データを読み込み、順伝播し、損失を計算し、誤差逆伝播でパラメータを更新する。読み終えるころには、モデルが本当に「次のトークンを予測する」ことを学び始めます。お楽しみに。