pytorchでnoisy networkを実装

元の論文はこちら

[1706.10295] Noisy Networks for Exploration

常にその時点で価値の高い行動を取り続けた場合、最初に価値が高くなった行動が取られ続け、別の行動を取る可能性がなくなってしまう。それを防ぐため元のDQNではε-greedy法と呼ばれる手法を用いている。これは一定確率で価値に依らずランダムで行動を選択することにより、選ばれる行動の可能性を広げようとするものである。

noisy networkはこの部分を改良するもので、ネットワークそのものに学習可能なパラメータと共に外乱を与え、それも含めて学習させていくことでより長期的で広範囲に探索を進めようというもの。

We propose a simple alternative approach, called NoisyNet, where learned perturbations of the network weights are used to drive exploration. The key insight is that a single change to the weight vector can induce a consistent, and potentially very complex, state-dependent change in policy over multiple time steps

noiseを含んだ計算式はこちら

 \displaystyle y = (µ^{w} + σ^{w} ⦿ ε^{w}) x + µ^{b} + σ^{b} ⦿ ε^{b}

通常のLinearの計算y = wx + bのw, bの項の部分をそれぞれnoiseを含んだパラメータに置き換えたもの。

添字としてついているwとbは要素数を表しており、入力数をp、出力数をqとすると、wが付くものはq * p個、bが付くものはq個の要素を持つtensor。μとσは学習させるパラメータ、εは計算のたびに生成する乱数(ガウス分布)、⦿の演算子は要素毎の積を表す。

また、乱数の取り方について2種類言及されているが、DQNに適用するため論文に合わせてFactorised Gaussian noiseの方法で実装した。

We explore two options: Independent Gaussian noise, which uses an independent Gaussian noise entry per weight and Factorised Gaussian noise, which uses an independent noise per each output and another independent noise per each input. The main reason to use factorised Gaussian noise is to reduce the compute time of random number generation in our algorithms. This computational overhead is especially prohibitive in the case of single-thread agents such as DQN and Duelling

乱数生成の計算によるオーバーヘッドをへらすためというのがFactorised Gaussian noiseを使う理由。

実装

noiseを含んだLinear layerの実装はこちら

class FactorizedNoisy(nn.Module):
    def __init__(self, in_features, out_features):
        super(FactorizedNoisy, self).__init__()
        self.in_features = in_features
        self.out_features = out_features

        # 学習パラメータを生成
        self.u_w = nn.Parameter(torch.Tensor(out_features, in_features))
        self.sigma_w  = nn.Parameter(torch.Tensor(out_features, in_features))
        self.u_b = nn.Parameter(torch.Tensor(out_features))
        self.sigma_b = nn.Parameter(torch.Tensor(out_features))
        self.reset_parameters()

    def reset_parameters(self):
        # 初期値設定
        stdv = 1. / math.sqrt(self.u_w.size(1))
        self.u_w.data.uniform_(-stdv, stdv)
        self.u_b.data.uniform_(-stdv, stdv)

        initial_sigma = 0.5 * stdv
        self.sigma_w.data.fill_(initial_sigma)
        self.sigma_b.data.fill_(initial_sigma)

    def forward(self, x):
        # 毎回乱数を生成
        rand_in = self._f(torch.randn(1, self.in_features, device=self.u_w.device))
        rand_out = self._f(torch.randn(self.out_features, 1, device=self.u_w.device))
        epsilon_w = torch.matmul(rand_out, rand_in)
        epsilon_b = rand_out.squeeze()

        w = self.u_w + self.sigma_w * epsilon_w
        b = self.u_b + self.sigma_b * epsilon_b
        return F.linear(x, w, b)

   def _f(self, x):
       return torch.sign(x) * torch.sqrt(torch.abs(x))

初期値の設定は論文の3.2

For factorised noisy networks, each element µi,j was initialised by a sample from an independent uniform distributions  \displaystyle U[−\frac{1}{\sqrt{p}}, +\frac{1}{\sqrt{p}}] and each element  σ_{i,j} was initialised to a constant  \displaystyle \frac{σ_0}{\sqrt{p}}. The hyperparameter  σ_0 is set to 0.5.

これを使ってnoisy networkを作成したのが以下(dueling networkの実装も含む) ただし、ノード数などは完全には合わせてないので注意

class DuelingNetConv2d(nn.Module):
    def __init__(self, num_states, num_actions, is_noisy=False):
        super(DuelingNetConv2d, self).__init__()
        self.num_states = num_states
        self.num_actions = num_actions

        self.conv1 = nn.Conv2d(num_states, 16, kernel_size=8, stride=4)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=4, stride=2)
        if is_noisy:
            # 9 * 9 * 32 = 2592
            self.fcV1 = FactorizedNoisy(2592, 256)
            self.fcA1 = FactorizedNoisy(2592, 256)
            self.fcV2 = FactorizedNoisy(256, 1)
            self.fcA2 = FactorizedNoisy(256, num_actions)
        else:
            self.fcV1 = nn.Linear(2592, 256)
            self.fcA1 = nn.Linear(2592, 256)
            self.fcV2 = nn.Linear(256, 1)
            self.fcA2 = nn.Linear(256, num_actions)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view([-1, 2592])
        V = self.fcV2(self.fcV1(x))
        A = self.fcA2(self.fcA1(x))

        averageA = A.mean(1).unsqueeze(1)
        return V.expand(-1, self.num_actions) + (A - averageA.expand(-1, self.num_actions))

元々Linear layerを使っていた部分をすべて置き換えた。noisy networkの論文では3.1にこうかかれていた

We apply the following modifications to both DQN and Dueling: first, ε-greedy is no longer used, but instead the policy greedily optimises the (randomised) action-value function. Secondly, the fully connected layers of the value network are parameterised as a noisy network,

なのでこの通りであればdueling network中のvalue network側のみをnoisyなものに置き換えるのが正しそうだが、今回はrainbowを実装するための一要素として実装していたので、こちらの論文を参考にすべてのLinear layerを置き換えた。

[1710.02298] Rainbow: Combining Improvements in Deep Reinforcement Learning

(The Integrated Agentの項目の最後)

We then replace all linear layers with their noisy equivalent described in Equation (4). Within these noisy linear layers we use factorised Gaussian noise (Fortunato et al. 2017) to reduce the number of independent noise variables

結果

atariのbreakoutを20000 episodes分学習させた結果がこちら 縦軸が100 episodes毎の平均reward(崩したブロックの数)、横軸がepisode数 f:id:y-kamiya:20181013105913p:plain

ちなみに、prioritized experience reply(とimportance sampling)の実装も含んだ状態で学習させた。noisy networkを適用した方が立ち上がりが遅くなっているものの、baselineに比べて後半が右肩上がりの状態になっているように見える。rainbowの関連情報などを見ると、noisy networkは最終的なパフォーマンスを向上させるのに寄与すると言われているので、それに沿った結果にはなっていそう。

ただし注意点として、上記の学習はたった20000 episodes(steps数は1.5M程度)しかやっていないため、論文にかかれているようなスケール(100M stepsとか)からするとほとんど差が出ていない初期の部分に当てはまる(noisy networkの論文のグラフを見るとbaselineとnoisy network適用時の結果に差が出始めているのは10M steps経過くらいで、それまではほとんど差が見えない)

追記

noiseリセットのタイミングが間違っていることに気づいたため修正しました

reset noise appropriately by y-kamiya · Pull Request #1 · y-kamiya/machine-learning-samples · GitHub

元の実装だと、全結合層のforward計算毎にnoiseをリセットしていました。 しかし、論文のP.15のAlgorithm 1: NoisyNet-DQN / NoisyNet-Duelingにかかれているアルゴリズムの流れに従うと、一連に繋がっている各全結合層は同じnoiseを使うようになっていました。そのため、agent側から適切なタイミングでnoiseのリセット処理を呼ぶようにしました。

参考

[1706.10295] Noisy Networks for Exploration

Extending PyTorch — PyTorch master documentation

RL-Adventure/5.noisy dqn.ipynb at master · higgsfield/RL-Adventure · GitHub

Rainbow/model.py at master · Kaixhin/Rainbow · GitHub

速習 強化学習 ―基礎理論とアルゴリズム―

速習 強化学習 ―基礎理論とアルゴリズム―

  • 作者: Csaba Szepesvari,小山田創哲,前田新一,小山雅典,池田春之介,大渡勝己,芝慎太朗,関根嵩之,高山晃一,田中一樹,西村直樹,藤田康博,望月駿一
  • 出版社/メーカー: 共立出版
  • 発売日: 2017/09/21
  • メディア: 単行本
  • この商品を含むブログを見る