Skip to content
The loss curve

Chapitre 14 · 12 min

Génération et sampling

Transforme le checkpoint entraîné en générateur local contrôlable avec température, top-k et top-p sampling.

Au chapitre 13, tu as entraîné un modèle qui, à partir d’un préfixe de tokens, produit une distribution de probabilité sur le prochain token. Le script scripts/sample.py choisissait avec torch.multinomial(probs, num_samples=1). Ça marche, mais ce n’est qu’un choix possible, et ce choix change fortement ce que le modèle dit. Ce chapitre explore les trois knobs exposés par presque tous les LLM modernes : température, top-K et top-P. Ils ne demandent aucun réentraînement et changent beaucoup la qualité des sorties.

Trois knobs, un pipeline

Le pipeline standard prend des logits bruts et produit la distribution finale :

  1. Température : diviser chaque logit par T. T < 1 rend plus déterministe ; T > 1 rend plus aléatoire.2. Softmax : transformer en probabilités.
  2. Top-K : garder seulement les K tokens les plus probables.
  3. Top-P : garder le plus petit préfixe trié dont la probabilité cumulée atteint P.
  4. Sample : tirer dans la distribution restante.

Code · JavaScript

Joue avec les sliders :

  • Température → 0 : presque greedy.
  • Température → 2 : distribution plus plate, sorties plus surprenantes.
  • Top-K = 1 : greedy.
  • Top-P = 1 : aucun filtre top-P.
  • Top-P bas : seulement la tête de distribution.

La fonction generate complète en Python

Sauvegarde scripts/generate.py :

"""scripts/generate.py — autoregressive generation with temperature + top-K + top-P."""
import torch
import torch.nn.functional as F
import tiktoken
 
from llm.model import GPT, GPTConfig
 
# config
# [1]
prompt = "ROMEO:"
max_new_tokens = 200
temperature = 0.8
top_k = 50
top_p = 0.9
 
# [2]
device = "mps" if torch.backends.mps.is_available() else (
    "cuda" if torch.cuda.is_available() else "cpu"
)
cfg = GPTConfig()
model = GPT(cfg).to(device)
model.load_state_dict(torch.load("checkpoints/model.pt", map_location=device))
model.eval()
 
# [3]
enc = tiktoken.get_encoding("gpt2")
idx = torch.tensor([enc.encode_ordinary(prompt)], device=device)
 
@torch.no_grad()
def sample_next(logits):
    # [4]
    logits = logits / temperature
    probs = F.softmax(logits, dim=-1)
    if top_k is not None:
        # [5]
        top_vals, _ = probs.topk(top_k)
        probs[probs < top_vals[..., -1, None]] = 0
        probs = probs / probs.sum(dim=-1, keepdim=True)
    if top_p < 1.0:
        # [6]
        sorted_probs, sorted_idx = probs.sort(descending=True, dim=-1)
        cum = sorted_probs.cumsum(dim=-1)
        # find cutoff
        mask = cum > top_p
        mask[..., 0] = False  # always keep the most probable
        sorted_probs[mask] = 0
        # scatter back
        probs = torch.zeros_like(probs).scatter_(-1, sorted_idx, sorted_probs)
        probs = probs / probs.sum(dim=-1, keepdim=True)
    return torch.multinomial(probs, num_samples=1)
 
# [7]
for _ in range(max_new_tokens):
    idx_cond = idx if idx.size(1) <= cfg.block_size else idx[:, -cfg.block_size:]
    logits, _ = model(idx_cond)
    next_id = sample_next(logits[:, -1, :])
    idx = torch.cat([idx, next_id], dim=1)
 
print(enc.decode(idx[0].tolist()))

Le cœur est sample_next :

  • [1] contient les knobs de génération.
  • [2] charge le modèle.
  • [3] encode le prompt.
  • [4] applique la température avant softmax.
  • [5] filtre top-K.
  • [6] filtre top-P puis remet les probabilités à leurs ids d’origine.
  • [7] boucle : crop du contexte, modèle, sample, append.

Lance :

python -m scripts.generate
python -m scripts.generate
python -m scripts.generate

Modifie les paramètres, relance, compare. temperature = 0.1 donne souvent du texte répétitif ; temperature = 1.5 et top_p = 0.95 donnent plus de variété et plus d’incohérence.

Et beam search ?

Pour la génération libre : évite. Beam search produit des séquences très probables, mais souvent moins bonnes subjectivement. Le langage humain contient des choix créatifs à faible probabilité que beam search coupe trop tôt. Pour traduction ou tâches à réponse stricte, beam search peut rester utile ; pour texte libre, on sample.

Recap

  • La température scale les logits avant softmax. - Top-K garde les K tokens les plus probables. - Top-P garde le plus petit préfixe de probabilité cumulée P. - Les defaults des modèles chat sont souvent T = 0.7-1.0, K = 50, P = 0.9-0.95. - Beam search a sa place, mais pas dans la génération libre. - Ton projet local a maintenant scripts/generate.py.

Pour aller plus loin

Prochaine étape : charger les vrais poids — ton modèle a la forme d’un LLM mais ses poids sont minuscules et empoisonnés au Shakespeare. Place aux poids GPT-2 d’OpenAI dans la même architecture, et regarde ce que ton code peut vraiment faire.