6.2 Keras-rl智能体通用API

Keras-rl统一了智能体类的API,便于大家使用。下面我们将重点介绍最常用的几个API。

1.Fit函数

Fit函数的定义如下,主要用于强化学习的训练过程。Keras-rl把训练阶段和测试阶段严格区分开来了,训练阶段基于强化学习算法充分训练保存Q函数的深度神经网络,生成对应的模型。测试阶段直接使用已经训练好的模型,验证训练的效果。


fit(self, env, nb_steps, action_repetition=1, callbacks=None, verbose=1, visualize=False, nb_max_start_steps=0, start_step_policy=None, log_interval=10000, nb_max_episode_steps=None)

其中,比较重要的几个参数的定义介绍如下:

·env,对应的OpenAI Gym环境对象。

·nb_steps,训练的步数,注意这个不是学习的次数。

·verbose,调试信息详细程度,0为不显示,2为全部显示。

·visualize,是否可视化,如果希望训练阶段可以看到对应的环境的图像,需要设置为True。

·nb_max_episode_steps,一个学习周期内最多可以执行多少步,默认一个学习周期内会一直学习下去直到游戏玩死。

2.Test函数

Test函数的定义如下,主要用于强化学习的测试过程。Keras-rl把训练阶段和测试阶段严格区分开来了。测试阶段直接使用已经训练好的模型,验证训练的效果。


test(self, env, nb_episodes=1, action_repetition=1, callbacks=None, visualize=True, nb_max_episode_steps=None, nb_max_start_steps=0, start_step_policy=None, verbose=1)

其中,比较重要的几个参数的定义介绍如下:

·env,对应的OpenAI Gym环境对象。

·nb_episodes,测试阶段测试的次数。

·verbose,调试信息详细程度,0为不显示,2为全部显示。

·visualize,是否可视化,如果希望训练阶段可以看到对应的环境的图像,需要设置为True。

3.Compile函数

Compile函数的定义如下,主要编译用户自定义的深度神经网络。从这个函数也可以很明显地表明底层使用的是Keras。


compile(self, optimizer, metrics=[])

其中,比较重要的几个参数的定义介绍如下:

·Optimizer,与Keras中的定义相同,常见的有SGD、RMSprop、Adagrad和Adam等,完整列表请参考相关文献 [1]

·metrics,列表类型,支持多选,与Keras中的定义相同,常见的有accuracy、mae和acc等,完整列表请参考相关文献 [2]

[1] https://keras.io/optimizers/

[2] https://keras.io/metrics/