Abstract
Sequence modeling is currently dominated by causal transformer architectures that
use softmax self-attention. Although widely adopted, transformers require scaling
memory and compute linearly during inference. A recent stream of work linearized
the softmax operation, resulting in powerful recurrent neural network (RNN)
models with constant memory and compute costs such as DeltaNet, Mamba or
xLSTM. These models can be unified by noting that their recurrent layer dynamics
can all be derived from an in-context regression objective, approximately optimized
through an online learning rule. Here, we join this line of work and introduce a
numerically stable, chunkwise parallelizable version of the recently proposed Mesa
layer (von Oswald et al., 2024), which could only run sequentially in time and was
therefore not scalable. This layer again stems from an in-context loss, but which is
now minimized to optimality at every time point using a fast conjugate gradient
solver. Through an extensive suite of experiments study up to the billion-parameter
scale, we show that optimal test-time training enables reaching lower language
modeling perplexity and higher downstream benchmark performance than previous
RNNs, especially on tasks requiring long context understanding. This performance
gain comes at the cost of additional flops spent during inference time. Our results
are therefore intriguingly related to recent trends of increasing test-time compute to
improve performance – here by spending compute to solve sequential optimization
problems within the neural network itself.