完整演示代码请见本书GitHub上的16-1.py。
1.数据清洗与特征化
我们继续使用MNIST数据集,MNIST数据集详细介绍请阅读第3章相关内容。这次,我们利用TFLearn提供的API来获取MNIST数据集:
X, Y, testX, testY = mnist.load_data(one_hot=True)
第一次调用这个API的时候,会自动下载MNIST数据集到默认目录:
Downloading MNIST... Succesfully downloaded train-images-idx3-ubyte.gz 9912422 bytes. Extracting mnist/train-images-idx3-ubyte.gz Downloading MNIST... Succesfully downloaded train-labels-idx1-ubyte.gz 28881 bytes. Extracting mnist/train-labels-idx1-ubyte.gz Downloading MNIST... Succesfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes. Extracting mnist/t10k-images-idx3-ubyte.gz Downloading MNIST... Succesfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes. Extracting mnist/t10k-labels-idx1-ubyte.gz
后继再调用该API时会自动从默认目录中加载对应的文件:
Extracting mnist/train-images-idx3-ubyte.gz Extracting mnist/train-labels-idx1-ubyte.gz Extracting mnist/t10k-images-idx3-ubyte.gz Extracting mnist/t10k-labels-idx1-ubyte.gz
其中需要注意的是,原有数据集中将数据标记为0~9,分别代表0~9这10个数字。在神经网络中为了简化设计,一般将最后一层设计成0~1型开关结构,所以0~9可以用一个长度为10的向量表示,每一位分别代表不同的数字,这种编码叫做one-hot,也称为独热编码,编码转换关系如表16-1所示。
表16-1 MNIST标记数据one-hot编码
原有的图片文件大小是28×28,在DNN以及MLP的案例中,我们将其转换成维度为784的特征向量,如图16-5所示。
在RNN模型中,特别适合处理时序型数据,所以我们需要将数据格式进行转换,如图16-6所示。
图16-5 DNN中针对MNIST数据的处理
图16-6 RNN中针对MNIST数据的处理
在Python环境下实现这样的转换非常方便,只需要调用NumPy的函数np.reshape即可:
X = np.reshape(X, (-1, 28, 28)) testX = np.reshape(testX, (-1, 28, 28))
2.训练样本
构造RNN神经网络,使用LSTM算法。
设置输入参数的形状为28×28:
net = tflearn.input_data(shape=[None, 28, 28])
设置使用LSTM算法:
net = tflearn.lstm(net, 128, return_seq=True) net = tflearn.lstm(net, 128)
设置全连接网络:
net = tflearn.fully_connected(net, 10, activation='softmax')
设置输出节点,优化算法使用adam,损失函数使用categorical_crossentropy:
net = tflearn.regression(net, optimizer='adam', loss='categorical_crossentropy', name="output1")
创建神经网络实体:
model = tflearn.DNN(net, tensorboard_verbose=2)
调用fit函数训练样本:
model.fit(X, Y, n_epoch=1, validation_set=0.1, show_metric=True, snapshot_step=100)
其中主要参数的含义为:
·n_epoch,整个数据集合训练的次数;
·validation_set,验证数据集的比例,也可以直接填写集合,比如(testX,testY),
·show_metric,是否展现完整训练过程,
·snapshot_step,snapshot的训练步长。
3.验证效果
运行程序,效果演示如图16-7所示。
图16-7 RNN识别MNIST数据
其中训练数据集个数为55000,测试数据集合个数为10000:
Training samples: 55000 Validation samples: 10000
最终准确率达到了95%左右,相当不错了:
Training Step: 860 | total loss: 0.19170 | time: 179.915s | Adam | epoch: 001 | loss: 0.19170 - acc: 0.9450 | val_loss: 0.16047 - val_acc:0.9515 -- iter: 55000/55000