pytorch3dのTextureUVでSMPLのmeshにtextureを貼る

3dの人体形状モデルとして有名なものにSMPLというのがあり、形状(shape)と姿勢(pose)を指定することで様々な人体形状のmeshを作り出せる。
https://smpl.is.tue.mpg.de/

SMPL自体はmeshを生成するものだが、対応するtexture mapも配布されていて、3d生成関連の論文読んでいるとよく使われているのを見かける。なので勉強がてらpytorch3dを使ってSMPLのmodelにtextureを貼ってみたのでメモ。

pytorch3dのinstall

手順はこちらに従う
https://github.com/facebookresearch/pytorch3d/blob/28f914b/INSTALL.md

mac環境なのでcudaは気にせずpipで入れた

pip install torch==1.13.0 torchvision
MACOSX_DEPLOYMENT_TARGET=10.14 CC=clang CXX=clang++  pip install "git+https://github.com/facebookresearch/pytorch3d.git@stable"

2.0系だとビルド時にエラーが出たので1.13.0を使っておいた。pytorch3dのinstall時点でpytorchが入ってないと依存でエラーになるので先にpytorchだけ入れておく。

ちなみにエラー内容としてはc+17が有効になってないというものだった。私の環境ではclang 14.0が入っていてc+17がデフォルトになってないversionだったためと思われる。
https://cpprefjp.github.io/implementation.html#clang

clang 16.0以降を使う、またはビルド時にc+17を使うようフラグを渡せばpytorch2系でもOKなはずだが目的から外れるのでとりあえずこのまま。

SMPLの準備

python環境かつuv座標の情報が必要なので SMPL for Python Users > Download UV map in OBJ format からダウンロードする。ちなみに取得にはメアドでのユーザ登録が必要。
https://smpl.is.tue.mpg.de/download.php

zipを開くとobjファイルと一緒にtexture mapのサンプルも入ってるのでこれを使うことにする。

pytorch3dによる描画

denseposeのtextureをSMPLに貼り付けるチュートリアルがあるのでこれと同じ形でOK
https://pytorch3d.org/tutorials/render_densepose

densepose -> smplへのuv map変換処理を除いて単純にTexture, Meshを作ればよいはず

with Image.open(tex_path) as image:
    np_image = np.asarray(image.convert("RGB")).astype(np.float32)
tex = torch.from_numpy(np_image / 255.)[None].to(device)

verts, faces, aux = load_obj(obj_path)
texture = renderer.TexturesUV(maps=tex, faces_uvs=[faces.verts_idx], verts_uvs=[aux.verts_uvs])
mesh = Meshes([verts], [faces.verts_idx], texture)

ということで描画してみるとこうなる 網目が完全に崩れているので何かおかしい

TextureUVの定義はこれ
https://github.com/facebookresearch/pytorch3d/blob/297020a4b1d7492190cb4a909cafbd2c81a12cb5/pytorch3d/renderer/mesh/textures.py#L612

  • faces_uvs: 各faceを形作る頂点indexのtensor
  • verts_uvs: 頂点毎のuv座標を与えるtensor

と書いてあるので良さそうに思える。

実際に渡しているデータを調べてみるとvertsとverts_uvsのデータ数が合っていない

verts, faces, aux = load_obj(obj_path)
print(verts.shape)    # torch.Size([6890, 3])
print(aux.verts_uvs)  # torch.Size([7576, 2])

SMPLのmeshの頂点数は6890個なのでverts_uvsのデータ数がそれより多くなっている模様。facesのデータを見てみるとこう

verts, faces, aux = load_obj(obj_path)
print(faces.verts_idx.shape)          # torch.Size([13776, 3])
print(torch.max(faces.verts_idx))     # tensor(6889)
print(faces.textures_idx.shape)       # torch.Size([13776, 3]) 
print(torch.max(faces.textures_idx))  # tensor(7575)            

faces.textures_idxがaux.verts_uvsのindexと対応しているということだった。

見る方向による色の違いを表現するため、mesh頂点毎にuv座標を指定するのではなく、同じ頂点でも別のfaceの場合は異なるuv座標が対応するとのこと
https://github.com/pmh47/dirt/issues/46#issuecomment-540392773

TextureUVの引数を修正して

texture = renderer.TexturesUV(maps=tex, faces_uvs=[faces.textures_idx], verts_uvs=[aux.verts_uvs])

とすることでちゃんと網目の模様になった

import os
import torch
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from pytorch3d.structures import Meshes
import pytorch3d.renderer as renderer
from pytorch3d.io import load_obj


device = torch.device("cpu")
    

def create_renderer():
    R, T = renderer.look_at_view_transform(1.0, 0, 0) 
    cameras = renderer.FoVPerspectiveCameras(device=device, R=R, T=T)

    raster_settings = renderer.RasterizationSettings(
        image_size=512, 
        blur_radius=0.0, 
        faces_per_pixel=1, 
    )
    lights = renderer.PointLights(device=device, location=[[0.0, 0.0, 2.0]])

    return renderer.MeshRenderer(
        rasterizer=renderer.MeshRasterizer(
            cameras=cameras, 
            raster_settings=raster_settings
        ),
        shader=renderer.SoftPhongShader(
            device=device, 
            cameras=cameras,
            lights=lights
        )
    )


# downloadしたSMPLのzipを展開してこのpathに置いておく
DATA_DIR = "./data"
tex_path = os.path.join(DATA_DIR, "smpl/smpl_uv_20200910.png")
obj_path = os.path.join(DATA_DIR, "smpl/smpl_uv.obj")


with Image.open(tex_path) as image:
    np_image = np.asarray(image.convert("RGB")).astype(np.float32)
tex = torch.from_numpy(np_image / 255.)[None].to(device)

verts, faces, aux = load_obj(obj_path)
texture = renderer.TexturesUV(maps=tex, faces_uvs=[faces.verts_idx], verts_uvs=[aux.verts_uvs])
mesh = Meshes([verts], [faces.verts_idx], texture)

renderer = create_renderer()
images = renderer(mesh)
plt.figure(figsize=(10, 10))
plt.imshow(images[0, ..., :3].cpu().numpy())
plt.axis("off");
plt.show()