欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

深度学习零基础使用 PyTorch 框架跑 MNIST 数据集的第四天:单例测试

程序员文章站 2022-07-14 18:46:47
...

1. Introduction

今天是尝试用 PyTorch 框架来跑 MNIST 手写数字数据集的第四天,主要学习导入模型并进行单例测试。本 blog 主要记录一个学习的路径以及学习资料的汇总。

注意:这是用 Python 2.7 版本写的代码

第一天:https://blog.csdn.net/qq_36627158/article/details/108098147

第二天:

第三天:https://blog.csdn.net/qq_36627158/article/details/108163693

第四天:https://blog.csdn.net/qq_36627158/article/details/108183655

 

 

 

2. Code(mnist_classify.py)

感谢 凯神 提供的代码与耐心指导!

from torchvision import transforms
from PIL import Image, ImageOps
from mnist_train import *


classes = ('0', '1', '2', '3', '4', '5', '6', '7', '8', '9')
model = Net()


def load_checkpoint(checkpoint_path, model):
    state = torch.load(checkpoint_path)
    model.load_state_dict(state['model'])


if __name__ == '__main__':
    load_checkpoint(
        'module/pytorch-mnist-batch-128-1407.pth',
        model
    )

    model = model.to(device)
    model.eval()

    img = Image.open("/home/ubuntu/Downloads/C6/3.jpg")
    img = ImageOps.invert(img)

    # rgb -> single channel image
    if len(img.split()) > 1:
        img = img.split()[0]

    plt.figure()
    plt.imshow(img)
    plt.show()

    trans = transforms.Compose([
        transforms.Resize((28, 28)),
        transforms.ToTensor(),
    ])
    img = trans(img)

    img = img.to(device)

    img = img.unsqueeze(0)

    output = model(img)
    prob = F.softmax(output, dim=1)

    max_value, max_index = torch.max(prob, 1)

    pred_class = classes[max_index.item()]
    print 'predicted class is', pred_class, ', probability is', round(max_value.item(), 6) * 100

 

 

 

3. Details

1、im.split()

r, g, b=im.split()   该函数用来将RGB图片分割成三个通道的图片

Python-Image 基本的图像处理操作

 

2、torch.unsqueeze()

为 Torch Tensor 添加维度

https://blog.csdn.net/xiexu911/article/details/80820028