In [ ]:
%pylab inline
import torch
import torchvision
from PIL import Image
import json
class_idx = json.load(open("imagenet_class_index.json"))
I1 = Image.open('dog.jpg')
I2 = Image.open('cat.jpg')
In [ ]:
model = torchvision.models.resnext101_32x8d(pretrained=True, progress=False)
# model = torchvision.models.mobilenet_v2(pretrained=True, progress=False)
In [ ]:
transform = torchvision.transforms.Compose([
    torchvision.transforms.Scale(224),
    torchvision.transforms.CenterCrop(224),
    torchvision.transforms.ToTensor(), 
    torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

model.eval()

p = model(transform(I1)[None])[0]
print( ' , '.join([class_idx[str(int(i))][1] for i in p.argsort(descending=True)[:5]]) )
p = model(transform(I2)[None])[0]
print( ' , '.join([class_idx[str(int(i))][1] for i in p.argsort(descending=True)[:5]]) )
In [ ]:
model.fc = torch.nn.Linear(2048, 10)
In [ ]:
model(transform(I)[None])
In [ ]: