-
[HFVAE] Structured Disentangled RepresentationsResearch/Generative Model 2024. 4. 5. 22:45
※ https://arxiv.org/pdf/1804.02086.pdf
Abstract
Deep latent-variable models learn representations of high-dimensional data in an unsupervised manner. A number of recent efforts have focused on learning representations that disentangle statistically independent axes of variation by introducing modifications to the standard objective function. These approaches generally assume a simple diagonal Gaussian prior and as a result are not able to reliably disentangle discrete factors of variation. We propose a two-level hierarchical objective to control relative degree of statistical independence between blocks of variables and individual variables within blocks. We derive this objective as a generalization of the evidence lower bound, which allows us to explicitly represent the trade-offs between mutual information between data and representation, KL divergence between representation and prior, and coverage of the support of the empirical data distribution. Experiments on a variety of datasets demonstrate that our objective can not only disentangle discrete variables, but that doing so also improves disentanglement of other variables and, importantly, generalization even to unseen combinations of factors.
1. Introduction
Deep generative models represent data x using a lowdimensional variable z (sometimes referred to as a code). The relationship between x and z is described by a conditional probability distribution pθ(x|z) parameterized by a deep neural network. There have been many recent successes in training deep generative models for complex data types such as images [Gatys et al., 2015, Gulrajani et al., 2017], audio [Oord et al., 2016], and language [Bowman et al., 2016]. The latent code z can also serve as a compressed representation for downstream tasks such as text classification [Xu et al., 2017], Bayesian optimization [Gómez-Bombarelli et al., 2018, Kusner et al., 2017], and lossy image compression [Theis et al., 2017]. The setting in which an approximate posterior distribution qφ(z|x) is simultaneously learnt with the generative model via optimization of the evidence lower bound (ELBO) is known as a variational autoencoder (VAE), where qφ(z|x) and pθ(x|z) represent probabilistic encoders and decoders respectively. In contrast to VAE, inference and generative models can also be learnt jointly in an adversarial setting [Makhzani et al., 2015, Dumoulin et al., 2016, Donahue et al., 2016].
While deep generative models often provide high-fidelity reconstructions, the representation z is generally not directly amenable to human interpretation. In contrast to classical methods such as principal components or factor analysis, individual dimensions of z don’t necessarily encode any particular semantically meaningful variation in x. This has motivated a search for ways of learning disentangled representations, where perturbations of an individual dimension of the latent code z perturb the corresponding x in an interpretable manner. Various strategies for weak supervision have been employed, including semi-supervision of latent variables [Kingma et al., 2014, Siddharth et al., 2017], triplet supervision [Karaletsos et al., 2015, Veit et al., 2016], or batch-level factor invariances [Kulkarni et al., 2015, Bouchacourt et al., 2017]. There has also been a concerted effort to develop fully unsupervised approaches that modify the VAE objective to induce disentangled representations. A well-known example is β-VAE [Higgins et al., 2016]. This has prompted a number of approaches that modify the VAE objective by adding, removing, or altering the weight of individual terms [Kumar et al., 2017, Zhao et al., 2017, Gao et al., 2018, Achille and Soatto, 2018].
In this paper, we introduce hierarchically factorized VAEs (HFVAEs). The HFVAE objective is based on a two-level hierarchical decomposition of the VAE objective, which allows us to control the relative levels of statistical independence between groups of variables and for individual variables in the same group. At each level, we induce statistical independence by minimizing the total correlation (TC), a generalization of the mutual information to more than two variables. A number of related approaches have also considered the TC [Kim and Mnih, 2018, Chen et al., 2018, Gao et al., 2018], but do not employ the two-level decomposition that we consider here. In our derivation, we reinterpret the standard VAE objective as a KL divergence between a generative model and its corresponding inference model. This has the side benefit that it provides a unified perspective on trade-offs in modifications of the VAE objective.
We illustrate the power of this decomposition by disentangling discrete factors of variation from continuous variables, which remains problematic for many existing approaches. We evaluate our methodology on a variety of datasets including dSprites, MNIST, Fashion MNIST (F-MNIST), CelebA and 20NewsGroups. Inspection of the learned representations confirms that our objective uncovers interpretable features in an unsupervised setting, and quantitative metrics demonstrate improvement over related methods. Crucially, we show that the learned representations can recover combinations of latent features that were not present in any examples in the training set, which has long been an implicit goal in learning disentangled representations that is now considered explicitly.
2. A Unified View of Generalized VAE Objectives
Variational autoencoders jointly optimize two models. The generative model pθ(x, z) defines a distribution on a set of latent variables z and observed data x in terms of a prior p(z) and a likelihood pθ(x | z), which is often referred to as the decoder model. This distribution is estimated in tandem with an encoder, a conditional distribution qφ(z | x) that performs approximate inference in this model. The encoder and decoder together define a probabilistic autoencoder.
The VAE objective is traditionally defined as sum over datapoints x n of the expected value of the per-datapoint ELBO, or alternatively as an expectation over an empirical distribution q(x) that approximates an unknown data distribution with a finite set of data points,
To better understand the various modifications of the VAE objective, which have often been introduced in an ad hoc manner, we here consider an alternate but equivalent definition of the VAE objective as a KL divergence between the generative model pθ(x, z) and inference model qφ(z, x) = q(z | x)q(x),
This definition differs from the expression in Equation (1) only by a constant term log N, which is the entropy of the empirical data distribution q(x). The advantage of this interpretation as a KL divergence is that it becomes more apparent what it means to optimize the objective with respect to the generative model parameters θ and the inference model parameters φ. In particular, it is clear that the KL is minimized when pθ(x, z) = qφ(z, x), which in turn implies that marginal distributions on data pθ(x) = q(x) and latent code qφ(z) = p(z) must also match. We will refer to qφ(z), as the inference marginal, which is the average over the data of the encoder distribution
To more explicitly represent the trade-offs that are implicit in optimizing the VAE objective, we perform a decomposition (Figure 1) similar to the one obtained by Hoffman and Johnson [2016]. This decomposition yields 4 terms. Terms 3 and 4 enforce consistency between the marginal distributions over x and z. Minimizing the KL in term 3 maximizes the marginal likelihood Eq(x) [log pθ(x)], whereas minimizing 4 ensures that the inference marginal qφ(z) approximates the prior p(z). Terms 1 and 2 enforce consistency between the conditional distributions. Intuitively speaking, term 1 maximizes the identifiability of the values z that generate each x n; when we sample z ∼ qφ(z | x n), then the likelihood pθ(x n | z) under the generative model should be higher than the marginal likelihood pθ(x n). Term 2 regularizes term 1 by minimizing the mutual information I(z; x) in the inference model, which means that qφ(z | x n) maps each x n to less identifiable values.
Note that term 1 is intractable in practice, since we are not able to pointwise evaluate pθ(x). We can circumvent this intractability by combining 1 + 3 into a single term, which recovers the likelihood
To build intuition for the impact of each of these terms, Figure 2 shows the effect of removing each term from the objective. When we remove 3 or 4 we can learn models in which pθ(x) deviates from q(x), or qφ(z) deviates from p(z). When we remove 1 , we eliminate the requirement that pθ(x n | z) should be higher when z ∼ qφ(z | x n) than when z ∼ p(z). Provided the decoder model is sufficiently expressive, we would then learn a generative model that ignores the latent code z. This undesirable type of solution does in fact arise in certain cases, even when 1 is included in the objective, particularly when using auto-regressive decoder architectures [Chen et al., 2016b].
When we remove 2 , we learn a model that minimizes the overlap between qφ(z | x n) for different data points x n, in order to maximize 1 . This maximizes the mutual information I(x; z), which is upper-bounded by log N. In practice 2 often saturates to log N, even when included in the objective, which suggests that maximizing 1 outweighs this cost, at least for the encoder/decoder architectures that are commonly considered in present-day models.
3. Hierarchically Factorized VAEs (HFVAEs)
In this paper, we are interested in defining an objective that will encourage statistical independence between features. The β-VAE objective [Higgins et al., 2016] aims to achieve this goal by defining the objective
We can express this objective in the terms of Figure 1 as 1 + 3 + β ( 2 + 4 ). In order to induce disentangled representations, the authors set β > 1. This works well in certain cases, but it has the drawback that it also increases the strength of 2 , which means that the encoder model may discard more information about x in order to minimize the mutual information I(x; z).
Looking at the β-VAE objective, it seems intuitive that increasing the weight of term 4 is likely to aid disentanglement. One notion of disentanglement is that there should be a low degree of correlation between different latent variables zd. If we choose a mean-field prior
then minimizing the KL term should induce an inference marginal
in which zd are also independent. However, in addition to being sensitive to correlations, the KL will also be sensitive to discrepancies in the shape of the distribution. When our primary interest is to disentangle representations, then we may wish to relax the constraint that the shape of the distribution matches the prior in favor of enforcing statistical independence.
To make this intuition explicit, we decompose 4 into two terms A and B (Figure 3). As with term 1 + 2 , term A consists of two components. The second of these takes the form of a total correlation, which is the generalization of the mutual information to more than two variables,
Minimizing the total correlation yields a qφ(z) in which different zd are statistically independent, hereby providing a possible mechanism for inducing disentanglement. In cases where zd itself represents a group of variables, rather than a single variable, we can continue to decompose to another set of terms i and ii which match the total correlation for zd and the KL divergences for constituent variables zd,e. This provides an opportunity to induce hierarchies of disentangled features. We can in principle continue this decomposition for any number of levels to define an HFVAE objective. We here restrict ourselves to the two-level case, which corresponds to an objective of the form
In this objective, α controls the I(x; z) regularization, β controls the TC regularization between groups of variables, and γ controls the TC regularization within groups. This objective is similar to, but more general than, the one recently proposed by Kim and Mnih [2018] and Chen et al. [2018]. Our objective admits these objectives as a special case corresponding to a non-hierarchical decomposition in which β = γ. The first component of A is not present in these objectives, which implicitly assume that
In the more general case where
maximizing A with respect to φ will match the total correlation in q(z) to that in p(z).
6. Discussion
Much of the work on learning disentangled representations thus far has focused on cases where the factors of variation are uncorrelated scalar variables. As we begin to apply these techniques to real world datasets, we are likely to encounter correlations between latent variables, particularly when there are causal dependencies between them. This work is a first step towards learning of more structured disentangled representations. By enforcing statistical independence between groups of variables, or relaxing this constraint, we now have the capability to disentangle variables that have higher-dimensional representations. An avenue of future work is to develop datasets that allow us to more rigorously test our ability to characterize correlations between higher-dimensional variables.
'Research > Generative Model' 카테고리의 다른 글
[VQ-VAE] Neural Discrete Representation Learning (0) 2024.04.07 Lagging Inference Networks and Posterior Collapse in Variational Autoencoders (0) 2024.04.06 [Factor VAE] Disentangling by Factorising (0) 2024.04.05 [Beta-VAE] Learning Basic Visual Concepts with a Constrained Variational Framework (0) 2024.04.04 Semi-supervised Learning with Variational Autoencoders (0) 2024.04.04