pytorchでmulti step learningを実装

rainbowの一要素であるmulti step learningを実装したのでメモ。 rainbowの論文はこちら

https://arxiv.org/pdf/1710.02298.pdf

Multi-step learningの項目で説明されており、参照として載っているのがこちら

http://www.incompleteideas.net/sutton/book/ebook/node73.html

通常1ステップ分の報酬とその次の最適な行動からloss計算に使う推定値を決めるが、それをnステップ分の報酬まで使うようにしたもの。 lossの数式的な違いはこちらにわかりやすく書かれていた

【深層強化学習】Ape-X 実装・解説 - Qiita

rainbowの要素の中ではdistributional RLと並んで影響の大きな要素として挙げられている。

実装

nステップ経過または遷移の終端に到達したら、その間の報酬を割引しつつ和を取る。 今回の実装では、遷移が必要分たまるまでは専用のバッファに置いておき、nステップ分の遷移が溜まったらその分の和を計算してreplay memoryに登録する形にした。 (ちなみにdopamineなどの実装を見ると、replay memoryに遷移を足すのは単純に1ステップずつ行い、学習のためのミニバッチを取得する段階でnステップ分の報酬を計算して返していた)

multi step learningのために変更する部分だけ以下に抜粋

def _get_multi_step_transition(self, transition):
    # 計算用のバッファに遷移を登録
    self.multi_step_transitions.append(transition)
    if len(self.multi_step_transitions) < self.config.num_multi_step_reward:
        return None

    next_state = transition.next_state
    nstep_reward = 0
    for i in range(self.config.num_multi_step_reward):
        r = self.multi_step_transitions[i].reward
        nstep_reward += r * self.config.gamma ** i

        # 終端の場合、それ以降の遷移は次のepisodeのものなので計算しない
        if self.multi_step_transitions[i].next_state is None:
            next_state = None
            break

    # 最も古い遷移を捨てる
    state, action, _, _ = self.multi_step_transitions.pop(0)
   
    # 時刻tでのstateとaction、t+nでのstate、その間での報酬の和をreplay memoryに登録
    return Transition(state, action, next_state, nstep_reward)

loss計算のための推定値を算出する部分 nに応じて割引率を累乗 reward_batchはnステップ分の報酬、next_state_valuesが時刻t+nでのstateを表す

...
gamma = self.config.gamma ** self.config.num_multi_step_reward
expected_values = reward_batch + gamma * next_state_values
...

agentの実装全体はこちら

machine-learning-samples/agent.py at b8b247b43f1379d67abe1aff15f7c3b74f632b6c · y-kamiya/machine-learning-samples · GitHub

結果

cartpoleで動かしてみた結果がこちら baseline: 242 episodes, 20915.5 steps baseline + multi step learning(n=3): 211 episodes, 19829.5 steps

1 episodeあたりの平均ステップ数が150を超えるまでにかかったepisode数、steps数を計測した。それを20回分回した際の中央値が上記の値。 multi step learningを加えた方がわずかに結果が向上している。

それぞれの実行時のコマンドもメモしておく

baseline
dqn/cartpole_rainbow.py --epochs 20 --episodes 5000 --steps_to_update_target 50 --nosave --replay_interval 4 --steps_learning_start 1000 --replay_memory_capacity 10000 --epsilon_end_step 100

baseline + multi step learning(n=3)
dqn/cartpole_rainbow.py --epochs 20 --episodes 5000 --steps_to_update_target 50 --nosave --replay_interval 4 --steps_learning_start 1000 --replay_memory_capacity 10000 --epsilon_end_step 100 --num_multi_step_reward 3