Chapter 13 · 16 min
The training loop
Write the training loop, plot the loss curve, save a checkpoint, generate a sample. This is where the project starts to feel real.
You have data/train.bin (chapter 11) and llm/model.py (chapter 12). Now we train.
Training is fundamentally the loop from chapter 5:
- Sample a batch.
- Forward pass — get logits and loss.
- Backward pass — compute gradients.
- Step the optimizer.
- Repeat. What changes at scale is everything around the loop: how you batch, how you sample, how you schedule the learning rate, how you decide when to stop. This chapter wires those pieces up.
1. The learning-rate schedule
Vanilla GD with a fixed learning rate is rarely the best choice for transformers. The standard pattern is:
- Linear warmup for the first few hundred steps. Stops the loss from exploding when the model is still in random-init noise.
- Cosine decay down to a small floor (usually 10% of the max). Lets the model fine-tune at the end without overshooting. Write the schedule.
Code · JavaScript
Slide the warmup. Short warmup with a high peak learning rate is risky on small batches; long warmup wastes compute. The standard heuristic is warmup ≈ 1-5% of total steps. For 10 000 steps, that's 100-500.
2. The training script
Save this as scripts/train.py. It loads data/train.bin / data/val.bin, sets up the model and optimizer, and runs the loop.
"""scripts/train.py — train the model from chapter 12 on data/train.bin."""
import math
import time
from pathlib import Path
import numpy as np
import torch
from llm.model import GPT, GPTConfig
# --- config ---
# [1]
cfg = GPTConfig()
batch_size = 32
max_steps = 5000
eval_every = 250
warmup_steps = 100
max_lr = 3e-4
min_lr = 3e-5
device = "mps" if torch.backends.mps.is_available() else (
"cuda" if torch.cuda.is_available() else "cpu"
)
print(f"training on {device}")
# --- data ---
# [2]
train_ids = np.memmap("data/train.bin", dtype=np.uint16, mode="r")
val_ids = np.memmap("data/val.bin", dtype=np.uint16, mode="r")
# [3]
def get_batch(split):
data = train_ids if split == "train" else val_ids
ix = torch.randint(len(data) - cfg.block_size - 1, (batch_size,))
x = torch.stack([torch.from_numpy(data[i : i + cfg.block_size].astype(np.int64)) for i in ix])
y = torch.stack([torch.from_numpy(data[i + 1 : i + 1 + cfg.block_size].astype(np.int64)) for i in ix])
return x.to(device), y.to(device)
# --- model ---
# [4]
model = GPT(cfg).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=max_lr, betas=(0.9, 0.95))
# [5]
def lr_at(step):
if step < warmup_steps:
return max_lr * (step / warmup_steps)
t = (step - warmup_steps) / (max_steps - warmup_steps)
return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * t))
@torch.no_grad()
# [6]
def eval_loss():
model.eval()
losses = []
for _ in range(20):
x, y = get_batch("val")
_, loss = model(x, y)
losses.append(loss.item())
model.train()
return sum(losses) / len(losses)
# --- loop ---
t0 = time.time()
# [7]
for step in range(max_steps):
lr = lr_at(step)
for g in optimizer.param_groups:
g["lr"] = lr
x, y = get_batch("train")
_, loss = model(x, y)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
if step % eval_every == 0 or step == max_steps - 1:
val = eval_loss()
dt = time.time() - t0
print(f"step {step:5d} | lr {lr:.5f} | train {loss.item():.4f} | val {val:.4f} | t {dt:.1f}s")
# [8]
Path("checkpoints").mkdir(exist_ok=True)
torch.save(model.state_dict(), "checkpoints/model.pt")
print("saved checkpoints/model.pt")The script has five moving parts:
- [1] Config sets model size, batch size, learning-rate schedule, and device. These are the knobs you will tune later.
- [2] Data uses
np.memmapto open the binary files without loading everything into RAM. - [3]
get_batchslices random windows and shifts them by one token to create(x, y).- [4] Model creates the network;AdamWowns the optimizer state from chapter 7. - [5]
lr_atwarms up, then decays. The loop writes that value into the optimizer every step.- [6]eval_lossruns validation without gradients, so evaluation is cheaper and cannot accidentally train. - [7] Loop is forward pass, zero old gradients,
loss.backward(), optimizer step, occasional validation. - [8] Checkpoint saves the learned weights so later scripts can generate, fine-tune, or quantize.
Trace one batch carefully: x is tokens t0..t63; y is tokens t1..t64. The model learns to predict each next token from the tokens before it.
Then run it:
python -m scripts.trainpython -m scripts.trainpython -m scripts.trainOn CPU (no GPU/MPS), training to 5 000 steps takes ~10 minutes. On Apple Silicon MPS or a recent NVIDIA card, it's ~2 minutes.
3. What you should see
The console output settles into a rhythm:
step 0 | lr 0.00000 | train 10.83 | val 10.82 | t 0.5s
step 250 | lr 0.00030 | train 5.41 | val 5.43 | t 28.4s
step 500 | lr 0.00029 | train 4.62 | val 4.66 | t 56.1s
step 1000 | lr 0.00026 | train 3.95 | val 4.02 | t 113.0s
step 2500 | lr 0.00017 | train 3.21 | val 3.36 | t 280.5s
step 4999 | lr 0.00003 | train 2.78 | val 3.07 | t 560.8s
Numbers will vary by hardware and random seed, but the trajectory is universal:
- Step 0: loss ≈
log(vocab_size) ≈ 10.8. The model knows nothing — it's uniform. - Step 250: loss has dropped 5+ nats. The model has learned the token frequency distribution.
- By step 5000: loss around 2.8-3.0. The model is producing locally coherent English most of the time. It's not eloquent, but it has structure.
The gap between train and val loss widens slightly over training. That's overfitting — small for our model (~14M params on ~272k training tokens), big if we scaled up. Dropout, weight decay, and more data are the standard cures.
4. Sanity-check what the model learned
After training, do a quick generation to see what came out. Save as scripts/sample.py:
"""scripts/sample.py — generate a short continuation from a trained model."""
import torch
import tiktoken
from llm.model import GPT, GPTConfig
# [1]
device = "mps" if torch.backends.mps.is_available() else (
"cuda" if torch.cuda.is_available() else "cpu"
)
cfg = GPTConfig()
model = GPT(cfg).to(device)
model.load_state_dict(torch.load("checkpoints/model.pt", map_location=device))
model.eval()
# [2]
enc = tiktoken.get_encoding("gpt2")
prompt = "ROMEO:"
idx = torch.tensor([enc.encode_ordinary(prompt)], device=device)
with torch.no_grad():
for _ in range(100):
# [3]
idx_cond = idx if idx.size(1) <= cfg.block_size else idx[:, -cfg.block_size:]
# [4]
logits, _ = model(idx_cond)
probs = torch.softmax(logits[:, -1, :], dim=-1)
# [5]
next_id = torch.multinomial(probs, num_samples=1)
# [6]
idx = torch.cat([idx, next_id], dim=1)
print(enc.decode(idx[0].tolist()))Generation is the training loop without the training:
- [1] reloads the trained checkpoint on the best available device.
- [2] encodes the prompt into token ids.
- [3] crops to the last
block_sizeids, because the model was trained with that maximum context length. - [4] runs the model and keeps only
logits[:, -1, :], the next-token scores after the final prompt token. - [5] samples one id from the probability distribution.
- [6] appends that id and repeats.
python -m scripts.samplepython -m scripts.samplepython -m scripts.sampleYou'll see something Shakespeare-shaped: line breaks, character names in caps, archaic phrasing. The grammar is loose; the semantics drift. It's recognizably the same kind of language as the training data — which is the bare minimum bar for a language model.
Chapter 14 picks up the generation story properly: how the next-token sampling actually works, and what knobs you have to control it.
Recap
- Training = the chapter-5 loop, with scaffolding. Sample batches, forward, backward, step,
log. - AdamW (chapter 7) is the standard optimizer for transformers.
betas=(0.9, 0.95)are the usual values. - Linear warmup + cosine decay is the standard learning-rate schedule. Short warmup avoids early instability; cosine decay finishes the training cleanly. - Loss starts atlog(vocab_size)for a fresh model. Reachinglog(vocab_size) / 2is a sign the model has learned token frequencies. Reachinglog(vocab_size) / 4is when it's producing locally coherent text. - CPU is slow but works. MPS/CUDA accelerate dense matmuls 5-20×. The script picks automatically. - Your local project now hasscripts/train.py,scripts/sample.py, andcheckpoints/model.ptafter a successful run.
Going further
- Karpathy's nanoGPT training script — the reference. Same structure as ours, plus DDP, gradient accumulation, FP16, checkpointing.
- Chinchilla scaling laws — Hoffmann et al. on how much data you actually need for a given model size (TL;DR: roughly 20× more tokens than parameters).
Next up: generation and sampling — your model produces a distribution at every position. We've been picking with multinomial; let's look at the alternatives.