Chapitre 14 · 12 min
Génération et sampling
Transforme le checkpoint entraîné en générateur local contrôlable avec température, top-k et top-p sampling.
Au chapitre 13, tu as entraîné un modèle qui, à partir d’un préfixe de tokens, produit une distribution de probabilité sur le prochain token. Le script scripts/sample.py choisissait avec torch.multinomial(probs, num_samples=1). Ça marche, mais ce n’est qu’un choix possible, et ce choix change fortement ce que le modèle dit.
Ce chapitre explore les trois knobs exposés par presque tous les LLM modernes : température, top-K et top-P. Ils ne demandent aucun réentraînement et changent beaucoup la qualité des sorties.
Trois knobs, un pipeline
Le pipeline standard prend des logits bruts et produit la distribution finale :
- Température : diviser chaque logit par
T.T < 1rend plus déterministe ;T > 1rend plus aléatoire.2. Softmax : transformer en probabilités. - Top-K : garder seulement les K tokens les plus probables.
- Top-P : garder le plus petit préfixe trié dont la probabilité cumulée atteint P.
- Sample : tirer dans la distribution restante.
Code · JavaScript
Joue avec les sliders :
- Température → 0 : presque greedy.
- Température → 2 : distribution plus plate, sorties plus surprenantes.
- Top-K = 1 : greedy.
- Top-P = 1 : aucun filtre top-P.
- Top-P bas : seulement la tête de distribution.
La fonction generate complète en Python
Sauvegarde scripts/generate.py :
"""scripts/generate.py — autoregressive generation with temperature + top-K + top-P."""
import torch
import torch.nn.functional as F
import tiktoken
from llm.model import GPT, GPTConfig
# config
# [1]
prompt = "ROMEO:"
max_new_tokens = 200
temperature = 0.8
top_k = 50
top_p = 0.9
# [2]
device = "mps" if torch.backends.mps.is_available() else (
"cuda" if torch.cuda.is_available() else "cpu"
)
cfg = GPTConfig()
model = GPT(cfg).to(device)
model.load_state_dict(torch.load("checkpoints/model.pt", map_location=device))
model.eval()
# [3]
enc = tiktoken.get_encoding("gpt2")
idx = torch.tensor([enc.encode_ordinary(prompt)], device=device)
@torch.no_grad()
def sample_next(logits):
# [4]
logits = logits / temperature
probs = F.softmax(logits, dim=-1)
if top_k is not None:
# [5]
top_vals, _ = probs.topk(top_k)
probs[probs < top_vals[..., -1, None]] = 0
probs = probs / probs.sum(dim=-1, keepdim=True)
if top_p < 1.0:
# [6]
sorted_probs, sorted_idx = probs.sort(descending=True, dim=-1)
cum = sorted_probs.cumsum(dim=-1)
# find cutoff
mask = cum > top_p
mask[..., 0] = False # always keep the most probable
sorted_probs[mask] = 0
# scatter back
probs = torch.zeros_like(probs).scatter_(-1, sorted_idx, sorted_probs)
probs = probs / probs.sum(dim=-1, keepdim=True)
return torch.multinomial(probs, num_samples=1)
# [7]
for _ in range(max_new_tokens):
idx_cond = idx if idx.size(1) <= cfg.block_size else idx[:, -cfg.block_size:]
logits, _ = model(idx_cond)
next_id = sample_next(logits[:, -1, :])
idx = torch.cat([idx, next_id], dim=1)
print(enc.decode(idx[0].tolist()))Le cœur est sample_next :
- [1] contient les knobs de génération.
- [2] charge le modèle.
- [3] encode le prompt.
- [4] applique la température avant softmax.
- [5] filtre top-K.
- [6] filtre top-P puis remet les probabilités à leurs ids d’origine.
- [7] boucle : crop du contexte, modèle, sample, append.
Lance :
python -m scripts.generatepython -m scripts.generatepython -m scripts.generateModifie les paramètres, relance, compare. temperature = 0.1 donne souvent du texte répétitif ; temperature = 1.5 et top_p = 0.95 donnent plus de variété et plus d’incohérence.
Et beam search ?
Pour la génération libre : évite. Beam search produit des séquences très probables, mais souvent moins bonnes subjectivement. Le langage humain contient des choix créatifs à faible probabilité que beam search coupe trop tôt. Pour traduction ou tâches à réponse stricte, beam search peut rester utile ; pour texte libre, on sample.
Recap
- La température scale les logits avant softmax. - Top-K garde les K tokens les plus
probables. - Top-P garde le plus petit préfixe de probabilité cumulée P. - Les defaults
des modèles chat sont souvent
T = 0.7-1.0,K = 50,P = 0.9-0.95. - Beam search a sa place, mais pas dans la génération libre. - Ton projet local a maintenantscripts/generate.py.
Pour aller plus loin
- Holtzman et al., “The Curious Case of Neural Text Degeneration”.
- La doc génération de Hugging Face.
- Step by Token, chapitre 7.
Prochaine étape : charger les vrais poids — ton modèle a la forme d’un LLM mais ses poids sont minuscules et empoisonnés au Shakespeare. Place aux poids GPT-2 d’OpenAI dans la même architecture, et regarde ce que ton code peut vraiment faire.