Ankit Singh Rawat

Ankit Singh Rawat

Authored Publications
Sort By
  • Title
  • Title, descending
  • Year
  • Year, descending
    Faster Cascades via Speculative Decoding
    Seungyeon Kim
    Neha Gupta
    Aditya Menon
    International Conference on Learning Representations (ICLR) 2025
    Preview abstract Cascades and speculative decoding are two common approaches to improving language models' inference efficiency. Both approaches involve interleaving models of different sizes, but via fundamentally distinct mechanisms: cascades employ a deferral rule that invokes the larger model only for "hard" inputs, while speculative decoding uses speculative execution to primarily invoke the larger model in parallel verification mode. These mechanisms offer different benefits: empirically, cascades are often capable of yielding better quality than even the larger model, while theoretically, speculative decoding offers a guarantee of quality-neutrality. In this paper, we leverage the best of both these approaches by designing new speculative cascading techniques that implement their deferral rule through speculative execution. We characterize the optimal deferral rule for our speculative cascades, and employ a plug-in approximation to the optimal rule. Through experiments with T5 and Gemma models on benchmark language tasks, we show that the proposed cascading approach matches the quality of a regular cascade, but at reduced inference costs. View details
    Gating is Weighting: Understanding Gated Linear Attention through In-context Learning
    Yingcong Li
    Maryam Fazel
    Samet Oymak
    Davoud Ataee Tarzanagh
    Conference on Language Modeling (COLM) 2025
    Preview abstract Linear attention methods provide a strong alternative to softmax attention as they allow for efficient recurrent decoding. Recent research has focused on enhancing standard linear attention by incorporating gating while retaining its computational benefits. Such Gated Linear Attention (GLA) architectures include highly competitive models such as Mamba and RWKV. In this work, we examine the in-context learning capabilities of the GLA model and make the following contributions. We show that a multilayer GLA can implement a general class of Weighted Projected Gradient Descent (WPGD) algorithms with data-dependent weights. These weights are induced by the gating and allows the model to control the contribution of individual tokens to prediction. To further understand the mechanics of weighting, we introduce a novel data model with multitask prompts and characterize the optimization landscape of the problem of learning a WPGD algorithm. We identify mild conditions under which there is a unique (global) minimum up to scaling invariance, and the associated WPGD algorithm is unique as well. Finally, we translate these findings to explore the optimization landscape of GLA and shed light on how gating facilitates context-aware learning and when it is provably better than vanilla linear attention. View details
    DistillSpec: Improving speculative decoding via knowledge distillation
    Yongchao Zhou
    Kaifeng Lyu
    Aditya Menon
    Afshin Rostamizadeh
    Jean-François Kagy
    Rishabh Agarwal
    International Conference on Learning Representations (ICLR) (2024)
    Preview abstract Speculative decoding proves highly effective in expediting Large Language Model inference by employing a smaller draft model for token generation and a larger model for parallel token verification. Nonetheless, identifying an accurate and compact draft model aligned with the target model presents challenges. To address this, we propose leveraging white-box knowledge distillation, significantly improving draft model alignment with the larger target model, thereby enhancing speculative decoding. Our findings underscore the pivotal role of on-policy data generation and a suitable divergence function tailored to the task and decoding scheme for successful distillation. In practice, our refined distillation approach yields 20\% speedup over standard speculative decoding across five distinct tasks, using both greedy decoding and temperature sampling. Furthermore, we extend the concept of lossless speculative decoding to incorporate a lenience factor in the rejection sampling step, offering fine-grained control over the trade-off between quality and latency in lossy decoding. Finally, adopting a strategy of "distilling for performance first and distillation for speculative decoding second" enables a remarkable 8x reduction in latency with minimal performance compromise, compared to no distillation and speculative decoding baseline. View details
    Mechanics of Next Token Prediction with Transformers
    Yingcong Li
    Yixiao Huang
    Muhammed Emrullah Ildiz
    Samet Oymak
    International Conference on Artificial Intelligence and Statistics (AISTATS) (2024)
    Preview abstract Transformer-based language models are trained on large datasets to predict the next token given an input sequence. Despite this seemingly simple training objective, they have revolutionized natural language processing within a short timeframe. Underlying this success is the self-attention mechanism. In this work, we ask: What does 1-layer self-attention learn from next-token prediction? We show that when trained with gradient descent, self-attention implements a simple automaton that induces a token hierarchy induced by the training data. Concretely, from the (sequence, label) pairs of the training data, we construct directed next-token graphs (NTGs) of the dataset that capture (input token, label) relations. We find that implicit bias of self-attention is captured by the strongly-connected components (SCCs) which partitions the NTGs into cyclic and acyclic subgraphs: Acyclic subgraph results in an SVM direction that enforces the priority order among SCCs. Cyclic subgraph yields a correction term that allocates the nonzero softmax probabilities among tokens within the same SCC. We empirically and theoretically demonstrate that superposition of these components can accurately predict the implicit bias of gradient descent in next-token prediction. We believe these results shed light on self-attention's ability to process sequential data and pave the path towards demystifying more complex transformer architectures. View details
    A Little Help Goes a Long Way: Efficient LLM Training by Leveraging Small LMs
    Vlad Feinberg
    Afshin Rostamizadeh
    Nikunj Saunshi
    Seungyeon Kim
    Veeru Sadhanala
    Rakesh Shivanna
    Rohan Anil
    Aditya Menon
    Hrayr Harutyunyan
    ArXiv (2024)
    Preview abstract A primary challenge in large language model (LLM) development is their onerous pre-training cost. This paper explores a promising paradigm to improve LLM pre-training efficiency and quality by leveraging a small language model (SLM). In particular, this paradigm relies on an SLM to both (1) provide soft labels as additional supervision, and (2) select a small subset of valuable training examples. Put together, this enables an effective transfer of the SLM's predictive distribution to the LLM, while prioritizing specific regions of the training data distribution. Empirically, this leads to reduced LLM training time compared to standard training, while improving the overall quality. Theoretically, we develop a statistical framework to study the utility of SLMs in enabling efficient training of high-quality LLMs. Our framework characterizes how the SLM's seemingly low-quality supervision can enhance the training of a much more capable LLM. Furthermore, it also highlights the need for an adaptive utilization of such supervision, by striking a balance between the bias and variance introduced by the SLM-provided soft labels. We corroborate our theoretical framework by improving the pre-training of LLMs with 2.8B and 8.6B parameters by utilizing smaller LMs on the Pile dataset. View details
    Dual-Encoders for Extreme Multi-label Classification
    Nilesh Gupta
    Devvrit Khatri
    Inderjit Dhillon
    International Conference on Learning Representations (ICLR) (2024)
    Preview abstract Dual-encoder models have demonstrated significant success in dense retrieval tasks for open-domain question answering that mostly involves zero-shot and few-shot scenarios. However, their performance in many-shot retrieval problems, such as extreme classification, remains largely unexplored. State-of-the-art extreme classification techniques like NGAME use a combination of dual-encoders and a learnable classification head for each class to excel on these tasks. Existing empirical evidence shows that, for such problems, the dual-encoder method's accuracies lag behind the performance of the SOTA extreme classification methods that grow the number of learnable parameters with the number of classes. In this work, we investigate the potential reasons behind this observed gap, such as the intrinsic capacity limit due to fixed model size for dual-encoder models that is independent of the numbers of classes, training, loss formulation, negative sampling, etc. We methodically experiment on these different axes and find that model size is not the main bottleneck, but rather the training and loss formulation. When trained correctly even small dual-encoders can outperform State-of-the-art extreme classification methods by up to 2% at Precision on million label scale extreme classification datasets, while being 20x smaller in terms of the number of trainable parameters. We further propose a differentiable top-k error-based loss function, which can be used to specifically optimize for recall@k metrics. View details
    Preview abstract Recent advances in language model (LM) design has yielded a series of models with remarkably improved quality on complex NLP tasks, but significantly in-creased inference cost. A simple strategy to achieve more favourable cost-quality tradeoffs is cascading: here, a small model is invoked for most “easy” instances, while a large model is invoked for a few “hard” instances. Typically, “easy” in-stances are those where the small model has high confidence in its prediction.While the principles underpinning effective cascading are well-studied for classification problems, a similar understanding is lacking for generative tasks. The ex-tension of simple ”Chow” rule which defers based on the probability of predicting an answer is not straightforward for generative tasks where the number of output tokens is variable. Moreover, LMs are known to suffer from length bias where longer answers are penalized more as compared to shorter answers which complicates things further. In this work, we initiate a systematic study of deferral rules for cascades for language models. For example, how does one best summarise model confidence across a variable number of output tokens? We show experimentally that there is no one straight forward extension of probability based uncertainty for LMs which works well across all tasks. Via experiments on a range of bench-marks with FLAN-T5 models, we find that incorporating token-level uncertainty can significantly improve the cost-quality tradeoff of cascades. We further show that incorporating embeddings from the smaller model and intermediate layer embeddings from the larger model can further boost performance View details
    Think before you speak: Training language models with pause tokens
    Sachin Goyal
    Ziwei Ji
    Aditya Menon
    Vaishnavh Nagarajan
    International Conference on Learning Representations (ICLR) (2024)
    Preview abstract The present-day language model generates its response by producing a series of tokens in immediate succession: the $K+1$th token is an outcome of manipulating exactly $K$ hidden values in each layer corresponding to each of the $K$ previous tokens. Is it possible to somehow allow the model to manipulate more hidden values before committing to an answer? If yes, would this help? We explore these questions by training models with learnable \textit{pause} tokens. Besides feeding the usual prefix to the model, our idea is to feed the model with an additional sequence of pause tokens. On these tokens, the model's output is ignored all the way until the last pause token, where we begin extracting the answer. We explore this idea of ``delayed answering'' in a 1B model, where we consider both pre-training and/or fine-tuning with pause tokens. We find that while merely finetuning a standard model is not very helpful, pause-pretrained models shows promise on some downstream tasks such as GSM (reasoning) and Squad, CommonSenseQA and Lambada (question-answering tasks). We also conduct various ablations to explore the effect of the number of pause tokens. While our work takes a preliminary exploration in delayed computations for language models by focusing on a 1B model, we hope it inspires future work that can make this idea practically feasible without pre-training and for models trained with other pretraining objectives and other sizes. View details
    USTAD: Unified Single-model Training Achieving Diverse Scores for Information Retrieval
    Seungyeon Kim
    Manzil Zaheer
    Veeru Sadhanala
    Sadeep Jayasumana
    Aditya Menon
    Rob Fergus
    International Conference on Machine Learning (ICML) (2024)
    Preview abstract Modern information retrieval (IR) systems consists of multiple stages like retrieval and ranking. Transformers are employed across these different IR stages, achieving state-of-the-art performance, but each model is trained separately leading to complex pipelines and increased cost for maintaining multiple models. The apparent need for separate models is due to different input/output semantics at different stages. In this paper, we challenge this tradition of using separate models as transformers are very expressive models and ask the question would changing just score function suffice? We present a new unified approach - USTAD - to train a single network that can provide powerful ranking scores as cross-encoder (CE) as well as factorized embeddings for large-scale retrieval as a dual-encoder (DE). Empirically, we find a single USTAD model to be competitive to separate ranking CE and retrieval DE models. Furthermore, USTAD enables new distillation techniques, significantly improving CE to DE distillations. Also using USTAD teacher, we can deploy novel asymmetric architectures for student models which realizes better embedding alignment without increasing online inference cost. On standard benchmarks like MSMARCO, we show that our approach successfully distills from both dual-encoder (DE) and cross-encoder (CE) teacher models to 1/10th size asymmetric students that can retain 95-97% of the teacher performance. View details
    A Statistical Framework for Data-dependent Retrieval-Augmented Models
    Soumya Basu
    Manzil Zaheer
    International Conference on Machine Learning (ICML) (2024)
    Preview abstract Modern ML systems increasingly augment input instances with additional relevant information to enhance final prediction. Despite growing interest in such retrieval-augmented models, their fundamental properties and training are not well understood. We propose a statistical framework to study such models with two components: 1) a {\em retriever} to identify the relevant information out of a large corpus via a data-dependent metric; and 2) a {\em predictor} that consumes the input instances along with the retrieved information to make the final predictions. We present a principled method for end-to-end training of both components and draw connections with various training approaches in the literature. Furthermore, we establish excess risk bounds for retrieval-augmented models while delineating the contributions of both retriever and predictor towards the model performance. We validate the utility of our proposed training methods along with the key takeaways from our statistical analysis on open domain question answering task where retrieval augmentation is important. View details
    ×