Ananda Theertha Suresh
Ananda Theertha Suresh is a research scientist at Google. He obtained PhD from University of California, San Diego where he was advised by Prof. Alon Orlitsky. His research interests lie in the intersection of machine learning, information theory, and statistics. More details can be found at theertha.info
Research Areas
Authored Publications
Sort By
Preview abstract
Suppose Alice has a distribution ${P}$ and Bob has a distribution ${Q}$. Alice wants to draw a sample $a\sim {P}$ and Bob a sample $b \sim {Q}$ such that $a = b$ with as high of probability as possible. It is well-known that, by sampling from an optimal coupling between the distributions, Alice and Bob can achieve $\Pr[a = b] = 1 - D_{\text{tv}}({P},{Q})$, where $D_{\text{tv}}({P},{Q})$ is the total variation distance between ${P}$ and ${Q}$.
What if Alice and Bob must solve this same problem \emph{without communicating at all?} Perhaps surprisingly, with access to public randomness, they can still achieve $\Pr[a = b] \geq \frac{1 - D_{\text{tv}}({P},{Q})}{1 + D_{\text{tv}}({P},{Q})} \geq 1-2D_{\text{tv}}({P},{Q})$ using a simple protocol based on the Weighted MinHash algorithm. This bound was shown to be optimal in the worst-case by Bavarian, Ghazi, Haramaty, Kamath, Rivest, and Sudan [ToC 2020].
In this work, we revisit the ``communication-free coupling'' problem. We provide a simpler proof of the optimality result from [Bavarian et al., 2020]. Moreover we show that, while the \emph{worst-case} success probability of Weighted MinHash cannot be improved, an equally simple protocol based on Gumbel sampling offers a Pareto improvement: for every pair of distributions ${P}$ and ${Q}$, Gumbel sampling achieves an equal or higher value of $\Pr[a = b]$ than Weighted MinHash.
Importantly, this improvement translates to practice. We demonstrate an application of communication-free coupling to \emph{speculative decoding}, a recent method for accelerating autoregressive large language models [Leviathan, Kalman, Matias, ICML 2023].
We show that communication-free protocols can be used to contruct \emph{\CSD{}} schemes, which have the desirable property that their output is fixed given a fixed random seed, regardless of what drafter is used for speculation. In experiments on a language generation task, Gumbel sampling outperforms Weighted MinHash.
Code is available at \url{https://github.com/majid-daliri/DISD}.
Finally, we study the coupling problem in the setting where communication is \emph{bounded}, rather than completely eliminated. We describe a protocol that uses just $O(\log(n/\epsilon))$ bits of communication to achieve $\Pr[a = b] = 1 - D_{\text{tv}}({P},{Q}) - \epsilon$, i.e. to essentially match optimal coupling.
View details
Preview abstract
Suppose Alice has a distribution ${P}$ and Bob has a distribution ${Q}$. Alice wants to draw a sample $a\sim {P}$ and Bob a sample $b \sim {Q}$ such that $a = b$ with as high of probability as possible. It is well-known that, by sampling from an optimal coupling between the distributions, Alice and Bob can achieve $\Pr[a = b] = 1 - D_{\text{tv}}({P},{Q})$, where $D_{\text{tv}}({P},{Q})$ is the total variation distance between ${P}$ and ${Q}$.
What if Alice and Bob must solve this same problem \emph{without communicating at all?} Perhaps surprisingly, with access to public randomness, they can still achieve $\Pr[a = b] \geq \frac{1 - D_{\text{tv}}({P},{Q})}{1 + D_{\text{tv}}({P},{Q})} \geq 1-2D_{\text{tv}}({P},{Q})$ using a simple protocol based on the Weighted MinHash algorithm. This bound was shown to be optimal in the worst-case by Bavarian, Ghazi, Haramaty, Kamath, Rivest, and Sudan [ToC 2020].
In this work, we revisit the ``communication-free coupling'' problem. We provide a simpler proof of the optimality result from [Bavarian et al., 2020]. Moreover we show that, while the \emph{worst-case} success probability of Weighted MinHash cannot be improved, an equally simple protocol based on Gumbel sampling offers a Pareto improvement: for every pair of distributions ${P}$ and ${Q}$, Gumbel sampling achieves an equal or higher value of $\Pr[a = b]$ than Weighted MinHash.
Importantly, this improvement translates to practice. We demonstrate an application of communication-free coupling to \emph{speculative decoding}, a recent method for accelerating autoregressive large language models [Leviathan, Kalman, Matias, ICML 2023].
We show that communication-free protocols can be used to contruct \emph{\CSD{}} schemes, which have the desirable property that their output is fixed given a fixed random seed, regardless of what drafter is used for speculation. In experiments on a language generation task, Gumbel sampling outperforms Weighted MinHash.
Code is available at \url{https://github.com/majid-daliri/DISD}.
Finally, we study the coupling problem in the setting where communication is \emph{bounded}, rather than completely eliminated. We describe a protocol that uses just $O(\log(n/\epsilon))$ bits of communication to achieve $\Pr[a = b] = 1 - D_{\text{tv}}({P},{Q}) - \epsilon$, i.e. to essentially match optimal coupling.
View details
Efficient Language Model Architectures for Differentially Private Federated Learning
Zheng Xu
Yanxiang Zhang
Privacy Regulation and Protection in Machine Learning Workshop at ICLR 2024 (2024) (to appear)
Preview abstract
Cross-device federated learning (FL) is a technique that trains a model on data distributed across typically millions of edge devices without data ever leaving the devices.
SGD is the standard client optimizer for on device training in cross-device FL, favored for its memory and computational efficiency.
However, in centralized training of neural language models, adaptive optimizers are preferred as they offer improved stability and performance.
In light of this, we ask if language models can be modified such that they can be efficiently trained with SGD client optimizers and answer this affirmatively.
We propose a scale-invariant \emph{Coupled Input Forget Gate} (SI CIFG) recurrent network by modifying the sigmoid and tanh activations in the recurrent cell
and show that this new model converges faster and achieves better utility than the standard CIFG recurrent model in cross-device FL in large scale experiments.
We further show that the proposed scale invariant modification also helps in federated learning of larger transformer models.
Finally, we demonstrate the scale invariant modification is also compatible with other non-adaptive algorithms.
Particularly, our results suggest an improved privacy utility trade-off in federated learning with differential privacy.
View details
FEDAQT: ACCURATE QUANTIZED TRAINING WITH FEDERATED LEARNING
Renkun Ni
Yonghui Xiao
Oleg Rybakov
Phoenix Meadowlark
Tom Goldstein
2024
Preview abstract
Federated learning has been widely used to train automatic speech recognition models, where the training procedure is decentralized to client devices to avoid data privacy concerns by keeping the training data locally. However, the limited computation resources on client devices prevent training with large models. Recently, quantization-aware training has shown the potential to train a quantized neural network with similar performance to the full-precision model while keeping the model size small and inference faster. However, these quantization methods will not save memory during training since they still keep the full-precision model. To address this issue, we propose a new quantization training framework for federated learning which saves the memory usage by training with quantized variables directly on local devices. We empirically show that our method can achieve comparable WER while only using 60% memory of the full-precision model.
View details
Preview abstract
We propose a practical maximum-likelihood-estimation framework for regression as an alternative to the typical approach of Empirical Risk Minimization (ERM) over a specific loss metric. Our approach is better suited to capture inductive biases in datasets, and can output post-hoc estimators at inference time that can optimize different types of loss metrics. We present theoretical evidence (in the fixed design setting) to demonstrate that our approach is always competitive with using ERM over the loss metric, and in many practical scenarios can be much superior to ERM. For time series forecasting, we propose an end-to-end MLE based training and inference approach that can flexibly capture various inductive biases, and optimize prediction accuracy for a variety of typical loss metrics, without having to choose a specific loss metric at training time. We demonstrate empirically that our method instantiated with a well-designed general purpose likelihood can obtain superior performance over ERM for a variety of time-series forecasting and regression datasets with different inductive biases and data distributions.
View details
Scaling Language Model Size in Cross-Device Federated Learning
FL4NLP@ACL2022 (2022) (to appear)
Preview abstract
Most studies in cross-device federated learning focus on small models, due to the server-client communication and on-device computation bottlenecks. In this work, we leverage various techniques for mitigating these bottlenecks to train larger language models in cross-device federated learning. With systematic applications of partial model training, quantization, efficient transfer learning, and communication-efficient optimizers, we are able to train a 21M parameter Transformer that achieves the same perplexity as that of a similarly sized LSTM with ~10x smaller client-to-server communication cost and 11% lower perplexity than smaller LSTMs commonly studied in literature.
View details
Preview abstract
We study multiple-source domain adaptation, when the learner has
access to abundant labeled data from multiple source domains and
limited labeled data from the target domain. We analyze existing
algorithms and propose an instance-optimal approach based on model
selection. We provide efficient algorithms and empirically demonstrate
the benefits of our approach.
View details
FedJAX: Federated learning simulation with JAX
Ke Wu
1st NeurIPS Workshop on New Frontiers in Federated Learning (NFFL 2021) (2021)
Preview abstract
Federated learning is a machine learning technique that enables training across decentralized data. Recently, federated learning has become an active area of research due to an increased focus on privacy and security. In light of this, a variety of open source federated learning libraries have been developed and released.
We introduce FedJAX, a JAX-based open source library for federated learning simulations that emphasizes ease-of-use in research. With its simple primitives for implementing federated learning algorithms, prepackaged datasets, models and algorithms, and fast simulation speed, FedJAX aims to make developing and evaluating federated algorithms faster and easier for researchers. Our benchmark results show that FedJAX can be used to train models with federated averaging on the EMNIST dataset in a few minutes and the Stack Overflow dataset in roughly an hour with standard hyperparameters using TPUs.
View details
Approximating probabilistic models as weighted finite automata
Vlad Schogol
Computational Linguistics, 47 (2021), pp. 221-254
Preview abstract
Weighted finite automata (WFA) are often used to represent probabilistic models, such as n-
gram language models, since they are efficient for recognition tasks in time and space. The
probabilistic source to be represented as a WFA, however, may come in many forms. Given
a generic probabilistic model over sequences, we propose an algorithm to approximate it as a
weighted finite automaton such that the Kullback-Leiber divergence between the source model
and the WFA target model is minimized. The proposed algorithm involves a counting step and a
difference of convex optimization step, both of which can be performed efficiently. We demonstrate
the usefulness of our approach on various tasks, including distilling n-gram models from neural
models, building compact language models, and building open-vocabulary character models. The
algorithms used for these experiments are available in an open-source software library.
View details
Preview abstract
We present a theoretical and algorithmic study of the multiple-source domain adaptation problem in the common scenario where the learner has access only to a limited amount of labeled target data, but where he has at his disposal a large amount of labeled data from multiple source domains. We show that a new family algorithms based on model selection ideas benefit from very favorable guarantees in
this scenario and discuss some theoretical obstacles affecting some alternative techniques. We also report the results of several experiments with our algorithms that demonstrate their practical effectiveness in several tasks
View details