Skip to content
The loss curve

Chapitre 13 · 16 min

La boucle d’entraînement

Entraîne le modèle PyTorch local sur data/train.bin, écris des checkpoints et génère le premier texte de ton propre modèle.

Tu as data/train.bin (chapitre 11) et llm/model.py (chapitre 12). Maintenant, on entraîne.

L’entraînement est fondamentalement la boucle du chapitre 5 :

  1. Échantillonner un batch.
  2. Passe avant : obtenir logits et loss.
  3. Passe arrière : calculer les gradients.
  4. Faire un pas d’optimiseur.
  5. Répéter. Ce qui change à l’échelle, c’est tout autour : batching, sampling, learning-rate schedule, évaluation, checkpoints.

1. Le schedule de learning rate

Pour les transformers, un learning rate fixe est rarement idéal. Le motif standard :

  1. Warmup linéaire au début, pour éviter que la loss explose pendant l’initialisation aléatoire.
  2. Décroissance cosinus vers un plancher bas, pour finir proprement. Écris ce schedule.

Code · JavaScript

Fais varier le warmup. Trop court avec un LR élevé est risqué ; trop long gaspille du compute. Une heuristique courante : warmup ≈ 1-5 % des steps.

2. Le script d’entraînement

Sauvegarde ceci comme scripts/train.py :

"""scripts/train.py — train the model from chapter 12 on data/train.bin."""
import math
import time
from pathlib import Path
 
import numpy as np
import torch
 
from llm.model import GPT, GPTConfig
 
# --- config ---
# [1]
cfg = GPTConfig()
batch_size = 32
max_steps = 5000
eval_every = 250
warmup_steps = 100
max_lr = 3e-4
min_lr = 3e-5
device = "mps" if torch.backends.mps.is_available() else (
    "cuda" if torch.cuda.is_available() else "cpu"
)
print(f"training on {device}")
 
# --- data ---
# [2]
train_ids = np.memmap("data/train.bin", dtype=np.uint16, mode="r")
val_ids = np.memmap("data/val.bin", dtype=np.uint16, mode="r")
 
# [3]
def get_batch(split):
    data = train_ids if split == "train" else val_ids
    ix = torch.randint(len(data) - cfg.block_size - 1, (batch_size,))
    x = torch.stack([torch.from_numpy(data[i : i + cfg.block_size].astype(np.int64)) for i in ix])
    y = torch.stack([torch.from_numpy(data[i + 1 : i + 1 + cfg.block_size].astype(np.int64)) for i in ix])
    return x.to(device), y.to(device)
 
# --- model ---
# [4]
model = GPT(cfg).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=max_lr, betas=(0.9, 0.95))
 
# [5]
def lr_at(step):
    if step < warmup_steps:
        return max_lr * (step / warmup_steps)
    t = (step - warmup_steps) / (max_steps - warmup_steps)
    return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * t))
 
@torch.no_grad()
# [6]
def eval_loss():
    model.eval()
    losses = []
    for _ in range(20):
        x, y = get_batch("val")
        _, loss = model(x, y)
        losses.append(loss.item())
    model.train()
    return sum(losses) / len(losses)
 
# --- loop ---
t0 = time.time()
# [7]
for step in range(max_steps):
    lr = lr_at(step)
    for g in optimizer.param_groups:
        g["lr"] = lr
 
    x, y = get_batch("train")
    _, loss = model(x, y)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
 
    if step % eval_every == 0 or step == max_steps - 1:
        val = eval_loss()
        dt = time.time() - t0
        print(f"step {step:5d} | lr {lr:.5f} | train {loss.item():.4f} | val {val:.4f} | t {dt:.1f}s")
 
# [8]
Path("checkpoints").mkdir(exist_ok=True)
torch.save(model.state_dict(), "checkpoints/model.pt")
print("saved checkpoints/model.pt")

Le script a cinq pièces :

  • [1] Config : taille de modèle, batch, LR, device.
  • [2] Data : np.memmap ouvre les fichiers sans tout charger en RAM.
  • [3] get_batch prend des fenêtres aléatoires et les décale d’un token pour créer (x, y).
  • [4] Model : crée le réseau ; AdamW garde l’état d’optimisation.
  • [5] lr_at fait warmup puis décroissance.
  • [6] eval_loss valide sans gradients.
  • [7] Loop : forward, zero grad, backward, step, validation périodique.
  • [8] Checkpoint : sauvegarde les poids appris.

Trace un batch : x contient t0..t63, y contient t1..t64. Le modèle apprend à prédire chaque prochain token à partir des précédents.

Lance :

python -m scripts.train
python -m scripts.train
python -m scripts.train

Sur CPU, 5000 steps prennent environ dix minutes ; sur MPS ou CUDA, plutôt quelques minutes.

3. Ce que tu devrais voir

La console finit par ressembler à :

step     0 | lr 0.00000 | train 10.83 | val 10.82 | t 0.5s
step   250 | lr 0.00030 | train 5.41  | val 5.43  | t 28.4s
step   500 | lr 0.00029 | train 4.62  | val 4.66  | t 56.1s
step  1000 | lr 0.00026 | train 3.95  | val 4.02  | t 113.0s
step  2500 | lr 0.00017 | train 3.21  | val 3.36  | t 280.5s
step  4999 | lr 0.00003 | train 2.78  | val 3.07  | t 560.8s

Les nombres varient, mais la trajectoire est universelle :

  • Step 0 : loss ≈ log(vocab_size) ≈ 10.8. Le modèle ne sait rien.
  • Step 250 : la loss chute fortement ; le modèle apprend les fréquences de tokens.
  • Step 5000 : loss autour de 2.8-3.0 ; le texte devient localement cohérent.

L’écart entre train et validation s’élargit un peu : c’est de l’overfitting. Plus de données, dropout et weight decay sont les remèdes classiques.

4. Vérifier ce que le modèle a appris

Après l’entraînement, génère un court texte. Sauvegarde scripts/sample.py :

"""scripts/sample.py — generate a short continuation from a trained model."""
import torch
import tiktoken
 
from llm.model import GPT, GPTConfig
 
# [1]
device = "mps" if torch.backends.mps.is_available() else (
    "cuda" if torch.cuda.is_available() else "cpu"
)
cfg = GPTConfig()
model = GPT(cfg).to(device)
model.load_state_dict(torch.load("checkpoints/model.pt", map_location=device))
model.eval()
 
# [2]
enc = tiktoken.get_encoding("gpt2")
prompt = "ROMEO:"
idx = torch.tensor([enc.encode_ordinary(prompt)], device=device)
 
with torch.no_grad():
    for _ in range(100):
        # [3]
        idx_cond = idx if idx.size(1) <= cfg.block_size else idx[:, -cfg.block_size:]
        # [4]
        logits, _ = model(idx_cond)
        probs = torch.softmax(logits[:, -1, :], dim=-1)
        # [5]
        next_id = torch.multinomial(probs, num_samples=1)
        # [6]
        idx = torch.cat([idx, next_id], dim=1)
 
print(enc.decode(idx[0].tolist()))

La génération est la boucle d’entraînement sans entraînement :

  • [1] recharge le checkpoint.
  • [2] encode le prompt.
  • [3] garde les derniers block_size tokens.
  • [4] récupère les logits du prochain token.
  • [5] échantillonne un id.
  • [6] l’ajoute et répète.

Lance :

python -m scripts.sample
python -m scripts.sample
python -m scripts.sample

Tu verras du texte vaguement shakespearien : retours à la ligne, noms de personnages, tournures archaïques. La grammaire reste fragile, le sens dérive, mais le style du corpus est reconnaissable.

Recap

  • Entraîner = boucle du chapitre 5 avec infrastructure. - AdamW est l’optimiseur standard des transformers. - Warmup linéaire + décroissance cosinus est le schedule classique. - La loss démarre à log(vocab_size) pour un modèle neuf. - CPU fonctionne mais lentement ; MPS/CUDA accélèrent fortement. - Ton projet local a maintenant scripts/train.py, scripts/sample.py et checkpoints/model.pt après un run réussi.

Pour aller plus loin

Prochaine étape : génération et sampling — ton modèle produit une distribution à chaque position. Regardons les façons de choisir dedans.