Coding 7: Residual Networks on CIFAR10

In this exercise, we will learn to use skip connections to build residual networks.

skip_connection

Image credits: https://i.stack.imgur.com/msvse.png.

Your residual network must have atleast 8 convolutional layers.

Tensorboard (run this only once).

In [ ]:
%load_ext tensorboard

Load CIFAR-10.

In [ ]:
import torch
import torchvision
import torchvision.transforms as transforms

  
def fetch_dataloader(batch_size):
    """
    Iterable of (image, label) tuples.
    """
    transform = transforms.Compose([
        transforms.ToTensor(),
       transforms.Normalize((0.49, 0.47, 0.42), (0.24, 0.23, 0.24))
    ])

    trn_data = torchvision.datasets.CIFAR10(
        root='./data', train=True, download=True, transform=transform)
    trn_data.data = trn_data.data[:40000]
    trn_data.targets = trn_data.targets[:40000]
    
    trn_loader = torch.utils.data.DataLoader(
        trn_data, batch_size=batch_size, shuffle=True, num_workers=2)

    val_data = torchvision.datasets.CIFAR10(
        root='./data', train=True, download=True, transform=transform)
    val_data.data = val_data.data[40000:]
    val_data.targets = val_data.targets[40000:]
    
    val_loader = torch.utils.data.DataLoader(
        val_data, batch_size=batch_size, shuffle=False, num_workers=2)

    return trn_loader, val_loader

Implement the model here.

In [ ]:
from torch import nn

class ResNet(nn.Module):
    class Block(nn.Module):
        def __init__(self, n_input, n_output, stride):
            super().__init__()
            # Initialize a block.
            """
            n_input: Number of input channels in residual block.
            n_output: Number of output channels in residual block.
            stride: Convolutional stride in residual block.
            """
            
        def forward(self, x):
            # Implement forward for block.
            pass        
                
    def __init__(self, layers, n_input_channels, out_classes):
        super().__init__()
        # Initialize the network.
        """
        layers: Number of residual blocks in the network.
        n_input_channels: Number of input channels in image.
        out_classes: Number of classes in dataset.
        """
    
    def forward(self, x):
        # Initialize forward for network.
        pass
In [ ]:
import numpy as np

def train(model, train_data, val_data, writer, device, lr=0.01, n_epochs=20):
    optim = torch.optim.Adam(model.parameters(), lr=lr)
    loss_func = nn.CrossEntropyLoss()
    steps = 0
    
    model.to(device)

    for epoch in range(n_epochs):
        train_acc = []
        val_acc = []
        
        # Train  
        for idx, (x, y) in enumerate(train_data):
            x, y = x.to(device), y.to(device)
            y_pred = model(x)
            loss = loss_func(y_pred, y)
          
            iter_acc = (y_pred.argmax(dim=1) == y).float().cpu().numpy()
            train_acc.extend(iter_acc)
          
            optim.zero_grad()
            loss.backward()
            optim.step()
          
            steps += 1
            writer.add_scalar('loss/train_batch', loss.item(), steps)
                      
        # Validation
        for x, y in val_data:
            x, y = x.to(device), y.to(device)
            y_pred = model(x)
          
            val_acc.extend((y_pred.argmax(dim=1) == y).float().cpu().numpy())
            
        ep_trn_acc = np.mean(np.stack(train_acc, 0))
        ep_val_acc = np.mean(np.stack(val_acc, 0))
        writer.add_scalar('accuracy/train_epoch', ep_trn_acc, epoch)
        writer.add_scalar('accuracy/valid_epoch', ep_val_acc, epoch)

Train your network

In [ ]:
%reload_ext tensorboard
%tensorboard --logdir log --reload_interval 1

import time
import torch.utils.tensorboard as tb

train_data, val_data = fetch_dataloader(batch_size=250)
In [ ]:
layers = [4, 8, 16, 32]
model = ResNet(layers=layers, n_input_channels=3, out_classes=10)
writer = tb.SummaryWriter('log/{}'.format(time.strftime('%m-%d-%H-%M')))
device = torch.device('cuda')

train(model, train_data, val_data, writer, device)