16.4 示例:生成城市名称

RNN具有记忆性,在经过大量训练后可以学习到时序数据的潜在规律,并且可以使用这种规律随机生成新的序列。下面我们举个非常有趣的例子,把美国现有的城市名称录入RNN,RNN学习到城市名称的潜在规律后,随机生成新的城市名称。图16-11是全美棒球联盟的队徽,几乎每个城市都有自己的棒球队,并且会以城市的名称命名球队。完整演示代码请见本书GitHub上的16-4.py。

图16-11 以城市命名的美国棒球队队徽

1.数据清洗与特征化

下载美国城市名称,下载链接为:

https://raw.GitHubusercontent.com/tflearn/tflearn.GitHub.io/master/resources/US_Cities.txt

保存文件为US_Cities.txt,文件内容如下:


Abbeville
Abbotsford
Abbott
Abbottsburg
Abbottstown
Abbyville
Abell
Abercrombie
Aberdeen
Aberfoil
Abernant
Abernathy

约定城市名称最长不超过20,逐行读取城市名称,将数据向量化,并生成对应的样本以及标记、字典:


path = "../data/US_Cities.txt"
maxlen = 20
string_utf8 = open(path, "r").read().decode('utf-8')
X, Y, char_idx = \
    string_to_semi_redundant_sequences(string_utf8, seq_maxlen=maxlen, redun_step=3)

2.训练样本

构造RNN,使用LSTM算法:


g = tflearn.input_data(shape=[None, maxlen, len(char_idx)])
g = tflearn.lstm(g, 512, return_seq=True)
g = tflearn.dropout(g, 0.5)
g = tflearn.lstm(g, 512)
g = tflearn.dropout(g, 0.5)
g = tflearn.fully_connected(g, len(char_idx), activation='softmax')
g = tflearn.regression(g, optimizer='adam', loss='categorical_crossentropy',
                       learning_rate=0.001)

对应的RNN结构如图16-12所示。

图16-12 自动生成城市名称的RNN结构示意图

实例化基于RNN的序列生成器,并使用对应的字典:


m = tflearn.SequenceGenerator(g, dictionary=char_idx,
                              seq_maxlen=maxlen,
                              clip_gradients=5.0,
                              checkpoint_path='model_us_cities')

3.验证效果

使用随机种子,通过RNN模型随机生成城市名称:


for i in range(40):
    seed = random_sequence_from_string(file_lines, maxlen)
    m.fit(X, Y, validation_set=0.1, batch_size=128,
          n_epoch=1, run_id='us_cities')
    print("-- TESTING...")
    print("-- Test with temperature of 1.2 --")
    print(m.generate(30, temperature=1.2, seq_seed=seed))
    print("-- Test with temperature of 1.0 --")
    print(m.generate(30, temperature=1.0, seq_seed=seed))
    print("-- Test with temperature of 0.5 --")
    print(m.generate(30, temperature=0.5, seq_seed=seed))

运行程序,学习现有城市名称,训练数据62106条,校验数据6901条:


Training samples: 62106
Validation samples: 6901

Temperature定义为新颖程度,Temperature越小,自动生成的城市名称越接近样本中的城市名称,Temperature越大,自动生成的城市名称与样本中的城市名称差别越大。当Temperature为1.2:


Efland
Egan
Egbert
Saagawie
Watwe ariplrna
Eifrac

Temperature为1.0:


Efland
Egan
Egbert
Fonlen Weno
Betemsli
Becmutnnd

Temperature为0.5:


Efland
Egan
Egbert
Landou
Parnsasn
Maron
Larerun