TensorFlow Numpy中的axis的理解

2022-07-30 11:42:13 字數 3051 閱讀 5171

tensorflow中有很多函式涉及到axis,比如tf.reduce_mean(),其函式原型如下:

1

defreduce_mean(input_tensor,

2 axis=none,

3 keepdims=none,

4 name=none,

5 reduction_indices=none,

6 keep_dims=none):

其中axis表示的是,對該維度進行求均值(預設情況下,是對所有值求均值)。

除了tensorflow中,numpy中也經常遇到很多對矩陣操作的函式會涉及axis操作。比如np.mean(),其函式原型如下:

1

def mean(a, axis=none, dtype=none, out=none, keepdims=np._novalue):

想要弄清楚如何處理涉及axis(維度)的操作,必須先明白axis是什麼。

首先axis是維度,如果axis=0則對應著高; 如果axis=1則對應著行處理;如果axis=2則對應著列;如果axis=3…n(無法用直觀的圖來表示)。我相信很多人看到這還是會一頭霧水。什麼是高,行還有列。為了說明這個問題,我舉個列子:

data=[[[1,2,3],[11,22,33]],[[4,5,6],[44,55,66]],[[10,11,12],[100,110,120]],[[7,8,9],[77,88,99]]]

data_np=np.array(data)

print

(data_np)

[[[ 1 2 3]

[ 11 22 33]]

[[ 4 5 6]

[ 44 55 66]]

[[ 10 11 12]

[100 110 120]]

[[ 7 8 9]

[ 77 88 99]]]

如上面,可以將最外層[ ]去掉,可以發現有4組元素(這裡的元素是矩陣),你可以將其理解為高。

再從這3組元素中選取一組,比如選擇的是

[[ 1 2 3]

[ 11 22 33]]

然後將該組的最外層[ ]去掉,可以發現有2組元素分別為[ 1 2 3]和 [ 11 22 33],此時對應的是行。

在從這兩組元素中選組一組,比如選擇的是

[ 11 22 33]

現在無需去掉最外層的[ ]了,一眼就能看出裡面有3個元素。這就是對應的列。

理解了上面的分析後,很容易就知道(高,行,列)對應的其實就是改矩陣的shape.

print

(data_np.shape):

(4,2,3)

現在弄清楚了axis的值與(高,行,列)的關係後,再來分析tf.reduce_mean()或者np.mean()等函式是如何對axis進行操作的。

1 data=[[[1,2,3],[11,22,33]],[[4,5,6],[44,55,66]],[[10,11,12],[100,110,120]],[[7,8,9],[77,88,99]]]

23 data_tensor=tf.constant(data,dtype=tf.float32)

45 mean_axis0=tf.reduce_mean(data_tensor,axis=0)

6 mean_axis1=tf.reduce_mean(data_tensor,axis=1)

7 mean_axis2=tf.reduce_mean(data_tensor,axis=2)89

with tf.session() as sess:

10print

(sess.run(mean_axis0))

11print

(sess.run(mean_axis1))

12print(sess.run(mean_axis2))

針對上述**,我們先對axis=0維度的資料處理進行分析。

首先對上述data資料進行立體化變換,如下圖(本人本想用軟體來繪製3d的矩陣疊加效果,可惜找了很多軟體都不適合,也許是本人尋找的還不夠,歡迎有知道可以繪製3d的矩陣疊加效果的朋友們,能夠分享一下。感激…)

如上如,axis=0的維度資料求均值,

[[(1+4+10+7)/4         (2+5+11+8)/4       (3+6+12+9)/4]

[(11+44+100+77)/4 (22+55+110+88)/4 (33+66+120+99)/4]]

=[[ 5.5 6.5 7.5]

[58. 68.75 79.5 ]]

同理,對axis=1的維度資料求均值,

[[(1+11)/2    (2+22)/2    (3+33)/2]

[(4+44)/2 (5+55)/2 (6+66)/2]

[(10+100)/2 (11+110)/2 (12+120)/2]

[(7+77)/2 (8+88)/2 (9+99)/2]]

=[[ 6. 12. 18. ]

[24. 30. 36. ]

[55. 60.5 66. ]

[42. 48. 54. ]]

同理可得axis=2維度的資料平均值為(過程留給讀者去推,運算結果如下):

[[  2.  22.]

[ 5. 55.]

[ 11. 110.]

[ 8. 88.]]

在python的世界裡,有很多時候都需要對資料進行維度的操作,如果對axis理解的不透的話,很容易找不著方向。

理解numpy中的axis

對於m個元素一維陣列a,因為只有乙個軸,所以axis只能為0,和預設值效果相同,觀察的是0軸上0,1,i,m點對應的元素。產生的新集合就乙個元素。舉例 對於mxn的二維陣列 a,axis可以取值0或1。axis 0 相當於平面座標的y軸,變化的是 行 即觀察每一列不同行的元素。產生的新集合,其元素的...

Python中axis的意思

對於乙個 5,4,3,2 的陣列data dim 0 5 dim 1 4 dim 2 3 dim 3 2 axis 0,操作時只有第0維的下標變化其他不變,操作結束後變為 4,3,2 axis 1,操作時只有第1維的下標變化其他不變,操作結束後變為 5,3,2 axis 2,操作時只有第2維的下標變...

關於numpy中axis 0和axis的區別的問題

很多人在學習numpy時,遇到axis 0和axis 1的問題,究竟是如何定義的 如一下列子 import numpy as np ww np.arange 10 reshape 2,5 print ww 結果 array 0,1,2,3,4 5,6,7,8,9 ww.mean axis 0 沿著縱...