用TensorFlow實現簡單線性回歸

2021-09-24 05:16:32 字數 2382 閱讀 2917

使用tensorflow 構造乙個神經元,簡單的線性回歸網路。

問題:現有一組有雜訊的樣本資料,共2000個,每乙個樣本 x 有 3 個特徵, 對應乙個標籤 y 值。從資料樣本中學習 y=w

×x+b

y=w\times x + b

y=w×x+

b 中的引數

首先我們來生成樣本資料,w_real 和 b_real 是控制樣本資料的引數的真實值,

x_data = np.random.randn(

2000,4

)w_real =

[0.2

,0.3

,0.1

,0.3

]b_real =

-0.3

noise = np.random.randn(1,

2000)*

0.1y_data = np.matmul(w_real, x_data.t)

+ b_real + noise

編寫神經網路

下面會用到的 tensorflow api

官方tensorflow 文件

全部源**實現:

import tensorflow as tf

import numpy as np

import matplotlib.pyplot as plt

# 建立資料模擬

x_data = np.random.randn(

2000,4

)w_real =

[0.2

,0.3

,0.1

,0.3

]b_real =

-0.3

noise = np.random.randn(1,

2000)*

0.1y_data = np.matmul(w_real, x_data.t)

+ b_real + noise

# 清除預設圖中的內容

tf.reset_default_graph(

)# 設定步數

num_step =

10# 學習率

learning_rate =

0.5# 建立圖

g = tf.graph(

)# 儲存wb

wb_sess =

with g.as_default():

# x, y_true 佔位符

x = tf.placeholder(tf.float32, name =

'x')

y_true = tf.placeholder(tf.float32, name =

'y_true'

)# w, b 變數

w = tf.variable([[

0,0,

0,0]

], dtype = tf.float32, name =

'w')

b = tf.variable(

0, dtype = tf.float32, name =

'b')

# **值 y = w * x + b

y_pred = tf.add(tf.matmul(w, tf.transpose(x)

), b, name =

'y_pred'

)# 損失 計算成員平均值

loss = tf.reduce_mean(tf.square(y_true - y_pred)

, name =

'loss'

)# 優化器,sgd

optimizer = tf.train.gradientdescentoptimizer(learning_rate, name=

'sgd'

) train = optimizer.minimize(loss, name =

'train'

)# 全域性初始化節點

init = tf.global_variables_initializer(

)with tf.session(

)as sess:

sess.run(init)

for step in

range

(num_step)

: sess.run(train,

)

[w, b]))

if(step %5==

0):print

(step +

1, sess.run(

[w, b]))

print

(num_step, sess.run(

[w, b]

))

總結:

這只會讓你了解tensorflow的一些api 特性,加強使用這些api,簡單模型。

用TensorFlow實現iris資料集線性回歸

本文將遍歷批量資料點並讓tensorflow更新斜率和y截距。這次將使用scikit learn的內建iris資料集。特別地,我們將用資料點 x值代表花瓣寬度,y值代表花瓣長度 找到最優直線。選擇這兩種特徵是因為它們具有線性關係,在後續結果中將會看到。本文將使用l2正則損失函式。用tensorflo...

用Tensorflow完成簡單的線性回歸模型

思路 在資料上選擇一條直線y wx b,在這條直線上附件隨機生成一些資料點如下圖,讓tensorflow建立回歸模型,去學習什麼樣的w和b能更好去擬合這些資料點。1 隨機生成1000個資料點,圍繞在y 0.1x 0.3 周圍,設定w 0.1,b 0.3,屆時看構建的模型是否能學習到w和b的值。imp...

tensorflow實現簡單的卷積網路

import tensorflow as tf import gc from tensorflow.examples.tutorials.mnist import input data mnist input data.read data sets f zxy python mnist data o...