tf.gather和tf.gather_nd的詳細用法--tensorflow通過索引取tensor里的數(shù)據(jù)

在numpy里取矩陣數(shù)據(jù)非常方便,比如:

a = np.random.random((5, 4))
indices = np.array([0,2,4])

print(a)
#array([[0.47122875, 0.37836802, 0.18210801, 0.341471  ],
#      [0.56551837, 0.27328607, 0.50911876, 0.01179739],
#       [0.75350208, 0.9967817 , 0.94043434, 0.15640884],
#      [0.09511502, 0.96345098, 0.6500849 , 0.04084285],
#       [0.93815553, 0.04821088, 0.10792035, 0.27093746]])
print(a[indices])
#array([[0.47122875, 0.37836802, 0.18210801, 0.341471  ],
#      [0.75350208, 0.9967817 , 0.94043434, 0.15640884],
#       [0.93815553, 0.04821088, 0.10792035, 0.27093746]])

這樣就把矩陣a中的1,3,5行取出來了。

如果是只取某一維中單個索引的數(shù)據(jù)可以直接寫成tensor[:, 2], 但如果要提取的索引不連續(xù)的話,在tensorflow里面的用法就要用到tf.gather.

import tensorflow as tf
sess = tf.Session()
b = tf.gather(tf.constant(a), indices)                                                                                              

sess.run(b)                                                                                                                         
#Output
array([[0.47122875, 0.37836802, 0.18210801, 0.341471  ],
       [0.75350208, 0.9967817 , 0.94043434, 0.15640884],
       [0.93815553, 0.04821088, 0.10792035, 0.27093746]])

tf.gather_nd允許在多維上進行索引:
matrix中直接通過坐標取數(shù)(索引維度與tensor維度相同):

    indices = [[0, 0], [1, 1]]
    params = [['a', 'b'], ['c', 'd']]
    output = ['a', 'd']

取第二行和第一行:

    indices = [[1], [0]]
    params = [['a', 'b'], ['c', 'd']]
    output = [['c', 'd'], ['a', 'b']]

3維tensor的結果:

    indices = [[1]]
    params = [[['a0', 'b0'], ['c0', 'd0']],
              [['a1', 'b1'], ['c1', 'd1']]]
    output = [[['a1', 'b1'], ['c1', 'd1']]]


    indices = [[0, 1], [1, 0]]
    params = [[['a0', 'b0'], ['c0', 'd0']],
              [['a1', 'b1'], ['c1', 'd1']]]
    output = [['c0', 'd0'], ['a1', 'b1']]

另外還有tf.batch_gather的用法如下:
tf.batch_gather(params, indices, name=None)
Gather slices from params according to indices with leading batch dims.

This operation assumes that the leading dimensions of indices are dense,
and the gathers on the axis corresponding to the last dimension of indices.

#tf.batch_gather按如下運算:
result[i1, ..., in] = params[i1, ..., in-1, indices[i1, ..., in]]

Therefore params should be a Tensor of shape [A1, ..., AN, B1, ..., BM],
indices should be a Tensor of shape [A1, ..., AN-1, C] and result will be
a Tensor of size [A1, ..., AN-1, C, B1, ..., BM].

如果索引是一維的tensor,結果和tf.gather 是一樣的.

?著作權歸作者所有,轉載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務。

相關閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容