Coding 8: Data augmentation

In this assignment, we'll focus on data augmentation and pushing the validation performance of a CNN as high as possiblem

Load CIFAR-10.

Do data-augmentation here.

In [26]:
import torch
import torchvision
import torchvision.transforms as transforms

  
def fetch_dataloader(batch_size=256):
    """
    Iterable of (image, label) tuples.
    """
    transform = list()

    for s in ['train', 'valid']:
        transform = [
           transforms.ToTensor(),
#            transforms.Normalize((0.49, 0.47, 0.42), (0.24, 0.23, 0.24))
        ]
        if s == 'train':
            transform = [
               transforms.RandomHorizontalFlip(),
               transforms.RandomCrop(32, padding=4, fill=(127,127,127)),
#                transforms.ColorJitter(0.5, 0.5, 0.5, 0),
               transforms.ToTensor(),
#                transforms.Normalize((0.49, 0.47, 0.42), (0.24, 0.23, 0.24))
            ]

        transform = transforms.Compose(transform)

        data = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
        if s == 'train':
            data.data = data.data[:40000]
            data.targets = data.targets[:40000]
        else:
            data.data = data.data[40000:]
            data.targets = data.targets[40000:]
        yield torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True, num_workers=6)

A residual network from last week

You do not need to update or change this in this weeks assignment.

In [27]:
from torch import nn

class Block(nn.Module):
    def __init__(self, n_input, n_output, stride):
            super().__init__()
            self.net = nn.Sequential(
                       nn.Conv2d(n_input, n_output, kernel_size=3, stride=stride, padding=1, bias=False),
                       nn.BatchNorm2d(n_output),
                       nn.ReLU(),
                       nn.Conv2d(n_output, n_output, kernel_size=3, stride=1, padding=1, bias=False),
                       nn.BatchNorm2d(n_output))
            
            self.downsample = None
            if (n_input != n_output) or (stride != 1):              
                    self.downsample = nn.Sequential(
                        nn.Conv2d(n_input, n_output, kernel_size=3, stride=stride, padding=1, bias=False),
                        nn.BatchNorm2d(n_output))
                    
            self.relu = nn.ReLU()
            
    def forward(self, x):
            identity = x
            
            if self.downsample is not None:
                identity = self.downsample(identity)
                
            return self.relu(self.net(x) + identity)

        
class ResNet(nn.Module):
                    
    def __init__(self, layers, n_input_channels, out_classes):
        super().__init__()
        self.L = [nn.BatchNorm2d(3),
                  nn.Conv2d(n_input_channels, layers[0], kernel_size=3, stride=2, padding=1, bias=False),
                  nn.BatchNorm2d(layers[0]),
                  nn.ReLU(),
                  nn.Conv2d(layers[0], layers[0], kernel_size=2, stride=1, padding=1, bias=False),
                  nn.BatchNorm2d(layers[0]),
                  nn.ReLU()]
        inpc = layers[0]
        
        for l in layers:
            self.L.append(Block( n_input=inpc, n_output=l, stride=2))
            inpc = l            

        self.net = nn.Sequential(*self.L)
        self.classifier = nn.Linear(layers[-1], out_classes)
         
    
    def forward(self, x):
        # feature extractor = L
        features = self.net(x)
        
        # classifier = nn.linear
        return self.classifier(features.mean(dim=(2,3)))
In [30]:
def train(model, lr=0.001, n_epochs=30):
    import torch.utils.tensorboard as tb
    import time
    import numpy as np
    device = torch.device('cuda')
    writer = tb.SummaryWriter('log/{}'.format(time.strftime('%m-%d-%H-%M')), flush_secs=5)
    train_data, val_data = fetch_dataloader(batch_size=256)

    optim = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=5e-6)
    loss_func = nn.CrossEntropyLoss()
    steps = 0
    
    model.to(device)

    for epoch in range(n_epochs):
        train_acc = []
        val_acc = []
        
        model.train()
                
        # Train  
        for idx, (x, y) in enumerate(train_data):            
            x, y = x.to(device), y.to(device)
            y_pred = model(x)
            loss = loss_func(y_pred, y)
          
            iter_acc = (y_pred.argmax(dim=1) == y).float().cpu().numpy()
            train_acc.extend(iter_acc)
          
            optim.zero_grad()
            loss.backward()
            optim.step()
          
            steps += 1
            writer.add_scalar('loss/train_batch', loss.item(), steps)
        for i in range(4):
            writer.add_image('train/nice_picture%d'%i, x[i].cpu(), epoch)
        model.eval()
        
        # Validation
        for x, y in val_data:
            x, y = x.to(device), y.to(device)
            y_pred = model(x)
          
            val_acc.extend((y_pred.argmax(dim=1) == y).float().cpu().numpy())
        for i in range(4):
            writer.add_image('valid/nice_picture%d'%i, x[i].cpu(), epoch)
            
        ep_trn_acc = np.mean(train_acc)
        ep_val_acc = np.mean(val_acc)
        writer.add_scalar('accuracy/train_epoch', ep_trn_acc, epoch)
        writer.add_scalar('accuracy/valid_epoch', ep_val_acc, epoch)
        
        print('Epoch: %d, %.3f, %.3f' % (epoch, ep_trn_acc, ep_val_acc))
model = ResNet(layers=[32]*6, n_input_channels=3, out_classes=10)
train(model)
Files already downloaded and verified
Files already downloaded and verified
Epoch: 0, 0.332, 0.447
Epoch: 1, 0.467, 0.502
Epoch: 2, 0.524, 0.544
Epoch: 3, 0.570, 0.589
Epoch: 4, 0.600, 0.594
Epoch: 5, 0.623, 0.637
Epoch: 6, 0.646, 0.638
Epoch: 7, 0.663, 0.685
Epoch: 8, 0.679, 0.675
Epoch: 9, 0.693, 0.699
Epoch: 10, 0.705, 0.697
Epoch: 11, 0.717, 0.726
Epoch: 12, 0.724, 0.727
Epoch: 13, 0.731, 0.743
Epoch: 14, 0.739, 0.739
Epoch: 15, 0.747, 0.731
Epoch: 16, 0.755, 0.756
Epoch: 17, 0.758, 0.727
Epoch: 18, 0.763, 0.739
Epoch: 19, 0.766, 0.729
Epoch: 20, 0.770, 0.775
Epoch: 21, 0.780, 0.770
Epoch: 22, 0.778, 0.780
Epoch: 23, 0.780, 0.752
Epoch: 24, 0.784, 0.781
Epoch: 25, 0.789, 0.776
Epoch: 26, 0.792, 0.787
Epoch: 27, 0.798, 0.782
Epoch: 28, 0.799, 0.789
Epoch: 29, 0.800, 0.793
Epoch: 30, 0.802, 0.789
Epoch: 31, 0.805, 0.792
Epoch: 32, 0.807, 0.793
Epoch: 33, 0.810, 0.795
Epoch: 34, 0.811, 0.807
Epoch: 35, 0.813, 0.788
Epoch: 36, 0.814, 0.796
Epoch: 37, 0.819, 0.799
Epoch: 38, 0.817, 0.800
Epoch: 39, 0.820, 0.780
Epoch: 40, 0.823, 0.809
Epoch: 41, 0.823, 0.798
Epoch: 42, 0.825, 0.810
Epoch: 43, 0.829, 0.805
Epoch: 44, 0.826, 0.809
Epoch: 45, 0.827, 0.808
Epoch: 46, 0.830, 0.800
Epoch: 47, 0.832, 0.806
Epoch: 48, 0.832, 0.809
Epoch: 49, 0.831, 0.815
Epoch: 50, 0.835, 0.812
Epoch: 51, 0.837, 0.804
Epoch: 52, 0.838, 0.809
Epoch: 53, 0.837, 0.817
Epoch: 54, 0.841, 0.821
Epoch: 55, 0.844, 0.816
Epoch: 56, 0.840, 0.817
Epoch: 57, 0.840, 0.809
Epoch: 58, 0.845, 0.810
Epoch: 59, 0.845, 0.809
Epoch: 60, 0.845, 0.821
Epoch: 61, 0.845, 0.818
Epoch: 62, 0.848, 0.815
Epoch: 63, 0.850, 0.822
Epoch: 64, 0.848, 0.819
Epoch: 65, 0.849, 0.824
Epoch: 66, 0.852, 0.822
Epoch: 67, 0.851, 0.815
Epoch: 68, 0.852, 0.833
Epoch: 69, 0.853, 0.819
Epoch: 70, 0.853, 0.831
Epoch: 71, 0.855, 0.823
Epoch: 72, 0.854, 0.825
Epoch: 73, 0.858, 0.819
Epoch: 74, 0.857, 0.824
Epoch: 75, 0.860, 0.829
Epoch: 76, 0.860, 0.817
Epoch: 77, 0.859, 0.821
Epoch: 78, 0.860, 0.829
Epoch: 79, 0.862, 0.825
Epoch: 80, 0.861, 0.827
Epoch: 81, 0.860, 0.822
Epoch: 82, 0.862, 0.821
Epoch: 83, 0.866, 0.830
Epoch: 84, 0.862, 0.821
Epoch: 85, 0.862, 0.831
Epoch: 86, 0.865, 0.819
Epoch: 87, 0.861, 0.827
Epoch: 88, 0.865, 0.823
Epoch: 89, 0.865, 0.831
Epoch: 90, 0.867, 0.825
Epoch: 91, 0.868, 0.825
Epoch: 92, 0.867, 0.820
Epoch: 93, 0.871, 0.826
Epoch: 94, 0.867, 0.833
Epoch: 95, 0.867, 0.831
Epoch: 96, 0.869, 0.826
Epoch: 97, 0.872, 0.831
Epoch: 98, 0.871, 0.836
Epoch: 99, 0.870, 0.824
Epoch: 100, 0.870, 0.836
Epoch: 101, 0.871, 0.823
Epoch: 102, 0.871, 0.836
Epoch: 103, 0.874, 0.833
Epoch: 104, 0.873, 0.829
Epoch: 105, 0.875, 0.831
Epoch: 106, 0.874, 0.836
Epoch: 107, 0.875, 0.818
Epoch: 108, 0.876, 0.830
Epoch: 109, 0.874, 0.835
Epoch: 110, 0.875, 0.833
Epoch: 111, 0.877, 0.822
Epoch: 112, 0.880, 0.826
Epoch: 113, 0.877, 0.838
Epoch: 114, 0.878, 0.836
Epoch: 115, 0.880, 0.836
Epoch: 116, 0.878, 0.823
Epoch: 117, 0.878, 0.830
Epoch: 118, 0.879, 0.833
Epoch: 119, 0.879, 0.834
Epoch: 120, 0.879, 0.833
Epoch: 121, 0.881, 0.834
Epoch: 122, 0.879, 0.831
Epoch: 123, 0.883, 0.833
Epoch: 124, 0.883, 0.821
Epoch: 125, 0.883, 0.831
Epoch: 126, 0.884, 0.839
Epoch: 127, 0.883, 0.835
Epoch: 128, 0.884, 0.832
Epoch: 129, 0.883, 0.832
Epoch: 130, 0.883, 0.835
Epoch: 131, 0.884, 0.825
Epoch: 132, 0.883, 0.827
Epoch: 133, 0.888, 0.830
Epoch: 134, 0.884, 0.833
Epoch: 135, 0.884, 0.830
Epoch: 136, 0.885, 0.833
Epoch: 137, 0.886, 0.831
Epoch: 138, 0.886, 0.836
Epoch: 139, 0.884, 0.827
Epoch: 140, 0.888, 0.840
Epoch: 141, 0.888, 0.835
Epoch: 142, 0.889, 0.831
Epoch: 143, 0.891, 0.832
Epoch: 144, 0.887, 0.828
Epoch: 145, 0.888, 0.836
Epoch: 146, 0.887, 0.837
Epoch: 147, 0.889, 0.829
Epoch: 148, 0.887, 0.835
Epoch: 149, 0.889, 0.821
Epoch: 150, 0.891, 0.836
Epoch: 151, 0.892, 0.837
Epoch: 152, 0.894, 0.834
Epoch: 153, 0.891, 0.831
Epoch: 154, 0.893, 0.828
Epoch: 155, 0.893, 0.836
Epoch: 156, 0.893, 0.836
Epoch: 157, 0.892, 0.836
Epoch: 158, 0.892, 0.837
Epoch: 159, 0.893, 0.829
Epoch: 160, 0.894, 0.836
Epoch: 161, 0.893, 0.839
Epoch: 162, 0.894, 0.834
Epoch: 163, 0.894, 0.836
Epoch: 164, 0.896, 0.837
Epoch: 165, 0.895, 0.835
Epoch: 166, 0.894, 0.842
Epoch: 167, 0.894, 0.832
Epoch: 168, 0.893, 0.836
Epoch: 169, 0.898, 0.836
Epoch: 170, 0.895, 0.835
Epoch: 171, 0.898, 0.840
Epoch: 172, 0.895, 0.836
Epoch: 173, 0.898, 0.830
Epoch: 174, 0.895, 0.834
Epoch: 175, 0.896, 0.839
Epoch: 176, 0.897, 0.838
Epoch: 177, 0.899, 0.840
Epoch: 178, 0.899, 0.838
Epoch: 179, 0.898, 0.837
Epoch: 180, 0.898, 0.843
Epoch: 181, 0.900, 0.836
Epoch: 182, 0.900, 0.831
Epoch: 183, 0.900, 0.841
Epoch: 184, 0.900, 0.836
Epoch: 185, 0.900, 0.841
Epoch: 186, 0.900, 0.824
Epoch: 187, 0.899, 0.839
Epoch: 188, 0.898, 0.842
Epoch: 189, 0.901, 0.829
Epoch: 190, 0.900, 0.837
Epoch: 191, 0.901, 0.839
Epoch: 192, 0.904, 0.837
Epoch: 193, 0.900, 0.830
Epoch: 194, 0.901, 0.838
Epoch: 195, 0.903, 0.840
Epoch: 196, 0.903, 0.835
Epoch: 197, 0.902, 0.836
Epoch: 198, 0.901, 0.835
Epoch: 199, 0.902, 0.838
Epoch: 200, 0.903, 0.829
Epoch: 201, 0.903, 0.837
Epoch: 202, 0.902, 0.839
Epoch: 203, 0.904, 0.838
Epoch: 204, 0.904, 0.838
Epoch: 205, 0.906, 0.841
Epoch: 206, 0.905, 0.840
Epoch: 207, 0.905, 0.835
Epoch: 208, 0.903, 0.841
Epoch: 209, 0.906, 0.836
Epoch: 210, 0.906, 0.833
Epoch: 211, 0.905, 0.837
Epoch: 212, 0.907, 0.825
Epoch: 213, 0.907, 0.840
Epoch: 214, 0.904, 0.837
Epoch: 215, 0.904, 0.835
Epoch: 216, 0.904, 0.839
Epoch: 217, 0.906, 0.840
Epoch: 218, 0.905, 0.838
Epoch: 219, 0.908, 0.836
Epoch: 220, 0.907, 0.834
Epoch: 221, 0.905, 0.842
Epoch: 222, 0.907, 0.837
Epoch: 223, 0.907, 0.837
Epoch: 224, 0.908, 0.837
Epoch: 225, 0.904, 0.835
Epoch: 226, 0.910, 0.841
Epoch: 227, 0.906, 0.836
Epoch: 228, 0.906, 0.837
Epoch: 229, 0.905, 0.833
Epoch: 230, 0.910, 0.837
Epoch: 231, 0.909, 0.838
Epoch: 232, 0.909, 0.840
Epoch: 233, 0.908, 0.838
Epoch: 234, 0.909, 0.836
Epoch: 235, 0.909, 0.840
Epoch: 236, 0.909, 0.835
Epoch: 237, 0.908, 0.837
Epoch: 238, 0.908, 0.842
Epoch: 239, 0.909, 0.839
Epoch: 240, 0.913, 0.834
Epoch: 241, 0.909, 0.840
Epoch: 242, 0.911, 0.842
Epoch: 243, 0.910, 0.831
Epoch: 244, 0.912, 0.840
Epoch: 245, 0.912, 0.835
Epoch: 246, 0.911, 0.840
Epoch: 247, 0.909, 0.844
Epoch: 248, 0.913, 0.836
Epoch: 249, 0.912, 0.831
Epoch: 250, 0.914, 0.835
Epoch: 251, 0.914, 0.833
Epoch: 252, 0.913, 0.838
Epoch: 253, 0.913, 0.840
Epoch: 254, 0.910, 0.839
Epoch: 255, 0.912, 0.844
Epoch: 256, 0.910, 0.836
Epoch: 257, 0.915, 0.842
Epoch: 258, 0.911, 0.844
Epoch: 259, 0.913, 0.843
Epoch: 260, 0.915, 0.841
Epoch: 261, 0.916, 0.835
Epoch: 262, 0.915, 0.841
Epoch: 263, 0.913, 0.841
Epoch: 264, 0.912, 0.830
Epoch: 265, 0.915, 0.842
Epoch: 266, 0.913, 0.841
Epoch: 267, 0.915, 0.844
Epoch: 268, 0.912, 0.836
Epoch: 269, 0.913, 0.834
Epoch: 270, 0.915, 0.842
Epoch: 271, 0.914, 0.840
Epoch: 272, 0.915, 0.845
Epoch: 273, 0.913, 0.845
Epoch: 274, 0.918, 0.843
Epoch: 275, 0.914, 0.837
Epoch: 276, 0.918, 0.836
Epoch: 277, 0.918, 0.838
Epoch: 278, 0.918, 0.842
Epoch: 279, 0.915, 0.844
Epoch: 280, 0.915, 0.843
Epoch: 281, 0.916, 0.837
Epoch: 282, 0.913, 0.837
Epoch: 283, 0.917, 0.843
Epoch: 284, 0.916, 0.844
Epoch: 285, 0.916, 0.844
Epoch: 286, 0.917, 0.845
Epoch: 287, 0.917, 0.839
Epoch: 288, 0.919, 0.843
Epoch: 289, 0.919, 0.846
Epoch: 290, 0.918, 0.841
Epoch: 291, 0.918, 0.844
Epoch: 292, 0.918, 0.838
Epoch: 293, 0.918, 0.846
Epoch: 294, 0.919, 0.838
Epoch: 295, 0.917, 0.840
Epoch: 296, 0.920, 0.844
Epoch: 297, 0.920, 0.834
Epoch: 298, 0.919, 0.838
Epoch: 299, 0.921, 0.844
In [38]:
def valid_accuracy(m):
    import numpy as np
    device = torch.device('cuda')
    train_data, val_data = fetch_dataloader(batch_size=256)
    model.eval()
    model.to(device)
    
    val_acc = []
    for x, y in val_data:
        x, y = x.to(device), y.to(device)
        y_pred = m(x)

        val_acc.extend((y_pred.argmax(dim=1) == y).float().cpu().numpy())
    return np.mean(val_acc)
print( valid_accuracy(model) )
Files already downloaded and verified
Files already downloaded and verified
0.8439
In [42]:
class DataAug(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        
    def forward(self, x):
        c1 = self.model(x)
        c2 = self.model(torch.flip(x,[3]))
        return (c1+c2) / 2

aug_model = DataAug(model)
print( valid_accuracy(aug_model) )
Files already downloaded and verified
Files already downloaded and verified
0.8607
In [ ]: