Skip to content
The loss curve

Chapter 19 · 10 min

Simple quantization

Quantize your model to INT8 — half the memory, almost the same outputs. See where it breaks and what KV cache costs.

A trained transformer is a collection of weight matrices full of float32 numbers. A 1B-parameter model in float32 is 4 GB. The same model in int8 is 1 GB. The same model in int4 is 500 MB. The output quality difference between float32 and int8 is small — often imperceptible — and the inference speedup is real (modern CPUs and GPUs have dedicated integer-arithmetic paths).

This is quantization: a one-line idea, applied systematically across every weight matrix, that makes large models cheap enough to run on consumer hardware.

1. Symmetric uniform quantization

The simplest scheme: pick a scale s, store every weight as a small integer, multiply by the scale to recover an approximation:

q=round(w/s),w^=qsq = \text{round}(w / s), \qquad \hat{w} = q \cdot s

The method trades model quality for size reduction and speedup.

For symmetric INT8 (the common case), s = max(|w|) / 127. Now every w maps to an integer in [-127, 127]. To use the weights at inference time, multiply each integer by s. The error is at most s/2 per weight — bounded by the precision of the quantization, independent of magnitude.

Write the round-trip function. The chapter applies it to a fake weight matrix and shows you the original vs reconstructed curves at different bit widths.

Code · JavaScript

Slide the bit width. At 8 bits, the reconstructed curve is visually indistinguishable from the original. At 4 bits, it's recognizable but stair-stepped. At 2 bits, you've lost most of the shape — you can only express 4 values across the entire range, and that's usually too coarse for neural-network weights.

2. Why this works at all

Quantizing a single weight matrix gives you compression for free, but the model has to compute with those weights. If the quantization error compounds across layers, output quality collapses.

In practice, it doesn't — for two reasons:

  • Neural networks are robust to small perturbations, especially weight perturbations. The training process produces models whose loss surface is roughly flat in the immediate neighborhood of the trained weights; quantization moves the weights within that neighborhood.
  • The error has zero mean. Random round-off averages out across many multiplications. The output of one matrix-vector product has perturbation that scales as √n · s/2, not n · s/2, by the central limit theorem.

For a typical INT8 quantized transformer, perplexity on a held-out test set goes up by ~1-2%. The model is, by most measures, the same model.

3. Where naïve quantization fails

The scheme above (one scale per weight matrix) leaves performance on the table. Two refinements get used in practice:

  • Per-channel scales (one per row of W) capture the fact that different output channels often have different magnitude distributions. Free improvement.
  • Outlier-aware quantization (LLM.int8, GPTQ, AWQ) handles the fact that a small fraction of weights are very large and dominate the max(|w|). Strategies: store outliers in higher precision, or rescale them out before quantizing.

Both add complexity but recover most of the perplexity gap on large models. For our scope, the simple symmetric scheme above is fine.

4. The PyTorch version

Quantizing a model after training (no fine-tuning) is post-training quantization (PTQ). The minimal version using PyTorch's built-in tools. Save this as scripts/quantize.py:

"""scripts/quantize.py — apply dynamic INT8 quantization to a trained model."""
import torch
 
from llm.model import GPT, GPTConfig
 
# [1]
cfg = GPTConfig()
model = GPT(cfg)
model.load_state_dict(torch.load("checkpoints/model.pt", map_location="cpu"))
model.eval()
 
# Dynamic quantization: weights become INT8 at load time;
# activations stay float and are quantized on-the-fly during the forward pass.
# [2]
qmodel = torch.quantization.quantize_dynamic(
    model,
    {torch.nn.Linear},  # which layer types to quantize
    dtype=torch.qint8,
)
 
# Compare file sizes
# [3]
torch.save(model.state_dict(), "checkpoints/model_fp32.pt")
torch.save(qmodel.state_dict(), "checkpoints/model_int8.pt")
 
import os
# [4]
fp32_size = os.path.getsize("checkpoints/model_fp32.pt")
int8_size = os.path.getsize("checkpoints/model_int8.pt")
print(f"fp32: {fp32_size / 1024:.1f} KB")
print(f"int8: {int8_size / 1024:.1f} KB")
print(f"ratio: {fp32_size / int8_size:.2f}x")

Read this script as a before/after measurement:

  • [1] loads the normal checkpoint on CPU. Quantization is mostly an inference-time CPU win here.
  • [2] quantize_dynamic finds nn.Linear layers and stores their weights in INT8 form. Activations stay floating point, so you do not need a calibration dataset.
  • [3] saves both state dicts so you can compare files directly.
  • [4] measures the before/after size ratio. This does not retrain the model; it is post-training compression.
python -m scripts.quantize
python -m scripts.quantize
python -m scripts.quantize

For our 14M-parameter model you should see a ~3× size reduction (somewhat less than 4× because some buffers don't quantize). Inference on CPU is roughly 2× faster. Perplexity will go up by less than 1%.

What about INT4 and below?

You can keep going. INT4 quantization (4 bits per weight) gets you to 8× compression and is the default in llama.cpp-style local-inference setups. INT2 mostly doesn't work — too few levels.

The state of the art (GPTQ, AWQ) does INT4 with very small quality loss. The trick: use a small calibration dataset and pick the quantization scheme that minimizes the post-quantization loss on it, not just the per-weight reconstruction error. We don't cover that here; the principle is the same and the implementation is in any production quantization library.

Recap

  • Quantization trades a small amount of model quality for 4-8× size reduction and 2-4× inference speedup. - Symmetric uniform quantization is the simplest scheme: q = round(w / s), s = max(|w|) / (2^(bits-1) - 1). - Why it works: neural network outputs are robust to small weight perturbations, and quantization error averages out across many products. - INT8 is essentially free; INT4 with good methods loses 2-5% perplexity; INT2 usually doesn't work. - PyTorch's quantize_dynamic is one line, handles nn.Linear for you, and is the right default for transformer inference. - Your local project now has scripts/quantize.py and a smaller checkpoint to compare against the original.

Going further

Next up: talk to your model — the checkpoint is trained and cheaper to run. Now wrap it in the smallest honest chat interface.