rainbowの一要素であるmulti step learningを実装したのでメモ。
rainbowの論文はこちら
https://arxiv.org/pdf/1710.02298.pdf
Multi-step learningの項目で説明されており、参照として載っているのがこちら
http://www.incompleteideas.net/sutton/book/ebook/node73.html
通常1ステップ分の報酬とその次の最適な行動からloss計算に使う推定値を決めるが、それをnステップ分の報酬まで使うようにしたもの。
lossの数式的な違いはこちらにわかりやすく書かれていた
【深層強化学習】Ape-X 実装・解説 - Qiita
rainbowの要素の中ではdistributional RLと並んで影響の大きな要素として挙げられている。
実装
nステップ経過または遷移の終端に到達したら、その間の報酬を割引しつつ和を取る。
今回の実装では、遷移が必要分たまるまでは専用のバッファに置いておき、nステップ分の遷移が溜まったらその分の和を計算してreplay memoryに登録する形にした。
(ちなみにdopamineなどの実装を見ると、replay memoryに遷移を足すのは単純に1ステップずつ行い、学習のためのミニバッチを取得する段階でnステップ分の報酬を計算して返していた)
multi step learningのために変更する部分だけ以下に抜粋
def _get_multi_step_transition(self, transition):
# 計算用のバッファに遷移を登録
self.multi_step_transitions.append(transition)
if len(self.multi_step_transitions) < self.config.num_multi_step_reward:
return None
next_state = transition.next_state
nstep_reward = 0
for i in range(self.config.num_multi_step_reward):
r = self.multi_step_transitions[i].reward
nstep_reward += r * self.config.gamma ** i
# 終端の場合、それ以降の遷移は次のepisodeのものなので計算しない
if self.multi_step_transitions[i].next_state is None:
next_state = None
break
# 最も古い遷移を捨てる
state, action, _, _ = self.multi_step_transitions.pop(0)
# 時刻tでのstateとaction、t+nでのstate、その間での報酬の和をreplay memoryに登録
return Transition(state, action, next_state, nstep_reward)
loss計算のための推定値を算出する部分
nに応じて割引率を累乗
reward_batchはnステップ分の報酬、next_state_valuesが時刻t+nでのstateを表す
...
gamma = self.config.gamma ** self.config.num_multi_step_reward
expected_values = reward_batch + gamma * next_state_values
...
agentの実装全体はこちら
machine-learning-samples/agent.py at b8b247b43f1379d67abe1aff15f7c3b74f632b6c · y-kamiya/machine-learning-samples · GitHub
結果
cartpoleで動かしてみた結果がこちら
baseline: 242 episodes, 20915.5 steps
baseline + multi step learning(n=3): 211 episodes, 19829.5 steps
1 episodeあたりの平均ステップ数が150を超えるまでにかかったepisode数、steps数を計測した。それを20回分回した際の中央値が上記の値。
multi step learningを加えた方がわずかに結果が向上している。
それぞれの実行時のコマンドもメモしておく
baseline
dqn/cartpole_rainbow.py --epochs 20 --episodes 5000 --steps_to_update_target 50 --nosave --replay_interval 4 --steps_learning_start 1000 --replay_memory_capacity 10000 --epsilon_end_step 100
baseline + multi step learning(n=3)
dqn/cartpole_rainbow.py --epochs 20 --episodes 5000 --steps_to_update_target 50 --nosave --replay_interval 4 --steps_learning_start 1000 --replay_memory_capacity 10000 --epsilon_end_step 100 --num_multi_step_reward 3