用numpy構造的乙個簡單BP

2021-08-28 17:10:36 字數 3273 閱讀 1407

# -*- coding: utf-8 -*-

"""created on thu oct 4 08:28:15 2018

@author: 37989

"""import numpy as np

import pandas as pd

from matplotlib import pyplot

# 標準化

def standard(x):

x_mean = x.mean(axis=0)

x_std = (((x - x_mean) ** 2).sum(axis=0) / len(x)) ** 0.5

return (x - x_mean) / x_std

# 歸一化

def maxmin(x):

x_min = x.min(axis=0)

x_max = x.max(axis=0)

return (x - x_min) / (x_max - x_min)

# 啟用函式

def sigmod(x):

return 1 / (1 + np.power(np.e, -x))

# 正切函式

def tanh(x):

return 1 - 2 / (np.power(np.e, 2 * x) + 1)

# 建立模型

# 預設隨機[-1,1]初始化權值閾值

def build(netnum):

w, q= ,

for i in range(len(netnum) - 1):

return w, q

# **模型

def predict(w, q, x):

for i in range(len(w)):

x = np.dot(x, w[i]) - q[i]

x = sigmod(x)

return x

# 訓練集,測試集劃分

def datasplit(x, y, testper):

testnum = int(len(x) * testper)

rand = np.random.permutation(len(x))

x_test = x[rand[:testnum]]

y_test = y[rand[:testnum]]

x_train = x[rand[testnum:]]

y_train = y[rand[testnum:]]

return x_train, y_train, x_test, y_test

if __name__ == '__main__':

# maxmin 歸一化

# sigmod 啟用函式

## 載入資料

data = pd.read_csv(r'f:\temp\python\scrapy\myspider\myspider\tensorflow\database\data_continuous.csv')

x_label = ['f1', 'f2', 'f3', 'f4']

y_label = ['r']

x = data[x_label].values

y = data[y_label].values

netnum = [len(x_label)]

# 相關引數設定

learnrate = 0.05 # 學習率

echos = 100 # 迭代次數

netnum += [10] # 隱層節點

testper = 0.3 # 測試集比例

# 標準化

x = maxmin(x)

y = maxmin(y)

# 建立模型

w, q = build(netnum)

x_train, y_train, x_test, y_test = datasplit(x, y, testper)

# 訓練

for num in range(echos):

for position in range(len(x_train)):

x = x_train[position].reshape(1,netnum[0])

y = y_train[position]

out_x = [x]

for i in range(len(w)):

# 隱層結果

temp = np.dot(x, w[i]) - q[i]

x = sigmod(temp)

# 最終輸出

output_y = out_x[-1]

# 更新

e = output_y * (1 - output_y) * (y - output_y)

for i in range(len(w)-1,-1,-1):

x = out_x[i]

n = netnum[i]

w[i] += learnrate * np.dot(x.reshape(n, 1), e)

q[i] -= learnrate * e

e = (w[i] * e).sum(axis=1) * x * (1 - x)

# 驗證網路模型

predict_y = predict(w, q, x_train)

mse = ((y_train - predict_y) ** 2).sum() / len(x_train)

print("第 %d 次訓練,均方誤差為:%f" % (num+1, mse))

# 模型評估

predict_y = predict(w, q, x_test)

mse = ((y_test - predict_y) ** 2).sum() / len(y_test)

print("測試集均方誤差為:%f" % mse)

pyplot.scatter(predict_y, y_test)

pyplot.show()

乙個簡單的bp回歸模型,

資料格式:

f1     f2     f3     f4  .......    r1    r2     r3

21.5  4.5     5       6  .......     2        5       4

可以構造多層神經網路,隱層結構和引數netnum中設定:

比如netnum = [8,6,9],那麼網路結構為  inputnum -> 8 -> 6 -> 9 -> outputnum

可以設定迭代次數 echos 、學習率learnrate、測試集比重testper

訓練效果一般,僅供練習**。

bp演算法的乙個簡單例子

視覺機器學習20講 中簡單講解了一下bp演算法的基本原理,公式推導看完後不是特別能理解,在網上找到乙個不錯的例子 bp演算法 error back propagation 對bp演算法的理解非常有幫助。於是為了加強記憶,將文中的示例 用python重新寫了一遍。使用梯度下降演算法不斷迭代更新引數w,...

用AJAX編寫乙個簡單的相簿

xml問題終於在今天還是解決了。最後在firefox裡還是使用了dom的一些老方法。我這裡就具體解釋一下方法吧.var xmlhttp 用來定義乙個xmlhttprequest物件 上面這段話是判斷當前瀏覽器版本,以定義給xmlhttp不同的xmlhttprequest物件.如果伺服器的響應沒有xm...

用redis做乙個簡單的秒殺

下面是乙個簡單的下單操作 include mmysql.class.php configarr host port user passwd dbname db new mmysql configarr sql select from sdb b2c products where product id...