欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

python学习—简单线性回归模型

程序员文章站 2024-02-16 09:00:58
...

初学线性回归,个人表示比较蒙,做下笔记,供自己回看

1.加载需要的模块

import pandas as pd
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
plt.rcParams["font.sans-serif"] = "Simhei"

2.读取显示数据

data = pd.read_excel("data/LinearRegression.xlsx")
data
>>>>>
	ID	visitors	sales
0	1   	3  	    10
1	2	    5	    31
2	3	    8	    45
3	4		8		50
4	5		14		75
5	6		15		85
6	7		16		109
7	8		18		117
8	9		22		138
9	10		24		145

3.画出data数据散点图

plt.scatter(x=data.visitors,y=data.sales,c="r",marker="o",edgecolors="b")
plt.xlabel("访客量")
plt.ylabel("销售额")
plt.title("访客量和销售额散点图")

python学习—简单线性回归模型
4.简单线性回归模型

#特征,访客数量
x = data.visitors.values.reshape(-1,1)
#目标,销售额
y = data.sales.values.reshape(-1,1)
#模型
model = LinearRegression().fit(x,y)
#斜率,轴距
print("斜率:{:.2f},轴距为{:.2f}".format(model.coef_[0,0],model.intercept_[0]))
print("类似于一次函数y={:.2f}x{:.2f}".format(model.coef_[0,0],model.intercept_[0]))

>>>>>>>>>>>>>>>>>>>>>>>>>
>print的输出结果
斜率:6.45,轴距为-5.30
类似于一次函数y=6.45x-5.30
>>>>>>>>>>>>>>>>>>>>>>>>>

5.画出模型的拟合线

x1 = x
y1 = model.predict(x)
plt.scatter(x=data.visitors,y=data.sales,c="r",marker="o",edgecolors="b")
plt.plot(x1,y1,linewidth=3)
plt.xlabel("访客量")
plt.ylabel("销售额")
plt.title("模型拟合线")
plt.savefig("out/2.png",dpi=100)

python学习—简单线性回归模型
6.预测

#预测客流量为100时候的销售额
Sales = model.predict([[100]])
print("预测客流量为100时候销售额为:{:.2f}".format(Sales[0,0]))

>>>>>>>>>>>>>>
>print的输出结果
预测客流量为100时候销售额为:639.82
>>>>>>>>>>>>>>

7.拟合优度

model.score(x,y)

>>>>>>>
0.9806125530325467
>>>>>>>