17.2 示例:hello world!卷积神经网络

完整演示代码请见本书GitHub上的17-1.py。

1.数据清洗与特征化

我们继续使用MNIST数据集,MNIST数据集的详细介绍请阅读第3章相关内容。这次我们利用TFLearn提供的API来获取MNIST数据集:


X, Y, testX, testY = mnist.load_data(one_hot=True)

第一次调用这个API的时候,会自动下载MNIST数据集到默认目录:


Downloading MNIST...
Succesfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.
Extracting mnist/train-images-idx3-ubyte.gz
Downloading MNIST...
Succesfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.
Extracting mnist/train-labels-idx1-ubyte.gz
Downloading MNIST...
Succesfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.
Extracting mnist/t10k-images-idx3-ubyte.gz
Downloading MNIST...
Succesfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.
Extracting mnist/t10k-labels-idx1-ubyte.gz

后继再调用该API时会自动从默认目录中加载对应的文件:


Extracting mnist/train-images-idx3-ubyte.gz
Extracting mnist/train-labels-idx1-ubyte.gz
Extracting mnist/t10k-images-idx3-ubyte.gz
Extracting mnist/t10k-labels-idx1-ubyte.gz

图片样本大小为28×28,TFLearn默认将其转换成维度为784的向量,为了处理方便,需要恢复成28×28的二维向量:


X = X.reshape([-1, 28, 28, 1])
testX = testX.reshape([-1, 28, 28, 1])

2.训练样本

构造CNN网络,定义输入层,大小为28×28:


net = tflearn.input_data(shape=[None, 28, 28, 1])

构造二维卷积函数:


net = tflearn.conv_2d(net, 64, 3, activation=’relu’, bias=False)

组装余下神经网络:


net = tflearn.residual_bottleneck(net, 3, 16, 64)
net = tflearn.residual_bottleneck(net, 1, 32, 128, downsample=True)
net = tflearn.residual_bottleneck(net, 2, 32, 128)
net = tflearn.residual_bottleneck(net, 1, 64, 256, downsample=True)
net = tflearn.residual_bottleneck(net, 2, 64, 256)
net = tflearn.batch_normalization(net)
net = tflearn.activation(net, 'relu')
net = tflearn.global_avg_pool(net)
net = tflearn.fully_connected(net, 10, activation='softmax')
net = tflearn.regression(net, optimizer='momentum',
                         loss='categorical_crossentropy',
                         learning_rate=0.1)

3.效果验证

训练并交叉验证效果,准确率达到了99%以上,非常不错:


model = tflearn.DNN(net, checkpoint_path='model_resnet_mnist',
                    max_checkpoints=10, tensorboard_verbose=0)
model.fit(X, Y, n_epoch=100, validation_set=(testX, testY),
          show_metric=True, batch_size=256, run_id='resnet_mnist')