pytorchでmulti step learningを実装
rainbowの一要素であるmulti step learningを実装したのでメモ。 rainbowの論文はこちら
https://arxiv.org/pdf/1710.02298.pdf
Multi-step learningの項目で説明されており、参照として載っているのがこちら
http://www.incompleteideas.net/sutton/book/ebook/node73.html
通常1ステップ分の報酬とその次の最適な行動からloss計算に使う推定値を決めるが、それをnステップ分の報酬まで使うようにしたもの。 lossの数式的な違いはこちらにわかりやすく書かれていた
rainbowの要素の中ではdistributional RLと並んで影響の大きな要素として挙げられている。
実装
nステップ経過または遷移の終端に到達したら、その間の報酬を割引しつつ和を取る。 今回の実装では、遷移が必要分たまるまでは専用のバッファに置いておき、nステップ分の遷移が溜まったらその分の和を計算してreplay memoryに登録する形にした。 (ちなみにdopamineなどの実装を見ると、replay memoryに遷移を足すのは単純に1ステップずつ行い、学習のためのミニバッチを取得する段階でnステップ分の報酬を計算して返していた)
multi step learningのために変更する部分だけ以下に抜粋
def _get_multi_step_transition(self, transition): # 計算用のバッファに遷移を登録 self.multi_step_transitions.append(transition) if len(self.multi_step_transitions) < self.config.num_multi_step_reward: return None next_state = transition.next_state nstep_reward = 0 for i in range(self.config.num_multi_step_reward): r = self.multi_step_transitions[i].reward nstep_reward += r * self.config.gamma ** i # 終端の場合、それ以降の遷移は次のepisodeのものなので計算しない if self.multi_step_transitions[i].next_state is None: next_state = None break # 最も古い遷移を捨てる state, action, _, _ = self.multi_step_transitions.pop(0) # 時刻tでのstateとaction、t+nでのstate、その間での報酬の和をreplay memoryに登録 return Transition(state, action, next_state, nstep_reward)
loss計算のための推定値を算出する部分 nに応じて割引率を累乗 reward_batchはnステップ分の報酬、next_state_valuesが時刻t+nでのstateを表す
... gamma = self.config.gamma ** self.config.num_multi_step_reward expected_values = reward_batch + gamma * next_state_values ...
agentの実装全体はこちら
結果
cartpoleで動かしてみた結果がこちら baseline: 242 episodes, 20915.5 steps baseline + multi step learning(n=3): 211 episodes, 19829.5 steps
1 episodeあたりの平均ステップ数が150を超えるまでにかかったepisode数、steps数を計測した。それを20回分回した際の中央値が上記の値。 multi step learningを加えた方がわずかに結果が向上している。
それぞれの実行時のコマンドもメモしておく
baseline dqn/cartpole_rainbow.py --epochs 20 --episodes 5000 --steps_to_update_target 50 --nosave --replay_interval 4 --steps_learning_start 1000 --replay_memory_capacity 10000 --epsilon_end_step 100 baseline + multi step learning(n=3) dqn/cartpole_rainbow.py --epochs 20 --episodes 5000 --steps_to_update_target 50 --nosave --replay_interval 4 --steps_learning_start 1000 --replay_memory_capacity 10000 --epsilon_end_step 100 --num_multi_step_reward 3