pytorchでprioritized experience replyを実装

元の論文はこちら

[1511.05952] Prioritized Experience Replay

DQNで学習を進めるための重要なテクニックとしてexperience replyというものがあり、これはメモリにためておいたstateやactionの記録をmini batchとしてランダムに取り出して学習させるというもの。

prioritized experience reply (PER)はこれを改善したもので、mini batchとして取り出すデータをランダムではなくTD誤差の大きさに基づく確率によって決める。TD誤差が大きい=理想的な出力と実際の差が大きい、ということなので優先的に学習に使うべきという考え方。

実装で特に参考にさせてもらったのはこちらのブログ

Let’s make a DQN: Double Learning and Prioritized Experience Replay | ヤロミル

保存した経験を効率的に取り出すための二分木(sum tree)の説明や実装もあり、とてもわかりやすかった。sum treeの実装についてはそのまま使わせて頂きました。

AI-blog/SumTree.py at 5aa9f0ba91f12ab4e24043134c0b33900a2f6236 · jaara/AI-blog · GitHub

実装の全体はこちら

environment

https://github.com/y-kamiya/machine-learning-samples/blob/7b6792ce37cc69051e9053afeddc6d485ad34e79/python3/reinforcement/dqn/cartpole_rainbow.py

agent

https://github.com/y-kamiya/machine-learning-samples/blob/7b6792ce37cc69051e9053afeddc6d485ad34e79/python3/reinforcement/dqn/agent.py

Prioritized Experience Reply

下記クラスを元々experience replyに使っていたRandomMemoryの代わりに使えばOK

なお、priorityの付け方としてProportional prioritizationの方法で実装した。

class PERMemory:
    epsilon = 0.0001
    alpha = 0.6
    size = 0

    # SumTreeについては参考にしたブログから拝借して必要なものをつけたし
    def __init__(self, capacity):
        self.tree = SumTree(capacity)

    # Proportional prioritizationによるpriorityの計算
    def _getPriority(self, td_error):
        return (td_error + self.epsilon) ** self.alpha

    # 新しい経験を入れる際は、必ず一度はreplyされるようにその時点で最大のpriorityで
    # reply開始前の場合は論文に従いpriority=1とした
    def push(self, transition):
        self.size += 1

        priority = self.tree.max()
        if priority <= 0:
            priority = 1

        self.tree.add(priority, transition)

    # 0 ~ priorityの合計値の間でbatch sizeの分だけ乱数を生成し、
    # それに合致するデータを取得する
    def sample(self, size):
        list = []
        indexes = []
        for rand in np.random.uniform(0, self.tree.total(), size):
            (idx, _, data) = self.tree.get(rand)
            list.append(data)
            indexes.append(idx)

        return (indexes, list)

    # 再生した経験のpriorityを更新
    def update(self, idx, td_error):
        priority = self._getPriority(td_error)
        self.tree.update(idx, priority)

    def __len__(self):
        return self.size

reply部分にpriorityの更新処理を追加

def reply(self):
    if len(self.memory) < self.config.steps_learning_start:
        return

    self.model.train()

    # batch size分だけデータを取り出す
    indexes, transitions = self.memory.sample(BATCH_SIZE)
    # lossの計算に必要な形に(DDQNとかの処理はこの中でやっている)
    values, expected_values = self._get_state_action_values(transitions)

    loss = F.smooth_l1_loss(values, expected_values)
    self.optimizer.zero_grad()
    loss.backward()
    self.optimizer.step()

    # replyしたデータのpriorityを更新
    if (indexes != None):
        for i, value in enumerate(values):
            td_error = abs(expected_values[i].item() - value.item())
            self.memory.update(indexes[i], td_error)

SumTreeに追加した処理はpriorityの最大値を取得する部分 リーフに該当するindexの中で最大値を取得する

# class SumTree
# self.index_leaf_start = capacity - 1
def max(self):
    return self.tree[self.index_leaf_start:].max()

bufferが大きくなるほど時間がかかるはずなので、問題になるかもと思って時間を測ったみた(bufferのcapacityが105の場合)

$ python -m timeit -s "import numpy as np; values = [float(x) for x in range(200000)]; npvalues = np.fromiter(values, np.float)" "npvalues[100000:].max()"
10000 loops, best of 3: 36 usec per loop

なので問題なし

Importance Sampling

priorityに基づく確率によりreplayに使うデータをサンプリングすることにより、priorityが大きなものほど学習に使われる頻度が高くなる。そのため、priorityの大きなデータほど学習への寄与が大きくなる(=bias)。

論文より引用

The estimation of the expected value with stochastic updates relies on those updates corresponding to the same distribution as its expectation. Prioritized replay introduces bias because it changes this distribution in an uncontrolled fashion

このbaisを是正するために、サンプリングされる確率の大きいデータほど、lossへの寄与が小さくなるように補正をかける。これにより全体としての期待値を、一様にサンプリングした場合に近づけることができる。

# class PERMemory
def sample(self, size, episode):
    list = []
    indexes = []
    weights = np.empty(size, dtype='float32')
    total = self.tree.total()
    # 初期の設定値から1に向けてannealing
    # episode単位で更新して最後のepisode時に1となるように
    beta = self.BETA + (1 - self.BETA) * episode / self.config.num_episodes

    for i, rand in enumerate(np.random.uniform(0, total, size)):
        (idx, priority, data) = self.tree.get(rand)
        list.append(data)
        indexes.append(idx)
        # 各データのωを計算
        weights[i] = (self.capacity * priority / total) ** (-beta)

    #  必ず1以下の値となるようωの最大値で割る
    return (indexes, list, weights / weights.max())

上記の呼び出し側

def loss(self, input, target, weights):
    if self.config.use_IS:
        # 論文だとω * δをすべてのデータで足し合わせているがここでは平均値を取る
        loss = torch.abs(target - input) * torch.from_numpy(weights).to(device=self.config.device)
        return loss.mean()

    return F.smooth_l1_loss(input, target)

def replay(self, episode):
    if len(self.memory) < self.config.steps_learning_start:
        return

    self.model.train()

    indexes, transitions, weights = self.memory.sample(BATCH_SIZE, episode)
    values, expected_values = self._get_state_action_values(transitions)

    # 各データに重みをかけるため独自の処理に
    loss = self.loss(values, expected_values, weights)
    ...

lossの計算のところは論文(のAlgorithm 1)だと以下のような式になっている

Accumulate weight-change ∆ ← ∆ + wj · δj · ∇θQ(Sj−1, Aj−1)

ただ、この最後の項については論文内で明確には説明されておらずよく理解できていない。 なので単純にω * δの和をとって試してみたのだが、学習がまったく進まなくなってしまった。そのため、ISを行わない場合に使っていたhubor lossと同じ程度のlossとなるようにTD誤差の絶対値を使い、かつ全体の和ではなく平均値を取るようにしたところ、学習が進むようになった。

結果

とりあえず簡単に計算が終わるcartpoleで試してみた 直近100episodeの経過steps数が150を超えた段階で終了とみなし、そこに到達するまでにかかったepisode数を比較した。同じ操作を100回行った平均値が以下

  • baseline:140
  • baseline + PER:140
  • baseline + PER + IS:120

ちなみにbaselineにはDDQNとDueling networkが含まれている。

PERだけを適用しても結果が変わらなかったが、ISも適用することで学習が早まった。ISも適用した場合については明らかに少ないepisode数で終わっているのでpriorityによる重み付けが効いているのだと思われる。PERのみ適用した場合は変化がなかったが、これはデータの偏りを補正しないことによる悪影響とpriorityをつけたことによる学習の立ち上がりの向上が相殺してそうなったのだろうか。(もしくは実装がおかしい?)

cartpoleでは学習する対象が小さいためこれらの効果見づらくなっている可能性はあるためatariでも試してみることにする。

atariで動かしてみた結果がこちら f:id:y-kamiya:20181008110450p:plain

縦軸は100 episodes毎に得たreward(壊したブロック数)の平均。 実行時間をへらすため10000episodesしかやってないが、PERを入れた方が明らかに学習の立ち上がりが早くなっている。 また、論文では108ステップとか学習させているので、10000 episodesくらいならannealingしないのと同じといえるかと思い、no anealingというデータもとってみた。

注意点として、episode数によって学習の終わりを決めていたので、終了までに行った総ステップ数(=学習時間)は一致していないこと。最も学習の進んだPERのみのもので900000、baselineのものは500000程度だった。なので論文に出ているstep数による取得rewardのグラフと比べる意味はないので参考までに。

一定ステップ数を経過したら評価用のepisodeを設けるという形でのデータもそのうちとってみる。