Skip to content
The loss curve

Chapitre 17 · 16 min

Donner des instructions à ton modèle

Fine-tune le checkpoint du chapitre 13 sur un petit dataset prompt/completion pour qu’il réponde réellement aux instructions, au lieu de continuer en pseudo-shakespearien.

À la fin du chapitre 14, tu as échantillonné ton modèle entraîné. Le résultat avait la forme du shakespearien : retours à la ligne, noms en majuscules, vocabulaire archaïque. Demande-lui maintenant « combien font deux plus deux ? » et tu obtiens une continuation en vers, pas une réponse.

Ce gap est le sujet du chapitre 16. Ce chapitre enseigne la technique la moins chère et la plus directe pour le combler : le fine-tuning supervisé, ou SFT.

Tu vas reprendre le checkpoint du chapitre 13, continuer son sur un petit ensemble d’exemples prompt/completion, et regarder la forme se déplacer. La même recette — à très grande échelle — est celle qu’OpenAI a utilisée pour transformer GPT-3 (un modèle de complétion brut) en InstructGPT.

1. Le chat template

Un modèle pré-entraîné ne sait pas que User: et Assistant: sont spéciaux. Il ne connaît que les statistiques de . Si tu l’ sur un où ces marqueurs sont cohérents, il apprend à les associer à un changement de rôle.

La convention la plus simple — proche de ce que font les systèmes de production — est un frame à trois rôles :

System: You answer questions briefly and accurately.
User: What is two plus two?
Assistant: Four.

Garde le même frame pour chaque exemple d’. Le modèle apprend le rythme : il voit User: ..., il produit Assistant: .... Les systèmes de production utilisent des templates plus riches (ChatML, [INST] de Llama, etc.) mais le principe est identique.

Écris le renderer. La cellule prend un exemple et le coupe en prompt (la partie montrée au modèle) et completion (la partie qu’il doit produire). Les rôles sont colorés pour visualiser le frame.

Code · JavaScript

2. L’astuce : masquer la sur le prompt

Si tu entraînes naïvement sur la séquence entière, le modèle doit prédire chaque — y compris la question de l’utilisateur. C’est du gâchis : les du prompt sont donnés, pas générés. Pire, s’entraîner dessus pousse le modèle à imiter la formulation des utilisateurs au lieu de celle des assistants.

La correction est la pièce de SFT la plus souvent omise dans les tutoriels : n’entraîne que sur les de l’assistant. Construis un masque sur la séquence qui vaut 0 partout sauf sur les à générer, multiplie la cross-entropy par par ce masque, et moyenne sur les positions masquées uniquement.

LSFT=tmtCE(y^t,yt)tmt\mathcal{L}_{\text{SFT}} = \frac{\sum_{t} m_t \cdot \text{CE}(\hat{y}_t, y_t)}{\sum_t m_t}

avec m_t = 1 si le appartient à un tour assistant, 0 sinon.

Sans ce masquage, SFT converge 5 à 10× plus lentement et le modèle suit moins bien la forme. Avec, 50 à 200 exemples suffisent pour voir la bascule.

Construis le masque. La cellule tokenize le même exemple et te demande d’écrire le masque par . Les en vert contribuent à la ; les autres sont masqués.

Code · JavaScript

3. Un mini dataset SFT

Les vrais datasets SFT contiennent des milliers à des millions d’exemples. Pas besoin d’en arriver là pour sentir l’effet. Crée data/sft.jsonl :

{"system": "You answer questions briefly.", "user": "What is two plus two?", "assistant": "Four."}
{"system": "You answer questions briefly.", "user": "Capital of France?", "assistant": "Paris."}
{"system": "You answer questions briefly.", "user": "How many sides does a triangle have?", "assistant": "Three."}

Vise 30 à 100 lignes. Pour un vrai run, passe à quelques centaines d’exemples sur un domaine étroit : politique de remboursement, code dans un langage précis, résumés dans ton ton. Plus le domaine est étroit, plus le dataset peut être petit tout en restant utile.

4. Le script SFT

Sauvegarde scripts/sft.py :

"""scripts/sft.py — supervised fine-tune the chapter-13 checkpoint."""
import json
from pathlib import Path
 
import numpy as np
import torch
import torch.nn.functional as F
import tiktoken
 
from llm.model import GPT, GPTConfig
 
 
cfg = GPTConfig()
batch_size = 8
max_steps = 500
lr = 1e-4
device = "mps" if torch.backends.mps.is_available() else (
    "cuda" if torch.cuda.is_available() else "cpu"
)
 
enc = tiktoken.get_encoding("gpt2")
 
 
def render(ex):
    prompt = (
        f"System: {ex['system']}\n"
        f"User: {ex['user']}\n"
        f"Assistant: "
    )
    completion = ex["assistant"] + "\n"
    return prompt, completion
 
 
records = []
for line in Path("data/sft.jsonl").read_text().splitlines():
    line = line.strip()
    if not line:
        continue
    ex = json.loads(line)
    prompt, completion = render(ex)
    records.append((enc.encode_ordinary(prompt), enc.encode_ordinary(completion)))
 
 
def make_batch():
    idx = np.random.randint(0, len(records), size=batch_size)
    seqs, masks = [], []
    for i in idx:
        prompt_ids, completion_ids = records[i]
        ids = (prompt_ids + completion_ids)[: cfg.block_size]
        mask = ([0] * len(prompt_ids) + [1] * len(completion_ids))[: cfg.block_size]
        pad = cfg.block_size - len(ids)
        seqs.append(ids + [0] * pad)
        masks.append(mask + [0] * pad)
    x = torch.tensor(seqs, dtype=torch.long, device=device)
    m = torch.tensor(masks, dtype=torch.float, device=device)
    return x, m
 
 
model = GPT(cfg).to(device)
model.load_state_dict(torch.load("checkpoints/model.pt", map_location=device))
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
 
for step in range(max_steps):
    x, mask = make_batch()
    inputs = x[:, :-1]
    targets = x[:, 1:]
    target_mask = mask[:, 1:]
 
    logits, _ = model(inputs)
    per_token = F.cross_entropy(
        logits.reshape(-1, logits.size(-1)),
        targets.reshape(-1),
        reduction="none",
    ).reshape(targets.shape)
    loss = (per_token * target_mask).sum() / target_mask.sum().clamp(min=1)
 
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
 
    if step % 50 == 0 or step == max_steps - 1:
        print(f"step {step:4d} | loss {loss.item():.4f}")
 
Path("checkpoints").mkdir(exist_ok=True)
torch.save(model.state_dict(), "checkpoints/model_sft.pt")

Lis-le comme la loop du chapitre 13 avec deux changements chirurgicaux :

  • render met chaque exemple dans le chat template. Le même format devra être réutilisé à l’.
  • make_batch concatène prompt et completion et construit un masque 0/1 qui n’active que les de l’assistant. Le padding reçoit 0 pour ne pas fuiter dans la .
  • load_state_dict repart du du chapitre 13 — SFT continue l’, ne le redémarre pas. Le est plus petit (1e-4) parce que le modèle parle déjà : on veut un coup de pouce, pas un séisme.
  • reduction="none" garde la cross-entropy par , qu’on multiplie ensuite par le masque avant moyenne — c’est le cœur de SFT.

Lance :

python -m scripts.sft
python -m scripts.sft
python -m scripts.sft

500 steps en CPU = environ deux minutes. La chute typiquement de ~6 à ~2 en 200 steps — vite, parce que le modèle parle déjà et n’a qu’à apprendre la forme du chat.

5. Échantillonner depuis le modèle SFT

Sauvegarde scripts/sample_sft.py :

"""scripts/sample_sft.py — sample with chat template from the SFT checkpoint."""
import torch
import tiktoken
 
from llm.model import GPT, GPTConfig
 
 
device = "mps" if torch.backends.mps.is_available() else (
    "cuda" if torch.cuda.is_available() else "cpu"
)
cfg = GPTConfig()
model = GPT(cfg).to(device)
model.load_state_dict(torch.load("checkpoints/model_sft.pt", map_location=device))
model.eval()
 
enc = tiktoken.get_encoding("gpt2")
prompt = (
    "System: You answer questions briefly.\n"
    "User: What is two plus two?\n"
    "Assistant: "
)
prompt_ids = enc.encode_ordinary(prompt)
idx = torch.tensor([prompt_ids], device=device)
newline = enc.encode_ordinary("\n")[0]
 
with torch.no_grad():
    for _ in range(40):
        ctx = idx if idx.size(1) <= cfg.block_size else idx[:, -cfg.block_size :]
        logits, _ = model(ctx)
        probs = torch.softmax(logits[:, -1, :] / 0.7, dim=-1)
        next_id = torch.multinomial(probs, num_samples=1)
        if next_id.item() == newline and idx.size(1) > len(prompt_ids):
            break
        idx = torch.cat([idx, next_id], dim=1)
 
print(enc.decode(idx[0].tolist()))

Deux différences avec le sampler du chapitre 14 :

  • Le prompt utilise le même chat template sur lequel le modèle a été fine-tuné. Si tu dévies du format, le modèle perd la majorité du gain SFT.
  • Le saut de ligne est un stop , pour finir à la première ligne d’assistant. Pour des chats multi-tours, on utilise des stop strings plus longues (\nUser: etc.) — pattern couvert au chapitre 20.

Lance et compare au sampler du chapitre 14 sur la même question. Le modèle de base continue en pseudo-shakespearien. Le modèle SFT produit quelque chose de formé comme une réponse — et si ton dataset n’a que 30 exemples, ces réponses sont en grande partie un remix de l’. C’est normal : ce que tu as démontré, c’est la forme. La qualité du contenu, à partir d’ici, est fonction de la taille et de la qualité du dataset — c’est là que les petits modèles gagnent leur place quand le domaine est étroit et privé.

6. D’où viennent les vrais datasets SFT

InstructGPT (OpenAI) a utilisé environ 13 000 exemples écrits à la main pour l’étape SFT, avec un pool bien plus large réservé à l’ du reward model qui suivait. Les datasets SFT open-source modernes sont en général :

  • Curés par des humains : HH-RLHF d’Anthropic, ShareGPT, OpenAssistant.
  • Distillés depuis de plus gros modèles : Alpaca (sorties LLaMA filtrées), Dolly (employés Databricks).
  • Minés par domaine : transformer des repos de code, des tickets de support ou de la documentation en paires prompt/completion programmatiquement.

Pour un projet à ton échelle, l’angle qui paie le plus est le SFT spécifique à un domaine sur des données que toi seul as. Un SFT générique contre un dataset ouvert ne produit qu’une version dégradée de ce qui traîne déjà sur Hugging Face. Le SFT étroit et privé est l’endroit où les petits modèles gagnent leur place.

Recap

  • SFT transforme un prédicteur next- en modèle qui suit la forme du chat. Même architecture, même loop ; données et différentes. - Le chat template est un frame textuel cohérent (System / User / Assistant) que le modèle apprend à reconnaître. - Le masquage de sur le prompt est l’astuce centrale. Seuls les de l’assistant comptent. - Repars du du chapitre 13 avec un plus petit. - Ton projet local a maintenant data/sft.jsonl, scripts/sft.py, scripts/sample_sft.py et checkpoints/model_sft.pt.

Pour aller plus loin

Prochaine étape : LoRA — même idée (fine-tuner sur un nouvel objectif), mais avec un entraînement -efficient pour que chaque adaptateur reste minuscule.