pytorchでエラー(Unexpected key(s) in state_dict: "bn1.num_batches_tracked")

gpuで学習して出力したデータを、cpuでloadしようとしたらタイトルのようなエラーが出た

フォーラムで同じ問題が質問されていたので簡単に解決

Unexpected key in state_dict: "bn1.num_batches_tracked" - PyTorch Forums

こんな感じでload_state_dictにstrict=Falseを指定するだけ

load_state_dict(torch.load(path_to_data, map_location=device_name), strict=False)

余分なキーが含まれていても無視してくれるオプションのよう

torch.nn — PyTorch master documentation