Chapitre 22 · 14 min
Appendice · La backprop à la main
Dérive la backprop sur un réseau à 2 couches étape par étape, puis vérifie le résultat contre loss.backward() de PyTorch pour que la magie cesse d’être magique.
Tu as appelé loss.backward() au chapitre 5, puis à chaque chapitre suivant. Ça remplit silencieusement .grad sur chaque du modèle. Cet appendice déballe ce que cette ligne fait, en dérivant à la main les d’un petit réseau, puis en vérifiant contre PyTorch.
Une fois fait, chaque loss.backward() ultérieur cesse d’être magique.
1. Le réseau
Le plus petit réseau intéressant : 2 entrées, 1 unité cachée, 1 sortie. sur le caché, identité en sortie. quadratique contre une seule cible.
Trois (w₁, w₂, v), quatre intermédiaires (z, h, ŷ, L), deux entrées et une cible.
Valeurs concrètes :
| valeur | |
|---|---|
| 1.0, 0.5 | |
| 0.4, 0.6 | |
| 0.8 | |
| 1.0 |
2. Passe avant à la main
Huit opérations arithmétiques pour transformer deux entrées en un seul nombre.
3. Passe arrière par la règle de la chaîne
On veut ∂L/∂v, ∂L/∂w₁, ∂L/∂w₂ — un nombre par . La règle de la chaîne remonte le graphe.
Étape 1. Sensibilité de la loss à la sortie :
Étape 2. Pousser à travers v :
v est fini. Pour w₁ et w₂, le chemin est plus long.
Étape 3. Pousser à travers h, puis à travers z. La dérivée de est h(1 − h) :
Étape 4. Enfin à travers les linéaires z = w₁ x₁ + w₂ x₂ :
Trois nombres : ∂L/∂v ≈ -0.6220, ∂L/∂w₁ ≈ -0.1651, ∂L/∂w₂ ≈ -0.0825.
Écris la chain rule toi-même. La cellule pré-remplit la passe avant ; tu remplis les quatre lignes de la passe arrière et regardes les gradients tomber :
Code · JavaScript
4. Vérification avec PyTorch
"""check_backprop.py — verify hand-derived gradients match PyTorch."""
import math
import torch
x = torch.tensor([1.0, 0.5])
y = torch.tensor(1.0)
w = torch.tensor([0.4, 0.6], requires_grad=True)
v = torch.tensor(0.8, requires_grad=True)
z = (w * x).sum()
h = torch.sigmoid(z)
y_hat = v * h
loss = (y_hat - y) ** 2
loss.backward()
print(f"forward: z={z.item():.4f} h={h.item():.4f} y_hat={y_hat.item():.4f} loss={loss.item():.4f}")
print(f"grads: w1={w.grad[0].item():.4f} w2={w.grad[1].item():.4f} v={v.grad.item():.4f}")
expected = {"w1": -0.1651, "w2": -0.0825, "v": -0.6220}
assert math.isclose(w.grad[0].item(), expected["w1"], abs_tol=1e-3)
assert math.isclose(w.grad[1].item(), expected["w2"], abs_tol=1e-3)
assert math.isclose(v.grad.item(), expected["v"], abs_tol=1e-3)
print("✓ hand-derived gradients match PyTorch")Sortie attendue :
forward: z=0.7000 h=0.6682 y_hat=0.5346 loss=0.2166
grads: w1=-0.1651 w2=-0.0825 v=-0.6220
✓ hand-derived gradients match PyTorchLes nombres collent à quatre décimales. C’est ce que loss.backward() fait pour chaque de chaque modèle du livre. Le mécanisme est identique ; seul le graphe est plus gros.
5. Pourquoi ça généralise
Un vrai modèle a des millions à des milliards de . La dérivation à la main ne passe pas à l’échelle, mais la règle de la chaîne, oui :
- Le graphe de calcul est construit automatiquement pendant la passe avant. Chaque opération enregistre ses entrées et sa dérivée locale.
- La passe arrière parcourt le graphe de la
lossvers chaque , multipliant les dérivées locales. Un nombre par , écrit dans.grad. - Le coût est du même ordre que la passe avant — 2 à 3× plus cher, pas exponentiellement plus. C’est toute la raison pour laquelle l’ moderne est possible.
Quand le module d’ du chapitre 8 appelle loss.backward(), le graphe parcouru a des centaines d’opérations et des dizaines de milliers de , mais chaque arête individuelle est l’une des règles simples de cet appendice.
Recap
loss.backward()parcourt un graphe de calcul à l’envers et applique la règle de la chaîne arête par arête.- Chaque opération contribue une dérivée locale. La passe avant mémorise ; la passe arrière utilise.
- La dérivation à la main colle à PyTorch à 4 décimales sur un réseau à 3 . Le même mécanisme s’étend à des milliards de .
- Un nombre par sort, écrit dans
.grad. L’optimiseur lit ces nombres pour déplacer les poids.
Pour aller plus loin
- 3blue1brown, « Backpropagation calculus ».
- micrograd de Karpathy — bibliothèque autodiff de 150 lignes en Python. À lire de bout en bout au moins une fois.
- Notes autograd de PyTorch.