How Do Models Learn?

Gradient descent, backpropagation, loss functions, and the optimization loop. The engine under every neural network.

Concept Foundational
7 min read
optimization gradient-descent backpropagation training-loop

Summary#

A neural network learns by following the gradient of a loss function down to a configuration of parameters where its predictions are mostly right. That’s the whole loop: forward pass to compute a prediction, loss to score how wrong it was, backward pass (backpropagation) to compute the derivative of the loss with respect to every weight, optimizer step to nudge each weight in the direction that reduces loss. Repeat for trillions of tokens.

What scales this from a textbook exercise to a foundation-model training run is mini-batched stochastic gradient descent with momentum-aware optimizers (Adam, AdamW, Lion, Shampoo), mixed-precision arithmetic (bfloat16 + fp32 master copies), and parallelism schemes (data, tensor, pipeline, expert, sequence) that split the work across thousands of accelerators. The math is unchanged from 1986; the engineering around it is what got us to GPT-4 scale.

Why it matters#

You don’t have to implement Adam by hand to ship a model, but you do have to understand it. Almost every training failure an engineer hits — loss not decreasing, loss exploding, gradient norms NaN-ing, fine-tune destroying the base model — has its root in the optimizer-loss-gradient loop. If you can read a learning-rate schedule and a gradient-norm chart, you can debug 80% of training-time problems. If you can’t, every training run becomes opaque.

Inference engineers also need this. Quantization, distillation, and LoRA all manipulate the learned parameters and only make sense if you understand what those parameters represent (a point in a loss landscape, found by the optimizer, whose precision and rank you can trade against). Knowing how the model got there tells you what you can safely change.

How it works#

The forward pass: compute a prediction#

Given input x and current parameters θ, the model computes a prediction ŷ = f(x; θ). For a language model, x is a sequence of token IDs and ŷ is a probability distribution over the next token for each position. For a vision model, x is an image and ŷ might be a class distribution or a pixel-by-pixel reconstruction.

The forward pass is a chain of matrix multiplications, nonlinearities (GELU, SwiGLU), normalizations (LayerNorm, RMSNorm), and attention operations. Each operation’s output is held in memory because the backward pass needs it.

The loss: score the prediction#

A loss function L(ŷ, y) measures how wrong the prediction was. For autoregressive language models, the standard loss is cross-entropy — the negative log probability the model assigned to the correct next token, averaged over every position. For diffusion models, the loss is the squared error between predicted noise and true noise. For contrastive learners (CLIP), it’s InfoNCE — pulling matching pairs together and pushing mismatched pairs apart in embedding space.

The choice of loss determines what the model optimizes for, which is not always what you actually want. Cross-entropy treats all wrong tokens as equally bad even though “the” instead of “a” is much less wrong than “purple” instead of “Paris”. This mismatch is one reason post-training (RLHF, DPO) exists.

Backpropagation: compute gradients#

Backpropagation is the chain rule applied to the computational graph of the forward pass. Starting from dL/dŷ, it walks the graph backward and computes dL/dθᵢ for every parameter θᵢ. The cost is roughly 2x the forward pass in compute and 1x in memory (the saved activations).

The optimizer step: update the weights#

An optimizer takes the gradients and produces a weight update. The basic version is SGD: θ ← θ - η · ∇L, where η is the learning rate. In practice everyone uses Adam or AdamW, which keep running estimates of the gradient’s first moment (mean) and second moment (variance) and use them to compute a per-parameter adaptive learning rate.

AdamW is the variant where weight decay is applied directly to the parameters rather than folded into the gradient — empirically more stable for language models. Frontier labs increasingly use Lion (signSGD-like, less memory) or Shampoo (second-order, better-conditioned).

The schedule: learning rate over time#

A flat learning rate doesn’t work for large models. The standard schedule is warmup-then-cosine-decay: ramp the learning rate from 0 to peak over the first few thousand steps, then decay it on a cosine curve to ~10% of peak by the end of training. Warmup prevents instability when the optimizer’s variance estimates haven’t yet converged; cosine decay lets the model settle into a low-loss region without overshoot.

Variants and trade-offs#

SGD + momentum — small memory footprint (just one extra buffer per parameter), well-understood theoretically, works for vision CNNs. Slow to converge for transformers; doesn’t adapt to per-parameter gradient scale.
AdamW — adaptive per-parameter learning rate, fast convergence on transformers, the de-facto choice for LLM pretraining. 2x the memory footprint of SGD (two extra buffers per parameter — first and second moments). Lion / Shampoo are recent variants chasing better-per-byte alternatives.

Other axes that matter:

  • Mixed precision. bfloat16 forward and backward with fp32 master weights is the default. Memory and bandwidth roughly halved, with no measurable quality loss on modern hardware (H100, TPU-v5). fp16 is older and needs loss-scaling to avoid underflow; bfloat16 has the same exponent range as fp32 and “just works”.
  • Gradient accumulation. When a batch is too large to fit in memory, accumulate gradients over micro-batches and step once at the end. Equivalent to a larger batch at the cost of more forward passes.
  • Gradient clipping. Cap the global gradient norm (typically at 1.0) to prevent rare exploding-gradient batches from destabilizing the model. Cheap, ubiquitous, occasionally load-bearing.
  • Activation checkpointing. Trade compute for memory: don’t save activations on the forward, recompute them on the backward. Roughly 30% more compute, often >50% memory saved. Essential at large scale.
Why 'just train longer' isn't a complete strategy

The Chinchilla scaling laws (DeepMind, 2022) showed that for a given compute budget, the optimal ratio of parameters to training tokens is roughly 1:20 — a 70B-parameter model should see about 1.4T tokens. Earlier models (GPT-3, 175B trained on 300B tokens) were under-trained: they had more parameters than their compute budget could optimally use. Today’s frontier models are often trained well past Chinchilla-optimal (Llama-3 70B on 15T tokens) because inference cost matters too — a smaller model trained longer is cheaper to serve. The optimizer loop is the same; the budgeting around it changed.

When this is asked in interviews#

This question is the gating “do you know how neural networks actually work” check on ML-engineering loops. It also appears on AI-platform loops when the team has to maintain training infrastructure.

What the interviewer is checking:

  1. Can you walk the forward-loss-backward-optimizer-step loop without notes.
  2. Do you know the difference between SGD and Adam well enough to pick one for a real situation.
  3. Can you reason about a failing training run from gradient-norm + loss-curve evidence — overfitting vs underfitting vs unstable optimizer vs bad data.

Common follow-ups:

  • “What happens if the learning rate is too high?” — loss oscillates, gradient norms spike, often a NaN within a few hundred steps. Lower the LR, add warmup, clip gradients.
  • “What’s the difference between batch size and learning rate?” — larger batches give lower-variance gradient estimates, which allows a higher learning rate. The standard linear-scaling rule: double batch, double LR (within limits — past a critical batch size the relationship breaks).
  • “Why does Adam need so much more memory than SGD?” — first and second moment buffers, each the size of the model. For a 70B-parameter model in bfloat16, that’s an extra ~280GB just for optimizer state.
Search ESC

Keyboard shortcuts

Shortcuts are disabled while typing in inputs.