MNIST是一个入门级的计算机视觉数据集,它包含各种手写数字图片,也包含每一张图片对应的标签,告诉我们这个是数字几。我们使用朴素贝叶斯来识别由MNIST组成的验证码。完整演示代码请见本书GitHub上的7-6.py。
1.数据搜集和数据清洗
使用MNIST离线版的数据,具体下载地址以及数据集的介绍请参考第3章相关内容:
def load_data(): with gzip.open('../data/MNIST/mnist.pkl.gz') as fp: training_data, valid_data, test_data = pickle.load(fp) return training_data, valid_data, test_data
2.特征化
MNIST已经将24×24的图片特征化成长度为784的一维向量。
3.训练模型
使用NB训练:
training_data, valid_data, test_data=load_data() x1,y1=training_data x2,y2=test_data clf = GaussianNB() clf.fit(x1, y1)
4.效果验证
验证效果:
print cross_validation.cross_val_score(clf, x2, y2, scoring="accuracy")
准确率55%左右,效果非常不理想,事实上NB算法在非黑即白的二分类问题上使用广泛,但是在多分类问题上的表现确实不如其他算法:
[ 0.53684841 0.58385839 0.6043857 ]