用一個非常簡單的例子學習匯出和載入模型;
寫一個y=a*x+b的運算,然後儲存graph;
import tensorflow as tf執行from tensorflow.python.framework.graph_util import convert_variables_to_constants
with tf.session() as sess:
a = tf.variable(5.0, name='a')
x = tf.variable(6.0, name='x')
b = tf.variable(3.0, name='b')
y = tf.add(tf.multiply(a,x),b, name="y")
tf.global_variables_initializer().run()
print (a.eval()) # 5.0
print (x.eval()) # 6.0
print (b.eval()) # 3.0
print (y.eval()) # 33.0
graph = convert_variables_to_constants(sess, sess.graph_def, ["y"])
#writer = tf.summary.filewriter("logs/", graph)
tf.train.write_graph(graph, 'models/', 'test_graph.pb', as_text=false)

在models目錄下生成了test_graph.pb;
注:convert_variables_to_constants操作是將模型引數froze(儲存)進graph中,這時的graph相當於是sess.graph_def + checkpoint,即有模型結構也有模型引數;
只載入,獲取各個變數的值
import tensorflow as tf執行看以看到原本儲存的結果(因為幾個變數都已經帶入模型,又從模型中載入了出來)from tensorflow.python.platform import gfile
with gfile.fastgfile("models/test_graph.pb", 'rb') as f:
graph_def = tf.graphdef()
graph_def.parsefromstring(f.read())
output = tf.import_graph_def(graph_def, return_elements=['a:0', 'x:0', 'b:0','y:0'])
#print(output)
with tf.session() as sess:
result = sess.run(output)
print (result)
載入的時候修改變數值
5*2+3=13,結果正確
執行時修改變數值
載入時用一個佔位符替掉x常量,在session執行時再給佔位符填值;
5*3+3=18,也正確
修改計算結果
偷偷把結果給改了會怎麼樣?
呵呵,不知原因為何;以後鑽進**了再說;
參考: