In [ ]:
%pylab inline
import torch
import sys
sys.path.append('..')
sys.path.append('../..')
from data import load
train_data, train_label = load.get_dogs_and_cats_data(resize=(32,32), n_images=10)
device = torch.device('cuda:2') if torch.cuda.is_available() else torch.device('cpu')
print('device = ', device)
In [ ]:
class ConvNet(torch.nn.Module):
    def __init__(self, layers=[], n_input_channels=3, kernel_size=3, stride=2):
        super().__init__()
        L = []
        c = n_input_channels
        for l in layers:
            L.append(torch.nn.Conv2d(c, l, kernel_size, padding=(kernel_size-1)//2, stride=stride))
            L.append(torch.nn.ReLU())
            c = l
        L.append(torch.nn.Conv2d(c, 1, kernel_size=1))
        self.layers = torch.nn.Sequential(*L)
    
    def forward(self, x):
        return self.layers(x).mean([1,2,3])

net = ConvNet([32,64])
In [ ]:
print( train_data[:1].shape )
print( net(train_data[:1]).shape )
In [ ]:
net2 = ConvNet([32,64,128])
print( net2(train_data[:1]).shape )
In [ ]:
class ConvNet2(torch.nn.Module):
    def __init__(self, layers=[], n_input_channels=3, kernel_size=3, stride=2):
        super().__init__()
        L = []
        c = n_input_channels
        for l in layers:
            L.append(torch.nn.Conv2d(c, l, kernel_size, padding=(kernel_size-1)//2))
            L.append(torch.nn.ReLU())
            L.append(torch.nn.MaxPool2d(3, padding=1, stride=stride))
            c = l
        L.append(torch.nn.Conv2d(c, 1, kernel_size=1))
        self.layers = torch.nn.Sequential(*L)
    
    def forward(self, x):
        return self.layers(x).mean([1,2,3])

net3 = ConvNet2([32,64,128])
In [ ]:
print( net3(train_data[:1]).shape )
In [ ]:
%load_ext tensorboard
import tempfile
log_dir = tempfile.mkdtemp()
%tensorboard --logdir {log_dir} --reload_interval 1
In [ ]:
from util import train
train.train(net2, batch_size=128, resize=(32,32), log_dir=log_dir+'/net2', device=device, n_epochs=100)
In [ ]:
from util import train
train.train(net3, batch_size=128, resize=(32,32), log_dir=log_dir+'/net3', device=device, n_epochs=100)