The Algorithms logo
The Algorithms
关于捐赠

CNN Pytorch

H
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "wq_TKqjUmILg"
   },
   "source": [
    "# Convolutional Neural Network (CNN)\n",
    "\n",
    "## Resources\n",
    "\n",
    "    CNN : https://en.wikipedia.org/wiki/Convolutional_neural_network\n",
    "    Pytorch : https://pytorch.org/tutorials/beginner/basics/intro.html"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Start by importing both the training and testing MNIST datasets using DataLoaders and the torchvision provided datasets. You can set both the training and testing batch size to be whatever you feel is best."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "id": "qGEvJYHnmILh"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torchvision import datasets, transforms\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data.dataloader import DataLoader\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 437,
     "referenced_widgets": [
      "f1ae475d1e48411ab7cb9f49f6f673c9",
      "6e770d6f5e5f4101b2c985f8182eb79e",
      "d784c3bdbaf543a299447b17b500c2a8",
      "51567646aecc4126a3b9cb96f97d5be5",
      "f65a4ec103c34c7186fb76d3e3507795",
      "af12215b0ba24c699e1111c132dee3fe",
      "908f60d17e1b49869ac3d4ee53a74a6e",
      "f6af65953a9841e1907a47fb27c01b96",
      "44f9d2dde4424058a337a1e3d585b5fe",
      "037faad73a8a40499555406f4e586731",
      "4b80b4415ce34130a048933b2df179a5",
      "c52318a206de4565a725388bc50bdfc0",
      "d054c2e9330c433ab6bb900c6fa7dac6",
      "7c8cf011ed684ef0841d41da76c9bfc8",
      "14d009d293864f9f85bd16b5e1b6b381",
      "553cf4ae8edb4b48a58ab4446036ae87",
      "a1dd0d1a7fb14235a271992bea0d233c",
      "b9fa1995789a49fdb6a0c3e8505733de",
      "d2e31ba6e42448839959f523cc56dbcb",
      "5ecbc7b6eb6544709da61283dfc8d3c6",
      "c75452bfc2264392bf3d842bfbd4eeee",
      "0554687d97424798a86ea0a4c56cdbf8",
      "2ee0648f054c49049fdf5d6ac81ec086",
      "4be171eff56046c2a401bf02b6d704c9",
      "c8c6df38201b470bb06bffc677873f93",
      "f51f06fbe28746dea09e075256e29451",
      "7c90cc0c4c3b4cd89dbdfffdb78d463d",
      "9ee4f57a58c14bb1880789196ff7ef63",
      "74d1d431fbb84c4f948ca510e397f4e2",
      "90f550d2f7344eb894208d0373ac6f8d",
      "7837b9def9dd470cb57f96b38afbfc18",
      "b3105726704240439adcf7f13bd48cca",
      "54ce0f87fc1f46d2b9f9850a393b46cf",
      "1a89067cd68c41fb96b74e0ebf3d1931",
      "91798827c9254ffc969228862c8ee37f",
      "cdbc8de66e8c4fa79100dcac1fcfed4d",
      "2773f678ca4a45158df219d2672fa646",
      "9f7585b1023d4e20ba3649efcfcdb881",
      "3d3af4241b714b9a87253ac87b2e31b1",
      "84bda6e8c38b4b669e2825c962e9dbfb",
      "780790f42e764ec5a0b8d441a290e5d3",
      "47b04fa0a67c4fbbbcdc3d1062430259",
      "480e369fb79f4326bd13d2355fddd890",
      "898cdc388a2d428ca742a9ba2365df20"
     ]
    },
    "id": "N_61-p6ymILj",
    "outputId": "765b9e09-2ce2-47ab-eac2-6978fa820170"
   },
   "outputs": [],
   "source": [
    "# Downloading MNIST dataset from Pytorch\n",
    "dataset = datasets.MNIST(\n",
    "    root=\"./data\",\n",
    "    download=True,\n",
    "    train=True,\n",
    "    transform=transforms.ToTensor(),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "id": "UlRAhZOpmILk"
   },
   "outputs": [],
   "source": [
    "# Splitting the dataset into training and testing set\n",
    "train_dataset, test_dataset = torch.utils.data.random_split(dataset, [50000, 10000])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 332
    },
    "id": "iQWzZCRFmILl",
    "outputId": "d2992c28-d667-4653-bbc2-9b323c82def0"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/anaconda3/lib/python3.8/site-packages/torchvision/datasets/mnist.py:52: UserWarning: train_labels has been renamed targets\n",
      "  warnings.warn(\"train_labels has been renamed targets\")\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Text(0.5, 1.0, 'Label : 3')"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAEICAYAAACZA4KlAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAAQCklEQVR4nO3dfbBU9X3H8fcHuII8OIIEgygi1PhQTUi8UUdpa2Jj1WnUTEYb2jS0NWKTaOuEOLW2HR3bTh1TJbHVGKhEND5OwZEkTBolbY1jtF6VCD6hUowIgoagQCry8O0fe2gvePfsZc/ZPcv9fV4zO7t3v+fhy3I/9+zub8/+FBGY2cA3qOoGzKw9HHazRDjsZolw2M0S4bCbJcJhN0uEw54gSf8h6YvtXteq5bDvwyStkvTbVffRX5Iuk7RS0juS1kiaLWlI1X2lwmG3dvoe8LGIOAA4DvgI8GfVtpQOh30AkjRa0vclvSnpl9ntQ/dYbIqk/5L0tqQHJI3ptf7Jkh6VtFHSzySdVkZfEfFKRGzctRtgJ/BrZWzbGnPYB6ZBwHeAw4GJwP8A/7zHMl8A/gQ4BNgO3AggaQLwA+DvgDHA14AFkj7QaKeSpkna2GCZ35f0DvAWtSP7t/v9r7JCHPYBKCJ+ERELIuJXEbEJ+Hvgt/ZY7I6IWB4RW4C/AS6QNBj4PLA4IhZHxM6IeBDoAc7ux34fiYgDGyxzV/Y0/kPALcC6vf4HWlMc9gFI0nBJ35b0anYUfRg4MAvzLq/1uv0q0AWMpfZs4PzsKfzG7Eg9DRhfZo8R8RLwLHBzmdu1+vxO6MA0CzgKOCki3pA0FXia2uvkXQ7rdXsisI3aU+vXqB31L2pDn0OAKW3Yj+Ej+0DQJWlYr8sQYBS11+kbszferupjvc9LOlbScOAa4F8jYgfwXeDTkn5H0uBsm6f18QbfXpP0RUnjstvHAn8JLCm6Xesfh33ft5hasHddrga+AexP7Uj9GPDDPta7A7gNeAMYRjYEFhGvAecCVwJvUjvSX04/flck/YakzTmLnAosk7Ql63txth9rA/nLK8zS4CO7WSIcdrNEOOxmiXDYzRLR1nH2/TQ0hjGinbs0S8q7bOG92Kq+aoXCLulM4JvAYOBfIuLavOWHMYKTdHqRXZpZjsej/scWmn4an3308ibgLOBYYHr2QQkz60BFXrOfCLwcESsj4j3gHmofxjCzDlQk7BPY/WSK1dl9u5E0U1KPpJ5tbC2wOzMrokjY+3oT4H0fx4uIORHRHRHdXQwtsDszK6JI2Fez+5lThwJrirVjZq1SJOxPAEdKOkLSfsDngEXltGVmZWt66C0itku6BPg3akNv8yLi2dI6M7NSFRpnj4hdpymaWYfzx2XNEuGwmyXCYTdLhMNulgiH3SwRDrtZIhx2s0Q47GaJcNjNEuGwmyXCYTdLhMNulgiH3SwRDrtZIhx2s0Q47GaJcNjNEuGwmyXCYTdLhMNulgiH3SwRbZ2y2Vrk5A/XLf33OflTZF/12fty6zesyJ91d9Oyg3LreaZc83Rufee77za9bXs/H9nNEuGwmyXCYTdLhMNulgiH3SwRDrtZIhx2s0R4nH0f8PoVp+TWF3/5urq1iUNGFtr3H5yQPw7PCc1ve9qTF+fWRyx4vPmN2/sUCrukVcAmYAewPSK6y2jKzMpXxpH9ExHxVgnbMbMW8mt2s0QUDXsAP5L0pKSZfS0gaaakHkk929hacHdm1qyiT+NPjYg1ksYBD0p6ISIe7r1ARMwB5gAcoDFRcH9m1qRCR/aIWJNdrwfuB04soykzK1/TYZc0QtKoXbeBM4DlZTVmZuUq8jT+YOB+Sbu2c1dE/LCUrmw3h89fmVtfM3P/urWJHfxJirnXz86tXzjkq7n1Ufc+VmY7A17TvwoRsRL4SIm9mFkLeejNLBEOu1kiHHazRDjsZolw2M0S0cEDM7bL9rVv5NYvnHtp3dpDX6p/+ivA+AanwC7aMjy3fs6IX+XW8xyzX/62135qe2591L1N7zpJPrKbJcJhN0uEw26WCIfdLBEOu1kiHHazRDjsZonwOPsAcOg/PFq39p3p+d/1fOXYF3PrL2/9YP7OR+SfflvE0Tduzq3vbNmeByYf2c0S4bCbJcJhN0uEw26WCIfdLBEOu1kiHHazRHicfYBb+E+fzK3vvFS59b8e+0KZ7eyVncO6Ktv3QOQju1kiHHazRDjsZolw2M0S4bCbJcJhN0uEw26WCI+zD3AHzf1pbv2nDx2VW//697bl1i8f88pe99Rfm6/ZklsfeWbLdj0gNTyyS5onab2k5b3uGyPpQUkvZdejW9ummRXVn6fxtwF7/g29AlgSEUcCS7KfzayDNQx7RDwMbNjj7nOB+dnt+cB55bZlZmVr9g26gyNiLUB2Pa7egpJmSuqR1LONrU3uzsyKavm78RExJyK6I6K7i6Gt3p2Z1dFs2NdJGg+QXa8vryUza4Vmw74ImJHdngE8UE47ZtYqDcfZJd0NnAaMlbQauAq4FrhP0oXAz4HzW9mkNW/9Jafk1jcelz8H+qLR9zfYQ+teCW54LP8760fSuu+sH4gahj0iptcpnV5yL2bWQv64rFkiHHazRDjsZolw2M0S4bCbJcKnuO4D9PHjc+vnzf9x3doXDvhG7rrDB+3XYO/VHQ8mLdzzlIzdecrmveMju1kiHHazRDjsZolw2M0S4bCbJcJhN0uEw26WCI+z7wN+cfzI3PrvjXqpbm34oOFlt9M2L87K7/3IGbll24OP7GaJcNjNEuGwmyXCYTdLhMNulgiH3SwRDrtZIjzOvg8YMy9/2uVTDv1a3dpPLvp67rpjB49oqqd2GH/wxqpbGFB8ZDdLhMNulgiH3SwRDrtZIhx2s0Q47GaJcNjNEuFx9gFg4jWP1q19+uVZueu+e2Cxv/fR4Ddowazr6tamdOWfp2/lavg/LWmepPWSlve672pJr0taml3Obm2bZlZUf/6s3wac2cf9syNianZZXG5bZla2hmGPiIeB/Hl4zKzjFXnBdomkZ7Kn+aPrLSRppqQeST3b2Fpgd2ZWRLNh/xYwBZgKrAWur7dgRMyJiO6I6O5iaJO7M7Oimgp7RKyLiB0RsROYC5xYbltmVramwi5pfK8fPwMsr7esmXWGhuPsku4GTgPGSloNXAWcJmkqEMAq4OLWtWhFHHDXY/n1ojuQcstnTK5/rv0rF9ySu+6Xj/jP3Pqdx56eW9/x3Ircemoahj0ipvdx960t6MXMWsgflzVLhMNulgiH3SwRDrtZIhx2s0T4FFcrZND+++fWGw2v5dm0Y1j+Att3NL3tFPnIbpYIh90sEQ67WSIcdrNEOOxmiXDYzRLhsJslwuPsVsgLs3+9wRL1v+a6kdkLz8mtT1qRP5W17c5HdrNEOOxmiXDYzRLhsJslwmE3S4TDbpYIh90sER5n76chEw6pW3vv9sG567618LDc+ribmh+LbrUhkyfl1h86c3aDLTQ/LfPk+36ZW9/Z9JbT5CO7WSIcdrNEOOxmiXDYzRLhsJslwmE3S4TDbpaI/kzZfBhwO/BBakObcyLim5LGAPcCk6hN23xBROQPjO7D1txcf3Ljp4+5J3fdOZfUH6MH+O7rv5tbH7Fqc25959Ln6ta2f/KE3HU3HD00t/7ZP/1xbn1KV/Pj6Ed8/6Lc+tGv1P932d7rz5F9OzArIo4BTga+IulY4ApgSUQcCSzJfjazDtUw7BGxNiKeym5vAp4HJgDnAvOzxeYD57WoRzMrwV69Zpc0Cfgo8DhwcESshdofBGBc6d2ZWWn6HXZJI4EFwGUR8c5erDdTUo+knm1sbaZHMytBv8IuqYta0O+MiIXZ3eskjc/q44H1fa0bEXMiojsiurvIfzPIzFqnYdglCbgVeD4ibuhVWgTMyG7PAB4ovz0zK4siIn8BaRrwE2AZ/39W4ZXUXrffB0wEfg6cHxEb8rZ1gMbESTq9aM+V2HrWx+vWPvy3S3PXvfGQJwrte8Hm+sN+ALe+Pq1u7abJ9+Wue0SBoTOAHZF/ouktbx9et/aDUybnb3vj2031lLLHYwnvxAb1VWs4zh4RjwB9rgzsm8k1S5A/QWeWCIfdLBEOu1kiHHazRDjsZolw2M0S0XCcvUz78jh7nhVz64/BAwxf2ZVbf/bSm8tsp62eee/d3Prlk05uUycG+ePsPrKbJcJhN0uEw26WCIfdLBEOu1kiHHazRDjsZonwlM0l+NBF+eerDxo+PLd+1MgvFdr/iOPrf43AU933Ftr2im1bcutf/eNLc+uDearQ/q08PrKbJcJhN0uEw26WCIfdLBEOu1kiHHazRDjsZonw+exmA4jPZzczh90sFQ67WSIcdrNEOOxmiXDYzRLhsJslomHYJR0m6d8lPS/pWUl/nt1/taTXJS3NLme3vl0za1Z/vrxiOzArIp6SNAp4UtKDWW12RPxj69ozs7I0DHtErAXWZrc3SXoemNDqxsysXHv1ml3SJOCjwOPZXZdIekbSPEmj66wzU1KPpJ5tbC3WrZk1rd9hlzQSWABcFhHvAN8CpgBTqR35r+9rvYiYExHdEdHdxdDiHZtZU/oVdkld1IJ+Z0QsBIiIdRGxIyJ2AnOBE1vXppkV1Z934wXcCjwfETf0un98r8U+Aywvvz0zK0t/3o0/FfhDYJmkpdl9VwLTJU0FAlgFXNyC/sysJP15N/4RoK/zYxeX346ZtYo/QWeWCIfdLBEOu1kiHHazRDjsZolw2M0S4bCbJcJhN0uEw26WCIfdLBEOu1kiHHazRDjsZolw2M0S0dYpmyW9Cbza666xwFtta2DvdGpvndoXuLdmldnb4RHxgb4KbQ37+3Yu9UREd2UN5OjU3jq1L3BvzWpXb34ab5YIh90sEVWHfU7F+8/Tqb11al/g3prVlt4qfc1uZu1T9ZHdzNrEYTdLRCVhl3SmpBclvSzpiip6qEfSKknLsmmoeyruZZ6k9ZKW97pvjKQHJb2UXfc5x15FvXXENN4504xX+thVPf1521+zSxoMrAA+BawGngCmR8RzbW2kDkmrgO6IqPwDGJJ+E9gM3B4Rx2X3XQdsiIhrsz+UoyPiLzqkt6uBzVVP453NVjS+9zTjwHnAH1HhY5fT1wW04XGr4sh+IvByRKyMiPeAe4BzK+ij40XEw8CGPe4+F5if3Z5P7Zel7er01hEiYm1EPJXd3gTsmma80scup6+2qCLsE4DXev28ms6a7z2AH0l6UtLMqpvpw8ERsRZqvzzAuIr72VPDabzbaY9pxjvmsWtm+vOiqgh7X1NJddL436kR8THgLOAr2dNV659+TePdLn1MM94Rmp3+vKgqwr4aOKzXz4cCayroo08RsSa7Xg/cT+dNRb1u1wy62fX6ivv5P500jXdf04zTAY9dldOfVxH2J4AjJR0haT/gc8CiCvp4H0kjsjdOkDQCOIPOm4p6ETAjuz0DeKDCXnbTKdN415tmnIofu8qnP4+Itl+As6m9I/8K8FdV9FCnr8nAz7LLs1X3BtxN7WndNmrPiC4EDgKWAC9l12M6qLc7gGXAM9SCNb6i3qZRe2n4DLA0u5xd9WOX01dbHjd/XNYsEf4EnVkiHHazRDjsZolw2M0S4bCbJcJhN0uEw26WiP8Fvji1zrt7lZQAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Visualizing a sample from dataset\n",
    "plt.imshow(train_dataset.dataset.data[10])\n",
    "plt.title(\"Label : \" + str(train_dataset.dataset.train_labels[10].item()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "id": "9WhgIZF0mILl"
   },
   "outputs": [],
   "source": [
    "# Creating a DataLoader for training and testing\n",
    "train = DataLoader(train_dataset, batch_size=32, shuffle=True)\n",
    "test = DataLoader(test_dataset, batch_size=1, shuffle=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "pWhZXmm6mILm"
   },
   "source": [
    "Define a network with the following architecture:\n",
    "\n",
    "Conv2d (input channels=1, output channels = 15,kernel size = 5)\n",
    "$\\rightarrow$\n",
    "MaxPool (kernel size = 2)\n",
    "$\\rightarrow$\n",
    "ReLU\n",
    "$\\rightarrow$\n",
    "Conv2d (input channels=15, output channels = 30,kernel size = 5)\n",
    "$\\rightarrow$\n",
    "Dropout2d (p = 0.5)\n",
    "$\\rightarrow$\n",
    "MaxPool (kernel size = 2)\n",
    "$\\rightarrow$\n",
    "ReLU\n",
    "$\\rightarrow$\n",
    "Linear(input dimension = 480, hidden units = 64)\n",
    "$\\rightarrow$\n",
    "ReLU\n",
    "$\\rightarrow$\n",
    "Dropout (p=0.5)\n",
    "$\\rightarrow$\n",
    "Linear(input dimension = 64, hidden units = 10)\n",
    "$\\rightarrow$\n",
    "LogSoftMax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "id": "_-RTFbeKmILn"
   },
   "outputs": [],
   "source": [
    "class CNN(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.cnn = nn.Sequential(\n",
    "            nn.Conv2d(1, 15, kernel_size=5),\n",
    "            nn.MaxPool2d(2, 2),\n",
    "            nn.ReLU(),\n",
    "            nn.Conv2d(15, 30, kernel_size=5),\n",
    "            nn.Dropout2d(0.5),\n",
    "            nn.MaxPool2d(2, 2),\n",
    "            nn.ReLU(),\n",
    "            nn.Flatten(),\n",
    "            nn.Linear(480, 64),\n",
    "            nn.ReLU(),\n",
    "            nn.Dropout(0.5),\n",
    "            nn.Linear(64, 10),\n",
    "            nn.LogSoftmax(dim=1),\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.cnn(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "jzwFIYidmILo"
   },
   "source": [
    "Train the network you defined in the previous question on MNIST, using the optimizer and the number of training epochs you deem appropriate. Use a cross-entropy loss. Each epoch test your model on the testing dataset and print the value of the accuracy that you achieve. \n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "-hdIBNSzmILp",
    "outputId": "101815c2-c1c6-441b-bd05-ce08b702c044"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/anaconda3/lib/python3.8/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at  ../c10/core/TensorImpl.h:1156.)\n",
      "  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch :0 Loss : 0.44618519723548844 Train Accuracy:0.8599048256874084 Test Accuracy : 0.9340000152587891\n",
      "Epoch :1 Loss : 0.20212347584169404 Train Accuracy:0.9430182576179504 Test Accuracy : 0.9490000009536743\n",
      "Epoch :2 Loss : 0.16394136824853056 Train Accuracy:0.952215313911438 Test Accuracy : 0.9531000256538391\n",
      "Epoch :3 Loss : 0.14257464571382256 Train Accuracy:0.9595929384231567 Test Accuracy : 0.9588000178337097\n",
      "Epoch :4 Loss : 0.12598471959617277 Train Accuracy:0.9644113779067993 Test Accuracy : 0.9648000001907349\n",
      "Epoch :5 Loss : 0.11733871379403024 Train Accuracy:0.9659308791160583 Test Accuracy : 0.9660000205039978\n",
      "Epoch :6 Loss : 0.11000267015220401 Train Accuracy:0.9681901931762695 Test Accuracy : 0.9664999842643738\n",
      "Epoch :7 Loss : 0.10582590816269605 Train Accuracy:0.9684301018714905 Test Accuracy : 0.9631999731063843\n",
      "Epoch :8 Loss : 0.09624670598793283 Train Accuracy:0.9711892008781433 Test Accuracy : 0.9706000089645386\n",
      "Epoch :9 Loss : 0.09422491162643551 Train Accuracy:0.9726687669754028 Test Accuracy : 0.9679999947547913\n"
     ]
    }
   ],
   "source": [
    "batch_size = 32\n",
    "\n",
    "model = CNN()\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
    "\n",
    "# Running the model on GPU if available\n",
    "## Refer pytorch documentation for more details about copying model and data onto the device\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "model.to(device)\n",
    "\n",
    "cost = []\n",
    "epochs = 10\n",
    "\n",
    "# Training the model\n",
    "for epoch in range(epochs):\n",
    "\n",
    "    loss_epoch = []\n",
    "    train_acc = []\n",
    "\n",
    "    for x, y in train:\n",
    "\n",
    "        # Predicting the output\n",
    "        y_pred = model(x.to(device))\n",
    "\n",
    "        # Converting the predicted output from one hot encoding to a single number\n",
    "        _, t_preds = torch.max(y_pred, dim=1)\n",
    "\n",
    "        # Calculating the training accuracy\n",
    "        train_acc.append(\n",
    "            torch.tensor(torch.sum(t_preds == y.to(device)).item() / len(t_preds))\n",
    "        )\n",
    "\n",
    "        # Calculating the loss\n",
    "        loss = F.cross_entropy(y_pred, y.type(torch.LongTensor).to(device))\n",
    "\n",
    "        # Backpropagation\n",
    "\n",
    "        # Zeroing the gradients\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        # Calculating the gradients\n",
    "        loss.backward()\n",
    "\n",
    "        # Updating the weights\n",
    "        optimizer.step()\n",
    "\n",
    "        # Appending the loss of each batch to the epoch loss\n",
    "        loss_epoch.append(loss.item())\n",
    "\n",
    "    # Calculating test accuracy\n",
    "    with torch.no_grad():\n",
    "        if epoch % 1 == 0:\n",
    "            test_acc = []\n",
    "            for x, y in test:\n",
    "                y_pred = model(x.to(device))\n",
    "                _, t_preds = torch.max(y_pred, dim=1)\n",
    "                test_acc.append(\n",
    "                    torch.tensor(\n",
    "                        torch.sum(t_preds == y.to(device)).item() / len(t_preds)\n",
    "                    )\n",
    "                )\n",
    "\n",
    "            print(\n",
    "                \"Epoch :{} Loss : {} Train Accuracy:{} Test Accuracy : {}\".format(\n",
    "                    epoch,\n",
    "                    sum(loss_epoch) / len(loss_epoch),\n",
    "                    sum(train_acc) / len(train_acc),\n",
    "                    sum(test_acc) / len(test_acc),\n",
    "                )\n",
    "            )\n",
    "\n",
    "    cost.append(sum(loss_epoch) / len(loss_epoch))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 337
    },
    "id": "J7ecIEspmILq",
    "outputId": "5ec70ca4-45b0-4cf3-fedc-4f57daf7a8fa"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x7fa94052aa00>]"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "",
      "text/plain": [
       "<Figure size 1080x360 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "### Plotting the cost vs epochs\n",
    "fig, ax = plt.subplots(figsize=(15, 5))\n",
    "plt.plot(np.array(cost))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "Assignment1-Step2_Harshit_Agarwal.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "037faad73a8a40499555406f4e586731": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "0554687d97424798a86ea0a4c56cdbf8": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "14d009d293864f9f85bd16b5e1b6b381": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_c75452bfc2264392bf3d842bfbd4eeee",
      "placeholder": "​",
      "style": "IPY_MODEL_0554687d97424798a86ea0a4c56cdbf8",
      "value": " 29696/? [00:00&lt;00:00, 395418.47it/s]"
     }
    },
    "1a89067cd68c41fb96b74e0ebf3d1931": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_91798827c9254ffc969228862c8ee37f",
       "IPY_MODEL_cdbc8de66e8c4fa79100dcac1fcfed4d",
       "IPY_MODEL_2773f678ca4a45158df219d2672fa646"
      ],
      "layout": "IPY_MODEL_9f7585b1023d4e20ba3649efcfcdb881"
     }
    },
    "2773f678ca4a45158df219d2672fa646": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_480e369fb79f4326bd13d2355fddd890",
      "placeholder": "​",
      "style": "IPY_MODEL_898cdc388a2d428ca742a9ba2365df20",
      "value": " 5120/? [00:00&lt;00:00, 57213.75it/s]"
     }
    },
    "2ee0648f054c49049fdf5d6ac81ec086": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_4be171eff56046c2a401bf02b6d704c9",
       "IPY_MODEL_c8c6df38201b470bb06bffc677873f93",
       "IPY_MODEL_f51f06fbe28746dea09e075256e29451"
      ],
      "layout": "IPY_MODEL_7c90cc0c4c3b4cd89dbdfffdb78d463d"
     }
    },
    "3d3af4241b714b9a87253ac87b2e31b1": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "44f9d2dde4424058a337a1e3d585b5fe": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": ""
     }
    },
    "47b04fa0a67c4fbbbcdc3d1062430259": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": ""
     }
    },
    "480e369fb79f4326bd13d2355fddd890": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "4b80b4415ce34130a048933b2df179a5": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "4be171eff56046c2a401bf02b6d704c9": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_9ee4f57a58c14bb1880789196ff7ef63",
      "placeholder": "​",
      "style": "IPY_MODEL_74d1d431fbb84c4f948ca510e397f4e2",
      "value": ""
     }
    },
    "51567646aecc4126a3b9cb96f97d5be5": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_037faad73a8a40499555406f4e586731",
      "placeholder": "​",
      "style": "IPY_MODEL_4b80b4415ce34130a048933b2df179a5",
      "value": " 9913344/? [00:00&lt;00:00, 33092818.93it/s]"
     }
    },
    "54ce0f87fc1f46d2b9f9850a393b46cf": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "553cf4ae8edb4b48a58ab4446036ae87": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "5ecbc7b6eb6544709da61283dfc8d3c6": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": ""
     }
    },
    "6e770d6f5e5f4101b2c985f8182eb79e": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_af12215b0ba24c699e1111c132dee3fe",
      "placeholder": "​",
      "style": "IPY_MODEL_908f60d17e1b49869ac3d4ee53a74a6e",
      "value": ""
     }
    },
    "74d1d431fbb84c4f948ca510e397f4e2": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "780790f42e764ec5a0b8d441a290e5d3": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "7837b9def9dd470cb57f96b38afbfc18": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": ""
     }
    },
    "7c8cf011ed684ef0841d41da76c9bfc8": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_d2e31ba6e42448839959f523cc56dbcb",
      "max": 28881,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_5ecbc7b6eb6544709da61283dfc8d3c6",
      "value": 28881
     }
    },
    "7c90cc0c4c3b4cd89dbdfffdb78d463d": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "84bda6e8c38b4b669e2825c962e9dbfb": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "898cdc388a2d428ca742a9ba2365df20": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "908f60d17e1b49869ac3d4ee53a74a6e": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "90f550d2f7344eb894208d0373ac6f8d": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "91798827c9254ffc969228862c8ee37f": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_3d3af4241b714b9a87253ac87b2e31b1",
      "placeholder": "​",
      "style": "IPY_MODEL_84bda6e8c38b4b669e2825c962e9dbfb",
      "value": ""
     }
    },
    "9ee4f57a58c14bb1880789196ff7ef63": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "9f7585b1023d4e20ba3649efcfcdb881": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "a1dd0d1a7fb14235a271992bea0d233c": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "af12215b0ba24c699e1111c132dee3fe": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "b3105726704240439adcf7f13bd48cca": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "b9fa1995789a49fdb6a0c3e8505733de": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "c52318a206de4565a725388bc50bdfc0": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_d054c2e9330c433ab6bb900c6fa7dac6",
       "IPY_MODEL_7c8cf011ed684ef0841d41da76c9bfc8",
       "IPY_MODEL_14d009d293864f9f85bd16b5e1b6b381"
      ],
      "layout": "IPY_MODEL_553cf4ae8edb4b48a58ab4446036ae87"
     }
    },
    "c75452bfc2264392bf3d842bfbd4eeee": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "c8c6df38201b470bb06bffc677873f93": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_90f550d2f7344eb894208d0373ac6f8d",
      "max": 1648877,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_7837b9def9dd470cb57f96b38afbfc18",
      "value": 1648877
     }
    },
    "cdbc8de66e8c4fa79100dcac1fcfed4d": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_780790f42e764ec5a0b8d441a290e5d3",
      "max": 4542,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_47b04fa0a67c4fbbbcdc3d1062430259",
      "value": 4542
     }
    },
    "d054c2e9330c433ab6bb900c6fa7dac6": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_a1dd0d1a7fb14235a271992bea0d233c",
      "placeholder": "​",
      "style": "IPY_MODEL_b9fa1995789a49fdb6a0c3e8505733de",
      "value": ""
     }
    },
    "d2e31ba6e42448839959f523cc56dbcb": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "d784c3bdbaf543a299447b17b500c2a8": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_f6af65953a9841e1907a47fb27c01b96",
      "max": 9912422,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_44f9d2dde4424058a337a1e3d585b5fe",
      "value": 9912422
     }
    },
    "f1ae475d1e48411ab7cb9f49f6f673c9": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_6e770d6f5e5f4101b2c985f8182eb79e",
       "IPY_MODEL_d784c3bdbaf543a299447b17b500c2a8",
       "IPY_MODEL_51567646aecc4126a3b9cb96f97d5be5"
      ],
      "layout": "IPY_MODEL_f65a4ec103c34c7186fb76d3e3507795"
     }
    },
    "f51f06fbe28746dea09e075256e29451": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_b3105726704240439adcf7f13bd48cca",
      "placeholder": "​",
      "style": "IPY_MODEL_54ce0f87fc1f46d2b9f9850a393b46cf",
      "value": " 1649664/? [00:00&lt;00:00, 2122626.57it/s]"
     }
    },
    "f65a4ec103c34c7186fb76d3e3507795": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "f6af65953a9841e1907a47fb27c01b96": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
关于这个算法

Convolutional Neural Network (CNN)

Resources

CNN : https://en.wikipedia.org/wiki/Convolutional_neural_network
Pytorch : https://pytorch.org/tutorials/beginner/basics/intro.html

Start by importing both the training and testing MNIST datasets using DataLoaders and the torchvision provided datasets. You can set both the training and testing batch size to be whatever you feel is best.

import torch
from torchvision import datasets, transforms
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data.dataloader import DataLoader

import matplotlib.pyplot as plt
import numpy as np
# Downloading MNIST dataset from Pytorch
dataset = datasets.MNIST(
    root="./data",
    download=True,
    train=True,
    transform=transforms.ToTensor(),
)
# Splitting the dataset into training and testing set
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [50000, 10000])
# Visualizing a sample from dataset
plt.imshow(train_dataset.dataset.data[10])
plt.title("Label : " + str(train_dataset.dataset.train_labels[10].item()))
/opt/anaconda3/lib/python3.8/site-packages/torchvision/datasets/mnist.py:52: UserWarning: train_labels has been renamed targets
  warnings.warn(&quot;train_labels has been renamed targets&quot;)
Text(0.5, 1.0, &#x27;Label : 3&#x27;)
# Creating a DataLoader for training and testing
train = DataLoader(train_dataset, batch_size=32, shuffle=True)
test = DataLoader(test_dataset, batch_size=1, shuffle=True)

Define a network with the following architecture:

Conv2d (input channels=1, output channels = 15,kernel size = 5) $\rightarrow$ MaxPool (kernel size = 2) $\rightarrow$ ReLU $\rightarrow$ Conv2d (input channels=15, output channels = 30,kernel size = 5) $\rightarrow$ Dropout2d (p = 0.5) $\rightarrow$ MaxPool (kernel size = 2) $\rightarrow$ ReLU $\rightarrow$ Linear(input dimension = 480, hidden units = 64) $\rightarrow$ ReLU $\rightarrow$ Dropout (p=0.5) $\rightarrow$ Linear(input dimension = 64, hidden units = 10) $\rightarrow$ LogSoftMax

class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 15, kernel_size=5),
            nn.MaxPool2d(2, 2),
            nn.ReLU(),
            nn.Conv2d(15, 30, kernel_size=5),
            nn.Dropout2d(0.5),
            nn.MaxPool2d(2, 2),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(480, 64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, 10),
            nn.LogSoftmax(dim=1),
        )

    def forward(self, x):
        return self.cnn(x)

Train the network you defined in the previous question on MNIST, using the optimizer and the number of training epochs you deem appropriate. Use a cross-entropy loss. Each epoch test your model on the testing dataset and print the value of the accuracy that you achieve.

batch_size = 32

model = CNN()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Running the model on GPU if available
## Refer pytorch documentation for more details about copying model and data onto the device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

cost = []
epochs = 10

# Training the model
for epoch in range(epochs):

    loss_epoch = []
    train_acc = []

    for x, y in train:

        # Predicting the output
        y_pred = model(x.to(device))

        # Converting the predicted output from one hot encoding to a single number
        _, t_preds = torch.max(y_pred, dim=1)

        # Calculating the training accuracy
        train_acc.append(
            torch.tensor(torch.sum(t_preds == y.to(device)).item() / len(t_preds))
        )

        # Calculating the loss
        loss = F.cross_entropy(y_pred, y.type(torch.LongTensor).to(device))

        # Backpropagation

        # Zeroing the gradients
        optimizer.zero_grad()

        # Calculating the gradients
        loss.backward()

        # Updating the weights
        optimizer.step()

        # Appending the loss of each batch to the epoch loss
        loss_epoch.append(loss.item())

    # Calculating test accuracy
    with torch.no_grad():
        if epoch % 1 == 0:
            test_acc = []
            for x, y in test:
                y_pred = model(x.to(device))
                _, t_preds = torch.max(y_pred, dim=1)
                test_acc.append(
                    torch.tensor(
                        torch.sum(t_preds == y.to(device)).item() / len(t_preds)
                    )
                )

            print(
                "Epoch :{} Loss : {} Train Accuracy:{} Test Accuracy : {}".format(
                    epoch,
                    sum(loss_epoch) / len(loss_epoch),
                    sum(train_acc) / len(train_acc),
                    sum(test_acc) / len(test_acc),
                )
            )

    cost.append(sum(loss_epoch) / len(loss_epoch))
/opt/anaconda3/lib/python3.8/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at  ../c10/core/TensorImpl.h:1156.)
  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
Epoch :0 Loss : 0.44618519723548844 Train Accuracy:0.8599048256874084 Test Accuracy : 0.9340000152587891
Epoch :1 Loss : 0.20212347584169404 Train Accuracy:0.9430182576179504 Test Accuracy : 0.9490000009536743
Epoch :2 Loss : 0.16394136824853056 Train Accuracy:0.952215313911438 Test Accuracy : 0.9531000256538391
Epoch :3 Loss : 0.14257464571382256 Train Accuracy:0.9595929384231567 Test Accuracy : 0.9588000178337097
Epoch :4 Loss : 0.12598471959617277 Train Accuracy:0.9644113779067993 Test Accuracy : 0.9648000001907349
Epoch :5 Loss : 0.11733871379403024 Train Accuracy:0.9659308791160583 Test Accuracy : 0.9660000205039978
Epoch :6 Loss : 0.11000267015220401 Train Accuracy:0.9681901931762695 Test Accuracy : 0.9664999842643738
Epoch :7 Loss : 0.10582590816269605 Train Accuracy:0.9684301018714905 Test Accuracy : 0.9631999731063843
Epoch :8 Loss : 0.09624670598793283 Train Accuracy:0.9711892008781433 Test Accuracy : 0.9706000089645386
Epoch :9 Loss : 0.09422491162643551 Train Accuracy:0.9726687669754028 Test Accuracy : 0.9679999947547913
### Plotting the cost vs epochs
fig, ax = plt.subplots(figsize=(15, 5))
plt.plot(np.array(cost))
[&lt;matplotlib.lines.Line2D at 0x7fa94052aa00&gt;]