tensorflow學習筆記1 匯出和載入模型

2022-08-05 12:39:12 字數 2040 閱讀 1676

用乙個非常簡單的例子學習匯出和載入模型;

寫乙個y=a*x+b的運算,然後儲存graph;

import tensorflow as tf

from tensorflow.python.framework.graph_util import convert_variables_to_constants

with tf.session() as sess:

a = tf.variable(5.0, name='a')

x = tf.variable(6.0, name='x')

b = tf.variable(3.0, name='b')

y = tf.add(tf.multiply(a,x),b, name="y")

tf.global_variables_initializer().run()

print (a.eval()) # 5.0

print (x.eval()) # 6.0

print (b.eval()) # 3.0

print (y.eval()) # 33.0

graph = convert_variables_to_constants(sess, sess.graph_def, ["y"])

#writer = tf.summary.filewriter("logs/", graph)

tf.train.write_graph(graph, 'models/', 'test_graph.pb', as_text=false)

執行

在models目錄下生成了test_graph.pb;

注:convert_variables_to_constants操作是將模型引數froze(儲存)進graph中,這時的graph相當於是sess.graph_def + checkpoint,即有模型結構也有模型引數;

只載入,獲取各個變數的值

import tensorflow as tf

from tensorflow.python.platform import gfile

with gfile.fastgfile("models/test_graph.pb", 'rb') as f:

graph_def = tf.graphdef()

graph_def.parsefromstring(f.read())

output = tf.import_graph_def(graph_def, return_elements=['a:0', 'x:0', 'b:0','y:0'])

#print(output)

with tf.session() as sess:

result = sess.run(output)

print (result)

執行看以看到原本儲存的結果(因為幾個變數都已經帶入模型,又從模型中載入了出來)

載入的時候修改變數值

5*2+3=13,結果正確

執行時修改變數值

載入時用乙個佔位符替掉x常量,在session執行時再給佔位符填值;

5*3+3=18,也正確

修改計算結果

偷偷把結果給改了會怎麼樣?

呵呵,不知原因為何;以後鑽進**了再說;

參考:

tensorflow學習筆記1

在跑minist demo時,遇到了這幾句 batchsize 6 label tf.expand dims tf.constant 0,2,3,6,7,9 1 index tf.expand dims tf.range 0,batchsize 1 concated tf.concat 1,inde...

TensorFlow學習筆記1

1 tensorflow 谷歌第二代人工智慧學習系統 2 tensorflow顧名思義tensor flow。tensor的意思是 張量,flow的意思是 流動,合起來就是 張量的流動 3 系統架構及程式設計模型。其中系統架構如圖1所示,程式設計模型如圖2所示。圖1 tensorflow系統架構圖 ...

TensorFlow學習筆記1

編寫tensorflow的兩個步驟 構建計算圖graph 使用session去執行graph中的operation 這裡寫描述 三個基本概念 rank rank一般是指資料的維度,其與線性代數中的rank不是乙個概念。其常 用rank舉例如下。shape 指tensor每個維度資料的個數,可以用py...