張量形狀的理解與相關(guān)操作
一.張量的形狀的判斷

Drawing (1).png
這里的A,B,C分布表示維度0,1,2
那如何理解shape呢,由左圖我們可以看到
和A這個方括號同維度的有[1 2],[4 5]兩個,所以維度0的長度為2,
而和B同維度的有0,2兩個,所以維度1的長度為2
所以左邊張量的shape=[2,2]
同理右圖的張量,
和A方括號同維度的只有[[1 2 3] [4 5 6]],所以維度0長度為1
和B方括號同維度的有[1 2 3], [4 5 6]兩個,所以維度1長度為2
和C同維度的有1,2,3,所以維度2長度為3
所以shape=[1,2,3]
二. tf.squeeze(input, squeeze_dims=None, name=None),維度去除
1.去掉所有長度為1的維度(相當(dāng)于去除那個維度的括號)。
舉個栗子:
#coding=utf-8
import tensorflow as tf;
import numpy as np;
B = np.array([[[[1],[2],[3] ],[[4],[5],[6] ]]])
#去除維度0和維度3,因為這兩個維度長度都為1
y = tf.squeeze(B,0)
with tf.Session() as sess:
print (sess.run(y),'\n')
輸出:
[[1 2 3]
[4 5 6]]
2.也可以去掉指定索引的維度(該維度長度必須為1):
#coding=utf-8
import tensorflow as tf;
import numpy as np;
#shape=[1,2,3,1]
B = np.array([[[[1],[2],[3] ],[[4],[5],[6] ]]])
#去除維度0
y = tf.squeeze(B,[0])
with tf.Session() as sess:
print (sess.run(y),'\n')
輸出:
[[[1]
[2]
[3]]
[[4]
[5]
[6]]]
三. tf.expand_dims(input, dim, name=None),擴(kuò)展維度
作用:跟squeeze作用相反,它在維度dim上擴(kuò)展一個長度為1的維度,原維度dim則被排在后面
#coding=utf-8
import tensorflow as tf;
import numpy as np;
B = np.array([[3,4],[5,6]])
# 在維度0的元素前面加括號
y = tf.expand_dims(B,0)
y1 = tf.expand_dims(B,2)
#-1表示最后一維
y2 = tf.expand_dims(B,-1)
print(B,'\n')
with tf.Session() as sess:
print ('y:shape=',y.shape,'\n',sess.run(y),'\n')
print ('y1:shape=',y1.shape,'\n',sess.run(y1),'\n')
print ('y2:shape=',y2.shape,'\n',sess.run(y2),'\n')
輸出:
y:shape= (1, 2, 2)
[[[3 4]
[5 6]]]
y1:shape= (2, 2, 1)
[[[3]
[4]]
[[5]
[6]]]
y2:shape= (2, 2, 1)
[[[3]
[4]]
[[5]
[6]]]
四.tf.transpose(input, [dimension_1, dimenaion_2,..,dimension_n])
作用:交換維度
舉個栗子:
A = np.array([[[1,2,3],[4,5,6]]])
#即:
[[
[1 2 3]
[4 5 6]
]]
如果x=tf.transpose(A, [0,2,1])
1.那么首先找到維度0,2,1的長度
維度0的長度: 1
維度1的長度: 2
維度2的長度: 3
2.再按序?qū)懗?,2維的形狀的張量:
[
[
[ ]
[ ]
[ ]
]
]
3.若x的最后一維長度比A最后一維的長度小,則取將同列的元素按序放入x的最后一維,否則將x的同行元素按序放入最后一維,這里x和A的最后一維長度分別為2,3,所以將同列寫入最后一維,最后的結(jié)果為:
[[[1 4]
[2 5]
[3 6]]]
代碼驗證:
import tensorflow as tf;
import numpy as np;
A = np.array([[[1,2,3],[4,5,6]]])
x = tf.transpose(A, [0,2,1])
y = tf.transpose(A, [0,1,2])
with tf.Session() as sess:
print ('A:\n',A,'\n')
print ('x:\n',sess.run(x),'\n')
print ('y:\n',sess.run(y),'\n')
輸出:
A:
[[[1 2 3]
[4 5 6]]]
x:
[[[1 4]
[2 5]
[3 6]]]
y:
[[[1 2 3]
[4 5 6]]]