pytorchでindexのリストを指定してtensorの要素を取得
Double DQNの実装に必要になるちょっとした計算についてメモ
2つの2次元tensor x, yを用意し、"xの各行において最大の値を持つ要素"と同じ位置にあるyの要素を取得する
>>> x = torch.rand(3,5) >>> x tensor([[ 0.0778, 0.6633, 0.4953, 0.1461, 0.4691], [ 0.3024, 0.0295, 0.3526, 0.6040, 0.7512], [ 0.1778, 0.7783, 0.1738, 0.5278, 0.0372]]) >>> x.argmax(1) tensor([ 1, 4, 1]) >>> x.argmax(1, keepdim=True) tensor([[ 1], [ 4], [ 1]]) >>> y = torch.rand(3,5) >>> y tensor([[ 0.4005, 0.3994, 0.1083, 0.8888, 0.9239], [ 0.6046, 0.5906, 0.3089, 0.4983, 0.2159], [ 0.4500, 0.9791, 0.4029, 0.9614, 0.5124]]) >>> y.gather(1, x.argmax(1, keepdim=True)) tensor([[ 0.3994], [ 0.2159], [ 0.9791]])
ちなみに、最後にsqueezeをかければ単純な1次元のリストにできる