421 Publications

Learning sum of diverse features: computational hardness and efficient gradient-based training for ridge combinations

Kazusato Oko, Yujin Song, Taiji Suzuki, D. Wu

We study the computational and sample complexity of learning a target function f∗ : Rd → R with additive structure, that is, f∗(x) = 1 √ M M m=1fm(⟨x,vm⟩), where f1,f2,...,fM : R → R are nonlinear link functions of single-index models (ridge functions) with diverse and near-orthogonal index features vmM m=1, and the number of additive tasks M grows with the dimensionality M ≍dγ forγ ≥ 0. This problem setting is motivated by the classical additive model literature, the recent representation learning theory of two-layer neural network, and large-scale pretraining where the model simultaneously acquires a large number of “skills” that are often localized in distinct parts of the trained network. We prove that a large subset of polynomial f∗ can be efficiently learned by gradient descent training of a two-layer neural network, with a polynomial statistical and computational complexity that depends on the number of tasks M and the information exponent of fm, despite the unknown link function and M growing with the dimensionality. We complement this learnability guarantee with computational hardness result by establishing statistical query (SQ) lower bounds for both the correlational SQ and full SQ algorithms.

Show Abstract

Variational Inference for Uncertainty Quantification: an Analysis of Trade-offs

C. Margossian, L. Pillaud-Vivien, L. Saul

Given an intractable distribution p, the problem of variational inference (VI) is to find the best approximation from some more tractable family Q. Commonly, one chooses Q to be a family of factorized distributions (i.e., the mean-field assumption), even though p itself does not factorize. We show that this mismatch leads to an impossibility theorem: if p does not factorize, then any factorized approximation q∈Q can correctly estimate at most one of the following three measures of uncertainty: (i) the marginal variances, (ii) the marginal precisions, or (iii) the generalized variance (which can be related to the entropy). In practice, the best variational approximation in Q is found by minimizing some divergence D(q,p) between distributions, and so we ask: how does the choice of divergence determine which measure of uncertainty, if any, is correctly estimated by VI? We consider the classic Kullback-Leibler divergences, the more general Rényi divergences, and a score-based divergence which compares ∇logp and ∇logq. We provide a thorough theoretical analysis in the setting where p is a Gaussian and q is a (factorized) Gaussian. We show that all the considered divergences can be

Show Abstract

How Truncating Weights Improves Reasoning in Language Models

Lei Chen, Joan Bruna, A. Bietti

In addition to the ability to generate fluent text in various languages, large language models have been successful at tasks that involve basic forms of logical "reasoning" over their context. Recent work found that selectively removing certain components from weight matrices in pre-trained models can improve such reasoning capabilities. We investigate this phenomenon further by carefully studying how certain global associations tend to be stored in specific weight components or Transformer blocks, in particular feed-forward layers. Such associations may hurt predictions in reasoning tasks, and removing the corresponding components may then improve performance. We analyze how this arises during training, both empirically and theoretically, on a two-layer Transformer trained on a basic reasoning task with noise, a toy associative memory model, and on the Pythia family of pre-trained models tested on simple reasoning tasks.

Show Abstract

Offline supervised learning vs online direct policy optimization: A comparative study and a unified training paradigm for neural network-based optimal feedback control

Yue Zhao, J. Han

This work is concerned with solving neural network-based feedback controllers efficiently for optimal control problems. We first conduct a comparative study of two prevalent approaches: offline supervised learning and online direct policy optimization. Albeit the training part of the supervised learning approach is relatively easy, the success of the method heavily depends on the optimal control dataset generated by open-loop optimal control solvers. In contrast, direct policy optimization turns the optimal control problem into an optimization problem directly without any requirement of pre-computing, but the dynamics-related objective can be hard to optimize when the problem is complicated. Our results underscore the superiority of offline supervised learning in terms of both optimality and training time. To overcome the main challenges, dataset and optimization, in the two approaches respectively, we complement them and propose the Pre-train and Fine-tune strategy as a unified training paradigm for optimal feedback control, which further improves the performance and robustness significantly. Our code is accessible at https://github.com/yzhao98/DeepOptimalControl.

Show Abstract

Contextual Counting: A Mechanistic Study of Transformers on a Quantitative Task

Siavash Golkar, A. Bietti, Mariel Pettee, Michael Eickenberg, et al.

Transformers have revolutionized machine learning across diverse domains, yet understanding their behavior remains crucial, particularly in high-stakes applications. This paper introduces the contextual counting task, a novel toy problem aimed at enhancing our understanding of Transformers in quantitative and scientific contexts. This task requires precise localization and computation within datasets, akin to object detection or region-based scientific analysis. We present theoretical and empirical analysis using both causal and non-causal Transformer architectures, investigating the influence of various positional encodings on performance and interpretability. In particular, we find that causal attention is much better suited for the task, and that no positional embeddings lead to the best accuracy, though rotary embeddings are competitive and easier to train. We also show that out of distribution performance is tightly linked to which tokens it uses as a bias term.

Show Abstract

Crowdsourcing with Difficulty: A Bayesian Rating Model for Heterogeneous Items

Seong Woo Han, Ozan Adıgüzel, B. Carpenter

In applied statistics and machine learning, the "gold standards" used for training are often biased and almost always noisy. Dawid and Skene's justifiably popular crowdsourcing model adjusts for rater (coder, annotator) sensitivity and specificity, but fails to capture distributional properties of rating data gathered for training, which in turn biases training. In this study, we introduce a general purpose measurement-error model with which we can infer consensus categories by adding item-level effects for difficulty, discriminativeness, and guessability. We further show how to constrain the bimodal posterior of these models to avoid (or if necessary, allow) adversarial raters. We validate our model's goodness of fit with posterior predictive checks, the Bayesian analogue of χ

Show Abstract

Neurosift: DANDI exploration and NWB visualization in the browser

J. Magland, J. Soules, Cody Baker, Benjamin Dichter

Neurosift, a browser-based visualization tool, is designed for the interactive exploration of Neurodata Without Borders (NWB) files, whether stored locally, on remote servers, or within the Distributed Archives for Neurophysiology Data Integration (DANDI). NWB (Rübel et al., 2022; Teeters et al., 2015) is an open data standard for neurophysiology that enables the sharing, archiving, and analysis of various types of neurophysiology data. DANDI (Rübel et al., 2022) is a cloud-based platform that supports the storage, sharing, and analysis of neurophysiology data including NWB files. With Neurosift integration, users browsing DANDI can easily open any NWB file in the browser and explore its contents, including timeseries data, images, and more. Neurosift can also be used to browse the DANDI database or individual Dandisets. Overall, Neurosift simplifies the visualization and exploration of complex NWB file structures, making it a valuable tool for neuroscientists.

Show Abstract

Why is parameter averaging beneficial in SGD? An objective smoothing perspective

Atsushi Nitanda, Ryuhei Kikuchi, Shugo Maeda, D. Wu

It is often observed that stochastic gradient descent (SGD) and its variants implicitly select a solution with good generalization performance; such implicit bias is often characterized in terms of the sharpness of the minima. Kleinberg et al. (2018) connected this bias with the smoothing effect of SGD which eliminates sharp local minima by the convolution using the stochastic gradient noise. We follow this line of research and study the commonly-used averaged SGD algorithm, which has been empirically observed in Izmailov et al. (2018) to prefer a flat minimum and therefore achieves better generalization. We prove that in certain problem settings, averaged SGD can efficiently optimize the smoothed objective which avoids sharp local minima. In experiments, we verify our theory and show that parameter averaging with an appropriate step size indeed leads to significant improvement in the performance of SGD.

Show Abstract

Simulation-Based Stacking

Yuling Yao , B. Régaldo-Saint Blancard, Justin Domke

Simulation-based inference has been popular for amortized Bayesian computation. It is typical to have more than one posterior approximation, from different inference algorithms, different architectures, or simply the randomness of initialization and stochastic gradients. With a consistency guarantee, we present a general posterior stacking framework to make use of all available approximations. Our stacking method is able to combine densities, simulation draws, conf idence intervals, and moments, and address the overall precision, calibration, coverage, and bias of the posterior approximation at the same time. We illustrate our method on several benchmark simulations and a challenging cosmological inference task.

Show Abstract

GIST: Gibbs self-tuning for locally adaptive Hamiltonian Monte Carlo

N. Bou-Rabee, B. Carpenter, Milo Marsden

We introduce a novel and flexible framework for constructing locally adaptive Hamiltonian Monte Carlo (HMC) samplers by Gibbs sampling the algorithm's tuning parameters conditionally based on the position and momentum at each step. For adaptively sampling path lengths, this framework -- which we call Gibbs self-tuning (GIST) -- encompasses randomized HMC, multinomial HMC, the No-U-Turn Sampler (NUTS), and the Apogee-to-Apogee Path Sampler as special cases. The GIST framework is illustrated with a novel alternative to NUTS for locally adapting path lengths, evaluated with an exact Hamiltonian for a high-dimensional, ill-conditioned Gaussian measure and with the leapfrog integrator for a suite of diverse models.

Show Abstract
  • Previous Page
  • Viewing
  • Next Page
Advancing Research in Basic Science and MathematicsSubscribe to Flatiron Institute announcements and other foundation updates