pytorchでデータ数を増やすとやけに学習時間が増えるバグ

pytorchで学習する処理を書いた際、データセット内のデータ数の増加により学習時間が増えた。

データ数増加で学習時間が増えるのは当然だろうと思うかもしれないが、今回書いていた処理はデータセットのすべてを学習に使わないもので、指定したbatch数分だけminibatchを作って学習するのを1epochにしていた。なのでデータセット内のデータ数が増えても処理にかかる時間はほとんど変わらない想定だったのでおかしい。

その増加分というのもやけに大きく、cProfileで1epoch分を計測してtottime順にみると以下のようになった。

# 8000件のデータセット
   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
       20    1.491    0.075    1.491    0.075 {method 'run_backward' of 'torch._C._EngineBase' objects}
        1    0.112    0.112    0.112    0.112 {built-in method _tkinter.create}
25000件のデータセット
   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
       20    8.083    0.404    8.083    0.404 {method 'run_backward' of 'torch._C._EngineBase' objects}
        1    0.104    0.104    0.104    0.104 {built-in method _tkinter.create}

backwardの計算にほとんどの時間がかかっている。

これを念頭にコードを見てみるとデータをロードしてtensorとする部分でrequires_grad=Trueとしてしまっていることに気づいた。なのでこれをFalseとして実行したところデータ数による処理時間の変動はなくなった。

requires_grad=Trueとすることでbackwardを行った際に計算グラフから勾配が計算されるが、データ全部にそれが適用されてしまったためデータ数に応じてbackwardの処理時間が増えるという事態になっていた。