Playing with Variational Auto Encoders - PCA vs. UMAP vs. VAE on FMNIST / MNIST

TLDR - they are very cool - but useful only on very simple domains and datasets

Posted by snakers41 on July 7, 2018

A VAE in a nutshell. For more details on intuition - please follow this link to an amazing post


TLDR;

What you can find here:

  • A working VAE (variational auto-encoder) example on PyTorch with a lot of flags (both FC and FCN, as well as a number of failed experiments);
  • Some tests - which loss works best (I did not do proper scaling, but out-of-the-box BCE works best compared to SSIM and MSE);
  • Some experimental boilerplate code for beginners on PyTorch 0.4 (I tried various architecture designs - see my tests below);
  • Comparison between embeddings produced by PCA / UMAP / VAEs (spoiler - VAEs win) on a simple classification task with KNN;
  • A step-by step logic of what I did in main.ipynb;
  • Some comments on VAEs / GANs for practical use;

What do exactly VAEs do?

To intuitively understand VAEs you can read this awesome article and to get some mathematical background you get follow this links (being a practitioner myself I am more interested in the practical ramifications):


But in a nutshell, VAEs learn to:

  • Encode your object into some small vector;
  • Decode this vector back;


If you are lazy and have some background with CNNs you can follow this mental shortcut to VAEs:

  • Encoder CNNs (anything from AlexNet to NasNet) essentially encode your image into a small vector or feature map (e.g. 224x224x3 => 1532 in case of inception);
  • Encoder-decoder CNNs (anything from FCN to any modern monsters) can perform noise removal / image augmentation / semantic segmentation - and they benefit severely from skip connections and residual connections;
  • Auto-encoder essentially is just an encoder-decoder CNN, where you treat the "middle" representations as, you guessed, "code";
  • To train a auto-encoder you essentially just take and encoder-decoder CNN and use the input image together with BCE loss to compare it with the output of the model;
  • To make the distribution of the embedded latent variables smooth and continuous (and to be able to sample from it):
    • You introduce a second loss term - KL divergence;
    • You train you model to output mean and standard deviation, that you then use in sampling process. So when training you process is a biz fuzzy by design;
    • This distinguishes VAE from AE;
  • Also intuitively you can also treat PCA as a simple auto-encoder, albeit a linear one;

To further boost you intuition, consider the below embedding projections from this article for reconstruction loss / KLD loss / both losses on MNIST.


Now compare this to the embeddings I obtained for FMNIST (a tougher version of MNIST, because MNIST is too easy nowadays) using PCA / UMAP / VAE.


Practical conclusions:

  • PCA is good enough for fast evaluation and looks really similar to VAE trained with reconstruction loss only;
  • UMAP works amazingly good for cluster separation, but the clusters are far away in the latent space;
  • VAE embeddings are smooth and continuous and work also well for data exploration / clustering and probably for EDA;


But you have to know that, PCA and UMAP will likely work on any domain with some result, yet VAEs are not really popular and not known to work with anything more complicated than Imagenet-64 (and they work kind of "meh" there).

Is it even remotely applicable in business / real life (or anything remotely useful, not creepy cat photos)? VAEs or GANs?

In 2018 I believe the following to be true:

  • GANs > VAEs, but GANs are notoriously difficult and expensive to train;
  • VAEs do not really work for difficult data-sets and domains (anything more difficult than CIFAR or Imagenet-64);
  • GANs have been shown to work on some simple domains like faces even on Full-HD images - but I have not seen a practical / open-source / easy-to-replicate implementation of this yet (there is some marketing BS or secret sauce as usual);
  • Even if you manage to train a GAN properly - then you will have to reconstruct the latent features from it (I have not tested it, but I believe the discriminator will be useless in this case);
  • Though some applications that only change / clean / enhance images (and not create them from scratch) have been shown to work really well in practice:
    • Image de-noising;
    • "Seeing" in the dark;
    • Style transfer;
    • Image filters;
    • Image in-painting and artifact removal;
    • For more examples - just look at the latest videos on this channel;


You can find a good example of said limitations in my recent article and in this repo. This repo is really interesting, because you see, the paper by NVIDIA boasted being able to produce even Full-HD images, but this does not seem to hold for independent researchers, as it usually happens.

It is impossible to easily understand why (because the results are heavily doctored and usually no readable code is published - only some obfuscated unreadable corporate trash at best).



I believe this images to be the real representation of what is really achievable by GANs now


Training VAE

The best model can be trained using this command as follows:

python3 train.py \
    --epochs 30 --batch-size 512 --seed 42 \
    --model_type fc_conv --dataset_type fmnist --latent_space_size 10 \
    --do_augs False \
    --lr 1e-3 --m1 40 --m2 50 \
    --optimizer adam \
    --do_running_mean False --img_loss_weight 1.0 --kl_loss_weight 1.0 \
    --image_loss_type bce --ssim_window_size 5 \
    --print-freq 10 \
    --lognumber fmnist_fc_conv_l10_rebalance_no_norm \
    --tensorboard True --tensorboard_images True \


If you launch this code, the copy of FMNIST-dataset will be downloaded automatically.

Suggested alternative values for the flags for playing with them:

  • dataset_type - can be set to mnist and fmnist. In each case will download the necessary dataset;
  • latent_space_size - will affect the latent space in combination with model_type  fc_conv or fc. Other model types do not work properly;
  • m1 and m2 control lr decay, but it did not really help here;
  • image_loss_type can be set to bce, mse or ssim. In practice bce works best. mse is worse. I suppose that proper scaling is required to make it work with ssim (it does not train now);
  • tensorboard  and tensorboard_images can also be set to False. But they just write logs, so you may just not bother;


These flags are optional --tensorboard True --tensorboard_images True, in order to use them, you have to:

  • Install tensorboard (installs with tensorflow)
  • Launch tensorboard with the following command tensorboard --logdir='path/to/tb_logs' --port=6006

You can also resume from the best checkpoint using these flags:

python3 train.py \
--resume weights/fmnist_fc_conv_l10_rebalance_no_norm_best.pth.tar \
--epochs 60 --batch-size 512 --seed 42 \
--model_type fc_conv --dataset_type fmnist --latent_space_size 10 \
--do_augs False \
--lr 1e-3 --m1 50 --m2 100 \
--optimizer adam \
--do_running_mean False --img_loss_weight 1.0 --kl_loss_weight 1.0 \
--image_loss_type bce --ssim_window_size 5 \
--print-freq 10 \
--lognumber fmnist_resume \
--tensorboard True --tensorboard_images True \


The best reconstructions are supposed to look like this (top row - original images, bottom row - reconstructions):


Brief ablation analysis of the results


What worked:

  1. Using BCE loss + KLD loss;
  2. Converting a plain FC model into a conv model in the most straight-forward fashion possible, i.e. replacing this with this;
  3. Using `SSIM` as visualization metric. It correlates awesomely with perceived visual similarity of the image and its reconstruction;


What did not work:

  1. Extracting `mean` and `std` from images - removing this feature boosted SSIM on FMNIST 4-5x;
  2. Doing any simple augmentations (unsurprisingly - it adds a complexity level to a simple task);
  3. Any architectures beyond the most obvious ones:
    1. UNet inspired architectures (my speculation - this is because image size is very small, and very global features work best, i.e. feature extraction cascade is overkill):
      1. I tried various combinations of convolution weights, all of them did not work;
      2. 1xN convolutions;
  4. MSE loss performed poorly, SSIM loss did not work at all
  5. LR decay, as well as any LR besides 1e-3 (with adam) does not really help
  6. Increasing latent space to 20 or 100 does not really change much


 ¯|_(ツ)_/¯ What I did not try:

  1. Ensembling or building meta-architectures;
  2. Conditional VAEs;
  3. Increasing network capacity;

Comparing obtained embeddings

In my case with KNN (the simplest classifier there is) the following accuracies were obtained:

  • VAE and UMAP - 80%+;
  • PCA - 76%;

Further reading