dueling networkを調べつつpytorchで実装

強化学習の性能を改善する方法の一つであるdueling networkについて調べたのでメモ。まず参考にしたのはこちらのqiitaの記事

【深層強化学習】Dueling Network 実装・解説

詳細な部分についてはこちらの論文を参考にした

[1511.06581] Dueling Network Architectures for Deep Reinforcement Learning

大枠としてやっていることは

  • ある状態においてagentのactionが結果に影響を与える場合と与えない場合を区別して計算する

ということ。

例えばレースゲームにおいて、真横に他の車がいる場合にそちらハンドルを切ればぶつかってクラッシュしてしまう。つまり自分の行動次第で得られる報酬は大きく変化する。一方、まわりに車がいない状態であればハンドルを切るという行動は結果にさほど影響を与えないといえる。

通常のDQNでは一つの全結合層を挟んでaction数分の出力を得るのみだが、dueling networkではそれに加えてもう一つ別の全結合層を用意して1つの出力も得る。

前者の出力をadvantage function A(s,a)、後者の出力をvalue function V(s)として、それらを足し合わせた結果として最終的な出力であるQ function Q(s,a)を得る。

\displaystyle Q(s,a) = V(s) + A(s,a) - \frac{1}{|A|}\sum_{a} A(s,a)

論文ではもっと厳密に書いてあるが、簡単のため直接関わらないθなどのパラメータは省略した。最後の項はすべてのactionに対してAの和をとってその要素数で割っているので、要するにAの平均値。

*追記
論文で書いてある実装と異なっていることに気づいたためネットワーク構成を修正しました。valueとadvantageの両方にそれぞれhidden layerが存在する形へ。

The value and advantage streams both have a fullyconnected layer with 512 units. The final hidden layers of the value and advantage streams are both fully-connected with the value stream having one output and the advantage as many outputs as there are valid actions2

これを実装したのが以下

from torch import nn
import torch.nn.functional as F

class DuelingNetFC(nn.Module):
    def __init__(self, num_states, num_actions):
        super(DuelingNetFC, self).__init__()
        self.num_states = num_states
        self.num_actions = num_actions

        self.fc1 = nn.Linear(self.num_states, 32)
        self.fcV1 = nn.Linear(32, 32)
        self.fcA1 = nn.Linear(32, 32)
        self.fcV2 = nn.Linear(32, 1)
        self.fcA2 = nn.Linear(32, self.num_actions)

    def forward(self, x):
        x = F.relu(self.fc1(x))

        V = self.fcV2(self.fcV1(x))
        A = self.fcA2(self.fcA1(x))

        averageA = A.mean(1).unsqueeze(1)
        return V.expand(-1, self.num_actions) + (A - averageA.expand(-1, self.num_actions))
        ```
OpenAI gymのCartPole-v0で動かしてみたため、convolutionは使わずすべて全結合層のネットワークになっている。通常のDQNでは上記コードのself.fcAの出力をそのままreturnすることになるが、それに対してVの出力の結果を足し合わせている。

ちなみにpytorchではinput, output共にミニバッチとして処理されるため、expandの第一引数はバッチ数を変化させないことを示す-1。バッチの中身一つで見ればVはスカラー値、Aはaction数分のtensorなので、Vの値をexpandでaction数分のtensorにしている。

これをcartpoleで動かしてみたところ、通常のDQNだと10回連続で成功するまでにかかるepisodeは200程度だったが、dueling networkを入れたところ130程度になった。
動かしたコードの全体はこちら

[https://github.com/y-kamiya/machine-learning-samples/blob/c261e42bc18cb09cb8575396045c9abd49faad9e/python3/reinforcement/cartpole_rainbow.py:title]

## なぜこの計算式でよいのか?
qiitaの記事などで調べたが、以下の式でQ(s,a)を計算する理由がよくわからなかったので論文を読んでみた。

 [tex:\displaystyle Q(s,a) = V(s) + A(s,a) - \frac{1}{|A|}\sum_{a} A(s,a)]

以下はこの答えに当てはまる部分を論文から引用したもの。直訳ではない部分ありますがあしからず。

>Using the definition of advantage, we might be tempted to
construct the aggregating module as follows:

>Q(s, a; θ, α, β) = V (s; θ, β) + A(s, a; θ, α), (7)

単純にこう算出したくなるが、それは以下の理由でよくない
>However, we need to keep in mind that Q(s, a; θ, α, β)
is only a parameterized estimate of the true Q-function.
Moreover, it would be wrong to conclude that V (s; θ, β)
is a good estimator of the state-value function, or likewise
that A(s, a; θ, α) provides a reasonable estimate of the advantage
function.

>Equation (7) is unidentifiable in the sense that given Q
we cannot recover V and A uniquely. To see this, add a
constant to V (s; θ, β) and subtract the same constant from
A(s, a; θ, α). This constant cancels out resulting in the
same Q value. This lack of identifiability is mirrored by
poor practical performance when this equation is used directly.

まず忘れてはならないのはQ, V, Aはどれもある条件における推定値であって真の値ではないということ。
また、Qの値からV, Aの値を一意に決めることはできない。(Vの値からある定数値を引き、かつAの値にその同じ定数値を足せば、同じQの値となる)そのため、これをそのまま使うと計算の効率が悪い。なので

>To address this issue of identifiability, we can force the advantage
function estimator to have zero advantage at the
chosen action. That is, we let the last module of the network
implement the forward mapping

> [tex:\displaystyle Q(s, a; θ, α, β) = V (s; θ, β) + A(s, a; θ, α) − max_{a'∈|A|} A(s, a'; θ, α)]. (8)

これを解決するために特定のactionをとった際にadvantage Aが0となるようにした。

>Now, for

> [tex:\displaystyle a* = arg max_{a'∈A} Q(s, a'; θ, α, β)] 

> =  [tex:\displaystyle  arg max_{a'∈A} A(s, a'; θ, α)]

>, we obtain Q(s, a∗; θ, α, β) =V (s; θ, β).

>Hence, the stream V (s; θ, β) provides an estimate
of the value function, while the other stream produces
an estimate of the advantage function.

特定のactionとはそのstateで取りうるactionの中でQを最大にするもので、このときQ(s,a) = V(s)となる。したがってVを算出した全結合層の出力はvalue functionの推定値と考えることができ、他方をadvantage functionの推定値と考えることができる。


不勉強なため上記の”したがって”が直感的には理解できなかったが、Qの最大値を引いた値を出力として扱うことによって、lossをネットワークにフィードバックする際の更新が効率よくできるようになる、ということだと思われる。

>An alternative module replaces the max operator with an
average:

> [tex:\displaystyle Q(s, a; θ, α, β) = V (s; θ, β) +A(s, a; θ, α) −
\frac{1}{|A|} \sum_{a'} A(s, a'; θ, α)]
. (9)

>On the one hand this loses the original semantics of V and
A because they are now off-target by a constant, but on
the other hand it increases the stability of the optimization:
with (9) the advantages only need to change as fast as the
mean, instead of having to compensate any change to the
optimal action’s advantage in (8).

また、Qの最大値を使うかわりにQの平均値を使うこともできる。
この場合、VとAの元々の意味合いは失われてしまうが、その分計算の安定性が向上する。lossをネットワークの最適化に適用する際、Qの最大値を使う場合に比べて平均値を使う場合の方が変化の割合が安定する。



以上のことより、ネットワークの重み更新を効率よくするために論文のような形の式で出力を扱っていることがわかった。