Coding 8: Data augmentation

In this assignment, we'll focus on data augmentation and pushing the validation performance of a CNN as high as possiblem

Load CIFAR-10.

Do data-augmentation here.

In [ ]:
import torch
import torchvision
import torchvision.transforms as transforms

  
def fetch_dataloader(batch_size=256):
    """
    Iterable of (image, label) tuples.
    """
    transform = list()

    for s in ['train', 'valid']:
        transform = [
           transforms.ToTensor(),
           transforms.Normalize((0.49, 0.47, 0.42), (0.24, 0.23, 0.24))
        ]
        if s == 'train':
            # TODO: Do data augmentation here
            pass

        transform = transforms.Compose(transform)

        data = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
        if s == 'train':
            data.data = data.data[:40000]
            data.targets = data.targets[:40000]
        else:
            data.data = data.data[40000:]
            data.targets = data.targets[40000:]
        yield torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True, num_workers=6)

A residual network from last week

You do not need to update or change this in this weeks assignment.

In [ ]:
from torch import nn

class Block(nn.Module):
    def __init__(self, n_input, n_output, stride):
            super().__init__()
            self.net = nn.Sequential(
                       nn.Conv2d(n_input, n_output, kernel_size=3, stride=stride, padding=1, bias=False),
                       nn.BatchNorm2d(n_output),
                       nn.ReLU(),
                       nn.Conv2d(n_output, n_output, kernel_size=3, stride=1, padding=1, bias=False),
                       nn.BatchNorm2d(n_output))
            
            self.downsample = None
            if (n_input != n_output) or (stride != 1):              
                    self.downsample = nn.Sequential(
                        nn.Conv2d(n_input, n_output, kernel_size=3, stride=stride, padding=1, bias=False),
                        nn.BatchNorm2d(n_output))
                    
            self.relu = nn.ReLU()
            
    def forward(self, x):
            identity = x
            
            if self.downsample is not None:
                identity = self.downsample(identity)
                
            return self.relu(self.net(x) + identity)

        
class ResNet(nn.Module):
    def __init__(self, layers, n_input_channels, out_classes):
        super().__init__()
        self.L = [nn.Conv2d(n_input_channels, layers[0], kernel_size=3, stride=2, padding=1, bias=False),
                  nn.BatchNorm2d(layers[0]),
                  nn.ReLU(),
                  nn.Conv2d(layers[0], layers[0], kernel_size=2, stride=1, padding=1, bias=False),
                  nn.BatchNorm2d(layers[0]),
                  nn.ReLU()]
        inpc = layers[0]
        
        for l in layers:
            self.L.append(Block( n_input=inpc, n_output=l, stride=2))
            inpc = l            

        self.net = nn.Sequential(*self.L)
        self.classifier = nn.Linear(layers[-1], out_classes)
           
    def forward(self, x):
        features = self.net(x)
        
        return self.classifier(features.mean(dim=(2,3)))
In [ ]:
def train(model, lr=0.001, n_epochs=30):
    import torch.utils.tensorboard as tb
    import time
    import numpy as np
    device = torch.device('cuda')
    writer = tb.SummaryWriter('log/{}'.format(time.strftime('%m-%d-%H-%M')), flush_secs=5)
    train_data, val_data = fetch_dataloader(batch_size=256)

    optim = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=5e-6)
    loss_func = nn.CrossEntropyLoss()
    steps = 0
    
    model.to(device)

    for epoch in range(n_epochs):
        train_acc = []
        val_acc = []
        
        model.train()
                
        # Train  
        for idx, (x, y) in enumerate(train_data):            
            x, y = x.to(device), y.to(device)
            y_pred = model(x)
            loss = loss_func(y_pred, y)
          
            iter_acc = (y_pred.argmax(dim=1) == y).float().cpu().numpy()
            train_acc.extend(iter_acc)
          
            optim.zero_grad()
            loss.backward()
            optim.step()
          
            steps += 1
            writer.add_scalar('loss/train_batch', loss.item(), steps)
        model.eval()
        
        # Validation
        for x, y in val_data:
            x, y = x.to(device), y.to(device)
            y_pred = model(x)
          
            val_acc.extend((y_pred.argmax(dim=1) == y).float().cpu().numpy())
            
        ep_trn_acc = np.mean(train_acc)
        ep_val_acc = np.mean(val_acc)
        writer.add_scalar('accuracy/train_epoch', ep_trn_acc, epoch)
        writer.add_scalar('accuracy/valid_epoch', ep_val_acc, epoch)
        
        print('Epoch: %d, %.3f, %.3f' % (epoch, ep_trn_acc, ep_val_acc))
model = ResNet(layers=[32]*6, n_input_channels=3, out_classes=10)
train(model)
In [ ]: