Wasserstein GAN(WGAN)是另外一种GAN的优化变体。WGAN彻底解决了GAN训练不稳定的问题,不再需要小心平衡生成器和判别器的训练程度,而且不需要精心设计的网络架构,不像DCGAN必须有BatchNormalization(批量正则化)。相对GAN,WGAN的主要变化是:
·生成器和判别器的损失函数中不取log。
·每次更新判别器的参数之后把它们的绝对值截断到不超过一个固定常数c。
·优化算法使用RMSProp或者SGD。
下面我们将介绍如何基于WGAN实现生成MNIST数据集的功能,相关代码在GitHub的code/keras-wgan.py。
1.Generator
我们创建Generator,架构如图13-28所示,参数如下:
·输入层大小为100。
·两个全连接层,结点数分别为1024和128×7×7即6272,激活函数均为tanh。
·改变形状为(7,7,128)。
·使用(2,2)进行上采样。
·使用64个大小为(5,5)进行卷积处理。
·使用(2,2)进行上采样。
·使用一个(5,5)进行卷积处理,生成一个(28,28,1)的图像数据。
代码如下:
def generator_model(): model = Sequential() model.add(Dense(input_dim=100, units=1024)) model.add(Activation('tanh')) model.add(Dense(128*7*7)) model.add(Activation('tanh')) model.add(Reshape((7, 7, 128), input_shape=(128*7*7,))) model.add(UpSampling2D(size=(2, 2))) model.add(Conv2D(64, (5, 5), padding='same')) model.add(Activation('tanh')) model.add(UpSampling2D(size=(2, 2))) model.add(Conv2D(1, (5, 5), padding='same')) model.add(Activation('tanh')) return model
2.Discriminator
我们创建Discriminator,架构如图13-29所示,参数如下:
·输入层大小为(28,28,1)。
·64个大小为(3,3),采样间距为(2,2)的卷积处理。
·32个大小为(3,3),采样间距为(2,2)的卷积处理。
·16个大小为(3,3),采样间距为(2,2)的卷积处理。
·1个大小为(3,3)的卷积处理。
·池化取平均值。
代码如下:
def discriminator_model(): model = Sequential() model.add( Conv2D(64, (3, 3),strides=(2, 2),padding='same', input_shape=(28, 28, 1)) ) model.add(LeakyReLU(0.2)) model.add(Conv2D(32, (3, 3),strides=(2, 2),padding='same')) model.add(LeakyReLU(0.2)) model.add(Conv2D(16, (3, 3),strides=(2, 2),padding='same')) model.add(LeakyReLU(0.2)) model.add(Conv2D(1, (3, 3),padding='same')) model.add(GlobalAveragePooling2D()) return model
图13-28 WGAN的Generator
图13-29 WGAN的Discriminator
3.对抗模型
WGAN的对抗模型实现非常简单,如图13-30所示,把Generator和Discriminator连接即可,不过需要将Discriminator参数设置为只允许手工更新,只有当设置trainable为Ture时才根据训练结果自动更新参数,代码如下:
def generator_containing_discriminator(g, d): model = Sequential() model.add(g) d.trainable = False model.add(d) return model
图13-30 WGAN的对抗模型
4.训练过程
首先定义Wasserstein距离,后面将使用它定义损失函数:
def wasserstein(y_true, y_pred): return K.mean(y_true * y_pred)
定义优化器,使用RMSprop,学习速率为5E-5,即10-5:
d_optim = RMSprop(lr=5E-5) g_optim = RMSprop(lr=5E-5)
定义Discriminator参数截断的范围 [1] :
c_lower = -0.1 c_upper = 0.1
Generator的优化函数使用RMSprop,损失函数使用mse;Discriminator的优化函数也使用RMSprop,损失函数使用Wasserstein距离;对抗模型的优化函数也使用RMSprop,损失函数也使用Wasserstein距离,代码如下:
g.compile(loss='mse', optimizer=g_optim) #gan的损失函数使用wasserstein d_on_g.compile(loss=wasserstein, optimizer=g_optim) d.trainable = True #d的损失函数使用wasserstein d.compile(loss=wasserstein, optimizer=d_optim)
WGAN的训练过程分为两步:第一步,生成一个大小为(BATCH_SIZE,100)的在-1~1之间平均分布的噪声,使用Generator生成图像样本,然后和同样大小的真实MNIST图像样本合并,分别标记为0和1(也可以标记为1和-1,有区分即可),对Discriminator进行训练。这个过程中Discriminator的trainable状态为True,训练过程会更新其参数。每次训练完Discriminator,需要将其参数按照指定范围截断,这个可以使用NumPy的clip函数完成,代码如下:
noise = np.random.uniform(-1, 1, size=(BATCH_SIZE, 100)) image_batch = X_train[index*BATCH_SIZE:(index+1)*BATCH_SIZE] generated_images = g.predict(noise, verbose=0) X = np.concatenate((image_batch, generated_images)) y = [-1] * BATCH_SIZE + [1] * BATCH_SIZE d_loss = d.train_on_batch(X, y) # 训练d之后 修正参数 wgan的精髓之一 for l in d.layers: weights = l.get_weights() weights = [np.clip(w, c_lower, c_upper) for w in weights] l.set_weights(weights)
第二步,生成一个大小为(BATCH_SIZE,100)的在-1~1之间平均分布的噪声,使用Generator生成图像样本,标记为1(也可以标记为-1,只要和第一步的对应上即可),欺骗Discriminator,这个过程针对对抗模型进行训练。这个过程中Discriminator的trainable状态为False,训练过程不会更新其参数。训练完成后将重新将Discriminator的trainable状态为True,代码如下:
noise = np.random.uniform(-1, 1, (BATCH_SIZE, 100)) d.trainable = False g_loss = d_on_g.train_on_batch(noise, [1] * BATCH_SIZE) d.trainable = True print("batch %d g_loss : %f" % (index, g_loss))
5.训练结果
我们使用GPU服务器进行训练和生成,这次我们使用了Python的PIL库,可以在字符界面的服务器上进行图像处理,但是Keras自带的plot_model函数无法在字符界面服务器运行。所以生成网络结构的过程我们还是在Mac本上运行。在GPU服务器上需要注释掉plot_model函数的相关代码。图13-31是WGAN训练1轮的结果,图13-32是训练30轮的结果。
图13-31 WGAN训练1轮的结果
图13-32 WGAN训练30轮的结果
[1] https://arxiv.org/abs/1701.07875