Keras-rl统一了智能体类的API,便于大家使用。下面我们将重点介绍最常用的几个API。
1.Fit函数
Fit函数的定义如下,主要用于强化学习的训练过程。Keras-rl把训练阶段和测试阶段严格区分开来了,训练阶段基于强化学习算法充分训练保存Q函数的深度神经网络,生成对应的模型。测试阶段直接使用已经训练好的模型,验证训练的效果。
fit(self, env, nb_steps, action_repetition=1, callbacks=None, verbose=1, visualize=False, nb_max_start_steps=0, start_step_policy=None, log_interval=10000, nb_max_episode_steps=None)
其中,比较重要的几个参数的定义介绍如下:
·env,对应的OpenAI Gym环境对象。
·nb_steps,训练的步数,注意这个不是学习的次数。
·verbose,调试信息详细程度,0为不显示,2为全部显示。
·visualize,是否可视化,如果希望训练阶段可以看到对应的环境的图像,需要设置为True。
·nb_max_episode_steps,一个学习周期内最多可以执行多少步,默认一个学习周期内会一直学习下去直到游戏玩死。
2.Test函数
Test函数的定义如下,主要用于强化学习的测试过程。Keras-rl把训练阶段和测试阶段严格区分开来了。测试阶段直接使用已经训练好的模型,验证训练的效果。
test(self, env, nb_episodes=1, action_repetition=1, callbacks=None, visualize=True, nb_max_episode_steps=None, nb_max_start_steps=0, start_step_policy=None, verbose=1)
其中,比较重要的几个参数的定义介绍如下:
·env,对应的OpenAI Gym环境对象。
·nb_episodes,测试阶段测试的次数。
·verbose,调试信息详细程度,0为不显示,2为全部显示。
·visualize,是否可视化,如果希望训练阶段可以看到对应的环境的图像,需要设置为True。
3.Compile函数
Compile函数的定义如下,主要编译用户自定义的深度神经网络。从这个函数也可以很明显地表明底层使用的是Keras。
compile(self, optimizer, metrics=[])
其中,比较重要的几个参数的定义介绍如下:
·Optimizer,与Keras中的定义相同,常见的有SGD、RMSprop、Adagrad和Adam等,完整列表请参考相关文献 [1] 。
·metrics,列表类型,支持多选,与Keras中的定义相同,常见的有accuracy、mae和acc等,完整列表请参考相关文献 [2] 。
[1] https://keras.io/optimizers/
[2] https://keras.io/metrics/