How to build a simple Generative Adversarial Network (GAN) using Keras?


What is Generative Adversarial Network?

Have you ever heard about deep learning or neural network? That's right, generative adversarial network (also called GAN) is two neural networks: generator and discriminator that learning from each other.

The generator tries to generate a new sample that similar to the original data. The discriminator has to classify the data generated or origin (fake or real).

Why GAN?

To getting started with deep learning, it's should be hard to understand and imaginable. Therefore, I suggest you make a simple GAN for visualizing how deep learning works?


Anaconda with Keras package installed.

Let's start

We import libraries from python

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import math

from keras.models import Sequential
from keras.layers import Dense, Dropout
from tqdm.notebook import tqdm as tqdm_notebook
from sklearn.preprocessing import StandardScaler

Deal with real data that created from:

x = f(t) = 2π * t (0 <= t < 1)

y = f(x) = sin(x)

def sample_data(batch_size):
  x = 2 * math.pi * np.random.random_sample(batch_size)
  data = []
  for i in x:
    data.append([i, math.sin(i)])
  data = np.array(data)
  return data

Generate the noise

def create_noise(batch_size):
  return np.random.uniform(-1., 1., size=[batch_size, 2])

Create a generator with two hidden layers and a discriminator with three hidden layers.

def create_generator():
  model = Sequential([
    Dense(16, input_dim=2, activation='relu'),
    Dense(32, activation='relu'),
  model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
  return model

def create_discriminator():
  model = Sequential([
    Dense(256, input_dim=2, activation='relu'),
    Dense(128, activation='relu'),
    Dense(64, activation='relu'),
    Dense(1, activation='sigmoid')
  model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
  return model

Define GAN with generator and discriminator combined.

def create_gan(generator, discriminator):
  discriminator.trainable = False
  model = Sequential([
  model.compile(loss='binary_crossentropy', optimizer='adam')
  return model

Start training with 10000 epochs and batch size is 256. After 100 epochs, we plot a graph to see the result.


generator = create_generator()
discriminator = create_discriminator()
gan = create_gan(generator, discriminator)

for i in tqdm_notebook(range(epochs)):
  # get real batch
  real_batch = sample_data(batch_size)
  # fake batch
  random_noise = create_noise(batch_size)
  fake_batch = generator.predict(random_noise)
  # plot result
  if i % plot_step == 0:
    plt.scatter(real_batch[:,0], real_batch[:,1], color='g')
    plt.scatter(fake_batch[:,0], fake_batch[:,1], color='r')
    plt.title('Epoch: %04d' % i)
    plt.xlim([-1, 7])
    plt.ylim([-1.5, 1.5])
    plt.savefig('plots/epoch-%04d.png' %i)
  # merge batch
  batch = np.concatenate((real_batch, fake_batch))
  y = np.concatenate((np.ones(batch_size), np.zeros(batch_size)))
  # train discriminator
  discriminator.train_on_batch(batch, y)
  # train generator
  gan.train_on_batch(random_noise, np.ones(batch_size))

Wait and see the result:

As you can see, the GAN wasn't generating perfectly new data in all cases of the origin. It had found a case that the discriminator always true, called mode collapsed. We have a lot of work to do for a better GAN.


You can found the notebook at:

This article is proposed a simple GAN but not an actual optimize for instance. So you can improve it or move to another GAN project of yourself.

Thanks for reading.

By: Anh Hao