Improved Techniques for Training GANs Improved Techniques for Training GANs
Paper summary * They suggest some small changes to the GAN training scheme that lead to visually improved results. * They suggest a new scoring method to compare the results of different GAN models with each other. ### How * Feature Matching * Usually G would be trained to mislead D as often as possible, i.e. to maximize D's output. * Now they train G to minimize the feature distance between real and fake images. I.e. they do: 1. Pick a layer $l$ from D. 2. Forward real images through D and extract the features from layer $l$. 3. Forward fake images through D and extract the features from layer $l$. 4. Compute the squared euclidean distance between the layers and backpropagate. * Minibatch discrimination * They allow D to look at multiple images in the same minibatch. * That is, they feed the features (of each image) extracted by an intermediate layer of D through a linear operation, resulting in a matrix per image. * They then compute the L1-distances between these matrices. * They then let D make its judgement (fake/real image) based on the features extracted from the image and these distances. * They add this mechanism so that the diversity of images generated by G increases (which should also prevent collapses). * Historical averaging * They add a penalty term that punishes weights which are rather far away from their historical average values. * I.e. the cost is `distance(current parameters, average of parameters over the last t batches)`. * They argue that this can help the network to find equilibria that normal gradient descent would not find. * One-sided label smoothing * Usually one would use the labels 0 (image is fake) and 1 (image is real). * Using smoother labels (0.1 and 0.9) seems to make networks more resistent to adversarial examples. * So they smooth the labels of real images (apparently to 0.9?). * Smoothing the labels of fake images would lead to (mathematical) problems in some cases, so they keep these at 0. * Virtual Batch Normalization (VBN) * Usually BN normalizes each example with respect to the other examples in the same batch. * They instead normalize each example with respect to the examples in a reference batch, which was picked once at the start of the training. * VBN is intended to reduce the dependence of each example on the other examples in the batch. * VBN is computationally expensive, because it requires forwarding of two minibatches. * They use VBN for their G. * Inception Scoring * They introduce a new scoring method for GAN results. * Their method is based on feeding the generated images through another network, here they use Inception. * For an image `x` and predicted classes `y` (softmax-output of Inception): * They argue that they want `p(y|x)` to have low entropy, i.e. the model should be rather certain of seeing a class (or few classes) in the image. * They argue that they want `p(y)` to have high entropy, i.e. the predicted classes (and therefore image contents) should have high diversity. (This seems like something that is quite a bit dependend on the used dataset?) * They combine both measurements to the final score of `exp(KL(p(y|x) || p(y))) = exp( <sum over images> p(y|xi) * (log(p(y|xi)) - log(p(y))) )`. * `p(y)` can be approximated as the mean of the softmax-outputs over many examples. * Relevant python code that they use (where `part` seems to be of shape `(batch size, number of classes)`, i.e. the softmax outputs): `kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0))); kl = np.mean(np.sum(kl, 1)); scores.append(np.exp(kl));` * They average this score over 50,000 generated images. * Semi-supervised Learning * For a dataset with K classes they extend D by K outputs (leading to K+1 outputs total). * They then optimize two loss functions jointly: * Unsupervised loss: The classic GAN loss, i.e. D has to predict the fake/real output correctly. (The other outputs seem to not influence this loss.) * Supervised loss: D must correctly predict the image's class label, if it happens to be a real image and if it was annotated with a class. * They note that training G with feature matching produces the best results for semi-supervised classification. * They note that training G with minibatch discrimination produces significantly worse results for semi-supervised classification. (But visually the samples look better.) * They note that using semi-supervised learning overall results in higher image quality than not using it. They speculate that this has to do with the class labels containing information about image statistics that are important to humans. ### Results * MNIST * They use weight normalization and white noise in D. * Samples of high visual quality when using minibatch discrimination with semi-supervised learning. * Very good results in semi-supervised learning when using feature matching. * Using feature matching decreases visual quality of generated images, but improves results of semi-supervised learning. * CIFAR-10 * D: 9-layer CNN with dropout, weight normalization. * G: 4-layer CNN with batch normalization (so no VBN?). * Visually very good generated samples when using minibatch discrimination with semi-supervised learning. (Probably new record quality.) * Note: No comparison with nearest neighbours from the dataset. * When using feature matching the results are visually not as good. * Again, very good results in semi-supervised learning when using feature matching. * SVHN * Same setup as in CIFAR-10 and similar results. * ImageNet * They tried to generate 128x128 images and compared to DCGAN. * They improved from "total garbage" to "garbage" (they now hit some textures, but structure is still wildly off). ![CIFAR-10 Examples]( "CIFAR-10 Examples") *Generated CIFAR-10-like images (with minibatch discrimination and semi-supervised learning).*

Summary by Udibr 3 years ago
I love the format of this summary. Thanks! The historical averaging idea is interesting. This is basically just a momentum update rule right?

that's what I understood (dont have first hand experience)

In minibatch discrimination, we have these $M$ matrices by multiplying with the $T$ tensor. What is the $T$ tensor? In the code it looks like you initialise it like a weight matrix, which means you learn it?

Your comment: allows researchers to publish paper summaries that are voted on and ranked!

Sponsored by: and