MNIST是一个入门级的计算机视觉数据集,如图14-6所示,它包含各种手写数字图片,它也包含每一张图片对应的标签,告诉我们这个是数字几。我们使用神经网络来识别由MNIST组成的验证码。完整演示代码请见本书GitHub上的14-1.py。
图14-6 MNIST数据集
1.数据搜集和数据清洗
在线抓取最新的MNIST,并将前60000个样本作为训练样本,剩下的作为测试样本。其中图片大小为28×28,所以输入参数个数为784。
如果出现样本下载失败,可以直接去MNIST网站上下载,具体信息可以参考第3章内容:
import matplotlib.pyplot as plt from sklearn.datasets import fetch_mldata from sklearn.neural_network import MLPClassifier mnist = fetch_mldata("MNIST original") X, y = mnist.data / 255., mnist.target X_train, X_test = X[:60000], X[60000:] y_train, y_test = y[:60000], y[60000:]
2.特征化
实例化神经网络算法,隐藏层为一层,神经元个数为50:
mlp = MLPClassifier(hidden_layer_sizes=(50,), max_iter=10, alpha=1e-4, solver='sgd', verbose=10, tol=1e-4, random_state=1, learning_rate_init=.1)
3.训练模型
mlp.fit(X_train, y_train)
4.效果验证
验证效果:
print("Training set score: %f" % mlp.score(X_train, y_train)) print("Test set score: %f" % mlp.score(X_test, y_test))
准确率达到了97%左右,效果非常不错:
Training set score: 0.985733 Test set score: 0.971000