Chapter 20 · 12 min
Talk to your model
Talk to your model. A minimal chat loop with a KV cache — and the difference between cached and uncached generation.
You have the pieces: , model, base , SFT from chapter 17, , LoRA, quantization. The missing piece is not more neural-network theory. It is an interface.
Modern chat products are still autoregressive language models underneath. The trick is to format the prompt as a conversation, generate the assistant turn, append that turn to the history, and repeat — using exactly the same chat template the model was fine-tuned on.
This chapter adds that wrapper, plus the one optimization that makes long generations bearable: a KV cache.
1. Chat is prompt formatting
A plain generator receives text like this:
ROMEO:A chat loop gives the model more structure:
System: You are a small language model trained in The loss curve.
User: explain what a token is
Assistant:The model still predicts the next token. The only difference is that the prefix now tells it which speaker should come next. That means the quality depends on your data. If the model was trained mostly on Shakespeare, it will not suddenly become a helpful assistant. But the interface is the same interface used by useful chat systems: prompt, sample, append, repeat.
2. Add the terminal loop
Create scripts/chat.py:
"""scripts/chat.py — talk to the checkpoint from your terminal."""
from __future__ import annotations
import torch
import torch.nn.functional as F
import tiktoken
from llm.model import GPT, GPTConfig
SYSTEM_PROMPT = "You answer questions briefly."
MAX_NEW_TOKENS = 160
TEMPERATURE = 0.8
TOP_K = 50
TOP_P = 0.9
# the SFT examples in ch.16 end each assistant turn with "\n", so that's the
# primary stop. The "\nUser:" / "\nSystem:" fallbacks catch the model if it
# starts a new turn on its own.
STOP_STRINGS = ["\nUser:", "\nSystem:", "\n"]
# [1]
def pick_device() -> str:
if torch.backends.mps.is_available():
return "mps"
if torch.cuda.is_available():
return "cuda"
return "cpu"
# [2]
def load_model(device: str) -> GPT:
cfg = GPTConfig()
model = GPT(cfg).to(device)
state = torch.load("checkpoints/model_sft.pt", map_location=device)
model.load_state_dict(state)
model.eval()
return model
# [3]
def render_prompt(turns: list[tuple[str, str]]) -> str:
lines = [f"System: {SYSTEM_PROMPT}"]
for role, text in turns:
lines.append(f"{role}: {text.strip()}")
lines.append("Assistant:")
return "\n".join(lines)
# [4]
def sample_next(logits: torch.Tensor) -> torch.Tensor:
logits = logits / TEMPERATURE
probs = F.softmax(logits, dim=-1)
if TOP_K is not None:
top_values, _ = probs.topk(TOP_K)
probs[probs < top_values[..., -1, None]] = 0
probs = probs / probs.sum(dim=-1, keepdim=True)
if TOP_P < 1.0:
sorted_probs, sorted_idx = probs.sort(descending=True, dim=-1)
cumulative = sorted_probs.cumsum(dim=-1)
mask = cumulative > TOP_P
mask[..., 0] = False
sorted_probs[mask] = 0
probs = torch.zeros_like(probs).scatter_(-1, sorted_idx, sorted_probs)
probs = probs / probs.sum(dim=-1, keepdim=True)
return torch.multinomial(probs, num_samples=1)
# [5]
@torch.no_grad()
def generate_reply(model: GPT, enc, prompt: str, device: str) -> str:
cfg = model.cfg
idx = torch.tensor([enc.encode_ordinary(prompt)], device=device)
pieces: list[str] = []
for _ in range(MAX_NEW_TOKENS):
idx_cond = idx if idx.size(1) <= cfg.block_size else idx[:, -cfg.block_size :]
logits, _ = model(idx_cond)
next_id = sample_next(logits[:, -1, :])
idx = torch.cat([idx, next_id], dim=1)
pieces.append(enc.decode([int(next_id.item())]))
text = "".join(pieces)
for stop in STOP_STRINGS:
if stop in text:
return text.split(stop, 1)[0].strip()
return "".join(pieces).strip()
# [6]
def main() -> None:
device = pick_device()
enc = tiktoken.get_encoding("gpt2")
model = load_model(device)
turns: list[tuple[str, str]] = []
print("The loss curve chat. Type /quit to stop.\n")
while True:
user_text = input("You: ").strip()
if user_text in {"/quit", "/exit"}:
break
if not user_text:
continue
turns.append(("User", user_text))
prompt = render_prompt(turns)
reply = generate_reply(model, enc, prompt, device)
print(f"Model: {reply}\n")
turns.append(("Assistant", reply))
if __name__ == "__main__":
main()Read it as a thin shell around chapter 14:
- [1]
pick_deviceuses the best local accelerator available, then falls back to CPU. - [2]
load_modelrestores the SFT fromcheckpoints/model_sft.pt. If you skipped chapter 17, point this atmodel.ptinstead and prepare for less assistant-shaped output. - [3]
render_promptturns the conversation history into one text prefix. - [4]
sample_nextis the same temperature + top-K + top-P pipeline from chapter 14. - [5]
generate_replycrops to the model's context window, samples one token at a time, and stops if the model starts a new speaker turn. - [6]
mainis the terminal loop: read user input, append it, generate, print, append the reply.
Then run it:
python -m scripts.chatpython -m scripts.chatpython -m scripts.chatYou should get a prompt like this:
The loss curve chat. Type /quit to stop.
You: what is a token?
Model: ...The answer may be a little better than chapter 14's — at minimum it should attempt a brief answer rather than continuing in Shakespeare. What matters is that the full path now exists: local data → → → → SFT → → sampler → chat interface.
You will also notice it is slow. The naive loop above is correct but redundant in a way the next section fixes.
3. Make it fast: the KV cache
Run a 160- generation. On CPU expect 30–60 seconds. The pain ratio gets worse as the conversation history grows.
The cause: at every step of generation, the model processes the entire context (prompt + every generated so far). For every layer, that means recomputing Q, K, and V over the full sequence even though only the last changed. With T total and L layers, each step does O(L · T²) work; the full generation is O(L · T³). At our tiny scale you still feel it.
The standard fix is the KV cache: store the K and V tensors from each layer's after the prompt's first pass. At each subsequent step, compute K and V only for the new , concatenate them onto the cache, and let the new query attend to the whole cached K, V. Per-step work drops to O(L · T).
Quantify the difference. The cell computes attention work over a 32- prompt + 100 generated , naive vs cached:
Code · JavaScript
The architectural support is already there
The llm/model.py from chapter 12 was written with KV cache in mind. Look at CausalSelfAttention.forward, Block.forward, and GPT.forward: each takes an optional past_kv / past_kvs argument, and the attention module conditionally concatenates new keys/values onto the cache and skips the causal mask when the new query is a single . When you pass past_kvs=None (the default in ), behavior is unchanged from chapter 10. Cache enabled = inference-time speedup. No retroactive patch needed.
Three subtleties to keep in mind when you use the cache:
- Position embeddings offset by the cached length. The single new is at position
cached_length, not at 0.GPT.forwardreadspast_kvs[0][0].size(2)and shifts the position arange accordingly. Without the offset, every generated would get position 0 and the model collapses. - The causal mask is only applied on the prefix pass. During cached generation,
T = 1and the new query is supposed to see the entire cache; no mask is needed.CausalSelfAttentioncheckspast_kv is Nonefor this. - Training is unaffected — when
past_kvsisNone, the behavior is identical to chapter 12's default forward path. The cache is purely an optimization.
The cached generation loop
Replace generate_reply with this version:
@torch.no_grad()
def generate_reply(model: GPT, enc, prompt: str, device: str) -> str:
cfg = model.cfg
prompt_ids = enc.encode_ordinary(prompt)
idx = torch.tensor([prompt_ids], device=device)
# [1] one full forward pass on the prompt, populates the cache
logits, past_kvs = model(idx[:, : cfg.block_size])
pieces: list[str] = []
for _ in range(MAX_NEW_TOKENS):
next_id = sample_next(logits[:, -1, :])
# [2] every subsequent pass processes ONE new token
logits, past_kvs = model(next_id, past_kvs=past_kvs)
pieces.append(enc.decode([int(next_id.item())]))
text = "".join(pieces)
for stop in STOP_STRINGS:
if stop in text:
return text.split(stop, 1)[0].strip()
return "".join(pieces).strip()The structural shift: the model is called with the full prompt once, then with a single per step. The cache carries the state.
Empirically on a 14M-param model, CPU, 100- generation drops from ~30s to ~3s. The exact speedup depends on prompt length; longer prompts win more.
4. Why the model still does not behave like ChatGPT
Three ingredients make a model feel assistant-like:
- Base capability — enough and pretraining data to model language well. Your model: 14M params, ~272k . Far below frontier.
- Instruction data — many examples shaped like
User: ... Assistant: .... Your model: 30–100 examples from chapter 17. Far below the 13k+ that even InstructGPT used. - Preference tuning — extra optimization toward helpful, harmless, concise answers. Your model: none. Beyond the scope of this book.
You added (2) at chapter 17 — that is what makes the chat actually follow the format. (1) is the part that demands scale you do not have. (3) is the part that lives in a separate workflow we did not build.
The takeaway: at your scale, expect format-correct but content-thin responses. The right move is to lean into a narrow domain where the small model can be good enough.
5. Going deeper
If the answers are not useful enough, the next levers are:
- More SFT data in the domain you actually care about. 30 examples teach the model the format; 300 teach it the domain vocabulary; 3,000 teach it the patterns. Diminishing returns kick in past a few thousand for a narrow task.
- LoRA on top of the SFT (chapter 18) when you want several domain adapters but one base model. Cheap to , cheap to ship.
- A bigger base model. Swap your
llm/model.pyweights for GPT-2 small or another open base, then SFT on your data. Same training script, same chat loop — the wrapper does not change. Just the weights and the time-to-loss.
6. What you now have
You do not have ChatGPT. You have the smallest honest version of the same product shape:
- a base + SFT
- a
- a context window
- a policy (temperature, top-k, top-p)
- a chat template (System / User / Assistant)
- a turn history
- a KV-cached generation loop
- a terminal REPL
That is the bridge between "I understand how work" and "I can ship a small -powered prototype."
Recap
- Chat is with structure. The model still predicts the next ; the prompt template
tells it whose turn it is. - Reuse the SFT chat template from chapter 17 at exactly.
Drift between train-time and inference-time format costs most of the SFT gains. -
scripts/chat.pyloadsmodel_sft.pt, keeps conversation history, an assistant turn, and loops. - Stop strings keep the model from continuing past one assistant turn — the SFT model learned\nas the natural end. - The KV cache is the difference between 10× faster generation and unusable latency.llm/model.pyalready acceptspast_kvssince chapter 12; this chapter just uses that argument in the generate loop so only the new goes through the model per step. - The local project now has a usable endpoint:python -m scripts.chat.
Going further
- Karpathy's nanoGPT — the codebase this course stays closest to in spirit.
- Hugging Face chat templates — production models use explicit conversation templates for the same reason this chapter does.
- OpenAI model spec — useful reading on how product-level assistant behavior is specified beyond raw next-token prediction.
Next up: ship a useful one — the capstone. Pick a narrow domain, generate ~150 SFT examples, fine-tune GPT-2 small, evaluate side by side, and end the book with one working specialized assistant instead of one generic one.