統計學第一章 最小二乘擬合正弦函式,正則化

2021-09-10 23:13:30 字數 2594 閱讀 1239

#coding:utf-8

import numpy as np

import scipy as sp

from scipy.optimize import leastsq

import matplotlib.pyplot as plt

# 目標函式

def real_func(x):

return np.sin(2*np.pi*x)

# 多項式

def fit_func(p, x):

f = np.poly1d(p)

# print('f=',f)

return f(x)

# 殘差

def residuals_func(p, x, y):

ret = fit_func(p, x) - y

return ret

# 十個點

x = np.linspace(0, 1, 10)

x_points = np.linspace(0, 1, 1000)

# 加上正態分佈噪音的目標函式的值

y_ = real_func(x)

y = [np.random.normal(0, 0.1) + y1 for y1 in y_]

def fitting(m=0):

"""m 為 多項式的次數

"""# 隨機初始化多項式引數

p_init = np.random.rand(m + 1)

# 最小二乘法

p_lsq = leastsq(residuals_func, p_init, args=(x, y))

print('fitting parameters:', p_lsq[0])

## 視覺化

plt.plot(x_points, real_func(x_points), label='real')

plt.plot(x_points, fit_func(p_lsq[0], x_points), label='fitted curve')

plt.plot(x, y, 'bo', label='noise')

plt.legend()

plt.show()

return p_lsq

# m=0

p_lsq_0 = fitting(m=0)

# m=1

p_lsq_1 = fitting(m=1)

# m=3

p_lsq_3 = fitting(m=3)

# m=9

p_lsq_9 = fitting(m=9)

m分別為0,1,3,9時的多項式係數。 

m=0,即多項式為常數時 

m=1, 即多項式為一次項時

m=3,即多項式為三次項時,可看出擬合的比較不錯

m=9時,可看出過擬合了

引入正則化

#加入正則

regularization = 0.0001

def residuals_func_regularization(p, x, y):

ret = fit_func(p, x) - y

return ret

# 最小二乘法,加正則化項

p_init = np.random.rand(9+1)

p_lsq_regularization = leastsq(residuals_func_regularization, p_init, args=(x, y))

plt.plot(x_points, real_func(x_points), label='real')

plt.plot(x_points, fit_func(p_lsq_9[0], x_points), label='fitted curve')

plt.plot(x_points, fit_func(p_lsq_regularization[0], x_points), label='regularization')

plt.plot(x, y, 'bo', label='noise')

plt.legend()

plt.show()

可看出:正則化有效 

統計學 第一章導論

什麼是統計學?統計學是一門收集 整理 顯示和分析解釋資料並從資料中得出結論的科學。通俗的講,統計就是利用資料,讓資料本身說話,根據資料建立模型從而得出結論。學統計有什麼用?統計學的分類 統計學大體上可以分為描述性統計和推斷統計。描述性統計主要研究資料的收集,整理,彙總,圖表描述,通過給出一些統計量 ...

統計學習方法 第一章

1.統計學習的特點 2.統計學習的物件 對資料的基本假設 同類資料具有一定的統計規律性 3.統計學習的目的 4.統計學習方法 1.基本概念 x x 1,x 2,x i x n t x i x i 1 x i 2 x in t t x 1 y 1 x 2 y 2 x n y n 2.聯合概率分布 3....

統計學習方法第一章

1.numpy.poly1d 1,2,3 import numpy as np np.poly1d 1 2,3 poly1d 1 2,3 r np.poly1d 1 2,3 print r 1 62.from scipy.optimize import leastsq 表示scipy.optimiz...