欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Pytorch入坑指南

程序员文章站 2022-07-14 20:24:26
...

Pytorch入坑指南

pytorch使用tensorboard

  1. 安装Pytorch、Torchvision、Tensorboard
pip install --upgrade torch torchvision
pip install tensorboard
  1. 运行代码测试
import torch
import torchvision
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets,transforms

# 设定输出路径,不写则默认  
writer=SummaryWriter('/home/bobo/Download/tensorboardDir')

# 设定  数据预处理(归一化、增广等)  的步骤
transform=transforms.Compose([
    transforms.ToTensor(),  #归一化
    transforms.Normalize((0.5,),(0.5,))]) #均值 方差

# 采用 内置数据集
trainset=datasets.MNIST('mnist_train',train=True,download=True,transform=transform)
#数据集加载器
trainloader=torch.utils.data.DataLoader(trainset,batch_size=64,shuffle=True)

# 采用 内置网络  False代表 不使用预训练模型(因为要改第一层,故 预训练权重不适用)
model=torchvision.models.resnet50(False)

# 由于数据集是灰度图,故修改网络输入,由RGB改为灰度图
model.conv1=torch.nn.Conv2d(in_channels=1,out_channels=64,kernel_size=7,stride=2,padding=3,bias=False)

#加载 训练数据
images,labels=next(iter(trainloader))

# 设定网格  将一个batch的图像 转化为 一张网格图像
grid=torchvision.utils.make_grid(images)

# 展示训练图像
writer.add_image('images',grid,0)
#展示 模型结构图
writer.add_graph(model,images)


#一定要加
writer.close()
  1. 启动tensorflow
tensorboard --logdir=/home/bobo/Download/tensorboardDir  --port 6006

Pytorch入坑指南
注意:
(1)Python代码及启动命令要明确指定路径,因为tensorboard会显示该文件夹下的所有内容。
(2) 若网页无法打开,可能是该端口占用。使用ps -a查看,并使用kill -9 端口号即可。

参考
PyTorch 自带 TensorBoard 使用教程
PyTorch 1.1 or 1.2 使用Tensorboard