小白學PyTorch 18 TF2構建自定義模型

2021-10-09 23:26:34 字數 3435 閱讀 999

之前講過了如何用tensorflow構建資料集,然後這一節課講解如何用tensorflow2.0來建立模型。

tf2.0中建立模型的api基本上都放到了它的keras中了,keras可以理解為tf的高階api,裡面封裝了很多的常見網路層、常見損失函式等。 後續會詳細介紹keras的全面功能,本篇文章講解如何構建模型。

import tensorflow as tf

import tensorflow.keras as keras

class

mylayer

(keras.layers.layer)

:def

__init__

(self, input_dim=

32, output_dim=32)

:super

(mylayer, self)

.__init__(

) w_init = tf.random_normal_initializer(

) self.weight = tf.variable(

initial_value=w_init(shape=

(input_dim, output_dim)

, dtype=tf.float32)

, trainable=

true

)# 如果是false則是不參與梯度下降的變數

b_init = tf.zeros_initializer(

) self.bias = tf.variable(initial_value=b_init(

shape=

(output_dim)

, dtype=tf.float32)

, trainable=

true

)def

call

(self, inputs)

:return tf.matmul(inputs, self.weight)

+ self.bias

x = tf.ones((3

,5))

my_layer = mylayer(input_dim=5,

output_dim=10)

out = my_layer(x)

print

(out.shape)

>>

>(3

,10)

這個就是定義了乙個tf的網路層,其實可以看出來和pytorch定義的方式非常的類似:

上面**中實現的是乙個全連線層的定義,其中可以看到使用tf.random_normal_initializer()來作為引數的初始化器,然後用tf.variable來產生網路層中的權重變數,通過trainable=true這個引數說明這個權重變數是乙個參與梯度下降的可以訓練的變數。

我通過tf.ones((3,5))產生乙個shape為[3,5]的乙個全是1的張量,這裡面第一維度的3表示有3個樣本,第二維度的5就是表示要放入全連線層的資料(全連線層的輸入是5個神經元);然後設定的全連線層的輸出神經元數量是10,所以最後的輸出是(3,10)。

import tensorflow as tf

import tensorflow.keras as keras

class

cbr(keras.layers.layer)

:def

__init__

(self,output_dim)

:super

(cbr,self)

.__init__(

) self.conv = keras.layers.conv2d(filters=output_dim, kernel_size=

4, padding=

'same'

, strides=1)

self.bn = keras.layers.batchnormalization(axis=3)

self.relu = keras.layers.relu(

)def

call

(self, inputs)

: inputs = self.conv(inputs)

inputs = self.relu(self.bn(inputs)

)return inputs

class

mynet

(keras.model)

:def

__init__

(self,input_dim=3)

:super

(mynet,self)

.__init__(

) self.cbr1 = cbr(16)

self.maxpool1 = keras.layers.maxpool2d(pool_size=(2

,2))

self.cbr2 = cbr(32)

self.maxpool2 = keras.layers.maxpool2d(pool_size=(2

,2))

defcall

(self, inputs)

: inputs = self.maxpool1(self.cbr1(inputs)

) inputs = self.maxpool2(self.cbr2(inputs)

)return inputs

model = mynet(3)

data = tf.random.normal((16

,224

,224,3

))output = model(data)

print

(output.shape)

>>

>(16

,56,56

,32)

這個是構建了乙個非常簡單的卷積網路,結構是常見的:卷積層+bn層+relu層。可以發現這裡繼承的乙個tf.keras.model這個類。

model比layer的功能更多,反過來說,layer的功能更精簡專一。

現在說一說上面的**和pytorch中的區別,作為乙個對比學習、也作為乙個對pytorch的回顧:

總之,學了pytorch之後,再看keras的話,對照的keras的api,很多東西都直接就會了,兩者的api越來越相似了。

上面最後輸出是(16, 56, 56, 32),輸入的是224

×224

224\times 224

224×22

4的維度,然後經過兩個最大池化層,就變成了56×56

56\times 56

56×56了。

到此為止,我們現在應該是可以用keras來構建模型了。

C 小白學指標2

內容依然來自於英文版的 c primer 小弟愚鈍 各路大神多多指教 demo int i 88 int r i 是引用符號 int p 表示p是乙個指標 p i 是address of 取位址 符號 p i 是dereference符號 int r2 p 是宣告的部分 是dereference符號...

小白學C語言基礎2

語句 1.順序執行語句 2.分支選擇語句 條件成立則執行 if 二者執行其一 if else 多種情況 if else if else if 多種情況擇其一 if else if else if else if注意事項 i 如果分支語句只有一條語句時 可以省略 但是不建議省略 ii if condi...

小白學Linux 實踐2(開機啟動)

vsftpd服務自啟動的三種方法 linux系統ftp工具是必備軟體,vsfpt是諸多ftp工具中最受站長歡迎,使用非常方便的工具之一。我們都不想系統重啟或者某些原因導致ftp不能正常工作,那麼將vsftp加入開機啟動是非常必要的。vsftpd有兩種啟動方式 自啟動或者由xinetd服務啟動,修改配...