'''
决策树分类:决策树分类模型会找到与样本特征匹配的叶子节点然后以投票的方式进行分类。
在样本文件中统计了小汽车的常见特征信息及小汽车的分类,使用这些数据基于决策树分类算法训练模型预测小汽车等级。
特征信息: 汽车价格 维修费用 车门数量 载客数 后备箱 安全性 汽车级别
案例:基于决策树分类算法训练模型预测小汽车等级。
1.读取文本数据,对每列进行标签编码,基于随机森林分类器进行交叉验证,模型训练.
2.自定义测试集,使用已训练的模型对测试集进行测试,输出结果。
'''
import numpy as np
import matplotlib.pyplot as mp
import sklearn.preprocessing as sp
import sklearn.ensemble as se
import sklearn.model_selection as ms
import sklearn.metrics as sm
import warnings
warnings.filterwarnings('ignore')
data =
[]
with open('./ml_data/car.txt',
'r') as f:
for line
in f.readlines():
sample = line[:-1].split(
',')
data.append(sample)
data =
np.array(data)
# print(data.shape)
# 整理好每一列的标签编码器encoders
# 整理好训练输入集与输出集
data =
data.T
# print(data.shape)
encoders =
[]
train_x, train_y =
[], []
for row
in range(len(data)):
encoder =
sp.LabelEncoder()
if row < len(data) - 1:
# 不是最后列
train_x.append(encoder.fit_transform(data[row]))
else:
# 是最后一列,作为输出集
train_y =
encoder.fit_transform(data[row])
encoders.append(encoder)
train_x =
np.array(train_x).T
# 训练随机森林分类器
model = se.RandomForestClassifier(max_depth=6, n_estimators=200, random_state=7
)
# 训练之前进行交叉验证
cv = ms.cross_val_score(model, train_x, train_y, cv=4, scoring=
'f1_weighted')
print(cv.mean())
model.fit(train_x, train_y)
# 自定义测试集,预测小汽车的等级
# 保证每个特征使用的标签编码器与训练时使用的标签编码器匹配
data =
[
['high',
'med',
'5more',
'4',
'big',
'low',
'unacc'],
['high',
'high',
'4',
'4',
'med',
'med',
'acc'],
['low',
'low',
'2',
'4',
'small',
'high',
'good'],
['low',
'med',
'3',
'4',
'med',
'high',
'vgood']]
data =
np.array(data).T
test_x, test_y =
[], []
for row
in range(len(data)):
encoder = encoders[row]
# 每列对应的标签编码器
if row < len(data) - 1
:
test_x.append(encoder.transform(data[row])) # 这里需要训练了,直接转换
else:
test_y =
encoder.transform(data[row])
test_x =
np.array(test_x).T
pred_test_y =
model.predict(test_x)
print(pred_test_y)
pred_test_y = encoders[-1
].inverse_transform(pred_test_y)
test_y = encoders[-1
].inverse_transform(test_y)
print(pred_test_y)
print(test_y)
输出结果:
0.7465877061619401
[2 0 0 3
]
['unacc' 'acc' 'acc' 'vgood']
['unacc' 'acc' 'good' 'vgood']
转载于:https://www.cnblogs.com/yuxiangyang/p/11194146.html