Chapitre 20 · 12 min
Parler à ton modèle
Emballe ton checkpoint entraîné dans une petite boucle de chat en terminal pour finir avec un modèle que tu peux vraiment prompter.
Tu as les pièces : , modèle, de base, SFT du chapitre 17, , LoRA, quantification. La pièce manquante n’est pas une nouvelle théorie de réseau neuronal — c’est une interface.
Les produits chat modernes restent des modèles autorégressifs. L’astuce consiste à formater le prompt comme une conversation, générer le tour assistant, l’ajouter à l’historique, recommencer — en utilisant exactement le chat template sur lequel le modèle a été fine-tuné.
Ce chapitre ajoute ce wrapper, plus l’optimisation qui rend les générations longues supportables : un KV cache.
1. Le chat est du formatage de prompt
Un générateur simple reçoit du texte comme :
ROMEO:Une boucle de chat donne plus de structure :
System: You are a small language model trained in The loss curve.
User: explain what a token is
Assistant:Le modèle prédit toujours le prochain token. La seule différence est que le préfixe lui indique quel interlocuteur doit parler ensuite. La qualité dépend donc des données. Un modèle surtout entraîné sur Shakespeare ne devient pas soudain un assistant utile. Mais l’interface est la même : prompt, sample, append, repeat.
2. Ajouter la boucle terminal
Crée scripts/chat.py :
"""scripts/chat.py — talk to the checkpoint from your terminal."""
from __future__ import annotations
import torch
import torch.nn.functional as F
import tiktoken
from llm.model import GPT, GPTConfig
SYSTEM_PROMPT = "You answer questions briefly."
MAX_NEW_TOKENS = 160
TEMPERATURE = 0.8
TOP_K = 50
TOP_P = 0.9
# le SFT du ch.16 finit chaque tour assistant par "\n" : c'est le stop principal.
STOP_STRINGS = ["\nUser:", "\nSystem:", "\n"]
# [1]
def pick_device() -> str:
if torch.backends.mps.is_available():
return "mps"
if torch.cuda.is_available():
return "cuda"
return "cpu"
# [2]
def load_model(device: str) -> GPT:
cfg = GPTConfig()
model = GPT(cfg).to(device)
state = torch.load("checkpoints/model_sft.pt", map_location=device)
model.load_state_dict(state)
model.eval()
return model
# [3]
def render_prompt(turns: list[tuple[str, str]]) -> str:
lines = [f"System: {SYSTEM_PROMPT}"]
for role, text in turns:
lines.append(f"{role}: {text.strip()}")
lines.append("Assistant:")
return "\n".join(lines)
# [4]
def sample_next(logits: torch.Tensor) -> torch.Tensor:
logits = logits / TEMPERATURE
probs = F.softmax(logits, dim=-1)
if TOP_K is not None:
top_values, _ = probs.topk(TOP_K)
probs[probs < top_values[..., -1, None]] = 0
probs = probs / probs.sum(dim=-1, keepdim=True)
if TOP_P < 1.0:
sorted_probs, sorted_idx = probs.sort(descending=True, dim=-1)
cumulative = sorted_probs.cumsum(dim=-1)
mask = cumulative > TOP_P
mask[..., 0] = False
sorted_probs[mask] = 0
probs = torch.zeros_like(probs).scatter_(-1, sorted_idx, sorted_probs)
probs = probs / probs.sum(dim=-1, keepdim=True)
return torch.multinomial(probs, num_samples=1)
# [5]
@torch.no_grad()
def generate_reply(model: GPT, enc, prompt: str, device: str) -> str:
cfg = model.cfg
idx = torch.tensor([enc.encode_ordinary(prompt)], device=device)
pieces: list[str] = []
for _ in range(MAX_NEW_TOKENS):
idx_cond = idx if idx.size(1) <= cfg.block_size else idx[:, -cfg.block_size :]
logits, _ = model(idx_cond)
next_id = sample_next(logits[:, -1, :])
idx = torch.cat([idx, next_id], dim=1)
pieces.append(enc.decode([int(next_id.item())]))
text = "".join(pieces)
for stop in STOP_STRINGS:
if stop in text:
return text.split(stop, 1)[0].strip()
return "".join(pieces).strip()
# [6]
def main() -> None:
device = pick_device()
enc = tiktoken.get_encoding("gpt2")
model = load_model(device)
turns: list[tuple[str, str]] = []
print("The loss curve chat. Type /quit to stop.\n")
while True:
user_text = input("You: ").strip()
if user_text in {"/quit", "/exit"}:
break
if not user_text:
continue
turns.append(("User", user_text))
prompt = render_prompt(turns)
reply = generate_reply(model, enc, prompt, device)
print(f"Model: {reply}\n")
turns.append(("Assistant", reply))
if __name__ == "__main__":
main()Lis-le comme une fine enveloppe autour du chapitre 14 :
- [1]
pick_deviceprend le meilleur accélérateur disponible, sinon CPU. - [2]
load_modelrechargecheckpoints/model_sft.pt(si tu as sauté le chapitre 17, pointe surmodel.pt). - [3]
render_prompttransforme l’historique de conversation en préfixe texte. - [4]
sample_nextreprend température + top-K + top-P. - [5]
generate_replycroppe à la fenêtre de contexte, sample un token à la fois et s’arrête si le modèle commence un nouveau tour. - [6]
mainest la boucle terminal : lire l’entrée, générer, afficher, ajouter la réponse.
Lance :
python -m scripts.chatpython -m scripts.chatpython -m scripts.chatTu devrais obtenir :
The loss curve chat. Type /quit to stop.
You: what is a token?
Model: ...La réponse devrait au moins tenter une réponse brève au lieu de continuer en pseudo-shakespearien. Tu remarqueras aussi que c’est lent — la section suivante règle ça.
3. Rendre la génération rapide : le KV cache
Une génération de 160 sur CPU prend 30 à 60 secondes. À chaque pas, le modèle traite tout le contexte, alors que seul le dernier a changé. Pour T et L couches, chaque pas fait O(L · T²) de travail.
Le KV cache stocke les tenseurs K et V de chaque couche d’ après la première passe sur le prompt. À chaque pas suivant, on ne calcule K et V que pour le nouveau , on les concatène au cache, et la nouvelle query attend tout le cache. Le travail par pas passe à O(L · T).
Quantifie la différence. La cellule calcule le travail d’ total sur un prompt de 32 + 100 générés, naïf vs cached :
Code · JavaScript
Le support est déjà dans llm/model.py
Le llm/model.py du chapitre 12 a été écrit avec le KV cache en tête. Regarde CausalSelfAttention.forward, Block.forward et GPT.forward : chacun prend un argument optionnel past_kv / past_kvs, le module d’attention concatène K/V au cache si fourni et saute le masque causal quand la query est un seul . Quand tu passes past_kvs=None (le défaut en ), le comportement est inchangé. Cache activé = speedup d’. Pas de patch rétroactif.
Trois subtilités à garder en tête quand tu utilises le cache :
- Les embeddings de position se décalent de la taille du cache. Le nouveau est à la position
cached_length, pas 0.GPT.forwardlitpast_kvs[0][0].size(2)et décale l’arangeen conséquence. Sans ce décalage, chaque généré reçoit position 0 et le modèle s’écroule. - Le masque causal n’est appliqué que pendant la passe préfixe. Pendant la génération cached,
T = 1et la query doit voir tout le cache ; pas besoin de masque.CausalSelfAttentionvérifiepast_kv is Nonepour ça. - L’ reste inchangé — quand
past_kvs is None, le forward est strictement celui du chapitre 12.
La nouvelle boucle de génération
@torch.no_grad()
def generate_reply(model, enc, prompt, device):
cfg = model.cfg
prompt_ids = enc.encode_ordinary(prompt)
idx = torch.tensor([prompt_ids], device=device)
# [1] une passe complète sur le prompt, peuple le cache
logits, past_kvs = model(idx[:, : cfg.block_size])
pieces = []
for _ in range(MAX_NEW_TOKENS):
next_id = sample_next(logits[:, -1, :])
# [2] chaque pas suivant ne traite QU'UN nouveau token
logits, past_kvs = model(next_id, past_kvs=past_kvs)
pieces.append(enc.decode([int(next_id.item())]))
text = "".join(pieces)
for stop in STOP_STRINGS:
if stop in text:
return text.split(stop, 1)[0].strip()
return "".join(pieces).strip()Empiriquement sur ton modèle 14M en CPU : 100 passent de ~30s à ~3s. Plus le prompt est long, plus le gain est grand.
4. Pourquoi le modèle ne ressemble toujours pas à ChatGPT
Trois ingrédients rendent un modèle assistant-like :
- Capacité de base — assez de et de pré-entraînement. Ton modèle : 14M, ~272k . Très loin de la frontière.
- Données d’instruction — exemples en forme
User: ... Assistant: .... Ton modèle : 30-100 exemples du ch.16. Loin des 13k+ d’InstructGPT. - Preference tuning — optimisation vers des réponses utiles. Ton modèle : zéro. Hors scope du livre.
Tu as ajouté (2) au chapitre 17 — c’est ce qui fait que le chat suit la forme. (1) demande une échelle qu’on n’a pas. (3) vit dans un workflow séparé qu’on n’a pas construit.
À ta taille, attends-toi à des réponses correctes en forme mais pauvres en contenu. Le bon move : viser un domaine étroit où ton petit modèle peut être suffisamment bon.
5. Aller plus loin
- Plus de données SFT dans le domaine que tu cibles. 30 exemples enseignent la forme ; 300 enseignent le vocabulaire ; 3000 enseignent les patterns.
- LoRA par-dessus le SFT (chapitre 18) si tu veux plusieurs adaptateurs domaine sur un même modèle de base.
- Un base model plus gros. Remplace les poids de
llm/model.pypar GPT-2 small, puis SFT sur tes données. Même boucle, même chat — seulement les poids et le temps changent.
6. Ce que tu as maintenant
Tu n’as pas ChatGPT. Tu as la plus petite version honnête de la même forme produit :
- un de base + SFT
- un
- une fenêtre de contexte
- une politique de
- un chat template (System / User / Assistant)
- un historique de tours
- une boucle de génération avec KV cache
- un REPL terminal
C’est le pont entre « je comprends comment les fonctionnent » et « je peux livrer un petit prototype ».
Recap
- Le chat est de la structurée. Le modèle prédit toujours le prochain ; le
template indique le tour. - Réutilise exactement le chat template du chapitre 17 à l’. Toute dérive entre train-time et inference-time coûte la majorité du gain SFT. -
scripts/chat.pychargemodel_sft.pt, garde l’historique, un tour, boucle. - Les stop strings empêchent le modèle de dépasser un tour assistant ; le SFT lui a appris\ncomme fin naturelle. - Le KV cache fait passer la de 30s à 3s pour 100 .llm/model.pyaccepte déjàpast_kvsdepuis le chapitre 12 ; ce chapitre utilise simplement cet argument dans la boucle de génération. - Le projet local a maintenant un endpoint utilisable :python -m scripts.chat.
Pour aller plus loin
Prochaine étape : livrer quelque chose d’utile — le capstone. Choisis un domaine étroit, génère ~150 exemples SFT, fine-tune GPT-2 small, évalue côte à côte, et termine le livre avec un assistant spécialisé qui marche au lieu d’un assistant générique qui parle mal.