tensorflowで指定の位置の値だけを抽出して次元を減らす
元々やろうとしたことは、transformerの出力から各sentenceの末尾にあたるtokenのembeddingsだけを取り出すこと
tensorflowは今までそこまで書いたことはなかったため、やってみると意外と時間がかかったためメモ
例として以下のような形を考える
# 元のtensor (inputと呼ぶ)
<tf.Tensor: id=88, shape=(2, 3, 4), dtype=int32, numpy=
array([[[ 1, 2, 3, 4],
[ 11, 22, 33, 44],
[111, 222, 333, 444]],
[[ 5, 6, 7, 8],
[ 55, 66, 77, 88],
[555, 666, 777, 888]]], dtype=int32)>
# 取り出したいtensor (outputと呼ぶ)
<tf.Tensor: id=90, shape=(2, 4), dtype=int32, numpy=
array([[ 11, 22, 33, 44],
[555, 666, 777, 888]], dtype=int32)>
transformerの出力だとすると
- 3つのtokenを持つ2つのsentenceが存在
- embedding数は4
- (1つ目、2つ目)のsentenceからそれぞれ(2つ目、3つ目)のtokenのembeddingのみを取り出す
という感じ
必要な処理は
sentence_lengths = [1, 2] indices = list(zip(len(sentence_lengths), sentence_lengths)) output = tf.gather_nd(input, indices)
ちなみ上記の場合、sentence_lengthsはデータの前処理の際に作っておく前提
もしそれもデータから抽出したい場合は以下のようにできる
pad = 2
# token_idで表現されたinputとなるmini batch(ここから各sentenceで初めてpadが出てくるindexを取得したい)
batch = [[4,pad,pad],[5,7,pad]]
pad_indices = tf.where(tf.equal(b, pad))
>>> <tf.Tensor: id=155, shape=(3, 2), dtype=int64, numpy=
array([[0, 1],
[0, 2],
[1, 2]])>
sentence_lengths = tf.segment_min(pad_indices[:,1], pad_indices[:,0])
>>> <tf.Tensor: id=167, shape=(2,), dtype=int64, numpy=array([1, 2])>