tensorflow基礎知識3,回歸例子

2021-09-10 12:40:41 字數 1857 閱讀 6923

1、線性回歸例子

import tensorflow as tf

import numpy as np

import matplotlib.pyplot as plt

#使用numpy生成200個隨機點

x_data=np.linspace(-0.5,0.5,2000)[:,np.newaxis]#在-0.5到0.5之間生成200個隨機數,裡是增加乙個維度,200行1列

noise=np.random.normal(0,0.2,x_data.shape)#干擾項,從正態分佈中輸出隨機值,shape是檢視矩陣或者陣列的維數,生成與x_data陣列一樣的形狀

y_data=np.square(x_data)+noise#得到乙個u形影象

#定義兩個placeholder佔位符

x=tf.placeholder(tf.float32,[none,1])#行不確定,列是1

y=tf.placeholder(tf.float32,[none,1])

#定義神經網路中間層

weights_l1=tf.variable(tf.random_normal([1,10]))#權值,定義乙個隨機變數,隨機賦值,形狀是1行10列,權值是連線輸入層和中間層,輸入層是1個神經元,中間層是10個神經元

biases_l1=tf.variable(tf.zeros([1,10]))#偏置值,初始化為0,注意zero函式的用途

wx_plus_b_l1=tf.matmul(x,weights_l1)+biases_l1#matmul是矩陣相乘,中間層訊號的總和

l1=tf.nn.tanh(wx_plus_b_l1)#l1相當於中間層的輸出,雙曲正切函式tanh作為啟用函式.啟用函式可以使輸出一定非線性函式

#定義神經網路輸出層

weights_l2=tf.variable(tf.random_normal([10,1]))#中間層是10個神經元,輸出層是1個神經元,所以weights_l2是10行1列的矩陣

biases_l2=tf.variable(tf.zeros([1,1]))

wx_plus_b_l2=tf.matmul(l1,weights_l2)+biases_l2#輸出層的輸入就是中間層的輸出l1,輸出層訊號的總和

prediction=tf.nn.tanh(wx_plus_b_l2)

#二次代價函式

loss=tf.reduce_mean(tf.square(y-prediction))

#使用梯度下降法訓練,得到最小化代價函式值

train_step=tf.train.gradientdescentoptimizer(0.1).minimize(loss)

with tf.session() as sess:

#變數初始化,只要有變數就一定要初始化

sess.run(tf.global_variables_initializer())

for _ in range(2000):#訓練2000次

sess.run(train_step,feed_dict=)#feed的資料以字典的形式傳入

#獲得**值

prediction_value=sess.run(prediction,feed_dict=)#此處的prediction是上面經過2000次訓練得到的最優解

#畫圖plt.figure()

plt.scatter(x_data,y_data)#用散點圖的方式把樣本點顯示出來

plt.plot(x_data,prediction_value,'r-',lw=5)#紅色實現,線寬為5

tensorflow 基礎知識

變數不執行,只列印行列引數,sess.run a 列印出變數實際內容 y int x0 x1 1 for x0,x1 in x 這樣能夠判斷x生成的資料兩個數的和小於1,那麼y就是1,否則是0 rng np.random.randomstate seed x rng.rand 32,2 生成32行2...

TensorFlow基礎知識

從helloword開始 mkdir mooc 新建乙個mooc資料夾 cd mooc mkdir 1.helloworld 新建乙個helloworld資料夾 cd 1.helloworld touch helloworld.py coding utf 8 引入 tensorflow 庫 impo...

TensorFlow基礎知識

title tensorflow基礎知識 date 2018 03 31 14 13 12 categories tensorflow是乙個採用資料流圖 data flow graphs 用於數值計算的開源軟體庫。下圖就是乙個資料流圖。資料流圖是乙個用來描述數學計算的由 結點 nodes 和 線 e...