Skip to content
The loss curve

Chapter 20 · 12 min

Talk to your model

Talk to your model. A minimal chat loop with a KV cache — and the difference between cached and uncached generation.

You have the pieces: , model, base , SFT from chapter 17, , LoRA, quantization. The missing piece is not more neural-network theory. It is an interface.

Modern chat products are still autoregressive language models underneath. The trick is to format the prompt as a conversation, generate the assistant turn, append that turn to the history, and repeat — using exactly the same chat template the model was fine-tuned on.

This chapter adds that wrapper, plus the one optimization that makes long generations bearable: a KV cache.

1. Chat is prompt formatting

A plain generator receives text like this:

ROMEO:

A chat loop gives the model more structure:

System: You are a small language model trained in The loss curve.
User: explain what a token is
Assistant:

The model still predicts the next token. The only difference is that the prefix now tells it which speaker should come next. That means the quality depends on your data. If the model was trained mostly on Shakespeare, it will not suddenly become a helpful assistant. But the interface is the same interface used by useful chat systems: prompt, sample, append, repeat.

2. Add the terminal loop

Create scripts/chat.py:

"""scripts/chat.py — talk to the checkpoint from your terminal."""
from __future__ import annotations
 
import torch
import torch.nn.functional as F
import tiktoken
 
from llm.model import GPT, GPTConfig
 
 
SYSTEM_PROMPT = "You answer questions briefly."
MAX_NEW_TOKENS = 160
TEMPERATURE = 0.8
TOP_K = 50
TOP_P = 0.9
# the SFT examples in ch.16 end each assistant turn with "\n", so that's the
# primary stop. The "\nUser:" / "\nSystem:" fallbacks catch the model if it
# starts a new turn on its own.
STOP_STRINGS = ["\nUser:", "\nSystem:", "\n"]
 
 
# [1]
def pick_device() -> str:
    if torch.backends.mps.is_available():
        return "mps"
    if torch.cuda.is_available():
        return "cuda"
    return "cpu"
 
 
# [2]
def load_model(device: str) -> GPT:
    cfg = GPTConfig()
    model = GPT(cfg).to(device)
    state = torch.load("checkpoints/model_sft.pt", map_location=device)
    model.load_state_dict(state)
    model.eval()
    return model
 
 
# [3]
def render_prompt(turns: list[tuple[str, str]]) -> str:
    lines = [f"System: {SYSTEM_PROMPT}"]
    for role, text in turns:
        lines.append(f"{role}: {text.strip()}")
    lines.append("Assistant:")
    return "\n".join(lines)
 
 
# [4]
def sample_next(logits: torch.Tensor) -> torch.Tensor:
    logits = logits / TEMPERATURE
    probs = F.softmax(logits, dim=-1)
 
    if TOP_K is not None:
        top_values, _ = probs.topk(TOP_K)
        probs[probs < top_values[..., -1, None]] = 0
        probs = probs / probs.sum(dim=-1, keepdim=True)
 
    if TOP_P < 1.0:
        sorted_probs, sorted_idx = probs.sort(descending=True, dim=-1)
        cumulative = sorted_probs.cumsum(dim=-1)
        mask = cumulative > TOP_P
        mask[..., 0] = False
        sorted_probs[mask] = 0
        probs = torch.zeros_like(probs).scatter_(-1, sorted_idx, sorted_probs)
        probs = probs / probs.sum(dim=-1, keepdim=True)
 
    return torch.multinomial(probs, num_samples=1)
 
 
# [5]
@torch.no_grad()
def generate_reply(model: GPT, enc, prompt: str, device: str) -> str:
    cfg = model.cfg
    idx = torch.tensor([enc.encode_ordinary(prompt)], device=device)
    pieces: list[str] = []
 
    for _ in range(MAX_NEW_TOKENS):
        idx_cond = idx if idx.size(1) <= cfg.block_size else idx[:, -cfg.block_size :]
        logits, _ = model(idx_cond)
        next_id = sample_next(logits[:, -1, :])
        idx = torch.cat([idx, next_id], dim=1)
 
        pieces.append(enc.decode([int(next_id.item())]))
        text = "".join(pieces)
        for stop in STOP_STRINGS:
            if stop in text:
                return text.split(stop, 1)[0].strip()
 
    return "".join(pieces).strip()
 
 
# [6]
def main() -> None:
    device = pick_device()
    enc = tiktoken.get_encoding("gpt2")
    model = load_model(device)
    turns: list[tuple[str, str]] = []
 
    print("The loss curve chat. Type /quit to stop.\n")
    while True:
        user_text = input("You: ").strip()
        if user_text in {"/quit", "/exit"}:
            break
        if not user_text:
            continue
 
        turns.append(("User", user_text))
        prompt = render_prompt(turns)
        reply = generate_reply(model, enc, prompt, device)
        print(f"Model: {reply}\n")
        turns.append(("Assistant", reply))
 
 
if __name__ == "__main__":
    main()

Read it as a thin shell around chapter 14:

  • [1] pick_device uses the best local accelerator available, then falls back to CPU.
  • [2] load_model restores the SFT from checkpoints/model_sft.pt. If you skipped chapter 17, point this at model.pt instead and prepare for less assistant-shaped output.
  • [3] render_prompt turns the conversation history into one text prefix.
  • [4] sample_next is the same temperature + top-K + top-P pipeline from chapter 14.
  • [5] generate_reply crops to the model's context window, samples one token at a time, and stops if the model starts a new speaker turn.
  • [6] main is the terminal loop: read user input, append it, generate, print, append the reply.

Then run it:

python -m scripts.chat
python -m scripts.chat
python -m scripts.chat

You should get a prompt like this:

The loss curve chat. Type /quit to stop.
 
You: what is a token?
Model: ...

The answer may be a little better than chapter 14's — at minimum it should attempt a brief answer rather than continuing in Shakespeare. What matters is that the full path now exists: local data → → SFT → → sampler → chat interface.

You will also notice it is slow. The naive loop above is correct but redundant in a way the next section fixes.

3. Make it fast: the KV cache

Run a 160- generation. On CPU expect 30–60 seconds. The pain ratio gets worse as the conversation history grows.

The cause: at every step of generation, the model processes the entire context (prompt + every generated so far). For every layer, that means recomputing Q, K, and V over the full sequence even though only the last changed. With T total and L layers, each step does O(L · T²) work; the full generation is O(L · T³). At our tiny scale you still feel it.

The standard fix is the KV cache: store the K and V tensors from each layer's after the prompt's first pass. At each subsequent step, compute K and V only for the new , concatenate them onto the cache, and let the new query attend to the whole cached K, V. Per-step work drops to O(L · T).

Quantify the difference. The cell computes attention work over a 32- prompt + 100 generated , naive vs cached:

Code · JavaScript

The architectural support is already there

The llm/model.py from chapter 12 was written with KV cache in mind. Look at CausalSelfAttention.forward, Block.forward, and GPT.forward: each takes an optional past_kv / past_kvs argument, and the attention module conditionally concatenates new keys/values onto the cache and skips the causal mask when the new query is a single . When you pass past_kvs=None (the default in ), behavior is unchanged from chapter 10. Cache enabled = inference-time speedup. No retroactive patch needed.

Three subtleties to keep in mind when you use the cache:

  • Position embeddings offset by the cached length. The single new is at position cached_length, not at 0. GPT.forward reads past_kvs[0][0].size(2) and shifts the position arange accordingly. Without the offset, every generated would get position 0 and the model collapses.
  • The causal mask is only applied on the prefix pass. During cached generation, T = 1 and the new query is supposed to see the entire cache; no mask is needed. CausalSelfAttention checks past_kv is None for this.
  • Training is unaffected — when past_kvs is None, the behavior is identical to chapter 12's default forward path. The cache is purely an optimization.

The cached generation loop

Replace generate_reply with this version:

@torch.no_grad()
def generate_reply(model: GPT, enc, prompt: str, device: str) -> str:
    cfg = model.cfg
    prompt_ids = enc.encode_ordinary(prompt)
    idx = torch.tensor([prompt_ids], device=device)
 
    # [1] one full forward pass on the prompt, populates the cache
    logits, past_kvs = model(idx[:, : cfg.block_size])
 
    pieces: list[str] = []
    for _ in range(MAX_NEW_TOKENS):
        next_id = sample_next(logits[:, -1, :])
        # [2] every subsequent pass processes ONE new token
        logits, past_kvs = model(next_id, past_kvs=past_kvs)
        pieces.append(enc.decode([int(next_id.item())]))
        text = "".join(pieces)
        for stop in STOP_STRINGS:
            if stop in text:
                return text.split(stop, 1)[0].strip()
 
    return "".join(pieces).strip()

The structural shift: the model is called with the full prompt once, then with a single per step. The cache carries the state.

Empirically on a 14M-param model, CPU, 100- generation drops from ~30s to ~3s. The exact speedup depends on prompt length; longer prompts win more.

4. Why the model still does not behave like ChatGPT

Three ingredients make a model feel assistant-like:

  1. Base capability — enough and pretraining data to model language well. Your model: 14M params, ~272k . Far below frontier.
  2. Instruction data — many examples shaped like User: ... Assistant: .... Your model: 30–100 examples from chapter 17. Far below the 13k+ that even InstructGPT used.
  3. Preference tuning — extra optimization toward helpful, harmless, concise answers. Your model: none. Beyond the scope of this book.

You added (2) at chapter 17 — that is what makes the chat actually follow the format. (1) is the part that demands scale you do not have. (3) is the part that lives in a separate workflow we did not build.

The takeaway: at your scale, expect format-correct but content-thin responses. The right move is to lean into a narrow domain where the small model can be good enough.

5. Going deeper

If the answers are not useful enough, the next levers are:

  • More SFT data in the domain you actually care about. 30 examples teach the model the format; 300 teach it the domain vocabulary; 3,000 teach it the patterns. Diminishing returns kick in past a few thousand for a narrow task.
  • LoRA on top of the SFT (chapter 18) when you want several domain adapters but one base model. Cheap to , cheap to ship.
  • A bigger base model. Swap your llm/model.py weights for GPT-2 small or another open base, then SFT on your data. Same training script, same chat loop — the wrapper does not change. Just the weights and the time-to-loss.

6. What you now have

You do not have ChatGPT. You have the smallest honest version of the same product shape:

  • a base + SFT
  • a
  • a context window
  • a policy (temperature, top-k, top-p)
  • a chat template (System / User / Assistant)
  • a turn history
  • a KV-cached generation loop
  • a terminal REPL

That is the bridge between "I understand how work" and "I can ship a small -powered prototype."

Recap

  • Chat is with structure. The model still predicts the next ; the prompt template tells it whose turn it is. - Reuse the SFT chat template from chapter 17 at exactly. Drift between train-time and inference-time format costs most of the SFT gains. - scripts/chat.py loads model_sft.pt, keeps conversation history, an assistant turn, and loops. - Stop strings keep the model from continuing past one assistant turn — the SFT model learned \n as the natural end. - The KV cache is the difference between 10× faster generation and unusable latency. llm/model.py already accepts past_kvs since chapter 12; this chapter just uses that argument in the generate loop so only the new goes through the model per step. - The local project now has a usable endpoint: python -m scripts.chat.

Going further

  • Karpathy's nanoGPT — the codebase this course stays closest to in spirit.
  • Hugging Face chat templates — production models use explicit conversation templates for the same reason this chapter does.
  • OpenAI model spec — useful reading on how product-level assistant behavior is specified beyond raw next-token prediction.

Next up: ship a useful one — the capstone. Pick a narrow domain, generate ~150 SFT examples, fine-tune GPT-2 small, evaluate side by side, and end the book with one working specialized assistant instead of one generic one.