20.2  Deep Convolutional Generative Adversarial Networks

In Section 20.1, we introduced the basic ideas behind how GANs work. We showed that they can draw samples from some simple, easy-to-sample distribution, like a uniform or normal distribution, and transform them into samples that appear to match the distribution of some dataset. And while our example of matching a 2D Gaussian distribution got the point across, it is not especially exciting.

In this section, we will demonstrate how you can use GANs to generate photorealistic images. We will be basing our models on the deep convolutional GANs (DCGAN) introduced in Radford et al. (2015) . We will borrow the convolutional architectures that have proven so successful for discriminative computer vision problems and show how via GANs, they can be leveraged to generate photorealistic images.

from d2l import torch as d2l
import torch
import torchvision
from torch import nn
import warnings
from d2l import tensorflow as d2l
import tensorflow as tf
%matplotlib inline
from d2l import jax as d2l
import jax
from jax import numpy as jnp
from flax import linen as nn
import optax
import numpy as np
from PIL import Image
import tensorflow as tf
import os
from mxnet import gluon, init, np, npx
from mxnet.gluon import nn
from d2l import mxnet as d2l

npx.set_np()

20.2.1 The Pokemon Dataset

The dataset we will use is a collection of Pokemon sprites obtained from pokemondb. First download, extract and load this dataset.

d2l.DATA_HUB['pokemon'] = (d2l.DATA_URL + 'pokemon.zip',
                           'c065c0e2593b8b161a2d7873e42418bf6a21106c')

data_dir = d2l.download_extract('pokemon')
pokemon = torchvision.datasets.ImageFolder(data_dir)

d2l.DATA_HUB['pokemon'] = (d2l.DATA_URL + 'pokemon.zip',
                           'c065c0e2593b8b161a2d7873e42418bf6a21106c')

data_dir = d2l.download_extract('pokemon')
batch_size = 256
pokemon = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir, batch_size=batch_size, image_size=(64, 64))
Found 40597 files belonging to 721 classes.

d2l.DATA_HUB['pokemon'] = (d2l.DATA_URL + 'pokemon.zip',
                           'c065c0e2593b8b161a2d7873e42418bf6a21106c')

data_dir = d2l.download_extract('pokemon')

d2l.DATA_HUB['pokemon'] = (d2l.DATA_URL + 'pokemon.zip',
                           'c065c0e2593b8b161a2d7873e42418bf6a21106c')

data_dir = d2l.download_extract('pokemon')
pokemon = gluon.data.vision.datasets.ImageFolderDataset(data_dir)

We resize each image into \(64\times 64\). The ToTensor transformation will project the pixel value into \([0, 1]\), while our generator will use the tanh function to obtain outputs in \([-1, 1]\). Therefore we normalize the data with \(0.5\) mean and \(0.5\) standard deviation to match the value range.

batch_size = 256
transformer = torchvision.transforms.Compose([
    torchvision.transforms.Resize((64, 64)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(0.5, 0.5)
])
pokemon.transform = transformer
data_iter = torch.utils.data.DataLoader(
    pokemon, batch_size=batch_size,
    shuffle=True, num_workers=d2l.get_dataloader_workers())
def transform_func(X):
    X = X / 255.
    X = (X - 0.5) / (0.5)
    return X

# For TF>=2.4 use `num_parallel_calls = tf.data.AUTOTUNE`
data_iter = pokemon.map(lambda x, y: (transform_func(x), y),
                        num_parallel_calls=tf.data.experimental.AUTOTUNE)
data_iter = data_iter.cache().shuffle(buffer_size=1000).prefetch(
    buffer_size=tf.data.experimental.AUTOTUNE)
batch_size = 256

# Load all Pokemon images via PIL, resize to 64x64, normalise to [-1, 1]
_all_images = []
for root, dirs, files in os.walk(data_dir):
    for fname in sorted(files):
        if fname.lower().endswith(('.png', '.jpg', '.jpeg')):
            img = Image.open(os.path.join(root, fname)).convert('RGB')
            img = img.resize((64, 64))
            arr = np.array(img, dtype=np.float32) / 255.0
            arr = (arr - 0.5) / 0.5  # normalise to [-1, 1]
            _all_images.append(arr)  # (H, W, C)

_all_images = np.stack(_all_images)  # (N, H, W, C)
_all_labels = np.zeros(len(_all_images), dtype=np.int32)  # dummy labels

data_iter = d2l.load_array((_all_images, _all_labels), batch_size,
                           is_train=True)
/home/smola/d2l/d2l-neu/.venv-jax/lib/python3.11/site-packages/PIL/Image.py:1137: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  warnings.warn(
batch_size = 256
transformer = gluon.data.vision.transforms.Compose([
    gluon.data.vision.transforms.Resize(64),
    gluon.data.vision.transforms.ToTensor(),
    gluon.data.vision.transforms.Normalize(0.5, 0.5)
])
data_iter = gluon.data.DataLoader(
    pokemon.transform_first(transformer), batch_size=batch_size,
    shuffle=True, num_workers=d2l.get_dataloader_workers())

Let’s visualize the first 20 images.

warnings.filterwarnings('ignore')
d2l.set_figsize((4, 4))
for X, y in data_iter:
    imgs = X[:20,:,:,:].permute(0, 2, 3, 1)/2+0.5
    d2l.show_images(imgs, num_rows=4, num_cols=5)
    break

dcgan-c3-pytorch
d2l.set_figsize(figsize=(4, 4))
for X, y in data_iter.take(1):
    imgs = X[:20, :, :, :] / 2 + 0.5
    d2l.show_images(imgs, num_rows=4, num_cols=5)

dcgan-c3-tensorflow
d2l.set_figsize((4, 4))
for batch in data_iter:
    X = np.array(batch[0])  # (N, H, W, C), values in [-1, 1]
    imgs = X[:20] / 2 + 0.5
    d2l.show_images(imgs, num_rows=4, num_cols=5)
    break

dcgan-c3-jax
d2l.set_figsize((4, 4))
for X, y in data_iter:
    imgs = X[:20,:,:,:].transpose(0, 2, 3, 1)/2+0.5
    d2l.show_images(imgs, num_rows=4, num_cols=5)
    break

dcgan-c3-mxnet

20.2.2 The Generator

The generator needs to map the noise variable \(\mathbf z\in\mathbb R^d\), a length-\(d\) vector, to a RGB image with width and height to be \(64\times 64\) . In Section 14.11 we introduced the fully convolutional network that uses transposed convolution layer (refer to Section 14.10) to enlarge input size. The basic block of the generator contains a transposed convolution layer followed by the batch normalization and ReLU activation.

class G_block(nn.Module):
    def __init__(self, out_channels, in_channels=3, kernel_size=4, strides=2,
                 padding=1, **kwargs):
        super(G_block, self).__init__(**kwargs)
        self.conv2d_trans = nn.ConvTranspose2d(in_channels, out_channels,
                                kernel_size, strides, padding, bias=False)
        self.batch_norm = nn.BatchNorm2d(out_channels)
        self.activation = nn.ReLU()

    def forward(self, X):
        return self.activation(self.batch_norm(self.conv2d_trans(X)))
class G_block(tf.keras.layers.Layer):
    def __init__(self, out_channels, kernel_size=4, strides=2, padding="same",
                 **kwargs):
        super().__init__(**kwargs)
        self.conv2d_trans = tf.keras.layers.Conv2DTranspose(
            out_channels, kernel_size, strides, padding, use_bias=False)
        self.batch_norm = tf.keras.layers.BatchNormalization()
        self.activation = tf.keras.layers.ReLU()

    def call(self, X):
        return self.activation(self.batch_norm(self.conv2d_trans(X)))
class G_block(nn.Module):
    out_channels: int
    kernel_size: int = 4
    strides: int = 2
    padding: str = 'SAME'
    use_running_average: bool = False

    @nn.compact
    def __call__(self, X):
        X = nn.ConvTranspose(
            self.out_channels, kernel_size=(self.kernel_size, self.kernel_size),
            strides=(self.strides, self.strides), padding=self.padding,
            use_bias=False,
            kernel_init=nn.initializers.normal(0.02))(X)
        X = nn.BatchNorm(
            use_running_average=self.use_running_average)(X)
        X = nn.relu(X)
        return X
class G_block(nn.Block):
    def __init__(self, channels, kernel_size=4,
                 strides=2, padding=1, **kwargs):
        super(G_block, self).__init__(**kwargs)
        self.conv2d_trans = nn.Conv2DTranspose(
            channels, kernel_size, strides, padding, use_bias=False)
        self.batch_norm = nn.BatchNorm()
        self.activation = nn.Activation('relu')

    def forward(self, X):
        return self.activation(self.batch_norm(self.conv2d_trans(X)))

In default, the transposed convolution layer uses a \(k_h = k_w = 4\) kernel, a \(s_h = s_w = 2\) strides, and a \(p_h = p_w = 1\) padding. With a input shape of \(n_h^{'} \times n_w^{'} = 16 \times 16\), the generator block will double input’s width and height.

\[ \begin{aligned} n_h^{'} \times n_w^{'} &= [(n_h k_h - (n_h-1)(k_h-s_h)- 2p_h] \times [(n_w k_w - (n_w-1)(k_w-s_w)- 2p_w]\\ &= [(k_h + s_h (n_h-1)- 2p_h] \times [(k_w + s_w (n_w-1)- 2p_w]\\ &= [(4 + 2 \times (16-1)- 2 \times 1] \times [(4 + 2 \times (16-1)- 2 \times 1]\\ &= 32 \times 32 .\\ \end{aligned} \]

x = torch.zeros((2, 3, 16, 16))
g_blk = G_block(20)
g_blk(x).shape
torch.Size([2, 20, 32, 32])
x = tf.zeros((2, 16, 16, 3))  # Channel last convention
g_blk = G_block(20)
g_blk(x).shape
TensorShape([2, 32, 32, 20])
x = jnp.zeros((2, 16, 16, 3))  # Channel last convention
g_blk = G_block(out_channels=20)
params = g_blk.init(jax.random.PRNGKey(0), x)
g_blk.apply(params, x, mutable=['batch_stats'])[0].shape
(2, 32, 32, 20)
x = np.zeros((2, 3, 16, 16))
g_blk = G_block(20)
g_blk.initialize()
g_blk(x).shape
(2, 20, 32, 32)

If changing the transposed convolution layer to a \(4\times 4\) kernel, \(1\times 1\) strides and zero padding. With a input size of \(1 \times 1\), the output will have its width and height increased by 3 respectively.

x = torch.zeros((2, 3, 1, 1))
g_blk = G_block(20, strides=1, padding=0)
g_blk(x).shape
torch.Size([2, 20, 4, 4])
x = tf.zeros((2, 1, 1, 3))
# `padding="valid"` corresponds to no padding
g_blk = G_block(20, strides=1, padding="valid")
g_blk(x).shape
TensorShape([2, 4, 4, 20])
x = jnp.zeros((2, 1, 1, 3))
# `padding="VALID"` corresponds to no padding
g_blk = G_block(out_channels=20, strides=1, padding='VALID')
params = g_blk.init(jax.random.PRNGKey(0), x)
g_blk.apply(params, x, mutable=['batch_stats'])[0].shape
(2, 4, 4, 20)
x = np.zeros((2, 3, 1, 1))
g_blk = G_block(20, strides=1, padding=0)
g_blk.initialize()
g_blk(x).shape
(2, 20, 4, 4)

The generator consists of four basic blocks that increase input’s both width and height from 1 to 32. At the same time, it first projects the latent variable into \(64\times 8\) channels, and then halve the channels each time. At last, a transposed convolution layer is used to generate the output. It further doubles the width and height to match the desired \(64\times 64\) shape, and reduces the channel size to \(3\). The tanh activation function is applied to project output values into the \((-1, 1)\) range.

n_G = 64
net_G = nn.Sequential(
    G_block(in_channels=100, out_channels=n_G*8,
            strides=1, padding=0),                  # Output: (64 * 8, 4, 4)
    G_block(in_channels=n_G*8, out_channels=n_G*4), # Output: (64 * 4, 8, 8)
    G_block(in_channels=n_G*4, out_channels=n_G*2), # Output: (64 * 2, 16, 16)
    G_block(in_channels=n_G*2, out_channels=n_G),   # Output: (64, 32, 32)
    nn.ConvTranspose2d(in_channels=n_G, out_channels=3,
                       kernel_size=4, stride=2, padding=1, bias=False),
    nn.Tanh())  # Output: (3, 64, 64)
n_G = 64
net_G = tf.keras.Sequential([
    # Output: (4, 4, 64 * 8)
    G_block(out_channels=n_G*8, strides=1, padding="valid"),
    G_block(out_channels=n_G*4), # Output: (8, 8, 64 * 4)
    G_block(out_channels=n_G*2), # Output: (16, 16, 64 * 2)
    G_block(out_channels=n_G), # Output: (32, 32, 64)
    # Output: (64, 64, 3)
    tf.keras.layers.Conv2DTranspose(
        3, kernel_size=4, strides=2, padding="same", use_bias=False,
        activation="tanh")
])
n_G = 64

class Generator(nn.Module):
    n_G: int = 64
    use_running_average: bool = False

    @nn.compact
    def __call__(self, X):
        X = G_block(out_channels=self.n_G*8, strides=1, padding='VALID',
                     use_running_average=self.use_running_average)(X)
        # Output: (4, 4, 64 * 8)
        X = G_block(out_channels=self.n_G*4,
                     use_running_average=self.use_running_average)(X)
        # Output: (8, 8, 64 * 4)
        X = G_block(out_channels=self.n_G*2,
                     use_running_average=self.use_running_average)(X)
        # Output: (16, 16, 64 * 2)
        X = G_block(out_channels=self.n_G,
                     use_running_average=self.use_running_average)(X)
        # Output: (32, 32, 64)
        X = nn.ConvTranspose(
            3, kernel_size=(4, 4), strides=(2, 2), padding='SAME',
            use_bias=False,
            kernel_init=nn.initializers.normal(0.02))(X)
        X = nn.tanh(X)
        # Output: (64, 64, 3)
        return X

net_G = Generator(n_G=n_G)
n_G = 64
net_G = nn.Sequential()
net_G.add(G_block(n_G*8, strides=1, padding=0),  # Output: (64 * 8, 4, 4)
          G_block(n_G*4),  # Output: (64 * 4, 8, 8)
          G_block(n_G*2),  # Output: (64 * 2, 16, 16)
          G_block(n_G),    # Output: (64, 32, 32)
          nn.Conv2DTranspose(
              3, kernel_size=4, strides=2, padding=1, use_bias=False,
              activation='tanh'))  # Output: (3, 64, 64)

Generate a 100 dimensional latent variable to verify the generator’s output shape.

x = torch.zeros((1, 100, 1, 1))
net_G(x).shape
torch.Size([1, 3, 64, 64])
x = tf.zeros((1, 1, 1, 100))
net_G(x).shape
TensorShape([1, 64, 64, 3])
x = jnp.zeros((1, 1, 1, 100))
params_G = net_G.init(jax.random.PRNGKey(0), x)
net_G.apply(params_G, x, mutable=['batch_stats'])[0].shape
(1, 64, 64, 3)
x = np.zeros((1, 100, 1, 1))
net_G.initialize()
net_G(x).shape
(1, 3, 64, 64)

20.2.3 Discriminator

The discriminator is a normal convolutional network except that it uses a leaky ReLU as its activation function. Given \(\alpha \in[0, 1]\), its definition is

\[\textrm{leaky ReLU}(x) = \begin{cases}x & \textrm{if}\ x > 0\\ \alpha x &\textrm{otherwise}\end{cases}. \tag{20.2.1}\]

As it can be seen, it is normal ReLU if \(\alpha=0\), and an identity function if \(\alpha=1\). For \(\alpha \in (0, 1)\), leaky ReLU is a nonlinear function that give a non-zero output for a negative input. It aims to fix the “dying ReLU” problem that a neuron might always output a negative value and therefore cannot make any progress since the gradient of ReLU is 0.

alphas = [0, .2, .4, .6, .8, 1]
x = d2l.arange(-2, 1, 0.1)
Y = [d2l.numpy(nn.LeakyReLU(alpha)(x)) for alpha in alphas]
d2l.plot(d2l.numpy(x), Y, 'x', 'y', alphas)

dcgan-c9-pytorch
alphas = [0, .2, .4, .6, .8, 1]
x = tf.range(-2, 1, 0.1)
Y = [tf.keras.layers.LeakyReLU(alpha)(x).numpy() for alpha in alphas]
d2l.plot(x.numpy(), Y, 'x', 'y', alphas)

dcgan-c9-tensorflow
alphas = [0, .2, .4, .6, .8, 1]
x = jnp.arange(-2, 1, 0.1)
Y = [np.array(nn.leaky_relu(x, negative_slope=alpha)) for alpha in alphas]
d2l.plot(np.array(x), Y, 'x', 'y', alphas)

dcgan-c9-jax
alphas = [0, .2, .4, .6, .8, 1]
x = d2l.arange(-2, 1, 0.1)
Y = [d2l.numpy(nn.LeakyReLU(alpha)(x)) for alpha in alphas]
d2l.plot(d2l.numpy(x), Y, 'x', 'y', alphas)

dcgan-c9-mxnet

The basic block of the discriminator is a convolution layer followed by a batch normalization layer and a leaky ReLU activation. The hyperparameters of the convolution layer are similar to the transpose convolution layer in the generator block.

class D_block(nn.Module):
    def __init__(self, out_channels, in_channels=3, kernel_size=4, strides=2,
                padding=1, alpha=0.2, **kwargs):
        super(D_block, self).__init__(**kwargs)
        self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size,
                                strides, padding, bias=False)
        self.batch_norm = nn.BatchNorm2d(out_channels)
        self.activation = nn.LeakyReLU(alpha, inplace=True)

    def forward(self, X):
        return self.activation(self.batch_norm(self.conv2d(X)))
class D_block(tf.keras.layers.Layer):
    def __init__(self, out_channels, kernel_size=4, strides=2, padding="same",
                 alpha=0.2, **kwargs):
        super().__init__(**kwargs)
        self.conv2d = tf.keras.layers.Conv2D(out_channels, kernel_size,
                                             strides, padding, use_bias=False)
        self.batch_norm = tf.keras.layers.BatchNormalization()
        self.activation = tf.keras.layers.LeakyReLU(alpha)

    def call(self, X):
        return self.activation(self.batch_norm(self.conv2d(X)))
class D_block(nn.Module):
    out_channels: int
    kernel_size: int = 4
    strides: int = 2
    padding: str = 'SAME'
    alpha: float = 0.2
    use_running_average: bool = False

    @nn.compact
    def __call__(self, X):
        X = nn.Conv(
            self.out_channels, kernel_size=(self.kernel_size, self.kernel_size),
            strides=(self.strides, self.strides), padding=self.padding,
            use_bias=False,
            kernel_init=nn.initializers.normal(0.02))(X)
        X = nn.BatchNorm(
            use_running_average=self.use_running_average)(X)
        X = nn.leaky_relu(X, negative_slope=self.alpha)
        return X
class D_block(nn.Block):
    def __init__(self, channels, kernel_size=4, strides=2,
                 padding=1, alpha=0.2, **kwargs):
        super(D_block, self).__init__(**kwargs)
        self.conv2d = nn.Conv2D(
            channels, kernel_size, strides, padding, use_bias=False)
        self.batch_norm = nn.BatchNorm()
        self.activation = nn.LeakyReLU(alpha)

    def forward(self, X):
        return self.activation(self.batch_norm(self.conv2d(X)))

A basic block with default settings will halve the width and height of the inputs, as we demonstrated in Section 7.3. For example, given a input shape \(n_h = n_w = 16\), with a kernel shape \(k_h = k_w = 4\), a stride shape \(s_h = s_w = 2\), and a padding shape \(p_h = p_w = 1\), the output shape will be:

\[ \begin{aligned} n_h^{'} \times n_w^{'} &= \lfloor(n_h-k_h+2p_h+s_h)/s_h\rfloor \times \lfloor(n_w-k_w+2p_w+s_w)/s_w\rfloor\\ &= \lfloor(16-4+2\times 1+2)/2\rfloor \times \lfloor(16-4+2\times 1+2)/2\rfloor\\ &= 8 \times 8 .\\ \end{aligned} \]

x = torch.zeros((2, 3, 16, 16))
d_blk = D_block(20)
d_blk(x).shape
torch.Size([2, 20, 8, 8])
x = tf.zeros((2, 16, 16, 3))
d_blk = D_block(20)
d_blk(x).shape
TensorShape([2, 8, 8, 20])
x = jnp.zeros((2, 16, 16, 3))
d_blk = D_block(out_channels=20)
params = d_blk.init(jax.random.PRNGKey(0), x)
d_blk.apply(params, x, mutable=['batch_stats'])[0].shape
(2, 8, 8, 20)
x = np.zeros((2, 3, 16, 16))
d_blk = D_block(20)
d_blk.initialize()
d_blk(x).shape
(2, 20, 8, 8)

The discriminator is a mirror of the generator.

n_D = 64
net_D = nn.Sequential(
    D_block(n_D),  # Output: (64, 32, 32)
    D_block(in_channels=n_D, out_channels=n_D*2),  # Output: (64 * 2, 16, 16)
    D_block(in_channels=n_D*2, out_channels=n_D*4),  # Output: (64 * 4, 8, 8)
    D_block(in_channels=n_D*4, out_channels=n_D*8),  # Output: (64 * 8, 4, 4)
    nn.Conv2d(in_channels=n_D*8, out_channels=1,
              kernel_size=4, bias=False))  # Output: (1, 1, 1)
n_D = 64
net_D = tf.keras.Sequential([
    D_block(n_D), # Output: (32, 32, 64)
    D_block(out_channels=n_D*2), # Output: (16, 16, 64 * 2)
    D_block(out_channels=n_D*4), # Output: (8, 8, 64 * 4)
    D_block(out_channels=n_D*8), # Output: (4, 4, 64 * 64)
    # Output: (1, 1, 1)
    tf.keras.layers.Conv2D(1, kernel_size=4, use_bias=False)
])
n_D = 64

class Discriminator(nn.Module):
    n_D: int = 64
    use_running_average: bool = False

    @nn.compact
    def __call__(self, X):
        X = D_block(out_channels=self.n_D,
                     use_running_average=self.use_running_average)(X)
        # Output: (32, 32, 64)
        X = D_block(out_channels=self.n_D*2,
                     use_running_average=self.use_running_average)(X)
        # Output: (16, 16, 64 * 2)
        X = D_block(out_channels=self.n_D*4,
                     use_running_average=self.use_running_average)(X)
        # Output: (8, 8, 64 * 4)
        X = D_block(out_channels=self.n_D*8,
                     use_running_average=self.use_running_average)(X)
        # Output: (4, 4, 64 * 8)
        X = nn.Conv(
            1, kernel_size=(4, 4), padding='VALID', use_bias=False,
            kernel_init=nn.initializers.normal(0.02))(X)
        # Output: (1, 1, 1)
        return X

net_D = Discriminator(n_D=n_D)
n_D = 64
net_D = nn.Sequential()
net_D.add(D_block(n_D),   # Output: (64, 32, 32)
          D_block(n_D*2),  # Output: (64 * 2, 16, 16)
          D_block(n_D*4),  # Output: (64 * 4, 8, 8)
          D_block(n_D*8),  # Output: (64 * 8, 4, 4)
          nn.Conv2D(1, kernel_size=4, use_bias=False))  # Output: (1, 1, 1)

It uses a convolution layer with output channel \(1\) as the last layer to obtain a single prediction value.

x = torch.zeros((1, 3, 64, 64))
net_D(x).shape
torch.Size([1, 1, 1, 1])
x = tf.zeros((1, 64, 64, 3))
net_D(x).shape
TensorShape([1, 1, 1, 1])
x = jnp.zeros((1, 64, 64, 3))
params_D = net_D.init(jax.random.PRNGKey(0), x)
net_D.apply(params_D, x, mutable=['batch_stats'])[0].shape
(1, 1, 1, 1)
x = np.zeros((1, 3, 64, 64))
net_D.initialize()
net_D(x).shape
(1, 1, 1, 1)

20.2.4 Training

Compared to the basic GAN in Section 20.1, we use the same learning rate for both generator and discriminator since they are similar to each other. In addition, we change \(\beta_1\) in Adam (Section 12.10) from \(0.9\) to \(0.5\). It decreases the smoothness of the momentum, the exponentially weighted moving average of past gradients, to take care of the rapid changing gradients because the generator and the discriminator fight with each other. Besides, the random generated noise Z, is a 4-D tensor and we are using GPU to accelerate the computation.

def train(net_D, net_G, data_iter, num_epochs, lr, latent_dim,
          device=d2l.try_gpu()):
    loss = nn.BCEWithLogitsLoss(reduction='sum')
    for w in net_D.parameters():
        nn.init.normal_(w, 0, 0.02)
    for w in net_G.parameters():
        nn.init.normal_(w, 0, 0.02)
    net_D, net_G = net_D.to(device), net_G.to(device)
    trainer_hp = {'lr': lr, 'betas': [0.5,0.999]}
    trainer_D = torch.optim.Adam(net_D.parameters(), **trainer_hp)
    trainer_G = torch.optim.Adam(net_G.parameters(), **trainer_hp)
    animator = d2l.Animator(xlabel='epoch', ylabel='loss',
                            xlim=[1, num_epochs], nrows=2, figsize=(5, 5),
                            legend=['discriminator', 'generator'])
    animator.fig.subplots_adjust(hspace=0.3)
    for epoch in range(1, num_epochs + 1):
        # Train one epoch
        timer = d2l.Timer()
        metric = d2l.Accumulator(3)  # loss_D, loss_G, num_examples
        for X, _ in data_iter:
            batch_size = X.shape[0]
            Z = torch.normal(0, 1, size=(batch_size, latent_dim, 1, 1))
            X, Z = X.to(device), Z.to(device)
            metric.add(d2l.update_D(X, Z, net_D, net_G, loss, trainer_D),
                       d2l.update_G(Z, net_D, net_G, loss, trainer_G),
                       batch_size)
        # Show generated examples
        Z = torch.normal(0, 1, size=(21, latent_dim, 1, 1), device=device)
        # Normalize the synthetic data to N(0, 1)
        fake_x = net_G(Z).permute(0, 2, 3, 1) / 2 + 0.5
        imgs = torch.cat(
            [torch.cat([
                fake_x[i * 7 + j].cpu().detach() for j in range(7)], dim=1)
             for i in range(len(fake_x)//7)], dim=0)
        animator.axes[1].cla()
        animator.axes[1].imshow(imgs)
        # Show the losses
        loss_D, loss_G = metric[0] / metric[2], metric[1] / metric[2]
        animator.add(epoch, (loss_D, loss_G))
    print(f'loss_D {loss_D:.3f}, loss_G {loss_G:.3f}, '
          f'{metric[2] / timer.stop():.1f} examples/sec on {str(device)}')
def train(net_D, net_G, data_iter, num_epochs, lr, latent_dim,
          device=d2l.try_gpu()):
    loss = tf.keras.losses.BinaryCrossentropy(
        from_logits=True, reduction=tf.keras.losses.Reduction.SUM)

    for w in net_D.trainable_variables:
        w.assign(tf.random.normal(mean=0, stddev=0.02, shape=w.shape))
    for w in net_G.trainable_variables:
        w.assign(tf.random.normal(mean=0, stddev=0.02, shape=w.shape))

    optimizer_hp = {"learning_rate": lr, "beta_1": 0.5, "beta_2": 0.999}
    optimizer_D = tf.keras.optimizers.Adam(**optimizer_hp)
    optimizer_G = tf.keras.optimizers.Adam(**optimizer_hp)

    animator = d2l.Animator(xlabel='epoch', ylabel='loss',
                            xlim=[1, num_epochs], nrows=2, figsize=(5, 5),
                            legend=['discriminator', 'generator'])
    animator.fig.subplots_adjust(hspace=0.3)

    for epoch in range(1, num_epochs + 1):
        # Train one epoch
        timer = d2l.Timer()
        metric = d2l.Accumulator(3) # loss_D, loss_G, num_examples
        for X, _ in data_iter:
            batch_size = X.shape[0]
            Z = tf.random.normal(mean=0, stddev=1,
                                 shape=(batch_size, 1, 1, latent_dim))
            metric.add(d2l.update_D(X, Z, net_D, net_G, loss, optimizer_D),
                       d2l.update_G(Z, net_D, net_G, loss, optimizer_G),
                       batch_size)

        # Show generated examples
        Z = tf.random.normal(mean=0, stddev=1, shape=(21, 1, 1, latent_dim))
        # Normalize the synthetic data to N(0, 1)
        fake_x = net_G(Z) / 2 + 0.5
        imgs = tf.concat([tf.concat([fake_x[i * 7 + j] for j in range(7)],
                                    axis=1)
                          for i in range(len(fake_x) // 7)], axis=0)
        animator.axes[1].cla()
        animator.axes[1].imshow(imgs)
        # Show the losses
        loss_D, loss_G = metric[0] / metric[2], metric[1] / metric[2]
        animator.add(epoch, (loss_D, loss_G))
    print(f'loss_D {loss_D:.3f}, loss_G {loss_G:.3f}, '
          f'{metric[2] / timer.stop():.1f} examples/sec on {str(device._device_name)}')
def train(net_D, net_G, data_iter, num_epochs, lr, latent_dim):
    key = jax.random.PRNGKey(0)

    # Initialize generator and discriminator parameters
    dummy_Z = jnp.ones((1, 1, 1, latent_dim))
    dummy_X = jnp.ones((1, 64, 64, 3))
    key, key_G, key_D = jax.random.split(key, 3)
    variables_G = net_G.init(key_G, dummy_Z)
    variables_D = net_D.init(key_D, dummy_X)

    # Reinitialize with normal(0, 0.02)
    params_G = jax.tree.map(
        lambda p: jax.random.normal(key_G, p.shape) * 0.02, variables_G['params'])
    batch_stats_G = variables_G.get('batch_stats', {})
    params_D = jax.tree.map(
        lambda p: jax.random.normal(key_D, p.shape) * 0.02, variables_D['params'])
    batch_stats_D = variables_D.get('batch_stats', {})

    optimizer_D = optax.adam(lr, b1=0.5, b2=0.999)
    optimizer_G = optax.adam(lr, b1=0.5, b2=0.999)
    opt_state_D = optimizer_D.init(params_D)
    opt_state_G = optimizer_G.init(params_G)

    animator = d2l.Animator(xlabel='epoch', ylabel='loss',
                            xlim=[1, num_epochs], nrows=2, figsize=(5, 5),
                            legend=['discriminator', 'generator'])
    animator.fig.subplots_adjust(hspace=0.3)

    for epoch in range(1, num_epochs + 1):
        timer = d2l.Timer()
        metric = d2l.Accumulator(3)  # loss_D, loss_G, num_examples
        for batch in data_iter:
            X = jnp.array(batch[0])  # Already (N, H, W, C)
            batch_size = X.shape[0]
            key, subkey = jax.random.split(key)
            Z = jax.random.normal(subkey, (batch_size, 1, 1, latent_dim))

            # Update discriminator
            fake_X, updates_G = net_G.apply(
                {'params': params_G, 'batch_stats': batch_stats_G},
                Z, mutable=['batch_stats'])
            batch_stats_G = updates_G['batch_stats']

            def loss_D_fn(params_D):
                real_Y, updates_D = net_D.apply(
                    {'params': params_D, 'batch_stats': batch_stats_D},
                    X, mutable=['batch_stats'])
                fake_Y, _ = net_D.apply(
                    {'params': params_D, 'batch_stats': batch_stats_D},
                    fake_X, mutable=['batch_stats'])
                ones = jnp.ones((batch_size,))
                zeros = jnp.zeros((batch_size,))
                loss_D = (jnp.sum(optax.sigmoid_binary_cross_entropy(
                              real_Y.squeeze(), ones)) +
                          jnp.sum(optax.sigmoid_binary_cross_entropy(
                              fake_Y.squeeze(), zeros))) / 2
                return loss_D, updates_D

            (loss_D_val, updates_D), grads_D = jax.value_and_grad(
                loss_D_fn, has_aux=True)(params_D)
            batch_stats_D = updates_D['batch_stats']
            updates_optax_D, opt_state_D = optimizer_D.update(
                grads_D, opt_state_D, params_D)
            params_D = optax.apply_updates(params_D, updates_optax_D)

            # Update generator
            def loss_G_fn(params_G):
                fake_X, updates_G = net_G.apply(
                    {'params': params_G, 'batch_stats': batch_stats_G},
                    Z, mutable=['batch_stats'])
                fake_Y, _ = net_D.apply(
                    {'params': params_D, 'batch_stats': batch_stats_D},
                    fake_X, mutable=['batch_stats'])
                ones = jnp.ones((batch_size,))
                loss_G = jnp.sum(optax.sigmoid_binary_cross_entropy(
                    fake_Y.squeeze(), ones))
                return loss_G, updates_G

            (loss_G_val, updates_G), grads_G = jax.value_and_grad(
                loss_G_fn, has_aux=True)(params_G)
            batch_stats_G = updates_G['batch_stats']
            updates_optax_G, opt_state_G = optimizer_G.update(
                grads_G, opt_state_G, params_G)
            params_G = optax.apply_updates(params_G, updates_optax_G)

            metric.add(loss_D_val, loss_G_val, batch_size)

        # Show generated examples
        key, subkey = jax.random.split(key)
        Z = jax.random.normal(subkey, (21, 1, 1, latent_dim))
        fake_x, _ = net_G.apply(
            {'params': params_G, 'batch_stats': batch_stats_G},
            Z, mutable=['batch_stats'])
        fake_x = fake_x / 2 + 0.5
        imgs = jnp.concatenate(
            [jnp.concatenate([fake_x[i * 7 + j] for j in range(7)], axis=1)
             for i in range(len(fake_x) // 7)], axis=0)
        animator.axes[1].cla()
        animator.axes[1].imshow(np.array(imgs))
        # Show the losses
        loss_D, loss_G = metric[0] / metric[2], metric[1] / metric[2]
        animator.add(epoch, (loss_D, loss_G))
    print(f'loss_D {loss_D:.3f}, loss_G {loss_G:.3f}, '
          f'{metric[2] / timer.stop():.1f} examples/sec')
def train(net_D, net_G, data_iter, num_epochs, lr, latent_dim,
          device=d2l.try_gpu()):
    loss = gluon.loss.SigmoidBCELoss()
    net_D.initialize(init=init.Normal(0.02), force_reinit=True, ctx=device)
    net_G.initialize(init=init.Normal(0.02), force_reinit=True, ctx=device)
    trainer_hp = {'learning_rate': lr, 'beta1': 0.5}
    trainer_D = gluon.Trainer(net_D.collect_params(), 'adam', trainer_hp)
    trainer_G = gluon.Trainer(net_G.collect_params(), 'adam', trainer_hp)
    animator = d2l.Animator(xlabel='epoch', ylabel='loss',
                            xlim=[1, num_epochs], nrows=2, figsize=(5, 5),
                            legend=['discriminator', 'generator'])
    animator.fig.subplots_adjust(hspace=0.3)
    for epoch in range(1, num_epochs + 1):
        # Train one epoch
        timer = d2l.Timer()
        metric = d2l.Accumulator(3)  # loss_D, loss_G, num_examples
        for X, _ in data_iter:
            batch_size = X.shape[0]
            Z = np.random.normal(0, 1, size=(batch_size, latent_dim, 1, 1))
            X, Z = X.as_in_ctx(device), Z.as_in_ctx(device),
            metric.add(d2l.update_D(X, Z, net_D, net_G, loss, trainer_D),
                       d2l.update_G(Z, net_D, net_G, loss, trainer_G),
                       batch_size)
        # Show generated examples
        Z = np.random.normal(0, 1, size=(21, latent_dim, 1, 1), ctx=device)
        # Normalize the synthetic data to N(0, 1)
        fake_x = net_G(Z).transpose(0, 2, 3, 1) / 2 + 0.5
        imgs = np.concatenate(
            [np.concatenate([fake_x[i * 7 + j] for j in range(7)], axis=1)
             for i in range(len(fake_x)//7)], axis=0)
        animator.axes[1].cla()
        animator.axes[1].imshow(imgs.asnumpy())
        # Show the losses
        loss_D, loss_G = metric[0] / metric[2], metric[1] / metric[2]
        animator.add(epoch, (loss_D, loss_G))
    print(f'loss_D {loss_D:.3f}, loss_G {loss_G:.3f}, '
          f'{metric[2] / timer.stop():.1f} examples/sec on {str(device)}')

We train the model with a small number of epochs just for demonstration. For better performance, the variable num_epochs can be set to a larger number.

latent_dim, lr, num_epochs = 100, 0.005, 20
train(net_D, net_G, data_iter, num_epochs, lr, latent_dim)
loss_D 0.071, loss_G 7.131, 5662.5 examples/sec on cuda:0

dcgan-c15-pytorch
latent_dim, lr, num_epochs = 100, 0.0005, 40
train(net_D, net_G, data_iter, num_epochs, lr, latent_dim)
loss_D 0.251, loss_G 3.113, 1673.3 examples/sec on /GPU:0

dcgan-c15-tensorflow
latent_dim, lr, num_epochs = 100, 0.005, 20
train(net_D, net_G, data_iter, num_epochs, lr, latent_dim)
loss_D 0.022, loss_G 7.704, 527.9 examples/sec

dcgan-c15-jax
latent_dim, lr, num_epochs = 100, 0.005, 20
train(net_D, net_G, data_iter, num_epochs, lr, latent_dim)
loss_D 0.249, loss_G 5.777, 2404.9 examples/sec on gpu(0)

dcgan-c15-mxnet

20.2.5 Summary

  • DCGAN architecture has four convolutional layers for the Discriminator and four “fractionally-strided” convolutional layers for the Generator.
  • The Discriminator is a 4-layer strided convolutions with batch normalization (except its input layer) and leaky ReLU activations.
  • Leaky ReLU is a nonlinear function that give a non-zero output for a negative input. It aims to fix the “dying ReLU” problem and helps the gradients flow easier through the architecture.

20.2.6 Exercises

  1. What will happen if we use standard ReLU activation rather than leaky ReLU?
  2. Apply DCGAN on Fashion-MNIST and see which category works well and which does not.