colaboratoryでapexのinstallに失敗する

タイトルの通りでapexのinstallに失敗したのでメモ。

実行したコマンドはこちら

!git clone https://github.com/NVIDIA/apex.git
!git checkout ebcd7f084bba96bdb0c3fdf396c3c6b02e745042 # 2021/09/18時点での最新
%cd apex
!pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./

エラーにかかわる部分の出力がこちら

 torch.__version__  = 1.9.0+cu102


    /tmp/pip-req-build-uufz822x/setup.py:67: UserWarning: Option --pyprof not specified. Not installing PyProf dependencies!
      warnings.warn("Option --pyprof not specified. Not installing PyProf dependencies!")

    Compiling cuda extensions with
    nvcc: NVIDIA (R) Cuda compiler driver
    Copyright (c) 2005-2020 NVIDIA Corporation
    Built on Mon_Oct_12_20:09:46_PDT_2020
    Cuda compilation tools, release 11.1, V11.1.105
    Build cuda_11.1.TC455_06.29190527_0
    from /usr/local/cuda/bin

    Traceback (most recent call last):
      File "<string>", line 1, in <module>
      File "/tmp/pip-req-build-uufz822x/setup.py", line 171, in <module>
        check_cuda_torch_binary_vs_bare_metal(torch.utils.cpp_extension.CUDA_HOME)
      File "/tmp/pip-req-build-uufz822x/setup.py", line 106, in check_cuda_torch_binary_vs_bare_metal
        "https://github.com/NVIDIA/apex/pull/323#discussion_r287021798.  "
    RuntimeError: Cuda extensions are being compiled with a version of Cuda that does not match the version used to compile Pytorch binaries.  Pytorch binaries were compiled with Cuda 10.2.
    In some cases, a minor-version mismatch will not cause later errors:  https://github.com/NVIDIA/apex/pull/323#discussion_r287021798.  You can try commenting out this check (at your own risk).
    Running setup.py install for apex ... error

colab上で使われるcudaが11.1だが、pytorchがcuda10.2でビルドされたものであるため失敗。ということでapex自体は特に関係なく、デフォルトで入ってるpytorchがなぜかcolab上で使われるcudaのバージョンと異なっているためだった。

ちなみにcudaとpytorchのバージョン確認方法はこちらの通り
https://qiita.com/ysit/items/a601cb59523cc1961556

!nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2020 NVIDIA Corporation
Built on Mon_Oct_12_20:09:46_PDT_2020
Cuda compilation tools, release 11.1, V11.1.105
Build cuda_11.1.TC455_06.29190527_0

!python -c 'import torch; print(torch.__version__) '
1.9.0+cu102

pytorchの公式ページからcuda11.1向けのインストール方法を確認してpip入れる
https://pytorch.org/

!pip3 install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html

これでapexのインストールは成功するようになった。