transformerを理解するために実装

自分の勉強のためにtransformerを実装してみたので引っかかったところなどメモ

実装の参考にしたのは主にこの3つ

各所わかりやすいものを参考にして実装した。ちなみにtransfomerの公式実装はあまり参考にしていないため、厳密にはtransformerと異なっているかもしれないがあしからず。

XLMを参考にしたのは、そもそもの始まりとして去年の夏頃にXLMを見て、その元となるtransformer自体を理解しようと始めたのがきっかけだったため (途中まで実装したものの、その後違うことを始めてしまって放っておいたためこんな時期になっただけで、そんなに大きな実装をしたわけではないですw)

コード全体はこちら

https://github.com/y-kamiya/transformer/tree/24d1d98eda4d69516b2255b3502a436beadb78a7

実際に翻訳学習用のデータを使った学習もやる予定だが、今回はテスト用のcopy taskが動くところまで
(copy task = 入力と同じもの出力するように学習)

モデルやコード全体の解説は既にいろんなところに上がっているので、自分で実装していて引っかかったところだけ書いておく。

maskについて

maskの役割としては以下の2つ

  1. paddingを計算から外すためのmask
  2. 未来の情報が見えないようにするためのmask

2.はdecoderのself-attentionで使われるもので、decodeの際に自分より後に来る単語の情報を使わないようにする。

1.はencoder, decoderどちらでも必要だが、実装上以下の2つが必要になる

  • encoder, decoderそれぞれのinputに対するmask
  • decoderのsource attentionに使われる翻訳元言語のinputに対するmask

source attentionは翻訳先言語の文とencode済みの翻訳元言語の文の間の関連度を計算して適切な文字を引いてくる処理なので、翻訳元でpaddingとなっているところは無視する必要がある。

私は最初に実装した際にこの2つを区別して考えておらず同じmaskを与えていたためうまくいかなかった。

decode時のinputの取扱い

decoderに渡すinputは翻訳先言語の文(=教師データ)であり、後続の文字を隠すmaskを使うことにより、i番目までの単語列からi+1番目の単語を予測して出力していることになる。

以下のように処理している

def __predict(self, x, y, causal):
    enc_output = self.encoder(x)
    # 最後の単語を除いた教師データを渡す
    dec_output = self.decoder(y[:, :-1], enc_output, x == PAD_ID, causal)

    # decoder.predictはembedding表現をvocabulary idにマッピングしているだけ
    return self.decoder.predict(dec_output)

def step(self, x, y):
    self.encoder.train()
    self.decoder.train()

    scores = self.__predict(x, y, True)
    nwords = (y[:, 1:] != PAD_ID).sum().item()
    # 出てきたscoreを最初の単語を除いた教師データと比較してlossを計算
    loss = self.criterion(scores, y[:, 1:], nwords)

最初に実装した際はこのあたりを理解しておらず、すべてのwordを渡してdecodeした結果に対し、教師データをそのまま突き合わせてlossを計算していた。 そのためBOSの次のwordはBOSとなるよう学習されてしまい、常にBOSばかりが生成されるという結果になってしまった。

翻訳文のgenerate

以下のような形でgenerateしていきEOSが出力された時点で終了

  • BOSのみをinputとしてdecodeして2 word目を取得
  • BOSの後ろにそのwordをつなげてinputとしdecodeして3 word目を取得
  • ...

処理はこんな感じ

def __generate(self, x):
    self.encoder.eval()
    self.decoder.eval()

    # EOSが出力されなかった場合の最大長
    max_len = self.config.n_words

    src_mask = x == PAD_ID
    enc_output = self.encoder(x)

    batch_size, _ = x.shape
    # 一文字目だけBOSで他はPADとして初期化
    generated = torch.empty(batch_size, max_len).fill_(PAD_ID)
    generated[:,0] = BOS_ID
    generated = generated.to(self.config.device)

    # 文毎にEOSが出力されるタイミングが異なるため記録しておく用
    unfinished_sents = torch.ones(batch_size, device=self.config.device)

    for i in range(1, max_len):
        dec_output = self.decoder(generated[:, :i], enc_output, src_mask, True)
        gen_output = self.decoder.predict(dec_output[:, -1])
        _, next_words = torch.max(gen_output, dim=1)

        # EOS以降の文字はすべてPADになるよう調整
        generated[:, i] = next_words * unfinished_sents + PAD_ID * (1 - unfinished_sents)

        unfinished_sents.mul_(next_words.ne(EOS_ID).long())
        if unfinished_sents.max() == 0:
            break

    return generated.to(dtype=torch.int)

copy taskを実行

一応実装した結果として書いておく

以下のように実行した結果がこちら

$ python transformer.py --train_test  --dataroot data/copy --batch_size 32 --epochs 100 --warmup_steps 500 --name test
start epoch 1
step: 5, loss: 8.18, tokens/sec: 5200.2, lr: 0.000316
step: 10, loss: 7.39, tokens/sec: 5733.7, lr: 0.000632
...
step: 3195, loss: 0.12, tokens/sec: 5471.3, lr: 0.012510
step: 3200, loss: 0.09, tokens/sec: 5266.6, lr: 0.012500
save model to data/copy/test.pth

f:id:y-kamiya:20200426105850p:plain 無事学習できた模様

# テスト用データ
$ cat data/copy/test.en
1 5 6 7 2
1 7 4 7 5 5 2
1 4 5 7 7 4 5 7 5 6

$ python transformer.py --generate_test  --dataroot data/copy --name test
input : [1, 5, 6, 7, 2, 3, 3, 3, 3, 3]
output: [1, 5, 6, 7, 4, 2, 3, 3, 3, 3]

input : [1, 7, 4, 7, 5, 5, 2, 3, 3, 3]
output: [1, 7, 4, 7, 5, 5, 2, 3, 3, 3]

input : [1, 4, 5, 7, 7, 4, 5, 7, 5, 6]
output: [1, 4, 5, 7, 7, 4, 5, 7, 5, 2]

うまくコピーできていた (最後の文字が6 -> 2となっているのは最後までEOSが出なかった場合に最終文字をEOSに変換してgenerateしているため)