Skip to content
The loss curve

Chapter 18 · 12 min

Fine-tuning with LoRA

Implement Low-Rank Adaptation in ~30 lines and fine-tune GPT-2 with a fraction of the parameters. Math, code, results.

You have a model. It works (sort of). You want to specialize it — make it better at code, or at one specific dialect, or at producing JSON. The naïve approach: keep training on the new data, updating every parameter.

That works but it's expensive. Re-training a billion-parameter model on a million tokens of new data takes hundreds of GPU-hours and produces a billion-parameter checkpoint per task. Most of the weights barely change. We're burning storage and compute we don't need to.

LoRA (Hu et al., 2021) is the trick that makes fine-tuning cheap. Instead of updating W, freeze it. Add a small low-rank correction A · B and only train those. The math:

y=xW+αx(AB)y = x \cdot W + \alpha \cdot x \cdot (A \cdot B)

Where W is [d × d] (frozen), A is [d × r], B is [r × d], and r ≪ d. The number of trainable drops from to 2 · d · r — a savings ratio of d / (2r). The trick gets dramatically better as models grow:

  • Your ch.12 model (n_embd = 128, r = 8): each square nn.Linear goes 16,384 → 2,048 — an 8× reduction.
  • GPT-2 small (d = 1024, r = 8): 1,048,576 → 16,384 — a 64× reduction.
  • GPT-3 territory (d = 4,096, r = 8): 16,777,216 → 65,536 — a 256× reduction.

LoRA earns its reputation when models are big. On your local 14M-parameter model, wrapping every nn.Linear still gets the trainable count down to roughly 65k — about 200× fewer trainable than the full model. The relative win only widens as you scale.

1. The LoRA forward pass

Write the formula. The reader gets a fixed W (the "pretrained" matrix), a low-rank A and B (what we'll train), and an alpha scaling factor.

Code · JavaScript

The output of LoRA is mathematically identical to the output of a full fine-tune if the full fine-tuned update happens to be representable as a rank-r matrix. The empirical claim of the LoRA paper is that for most adaptation tasks, the update is approximately low-rank. You almost always lose nothing by constraining yourself.

2. Why this works

The thing being adapted in fine-tuning is usually not a totally new behavior — it's a relatively narrow shift. Speak in a different register. Follow instructions more carefully. Cite sources. These shifts don't need millions of dimensions of expressive capacity; a few hundred dimensions of "delta" is plenty.

Think of W as the model's baseline competence and α · A · B as a small per-task overlay. When you swap tasks, you swap overlays. The base model stays untouched.

3. The Python version

The pattern in PyTorch is to wrap the existing nn.Linear layers of your model with a LoraLinear version. Save as llm/lora.py:

"""llm/lora.py — LoRA adapter wrapper for nn.Linear."""
import math
import torch
import torch.nn as nn
 
class LoraLinear(nn.Module):
    """Wraps an existing nn.Linear with frozen weights + a trainable low-rank update."""
    def __init__(self, base: nn.Linear, r: int = 8, alpha: float = 16.0):
        super().__init__()
        self.base = base
        for p in self.base.parameters():
            p.requires_grad = False  # freeze
        self.r = r
        self.alpha = alpha
        d_in = base.in_features
        d_out = base.out_features
        self.A = nn.Parameter(torch.zeros(d_in, r))
        nn.init.kaiming_uniform_(self.A, a=math.sqrt(5))
        self.B = nn.Parameter(torch.zeros(r, d_out))  # zero init: model starts at base behavior
 
    def forward(self, x):
        return self.base(x) + self.alpha * (x @ self.A @ self.B)
 
def apply_lora(model, r=8, alpha=16):
    """Replace every nn.Linear inside the model with a LoraLinear wrapper."""
    for name, module in model.named_children():
        if isinstance(module, nn.Linear):
            setattr(model, name, LoraLinear(module, r=r, alpha=alpha))
        else:
            apply_lora(module, r=r, alpha=alpha)
    return model

Read the wrapper as “base behavior plus a small learned correction”:

  • self.base = base keeps the original linear layer.
  • Setting requires_grad = False freezes the original weights, so training cannot damage the base model.
  • A maps from input dimension to a tiny rank r.
  • B maps from rank r back to the output dimension.
  • x @ self.A @ self.B is the low-rank delta.
  • self.base(x) + ... means the model starts from the old answer and learns a task-specific adjustment.
  • apply_lora walks the model tree and swaps every nn.Linear it finds.

Two details worth flagging:

  • B initialized to zero. That makes α · A · B = 0 at the start, so the wrapped model produces exactly the base model's output before any training. Training nudges B away from zero gradually.
  • A initialized with Kaiming. Standard small random init, ready to receive gradients.

To use LoRA on your chapter-13 trained model, modify scripts/train.py to load the checkpoint, apply LoRA, then continue training only on the new data:

from llm.model import GPT, GPTConfig
from llm.lora import apply_lora
 
# [1]
model = GPT(GPTConfig())
# [2]
model.load_state_dict(torch.load("checkpoints/model.pt"))
# [3]
model = apply_lora(model, r=8, alpha=16)
 
# only A and B are trainable now
# [4]
trainable = [p for p in model.parameters() if p.requires_grad]
print(f"trainable params: {sum(p.numel() for p in trainable):,}")
# ... rest of the training loop, using `trainable` for the optimizer

The modified training setup changes what gets updated, not the shape of the training loop:

  • [1] recreates the same base model architecture.
  • [2] loads the chapter-13 checkpoint before adding adapters.
  • [3] wraps linear layers with LoRA modules.
  • [4] builds the optimizer input from trainable adapter parameters only, not from all model parameters.
  • Save only the adapter weights if you want the small commercial-friendly checkpoint.

You should see a number a couple orders of magnitude smaller than the full model's parameter count.

4. Checkpoint diet

A LoRA fine-tune of a 1B-parameter model with r = 8 produces a checkpoint of ~16M parameters — about 30 MB. The original 1B-parameter checkpoint is ~2 GB. You can ship one base model + 100 LoRA adapters and have the storage footprint of a few base models.

This is why LoRA dominates open-source fine-tuning. Every domain-specific adapter on Hugging Face is some flavor of LoRA. The original base model stays where it is; people share the deltas.

Recap

  • LoRA = freeze W, add a trainable low-rank update α · A · B. Trainable params go from to 2 · d · r. - r is the rank of the update. Typical values: 4-32 for adaptation tasks, larger for tasks that need more capacity. - B zero-init ensures the wrapped model starts indistinguishable from the base. - Disk savings are huge: a LoRA adapter is 30-200 MB even for billion-parameter base models. - Quality: LoRA fine-tuned models match or come close to full-fine-tuned baselines on most adaptation tasks. The rank assumption usually holds. - Your local project now has llm/lora.py, the adapter pattern used by most cheap fine-tunes.

Going further

Next up: quantization — the second half of "make your model cheap to serve". LoRA cuts training cost; quantization cuts inference cost.