如何解决二分类中的样本不平衡问题
程序员文章站
2022-05-26 19:07:50
...
在搭建模型时,二分类中,经常会遇到目标变量的分类数量相差很大,比如分类是1的数量是5000,分类是0的数量是100,这样如果对数据的不平衡性不做处理,模型的效果也会很差。今天用一个案例来进行实操:
1、案例目的:
找出有资金需求的中小企业借贷户并销售其贷款产品
2、背景:
对于中小企业而言,要快速成长最需要的就是资金。若能找出这些有资金需求的中小企业公司户并销售其贷款产品,将能为银行带来不少的营收,并改善中小企业的经营。
3、数据说明
训练数据包含26,144笔客户资料;每笔客户资料包含26个字段(1个客户ID字段、24个输入字段及1个目标字段-VV(是否为SME公司户;1代表有资金需求的中小企业公司户,0则代表没有资金需求的中小企业公司户)。
![每个字段的含义如图]
(https://img-blog.csdnimg.cn/fd912d79e0634a9c819e03772516aa81.png?x-oss-process=image/watermark,type_d3F5LXplbmhlaQ,shadow_50,text_Q1NETiBAYnJ1Y2VsaW5fMTg4,size_20,color_FFFFFF,t_70,g_se,x_16)
需要说明一下area:区域代码第1码为大分类, 第1+2码为中分类, 依此类推
测试数据包含6,537笔客户资料;字段个数与训练数据相同,只有目标字段的值全部填“Withheld”。
建立一个分类预测模型,找出有资金需求的中小企业借贷户,并输出一个测试结果的档案(考生姓名_results.csv)。考生姓名_results.csv中只有两个字段,分别是客户ID以及预测客户是否是有资金需求的中小企业公司户。3.results.csv的形式如下:
ID,Predicted_Results
1,1
2,0
8,0
4、模型评估方式:
F-Measure,值越大越好
#导入第三方包
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity='all'
from scipy import stats
import pandas as pd
import numpy as np
from statsmodels.formula.api import ols
from statsmodels.stats.anova import anova_lm
from statsmodels.stats.multicomp import pairwise_tukeyhsd
import matplotlib.pyplot as plt
#导入数据
df_train = pd.read_csv('Training.csv')
#探查数据,数据总计26144行,每个字段的数据类型如下,
df_train.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 26144 entries, 0 to 26143
Data columns (total 26 columns):
ID 26144 non-null int64
area 24423 non-null object
ck 24503 non-null object
ck-saveall 24504 non-null object
ck-drawall 24449 non-null object
ck-savetime 24518 non-null object
ck-drawtime 24495 non-null object
ck-saveavg 24485 non-null object
ck-drawavg 24454 non-null object
ck-avg 24512 non-null object
dep-saveall 24572 non-null object
dep-drawall 24471 non-null object
dep-savetime 24546 non-null object
dep-drawtime 24370 non-null object
depsaveavg 24482 non-null object
depdrawavg 24426 non-null object
dep-avg 24498 non-null object
dep-9201 24504 non-null object
fed-9201 24510 non-null object
fed-avg 24522 non-null object
comp 24495 non-null object
ck-changame 24507 non-null object
ck-changtime 24466 non-null object
dep-changtime 24517 non-null object
VV 26144 non-null int64
dtypes: int64(2), object(24)
memory usage: 5.2+ MB
#通过查看数据,发现area的数据存在?和.符号脏数据,这里用众数填充一下:
df_train.groupby('area').size()
df_train["area"] = df_train["area"].replace("?","104")
df_train["area"] = df_train["area"].replace(".","104")
area
. 1623
0 6
100 445
10000 259
103 180
10300 154
104 951
10400 472
105 553
10500 326
106 869
10600 471
108 77
10800 78
110 404
11000 255
111 244
1110 1
11100 90
112 63
11200 30
114 104
11400 62
115 37
979 1
981 2
98100 1
98300 1
? 3288
Length: 489, dtype: int64
#剩余的空值和符号数据(,和.都用0来填充)
df_train= df_train.fillna('0') #填充空值
df_train = df_train.replace("?","0") #处理脏数据
df_train = df_train.replace(".","0")
#删除重复数据
df_train.drop_duplicates(keep = 'first', inplace = True)
#查看目标变量的分布比例
df_train.groupby('VV').size()
VV
0 25554
1 590
dtype: int64
#目标变量的分类严重失衡,所以要进行平衡样本
data_train = df_train.copy() #备份数据,方便修改和使用
data_train_Y = data_train.pop("VV").values.reshape([-1, ])
data_train_X = data_train.values
#重点来了,平衡二分类中的目标变量数据经常采用降采样和升采样,由于此次训练数据数量还可以,所以采用降采样
from imblearn.under_sampling import ClusterCentroids, RandomUnderSampler, EditedNearestNeighbours #降采样经常用的几种方式
cc = ClusterCentroids() #经测试,三个降采样中这个评分最高
X_resampled, y_resampled = cc.fit_sample(data_train_X, data_train_Y.reshape([-1]))
from collections import Counter
sorted(Counter(y_resampled).items())
data_train_X, data_train_Y = X_resampled, y_resampled
[(0, 590), (1, 590)]
from sklearn.model_selection import train_test_split
X_train, X_test, Y_train, Y_test = train_test_split(data_train_X, data_train_Y, test_size=0.2, random_state=0)
print(X_train.shape, X_test.shape, Y_train.shape, Y_test.shape)
(944, 25) (236, 25) (944,) (236,)
# # 特征标准化,采用最大最小值标准化,转化后的值范围(0,1)
from sklearn.preprocessing import MinMaxScaler
min_max_scaler = MinMaxScaler(copy=True, feature_range=(0, 1))
new_X_train = X_train
new_X_test = X_test
from sklearn.preprocessing import Normalizer
normalizer = Normalizer(copy=True, norm='l2').fit(new_X_train)
new_X_train = normalizer.transform(new_X_train)
new_X_test = normalizer.transform(new_X_test)
from sklearn.ensemble import RandomForestClassifier #集成boosting有以下几种常用model:AdaBoostClassifier,GradientBoostingClassifier,RandomForestClassifier,其他的分类模型例如逻辑回归、决策树、SVM等都可以拿来试试。
clf1 = RandomForestClassifier()
clf1.fit(new_X_train, Y_train)
_y = clf1.predict(new_X_test)
#导入评价指标F1、精确率、响应率
from sklearn.metrics import f1_score, precision_score, recall_score
y_true = Y_test.reshape([-1,])
y_pred = _y
print(f1_score(y_true, y_pred, labels=None, pos_label=1, average='binary', sample_weight=None))
print(precision_score(y_true, y_pred))
print(recall_score(y_true, y_pred))
0.894736842105
0.971428571429
0.829268292683
new_df_train_X, new_df_train_Y = ss_X.transform(data_train_X), data_train_Y # 所有数据
clf2 = RandomForestClassifier() #选定的模型
clf2.fit(new_df_train_X, new_df_train_Y) #new_df_train_X和new_df_train_Y是降采样后的结果
#导入测试数据,数据处理同训练数据的填充空值、处理脏数据、删除重复值等步骤,不再重复以上步骤
df_test = pd.read_csv('Test.csv')
df_test.info() #总计6537行
data_test_Y = df_test.pop("VV").values.reshape([-1, ]) # 转成 (120000, 1) array
data_test_X = df_test.values
origin_df = df_test.copy()
new_X = origin_df.values
# X = best_fit.transform(new_X)
new_fit_X=min_max_scaler.fit_transform(new_X)
new_fit_X = normalizer.transform(new_fit_X)
new_predict_y_result = cf1.predict(new_fit_X)
new_predict_y = pd.DataFrame({'predict': np.array(new_predict_y_result.reshape(-1,))})
new_predict_y
predict
0 0
1 1
2 1
3 1
4 1
5 1
6 1
7 0
8 1
9 1
10 1
11 1
12 1
13 1
14 1
15 1
16 1
17 1
18 1
19 1
20 1
21 1
22 1
23 1
24 1
25 1
26 1
27 1
28 1
29 1
... ...
6507 1
6508 1
6509 1
6510 1
6511 1
6512 1
6513 1
6514 1
6515 1
6516 1
6517 1
6518 1
6519 1
6520 1
6521 1
6522 1
6523 1
6524 1
6525 1
6526 1
6527 1
6528 1
6529 1
6530 1
6531 1
6532 1
6533 1
6534 1
6535 1
6536 1
6537 rows × 1 columns
new_predict_y.groupby('predict').size()
predict
0 134
1 6403
dtype: int64
这次没有做特征选择,大家可以选择一下特征,跟目标变量关系大的自变量,比如,ID这个自变量就跟目标变量的关系性很弱,所以可以删掉。
推荐阅读
-
如何解决hover在ie6中的兼容性问题
-
javascript中引用传递的问题如何解决
-
如何解决婚姻中的问题?教你三种非常好用的方法
-
APP运营中如何解决用户自主传播分享的问题?
-
解析如何在PHP下载文件名中解决乱码的问题
-
H5如何解决安卓中input输入框获取焦点后,底部固定定位的按钮顶起问题?
-
Android开发中如何解决Fragment +Viewpager滑动页面重复加载的问题
-
机器学习算法的分类:关于如何选择机器学习算法和适用解决的问题
-
企业网站推广中容易遇到什么样的问题 如何解决
-
IDEA,maven运行时缺少junit啊,servlet啊各种包,但是自己设置的仓库中明明有啊,这种问题如何解决