Generative Adversarial Networks
April 13, 2018Transcript of the video
Many of the examples we had in this series so far were discriminative models. Sometimes we need models that can not only recognize what's in the input but also generate new samples. Such models are called generative models.
A discriminative model, such as a model that recognizes cat and dog pictures, wouldn't be able to generate a cat picture, although it might have a notion of what a cat looks like.
A recently proposed approach, called generative adversarial networks, has shown great success in building models that can generate samples that are similar to the ones in a given dataset.
Generative adversarial networks, or GANs for short, consist of two components that compete against each other during training. One network tries to generate realistic samples that have never been seen before. The other one tries to tell whether its inputs are real or fake. The first network is called the generator, and the second one is called the discriminator.
This setting draws inspiration from game theory, where two agents play a game against each other. The goal of the generator is to fool the discriminator, and the goal of the discriminator is not to be fooled.
At each training step, the generator inputs a code vector and generates a sample. The code vector can be a vector of random values. This would prevent the model from producing the same sample every single time. The discriminator gets either a randomly sampled image that actually comes from a dataset or a synthetic sample that is generated by the generator. Given this input, it outputs a probability that the input is a real training example rather than a fake sample generated by the model. After training both networks in this setting, both the generator and discriminator get better towards achieving their adversarial goals.
In the beginning, both the generator and the discriminator perform poorly. The generator can generate very bad samples, yet the discriminator still doesn't catch it. The generator and the discriminator both get better and better as they simultaneously learn how to beat a better opponent. Ideally, they would come to an equilibrium where the generator generates perfectly realistic samples, and the discriminator always outputs .5 probability. Once the training is complete, you can throw away the discriminator and use the generator to generate new samples.
The generator and discriminator can be any differentiable function, such as a feedforward convolutional neural network. Earlier, we talked about how convolutional neural networks can be used as feature extractors, where a model inputs an image and outputs a code vector that explains the input. This code vector usually encodes the essence of the input while some details in the input get lost. What a generator network does is kind of the opposite. Given a code vector, which might just be sampled randomly from a simple distribution, it tries to generate a representative sample, like an image. The generator needs to learn to add realistic details to be able to fool the discriminator.
You might wonder why we even use the discriminator network part. Why don't we train a generator network directly to minimize the mean squared error between the generated and actual images? Optimizing a model to minimize the mean squared error would focus on making a large number of pixels similar to the real images. This can sometimes result in ignoring some important details just because they make up only a small portion of the inputs.
One interesting thing about adversarial training is that the model is forced to work harder on where it's failing. The discriminator constantly looks for the weaknesses of the generator to catch whether it's a fake sample, and the generator is forced to address those weaknesses. Once those weaknesses are addressed, the discriminator tries to find other weaknesses, and so on.
In generative adversarial networks, both the generator and discriminator functions try to minimize their own loss functions. The discriminator loss can be a cross-entropy function that defines the classification error. A very simple generator loss would be just the negation of the discriminator loss. When we optimize this model, essentially, the discriminator function tries to minimize the classification error, while the generator tries to maximize it. There are other, more stable ways to do this, but this is the basic idea.
GANs can be used anywhere we need to generate some data, particularly images and videos. They can generate photorealistic images that look like the ones from a training set yet not identical to any of them, such as creating photos of celebrities that don't exist. They can learn to transform data, such as converting text into pictures or converting pictures to other types of pictures. You can find more information about these applications in the description below.
Let's take a look at one of the applications: image to image translation. CycleGAN is a fairly recent model architecture that translates images from one domain to another without having a dataset of image pairs with one-to-one correspondence. It has shown to be able to translate between Monet paintings and pictures, horses and zebras, and summer and winter, among many other pairs of domains.
CycleGAN is based on the idea that when we translate from domain A to domain B, then translate back from domain B to A, the result we get should be similar to the input. During training, the model minimizes the distance between the original and the reconstructed input. In addition, the discriminator is trained to tell whether the input actually belongs to domain B or not. The same thing is also done in the opposite direction. The cycle consistency loss in the model prevents the model from ignoring the input and generating a random image from domain B.
That's all for today. The next video will focus on practical methodology. We will talk about how to test, debug, and improve models.
Thanks for watching, stay tuned, and see you next time.