#### GA Ns Py Torch

K
{
"cells": [
{
"cell_type": "markdown",
"id": "oDUanjQCdtoV"
},
"source": [
"### What is a GAN?\n",
"\n",
"In 2014, [Goodfellow et al.](https://arxiv.org/abs/1406.2661) presented a method for training generative models called Generative Adversarial Networks (GANs for short). In a GAN, we build two different neural networks. Our first network is a traditional classification network, called the **discriminator**. We will train the discriminator to take images, and classify them as being real (belonging to the training set) or fake (not present in the training set). Our other network, called the **generator**, will take random noise as input and transform it using a neural network to produce images. The goal of the generator is to fool the discriminator into thinking the images it produced are real.\n",
"\n",
"We can think of this back and forth process of the generator ($G$) trying to fool the discriminator ($D$), and the discriminator trying to correctly classify real vs. fake as a minimax game:\n",
"$$\\underset{G}{\\text{minimize}}\\; \\underset{D}{\\text{maximize}}\\; \\mathbb{E}_{x \\sim p_\\text{data}}\\left[\\log D(x)\\right] + \\mathbb{E}_{z \\sim p(z)}\\left[\\log \\left(1-D(G(z))\\right)\\right]$$\n",
"where $z \\sim p(z)$ are the random noise samples, $G(z)$ are the generated images using the neural network generator $G$, and $D$ is the output of the discriminator, specifying the probability of an input being real. In [Goodfellow et al.](https://arxiv.org/abs/1406.2661), they analyze this minimax game and show how it relates to minimizing the Jensen-Shannon divergence between the training data distribution and the generated samples from $G$.\n",
"\n",
"To optimize this minimax game, we will aternate between taking gradient *descent* steps on the objective for $G$, and gradient *ascent* steps on the objective for $D$:\n",
"1. update the **generator** ($G$) to minimize the probability of the __discriminator making the correct choice__. \n",
"2. update the **discriminator** ($D$) to maximize the probability of the __discriminator making the correct choice__.\n",
"\n",
"While these updates are useful for analysis, they do not perform well in practice. Instead, we will use a different objective when we update the generator: maximize the probability of the **discriminator making the incorrect choice**. This small change helps to allevaiate problems with the generator gradient vanishing when the discriminator is confident. This is the standard update used in most GAN papers, and was used in the original paper from [Goodfellow et al.](https://arxiv.org/abs/1406.2661). \n",
"\n",
"In this assignment, we will alternate the following updates:\n",
"1. Update the generator ($G$) to maximize the probability of the discriminator making the incorrect choice on generated data:\n",
"$$\\underset{G}{\\text{maximize}}\\; \\mathbb{E}_{z \\sim p(z)}\\left[\\log D(G(z))\\right]$$\n",
"2. Update the discriminator ($D$), to maximize the probability of the discriminator making the correct choice on real and generated data:\n",
"$$\\underset{D}{\\text{maximize}}\\; \\mathbb{E}_{x \\sim p_\\text{data}}\\left[\\log D(x)\\right] + \\mathbb{E}_{z \\sim p(z)}\\left[\\log \\left(1-D(G(z))\\right)\\right]$$\n",
"\n",
"### What else is there in this notebook?\n",
"![caption](gan_outputs_pytorch.png)"
]
},
{
"cell_type": "markdown",
"id": "OgrXJSMmdtoW"
},
"source": [
"## Setup"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "CYVwNTuFdtoX"
},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"from torch.nn import init\n",
"import torchvision\n",
"import torchvision.transforms as T\n",
"import torch.optim as optim\n",
"from torch.utils.data import sampler\n",
"import torchvision.datasets as dset\n",
"\n",
"import numpy as np\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.gridspec as gridspec\n",
"\n",
"%matplotlib inline\n",
"plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots\n",
"plt.rcParams['image.interpolation'] = 'nearest'\n",
"plt.rcParams['image.cmap'] = 'gray'\n",
"\n",
"def show_images(images):\n",
"    images = np.reshape(images, [images.shape[0], -1])  # images reshape to (batch_size, D)\n",
"    sqrtn = int(np.ceil(np.sqrt(images.shape[0])))\n",
"    sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))\n",
"\n",
"    fig = plt.figure(figsize=(sqrtn, sqrtn))\n",
"    gs = gridspec.GridSpec(sqrtn, sqrtn)\n",
"    gs.update(wspace=0.05, hspace=0.05)\n",
"\n",
"    for i, img in enumerate(images):\n",
"        ax = plt.subplot(gs[i])\n",
"        plt.axis('off')\n",
"        ax.set_xticklabels([])\n",
"        ax.set_yticklabels([])\n",
"        ax.set_aspect('equal')\n",
"        plt.imshow(img.reshape([sqrtimg,sqrtimg]))\n",
"    return \n",
"\n",
"def preprocess_img(x):\n",
"    return 2 * x - 1.0\n",
"\n",
"def deprocess_img(x):\n",
"    return (x + 1.0) / 2.0\n",
"\n",
"def rel_error(x,y):\n",
"    return np.max(np.abs(x - y) / (np.maximum(1e-8, np.abs(x) + np.abs(y))))\n",
"\n",
"def count_params(model):\n",
"    \"\"\"Count the number of parameters in the current TensorFlow graph \"\"\"\n",
"    param_count = np.sum([np.prod(p.size()) for p in model.parameters()])\n",
"    return param_count\n",
"\n",
]
},
{
"cell_type": "markdown",
"id": "BC8RGqopdtob"
},
"source": [
"## Dataset"
]
},
{
"cell_type": "code",
"execution_count": 3,
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000,
"referenced_widgets": [
"3de4b11b43e04731bb23455e0f368565",
"42d3d6f4bc874ef5bbc98b4632e29e38",
"a6be8502a9844ec494d322042e214d77",
"17caedca349a47d1a992e8bbb60b5642",
"ee639b5a0e294d9883e43314ea8d6702",
"c6e99d483fa24dfea0772971de3316fe",
"9224ea3208fe494b915250904bfa3eb6",
"bd1077ce31694e2a90080b6d51b484cc",
"255e59e324e2406ba342673af8f2a105",
"b7d43cde0bc6456ab63d6ab2f8422bec",
"ca345e79a4384c5489e1a642afae551f",
"261c9a1cb1f4421292fc632c50b78e20",
"be079b45d9f04ff38d9eba6c9e289c11",
"7ba6e77f6cbb4ca785ef608df8b5ef5c",
"8aae89186587409e96bfa45fe48b0885",
"d0bc3366acda44029c562ec8899651e8",
"f6c446977fc841f6b8615582c048740d",
"c0e0136b54044de498bf141dac1db574",
"8e9b9e1b5e464cbc9d35c27671ed9959",
"86504c47c60948639ed318e4386cbbd2",
"489b59fc12cd4dd99e66f1f80caecfea",
"51e877e40f8141b4a1aa7c9fd63ed03d",
"7f2fb3876b0f442999db456fea411d37",
"b7dbfcbf7c9c489b8975edf939632b16",
"51d715eb014e4163a1bb57b15e5b4bc4",
"78f6e9e772434dc891521d57062bf1b2",
"0149a792bcdb441eb23471e3d733878e",
"58d3a6ea0d6c49b68a22693d026ecc7d"
]
},
"id": "cxkhjwB6dtob",
"outputId": "4ed97823-c3fa-4c58-e9d6-a7df2380ee49",
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3de4b11b43e04731bb23455e0f368565",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))"
]
},
"tags": []
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting ./utils/datasets/MNIST_data/MNIST/raw/train-images-idx3-ubyte.gz to ./utils/datasets/MNIST_data/MNIST/raw\n",
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9224ea3208fe494b915250904bfa3eb6",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))"
]
},
"tags": []
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting ./utils/datasets/MNIST_data/MNIST/raw/train-labels-idx1-ubyte.gz to ./utils/datasets/MNIST_data/MNIST/raw\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8aae89186587409e96bfa45fe48b0885",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))"
]
},
"tags": []
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting ./utils/datasets/MNIST_data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./utils/datasets/MNIST_data/MNIST/raw\n",
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "51e877e40f8141b4a1aa7c9fd63ed03d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))"
]
},
"tags": []
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting ./utils/datasets/MNIST_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./utils/datasets/MNIST_data/MNIST/raw\n",
"Processing...\n",
"Done!\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.6/dist-packages/torchvision/datasets/mnist.py:469: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at  /pytorch/torch/csrc/utils/tensor_numpy.cpp:141.)\n",
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\n"
]
},
{
"data": {
"text/plain": [
"<Figure size 864x864 with 128 Axes>"
]
},
"needs_background": "light",
"tags": []
},
"output_type": "display_data"
}
],
"source": [
"class ChunkSampler(sampler.Sampler):\n",
"    \"\"\"Samples elements sequentially from some offset. \n",
"    Arguments:\n",
"        num_samples: # of desired datapoints\n",
"        start: offset where we should start selecting from\n",
"    \"\"\"\n",
"    def __init__(self, num_samples, start=0):\n",
"        self.num_samples = num_samples\n",
"        self.start = start\n",
"\n",
"    def __iter__(self):\n",
"        return iter(range(self.start, self.start + self.num_samples))\n",
"\n",
"    def __len__(self):\n",
"        return self.num_samples\n",
"\n",
"NUM_TRAIN = 50000\n",
"NUM_VAL = 5000\n",
"\n",
"NOISE_DIM = 96\n",
"batch_size = 128\n",
"\n",
"                           transform=T.ToTensor())\n",
"                          sampler=ChunkSampler(NUM_TRAIN, 0))\n",
"\n",
"                           transform=T.ToTensor())\n",
"                        sampler=ChunkSampler(NUM_VAL, NUM_TRAIN))\n",
"\n",
"\n",
"show_images(imgs)"
]
},
{
"cell_type": "markdown",
"id": "oXmeqMF_dtoe"
},
"source": [
"## Random Noise\n",
"Generate uniform noise from -1 to 1 with shape [batch_size, dim]."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "mTSBJLnDdtoe"
},
"outputs": [],
"source": [
"def sample_noise(batch_size, dim):\n",
"    \"\"\"\n",
"    Generate a PyTorch Tensor of uniform random noise.\n",
"\n",
"    Input:\n",
"    - batch_size: Integer giving the batch size of noise to generate.\n",
"    - dim: Integer giving the dimension of noise to generate.\n",
"    \n",
"    Output:\n",
"    - A PyTorch Tensor of shape (batch_size, dim) containing uniform\n",
"      random noise in the range (-1, 1).\n",
"    \"\"\"\n",
]
},
{
"cell_type": "markdown",
"id": "pyZvX4kYdtoh"
},
"source": [
"Check noise is the correct shape and type:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "36d3917f-d5cd-43ef-8bb2-ac3ae6896d84"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"All tests passed!\n"
]
}
],
"source": [
"def test_sample_noise():\n",
"    batch_size = 3\n",
"    dim = 4\n",
"    torch.manual_seed(231)\n",
"    z = sample_noise(batch_size, dim)\n",
"    np_z = z.cpu().numpy()\n",
"    assert np_z.shape == (batch_size, dim)\n",
"    assert torch.is_tensor(z)\n",
"    assert np.all(np_z >= -1.0) and np.all(np_z <= 1.0)\n",
"    assert np.any(np_z < 0.0) and np.any(np_z > 0.0)\n",
"    print('All tests passed!')\n",
"    \n",
"test_sample_noise()"
]
},
{
"cell_type": "markdown",
"id": "SS_F4WRHdtok"
},
"source": [
"## Flatten"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "AliSaexBdtok"
},
"outputs": [],
"source": [
"class Flatten(nn.Module):\n",
"    def forward(self, x):\n",
"        N, C, H, W = x.size() # read in N, C, H, W\n",
"        return x.view(N, -1)  # \"flatten\" the C * H * W values into a single vector per image\n",
"    \n",
"class Unflatten(nn.Module):\n",
"    \"\"\"\n",
"    An Unflatten module receives an input of shape (N, C*H*W) and reshapes it\n",
"    to produce an output of shape (N, C, H, W).\n",
"    \"\"\"\n",
"    def __init__(self, N=-1, C=128, H=7, W=7):\n",
"        super(Unflatten, self).__init__()\n",
"        self.N = N\n",
"        self.C = C\n",
"        self.H = H\n",
"        self.W = W\n",
"    def forward(self, x):\n",
"        return x.view(self.N, self.C, self.H, self.W)\n",
"\n",
"def initialize_weights(m):\n",
"    if isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose2d):\n",
"        init.xavier_uniform_(m.weight.data)"
]
},
{
"cell_type": "markdown",
"id": "cjQipV5idton"
},
"source": [
"## CPU / GPU"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "Ss-M5fZwdton"
},
"outputs": [],
"source": [
"dtype = torch.FloatTensor\n",
"dtype = torch.cuda.FloatTensor # COMMENT THIS LINE IF YOU'RE ON A CPU!"
]
},
{
"cell_type": "markdown",
"id": "OcpYcDNLdtoq"
},
"source": [
"# Discriminator"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "gT4rloGkdtor"
},
"outputs": [],
"source": [
"def discriminator():\n",
"    \"\"\"\n",
"    Build and return a PyTorch model implementing the architecture.\n",
"    \"\"\"\n",
"    model = nn.Sequential( Flatten(),\n",
"                           nn.Linear(784, 256),\n",
"                           nn.LeakyReLU(inplace=True),\n",
"                           nn.Linear(256,256),\n",
"                           nn.LeakyReLU(inplace=True),\n",
"                           nn.Linear(256,1)\n",
"                         )\n",
"    return model"
]
},
{
"cell_type": "markdown",
"id": "0MbstME3dtot"
},
"source": [
"Test to make sure the number of parameters in the discriminator is correct:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"id": "APf0Nevndtot",
"outputId": "01740507-7878-47be-f656-be677d6079a7"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Correct number of parameters in discriminator.\n"
]
}
],
"source": [
"def test_discriminator(true_count=267009):\n",
"    model = discriminator()\n",
"    cur_count = count_params(model)\n",
"    if cur_count != true_count:\n",
"        print('Incorrect number of parameters in discriminator. Check your achitecture.')\n",
"    else:\n",
"        print('Correct number of parameters in discriminator.')     \n",
"\n",
"test_discriminator()"
]
},
{
"cell_type": "markdown",
"id": "K03pVpqqdtow"
},
"source": [
"# Generator"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "4GdEZ7grdtow"
},
"outputs": [],
"source": [
"def generator(noise_dim=NOISE_DIM):\n",
"    \"\"\"\n",
"    Build and return a PyTorch model implementing the architecture.\n",
"    \"\"\"\n",
"    model = nn.Sequential( nn.Linear(noise_dim,1024),\n",
"                           nn.ReLU(inplace=True),\n",
"                           nn.Linear(1024,1024),\n",
"                           nn.ReLU(inplace=True),\n",
"                           nn.Linear(1024,784),\n",
"                           nn.Tanh()\n",
"                         )\n",
"    return model"
]
},
{
"cell_type": "markdown",
"id": "cBIjphBKdtoz"
},
"source": [
"Test to make sure the number of parameters in the generator is correct:"
]
},
{
"cell_type": "code",
"execution_count": 12,
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"id": "Lfc_zIWJdtoz",
"outputId": "0c1c13cb-7307-4a91-abc2-3460b3886cc8"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Correct number of parameters in generator.\n"
]
}
],
"source": [
"def test_generator(true_count=1858320):\n",
"    model = generator(4)\n",
"    cur_count = count_params(model)\n",
"    if cur_count != true_count:\n",
"        print('Incorrect number of parameters in generator. Check your achitecture.')\n",
"    else:\n",
"        print('Correct number of parameters in generator.')\n",
"\n",
"test_generator()"
]
},
{
"cell_type": "markdown",
"id": "xnMGNozNdto2"
},
"source": [
"# GAN Loss\n",
"\n",
"Compute the generator and discriminator loss. The generator loss is:\n",
"$$\\ell_G = -\\mathbb{E}_{z \\sim p(z)}\\left[\\log D(G(z))\\right]$$\n",
"and the discriminator loss is:\n",
"$$\\ell_D = -\\mathbb{E}_{x \\sim p_\\text{data}}\\left[\\log D(x)\\right] - \\mathbb{E}_{z \\sim p(z)}\\left[\\log \\left(1-D(G(z))\\right)\\right]$$"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "9yu9yAO6dto2"
},
"outputs": [],
"source": [
"def bce_loss(input, target):\n",
"    \"\"\"pa \n",
"    Inputs:\n",
"    - input: PyTorch Tensor of shape (N, ) giving scores.\n",
"    - target: PyTorch Tensor of shape (N,) containing 0 and 1 giving targets.\n",
"\n",
"    Returns:\n",
"    - A PyTorch Tensor containing the mean BCE loss over the minibatch of input data.\n",
"    \"\"\"\n",
"    neg_abs = - input.abs()\n",
"    loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log()\n",
"    return loss.mean()"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "AOCyZALXdto5"
},
"outputs": [],
"source": [
"def discriminator_loss(logits_real, logits_fake):\n",
"    \"\"\"\n",
"    Computes the discriminator loss described above.\n",
"    \n",
"    Inputs:\n",
"    - logits_real: PyTorch Tensor of shape (N,) giving scores for the real data.\n",
"    - logits_fake: PyTorch Tensor of shape (N,) giving scores for the fake data.\n",
"    \n",
"    Returns:\n",
"    - loss: PyTorch Tensor containing (scalar) the loss for the discriminator.\n",
"    \"\"\"\n",
"    N, _ = logits_real.size() \n",
"    loss = (bce_loss(logits_real, torch.ones(N).type(dtype)))+(bce_loss(logits_fake, torch.zeros(N).type(dtype)))\n",
"    return loss\n",
"\n",
"def generator_loss(logits_fake):\n",
"    \"\"\"\n",
"    Computes the generator loss described above.\n",
"\n",
"    Inputs:\n",
"    - logits_fake: PyTorch Tensor of shape (N,) giving scores for the fake data.\n",
"    \n",
"    Returns:\n",
"    - loss: PyTorch Tensor containing the (scalar) loss for the generator.\n",
"    \"\"\"\n",
"    N, _ = logits_fake.size()\n",
"    loss = (bce_loss(logits_fake, torch.ones(N).type(dtype)))\n",
"    return loss"
]
},
{
"cell_type": "markdown",
},
"source": [
"Check generator and discriminator loss. We should see errors < 1e-7."
]
},
{
"cell_type": "code",
"execution_count": 15,
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"id": "9qVTG21-dto7",
"outputId": "fd6dbf37-e87d-4e0a-9781-58de0fc98eea"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Maximum error in d_loss: 2.83811e-08\n"
]
}
],
"source": [
"def test_discriminator_loss(logits_real, logits_fake, d_loss_true):\n",
"    d_loss = discriminator_loss(torch.Tensor(logits_real).type(dtype),\n",
"                                torch.Tensor(logits_fake).type(dtype)).cpu().numpy()\n",
"    print(\"Maximum error in d_loss: %g\"%rel_error(d_loss_true, d_loss))\n",
"\n",
]
},
{
"cell_type": "code",
"execution_count": 16,
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"id": "AK2fPRgNdto-",
"outputId": "6b71b6d9-92e2-4206-a311-5cda0d1060d0"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Maximum error in g_loss: 3.4188e-08\n"
]
}
],
"source": [
"def test_generator_loss(logits_fake, g_loss_true):\n",
"    g_loss = generator_loss(torch.Tensor(logits_fake).type(dtype)).cpu().numpy()\n",
"    print(\"Maximum error in g_loss: %g\"%rel_error(g_loss_true, g_loss))\n",
"\n",
]
},
{
"cell_type": "markdown",
"id": "hZ9a-AOgdtpA"
},
"source": [
"# Optimizing our loss"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "sJeiH6ZJdtpA"
},
"outputs": [],
"source": [
"def get_optimizer(model):\n",
"    \"\"\"\n",
"    Construct and return an Adam optimizer for the model with learning rate 1e-3,\n",
"    beta1=0.5, and beta2=0.999.\n",
"    \n",
"    Input:\n",
"    - model: A PyTorch model that we want to optimize.\n",
"    \n",
"    Returns:\n",
"    - An Adam optimizer for the model with the desired hyperparameters.\n",
"    \"\"\"\n",
"    optimizer = optim.Adam(model.parameters(), lr = 1e-3, betas = (0.5,0.999))\n",
"    return optimizer"
]
},
{
"cell_type": "markdown",
"id": "5eeMguyGdtpD"
},
"source": [
"# Training a GAN!"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "0Qj-KOMBdtpE"
},
"outputs": [],
"source": [
"def run_a_gan(D, G, D_solver, G_solver, discriminator_loss, generator_loss, show_every=250, \n",
"              batch_size=128, noise_size=96, num_epochs=10):\n",
"    \"\"\"\n",
"    Train a GAN!\n",
"    \n",
"    Inputs:\n",
"    - D, G: PyTorch models for the discriminator and generator\n",
"    - D_solver, G_solver: torch.optim Optimizers to use for training the\n",
"      discriminator and generator.\n",
"    - discriminator_loss, generator_loss: Functions to use for computing the generator and\n",
"      discriminator loss, respectively.\n",
"    - show_every: Show samples after every show_every iterations.\n",
"    - batch_size: Batch size to use for training.\n",
"    - noise_size: Dimension of the noise to use as input to the generator.\n",
"    - num_epochs: Number of epochs over the training dataset to use for training.\n",
"    \"\"\"\n",
"    iter_count = 0\n",
"    for epoch in range(num_epochs):\n",
"        for x, _ in loader_train:\n",
"            if len(x) != batch_size:\n",
"                continue\n",
"            real_data = x.type(dtype)\n",
"            logits_real = D(2* (real_data - 0.5)).type(dtype)\n",
"\n",
"            g_fake_seed = sample_noise(batch_size, noise_size).type(dtype)\n",
"            fake_images = G(g_fake_seed).detach()\n",
"            logits_fake = D(fake_images.view(batch_size, 1, 28, 28))\n",
"\n",
"            d_total_error = discriminator_loss(logits_real, logits_fake)\n",
"            d_total_error.backward()        \n",
"            D_solver.step()\n",
"\n",
"            g_fake_seed = sample_noise(batch_size, noise_size).type(dtype)\n",
"            fake_images = G(g_fake_seed)\n",
"\n",
"            gen_logits_fake = D(fake_images.view(batch_size, 1, 28, 28))\n",
"            g_error = generator_loss(gen_logits_fake)\n",
"            g_error.backward()\n",
"            G_solver.step()\n",
"\n",
"            if (iter_count % show_every == 0):\n",
"                print('Iter: {}, D: {:.4}, G:{:.4}'.format(iter_count,d_total_error.item(),g_error.item()))\n",
"                imgs_numpy = fake_images.data.cpu().numpy()\n",
"                show_images(imgs_numpy[0:16])\n",
"                plt.show()\n",
"                print()\n",
"            iter_count += 1"
]
},
{
"cell_type": "code",
"execution_count": 19,
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"id": "B9miV1qfdtpG",
"outputId": "dbb0c084-b3f3-4983-ecb8-bee24e7d8dac",
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Iter: 0, D: 1.328, G:0.7202\n"
]
},
{
"data": {
"text/plain": [
"<Figure size 288x288 with 16 Axes>"
]
},
"needs_background": "light",
"tags": []
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Iter: 250, D: 1.43, G:0.6752\n"
]
},
{
"data": {
"text/plain": [
"<Figure size 288x288 with 16 Axes>"
]
},
"needs_background": "light",
"tags": []
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Iter: 500, D: 1.181, G:1.414\n"
]
},
{
"data": {
"text/plain": [
"<Figure size 288x288 with 16 Axes>"
]
},
"needs_background": "light",
"tags": []
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Iter: 750, D: 1.204, G:1.556\n"
]
},
{
"data": {
"text/plain": [
"<Figure size 288x288 with 16 Axes>"
]
},
"needs_background": "light",
"tags": []
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Iter: 1000, D: 1.174, G:1.126\n"
]
},
{
"data": {
"text/plain": [
"<Figure size 288x288 with 16 Axes>"
]
},
"needs_background": "light",
"tags": []
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Iter: 1250, D: 1.255, G:1.068\n"
]
},
{
"data": {
"text/plain": [
"<Figure size 288x288 with 16 Axes>"
]
},
"needs_background": "light",
"tags": []
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Iter: 1500, D: 1.136, G:0.971\n"
]
},
{
"data": {
"text/plain": [
"<Figure size 288x288 with 16 Axes>"
]
},
"needs_background": "light",
"tags": []
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Iter: 1750, D: 1.317, G:0.7927\n"
]
},
{
"data": {
"text/plain": [
"<Figure size 288x288 with 16 Axes>"
]
},
"needs_background": "light",
"tags": []
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Iter: 2000, D: 1.274, G:0.9762\n"
]
},
{
"data": {
"text/plain": [
"<Figure size 288x288 with 16 Axes>"
]
},
"needs_background": "light",
"tags": []
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Iter: 2250, D: 1.258, G:0.9521\n"
]
},
{
"data": {
"text/plain": [
"<Figure size 288x288 with 16 Axes>"
]
},
"needs_background": "light",
"tags": []
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Iter: 2500, D: 1.202, G:0.833\n"
]
},
{
"data": {
"text/plain": [
"<Figure size 288x288 with 16 Axes>"
]
},
"needs_background": "light",
"tags": []
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Iter: 2750, D: 1.288, G:0.8659\n"
]
},
{
"data": {
"text/plain": [
"<Figure size 288x288 with 16 Axes>"
]
},
"needs_background": "light",
"tags": []
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Iter: 3000, D: 1.379, G:0.824\n"
]
},
{
"data": {
"text/plain": [
"<Figure size 288x288 with 16 Axes>"
]
},
"needs_background": "light",
"tags": []
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Iter: 3250, D: 1.392, G:0.8353\n"
]
},
{
"data": {
"text/plain": [
"<Figure size 288x288 with 16 Axes>"
]
},
"needs_background": "light",
"tags": []
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Iter: 3500, D: 1.296, G:0.8011\n"
]
},
{
"data": {
"text/plain": [
"<Figure size 288x288 with 16 Axes>"
]
},
"needs_background": "light",
"tags": []
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Iter: 3750, D: 1.221, G:0.841\n"
]
},
{
"data": {
"text/plain": [
"<Figure size 288x288 with 16 Axes>"
]
},
"needs_background": "light",
"tags": []
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"# Make the discriminator\n",
"D = discriminator().type(dtype)\n",
"\n",
"# Make the generator\n",
"G = generator().type(dtype)\n",
"\n",
"# Use the function you wrote earlier to get optimizers for the Discriminator and the Generator\n",
"D_solver = get_optimizer(D)\n",
"G_solver = get_optimizer(G)\n",
"# Run it!\n",
"run_a_gan(D, G, D_solver, G_solver, discriminator_loss, generator_loss)"
]
},
{
"cell_type": "markdown",
"id": "vnLQHE1VdtpJ"
},
"source": [
"In the iterations in the low 100s we should see black backgrounds, fuzzy shapes as you approach iteration 1000, and decent shapes, about half of which will be sharp and clearly recognizable as we pass 3000."
]
},
{
"cell_type": "markdown",
"id": "RAITXp5ZdtpK"
},
"source": [
"# Least Squares GAN\n",
"We'll now look at [Least Squares GAN](https://arxiv.org/abs/1611.04076), a newer, more stable alernative to the original GAN loss function. For this part, all we have to do is change the loss function and retrain the model. We'll implement equation (9) in the paper, with the generator loss:\n",
"$$\\ell_G = \\frac{1}{2}\\mathbb{E}_{z \\sim p(z)}\\left[\\left(D(G(z))-1\\right)^2\\right]$$\n",
"and the discriminator loss:\n",
"$$\\ell_D = \\frac{1}{2}\\mathbb{E}_{x \\sim p_\\text{data}}\\left[\\left(D(x)-1\\right)^2\\right] + \\frac{1}{2}\\mathbb{E}_{z \\sim p(z)}\\left[ \\left(D(G(z))\\right)^2\\right]$$"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "nuOMD1TWdtpK"
},
"outputs": [],
"source": [
"def ls_discriminator_loss(scores_real, scores_fake):\n",
"    \"\"\"\n",
"    Compute the Least-Squares GAN loss for the discriminator.\n",
"    \n",
"    Inputs:\n",
"    - scores_real: PyTorch Tensor of shape (N,) giving scores for the real data.\n",
"    - scores_fake: PyTorch Tensor of shape (N,) giving scores for the fake data.\n",
"    \n",
"    Outputs:\n",
"    - loss: A PyTorch Tensor containing the loss.\n",
"    \"\"\"\n",
"    N,_ = scores_real.size()\n",
"    loss = (0.5 * torch.mean((scores_real-torch.ones(N).type(dtype))**2)) + (0.5 * torch.mean(scores_fake**2))\n",
"    return loss\n",
"\n",
"def ls_generator_loss(scores_fake):\n",
"    \"\"\"\n",
"    Computes the Least-Squares GAN loss for the generator.\n",
"    \n",
"    Inputs:\n",
"    - scores_fake: PyTorch Tensor of shape (N,) giving scores for the fake data.\n",
"    \n",
"    Outputs:\n",
"    - loss: A PyTorch Tensor containing the loss.\n",
"    \"\"\"\n",
"    N,_ = scores_fake.size()\n",
"    loss = (0.5 * torch.mean((scores_fake-torch.ones(N).type(dtype))**2))\n",
"    return loss"
]
},
{
"cell_type": "markdown",
"id": "krGClF97dtpM"
},
"source": [
"Before running a GAN with our new loss function, let's check it:"
]
},
{
"cell_type": "code",
"execution_count": 21,
"colab": {
"base_uri": "https://localhost:8080/",
"height": 52
},
"id": "Wo_nel7-dtpM",
"outputId": "101f7a56-1b28-4236-8633-4171ae4283ee"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Maximum error in d_loss: 1.64377e-08\n",
"Maximum error in g_loss: 2.7837e-09\n"
]
}
],
"source": [
"def test_lsgan_loss(score_real, score_fake, d_loss_true, g_loss_true):\n",
"    score_real = torch.Tensor(score_real).type(dtype)\n",
"    score_fake = torch.Tensor(score_fake).type(dtype)\n",
"    d_loss = ls_discriminator_loss(score_real, score_fake).cpu().numpy()\n",
"    g_loss = ls_generator_loss(score_fake).cpu().numpy()\n",
"    print(\"Maximum error in d_loss: %g\"%rel_error(d_loss_true, d_loss))\n",
"    print(\"Maximum error in g_loss: %g\"%rel_error(g_loss_true, g_loss))\n",
"\n",
]
},
{
"cell_type": "markdown",
"id": "q82122yedtpO"
},
"source": [
"Run the following cell to train model!"
]
},
{
"cell_type": "code",
"execution_count": 22,
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"id": "htEifHj5dtpP",
"outputId": "3f4566d9-02c7-4948-db13-ba5934e86437",
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Iter: 0, D: 0.5689, G:0.51\n"
]
},
{
"data": {
"text/plain": [
"<Figure size 288x288 with 16 Axes>"
]
},
"needs_background": "light",
"tags": []
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Iter: 250, D: 0.1481, G:0.3264\n"
]
},
{
"data": {
"text/plain": [
"<Figure size 288x288 with 16 Axes>"
]
},
"needs_background": "light",
"tags": []
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Iter: 500, D: 0.2063, G:0.4708\n"
]
},
{
"data": {
"text/plain": [
"<Figure size 288x288 with 16 Axes>"
]
},
"needs_background": "light",
"tags": []
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Iter: 750, D: 0.1258, G:0.2649\n"
]
},
{
"data": {
"text/plain": [
"<Figure size 288x288 with 16 Axes>"
]
},
"needs_background": "light",
"tags": []
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Iter: 1000, D: 0.152, G:0.4361\n"
]
},
{
"data": {
"text/plain": [
"<Figure size 288x288 with 16 Axes>"
]
},
"needs_background": "light",
"tags": []
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Iter: 1250, D: 0.1842, G:0.2598\n"
]
},
{
"data": {
"text/plain": [
"<Figure size 288x288 with 16 Axes>"
]
},
"needs_background": "light",
"tags": []
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Iter: 1500, D: 0.1986, G:0.2422\n"
]
},
{
"data": {
"text/plain": [
"<Figure size 288x288 with 16 Axes>"
]
},
"needs_background": "light",
"tags": []
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Iter: 1750, D: 0.2018, G:0.2362\n"
]
},
{
"data": {
"text/plain": [
"<Figure size 288x288 with 16 Axes>"
]
},
"needs_background": "light",
"tags": []
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Iter: 2000, D: 0.2339, G:0.1912\n"
]
},
{
"data": {
"text/plain": [
"<Figure size 288x288 with 16 Axes>"
]
},
"needs_background": "light",
"tags": []
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Iter: 2250, D: 0.2559, G:0.2198\n"
]
},
{
"data": {
"text/plain": [
"<Figure size 288x288 with 16 Axes>"
]
},
"needs_background": "light",
"tags": []
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Iter: 2500, D: 0.2503, G:0.1511\n"
]
},
{
"data": {
"text/plain": [
"<Figure size 288x288 with 16 Axes>"
]
},
"needs_background": "light",
"tags": []
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Iter: 2750, D: 0.2112, G:0.1597\n"
]
},
{
"data": {
"text/plain": [
"<Figure size 288x288 with 16 Axes>"
]
},
"needs_background": "light",
"tags": []
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Iter: 3000, D: 0.2393, G:0.1796\n"
]
},
{
"data": {
"text/plain": [
"<Figure size 288x288 with 16 Axes>"
]
},
"needs_background": "light",
"tags": []
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Iter: 3250, D: 0.2336, G:0.1621\n"
]
},
{
"data": {
"text/plain": [
"<Figure size 288x288 with 16 Axes>"
]
},
"needs_background": "light",
"tags": []
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Iter: 3500, D: 0.2206, G:0.1707\n"
]
},
{
"data": {
"image/png": "