Ankit Singh Rawat

Ankit Singh Rawat

Authored Publications
Sort By
  • Title
  • Title, descending
  • Year
  • Year, descending
    DistillSpec: Improving speculative decoding via knowledge distillation
    Yongchao Zhou
    Kaifeng Lyu
    Aditya Menon
    Jean-François Kagy
    International Conference on Learning Representations (ICLR) (2024)
    Preview abstract Speculative decoding proves highly effective in expediting Large Language Model inference by employing a smaller draft model for token generation and a larger model for parallel token verification. Nonetheless, identifying an accurate and compact draft model aligned with the target model presents challenges. To address this, we propose leveraging white-box knowledge distillation, significantly improving draft model alignment with the larger target model, thereby enhancing speculative decoding. Our findings underscore the pivotal role of on-policy data generation and a suitable divergence function tailored to the task and decoding scheme for successful distillation. In practice, our refined distillation approach yields 20\% speedup over standard speculative decoding across five distinct tasks, using both greedy decoding and temperature sampling. Furthermore, we extend the concept of lossless speculative decoding to incorporate a lenience factor in the rejection sampling step, offering fine-grained control over the trade-off between quality and latency in lossy decoding. Finally, adopting a strategy of "distilling for performance first and distillation for speculative decoding second" enables a remarkable 8x reduction in latency with minimal performance compromise, compared to no distillation and speculative decoding baseline. View details
    Preview abstract Recent advances in language model (LM) design has yielded a series of models with remarkably improved quality on complex NLP tasks, but significantly in-creased inference cost. A simple strategy to achieve more favourable cost-quality tradeoffs is cascading: here, a small model is invoked for most “easy” instances, while a large model is invoked for a few “hard” instances. Typically, “easy” in-stances are those where the small model has high confidence in its prediction.While the principles underpinning effective cascading are well-studied for classification problems, a similar understanding is lacking for generative tasks. The ex-tension of simple ”Chow” rule which defers based on the probability of predicting an answer is not straightforward for generative tasks where the number of output tokens is variable. Moreover, LMs are known to suffer from length bias where longer answers are penalized more as compared to shorter answers which complicates things further. In this work, we initiate a systematic study of deferral rules for cascades for language models. For example, how does one best summarise model confidence across a variable number of output tokens? We show experimentally that there is no one straight forward extension of probability based uncertainty for LMs which works well across all tasks. Via experiments on a range of bench-marks with FLAN-T5 models, we find that incorporating token-level uncertainty can significantly improve the cost-quality tradeoff of cascades. We further show that incorporating embeddings from the smaller model and intermediate layer embeddings from the larger model can further boost performance View details
    Dual-Encoders for Extreme Multi-label Classification
    Nilesh Gupta
    Devvrit Khatri
    Inderjit Dhillon
    International Conference on Learning Representations (ICLR) (2024)
    Preview abstract Dual-encoder models have demonstrated significant success in dense retrieval tasks for open-domain question answering that mostly involves zero-shot and few-shot scenarios. However, their performance in many-shot retrieval problems, such as extreme classification, remains largely unexplored. State-of-the-art extreme classification techniques like NGAME use a combination of dual-encoders and a learnable classification head for each class to excel on these tasks. Existing empirical evidence shows that, for such problems, the dual-encoder method's accuracies lag behind the performance of the SOTA extreme classification methods that grow the number of learnable parameters with the number of classes. In this work, we investigate the potential reasons behind this observed gap, such as the intrinsic capacity limit due to fixed model size for dual-encoder models that is independent of the numbers of classes, training, loss formulation, negative sampling, etc. We methodically experiment on these different axes and find that model size is not the main bottleneck, but rather the training and loss formulation. When trained correctly even small dual-encoders can outperform State-of-the-art extreme classification methods by up to 2% at Precision on million label scale extreme classification datasets, while being 20x smaller in terms of the number of trainable parameters. We further propose a differentiable top-k error-based loss function, which can be used to specifically optimize for recall@k metrics. View details
    Think before you speak: Training language models with pause tokens
    Sachin Goyal
    Ziwei Ji
    Aditya Menon
    Vaishnavh Nagarajan
    International Conference on Learning Representations (ICLR) (2024)
    Preview abstract The present-day language model generates its response by producing a series of tokens in immediate succession: the $K+1$th token is an outcome of manipulating exactly $K$ hidden values in each layer corresponding to each of the $K$ previous tokens. Is it possible to somehow allow the model to manipulate more hidden values before committing to an answer? If yes, would this help? We explore these questions by training models with learnable \textit{pause} tokens. Besides feeding the usual prefix to the model, our idea is to feed the model with an additional sequence of pause tokens. On these tokens, the model's output is ignored all the way until the last pause token, where we begin extracting the answer. We explore this idea of ``delayed answering'' in a 1B model, where we consider both pre-training and/or fine-tuning with pause tokens. We find that while merely finetuning a standard model is not very helpful, pause-pretrained models shows promise on some downstream tasks such as GSM (reasoning) and Squad, CommonSenseQA and Lambada (question-answering tasks). We also conduct various ablations to explore the effect of the number of pause tokens. While our work takes a preliminary exploration in delayed computations for language models by focusing on a 1B model, we hope it inspires future work that can make this idea practically feasible without pre-training and for models trained with other pretraining objectives and other sizes. View details
    Understanding Self-Attention through Prompt-Conditioned Markov Chains
    Muhammed Emrullah Ildiz
    Yixiao Huang
    Yingcong Li
    Samet Oymak
    International Conference on Machine Learning (ICML) (2024)
    Preview abstract Modern language models rely on the transformer architecture and self-attention mechanism to perform language understanding and text generation. In this work, we study learning a 1-layer self-attention model from a set of prompts and associated output data sampled according to ground-truth weights. As our main contribution, we establish a precise mapping between a self-attention model and a Markov chain through a convex problem formulation: Inputting a prompt to the model samples the output token according to a prompt-conditioned Markov chain which weights the transitions of a base chain. Additionally, incorporating positional encoding results in position-dependent scaling of the chain transitions. Building on this formalism, we develop identifiability/coverage conditions for data distribution that guarantee consistent estimation and establish sample complexity guarantees under IID sampled data. Finally, we study the challenging problem of learning from a single dependent trajectory generated from an initial prompt. Unlike standard Markov chains, we characterize a winner-takes-all phenomenon where the sampling process degenerates into generating a limited subset of tokens due to the non-mixing nature of the attention layer. We argue that this phenomenon explains the tendency of modern LLMs to generate repetitive text and makes consistent estimation from a single-trajectory intricate and problem-dependent -- which we provide a preliminary characterization of. View details
    Preview abstract Modern information retrieval (IR) systems consists of multiple stages like retrieval and ranking. Transformers are employed across these different IR stages, achieving state-of-the-art performance, but each model is trained separately leading to complex pipelines and increased cost for maintaining multiple models. The apparent need for separate models is due to different input/output semantics at different stages. In this paper, we challenge this tradition of using separate models as transformers are very expressive models and ask the question would changing just score function suffice? We present a new unified approach - USTAD - to train a single network that can provide powerful ranking scores as cross-encoder (CE) as well as factorized embeddings for large-scale retrieval as a dual-encoder (DE). Empirically, we find a single USTAD model to be competitive to separate ranking CE and retrieval DE models. Furthermore, USTAD enables new distillation techniques, significantly improving CE to DE distillations. Also using USTAD teacher, we can deploy novel asymmetric architectures for student models which realizes better embedding alignment without increasing online inference cost. On standard benchmarks like MSMARCO, we show that our approach successfully distills from both dual-encoder (DE) and cross-encoder (CE) teacher models to 1/10th size asymmetric students that can retain 95-97% of the teacher performance. View details
    Mechanics of Next Token Prediction with Transformers
    Yingcong Li
    Yixiao Huang
    Muhammed Emrullah Ildiz
    Samet Oymak
    International Conference on Artificial Intelligence and Statistics (AISTATS) (2024)
    Preview abstract Transformer-based language models are trained on large datasets to predict the next token given an input sequence. Despite this seemingly simple training objective, they have revolutionized natural language processing within a short timeframe. Underlying this success is the self-attention mechanism. In this work, we ask: What does 1-layer self-attention learn from next-token prediction? We show that when trained with gradient descent, self-attention implements a simple automaton that induces a token hierarchy induced by the training data. Concretely, from the (sequence, label) pairs of the training data, we construct directed next-token graphs (NTGs) of the dataset that capture (input token, label) relations. We find that implicit bias of self-attention is captured by the strongly-connected components (SCCs) which partitions the NTGs into cyclic and acyclic subgraphs: Acyclic subgraph results in an SVM direction that enforces the priority order among SCCs. Cyclic subgraph yields a correction term that allocates the nonzero softmax probabilities among tokens within the same SCC. We empirically and theoretically demonstrate that superposition of these components can accurately predict the implicit bias of gradient descent in next-token prediction. We believe these results shed light on self-attention's ability to process sequential data and pave the path towards demystifying more complex transformer architectures. View details
    Preview abstract Modern ML systems increasingly augment input instances with additional relevant information to enhance final prediction. Despite growing interest in such retrieval-augmented models, their fundamental properties and training are not well understood. We propose a statistical framework to study such models with two components: 1) a {\em retriever} to identify the relevant information out of a large corpus via a data-dependent metric; and 2) a {\em predictor} that consumes the input instances along with the retrieved information to make the final predictions. We present a principled method for end-to-end training of both components and draw connections with various training approaches in the literature. Furthermore, we establish excess risk bounds for retrieval-augmented models while delineating the contributions of both retriever and predictor towards the model performance. We validate the utility of our proposed training methods along with the key takeaways from our statistical analysis on open domain question answering task where retrieval augmentation is important. View details
    Preview abstract Many modern high-performing machine learning models such as GPT-3 primarily rely on scaling up models, e.g., transformer networks. Simultaneously, a parallel line of work aims to improve the model performance by augmenting an input instance with other (labeled) instances during inference. Examples of such augmentations include task-specific prompts and similar examples retrieved from the training data by a nonparametric component. Remarkably, retrieval-based methods have enjoyed success on a wide range of problems, ranging from standard natural language processing and vision tasks to protein folding, as demonstrated by many recent efforts, including WebGPT and AlphaFold. Despite a growing literature showcasing the promise of these models, the theoretical underpinning for such models remains underexplored. In this paper, we present a formal treatment of retrieval-based models to characterize their generalization ability. In particular, we focus on two classes of retrieval-based classification approaches: First, we analyze a local learning framework that employs an explicit local empirical risk minimization based on retrieved examples for each input instance. Interestingly, we show that breaking down the underlying learning task into local sub-tasks enables the model to employ a low complexity parametric component to ensure good overall accuracy. The second class of retrieval-based approaches we explore learns a global model using kernel methods to directly map an input instance and retrieved examples to a prediction, without explicitly solving a local learning task. View details
    Supervision complexity and its role in knowledge distillation
    Hrayr Harutyunyan
    Aditya Krishna Menon
    International Conference on Learning Representations (ICLR) (2023)
    Preview abstract Knowledge distillation is a popular method of compressing a large teacher model (or an ensemble of models) to a more compact student model. While empirically effective, there is limited understanding of why distillation helps, and how to improve it to transfer richer knowledge from the teacher to student. In this paper, we propose a new online distillation algorithm that applies distillation using a sequence of teacher models, corresponding to different checkpoints during teacher training. Intuitively, this gradually increases the complexity of the target functions that the student model is asked to mimic. Formally, we establish generalization bounds that explicate how the target label complexity can benefit the student. We empirically demonstrate that online distillation can significantly improve over regular offline distillation, particularly in scenarios where there is a large teacher-student capacity gap. View details