pix2pixを理解するために実装

基本的なGANの実装はやってみたので、今度は少し複雑になったpix2pixを実装してみる。 pix2pixは論文著者による実装が公開されており中身が実際にどうなっているのか勉強するはとても都合がよい。

著者の実装はcycleGANと共通になっており、また実験のための様々なオプションがついていたりするため汎用化されている部分が多く、ひと目で全体を把握という感じにはいかない。 今回は理解することを目的にpix2pixの部分だけを1ファイルに抜き出すような形で実装してみる。

pix2pixの論文やコードはここから辿れる

Image-to-Image Translation with Conditional Adversarial Networks

また、動かしてみるだけならこちらに従えば簡単にできる

https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/docker.md

なお、結果の項目にあるもの以外の英文と画像は論文から引用してあります。

pix2pixとは

abstractに書いてあるのはこんな感じ。

様々な種類の画像の変換タスクがあるものの、その中でやっていることは変換元の画像のpixelから変換先の画像のpixelを予測することであり、予測を最適化するための手法をタスクに応じて人間が考えて設定していた。pix2pixではcGANの手法で損失を学習させることで同じ方法で各種タスクの画像変換を可能にする。

論文で挙げてある様々な種類のタスク

  • ヒートマップからそれが示すリアル画像への変換
  • 白黒からカラーへの変換
  • 航空写真から地図画像への変換
  • 昼夜の変換
  • 輪郭画像から色付けした画像への変換

これらの変換を入力データを変えるだけで同じモデルで行える。

loss

generatorはdiscriminatorを騙すように、discriminatorはgeneratorの出力であることを見破るように学習させるためのGAN loss。それに加えてgeneratorの出力を正解ラベルとなる画像に近づけるためL1 lossを使って学習させた。

naiveなGANでは元の画像に近しい画像は出力できなかったため、conditional GANの手法(正解ラベルも与えて学習させる)をとった

We also test the effect of removing conditioning from the discriminator (labeled as GAN). In this case, the loss does not penalize mismatch between the input and output; it only cares that the output look realistic. This variant results in very poor performance;

また、2乗誤差(L2)を使うよりも絶対値の差(L1)を使う方がよりシャープ

Previous approaches to conditional GANs have found it beneficial to mix the GAN objective with a more traditional loss, such as L2 distance [29]. The discriminator’s job remains unchanged, but the generator is tasked to not only fool the discriminator but also to be near the ground truth output in an L2 sense. We also explore this option, using L1 distance rather than L2 as L1 encourages less blurring:

UNetというネットワーク構成

encoder-decoder型の構成のネットワークを使っている。UNetはencoderとdecoderの対応する層をskip connectionでつないだもの。

画像変換の場合、変換前の画像と変換後の画像は一定レベルで同じ要素を持つものになっている。例えば輪郭から色を塗るタスクの場合なら輪郭部分はまったく同じになることが望まれる。

そのため、完全に情報を圧縮した後のbottleneck層の情報だけを使うのではなく、encoderの各層の出力を対応するdecoderの入力にダイレクトに使うことで、より具体的な情報を元にしてdecodeできるようにした。

f:id:y-kamiya:20190407161416p:plain

PatchGAN

discriminatorによる判定の際、画像全体を一度にみて判定するのではなく、小さな領域に分けた上で各領域の判定を行い、それらを足し合わせることで全体を判定する。論文では1 * 1から256 * 256までの出力を比較した結果が載っており、その中で最も結果の良かった70 * 70を採用している。

f:id:y-kamiya:20190407161412p:plain

実装

ハンドバッグの輪郭から色付けをするタスクを例題として使う。 論文だとこのタスクには137Kの画像で15 epochs学習させたようだが、ちょっと多すぎるので400枚だけ使うことにし、その分epoch数を増やすことにする。

画像はこちらを参考にdownloadした

https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/docker.md

また、基本的に元の実装でデフォルトになっている設定に従って実装している。 pix2pixの学習時に明示的に渡しているオプション

https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/18a40e606eb5ef5214db84b1bb24b9f0e3641371/scripts/test_pix2pix.sh

各オプションのデフォルト値設定

https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/18a40e606eb5ef5214db84b1bb24b9f0e3641371/options/train_options.py

https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/18a40e606eb5ef5214db84b1bb24b9f0e3641371/options/base_options.py

https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/18a40e606eb5ef5214db84b1bb24b9f0e3641371/models/pix2pix_model.py#L32,L35

今回の実装についての全体はこちら

https://github.com/y-kamiya/machine-learning-samples/blob/51ff880b39ddd918a1cd31baa68bc809a9e1cb66/python3/deep/pytorch/pix2pix.py

ちょっと解説

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        # RGB画像なのでチャネル数は3
        self.down0 = nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1)

        # セットになっているものをup, downとしてまとめた
        self.down1 = self.__down(64, 128)
        self.down2 = self.__down(128, 256)
        self.down3 = self.__down(256, 512)
        self.down4 = self.__down(512, 512)
        self.down5 = self.__down(512, 512)
        self.down6 = self.__down(512, 512)
        self.down7 = self.__down(512, 512, use_norm=False)

        self.up7 = self.__up(512, 512)
        # down側の出力もconcatして入力にするため2倍の大きさ
        self.up6 = self.__up(1024, 512, use_dropout=True)
        self.up5 = self.__up(1024, 512, use_dropout=True)
        self.up4 = self.__up(1024, 512, use_dropout=True)
        self.up3 = self.__up(1024, 256)
        self.up2 = self.__up(512, 128)
        self.up1 = self.__up(256, 64)

        self.up0 = nn.Sequential(
            # RGB画像なのでチャネル数は3
            self.__up(128, 3, use_norm=False),
            nn.Tanh(),
        )

unet_256のタイプで実装したので、encoderもdecoderもそれぞれ8層。decoderにはdropout層をつける。norm_layerは何も設定されてないのでbatchとなる。

discriminatorで注意する点は入力チャネル数が6になること。discriminatorへの入力は、判定したい画像と正解画像をconcatしたものであるためRGB*2。 generatorとdiscriminatorの元の実装はここ

https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/18a40e606eb5ef5214db84b1bb24b9f0e3641371/models/networks.py#L149

https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/18a40e606eb5ef5214db84b1bb24b9f0e3641371/models/networks.py#L189

datasetの作成部分は元の実装のAlignedDataset。ここから今回の処理に必要なものだけとってきた。

 def __transform(self, param):
        list = []

        load_size = self.config.load_size
        list.append(transforms.Resize([load_size, load_size], Image.BICUBIC))

        (x, y) = param['crop_pos']
        crop_size = self.config.crop_size
        list.append(transforms.Lambda(lambda img: img.crop((x, y, x + crop_size, y + crop_size))))

        if param['flip']:
            list.append(transforms.Lambda(lambda img: img.transpose(Image.FLIP_LEFT_RIGHT)))

        list += [transforms.ToTensor(),
                 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]

        return transforms.Compose(list)

    def __transform_param(self):
        x_max = self.config.load_size - self.config.crop_size
        x = random.randint(0, np.maximum(0, x_max))
        y = random.randint(0, np.maximum(0, x_max))

        flip = random.random() > 0.5

        return {'crop_pos': (x, y), 'flip': flip}

     def __getitem__(self, index):
        AB_path = self.AB_paths[index]
        AB = Image.open(AB_path).convert('RGB')

        param = self.__transform_param()
        w, h = AB.size
        w2 = int(w / 2)
        # AとBのtransform時は全く同じにcrop & flipする必要があるため先にparamを計算しておいて渡す
        transform = self.__transform(param)
        A = transform(AB.crop((0, 0, w2, h)))
        B = transform(AB.crop((w2, 0, w, h)))

        return {'A': A, 'B': B, 'A_paths': AB_path, 'B_paths': AB_path}

入力画像を一度286 * 286にresizeし、その中から256 * 256をランダムで切り取って学習に使う。

PatchGANはどこに?

実装中、画像を分割して渡したりしてる部分は存在しないことに気づいた。調べてみるそのままスバリ書いてくれているブログがあった

【DeepLearning】Patch GANのPatchとは? - 0.5から始める機械学習

元となる解説は論文著者のgithubのissueとして上がっているこれ

https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/39

畳み込みで処理した結果として得られた特徴マップの1要素というのは、元となる画像のある範囲の領域だけの値から導出されたものとなる。なので以下の2つは等価

  • 入力画像をn2個の小さい領域に分割した上で、それぞれを1要素となるまで畳み込む
  • 入力画像をそのまま畳み込んでn * nの大きさとなるまで畳み込む

patchの大きさの計算はこちらのようになるらしい

https://github.com/phillipi/pix2pix/blob/master/scripts/receptive_field_sizes.m

今回実装したのは元の実装内でn_layers=3となっているものなのでこれに当てはまる

f = @(output_size, ksize, stride) (output_size - 1) * stride + ksize;

%% n=3 discriminator

% fix the output size to 1 and derive the receptive field in the input
out = ...
f(f(f(f(f(1, 4, 1), ...   % conv4 -> conv5
             4, 1), ...   % conv3 -> conv4
             4, 2), ...   % conv2 -> conv3
             4, 2), ...   % conv1 -> conv2
             4, 2);       % input -> conv1

fprintf('n=3 discriminator receptive field size: %d\n', out);

上記に従って計算するとpatchの大きさは70となる。

結果

以下のコマンドで実行して学習させた

python pix2pix.py --epochs 1000 --save_data_interval 100  --save_image_interval 100 --batch_size 4 --output_dir output --dataroot data 

google colab上で実行したので、終わらなかった部分は以下のようにmodelをロードしてさらに学習

python pix2pix.py --epochs 2000 --save_data_interval 100  --save_image_interval 100 --batch_size 4 --output_dir output --dataroot data
--generator output/pix2pix_G_epoch_1000 --discriminator output/pix2pix_G_epoch_1000

1000 epochで9.5h程度だった。ちなみにある程度近い画像が出力されるまで学習を進めるつもりだったので、論文とは違いlearning rateは減衰させずに進めた。

出力がこちら 10 epoch f:id:y-kamiya:20190407150522p:plain

100 epoch f:id:y-kamiya:20190407150525p:plain

1000 epoch f:id:y-kamiya:20190407150723p:plain

2000 epoch f:id:y-kamiya:20190407150735p:plain