欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

上课笔记篇---用Pytorch实现简单的线性回归

程序员文章站 2022-05-26 20:43:50
...

@用Pytorch实现简单线性回归

用Pytorch实现简单的线性回归

之前关于深学啥的,看了一些书和论文,但真正上手,还没做过。现在项目有需求,就在微信上报了个班,开始学习Pytorch上手的东西。这里做一个记录,以备自己以后查看。

直接上代码吧

这个代码是教程上的,版权归教程所有,如果侵权,我会马上删掉。.

// An highlighted block
lr=0.1 #学习率
x=torch.rand(20,1) * 10 #随机生成的X
y=2*x+(5+torch.rand(20,1)) #在X基础上,添加随机量,生成Y

w=torch.randn((1),requires_grad=True) #随机给个W,后面的参数是自动求导时用的
b=torch.zeros((1),requires_grad=True) #b先给的是零

for iteration in range(1000):
	wx = torch.mul(w,x) # 给所有的X先乘个w
    y_pred = torch.add(wx, b) #然后根据y=wx+b求出预测的y
    
	#根据Y和预测的Y_Pred求出两个之间的平均误差
    loss=(0.5 * (y-y_pred) ** 2).mean() 
    #再根据这个误差进行反向传播,也就是梯度下降
    loss.backward()
	# 根据梯度下降获得的值更新w和b
    b.data.sub_(lr * b.grad)
    w.data.sub_(lr * w.grad)
    
	#这里主要是绘制图像,将散点和求出的直线绘制出来
    if iteration % 20 == 0:
       # 原代码中缺了下面这一句,导致画图会出问题
       plt.clf()
       plt.scatter(x.data.numpy(),y.data.numpy())
       plt.plot(x.data.numpy(),y_pred.data.numpy(),"r-",lw=5)
       plt.text(2,20,'Loss*%.4f'% loss.data.numpy(),fontdict={'size':20,'color':'red'})
       plt.xlim(1.5,10)
       plt.ylim(8,28)
       plt.title("Iteration : {}\n w: {} b: {}".format(iteration,w.data.numpy(),b.data.numpy()))
       plt.pause(0.5)
	   # 如果误差小到一定范围,则退出循环
       if loss.data.numpy() < 1:
           break

我是代码说明

上面这段代码,实现的是对随机生成的一堆散点实现曲线的拟合。具体的解释本来想写在这呢,后来还是觉得直接写在代码注释中好一些。