Tensorflow 變數的共享

2022-07-07 03:54:19 字數 2392 閱讀 7725

tensorflow-exp/example/sparse-tensor-classification/train-validate.py

當你需要train的過程中validate的時候,如果用placeholder來接收輸入資料

那麼乙個compute graph可以完成這個任務。如果你用的是tfrecord的方式

輸入嵌入到compute graph,那麼對應input(for train), input_1(for validate),就會產生兩個compute graph,但是要注意的是validate過程中需要share使用等同於train過程的w_h等變數,如果直接build兩次graph就回闡釋下面的示意圖

這種並沒有共享 w_h等資料,因此validate 會有問題(注意input_1裡面對應的w_h_1)

cost, accuracy=build_graph(x, label)

_, accuracy_test=build_graph((index_test, value_test), label_test)

train_op=gen_optimizer(cost, flags.learning_rate)

#train_op_test = gen_optimizer(cost_test, flags.learning_rate)

來自 <>

這裡tf.get_variable_scope().reuse_variables()並不起作用,因為build_graph裡面並沒有使用ge_variable機制

第一種解決方案

用類 self.w_h

解決此類問題的方法之一就是使用類來建立模組,在需要的地方使用類來小心地管理他們需要的變數. 乙個更高明的做法,不用呼叫類,而是利用tensorflow 提供了變數作用域

機制,當構建乙個檢視時,很容易就可以共享命名過的變數.

來自 <>

使用類的方式,共享w_h等變數

class mlp(object):

def __init__(self):

hidden_size = 200

num_features = num_features

num_classes = num_classes

with tf.device('/cpu:0'):

self.w_h = init_weights([num_features, hidden_size], name = 'w_h')

self.b_h = init_bias([hidden_size], name = 'b_h')

self.w_o = init_weights([hidden_size, num_classes], name = 'w_o')

self.b_o = init_bias([num_classes], name = 'b_o')

def model(self, x, w_h, b_h, w_o, b_o):

h = tf.nn.relu(matmul(x, w_h) + b_h)

return tf.matmul(h, w_o) + b_o

def forward(self, x):

py_x = self.model(x, self.w_h, self.b_h, self.w_o, self.b_o)

return py_x

x = (index, value)

algo = mlp()

cost, accuracy = build_graph(x, label, algo)

cost_test, accuracy_test = build_graph((index_test, value_test), label_test, algo)

train_op = gen_optimizer(cost, flags.learning_rate)

類似這種做法的例子tensorflow/tensorflow/models/embedding/word2vec.py

第二中變數共享

變數作用域機制在

tensorflow

中主要由兩部分組成:

方法 tf.get_variable()

用來獲取或建立乙個變數,而不是直接呼叫

tf.variable

.它採用的不是像

`tf.variable

這樣直接獲取值來初始化的方法

.乙個初始化就是乙個方法,建立其形狀並且為這個形狀提供乙個張量

.這裡有一些在

tensorflow

中使用的初始化變數:

**train-validate-share.py

來自 <>

tensorflow 共享變數

import tensorflow as tf 設定隨機種子,使得每次隨機初始化都一樣 tf.set random seed 1234 這是我們要共享的變數函式 def share variable input weight tf.get variable weight 2,2 return wei...

Tensorflow共享變數

使用variable宣告變數,同名變數的name後會自動加 1,可以賦初始值,但是需要在session初始化後才會生效。import tensorflow as tf var1 tf.variable 1.0,name firstvar print var1 var1.name var1 tf.va...

Tensorflow 變數的共享

tensorflow exp example sparse tensor classification train validate.py 當你需要train的過程中validate的時候,如果用placeholder來接收輸入資料 那麼乙個compute graph可以完成這個任務。如果你用的是t...