Probabilistic modeling often aims to learn and/or sample from a probability distribution. In the specific context of in-context learning, the distribution of interest is oftentimes a conditional distribution where some data D is provided “in-context”:
For concreteness, the in-context data might be text (Brown et al., 2020), synthetic linear regression covariates and tar- gets (Garg et al., 2022), or images and assigned classes (Chan et al., 2022). Directly learning this conditional distribution can be straightforward if the probability distribution can be easily parameterized; for instance, next-token prediction can be readily specified as a classification problem, where the conditional distribution is a categorical distribution parameterized by the model’s output logits. However, this limits the expressivity of in-context learning to situations where the conditional distribution can be straightforwardly parameterized.
In this work, we explore a more general form of in-context learning with no such constraint on how readily the conditional distribution can be specified. We call this more general form in-context learning of energy functions. The key insight is that rather than dealing with the constrained conditional distribution, we instead re-express it in its Boltzmann distribution form (Bishop & Nasrabadi, 2006):
where . This alternative form is preferable because the energy function is an arbitrary unconstrained function that can be used to express any probability distribution without requiring a particular form. We then propose learning the in-context energy function rather than the constrained in-context conditional distribution , which we accomplish by drawing upon well-established ideas in probabilistic modeling called energy-based models (Hinton, 2002; Mordatch, 2018; Du & Mordatch, 2019; Du et al., 2020).
2.1. Learning In-Context Energy Functions
Our goal is to learn the in-context energy function:
What concretely does this mean? We seek a model with parameters that accepts as input a dataset D with arbitrary cardinality and a single datum x, and adaptively changes its output energy function based on the input dataset D without changing its parameters
Figure 1. In-Context Learning of Energy Functions. Transformers learn to compute energy functions corresponding to probability distributions are in-context datasets that vary during pretraining. At inference time, when conditioned on a new in-context dataset, the transformer computes a new energy function using fixed network parameters . The transformers’ energy landscapes progressively sharpen as additional in-context training data are conditioned upon (left to right). Bottom. The energy function can be used to compute a gradient with respect to x that enables sampling higher probability points, without requiring a restricted parametric form for the corresponding conditional probability distribution
For concreteness, in the context of conditional probabilistic modeling, a causal transformer is typically trained to output a conditional probability distribution at every index, i.e.,
Instead of learning each conditional distribution , we instead learn the corresponding energy function . This means that the transformer instead outputs a scalar at every index, regardless of the shape of the inputs:
This scalar at each index is the model’s estimate of the energy at the last () input datum, based on an energy function constructed from the previous datapoints.
To achieve this practically, we use causal GPT-style transformers (Vaswani et al., 2017; Radford et al., 2018; 2019). Just like with standard in-context learning of language models, we train our transformers by minimizing the negative log
# Sample new c o n f a b u l a t e d data using Langevin MCMC. i n i t i a l s a m p l e d d a t a c o n f a b d a t a = sample data with langevin mcmc ( r e a l d a t a , i n i t i a l s a m p l e d d a t a )
# Compute d i f f e r e n c e in energy between r e a l and c o n f a b u l a t o r y data . d i f f o f e n e r g y = e n e r g i e s o n r e a l d a t a − e n e r g i e s o n c o n f a b d a t a
# Compute t o t a l l o s s . t o t a l l o s s = mean ( d i f f o f e n e r g y )
Figure 2. Pseudocode for Training In-Context Learning of Energy Functions.
This equation tells us that we can minimize the negative log likelihood by equivalently minimizing the energy of real data (conditioning upon the in-context data) context while simultaneously maximizing the energy of confabulated data (again conditioning upon the in-context data). Training Python pseudocode is given in Figure 2.
2.2. Sampling From In-Context Energy Functions
To sample from the conditional distribution follow standard practice in energy-based modeling (Hin- ton, 2002; Du & Mordatch, 2019; Du et al., 2020): We first choose N data (deterministically or stochastically) to condition on, and sample for some distribution U to compute the initial energy . We then use Langevin dynamics to iteratively increase the probability of by sampling with and minimizing the energy with respect to
This in-context learning of energy functions is akin to Mor- datch (2018), but rather than conditioning on a “mask” and “concepts”, we instead condition on sequences of data from the same distribution and we additionally replace the all-to-all relational network with a causal transformer.
2.3. Preliminary Experimental Results of In-Context Learning of Energy Functions
As proof of concept, we train causal transformer-based ICLEBMs on synthetic mixture-of-Gaussian datasets. The transformers have 6 layers, 8 heads, 128 embedding dimensions, and GeLU nonlinearities (Hendrycks & Gimpel, 2016). The transformers are pretrained on a set of randomly sampled synthetic 2-dimensional mixture of three Gaussians with uniform mixing proportions with Langevin noise scale 0.01 and 15 MCMC steps of size . After pretraining, we then freeze the ICL-EBMs’ parameters and measure whether the model can adapt its energy function to new in-context datasets drawn from the same distribution as the pretraining datasets. The energy landscapes of frozen ICL EBMs display clear signs of in-context learning (Fig. 1).
To the best of our knowledge, this is the first instance of in-context learning where the input and output spaces differ. This stands in stark comparison with more common examples of in-context learning such as language modeling (Brown et al., 2020), linear regression (Garg et al., 2022) and image classification (Chan et al., 2022). Our method is
similar to that of Mordatch (2018), as well as M¨uller et al. (2022). Our results demonstrate that transformers are more capable of different types of in-context learning than previously known, and our results demonstrate that transformers can successfully learn energy functions rather than probability distributions. Although our results are quite preliminary, we believe this is an exciting direction that can be pushed significantly further.
Limitations Energy-based models require differentiating with respect to network inputs during training, often with tens to hundreds of backwards steps per batch, making training these models significantly more expensive than standard pretraining. Future work should aim to solve this problem.
Bishop, C. M. and Nasrabadi, N. M. Pattern recognition and machine learning, volume 4. Springer, 2006.
Brown, T., Mann, B., Ryder, N., Subbiah, M., Kaplan, J. D., Dhariwal, P., Neelakantan, A., Shyam, P., Sastry, G., Askell, A., et al. Language models are few-shot learners. Advances in neural information processing systems, 33: 1877–1901, 2020.
Chan, S., Santoro, A., Lampinen, A., Wang, J., Singh, A., Richemond, P., McClelland, J., and Hill, F. Data distributional properties drive emergent in-context learning in transformers. Advances in Neural Information Processing Systems, 35:18878–18891, 2022.
Du, Y. and Mordatch, I. Implicit generation and modeling with energy based models. Advances in Neural Information Processing Systems, 32, 2019.
Du, Y., Li, S., Tenenbaum, J., and Mordatch, I. Improved contrastive divergence training of energy based models. arXiv preprint arXiv:2012.01316, 2020.
Garg, S., Tsipras, D., Liang, P. S., and Valiant, G. What can transformers learn in-context? a case study of simple function classes. Advances in Neural Information Processing Systems, 35:30583–30598, 2022.
Hendrycks, D. and Gimpel, K. Gaussian error linear units (gelus). arXiv preprint arXiv:1606.08415, 2016.
Hinton, G. E. Training products of experts by minimizing contrastive divergence. Neural computation, 14(8):1771– 1800, 2002.
Mordatch, I. Concept learning with energy-based models. arXiv preprint arXiv:1811.02486, 2018.
M¨uller, S., Hollmann, N., Arango, S. P., Grabocka, J., and Hutter, F. Transformers can do bayesian inference. In