Coding 10: Human pose estimation with heatmaps

In this assignment, we will train a estimate human keypoint on LSP dataset with heatmaps.

drawing

Implement heatmaps2vectors below.

In [0]:
class LSPDataset(torch.utils.data.Dataset):
    # data (N x 128 x 128 x 3): Images
    # targets (N x 14 x 2): 14 Human joint locations (x, y) in pixel,
    #   normalized to [-1, 1). x=-1 is the left side of the image, x=1 the right.
    def __init__(self, data, targets, img_transforms, img_target_transforms):
        super().__init__()
        self.data, self.targets = data, targets
        self.img_transforms = img_transforms
        self.img_target_transforms = img_target_transforms

    def __getitem__(self, index):
        data, target = self.data[index], self.targets[index]
        for img_target_transform in self.img_target_transforms:
            data, target = img_target_transform(data, target)
        for img_transform in self.img_transforms:
            data = img_transform(data)
        heatmap = vector2heatmap(target)
        return data, target.reshape(-1), heatmap

    def __len__(self):
        return len(self.data)

def _draw_one_point(heatmap, point, radius=2):
    """
    Drawing one point on one specific heatmap channel. Adopted form the homework.
    Input: heatmap (32 x 32): one channel heatmap
           point (2): x, y normalized coordinate of the point, in -1 ~ 1.
    Output: None. The heatmap will be modified inplace.
    """
    cx, cy, R = int(point[0] * 16 + 16), int(point[1] * 16 + 16), int(2 * radius)
    heat_crop = heatmap[max(cy - R, 0):cy + R + 1, max(cx - R, 0):cx + R + 1]
    g_x = (-((torch.arange(heat_crop.size(1)).float() - min(R, cx)) / radius) ** 2).exp()
    g_y = (-((torch.arange(heat_crop.size(0)).float() - min(R, cy)) / radius) ** 2).exp()
    g = g_x[None, :] * g_y[:, None]
    heat_crop[...] = torch.max(heat_crop, g)

def vector2heatmap(target):
    """
    Render heatmap. Similar to the homework, but we only have one peak per heatmap.
    input: target (14 x 2): keypoint coordinate in normalized space (-1 ~ 1)
    output: heatmap(14 x 32 x 32): heatmap with only one peak in each channel.
    """
    target = target.reshape((14, 2))
    heatmap = torch.zeros((14, 32, 32)).float()
    for j in range(target.shape[0]):
        _draw_one_point(heatmap[j], target[j])
    return heatmap

def heatmaps2vectors(heatmaps):
    """
    TODO: implement this.
    input: heatmaps (B x 14 x 32 x 32): a batch of heatmaps.
    output: target (B x 14 x 2): a batch of keypoint coordinate in (-1 ~ 1)
    Note: Implement the model and loss first before implementing this.
    Note: mpjpe will be invalid before implementing this.
    Hint: There is only one peak in one heatmap. So torch.argmax is sufficient.
    Hint: Don't use for loop. It is too slow.
    """
    target = torch.zeros((heatmaps.shape[0], 14, 2))
    return target

def fetch_dataloader(batch_size=64):
    data, targets = pickle.load(open('lsp_dataset.pkl', 'rb'))
    splits = {'train': [0, 1000], 'val': [1000, 1500], 'test': [1500, 2000]}
    for s in splits:
        img_transforms, img_target_transforms = [transforms.ToTensor()], []
        if s == 'train':
            img_target_transforms.append(flip_augment)
        dataset = LSPDataset(data[splits[s][0]:splits[s][1]], 
            targets[splits[s][0]:splits[s][1]], 
            img_transforms, img_target_transforms)
        yield torch.utils.data.DataLoader(
            dataset, batch_size=batch_size if s != 'test' else 1, 
            shuffle=(s != 'test'), num_workers=6)
In [0]:
%reload_ext tensorboard
%tensorboard --logdir log --reload_interval 1
In [0]:
def train(model, train_data, val_data, lr=0.001, n_epochs=30):
    device = torch.device('cuda')
    writer = tb.SummaryWriter(
        'log/{}'.format(time.strftime('%m-%d-%H-%M')), flush_secs=5)
    
    optim = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=5e-6)
    # TODO: define training loss here
    # Hint: visualize the network output at tensorboard --> images
    loss_func = None
    steps = 0
    
    # Standard training loop.
    model.to(device)
    for epoch in range(n_epochs):
        train_mpjpe, val_mpjpe, train_loss = [], [], []
        model.train()
        for idx, (x, y, heatmap) in enumerate(train_data):
            x, y, heatmap = x.to(device), y, heatmap.to(device)
            heatmap_pred = model(x)
            loss = loss_func(heatmap_pred, heatmap)
            y_pred = heatmaps2vectors(heatmap_pred)
            train_mpjpe.append(mpjpe(y_pred, y))
            train_loss.append(loss)
          
            optim.zero_grad()
            loss.backward()
            optim.step()
          
            steps += 1
            writer.add_scalar('loss/train_batch', loss.item(), steps)
        for i in range(4):
            writer.add_image('train/pose%d'%i, draw_pose(
                x[i].cpu(), y[i].cpu(), y_pred[i].detach().cpu()), epoch)
            writer.add_image('train/gt_heatmap%d'%i,
                heatmap[i].max(dim=0, keepdim=True)[0].detach().cpu(), epoch)
            writer.add_image('train/pred_heatmap%d'%i,
                heatmap_pred[i].max(dim=0, keepdim=True)[0].sigmoid().detach().cpu(), epoch)
            
        model.eval()
        for x, y, heatmap in val_data:
            x, y, heatmap = x.to(device), y, heatmap.to(device)
            heatmap_pred = model(x)
            y_pred = heatmaps2vectors(heatmap_pred)
            val_mpjpe.append(mpjpe(y_pred, y))
        for i in range(4):
            writer.add_image(
                'valid/pose%d'%i, draw_pose(
                    x[i].cpu(), y[i].cpu(), y_pred[i].detach().cpu()), epoch)
            writer.add_image('valid/gt_heatmap%d'%i, 
                heatmap[i].max(dim=0, keepdim=True)[0].detach().cpu(), epoch)
            writer.add_image('valid/pred_heatmap%d'%i,
                heatmap_pred[i].max(dim=0, keepdim=True)[0].sigmoid().detach().cpu(), epoch)
            
        ep_trn_mpjpe = torch.mean(torch.stack(train_mpjpe))
        ep_val_mpjpe = torch.mean(torch.stack(val_mpjpe))
        ep_train_loss = torch.mean(torch.stack(train_loss))
        writer.add_scalar('mpjpe/train_epoch', ep_trn_mpjpe, epoch)
        writer.add_scalar('mpjpe/valid_epoch', ep_val_mpjpe, epoch)
        
        print('Epoch: %d, train loss: %.4f, train mpjpe: %.3f, val mpjpe: %.3f'\
              % (epoch, ep_train_loss, ep_trn_mpjpe, ep_val_mpjpe))
        
        if epoch in [25]:
            lr = lr / 10
            for param in optim.param_groups:
                param['lr'] = lr


model = models.resnet18(pretrained=True)
# TODO: define the up convolutional layers here.
#   We takes the first 3 convolutional blocks of resnet18, in shape (256 x 8 x 8),
#   Add ConvTranspose2d layers here to make the output in shape (14 x 32 x 32).
up_layers = torch.nn.Sequential()

# This line truncate the regular resnet18 at the 3rd block and append our up layers.
model = torch.nn.Sequential(*list(model.children())[:-3], up_layers)
train_data, val_data, test_data = fetch_dataloader(batch_size=32)
train(model, train_data, val_data)