Coding 9: Human pose regression

In this assignment, we will train a regression network for human pose estimation on LSP dataset.

drawing

In [ ]:
import torch
import torchvision.models as models
import torchvision.transforms as transforms
import pickle
from torchvision.transforms import functional as F
import torch.utils.tensorboard as tb
import time

# Download dataset
!wget http://www.cs.utexas.edu/~zhouxy/lsp_dataset.pkl

Implement data augmentation below (probably on Friday)

In [ ]:
class LSPDataset(torch.utils.data.Dataset):
    # data (N x 128 x 128 x 3): Images
    # targets (N x 14 x 2): 14 Human joint locations (x, y) in pixel,
    #   normalized to [-1, 1). x=-1 is the left side of the image, x=1 the right.
    # The human joints are 14 points in order of:
    #  [Right ankle, Right knee, Right hip, Left hip, Left knee, Left ankle, 
    #   Right wrist, Right elbow, Right shoulder, Left shoulder, Left elbow, 
    #   Left wrist, Neck, Head top]
    def __init__(self, data, targets, img_transforms, img_target_transforms):
        # img_transforms: list of functions that only change the image
        # img_target_transforms: list of functions that also transform the labels
        super().__init__()
        self.data, self.targets = data, targets
        self.img_transforms = img_transforms
        self.img_target_transforms = img_target_transforms

    def __getitem__(self, index):
        data, target = self.data[index], self.targets[index]
        for img_target_transform in self.img_target_transforms:
            data, target = img_target_transform(data, target)
        for img_transform in self.img_transforms:
            data = img_transform(data)
        return data, target.reshape(-1)

    def __len__(self):
        return len(self.data)

def flip_augment(data, target):
    """
    TODO: implement
    Hint: when the image is flipped, the left and right joint are also 
          swapped. The flipped index is defined in flipped_index.
          flipped_index = [5, 4, 3, 2, 1, 0, 11, 10, 9, 8, 7, 6, 12, 13]
          The indices above means joint 0 is now 5, 1->4, 2->3, 3->2, ...
    Hint: use F.hflip(data) to flip the image.
    """
    return data, target

def fetch_dataloader(batch_size=64):
    """
    Iterable of (image, label) tuples.
    """
    data, targets = pickle.load(open('lsp_dataset.pkl', 'rb'))
    splits = {
        'train': [0, 1000], 'val': [1000, 1500], 'test': [1500, 2000]}
    for s in splits:
        img_transforms, img_target_transforms = [], []
        if s == 'train':
            img_target_transforms.append(flip_augment)
            # TODO: add more augmentations here
        img_transforms.append(transforms.ToTensor())
        dataset = LSPDataset(data[splits[s][0]:splits[s][1]], 
            targets[splits[s][0]:splits[s][1]], 
            img_transforms, img_target_transforms)
        yield torch.utils.data.DataLoader(
            dataset, batch_size=batch_size if s != 'test' else 1, 
            shuffle=(s != 'test'), num_workers=6)

Visualization and evaluation code. Feel free to ignore.

In [ ]:
def mpjpe(pred, gt):
    """
    The evaluation metric: mean per joint position error.
    """
    # convert back to the original pixel space.
    pred, gt = pred.view(-1, 14, 2) * 64 + 64, gt.view(-1, 14, 2) * 64 + 64
    return (((pred - gt) ** 2).sum(dim=2) ** 0.5).mean()
In [ ]:
def train(model, train_data, val_data, lr=0.001, n_epochs=30):
    device = torch.device('cuda')
    writer = tb.SummaryWriter(
        'log/{}'.format(time.strftime('%m-%d-%H-%M')), flush_secs=5)
    
    optim = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=5e-6)
    # TODO: define training loss here
    loss_func = None
    steps = 0
    
    # Standard training loop.
    model.to(device)
    for epoch in range(n_epochs):
        train_mpjpe, val_mpjpe = [], []
        model.train()
        for idx, (x, y) in enumerate(train_data):
            x, y = x.to(device), y.to(device)
            y_pred = model(x)
            loss = loss_func(y_pred, y)
            train_mpjpe.append(mpjpe(y_pred, y))
          
            optim.zero_grad()
            loss.backward()
            optim.step()
          
            steps += 1
            writer.add_scalar('loss/train_batch', loss.item(), steps)
        for i in range(4):
            writer.add_image('train/pose%d'%i, draw_pose(
                x[i].cpu(),  y[i].cpu(), y_pred[i].detach().cpu()), epoch)

        model.eval()
        for x, y in val_data:
            x, y = x.to(device), y.to(device)
            y_pred = model(x)
            val_mpjpe.append(mpjpe(y_pred, y))
        for i in range(4):
            writer.add_image(
                'valid/pose%d'%i, draw_pose(
                    x[i].cpu(), y[i].cpu(), y_pred[i].detach().cpu()), epoch)
            
        ep_trn_mpjpe = torch.mean(torch.stack(train_mpjpe))
        ep_val_mpjpe = torch.mean(torch.stack(val_mpjpe))
        writer.add_scalar('mpjpe/train_epoch', ep_trn_mpjpe, epoch)
        writer.add_scalar('mpjpe/valid_epoch', ep_val_mpjpe, epoch)
        
        print('Epoch: %d, %.3f, %.3f' % (epoch, ep_trn_mpjpe, ep_val_mpjpe))
        
        if epoch in [25]:
            lr = lr / 10
            for param in optim.param_groups:
                param['lr'] = lr

# TODO: Build your model here
# You may use a pre-trained torchvision model for this exercise (e.g. torchvision.models.resnet18(pretrained=True))
# Remeber, this is NOT allowed for the homework.
# In order to repurpose the pre-trained network for keypoint detection replace the last fc layer
# Hint: model.fc.in_features gives you the required input features, the number of outputs should be 28
model = None # TODO:

# Train your model
train_data, val_data, test_data = fetch_dataloader(batch_size=32)
train(model, train_data, val_data)