在numpy里取矩陣數(shù)據(jù)非常方便,比如:
a = np.random.random((5, 4))
indices = np.array([0,2,4])
print(a)
#array([[0.47122875, 0.37836802, 0.18210801, 0.341471 ],
# [0.56551837, 0.27328607, 0.50911876, 0.01179739],
# [0.75350208, 0.9967817 , 0.94043434, 0.15640884],
# [0.09511502, 0.96345098, 0.6500849 , 0.04084285],
# [0.93815553, 0.04821088, 0.10792035, 0.27093746]])
print(a[indices])
#array([[0.47122875, 0.37836802, 0.18210801, 0.341471 ],
# [0.75350208, 0.9967817 , 0.94043434, 0.15640884],
# [0.93815553, 0.04821088, 0.10792035, 0.27093746]])
這樣就把矩陣a中的1,3,5行取出來了。
如果是只取某一維中單個索引的數(shù)據(jù)可以直接寫成tensor[:, 2], 但如果要提取的索引不連續(xù)的話,在tensorflow里面的用法就要用到tf.gather.
import tensorflow as tf
sess = tf.Session()
b = tf.gather(tf.constant(a), indices)
sess.run(b)
#Output
array([[0.47122875, 0.37836802, 0.18210801, 0.341471 ],
[0.75350208, 0.9967817 , 0.94043434, 0.15640884],
[0.93815553, 0.04821088, 0.10792035, 0.27093746]])
tf.gather_nd允許在多維上進行索引:
matrix中直接通過坐標取數(shù)(索引維度與tensor維度相同):
indices = [[0, 0], [1, 1]]
params = [['a', 'b'], ['c', 'd']]
output = ['a', 'd']
取第二行和第一行:
indices = [[1], [0]]
params = [['a', 'b'], ['c', 'd']]
output = [['c', 'd'], ['a', 'b']]
3維tensor的結果:
indices = [[1]]
params = [[['a0', 'b0'], ['c0', 'd0']],
[['a1', 'b1'], ['c1', 'd1']]]
output = [[['a1', 'b1'], ['c1', 'd1']]]
indices = [[0, 1], [1, 0]]
params = [[['a0', 'b0'], ['c0', 'd0']],
[['a1', 'b1'], ['c1', 'd1']]]
output = [['c0', 'd0'], ['a1', 'b1']]
另外還有tf.batch_gather的用法如下:
tf.batch_gather(params, indices, name=None)
Gather slices from params according to indices with leading batch dims.
This operation assumes that the leading dimensions of indices are dense,
and the gathers on the axis corresponding to the last dimension of indices.
#tf.batch_gather按如下運算:
result[i1, ..., in] = params[i1, ..., in-1, indices[i1, ..., in]]
Therefore params should be a Tensor of shape [A1, ..., AN, B1, ..., BM],
indices should be a Tensor of shape [A1, ..., AN-1, C] and result will be
a Tensor of size [A1, ..., AN-1, C, B1, ..., BM].
如果索引是一維的tensor,結果和tf.gather 是一樣的.