元々やろうとしたことは、transformerの出力から各sentenceの末尾にあたるtokenのembeddingsだけを取り出すこと
tensorflowは今までそこまで書いたことはなかったため、やってみると意外と時間がかかったためメモ
例として以下のような形を考える
# 元のtensor (inputと呼ぶ) <tf.Tensor: id=88, shape=(2, 3, 4), dtype=int32, numpy= array([[[ 1, 2, 3, 4], [ 11, 22, 33, 44], [111, 222, 333, 444]], [[ 5, 6, 7, 8], [ 55, 66, 77, 88], [555, 666, 777, 888]]], dtype=int32)> # 取り出したいtensor (outputと呼ぶ) <tf.Tensor: id=90, shape=(2, 4), dtype=int32, numpy= array([[ 11, 22, 33, 44], [555, 666, 777, 888]], dtype=int32)>
transformerの出力だとすると
- 3つのtokenを持つ2つのsentenceが存在
- embedding数は4
- (1つ目、2つ目)のsentenceからそれぞれ(2つ目、3つ目)のtokenのembeddingのみを取り出す
という感じ
必要な処理は
sentence_lengths = [1, 2] indices = list(zip(len(sentence_lengths), sentence_lengths)) output = tf.gather_nd(input, indices)
ちなみ上記の場合、sentence_lengthsはデータの前処理の際に作っておく前提
もしそれもデータから抽出したい場合は以下のようにできる
pad = 2 # token_idで表現されたinputとなるmini batch(ここから各sentenceで初めてpadが出てくるindexを取得したい) batch = [[4,pad,pad],[5,7,pad]] pad_indices = tf.where(tf.equal(b, pad)) >>> <tf.Tensor: id=155, shape=(3, 2), dtype=int64, numpy= array([[0, 1], [0, 2], [1, 2]])> sentence_lengths = tf.segment_min(pad_indices[:,1], pad_indices[:,0]) >>> <tf.Tensor: id=167, shape=(2,), dtype=int64, numpy=array([1, 2])>