Towards Long-Context Memory
Thank you to Dhruv Pai and Ben Keigwin for discussions that have contributed to this post.
Introduction
Modern-day Large Language Models (LLMs) are variants of the transformer architecture, which alternates sequence mixers (attention blocks) and feature mixers (MLPs). Attention moves information across tokens by comparing a current token’s query to earlier keys those similarities weight the corresponding values, which are added to the residual stream, whereas feature mixers only transform tokens within its hidden dimension. MLPs, by contrast, only transform each token within its hidden dimension. Because attention computes pairwise similarities across the context, it is the primary computational bottleneck as sequence length grows. Long-context computation is essential for problem-solving tasks and code generation, where chain-of-thought models thrive.
This is one reason why LLMs have a maximum context length, typically ~250k for frontier models. The other reason is that many models use Rotary Positional Embeddings (RoPE) to encode the position of a token, which applies a rotation to the complex plane for query and key vectors. However, because such rotations are inherently sinusoidal (i.e. they repeat with a period), so past a certain context length, the phase transformations will begin to wrap. Approaches like YaRN gradually extend RoPE to infinite-length context lengths during pre-training by rescaling position indices, but they do not change the quadratic time/space cost of vanilla attention.
In autoregressive generation, KV caching — storing the newly generated keys and values at each token, and only comparing the new query with previously cached keys — decreases the per-step cost from to . However, we still run into issues: for large , there isn't enough storage on VRAM to cache increasing numbers of high-dimensional key and value vectors. Moreover, to generate those tokens, the total time complexity would be , so KV caching isn't enough.
Linear Attention
Instead of caching every individual key and value, linear attention stores a fixed-size matrix that compresses all of the important past context. Concretely, if our keys are and our values are we can multiply by our computed query and to derive the output . This yields a simple calculation that is not more computationally intensive as increases. But, how is updated?
Vanilla attention uses
to generate the outputs for the next token. To turn this into the form of for some updating matrix we want a function such that This factorization rewrites the attention calculation as
which means that now, using a prefix sum enables updates to the state matrix . A simple example of such a is .
However, errors will accumulate with only finite memory. One significant challenge in long-context tasks is for our model to accurately retrieve specific values, since the quantity of information we want to remember increases linearly, but our state matrix remains fixed. For simplicity, assume is the identity and ignore the (normalizing) denominator. Then,
Suppose we want to retrieve a specific value ; if the keys are normalized, we can multiply by to obtain . However, in practice, multiplying by the matrix also accumulates a residual error given by the second term in
Note that any two randomly selected normalized vectors in dimensional space has an expected dot product similarity of , so retrieval error grows with which becomes prohibitive for large . This matches empirical findings: gated-convolution linear attention architectures underperform transformers on associative recall tasks (Arora et. al., 2023); for example, predicting the next token in “Hakuna Matata means no worries Hakuna Matata it means no _." ; Linear attention will struggle with precision across many thousands of tokens.
Deriving a Regression
The retrieval objective suggested by the calculation above is: learn a matrix such that for past pairs. Thus, to minimize retrieval loss, it seems reasonable to minimize , which turns out to be what the Delta update rule minimizes (Yang, 2024). However, it would be nice to extend this more generally to ensure continued recall across all keys and values; the most general form of this would be
where are input-dependent parameters.
Why do we care so much about key-value retrieval in the first place? We'll revisit this assumption later, but there is substantial evidence that transformers will naturally solve a similar regression during in-context learning. (Oswald et. al, 2023) finds that single attention layers implement the one-step gradient step on and deeper Transformers approach multiple gradient steps on the same loss. Moreover, replacing the linear attention layers with a Mesa-layer that exactly solves the inner optimization improve the trained Transformer's in-context learning performance. But, instead of multiple sub-optimal gradient steps towards the minimum , models would ideally solve the regression exactly at each step. MesaNet leverages this insight to achieve "optimal test-time regression" by computing the optimal solution to
at each token. This weighted ridge regression has the closed-form solution
which enables constant-time per-token updates in , although despite MesaNet's parallelization, the update runs slightly slower than simpler linear attention rules.
So far, the memory has been stored in a matrix. But, an MLP's non-linearity could potentially offer far more mixing and memorization ability, albeit with efficiency costs. ATLAS (Behrouz, 2025) uses the Muon optimizer to optimize the same regression, albeit over a sliding window instead of globally:
Importantly, the regression ATLAS minimizes is devoid of the regularizing trace term, since Muon, by design, seeks to optimally update gradients whilst ensuring numerical stability by bounding the change in the L2 norm which is exactly the same regularization term as MesaNet. Without such guarantees, MesaNet might overfit state matrices to become highly volatile to small changes in the key space, underscoring the need for a regularizer to prevent gradient update explosions.
Regression Alternatives
Now, let's challenge the importance of the aforementioned key-value retrieval regression. Even though the exponentials in softmax attention do amplify differences between query-key similarities, attention layers also serve as sequence mixers which transfer information between tokens. The temperature in attention determines how much mixing vs. precise retrieval the model wishes to perform. Since , which causes where , is the optimal temperature to minimize the aforementioned regression, never occurs, we suspect that solely focusing on retrieval is not correct.
Unfortunately, there isn't a theoretically grounded answer to what the correct combination of mixing vs. retrieval is, and methods today are highly empirical. This sections documents strategies that have seen improvements to long-context memory.
Memory Compression
We return to our first example of vanilla linear attention. One downside, which causes the previously observed retrieval error, is that keys and values are never deleted, causing interference. The secondary issue is that language models have maximum context limits, and we will necessarily need to reset our KV cache after that point, making further contextually-aware inference impossible. For linear attention, an intuitive way to perform compression is through a deep, persistent neural network that generates an output which is incorporated into attention, as in Titans (Behrouz, 2024) which stores contextual information well past the current context window. However, in vanilla attention, the KV cache is continually appended, making the memorization of keys and values in the far past tricky. An Evolved Universal Transformer Memory employs a neural network to compress the KV cache to a fixed size. At each step, it takes in the generated attention matrices and determines whether to append to the cache, replace a previous key-value pair, or merge with existing pairs.
Another intuitive solution is to gradually diminish values in the more distant path. Hierarchical Memory Architectures (HMA) achieve this by organizing memory at multiple granularities; short-term, real memory stores a queue of key-value pairs, whereas synthetic memory comprises of encoded "memory tokens" structured into RMA, mid-term, and long-term memory. At periodic intervals, a window of recent RMA memory tokens are consolidated into mid-term memory, and mid-term memory tokens are consolidated into the long-term memory, creating coarser summaries of the distant past. MEGALODON (Ma, Wang, 2024) uses exponential decay to entirely remove the context window, by using an exponential decay term
Thus, the weighted values will become arbitrarily small for small enough and thus can be safely clipped, providing a more robust way to discard unneeded keys and values.
Persistent Memory
Across long horizons, the line between training and inference itself becomes blurred due to in-context learning, and important keys and values could be integrated into the slow weights, enabling a form of fine-tuning on the specific context provided. Under this lens, vanilla softmax attention begins without any stored keys and values and appends every key-value pair, memory compression starts from scratch and persists vectors as needed, and persistent memory initializes with slow weights that the model appends to during inference time. For example, Titans (Behrouz, 2024) leverages a set of learnable, input-independent parameters appended to the beginning of the sequence, where it is concatenated with contextual memory compressed from previous context windows and the current sequence's attention.
Many existing associative memory architectures lend themselves readily to such integration. Memory Layers at Scale (Berges, Oğuz, 2024) leverage extreme sparsity by learning millions of key-value pairs through product quantization to replace MLP layers, and selecting around a hundred of such pairs to incorporate into the residual stream. It is designed similarly to a Mixture of Experts layer, except with far greater sparsity at the cost of reduced expressivity by retrieving vectors instead of MLPs. Memory Mosaics (Zhang, 2025) also use layers of memory modules comprised of key-value pairs, but these vectors are autoregressively generated instead of learned during pre-training. Fortunately, associative memory modules are flexible enough to support persistent memory that adapts during inference.
We illustratively augment the Memory Layers at Scale architecture to support test-time modifications whilst improving its time complexity from to where is the number of stored pairs and is the number of retrieved pairs. Iteratively build a tree of depth , using our vectors as leaf nodes, such that each group of nodes has a parent node, with a learnable vector as its label, until we have a single root node.
Perform retrieval for a query vector in time with the following algorithm:
- Start at the root node, and iteratively go down by a layer.
- At each layer, save the nodes with labels of highest cosine similarity to the query vector , and recurse down those nodes in the next layer to evaluate the next kb vectors, yielding the desired time complexity as we process at most vectors per layer. At the last layer, take those keys and values as our retrieved top- values.
Note that inserts will be similarly easy to implement, by take an -greedy approach by going down the node with the highest cosine similarity to the key at each step. Load-balancing challenges can be addressed by implementing an exponential moving average of where is the mean of each node’s child vectors (or keys, if the child nodes are leaf nodes) every set amount of steps, and by incorporating Deepseek's Auxillary-Loss-Free Load Balancing (Wang, 2024). Such integrative approaches serve as strict generalizations of attention, which is essential for saving key information across millions of tokens.
Polynomial Feature Maps
- LMUFormer
- Fourier transform
Hybrid Architectures
Finally, we discuss hybrid architectures, where the model chooses how much of each of multiple module outputs to use in a final weighted summation. In gating, For two modules with outputs , a gated combination is differentiable and expressive. Routing works similarly, by computing probabilities with to select one output. The Gumbel-softmax ensures that both modules are differentiable.
Note that gating provides greater expressivity: The 0.5B Falcon H-1 concatenates a traditional attention block and the state space machine Mamba to rival 7B models on benchmarks. However, this gating comes at the cost of increased compute. Routing saves compute, but there is no free lunch either, as we still need to store both modules in memory, and we can suffer lack of gradient signal (these are both issues that plague Mixture of Experts routers).
Gating and routing also commonly show up in sparse attention to decide which tokens the model attends to. Reformer: The Efficient Transformer uses locally sensitive hashing (LSH), which hashes each query and key into a bucket, and only compares tokens in the same bucket for attention.
Another example is DeepSeek's hardware-aligned NSA (Native Sparse Attention), which groups keys and values into blocks (say, of size ) and assigns each block a differentiable centroid, e.g. the mean of its constituents. Then, it scores the dot product between the query and each block's representative key, and selects the blocks with the highest similarities. Additionally, a local sliding window of blocks is selected. Finally, NSA optionally gates the previous attention output with the attention output on the compressed stream (using the representative keys and values for each block).
It's worth mentioning hybrid architectures that combine an autoregressive model with a masked diffusion model (MDM). MDMs are highly parallelizable, providing large upsides to transformers, but they can only generate a fixed number of tokens. Thus, since attention is bidirectional, it's impossible for MDMs to keep a KV cache, and keys and values are recomputed at each denoising step. One hybrid solution is Block Diffusion, which is autoregressive across blocks of tokens, but tokens within each block are generation using discrete diffusion models. Moreover, Esoteric Language Models (Sahoo, Yang et. al, 2025) introduce KV caching.
Future Directions
Back to the Regression
Inspired by approaches mentioned in the Persistent Memory section (a la Titans), concatenating the regression with auxiliary keys and values yields
where the additional keys and values are and . This facilitates context-length extensions, since information from previous context windows can be compressed into the auxiliary values before is reset, and enables the retrieval of specific information learned during training (a la Memory Layers at Scale). Solely optimizing the regression, which corresponds to retrieval, isn't optimal, but similar to varying temperature balances mixing and retrieval, the additional keys and values can serve can enable both retrieval and mixing. Moreover, this framing relates to a "semi-parametric" memory across long contexts: new, important keys and values can be permanently inserted into the regression objective as valuable, persistent data-points, even past the original context length and the current matrix, keys, and values are reset.
Another area for improvement is the selection of norm. Traditionally, the L2 norm is preferred as and the dot product can exploit fused multiply-add paths on GPUs that the L1 norm cannot. However, the efficiency differences between the L2 and L1 norm is not major, and alternative losses have been proposed, e.g. the Huber loss (Behrouz, 2025), which is given by
More hardware-aligned characterizations losses serve as a key area for improvement in linear attention design.
Utilizing External Memory
- RAG
- Million dummies
- SSD
- kNN-LM
- Retro
Data-Architecture Co-design
Unfortunately, high-quality long-context data is scarce. Hierarchical strategies maximize existing data by concatenating shorter chunks of data, and using curriculum learning schedules, which gradually increase difficulty over time similar to how YaRN expands context limits. Synthetic approaches such as He et. al, 2025, whom splits long books into variable-length chunks and generates QA pairs to fine-tune on and gradually increases the number of concatenated chunks up to 1M tokens, have also seen promise.
Conclusion
Efficient long-context reasoning is one of the biggest bottlenecks towards generally intelligent AI. If we trace back the last decade of machine learning, we have evidence for The Bitter Lesson, which posits that in the limit, models that effectively scale with increased computation will win. Self-attention's genius lies in its remarkable parallelization and design with specialized GPU hardware in mind. This enabled actualized performance gains with increased compute and larger model sizes, which were not replicated in RNNs and LSTMs. The Chinchilla scaling laws hammers the point home: despite architectural differences, model perplexity boils down to dataset size and compute (which governs parameter count), and past a certain point, compute is much easier to reliably continue to source than high-quality data. The recent innovation in chain-of-thought models again reiterates this philosophy: it enables the continued trading of compute for performance gains by using generated reasoning tokens as a "scratchpad", in a way that universal transformers of the past couldn't. It's likely that this trend will continue, and optimized architectures for ever-longer chain-of-thought traces will be critical for building fully autonomous ML engineers or researchers.
This is also my first go at technical writing, so I'd love to discuss any thoughts or feedback, and apologizes for the rough patches :)