Generator

A short article on GANs and how they work. This is a high-level overview, without any code or mathematical rigor. Some knowledge of tensors / matrices and neural networks is required.

Table of Contents

Definition - By Merriam Webster

Generative - Having the power or function of generating, originating, producing, or reproducing.

Adversarial (Adversary) - One that contends with, opposes, or resists: an enemy or opponent.

Net (Network) - a computer architecture in which a number of processors are interconnected in a manner suggestive of the connections between neurons in a human brain and which is able to learn by a process of trial and error.

As a whole, GANs are neural networks that are able to generate synthetic information. It does this by implementing two neural networks which learn from and oppose each other, hence adversarial.

The defending champion

Neural networks are used to map inputs to outputs, such as the MNIST dataset. The input is the image and the output is the digit, which the network then learns to map. This is known as discriminative modelling, which discriminates between different classes of data. These are also the discriminators.

Fig 1: MNIST Classfication


But what if…

What if you wanted the reverse and get new handwritten images instead?

Note that although the arrow is pointing from the numbers to the images, for this explanation, we will be discussing unconditional GANs. Further will be mentioned below.

Fig 2: Reverse MNIST

The new challenger

Introducing, the Generator

The role of the generator is to create new images from an input, which in this case is a random noise tensor. As mentioned above, this architecture is known as a Vanilla GAN which does not have a condition or constraint, such as the digit. On the other hand, there are other GANs that take in additional input, so that a number of your choice can be generated. They are commonly named Conditional GANs, or cGANs for short.


Round 1 begins

Fig 3: GAN architecure

  1. The generator starts first, taking a random input and tries to generate a handwritten digit. It hasn’t been trained, so you can expect the output to be very bad.
  2. The real image is passed into the discriminator to guess if it is a real image, which it is. This will return a probability, real_proba.
  3. The generated image is then passed into the discriminator to do the same, which in this case it is expected to guess that it is fake. The probability returned here is fake_proba.
  4. The generator loss is a binary cross-entropy loss (BCE) calculated between the fake_proba and a tensor of values 1, or True. Remember, the goal of the generator is to generate real images, that’s why the loss is measured against True.
  5. The discriminator loss is a sum of the BCE loss of the real_proba and fake_proba. The real_proba is measured against 1 and the fake_proba is measured against 0. Likewise, it hasn’t been trained, so the expected losses would be high.
  6. Finally, backpropagation of the losses is carried out. The weights of the generator and discriminator are updated according to the generator loss and discriminator loss respectively.

The above represents one training step, if you would like to take a look at a code implementation, here is one by Tensorflow.

When does training end?

In a very ideal scenario, the real_proba should be 1, and the fake_proba should also be 1. This means that the generator is able to successfully trick the discriminator. However, care must be taken to ensure that it is not due to the generator overpowering the discriminator. This could lead to mode collapse or low-quality outputs. I found that the advanced ML course on Google was very useful in explaining mode collapse. You can check it out here!

Common problems

Fig 4: Balanced training

Striking a balance between the generator and discriminator is not easy. Neither of them can overpower each other, and they depend on each other. If the balance is lost, the training can become unstable.

I quote what is mentioned in the training notes for the advanced course.

For a GAN, convergence is often a fleeting, rather than stable, state.

Conclusion

Thank you for reading this far and I do hope that this article would help you understand GANs better and piques your interest as well. When I experimented with my first GAN model, which was pix2pix, I encountered multiple issues, which I will talk about when I write an article on that. It was very time-consuming and challenging, but the process was very fruitful. If you’re interested, here is my repository.

Resources

Tensorflow DCGAN

Google Advanced ML Course

GAN Paper, 2014

Fig 1: MNIST Classification

Fig 2: GAN architecture