tensorflowで指定の位置の値だけを抽出して次元を減らす

元々やろうとしたことは、transformerの出力から各sentenceの末尾にあたるtokenのembeddingsだけを取り出すこと

tensorflowは今までそこまで書いたことはなかったため、やってみると意外と時間がかかったためメモ

例として以下のような形を考える

# 元のtensor (inputと呼ぶ)
<tf.Tensor: id=88, shape=(2, 3, 4), dtype=int32, numpy=
array([[[  1,   2,   3,   4],
        [ 11,  22,  33,  44],
        [111, 222, 333, 444]],

       [[  5,   6,   7,   8],
        [ 55,  66,  77,  88],
        [555, 666, 777, 888]]], dtype=int32)>

# 取り出したいtensor (outputと呼ぶ)
<tf.Tensor: id=90, shape=(2, 4), dtype=int32, numpy=
array([[ 11,  22,  33,  44],
       [555, 666, 777, 888]], dtype=int32)>

transformerの出力だとすると

  • 3つのtokenを持つ2つのsentenceが存在
  • embedding数は4
  • (1つ目、2つ目)のsentenceからそれぞれ(2つ目、3つ目)のtokenのembeddingのみを取り出す

という感じ

必要な処理は

sentence_lengths = [1, 2]
indices = list(zip(len(sentence_lengths), sentence_lengths))
output = tf.gather_nd(input, indices)

ちなみ上記の場合、sentence_lengthsはデータの前処理の際に作っておく前提

もしそれもデータから抽出したい場合は以下のようにできる

pad = 2
# token_idで表現されたinputとなるmini batch(ここから各sentenceで初めてpadが出てくるindexを取得したい)
batch = [[4,pad,pad],[5,7,pad]]

pad_indices = tf.where(tf.equal(b, pad))
>>> <tf.Tensor: id=155, shape=(3, 2), dtype=int64, numpy=
array([[0, 1],
       [0, 2],
       [1, 2]])>

sentence_lengths = tf.segment_min(pad_indices[:,1], pad_indices[:,0])
>>> <tf.Tensor: id=167, shape=(2,), dtype=int64, numpy=array([1, 2])>

こちらを参考にした
https://stackoverflow.com/questions/42184663/how-to-find-an-index-of-the-first-matching-element-in-tensorflow/42190780