Building Context with Neurons (RNNs)

Vanilla recurrent networks: sequential context, the gradient problem, why they fail past ~50 tokens.

Architecture Intermediate
10 min read
architecture rnn sequence-models foundational-paper

Origin and intuition#

Feed-forward networks have a fixed-shape input. To classify a 28×28 image you flatten it to 784 numbers and feed that vector to the first layer. To classify a sentence you have a problem: sentences are variable-length, and the meaning of word number 17 depends on words 1 through 16. You could pad to a max length and treat every position as an independent input — that throws away order. You could concatenate everything into one long vector — the network has no way to know that two patterns at different offsets are the same pattern.

Recurrent neural networks (RNNs), in the form Elman proposed in 1990, fixed this by making the network stateful. At each step t, the network reads input x_t and a hidden state h_{t-1} produced at the previous step, computes a new hidden state h_t, and emits an output y_t. The same weights are reused at every step. The hidden state is a learned summary of “everything I’ve seen so far,” and the output at step t can depend on the entire history compressed into that vector.

This was the dominant architecture for sequence modeling from roughly 1990 to 2014. Word-level language models, character-level language models, machine translation (the original sequence-to-sequence paper used LSTMs which are gated RNNs), speech recognition, and handwriting generation all ran on recurrent backbones. The intuition is appealing — humans read left-to-right, accumulating meaning — and the parameter count is tiny because the same weights are reused at every step.

The architecture has two structural problems that limit how far it can scale. Training is sequential because step t depends on step t-1 — you can’t parallelize across positions, which became a bigger and bigger handicap as GPUs widened. And the hidden state is a fixed-size vector trying to summarize an arbitrarily long history, which empirically degrades past a few dozen tokens. The transformer paper of 2017 was, more than anything, an answer to these two constraints.

Inputs and outputs#

A vanilla RNN consumes a sequence x_1, x_2, ..., x_n of input vectors (one per timestep) and produces a sequence h_1, h_2, ..., h_n of hidden states, plus optional outputs y_1, ..., y_n. The input vectors are typically word or character embeddings looked up from an embedding table.

Four canonical usage modes:

  • One-to-many. One input, sequence of outputs. Image captioning: feed an image embedding to h_0, generate a caption autoregressively.
  • Many-to-one. Sequence of inputs, one output. Sentiment classification: read the sentence, take h_n (the final hidden state), feed it to a classifier head.
  • Many-to-many, aligned. Sequence in, sequence out, same length. Named entity recognition, part-of-speech tagging: predict a label per token.
  • Many-to-many, unaligned. Sequence in, sequence out, different lengths. Machine translation: read the source sentence, then start producing the target. This is what the encoder-decoder framework was built for.

For language modeling specifically, the setup is many-to-many aligned: at each position t, the RNN predicts a distribution over the next token given everything seen so far.

Architecture diagram#

The defining feature is the loop. Unrolled across timesteps, an RNN looks like a feed-forward network with shared weights and a hidden state threaded through:

┌──────────────────────────┐
│ shared weights W, U, V │
└──────────────────────────┘
x_1 x_2 x_3 x_n
│ │ │ │
▼ ▼ ▼ ▼
┌────┐ h_1 ┌────┐ h_2 ┌────┐ h_3 ┌────┐
│ U │───────▶│ U │───────▶│ U │── ... ───▶│ U │
│ +W │ │ +W │ │ +W │ │ +W │
└────┘ └────┘ └────┘ └────┘
│ │ │ │
▼ ▼ ▼ ▼
V V V V
│ │ │ │
▼ ▼ ▼ ▼
y_1 y_2 y_3 y_n

At each step, the cell computes:

h_t = tanh(W · x_t + U · h_{t-1} + b_h)
y_t = V · h_t + b_y

Three weight matrices, reused at every timestep: W (input projection), U (recurrence), V (output projection). The activation is typically tanh (or ReLU, which has its own problems for RNNs). The hidden state h_t is a vector of size d_h, typically 128 to 1024.

The horizontal arrow carrying h from one step to the next is what makes this recurrent. The vertical arrows are the same as a feed-forward layer. The shared weights are why the model can generalize across sequence positions — pattern recognized at position 5 will be recognized at position 50, because the same W and U process both.

Training objective#

For language modeling, the objective is next-token prediction with cross-entropy loss, summed over positions:

L = - Σ_t log P(x_{t+1} | x_1, ..., x_t)

This is the same objective decoder-only transformers use today — only the architecture computing the probabilities changed.

Training uses backpropagation through time (BPTT): unroll the RNN across the entire sequence, compute the loss, and backpropagate. Because the same weights are used at every timestep, gradients accumulate: the gradient of the loss with respect to W is a sum of contributions from every position. This is what lets the model learn long-range patterns in principle.

In practice, full BPTT across a long document is prohibitive — you’d hold the entire activation history in memory. Truncated BPTT chunks the sequence into segments (typically 35-100 tokens) and backpropagates within each chunk while preserving the hidden state across chunks. This is a compromise: the forward pass sees long context, but gradients only flow back a chunk’s worth.

The deeper problem is what happens to the gradients as they flow backward through many timesteps.

Variants and refinements#

The 1990s and early 2000s produced a small zoo of recurrent variants trying to fix the vanishing-gradient problem before LSTMs settled it:

  • Vanilla RNN (Elman). The setup above. Trains, sort of. Empirically struggles past ~10-20 timesteps.
  • Jordan networks. Feed the output back as the recurrent signal instead of the hidden state. Slightly different inductive bias; never caught on.
  • Bidirectional RNNs. Run one RNN left-to-right, another right-to-left, concatenate the hidden states. Used for tagging tasks where you can see the whole sequence at training time. The BERT-style bidirectional idea predates BERT by two decades, just with worse architectures.
  • Echo state networks / reservoir computing. Randomly initialize the recurrent matrix, freeze it, only train the output layer. Cheap to train, surprisingly competitive on small problems. A theoretical curiosity now.
  • Identity-init RNNs (IRNN). Initialize U to the identity and use ReLU. Trained better than vanilla RNNs without gates, briefly. Superseded by LSTMs/GRUs.
Vanilla RNN — three matrices, one nonlinearity, hidden state recomputed from scratch every step. Tiny parameter count. Useful context: ~10-50 tokens. Vanishing gradients dominate past that.
LSTM / GRU — gated cell with explicit memory pathway. ~4× the parameters of vanilla RNN. Useful context: hundreds of tokens. The gradient highway through the cell state is the entire difference.

Practical considerations#

The vanishing / exploding gradient problem. When you backpropagate through a chain of n recurrent steps, the gradient of the loss with respect to a weight involves a product of n Jacobians of the recurrence. If the largest eigenvalue of the recurrent Jacobian is consistently below 1, this product shrinks geometrically — the gradient vanishes, and the model can’t learn long-range dependencies. If above 1, it grows geometrically — the gradient explodes, and training diverges.

Bengio, Simard and Frasconi proved in 1994 that with bounded activations like tanh, vanilla RNNs can’t simultaneously be stable in the forward pass and propagate gradients across long chains. The signal has to decay. This is not a hyperparameter problem — it’s a structural one. It motivated the LSTM design two years later.

Mitigations that partially help:

  • Gradient clipping by global norm. The standard fix for exploding gradients. Usually clipped at norm < 5 or norm < 1.
  • Careful initialization. Orthogonal initialization of U keeps its eigenvalues at 1 at init, slowing the decay. Helps a little.
  • Skip connections. Identity shortcuts through time. Anticipates the LSTM cell-state highway.
  • Layer normalization on the recurrent step. Stabilizes training. Doesn’t solve the long-range problem.

Useful context length. Empirically, vanilla RNNs handle ~10-50 tokens of meaningful context on language tasks. Past that, performance plateaus regardless of hidden-state size. LSTMs and GRUs push this to hundreds of tokens with the same gradient-flow mechanism. Transformers blew past both because attention is a one-hop lookup — no Jacobian product through time at all.

Training cost shape. The sequential nature means you can’t parallelize across the sequence. You can parallelize across the batch dimension. So RNN training scales with (batch size × hidden dim²), and a single sequence’s processing time scales linearly with sequence length. On a GPU built for parallel matrix multiplies this is an inefficient use of silicon — the GPU is mostly idle waiting for the previous step. The transformer’s O(n²) attention is a worse FLOPs cost in theory but a much better fit for actual hardware.

Real-world deployments#

Vanilla RNNs are essentially historical now for any task with sequences longer than a few dozen tokens. Where you find them:

  • Pedagogical implementations. Karpathy’s char-rnn (2015) is still the most-linked introduction to language modeling. It’s a vanilla LSTM in practice but the structure generalizes.
  • Tiny embedded models. Keyword-spotting ("Hey Siri", "Alexa") on microcontrollers sometimes uses small RNNs because the constant per-step memory beats transformers’ KV cache for streaming audio at sub-watt power budgets.
  • Some classical NLP libraries. Older versions of spaCy used bidirectional RNNs for tagging; current versions use transformers but still ship the RNN paths.
  • Reinforcement-learning policies. Older Atari-era RL papers used RNN policy networks (e.g., A3C with LSTM) for partial-observability. Replaced by transformer policies in most recent work.

The 2014 sequence-to-sequence paper using stacked LSTMs for translation — the direct ancestor of modern encoder-decoder architectures — was the high-water mark of recurrent NLP. By 2017 attention had reduced LSTMs to “the thing under the attention layer”, and by 2018 the transformer had removed them entirely.

Why this architecture matters even though nobody trains vanilla RNNs anymore

Three reasons. First, the vanishing-gradient analysis from RNNs is the foundational result on why deep networks are hard to train — every subsequent architecture (residual connections, LSTMs, layer norm, transformers’ pre-norm) is in some sense a response to it. Second, the language-modeling objective P(x_{t+1} | x_1, ..., x_t) was first formulated and operationalized on RNNs; modern decoder-only transformers compute the exact same probability with a different function approximator. Third, state-space models (Mamba, RWKV, Hyena) are essentially modern RNNs with better gradient flow and matrix-form parallel training. The shape is back — just not the 1990 version of it. Understanding what broke in vanilla RNNs is the cleanest way to understand why the things that replaced them are shaped the way they are.

Search ESC

Keyboard shortcuts

Shortcuts are disabled while typing in inputs.