TensorFlow2 0 自定義層與自定義網路

2021-09-27 01:37:25 字數 4231 閱讀 5043

自定義層函式需要繼承layers.layer,自定義網路需要繼承keras.model。

其內部需要定義兩個函式:

1、__init__初始化函式,內部需要定義構造形式;

2、call函式,內部需要定義計算形式及返回值。

#self def layer

class mydense(layers.layer):#inherit layers.layer

def __init__(self,input_dim,output_dim):#init

super(mydense,self).__init__()

self.kernal = self.add_variable('w',[input_dim,output_dim])

self.bias = self.add_variable('b',[output_dim])

def call(self,inputs,training=none):#compute

out = inputs @ self.kernal + self.bias

return out

#self def network

class mymodel(keras.model):#inherit keras.model

def __init__(self):#init

super(mymodel,self).__init__()

self.fc1 = mydense(input_dim=28*28,output_dim=512)

self.fc2 = mydense(input_dim=512, output_dim=256)

self.fc3 = mydense(input_dim=256, output_dim=128)

self.fc4 = mydense(input_dim=128, output_dim=64)

self.fc5 = mydense(input_dim=64, output_dim=32)

self.fc6 = mydense(input_dim=32, output_dim=10)

def call(self,inputs,training=none):#compute inputs.shape = [b,28*28]

x = self.fc1(inputs)

x = tf.nn.relu(x)

x = self.fc2(x)

x = tf.nn.relu(x)

x = self.fc3(x)

x = tf.nn.relu(x)

x = self.fc4(x)

x = tf.nn.relu(x)

x = self.fc5(x)

x = tf.nn.relu(x)

x = self.fc6(x)

return x

自定義的層和網路在使用上與正常一樣,並無任何區別。

import tensorflow as tf

from tensorflow import keras

from tensorflow.keras import layers,sequential,optimizers,datasets,metrics

def preprocess(x,y):

x = tf.cast(tf.reshape(x,[-1]),dtype=tf.float32)/255.

y = tf.cast(tf.one_hot(y,depth=10),dtype=tf.int32)

return x,y

#load_data

(x_train,y_train),(x_val,y_val) = datasets.mnist.load_data()

print('data: ',x_train.shape,y_train.shape,x_val.shape,y_val.shape)

db = tf.data.dataset.from_tensor_slices((x_train,y_train))

db = db.map(preprocess).shuffle(60000).batch(128)

db_val = tf.data.dataset.from_tensor_slices((x_val,y_val))

db_val = db_val.map(preprocess).batch(128)

#self def layer

class mydense(layers.layer):#inherit layers.layer

def __init__(self,input_dim,output_dim):#init

super(mydense,self).__init__()

self.kernal = self.add_variable('w',[input_dim,output_dim])

self.bias = self.add_variable('b',[output_dim])

def call(self,inputs,training=none):#compute

out = inputs @ self.kernal + self.bias

return out

#self def network

class mymodel(keras.model):#inherit keras.model

def __init__(self):#init

super(mymodel,self).__init__()

self.fc1 = mydense(input_dim=28*28,output_dim=512)

self.fc2 = mydense(input_dim=512, output_dim=256)

self.fc3 = mydense(input_dim=256, output_dim=128)

self.fc4 = mydense(input_dim=128, output_dim=64)

self.fc5 = mydense(input_dim=64, output_dim=32)

self.fc6 = mydense(input_dim=32, output_dim=10)

def call(self,inputs,training=none):#compute inputs.shape = [b,28*28]

x = self.fc1(inputs)

x = tf.nn.relu(x)

x = self.fc2(x)

x = tf.nn.relu(x)

x = self.fc3(x)

x = tf.nn.relu(x)

x = self.fc4(x)

x = tf.nn.relu(x)

x = self.fc5(x)

x = tf.nn.relu(x)

x = self.fc6(x)

return x

network = mymodel()

network.build(input_shape=[none,28*28])

network.summary()

#build network

network = sequential([

layers.dense(512,activation=tf.nn.relu),

layers.dense(256,activation=tf.nn.relu),

layers.dense(128,activation=tf.nn.relu),

layers.dense(64,activation=tf.nn.relu),

layers.dense(32,activation=tf.nn.relu),

layers.dense(10)

])network.build(input_shape=[none,28*28])

network.summary()

#input para

network.compile(optimizer=optimizers.adam(lr=1e-2),

loss = tf.losses.categoricalcrossentropy(from_logits=true),

metrics = ['accuracy'])

#run network

network.fit(db,epochs=20,validation_data=db_val,validation_freq=1)

tensorflow2 0 自定義層

無引數的自定義層可以使用 keras.layers.lambda函式 customized spftplus keras.layers.lambda lambda x tf.nn.softplus x print customized spftplus 1.0 1.0 1.0 0.0 0.1 0.2...

tensorflow2 0學習筆記 自定義求導

tensorflow2.0建立神經網路模型,tensorflow近似求導與keras.optimizers結合使用,實現自定義求導,使得模型訓練更加靈活。tensorflow2.0學習筆記 應用tensorflow近似求導介紹tensorflow求導的基本用法。import matplotlib a...

tensorflow2 0視訊記憶體設定

遇到乙個問題 新買顯示卡視訊記憶體8g但是tensorflow執行的時候介面顯示只有約6.3g的視訊記憶體可用,如下圖 即限制了我的視訊記憶體,具體原因為什麼我也不知道,但原來的視訊記憶體小一些的顯示卡就沒有這個問題。目前的解決辦法是 官方文件解決 然後對應的中文部落格 總結一下,就是下面的兩個辦法...