-
[SA-VAE] Semi-Amortized Variational AutoencodersResearch/Generative Model 2024. 4. 8. 09:02
※ https://arxiv.org/pdf/1802.02550.pdf
Abstract
Amortized variational inference (AVI) replaces instance-specific local inference with a global inference network. While AVI has enabled efficient training of deep generative models such as variational autoencoders (VAE), recent empirical work suggests that inference networks can produce suboptimal variational parameters. We propose a hybrid approach, to use AVI to initialize the variational parameters and run stochastic variational inference (SVI) to refine them. Crucially, the local SVI procedure is itself differentiable, so the inference network and generative model can be trained end-to-end with gradient-based optimization. This semi-amortized approach enables the use of rich generative models without experiencing the posterior-collapse phenomenon common in training VAEs for problems like text generation. Experiments show this approach outperforms strong autoregressive and variational baselines on standard text and image datasets.
1. Introduction
Variational inference (VI) (Jordan et al., 1999; Wainwright & Jordan, 2008) is a framework for approximating an intractable distribution by optimizing over a family of tractable surrogates. Traditional VI algorithms iterate over the observed data and update the variational parameters with closed-form coordinate ascent updates that exploit conditional conjugacy (Ghahramani & Beal, 2001). This style of optimization is challenging to extend to large datasets and non-conjugate models. However, recent advances in stochastic (Hoffman et al., 2013), black-box (Ranganath et al., 2014; 2016), and amortized (Mnih & Gregor, 2014; Kingma & Welling, 2014; Rezende et al., 2014) variational inference have made it possible to scale to large datasets and rich, non-conjugate models (see Blei et al. (2017), Zhang et al. (2017) for a review of modern methods).
In stochastic variational inference (SVI), the variational parameters for each data point are randomly initialized and then optimized to maximize the evidence lower bound (ELBO) with, for example, gradient ascent. These updates are based on a subset of the data, making it possible to scale the approach.
In amortized variational inference (AVI), the local variational parameters are instead predicted by an inference (or recognition) network, which is shared (i.e. amortized) across the dataset.
Variational autoencoders (VAEs) are deep generative models that utilize AVI for inference and jointly train the generative model alongside the inference network.
SVI gives good local (i.e. instance-specific) distributions within the variational family but requires performing optimization for each data point. AVI has fast inference, but having the variational parameters be a parametric function of the input may be too strict of a restriction. As a secondary effect this may militate against learning a good generative model since its parameters may be updated based on suboptimal variational parameters. Cremer et al. (2018) observe that the amortization gap (the gap between the log-likelihood and the ELBO due to amortization) can be significant for VAEs, especially on complex datasets.
Recent work has targeted this amortization gap by combining amortized inference with iterative refinement during training (Hjelm et al., 2016; Krishnan et al., 2018). These methods use an encoder to initialize the local variational parameters, and then subsequently run an iterative procedure to refine them. To train with this hybrid approach, they utilize a separate training time objective. For example Hjelm et al. (2016) train the inference network to minimize the KL-divergence between the initial and the final variational distributions, while Krishnan et al. (2018) train the inference network with the usual ELBO objective based on the initial variational distribution.
In this work, we address the train/test objective mismatch and consider methods for training semi-amortized variational autoencoders (SA-VAE) in a fully end-to-end manner. We propose an approach that leverages differentiable optimization (Domke, 2012; Maclaurin et al., 2015; Belanger et al., 2017) and differentiates through SVI while training the inference network/generative model. We find that this method is able to both improve estimation of variational parameters and produce better generative models.
We apply our approach to train deep generative models of text and images, and observe that they outperform autoregressive/VAE/SVI baselines, in addition to direct baselines that combine VAE with SVI but do not perform end-to-end training. We also find that under our framework, we are able to utilize a powerful generative model without experiencing the “posterior-collapse” phenomenon often observed in VAEs, wherein the variational posterior collapses to the prior and the generative model ignores the latent variable (Bowman et al., 2016; Chen et al., 2017; Zhao et al., 2017). This problem has particularly made it very difficult to utilize VAEs for text, an important open issue in the field. With SA-VAE, we are able to outperform an LSTM language model by utilizing an LSTM generative model that maintains non-trivial latent representations. Code is available at https://github.com/harvardnlp/sa-vae.
2. Background
Notation
2.1. Variational Inference
2.2. Stochastic Variational Inference
We can apply SVI (Hoffman et al., 2013) with gradient ascent to approximately maximize the above objective:
SVI optimizes directly for instance-specific variational distributions, but may require running iterative inference for a large number of steps. Further, because of this block coordinate ascent approach the variational parameters λ are optimized separately from θ, potentially making it difficult for θ to adapt to local optima.
2.3. Amortized Variational Inference
AVI uses a global parametric model to predict the local variational parameters for each data point. A particularly popular application of AVI is in training the variational autoencoder (VAE) (Kingma & Welling, 2014), which runs an inference network (i.e. encoder) enc(·) parameterized by φ over the input to obtain the variational parameters:
The inference network is learned jointly alongside the generative model with the same loss function, allowing the pair to coadapt. Additionally inference for AVI involves running the inference network over the input, which is usually much faster than running iterative optimization on the ELBO. Despite these benefits, requiring the variational parameters to be a parametric function of the input may be too strict of a restriction and can lead to an amortization gap. This gap can propagate forward to hinder the learning of the generative model if θ is updated based on suboptimal λ.
3. Semi-Amortized Variational Autoencoders
Semi-amortized variational autoencoders (SA-VAE) utilize an inference network over the input to give the initial variational parameters, and subsequently run SVI to refine them. One might appeal to the universal approximation theorem (Hornik et al., 1989) and question the necessity of additional SVI steps given a rich-enough inference network. However, in practice we find that the variational parameters found from VAE are usually not optimal even with a powerful inference network, and the amortization gap can be significant especially on complex datasets (Cremer et al., 2018; Krishnan et al., 2018).
SA-VAE models are trained using a combination of AVI and SVI steps:
Note that for training we need to compute the total derivative of the final ELBO with respect to θ, φ (i.e. steps 4 and 5 above). Unlike with AVI, in order to update the encoder and generative model parameters, this total derivative requires backpropagating through the SVI updates. Specifically this requires backpropagating through gradient ascent (Domke, 2012; Maclaurin et al., 2015).
Following past work, this backpropagation step can be done efficiently with fast Hessian-vector products (LeCun et al., 1993; Pearlmutter, 1994). In particular, consider the case where we perform one step of refinement, λ1 = λ0 + α∇λ ELBO(λ0, θ, x), and for brevity let L = ELBO(λ1, θ, x). To backpropagate through this, we receive the derivative dL / dλ1 and use the chain rule,
We can then backpropagate dL / dλ0 through the inference network to calculate the total derivative, i.e. dL / dφ = dλ0/dφ * dL/dλ0 . Similar rules can be used to derive dL / dθ. The full forward/backward step, which uses gradient descent with momentum on the negative ELBO, is shown in Algorithm 1.
In our implementation we calculate Hessian-vector products with finite differences (LeCun et al., 1993; Domke, 2012), which was found to be more memory-efficient than automatic differentiation (and therefore crucial for scaling our approach to rich inference networks/generative models).
7. Conclusion
This work outlines semi-amortized variational autoencoders, which combine amortized inference with local iterative refinement to train deep generative models of text and images. With the approach we find that we are able to train deep latent variable models of text with an expressive autogressive generative model that does not ignore the latent code.
From the perspective of learning latent representations, one might question the prudence of using an autoregressive model that fully conditions on its entire history (as opposed to assuming some conditional independence) given that p(x) can always be factorized as
and therefore the model is non-identifiable (i.e. it does not have to utilize the latent variable). However in finite data regimes we might still expect a model that makes use of its latent variable to generalize better due to potentially better inductive bias (from the latent variable). Training generative models that both model the underlying data well and learn good latent representations is an important avenue for future work.
'Research > Generative Model' 카테고리의 다른 글
DiffusionAD: Norm-guided One-step Denoising Diffusion for Anomaly Detection (0) 2024.04.16 Introduction to Diffusion Models (0) 2024.04.15 [IWAE] Importance Weighted Autoencoders (0) 2024.04.07 [VQ-VAE] Neural Discrete Representation Learning (0) 2024.04.07 Lagging Inference Networks and Posterior Collapse in Variational Autoencoders (0) 2024.04.06