tensorflow2 0的一些高階函式用法

2021-09-29 18:22:13 字數 2287 閱讀 4577

最近在學習tensorflow2.0的時候看到一些特別好用的高階函式,這裡來記錄一下它們的用法

1.tf.gather()

tf.gather(params,indices,validate_indices=none,name=none,axis=0)

簡單的理解一下,首先傳入乙個需要處理的張量,然後傳入對他的選擇操作,也就是乙個索引張量。

下面舉個例子:

考慮班級成績冊的例子,共有 4 個班級,每個班級 35 個學生,8 門科目,儲存成績冊的張量 shape 為[4,35,8]。

#建立成績冊

record=tf.random.uniform([4,35,8],maxval=100)

record.numpy

如果現在需要收集第 1,2 兩個班級的成績冊,我們可以通過切片操作

record1_2=record[0:2]

record1_2.numpy

也可以使用tf.gather()得到一樣的結果

#從第乙個維度(班級)選擇前兩個班級

record1_2=tf.gather(record,[0,1],axis=0)

record1_2.numpy

但是換個要求,需要抽查所有班級的第 1,4,9,12,13,27 號同學的成績,這時候用切片就不好得到結果了,用gather還是很容易的

#從第二個維度(學生)抽取

score=tf.gather(record,[0,3,8,11,12,26],axis=1)

score.numpy

2.tf.gather_nd()

通過 tf.gather_nd(),可以通過指定每次取樣的座標來實現取樣多個點的目的

例子:得到班級 1,學生 1 的科目 2;班級 2,學生 2 的科目 3;班級 3,學生 3 的科目 4 的成績

score=tf.gather_nd(record,[[0,0,1],[1,1,2],[2,2,3]])

score.numpy

3.tf.scatter_nd()

通過 tf.scatter_nd(indices, updates, shape)可以高效地重新整理張量的部分資料,但是只能在全 0 張量的白板上面重新整理,因此可能需要結合其他操作來實現現有張量的資料重新整理功能。

#需要重新整理的位置

indices = tf.constant([[4], [3], [1], [7]])

# 構造需要寫入的資料

updates = tf.constant([4.4, 3.3, 1.1, 7.7]) 

# 在長度為 8 的全 0 向量上根據 indices 寫入 updates

tf.scatter_nd(indices, updates, [8])

4.tf.meshgrid()

通過 tf.meshgrid 可以方便地生成二維網格取樣點座標,或者可以理解成為了滿足矩陣相乘,把x按行重複y的列次,y按列重複x的行次(廣播機制)

例子:實現

z=sin(x2+y2)x2+y2 z=\frac

z= x 

2+y 

2sin(x 

2+y 2)

​    

import matplotlib.pyplot as plt

from mpl_toolkits.mplot3d import axes3d

plt.rcparams['axes.unicode_minus']=false

x = tf.linspace(-8.,8,100) # 設定 x 座標的間隔

y = tf.linspace(-8.,8,100) # 設定 y 座標的間隔

x,y = tf.meshgrid(x,y) # 生成網格點,並拆分後返回

print(x.shape,y.shape) # 列印拆分後的所有點的 x,y 座標張量 shape

z = tf.sqrt(x**2+y**2) 

z = tf.sin(z)/z # sinc 函式實現

fig = plt.figure()

ax = axes3d(fig)

# 根據網格點繪製 sinc 函式 3d 曲面

ax.contour3d(x.numpy(), y.numpy(), z.numpy(), 50)

plt.show()

或者來個簡單的例子更能體現它的變換

x=tf.constant([1,2,3])

y=tf.constant([3,4,5])

x,y = tf.meshgrid(x,y) 

print(x.numpy,y.numpy)

這樣meshgrid的作用就一目了然

Tensorflow2 0簡單應用 一

我是初學者 參考 1.匯入tf.keras tensorflow2推薦使用keras構建網路,常見的神經網路都包含在keras.layer中 最新的tf.keras的版本可能和keras不同 import tensorflow as tf from tensorflow.keras import l...

tensorflow2 0視訊記憶體設定

遇到乙個問題 新買顯示卡視訊記憶體8g但是tensorflow執行的時候介面顯示只有約6.3g的視訊記憶體可用,如下圖 即限制了我的視訊記憶體,具體原因為什麼我也不知道,但原來的視訊記憶體小一些的顯示卡就沒有這個問題。目前的解決辦法是 官方文件解決 然後對應的中文部落格 總結一下,就是下面的兩個辦法...

Tensorflow2 0 啟用函式

常用啟用函式及對應特點 神經網路結構的輸出為所有輸入的加權和,這導致整個神經網路是乙個線性模型。而線性模型不能解決異或問題,且面對多分類問題,也顯得束手無策。所以為了解決非線性的分類或回歸問題,啟用函式必須是非線性函式。神經網路中啟用函式的主要作用是提供網路的非線性建模能力。這是因為反向傳播演算法就...