ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • Semi-supervised Learning with Variational Autoencoders
    Research/Generative Model 2024. 4. 4. 12:03

    https://bjlkeng.io/posts/semi-supervised-learning-with-variational-autoencoders/


    Semi-supervised Learning

    Semi-supervised learning is a set of techniques used to make use of unlabelled data in supervised learning problems (e.g. classification and regression). Semi-supervised learning falls in between unsupervised and supervised learning because you make use of both labelled and unlabelled data points.

     

    If you think about the plethora of data out there, most of it is unlabelled. Rarely do you have something in a nice benchmark format that tells you exactly what you need to do. As an example, there are billions (trillions?) of unlabelled images all over the internet but only a tiny fraction actually have any sort of label. So our goals here is to get the best performance with a tiny amount of labelled data.

     

    Humans somehow are very good at this. Even for those of us who haven't seen one, I can probably show you a handful of ant eater images and you can probably classify them pretty accurately. We're so good at this because our brains have learned common features about what we see that allow us to quickly categorize things into buckets like ant eaters. For machines it's no different, somehow we want to allow a machine to learn some additional (useful) features in an unsupervised way to help the actual task of which we have very few examples.

     


    Variational Lower Bound

    In a lot of ML papers, they take for granted the "maximization of the variational lower bound", so I just want to give a bit of intuition behind it.

     

    Let's start off with the high level problem. We have some data X, a generative probability model P(X| θ) that shows us how to randomly sample (e.g. generate) data points that follow the distribution of X, assuming we know the "magic values" of the θ parameters. We can see Bayes theorem in Equation 1 (small p for densities):

     

    Our goal is to find the posterior, P(θ|X), that tells us the distribution of the θ parameters, which sometimes is the end goal (e.g. the cluster centers and mixture weights for a Gaussian mixture models), or we might just want the parameters so we can use P(X|θ) to generate some new data points (e.g. use variational autoencoders to generate a new image). Unfortunately, this problem is intractable (mostly the denominator) for all but the simplest problems, that is, we can't get a nice closed-form solution.

     

    Our solution? Approximation! We'll approximate P(θ|X) by another function Q(θ|X) (it's usually conditioned on X but not necessarily). And solving for Q is (relatively) fast because we can assume a particular shape for Q(θ|X) and turn the inference problem (i.e. finding P(θ|X)) into an optimization problem (i.e. finding Q). Of course, it can't be just a random function, we want it to be as close as possible to P(θ|X), which will depend on the structural form of Q(θ|X) (how much flexibility it has), our technique to find it, and our metric of "closeness".

     

    In terms of "closeness", the standard way of measuring it is to use KL divergence, which we can neatly write down here:

     

    Rearranging, dropping the KL divergence term and putting it in terms of an expectation of q(θ), we get what's called the Evidence Lower Bound (ELBO) for a single data point X:

     

    If you have multiple data points, you can just sum over them because we're in log space (assuming independence between data points).

     

    In a lot of papers, you'll see that people will go straight to optimizing the ELBO whenever they are talking about variational inference. And if you look at it in isolation, you can gain some intuition of how it works:

    • It's a lower bound on the evidence, that is, it's a lower bound on the probability of your data occurring given your model.
    • Maximizing the ELBO is equivalent to minimizing the KL divergence.
    • The first two terms try to maximize the MAP estimate (likelihood + prior).
    • The last term tries to ensure Q is diffuse (maximize information entropy).

    A Vanilla VAE for Semi-Supervised Learning (M1 Model)

    The high level is pretty easy to understand. A variational autoencoder defines a generative model for your data which basically says take an isotropic standard normal distribution (Z), run it through a deep net (defined by g) to produce the observed data (X). The hard part is figuring out how to train it.

     

    Using the autoencoder analogy, the generative model is the "decoder" since you're starting from a latent state and translating it into the observed data. A VAE also has an "encoder" part that is used to help train the decoder. It goes from observed values to a latent state (X to z). A keen observer will notice that this is actually our variational approximation of the posterior (q(z|X)), which coincidentally is also a neural network (defined by gz|X). This is visualized in Figure 1.

    Figure 1: Vanilla Variational Autoencoder

     

    After our VAE has been fully trained, it's easy to see how we can just use the "encoder" to directly help with semi-supervised learning:

    • Train a VAE using all our data points (labelled and unlabelled), and transform our observed data (X) into the latent space defined by the Z variables.
    • Solve a standard supervised learning problem on the labelled data using (Z, Y) pairs (where Y is our label).

    Intuitively, the latent space defined by z should capture some useful information about our data such that it's easily separable in our supervised learning problem. This technique is defined as M1 model in the Kingma paper. As you may have noticed though, step 1 doesn't directly involve any of the y labels; the steps are disjoint. Kingma also introduces another model "M2" that attempts to solve this problem.


    Extending the VAE for Semi-Supervised Learning (M2 Model)

    In the M1 model, we basically ignored our labelled data in our VAE. The M2 model (from the Kingma paper) explicitly takes it into account. Let's take a look at the generative model (i.e. the "decoder"):

    where:

    • x is a vector of our observed variables.
    • f(x; y, z, θ) is a suitable likelihood function to model our output such as a Gaussian or Bernoulli. We use a deep net to approximate it based on inputs y, z with network weights defined by θ.
    • z is a vector latent variables (same as vanilla VAE).
    • y is one-hot encoded categorical variable representing our class labels, whose relative probabilities are parameterized by π.
    • SimDir is Symmetric Dirichlet distribution with hyper-parameter α (a conjugate prior for categorical/multinomial variables).

    How do we do use this for semi-supervised learning you ask? The basic gist of it is: we will define a approximate posterior function qϕ(y|x) using a deep net that is basically a classifier. However the genius is that we can train this classifier for both labelled and unlabelled data by just training this extended VAE. Figure 2 shows a visualization of the network.

    Figure 2: M2 Variational Autoencoder for Semi-Supervised Learning

     

    Now the interesting part is that we have two cases: one where we observe the y labels and one where we don't. We have to deal with them differently when constructing the approximate posterior q as well as in the variational objective.


    Variational Objective with Unlabelled Data

    For any variational inference problem, we need to start with our approximate posterior. In this case, we'll treat y, z as the unknown latent variables, and perform variational inference (i.e. define approximate posteriors) over them. Notice that we excluded π because we don't really care what its posterior is in this case.

     

    We'll assume the approximate posterior qϕ(y,z|x) has a fully factorized form as such:

     

    where μϕ(x),σ2ϕ(x),πϕ(X) are all defined by neural networks parameterized by ϕ that we will learn. Here, πϕ(X) should not be confused with our actual parameter π above, the former is a point-estimate coming out of our network, the latter is a random variable as a symmetric Dirichlet.

     

    From here, we use the ELBO to determine our variational objective for a single data point:

     

    Going through line by line, we factor our qϕ function into the separate y and z parts for both the expectation and the log. Notice we also absorb logpθ(y) into a constant because p(y)=p(y|π)p(π), a Dirichlet-multinomial distribution, and simplifies to a constant (alternatively, our model's assumption is that y's are equally likely to happen).

     

    Next, we notice that some terms form a KL distribution between qϕ(z|x) and pθ(z). Then, we group a few terms together and name it L(x,y). This latter term is essentially the same variational objective we used for a vanilla variational autoencoder (sans the reference to y). Finally, we explicitly write out the expectation with respect to y. I won't write out all the details for how to compute it, for that you can look at my previous post for L(x,y), and the implementation notebooks for the rest. The loss functions are pretty clearly labelled so it shouldn't be too hard to map it back to these equations.

     

    So Equation 6 defines our objective function for our VAE, which will simultaneously train both the θ  parameters of the "decoder" network as well as the approximate posterior "encoder" ϕ parameters relating to y,z.


    Variational Objective with Labelled Data

    So here's where it gets a bit trickier because this part was glossed over in the paper. In particular, when training with labelled data, you want to make sure you train both the y and the z networks at the same time. It's actually easy to leave out the y network since you have the observations for y, allowing you to ignore the classifier network.

     

    Now of course the whole point of semi-supervised learning is to learn a mapping using labelled data from x to y so it's pretty silly not to train that part of your VAE using labelled data. So Kingma et al. add an extra loss term initially describing it as a fix to this problem. Then, they add an innocent throw-away line that this actually can be derived by performing variational inference over π. Of course, it's actually true (I think) but it's not that straightforward to derive! Well, I worked out the details, so here's my presentation of deriving the variational objective with labelled data.

     

    For the case when we have both (x,y) points, we'll treat both z and π as unknown latent variables and perform variational inference for both z and π using a fully factorized posterior dependent only on x.

     

    Remember we can define our approximate posteriors however we want, so we explicitly choose to have π to depend only on x and not on our observed y. Why you ask? It's because we want to make sure our ϕ parameters of our classifier are trained when we have labelled data.

     

    As before, we start with the ELBO to determine our variational objective for a single data point (x,y):

     

    where α is a hyper-parameter that controls the relative weight of how strongly you want to train the discriminative classified (qϕ(y|x)). In the paper, they set it to α=0.1N

     

    Going line by line, we start off with the ELBO, expanding all the priors. The one trick we do is instead of expanding the joint distribution of y,π conditioned on π (i.e. pθ(y,π)=pθ(y|π)pθ(π)), we instead expand using the posterior: pθ(π|y). The posterior in this case is again a Dirichlet distribution because it's the conjugate prior of y's categorical/multinomial distribution.

     

    Next, we just rearrange and factor qϕ, both in the log term as well as the expectation. We notice that the first part is exactly our L loss function from above and the rest is a KL divergence between our π posterior and our approximate posterior. The last simplification of the KL divergence is a bit verbose (and hand-wavy) so I've put it in Appendix A.


    Training the M2 Model

    Using Equations 6 and 8, we can derive a loss function as such (remember it's the negative of the ELBO above):

     

    With this loss function, we just train the network as you would expect. Simply grab a mini-batch, compute the needed values in the network (i.e. q(y|x),q(z|x),p(x|y,z)), compute the loss function above using the appropriate summation depending on if you have labelled or unlabelled data, and finally just take the gradients to update our network parameters θ,ϕ. The network is remarkably similar to a vanilla VAE with the addition of the posterior on y, and the additional terms to the loss function. The tricky part is dealing with the two types of data (labelled and unlabelled), which I explain in the implementation notes below.


    Implementation Notes

    Variational Autoencoder Implementations (M1 and M2)

    The architectures I used for the VAEs were as follows:

    • For q(y|x), I used the CNN example fro Keras, which has 3 conv layers, 2 max pool layers, a softmax layer, with dropout and ReLU activation.
    • For q(z|x), I used 3 conv layers, and 2 fully connected layers with batch normalization, dropout and ReLU activation.
    • For p(x|z) and p(x|y,z), I used a fully connected layer, followed by 4 transposed conv layers (the first 3 with ReLU activation the last with sigmoid for the output).

    The one complication that I had was how to implement the training of the M2 model because you need to treat y simultaneously as an input and an output depending if you have labelled or unlabelled data. I still wanted to use Keras and didn't want to go as low level as TensorFlow, so I came up with a workaround: train two networks (with shared layers)!

     

    So basically, I have one network for labelled data and one for unlabelled data. They both share all the same components (q(y|x),q(z|x),p(x|y,z)) but differ in their input/output as well as loss functions. The labelled data has input (x,y) and output (x′,y′). y′ corresponds to the predictions from the posterior, while x′ corresponds to the decoder output. The loss function is Equation 8 with α=0.1N (not the one I derived in Appendix A). For the unlabelled case, the input is x and the output is the predicted x′.


    Comparison Implementations

    In the results below I compared a semi-supervised VAE with several other ways of dealing with semi-supervised learning problems:

    • PCA + SVM: Here I just ran principle component analysis on the entire image set, and then trained a SVM using a PCA-transformed representation on only the labelled data.
    • CNN: a vanilla CNN using the Keras CNN example trained only on labelled data.
    • Inception: Here I used a pre-trained Inception network available in Keras. I pretty much just used the example they had which adds a global average pooling layer, a dense layer, followed by a softmax layer. Trained only on the labelled data while freezing all the original pre-trained Inception layers. I didn't do any fine-tuning of the Inception layers.

    Semi-supervised Results

    The datasets  I used were MNIST and CIFAR10 with stratified sampling on the training data to create the semi-supervised dataset. The test sets are the ones included with the data. Here are the results for MNIST:

    Table 1: MNIST Results

     

    The M2 model was only run for N=1000. From the MNIST results table, we really see the the M2 model shine where at a comparable sample size, all the other methods have much lower performance. You need to get to N=5000 before the CNN gets in the same range. Interestingly at N=100 the models that make use of the unlabelled data do better than a CNN which has so little training data it surely is not learning to generalize. Next, onto CIFAR 10 results shown in Table 2.

    Table 2: CIFAR10 Results

     

    Again  I only train M2 on N=1000. The CIFAR10 results show another story. Clearly the pre-trained Inception network is doing the best. It's pre-trained on Imagenet which is very similar to CIFAR10. You have to get to relatively large sample sizes before even the CNN starts approaching the same accuracy.

     

    The M1/M2 results are quite poor, not even beating out PCA in most cases! My reasoning here is that the CIFAR10 dataset is too complex for the VAE model. That is, when I look at the images generated from it, it's pretty hard for me to figure out what the label should be. Take a look at some of the randomly generated images from my M2 model:

    Figure 3: Images generated from M2 VAE model trained on CIFAR data.

     

    Other people have had similar problems. I suspect the z Gaussian latent variables are not powerful enough to encode the complexity of the CIFAR10 dataset. I've read somewhere that the unimodal nature of the latent variables is thought to be quite limiting, and here I guess we see that is the case. I'm pretty sure more recent research has tried to tackle this problem so I'm excited to explore this phenomenon more later.


    Conclusion

    As I've been writing about for the past few posts, I'm a huge fan of scalable probabilistic models using deep learning. I think it's both elegant and intuitive because of the probabilistic formulation. Unfortunately, VAEs using Gaussians as the latent variable do have limitations, and obviously they are not quite the state-of-the-art in generative models (i.e. GANs seem to be the top dog). In any case, there is still a lot more recent research in this area that I'm going to follow up on and hopefully I'll have something to post about soon. Thanks for reading!

     

Designed by Tistory.