Coding 2: K Nearest Neighbors

The k-nearest neighbors algorithm (KNN) is a non-parametric method.

KNN relies on the observation that (in an appropriate embedding space) similar images exist in proximity. A new test image is classified by simply looking at nearby train images and aggregating the evidence.

The number of neighbors used for prediction may lead to varying outcomes. For example, the green circle below may be classified either as a red triangle based on 3-nearest-neighbors or as a blue square based on the 5-nearest-neighbors. KNN-Example

In this exercise, we will apply KNN to classify the CIFAR-10 dataset. CIFAR-10 consists of 32x32 images from 10 classes. The train set consists of 50k images and the test set consists of 10k images. The following is are sample images from each class: CIFAR-10-Samples

In [ ]:
%matplotlib inline
import matplotlib.pyplot as plt
import torch
import torchvision
import torchvision.transforms as transforms

Function to load CIFAR-10.

You do not need to fully understand this.

In [ ]:
def fetch_dataloader(transform=None, batch_size=-1, is_train=True):
    Loads data from disk and returns a data_loader.
    A DataLoader is similar to a list of (image, label) tuples.
    You do not need to fully understand this code to do this assignment, we're happy to explain though.
    data = torchvision.datasets.CIFAR10(root='./data',
                                            train=is_train, download=True, transform=transform)
    batch = len(data) if batch_size is -1 else batch_size
    loader =, batch_size=batch,
                                              shuffle=True, num_workers=4)
    return loader

Fetch and preprocess data.

You do not need to fully understand this.

In [ ]:
classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

transform = transforms.Compose(
             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = -1

loader_trn = fetch_dataloader(transform, batch_size, is_train=True)

x_trn, y_trn = iter(loader_trn).next()
print(x_trn.shape, y_trn.shape)

A class for implementing K Nearest Neighbors algorithm.

In [ ]:
class KNearestNeighborClassifier(object):

    def __init__(self, k):
        self.k = k

    def train(self, x_trn, y_trn):
        Implement this function.
        "Train" your knn classifier.
        Hint: no computation is involved
    def predict(self, image):
        Implement this function.
        Compute distances between a test image and all train samples.
        Predict the label by voting on K nearest train samples.

Initialize an instance of KNearestNeighbor

In [ ]:
# Use a subset of the dataset for practicing implementation.
x_trn = x_trn[:1000]
y_trn = y_trn[:1000]
K = 3

model = KNearestNeighborClassifier(k=K)

Implement training here.

In [ ]:
model.train(x_trn, y_trn)

Evaluate on test data.

In [ ]:
loader_tst = fetch_dataloader(transform, batch_size, is_train=False)
x_tst, y_tst = iter(loader_tst).next()
x_tst = x_tst.view(x_tst.shape[0], -1)
print(x_tst.shape, y_tst.shape)

def compute_accuracy(model, x_tst, y_tst):
    Compute predictions on all test samples and report accuracy.
compute_accuracy(model, x_tst, y_tst)


KNN visual example is modeled after the KNN Wikipedia article.

CIFAR10 data montage is borrowed from the CIFAR10 homepage.

Dataloader and visualizer borrowed from the PyTorch Tutorials.

In [ ]: