Chapitre 13 · 16 min
La boucle d’entraînement
Entraîne le modèle PyTorch local sur data/train.bin, écris des checkpoints et génère le premier texte de ton propre modèle.
Tu as data/train.bin (chapitre 11) et llm/model.py (chapitre 12). Maintenant, on entraîne.
L’entraînement est fondamentalement la boucle du chapitre 5 :
- Échantillonner un batch.
- Passe avant : obtenir logits et loss.
- Passe arrière : calculer les gradients.
- Faire un pas d’optimiseur.
- Répéter. Ce qui change à l’échelle, c’est tout autour : batching, sampling, learning-rate schedule, évaluation, checkpoints.
1. Le schedule de learning rate
Pour les transformers, un learning rate fixe est rarement idéal. Le motif standard :
- Warmup linéaire au début, pour éviter que la loss explose pendant l’initialisation aléatoire.
- Décroissance cosinus vers un plancher bas, pour finir proprement. Écris ce schedule.
Code · JavaScript
Fais varier le warmup. Trop court avec un LR élevé est risqué ; trop long gaspille du compute. Une heuristique courante : warmup ≈ 1-5 % des steps.
2. Le script d’entraînement
Sauvegarde ceci comme scripts/train.py :
"""scripts/train.py — train the model from chapter 12 on data/train.bin."""
import math
import time
from pathlib import Path
import numpy as np
import torch
from llm.model import GPT, GPTConfig
# --- config ---
# [1]
cfg = GPTConfig()
batch_size = 32
max_steps = 5000
eval_every = 250
warmup_steps = 100
max_lr = 3e-4
min_lr = 3e-5
device = "mps" if torch.backends.mps.is_available() else (
"cuda" if torch.cuda.is_available() else "cpu"
)
print(f"training on {device}")
# --- data ---
# [2]
train_ids = np.memmap("data/train.bin", dtype=np.uint16, mode="r")
val_ids = np.memmap("data/val.bin", dtype=np.uint16, mode="r")
# [3]
def get_batch(split):
data = train_ids if split == "train" else val_ids
ix = torch.randint(len(data) - cfg.block_size - 1, (batch_size,))
x = torch.stack([torch.from_numpy(data[i : i + cfg.block_size].astype(np.int64)) for i in ix])
y = torch.stack([torch.from_numpy(data[i + 1 : i + 1 + cfg.block_size].astype(np.int64)) for i in ix])
return x.to(device), y.to(device)
# --- model ---
# [4]
model = GPT(cfg).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=max_lr, betas=(0.9, 0.95))
# [5]
def lr_at(step):
if step < warmup_steps:
return max_lr * (step / warmup_steps)
t = (step - warmup_steps) / (max_steps - warmup_steps)
return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * t))
@torch.no_grad()
# [6]
def eval_loss():
model.eval()
losses = []
for _ in range(20):
x, y = get_batch("val")
_, loss = model(x, y)
losses.append(loss.item())
model.train()
return sum(losses) / len(losses)
# --- loop ---
t0 = time.time()
# [7]
for step in range(max_steps):
lr = lr_at(step)
for g in optimizer.param_groups:
g["lr"] = lr
x, y = get_batch("train")
_, loss = model(x, y)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
if step % eval_every == 0 or step == max_steps - 1:
val = eval_loss()
dt = time.time() - t0
print(f"step {step:5d} | lr {lr:.5f} | train {loss.item():.4f} | val {val:.4f} | t {dt:.1f}s")
# [8]
Path("checkpoints").mkdir(exist_ok=True)
torch.save(model.state_dict(), "checkpoints/model.pt")
print("saved checkpoints/model.pt")Le script a cinq pièces :
- [1] Config : taille de modèle, batch, LR, device.
- [2] Data :
np.memmapouvre les fichiers sans tout charger en RAM. - [3]
get_batchprend des fenêtres aléatoires et les décale d’un token pour créer(x, y). - [4] Model : crée le réseau ;
AdamWgarde l’état d’optimisation. - [5]
lr_atfait warmup puis décroissance. - [6]
eval_lossvalide sans gradients. - [7] Loop : forward, zero grad, backward, step, validation périodique.
- [8] Checkpoint : sauvegarde les poids appris.
Trace un batch : x contient t0..t63, y contient t1..t64. Le modèle apprend à prédire chaque prochain token à partir des précédents.
Lance :
python -m scripts.trainpython -m scripts.trainpython -m scripts.trainSur CPU, 5000 steps prennent environ dix minutes ; sur MPS ou CUDA, plutôt quelques minutes.
3. Ce que tu devrais voir
La console finit par ressembler à :
step 0 | lr 0.00000 | train 10.83 | val 10.82 | t 0.5s
step 250 | lr 0.00030 | train 5.41 | val 5.43 | t 28.4s
step 500 | lr 0.00029 | train 4.62 | val 4.66 | t 56.1s
step 1000 | lr 0.00026 | train 3.95 | val 4.02 | t 113.0s
step 2500 | lr 0.00017 | train 3.21 | val 3.36 | t 280.5s
step 4999 | lr 0.00003 | train 2.78 | val 3.07 | t 560.8sLes nombres varient, mais la trajectoire est universelle :
- Step 0 : loss ≈
log(vocab_size) ≈ 10.8. Le modèle ne sait rien. - Step 250 : la loss chute fortement ; le modèle apprend les fréquences de tokens.
- Step 5000 : loss autour de 2.8-3.0 ; le texte devient localement cohérent.
L’écart entre train et validation s’élargit un peu : c’est de l’overfitting. Plus de données, dropout et weight decay sont les remèdes classiques.
4. Vérifier ce que le modèle a appris
Après l’entraînement, génère un court texte. Sauvegarde scripts/sample.py :
"""scripts/sample.py — generate a short continuation from a trained model."""
import torch
import tiktoken
from llm.model import GPT, GPTConfig
# [1]
device = "mps" if torch.backends.mps.is_available() else (
"cuda" if torch.cuda.is_available() else "cpu"
)
cfg = GPTConfig()
model = GPT(cfg).to(device)
model.load_state_dict(torch.load("checkpoints/model.pt", map_location=device))
model.eval()
# [2]
enc = tiktoken.get_encoding("gpt2")
prompt = "ROMEO:"
idx = torch.tensor([enc.encode_ordinary(prompt)], device=device)
with torch.no_grad():
for _ in range(100):
# [3]
idx_cond = idx if idx.size(1) <= cfg.block_size else idx[:, -cfg.block_size:]
# [4]
logits, _ = model(idx_cond)
probs = torch.softmax(logits[:, -1, :], dim=-1)
# [5]
next_id = torch.multinomial(probs, num_samples=1)
# [6]
idx = torch.cat([idx, next_id], dim=1)
print(enc.decode(idx[0].tolist()))La génération est la boucle d’entraînement sans entraînement :
- [1] recharge le checkpoint.
- [2] encode le prompt.
- [3] garde les derniers
block_sizetokens. - [4] récupère les logits du prochain token.
- [5] échantillonne un id.
- [6] l’ajoute et répète.
Lance :
python -m scripts.samplepython -m scripts.samplepython -m scripts.sampleTu verras du texte vaguement shakespearien : retours à la ligne, noms de personnages, tournures archaïques. La grammaire reste fragile, le sens dérive, mais le style du corpus est reconnaissable.
Recap
- Entraîner = boucle du chapitre 5 avec infrastructure. - AdamW est l’optimiseur standard
des transformers. - Warmup linéaire + décroissance cosinus est le schedule classique. - La
loss démarre à
log(vocab_size)pour un modèle neuf. - CPU fonctionne mais lentement ; MPS/CUDA accélèrent fortement. - Ton projet local a maintenantscripts/train.py,scripts/sample.pyetcheckpoints/model.ptaprès un run réussi.
Pour aller plus loin
Prochaine étape : génération et sampling — ton modèle produit une distribution à chaque position. Regardons les façons de choisir dedans.