RL Atari

Gym 是 OpenAI 做的 RL tasks 模拟库 (虽然 Play Atari w/ Deep RL 等都是 Deepmind 发的),比如 CartPole, Robot Locomotion, Atari 2600 都有模拟。除了 Python 外在 C++ 也实现了此 API。
其中 ALE (Arcade-Learning-Environment) 是 Atari 2600 part,stable-baselines3 在这个模拟库的逻辑之上实现 RL 经典算法。

Gymnasium

https://gymnasium.farama.org/

為所有單 Agent RL Environments 提供 API,並實現了一些常用環境。

Env

Env 是 Gym 的核心,是個實現了 Markov 決策過程的類,能生成 initial state、給定 action 转移到下一个 state、可视化 environment。Wrapper 用来修改 environment,即修改 agent observations, rewards, actions。

使用 Gym 有四个关键函数 make(), Env.reset(), Env.step(), Env.render()

初始化环境

import gymnasium as gym
env = gym.make('CartPole-v1')

make() 返回一个可交互的 Env 对象,pprint_registry() 查看所有可以创建的环境。

可以指定 environment 内部定义的 render_mode,可以为 None, 字符串比如 human, rgb_array, rgb_array_list 等,以 rgb_arrayrgb_array_list 为例:

frame = env.render(render_mode='rgb_array')  # 返回单个图像帧 (3 维 numpy 数组)
frames = env.render(render_mode='rgb_array_list') # 从环境开始到当前 time step 的所有图像,[frame_1, frame_2, ..., frame_n],每个 frame 都是 3d numpy array

用 register 初始化的话, 可以用

gym.register(id="namespace/mandatory name-optional version", entry_point=class)

最好用 register 更灵活的扩展 register_envs,比如 gym.register(ale_py) 这种语法在 register 里不行。

gymnasium.pprint_registry() 能打印所有注册的环境,注册过才可以用 gymnasium.make() 创建,gymnasium.make_vec(, num_envs=n) 可以生成 n 个并行的相同环境。

环境交互

参考上面经典的 “agent-environment loop”,Agent 会收到一个环境的 observation,然后选择一个 action,environment 用它来决定 reward 和下一个 observation。

!pip install swig "gymnasium[box2d]"
import gymnasium as gym

env = gym.make("CartPole-v1", render_mode="human")
observation, info = env.reset() # reset(seed, options), 特定 environment 可以输入一个 dict 选择初始化的方式,reset 即得到 first observation

episode_over = False # 结束循环 flag
while not episode_over:
action = env.action_space.sample() # agent policy that uses the observation and info,随机从 action space 选一个
observation, reward, terminated, truncated, info = env.step(action) # 执行指定 action,一次 action-observation 的 exchange (给 Env 一个 action 交易到一个新 state / observation) 是一个 timestep

episode_over = terminated or truncated # terminated 是环境终止,比如 robot 掉进 lava 这种,truncated 用于在一定 timesteps 之后终止 Env

env.close() # 有时不希望 close Env 就用 reset 重新开始

动作和状态空间

任何 environment 都指定了有效的 actions 和 observations,存在 attributes action_spaceobservation_space 里,有助于理解环境期望的输入输出。

Env.action_spaceEnv.observation_space 都属于 Space 类的 instances,主要有两个函数 Space.contains()Space.sample()

Gym 支持很多种数据结构的 spaces:

  • Box
  • Discrete
  • MultiBinary
  • MultiDiscrete
  • Text
  • Dict
  • Tuple
  • Graph
  • Sequence

修改环境

Wrappers 可以修改 environment 而不改变其代码,大多数用 gymnasium.make() 产生的环境都是默认被 TimeLimit (定义一个 truncated timestep)、OrderEnforcing (reset 之前 step 或 render 会报错)、PassiveEnvChecker (检查 reset 是否返回符合规范的初始状态 obs 和信息 info,检查 step 是否返回格式正确的 (obs, reward, done(terminated, truncated), info),检查 render (可选) 是否是合理的可视化)。

其中 PassiveEnvChecker 这层 wrapper 可以在 make 时通过 , disable_env_checker=True 去掉。

除此之外比如 FlattenObservation,

import gymnasium as gym
from gymnasium.wrappers import FlattenObservation
env = gym.make("CarRacing-v3")
env.observation_space.shape
(96, 96, 3)
wrapped_env = FlattenObservation(env)
wrapped_env.observation_space.shape
(27648,)
wrapped_env
<FlattenObservation<TimeLimit<OrderEnforcing<PassiveEnvChecker<CarRacing<CarRacing-v3>>>>>>
wrapped_env.unwrapped
<gymnasium.envs.box2d.car_racing.CarRacing object at 0x7f04efcb8850>

unwrapped 可以得到最里面的 base environment。

AtariPreprocessing 也是个 warpper,会完成 Noop Reset (返回执行过随机次 no-op 的 initial state, default max 30 no-ops), Frame Skipping (4 by default), Resize to a square image (210x180 to 84x84 by default), Grayscale observation (默认转灰度图), Grayscale new axis (保留 channel 维度为 1,默认不保留) 等。

自定义环境

  • class CustomEnv(gym.Env) 实现 __init__ (包括定义 agent / target location 和 actions)
  • 构建 observations
  • 实现reset, step
  • 包装 wrappers
  • register / make

ALE

http://ale.farama.org/

rom_file = "Breakout.bin"  # WARNING: Possibly unsupported ROM: mismatched MD5.
from ale_py import ALEInterface, roms
ale = ALEInterface()
ale.loadROM(roms.get_rom_path("breakout")) # /Users/v2beach/.pyenv/versions/miniforge3-4.10/lib/python3.9/site-packages/ale_py/roms/breakout.bin
print(ale.getLegalActionSet()) # [<Action.NOOP: 0>, <Action.FIRE: 1>, <Action.UP: 2>, <Action.RIGHT: 3>, <Action.LEFT: 4>, <Action.DOWN: 5>, <Action.UPRIGHT: 6>, <Action.UPLEFT: 7>, <Action.DOWNRIGHT: 8>, <Action.DOWNLEFT: 9>, <Action.UPFIRE: 10>, <Action.RIGHTFIRE: 11>, <Action.LEFTFIRE: 12>, <Action.DOWNFIRE: 13>, <Action.UPRIGHTFIRE: 14>, <Action.UPLEFTFIRE: 15>, <Action.DOWNRIGHTFIRE: 16>, <Action.DOWNLEFTFIRE: 17>]
ale.act(0)
rgb_image = ale.getScreenRGB()
import matplotlib.pyplot as plt
plt.imshow(rgb_image)
plt.show()

可以这样每次执行一个动作都用 ALE Python interface 保存 screenshots。

用 Gym

用原始 API 比较方便,不用学两套 function,

import gymnasium as gym
import ale_py
gym.register_envs(ale_py)
env = gym.make_vec("ALE/Breakout-v5", num_envs=4)

# 或者make_vec_env,跟 register 一样是过时版本,但可能会遇到
make_vec_env("ALE/Breakout-v5", n_envs=4, vec_env_cls=stable_baselines3.common.vec_env.DummyVecEnv)

这个 env 其实就是对上面 ALEInterface() 的 wrapper 封装,用下面的方法可以直接操作 base env ALEInterface:

ale_interface = env.unwrapped.ale
current_state = ale_interface.cloneSystemState()

print(ale_interface.getAvailableModes()) # [0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44]
print(ale_interface.getAvailableDifficulties()) # [0, 1]
print(current_state.getCurrentMode())
print(current_state.getDifficulty())

ale_interface.setMode(ale_interface.getAvailableModes()[3])
ale_interface.setDifficulty(ale_interface.getAvailableDifficulties()[1]) # 但修改了模式和难度游戏内没发现变化

ale_interface = env.unwrapped.ale
current_state = ale_interface.cloneSystemState()
print(current_state.getCurrentMode())
print(current_state.getDifficulty())
# /Users/v2beach/.pyenv/versions/miniforge3-4.10/lib/python3.9/site-packages/ale_py/ 除了 roms/breakout.bin 就是 env.py
"""
render_mode: str => One of { 'human', 'rgb_array' }.
If `human` we'll interactively display the screen and enable game sounds. This will lock emulation to the ROMs specified FPS
If `rgb_array` we'll return the `rgb` key in step metadata with the current environment RGB frame.
"""

可以在训练的时候用 rgb array,推理 render 的时候用 human。

Frame Skipping and Preprocessing

这篇文章写得相当好。

Training an Agent

https://gymnasium.farama.org/introduction/train_agent/

Stable Baseline 3

pip install "stable-baselines3[extra]"

直接看代码,反正 Gym ALE Stable-baseline 的 doc 也是直接贴代码。

pre-trained zoo 里 breakout 模型都训了 10M 个 timesteps,作者是这个,打砖块代码分别存在 1, 2, 3 里面。

import gymnasium as gym
import ale_py # 这是 Env 类
gym.register_envs(ale_py)

# env = gym.make_vec('ALE/Breakout-v5', num_envs=4, frameskip=1, render_mode="human")
# env = gym.wrappers.AtariPreprocessing(env)

from gymnasium.wrappers import AtariPreprocessing
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv

def make_env():
env = gym.make('ALE/Breakout-v5', frameskip=1, render_mode="human")
env = AtariPreprocessing(env)
env = gym.wrappers.FrameStackObservation(env, stack_size=4)
return env
env = DummyVecEnv([make_env for _ in range(4)])

# print(type(env.observation_space), env.observation_space.shape, type(env.action_space), env.action_space)

from stable_baselines3 import DQN
# model = DQN('CnnPolicy', env, verbose=1)
# model.learn(total_timesteps=10_000_000, progress_bar=True)
# model.save("xxx-BreakoutNoFrameskip-v4")
model = DQN.load("dqn-BreakoutNoFrameskip-v4", env=env, custom_objects={
"optimize_memory_usage": False,
"handle_timeout_termination": False
})
# env = model.get_env() # 这里需要保证跟 model 的环境一致,用非 vec 的写法可以生成 model 但其实跟训练环境不一致
obs = env.reset()
done = False

while not done:
actions, _states = model.predict(obs, deterministic=True)
obs, rewards, dones, info = env.step(actions)
env.render("human") # img = env.render(mode='rgb_array') if rgb_array, 而且确保在初始化环境时正确配置了渲染模式, make(,render_mode=)

评论

Your browser is out-of-date!

Update your browser to view this website correctly. Update my browser now

×