Skip to content
The loss curve

Chapitre 20 · 12 min

Parler à ton modèle

Emballe ton checkpoint entraîné dans une petite boucle de chat en terminal pour finir avec un modèle que tu peux vraiment prompter.

Tu as les pièces : , modèle, de base, SFT du chapitre 17, , LoRA, quantification. La pièce manquante n’est pas une nouvelle théorie de réseau neuronal — c’est une interface.

Les produits chat modernes restent des modèles autorégressifs. L’astuce consiste à formater le prompt comme une conversation, générer le tour assistant, l’ajouter à l’historique, recommencer — en utilisant exactement le chat template sur lequel le modèle a été fine-tuné.

Ce chapitre ajoute ce wrapper, plus l’optimisation qui rend les générations longues supportables : un KV cache.

1. Le chat est du formatage de prompt

Un générateur simple reçoit du texte comme :

ROMEO:

Une boucle de chat donne plus de structure :

System: You are a small language model trained in The loss curve.
User: explain what a token is
Assistant:

Le modèle prédit toujours le prochain token. La seule différence est que le préfixe lui indique quel interlocuteur doit parler ensuite. La qualité dépend donc des données. Un modèle surtout entraîné sur Shakespeare ne devient pas soudain un assistant utile. Mais l’interface est la même : prompt, sample, append, repeat.

2. Ajouter la boucle terminal

Crée scripts/chat.py :

"""scripts/chat.py — talk to the checkpoint from your terminal."""
from __future__ import annotations
 
import torch
import torch.nn.functional as F
import tiktoken
 
from llm.model import GPT, GPTConfig
 
 
SYSTEM_PROMPT = "You answer questions briefly."
MAX_NEW_TOKENS = 160
TEMPERATURE = 0.8
TOP_K = 50
TOP_P = 0.9
# le SFT du ch.16 finit chaque tour assistant par "\n" : c'est le stop principal.
STOP_STRINGS = ["\nUser:", "\nSystem:", "\n"]
 
 
# [1]
def pick_device() -> str:
    if torch.backends.mps.is_available():
        return "mps"
    if torch.cuda.is_available():
        return "cuda"
    return "cpu"
 
 
# [2]
def load_model(device: str) -> GPT:
    cfg = GPTConfig()
    model = GPT(cfg).to(device)
    state = torch.load("checkpoints/model_sft.pt", map_location=device)
    model.load_state_dict(state)
    model.eval()
    return model
 
 
# [3]
def render_prompt(turns: list[tuple[str, str]]) -> str:
    lines = [f"System: {SYSTEM_PROMPT}"]
    for role, text in turns:
        lines.append(f"{role}: {text.strip()}")
    lines.append("Assistant:")
    return "\n".join(lines)
 
 
# [4]
def sample_next(logits: torch.Tensor) -> torch.Tensor:
    logits = logits / TEMPERATURE
    probs = F.softmax(logits, dim=-1)
 
    if TOP_K is not None:
        top_values, _ = probs.topk(TOP_K)
        probs[probs < top_values[..., -1, None]] = 0
        probs = probs / probs.sum(dim=-1, keepdim=True)
 
    if TOP_P < 1.0:
        sorted_probs, sorted_idx = probs.sort(descending=True, dim=-1)
        cumulative = sorted_probs.cumsum(dim=-1)
        mask = cumulative > TOP_P
        mask[..., 0] = False
        sorted_probs[mask] = 0
        probs = torch.zeros_like(probs).scatter_(-1, sorted_idx, sorted_probs)
        probs = probs / probs.sum(dim=-1, keepdim=True)
 
    return torch.multinomial(probs, num_samples=1)
 
 
# [5]
@torch.no_grad()
def generate_reply(model: GPT, enc, prompt: str, device: str) -> str:
    cfg = model.cfg
    idx = torch.tensor([enc.encode_ordinary(prompt)], device=device)
    pieces: list[str] = []
 
    for _ in range(MAX_NEW_TOKENS):
        idx_cond = idx if idx.size(1) <= cfg.block_size else idx[:, -cfg.block_size :]
        logits, _ = model(idx_cond)
        next_id = sample_next(logits[:, -1, :])
        idx = torch.cat([idx, next_id], dim=1)
 
        pieces.append(enc.decode([int(next_id.item())]))
        text = "".join(pieces)
        for stop in STOP_STRINGS:
            if stop in text:
                return text.split(stop, 1)[0].strip()
 
    return "".join(pieces).strip()
 
 
# [6]
def main() -> None:
    device = pick_device()
    enc = tiktoken.get_encoding("gpt2")
    model = load_model(device)
    turns: list[tuple[str, str]] = []
 
    print("The loss curve chat. Type /quit to stop.\n")
    while True:
        user_text = input("You: ").strip()
        if user_text in {"/quit", "/exit"}:
            break
        if not user_text:
            continue
 
        turns.append(("User", user_text))
        prompt = render_prompt(turns)
        reply = generate_reply(model, enc, prompt, device)
        print(f"Model: {reply}\n")
        turns.append(("Assistant", reply))
 
 
if __name__ == "__main__":
    main()

Lis-le comme une fine enveloppe autour du chapitre 14 :

  • [1] pick_device prend le meilleur accélérateur disponible, sinon CPU.
  • [2] load_model recharge checkpoints/model_sft.pt (si tu as sauté le chapitre 17, pointe sur model.pt).
  • [3] render_prompt transforme l’historique de conversation en préfixe texte.
  • [4] sample_next reprend température + top-K + top-P.
  • [5] generate_reply croppe à la fenêtre de contexte, sample un token à la fois et s’arrête si le modèle commence un nouveau tour.
  • [6] main est la boucle terminal : lire l’entrée, générer, afficher, ajouter la réponse.

Lance :

python -m scripts.chat
python -m scripts.chat
python -m scripts.chat

Tu devrais obtenir :

The loss curve chat. Type /quit to stop.
 
You: what is a token?
Model: ...

La réponse devrait au moins tenter une réponse brève au lieu de continuer en pseudo-shakespearien. Tu remarqueras aussi que c’est lent — la section suivante règle ça.

3. Rendre la génération rapide : le KV cache

Une génération de 160 sur CPU prend 30 à 60 secondes. À chaque pas, le modèle traite tout le contexte, alors que seul le dernier a changé. Pour T et L couches, chaque pas fait O(L · T²) de travail.

Le KV cache stocke les tenseurs K et V de chaque couche d’ après la première passe sur le prompt. À chaque pas suivant, on ne calcule K et V que pour le nouveau , on les concatène au cache, et la nouvelle query attend tout le cache. Le travail par pas passe à O(L · T).

Quantifie la différence. La cellule calcule le travail d’ total sur un prompt de 32 + 100 générés, naïf vs cached :

Code · JavaScript

Le support est déjà dans llm/model.py

Le llm/model.py du chapitre 12 a été écrit avec le KV cache en tête. Regarde CausalSelfAttention.forward, Block.forward et GPT.forward : chacun prend un argument optionnel past_kv / past_kvs, le module d’attention concatène K/V au cache si fourni et saute le masque causal quand la query est un seul . Quand tu passes past_kvs=None (le défaut en ), le comportement est inchangé. Cache activé = speedup d’. Pas de patch rétroactif.

Trois subtilités à garder en tête quand tu utilises le cache :

  • Les embeddings de position se décalent de la taille du cache. Le nouveau est à la position cached_length, pas 0. GPT.forward lit past_kvs[0][0].size(2) et décale l’arange en conséquence. Sans ce décalage, chaque généré reçoit position 0 et le modèle s’écroule.
  • Le masque causal n’est appliqué que pendant la passe préfixe. Pendant la génération cached, T = 1 et la query doit voir tout le cache ; pas besoin de masque. CausalSelfAttention vérifie past_kv is None pour ça.
  • L’ reste inchangé — quand past_kvs is None, le forward est strictement celui du chapitre 12.

La nouvelle boucle de génération

@torch.no_grad()
def generate_reply(model, enc, prompt, device):
    cfg = model.cfg
    prompt_ids = enc.encode_ordinary(prompt)
    idx = torch.tensor([prompt_ids], device=device)
 
    # [1] une passe complète sur le prompt, peuple le cache
    logits, past_kvs = model(idx[:, : cfg.block_size])
 
    pieces = []
    for _ in range(MAX_NEW_TOKENS):
        next_id = sample_next(logits[:, -1, :])
        # [2] chaque pas suivant ne traite QU'UN nouveau token
        logits, past_kvs = model(next_id, past_kvs=past_kvs)
        pieces.append(enc.decode([int(next_id.item())]))
        text = "".join(pieces)
        for stop in STOP_STRINGS:
            if stop in text:
                return text.split(stop, 1)[0].strip()
    return "".join(pieces).strip()

Empiriquement sur ton modèle 14M en CPU : 100 passent de ~30s à ~3s. Plus le prompt est long, plus le gain est grand.

4. Pourquoi le modèle ne ressemble toujours pas à ChatGPT

Trois ingrédients rendent un modèle assistant-like :

  1. Capacité de base — assez de et de pré-entraînement. Ton modèle : 14M, ~272k . Très loin de la frontière.
  2. Données d’instruction — exemples en forme User: ... Assistant: .... Ton modèle : 30-100 exemples du ch.16. Loin des 13k+ d’InstructGPT.
  3. Preference tuning — optimisation vers des réponses utiles. Ton modèle : zéro. Hors scope du livre.

Tu as ajouté (2) au chapitre 17 — c’est ce qui fait que le chat suit la forme. (1) demande une échelle qu’on n’a pas. (3) vit dans un workflow séparé qu’on n’a pas construit.

À ta taille, attends-toi à des réponses correctes en forme mais pauvres en contenu. Le bon move : viser un domaine étroit où ton petit modèle peut être suffisamment bon.

5. Aller plus loin

  • Plus de données SFT dans le domaine que tu cibles. 30 exemples enseignent la forme ; 300 enseignent le vocabulaire ; 3000 enseignent les patterns.
  • LoRA par-dessus le SFT (chapitre 18) si tu veux plusieurs adaptateurs domaine sur un même modèle de base.
  • Un base model plus gros. Remplace les poids de llm/model.py par GPT-2 small, puis SFT sur tes données. Même boucle, même chat — seulement les poids et le temps changent.

6. Ce que tu as maintenant

Tu n’as pas ChatGPT. Tu as la plus petite version honnête de la même forme produit :

  • un de base + SFT
  • un
  • une fenêtre de contexte
  • une politique de
  • un chat template (System / User / Assistant)
  • un historique de tours
  • une boucle de génération avec KV cache
  • un REPL terminal

C’est le pont entre « je comprends comment les fonctionnent » et « je peux livrer un petit prototype ».

Recap

  • Le chat est de la structurée. Le modèle prédit toujours le prochain ; le template indique le tour. - Réutilise exactement le chat template du chapitre 17 à l’. Toute dérive entre train-time et inference-time coûte la majorité du gain SFT. - scripts/chat.py charge model_sft.pt, garde l’historique, un tour, boucle. - Les stop strings empêchent le modèle de dépasser un tour assistant ; le SFT lui a appris \n comme fin naturelle. - Le KV cache fait passer la de 30s à 3s pour 100 . llm/model.py accepte déjà past_kvs depuis le chapitre 12 ; ce chapitre utilise simplement cet argument dans la boucle de génération. - Le projet local a maintenant un endpoint utilisable : python -m scripts.chat.

Pour aller plus loin

Prochaine étape : livrer quelque chose d’utile — le capstone. Choisis un domaine étroit, génère ~150 exemples SFT, fine-tune GPT-2 small, évalue côte à côte, et termine le livre avec un assistant spécialisé qui marche au lieu d’un assistant générique qui parle mal.