深度学习零基础使用 PyTorch 框架跑 MNIST 数据集的第四天:单例测试
程序员文章站
2022-07-14 18:46:47
...
1. Introduction
今天是尝试用 PyTorch 框架来跑 MNIST 手写数字数据集的第四天,主要学习导入模型并进行单例测试。本 blog 主要记录一个学习的路径以及学习资料的汇总。
注意:这是用 Python 2.7 版本写的代码
第一天:https://blog.csdn.net/qq_36627158/article/details/108098147
第二天:
第三天:https://blog.csdn.net/qq_36627158/article/details/108163693
第四天:https://blog.csdn.net/qq_36627158/article/details/108183655
2. Code(mnist_classify.py)
感谢 凯神 提供的代码与耐心指导!
from torchvision import transforms
from PIL import Image, ImageOps
from mnist_train import *
classes = ('0', '1', '2', '3', '4', '5', '6', '7', '8', '9')
model = Net()
def load_checkpoint(checkpoint_path, model):
state = torch.load(checkpoint_path)
model.load_state_dict(state['model'])
if __name__ == '__main__':
load_checkpoint(
'module/pytorch-mnist-batch-128-1407.pth',
model
)
model = model.to(device)
model.eval()
img = Image.open("/home/ubuntu/Downloads/C6/3.jpg")
img = ImageOps.invert(img)
# rgb -> single channel image
if len(img.split()) > 1:
img = img.split()[0]
plt.figure()
plt.imshow(img)
plt.show()
trans = transforms.Compose([
transforms.Resize((28, 28)),
transforms.ToTensor(),
])
img = trans(img)
img = img.to(device)
img = img.unsqueeze(0)
output = model(img)
prob = F.softmax(output, dim=1)
max_value, max_index = torch.max(prob, 1)
pred_class = classes[max_index.item()]
print 'predicted class is', pred_class, ', probability is', round(max_value.item(), 6) * 100
3. Details
1、im.split()
r, g, b=im.split() 该函数用来将RGB图片分割成三个通道的图片
2、torch.unsqueeze()
为 Torch Tensor 添加维度
上一篇: Colab使用方法