We use cookies to analyze the browsing and usage of our website and to personalize your experience. You can disable these technologies at any time, but this may limit certain functionalities of the site. Read our Privacy Policy for more information.
Setting cookies
You can enable and disable the types of cookies you wish to accept. However certain choices you make could affect the services offered on our sites (e.g. suggestions, personalised ads, etc.).
Essential cookies
These cookies are necessary for the operation of the site and cannot be deactivated. (Still active)
Analytics cookies
Do you accept the use of cookies to measure the audience of our sites?
Multimedia Player
Do you accept the use of cookies to display and allow you to watch the video content hosted by our partners (YouTube, etc.)?
Publications
Learning Robust Representations for Transfer in Reinforcement Learning
Learning transferable representations for deep reinforcement learning (RL) is a challenging problem due to the inherent non-stationarity, di… (see more)stribution shift, and unstable training dynamics. To be useful, a transferable representation needs to be robust to such factors. In this work, we introduce a new architecture and training strategy for learning robust representations for transfer learning in RL. We propose leveraging multiple CNN encoders and training them not to specialize in areas of the state space but instead to match each other's representation. We find that learned representations transfer well across many Atari tasks, resulting in better transfer learning performance and data efficiency than training from scratch.
Random feature models are a popular approach for studying network learning that can capture important behaviors while remaining simpler than… (see more) traditional training.
Guth et al. [2024] introduced “rainbow” networks which model the distribution of trained weights as correlated random features conditioned on previous layer activity.
Sampling new weights from distributions fit to learned networks led to similar performance in entirely untrained networks, and the observed weight covariance were found to be low rank.
This provided evidence that random feature models could be extended to some networks away from initialization, but White et al. [2024] failed to replicate their results in the deeper ResNet18 architecture.
Here we ask whether the rainbow formulation can succeed in deeper networks by directly training a stochastic ensemble of random features, which we call stochastic rainbow networks.
At every gradient descent iteration, new weights are sampled for all intermediate layers and features aligned layer-wise.
We find:
(1) this approach scales to deeper models, which outperform shallow networks at large widths;
(2) ensembling multiple samples from the stochastic model is better than retraining the classifier head; and
(3) low-rank parameterization of the learnable weight covariances can approach the accuracy of full-rank networks.
This offers more evidence for rainbow and other structured random feature networks as reduced models of deep learning.
A significant approach in natural language processing involves large-scale pre-training on general domain data followed by adaptation to spe… (see more)cific tasks or domains. As models grow in size, full fine-tuning all parameters becomes increasingly impractical. To address this, some methods for low-rank task adaptation of language models have been proposed, e.g. LoRA and FLoRA. These methods keep the pre-trained model weights fixed and incorporate trainable low-rank decomposition matrices into some layers of the transformer architecture, called adapters. This approach significantly reduces the number of trainable parameters required for downstream tasks compared to full fine-tuning all parameters. In this work, we look at low-rank adaptation from the lens of data privacy. We show theoretically that the low-rank adaptation used in LoRA and FLoRA is equivalent to injecting some random noise into the batch gradients w.r.t the adapter parameters coming from their full fine-tuning, and we quantify the variance of the injected noise. By establishing a Berry-Esseen type bound on the total variation distance between the noise distribution and a Gaussian distribution with the same variance, we show that the dynamics of LoRA and FLoRA are very close to differentially private full fine-tuning the adapters, which suggests that low-rank adaptation implicitly provides privacy w.r.t the fine-tuning data. Finally, using Johnson-Lindenstrauss lemma, we show that when augmented with gradient clipping, low-rank adaptation is almost equivalent to differentially private full fine-tuning adapters with a fixed noise scale.
Neural networks often learn simple explanations that fit the majority of the data while memorizing exceptions that deviate from these explan… (see more)ations. This leads to poor generalization when the learned explanations are spurious. In this work, we formalize
In many domains, such as healthcare, time-series data is irregularly sampled with varying intervals between observations. This creates chall… (see more)enges for classical time-series models that require equally spaced data. To address this, we propose a novel time-series Transformer called **Trajectory Generative Pre-trained Transformer (TrajGPT)**. It introduces a data-dependent decay mechanism that adaptively forgets irrelevant information based on clinical context. By interpreting TrajGPT as ordinary differential equations (ODEs), our approach captures continuous dynamics from sparse and irregular time-series data. Experimental results show that TrajGPT, with its time-specific inference approach, accurately predicts trajectories without requiring task-specific fine-tuning.
In many domains, such as healthcare, time-series data is irregularly sampled with varying intervals between observations. This creates chall… (see more)enges for classical time-series models that require equally spaced data. To address this, we propose a novel time-series Transformer called **Trajectory Generative Pre-trained Transformer (TrajGPT)**. It introduces a data-dependent decay mechanism that adaptively forgets irrelevant information based on clinical context. By interpreting TrajGPT as ordinary differential equations (ODEs), our approach captures continuous dynamics from sparse and irregular time-series data. Experimental results show that TrajGPT, with its time-specific inference approach, accurately predicts trajectories without requiring task-specific fine-tuning.
Linear mode connectivity (LMC) has become a topic of great interest in recent years. It has been empirically demonstrated that popular deep … (see more)learning models trained from different initializations exhibit linear model connectivity up to permutation. Based on this, several approaches for finding a permutation of the model's features or weights have been proposed leading to several popular methods for model merging. These methods enable the simple averaging of two models to create a new high-performance model. However, besides accuracy, the properties of these models and their relationships to the representations of the models they derive from are poorly understood.
In this work, we study the inner mechanisms behind LMC in model merging through the lens of classic feature visualization methods. Focusing on convolutional neural networks (CNNs) we make several observations that shed light on the underlying mechanisms of model merging by permute and average.
Current LLM training positions mathematical reasoning as a core capability. With publicly available sources fully tapped, there is unmet dem… (see more)and for diverse and challenging math questions. Relying solely on human experts is both time-consuming and costly, while LLM-generated questions often lack the requisite diversity and difficulty. We present a design framework that combines the strengths of LLMs with a human-in-the-loop approach to generate a diverse array of challenging math questions. We leverage LLM metacognition skills [Didolkar et al., 2024] of a strong LLM to extract core"skills"from existing math datasets. These skills serve as the basis for generating novel and difficult questions by prompting the LLM with random pairs of core skills. The use of two different skills within each question makes finding such questions an"out of distribution"task for both LLMs and humans. Our pipeline employs LLMs to iteratively generate and refine questions and solutions through multiturn prompting. Human annotators then verify and further refine the questions, with their efficiency enhanced via further LLM interactions. Applying this pipeline on skills extracted from the MATH dataset [Hendrycks et al., 2021] resulted in MATH
Advances in Large Language Models (LLMs) have spurred a wave of LLM library learning systems for mathematical reasoning. These systems aim … (see more)to learn a reusable library of *tools*, such as formal Isabelle lemmas or Python programs that are tailored to a family of tasks. Many of these systems are inspired by the human structuring of knowledge into reusable and extendable concepts, but do current methods actually learn reusable libraries of tools?
We study two library learning systems for mathematics which both reported increased accuracy: LEGO-Prover and TroVE. We find that function reuse is extremely infrequent on miniF2F and MATH. Our followup ablation experiments suggest that, rather than reuse, self-correction and self-consistency are the primary drivers of the observed performance gains.