pytorchでcategorical(distributional) dqnを実装

categorical dqnの論文はこちら A Distributional Perspective on Reinforcement Learning

https://arxiv.org/pdf/1707.06887.pdf

元のdqnでは報酬を一つの値として扱っているが、分布として扱うことによって学習のパフォーマンスが向上したというもの。

また、今回はrainbowの要素の一つとして実装しているのでこちらの論文も参考にしている

https://arxiv.org/pdf/1710.02298.pdf

実装

実装の全体はこちら https://github.com/y-kamiya/machine-learning-samples/blob/1e07bf023eb624b5f3b4fd9ef0fbcfa22ff98096/python3/reinforcement/dqn/atari_rainbow.py https://github.com/y-kamiya/machine-learning-samples/blob/1e07bf023eb624b5f3b4fd9ef0fbcfa22ff98096/python3/reinforcement/dqn/agent.py

categorical dqnに特に関係する部分は以下にまとめた。

また、実装やデバッグの際にはこちらのコードに大変お世話になったので先に載せておく https://github.com/Kaixhin/Rainbow/blob/41b781f4ad6e8219443d23eded1a62bab3afd8c9/agent.py

model

rainbowのアルゴリズム実装の一つとしてやっているため、既にdueling networkの実装は入っている状態。 categorical dqnのために変わった部分は以下2点

  • 分布を表現するために各actionをatoms数の要素を持つリストとする
  • 確率分布を表すために出力はsoftmaxを通したものとする
    • log_softmaxはloss計算のためのもの
    def forward(self, x, apply_softmax=ApplySoftmax.NONE):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view([-1, 2592])
        V = self.fcV2(F.relu(self.fcV1(x)))
        A = self.fcA2(F.relu(self.fcA1(x)))

        v = V.view(-1, 1, self.num_atoms)
        a = A.view(-1, self.num_actions, self.num_atoms)

        averageA = a.mean(1, keepdim=True)
        # 単にv + a - averageAだけだと、cuda環境ではエラーになったためexpandして次元をあわせた
        output = v.expand(-1, self.num_actions, self.num_atoms) + (a - averageA.expand(-1, self.num_actions, self.num_atoms))

        if apply_softmax == ApplySoftmax.NORMAL:
            return F.softmax(output, dim=2)

        if apply_softmax == ApplySoftmax.LOG:
            return F.log_softmax(output, dim=2)

        # num_atoms == 1 in this case
        return output.squeeze()

lossの適用部分

   def replay(self, episode):
       ...

        if self.config.use_categorical:
            # categorical dqnのメイン部分
            losses = self.loss_categorical(transitions)
            # PERのpriorityとしてKL lossをそのまま使う
            self._update_memory(indexes, losses)
            loss = (losses * torch.from_numpy(weights)).mean() if self.config.use_IS else losses.mean()
        else:
            ...

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

rainbowの論文の記述(p.3 The Integrated Agent)に従って、prioritized experience replay(PER)のpriorityとしてKL lossをそのまま使う

However, in our experiments
all distributional Rainbow variants prioritize transitions by
the KL loss,

Q valueの計算

    def _get_Q(self, model, model_input):
        model.reset_noise()

        if not self.config.use_categorical:
            return model(model_input)

        model_output = model(model_input, ApplySoftmax.NORMAL)
        # 確率分布model_outputとそれに対応するatomsの値をかけ合わせてQ値を算出
        return torch.sum(model_output * self.support, dim=2)

categorical dqnのメイン部分

論文のAlgorithm 1として書かれているものと同じように実装する。 以下は論文からスクショで引用

f:id:y-kamiya:20181225185830p:plain

    def loss_categorical(self, transitions):
        ...

        with torch.no_grad():
            # terminal state以外のstateだけ抽出
            non_final_next_state = torch.cat(next_states).to(torch.float32)

            # 最も価値の高いactionを抽出
            best_actions = self._get_Q(self.model, non_final_next_state).argmax(dim=1)

            self.target_model.reset_noise()
            p_next = self.target_model(non_final_next_state, ApplySoftmax.NORMAL)

            # (1) terminal state用に確率分布としてすべてのatomに同じ値を与えておく
            p_next_best = torch.zeros(0).to(self.config.device, dtype=torch.float32).new_full((batch_size, num_atoms), 1.0 / num_atoms)
            # terminal state以外はDDQNで計算したもので上書き
            p_next_best[non_final_mask] = p_next[range(len(non_final_next_state)), best_actions]

            # (2) terminal stateの場合にγ = 0.0とする
            gamma = torch.zeros(batch_size, num_atoms).to(self.config.device)
            gamma[non_final_mask] = GAMMA

            # 報酬を分布に直す
            Tz = (reward_batch.unsqueeze(1) + gamma * self.support.unsqueeze(0)).clamp(self.Vmin, self.Vmax)
            b = (Tz - self.Vmin) / self.delta_z
            l = b.floor().long()
            u = b.ceil().long()

            # (3) bの値がちょうど整数値だった場合にmの要素値が0となってしまうことを回避
            l[(l == u) * (0 < l)] -= 1
            # ↑の処理によってlの値は既に変更済みなため、↓の処理が同時に行われてしまうことはない
            u[(l == u) * (u < num_atoms - 1)] += 1

            m = torch.zeros(batch_size, num_atoms).to(self.config.device, dtype=torch.float32)
            # (4) ミニバッチの各要素毎に和を持っておくため、offsetを計算した上でmを一次元のリストにして扱う
            offset = torch.linspace(0, ((batch_size-1) * num_atoms), batch_size).unsqueeze(1).expand(batch_size, num_atoms).to(l)
            m.view(-1).index_add_(0, (l + offset).view(-1), (p_next_best * (u.float() - b)).view(-1))
            m.view(-1).index_add_(0, (u + offset).view(-1), (p_next_best * (b - l.float())).view(-1))

        self.model.reset_noise()
        log_p = self.model(state_batch, ApplySoftmax.LOG)
        log_p_a = log_p[range(batch_size), action_batch.squeeze()]

        # ミニバッチの要素毎にcross entropyを算出
        return -torch.sum(m * log_p_a, dim=1)

アルゴリズムの表にかかれているものとなるべく変数名を合わせて書いたのでなんとなく対応していることがわかるはず。

(1)について

terminal stateで得たrewardを正しくlossに反映させるために一様な確率分布を与えておく。 最初ここをtorch.zerosで生成していて学習がなかなか進まずハマった。ここが0だと失敗or成功した場合のrewardがすべて0として計算されることになるため、間違った方向に学習が進んでしまう。実際ここが0の場合は一時的に学習が進んでもその後未学習と同じ状態に戻るという挙動だった。

ちなみに上記のように一様な分布を与えておくことにより、mの値が以下のような形になる(atoms=11の場合で失敗時)

[ 0.0000,  0.0000,  0.0000,  0.0000,  0.5000,  0.5000,  0.0000, 0.0000,  0.0000,  0.0000,  0.0000]

4, 5番目のatomに0.5が入っているのは、失敗時のreward=-1.0, Vmin=-10, Vmax=10で学習させているため。 これによってterminal stateでQ=-1.0を表す分布が学習され、それがシミュレーション開始のstateに向けて伝播していく。

(2)について

categorical dqnの論文の以下に従っている(p.16 C. Algorithmic Details)

 Transitions to a terminal state are handled
with γt = 0.
(3)について

bの値とu, lの値の差分を隣接したatomに配分していくことで分布を作り出しているため、bとそのfloor, ceilを取った値が等しい場合(つまりちょうど整数のとき)に配分される値が0になって消えてしまう。これを防ぐために等しい場合はlの値を-1している。ただし、等しくなったindexが0番目のものだった場合にlの値がおかしくなってしまうため、その場合はuを+1する。

(4)について

真の値としてcross entropyの計算に使うmを算出する部分。 アルゴリズム表ではfor loopによって書かれているところだが、ここでは行列計算として処理しているのでpytorchのindex_addを使っている。 offsetは少しわかりづらいが、index_addで各ミニバッチの要素を区別して正しい位置に和をとっていくため必要。

index_addのドキュメント torch.Tensor — PyTorch master documentation

結果

atariのbreakoutで10000 episodes実行して結果を比較してみた f:id:y-kamiya:20181225224511p:plain

どちらにもDDQNとdueling networkの実装が含まれている。 atoms 51は名前の通りatom数を51にして実行したもの。 categorical dqnの方が立ち上がりも早く、最終結果も高くなった。

ちなみにatoms51を実行した際のコマンドはこちら

python dqn/atari-rainbow.py   --learning_rate 0.00025  --episodes 10000 --steps_to_update_target 10000 --steps_learning_start 20000 --replay_memory_capacity 350000  --atoms 51

google datalabからgpuインスタンス上で実行し、5時間程度の学習で1.3M steps程度

参考

Rainbow/agent.py at master · Kaixhin/Rainbow · GitHub

https://github.com/frankibem/categorical_dqn/blob/master/categorical/agent.py

正規分布間のKLダイバージェンス - Qiita

Softmax Cross Entropyを計算する - Qiita

https://web.stanford.edu/class/psych209/Readings/MnihEtAlHassibis15NatureControlDeepRL.pdf

DQNからRainbowまで 〜深層強化学習の最新動向〜

分散型DQNの論文を読む - mabonki0725の日記