假设样本的特征符合高斯分布(正态分布)。
import numpy as np from sklearn.naive_bayes import GaussianNB #导入数据集生成工具 from sklearn.datasets import make_blobs #导入数据集拆分工具 from sklearn.model_selection import train_test_split X,y=make_blobs(n_samples=500,centers=5,random_state=8) X_train,X_test,y_train,y_test=train_test_split(X,y,random_state=8) gnb=GaussianNB() gnb.fit(X_train,y_train) print('模型得分:{:.3f}'.format(gnb.score(X_test, y_test))) #导入画图工具 import matplotlib.pyplot as plt #限定坐标轴范围,X共两列 x_min,x_max=X[:,0].min()-0.5,X[:,0].max()+0.5 y_min,y_max=X[:,1].min()-0.5,X[:,1].max()+0.5 #使用不同的背景色表示不同的分类 xx,yy = np.meshgrid(np.arange(x_min, x_max,.02),np.arange(y_min, y_max, .02)) z = gnb.predict(np.c_[(xx.ravel(),yy.ravel())]).reshape(xx.shape) plt.pcolormesh(xx,yy,z,cmap=plt.cm.Pastel1) plt.scatter(X_train[:,0],X_train[:,1],c=y_train,cmap=plt.cm.cool,edgecolor='k') plt.scatter(X_test[:,0],X_test[:,1],c=y_test,cmap=plt.cm.cool,marker='*',edgecolor='k') plt.xlim(xx.min(),xx.max()) plt.ylim(yy.min(),yy.max()) plt.title('Classifier: GaussianNB') plt.show()