2020年5月21日木曜日

最小二乗法

講義で使える統計素材」シリーズ.今回は,最小二乗法の説明に必要なデータ点とフィットした関数の図を掲載しています. 関数として直線と$\sin$関数を例に挙げました.

In [1]:
import numpy as np
import matplotlib.pyplot as plt
#import scipy.special as sp
from scipy.optimize import curve_fit
from mpl_toolkits.mplot3d import Axes3D

データ $y=ax+b$

np.random.normalは議事乱数なので実行するたびに結果が異なって面倒な場合はseed関数を指定する.中の数字は何でも良い.

In [2]:
delta = 1
np.random.seed(1)
x = np.arange(0, 10, delta)
y = 0.5*x+2 + np.random.normal(0, 0.5, x.shape[0])
plt.xlim(0, 10)
plt.ylim(0, 10)
plt.gca().set_aspect('equal', adjustable='box')

plt.plot(x, y, linestyle='None', marker='.', markersize=10)

plt.savefig('plot_out.svg', transparent=True)

最小二乗法 結果 $y=ax+b$

In [3]:
def err(x, a,b):
    return a*x+b

popt, pcov = curve_fit(err, x, y, p0=[1, 0])
print(popt)

delta=1/10
x_ = np.arange(0, 10, delta)
y_ = popt[0]*x_ + popt[1]
plt.xlim(0, 10)
plt.ylim(0, 10)
plt.gca().set_aspect('equal', adjustable='box')

plt.plot(x_, y_)
plt.plot(x, y, linestyle='None', marker='.', markersize=10)

plt.savefig('plot_out.svg', transparent=True)
[0.48113135 2.03633847]
In [4]:
def err(x, a,b):
    return a*x+b

popt, pcov = curve_fit(err, x, y, p0=[1, 0])
print(popt)

delta=1/10
x_ = np.arange(0, 10, delta)
y_ = popt[0]*x_ + popt[1]
plt.xlim(3, 7)
plt.ylim(3, 7)
plt.gca().set_aspect('equal', adjustable='box')

plt.plot(x_, y_)
plt.plot(x, y, linestyle='None', marker='.', markersize=10)
plt.vlines(np.mean(x_), ymin=0,ymax=10,linestyle="dashed")
plt.hlines(np.mean(y_), xmin=0,xmax=10,linestyle="dashed")

plt.savefig('plot_out.svg', transparent=True)
[0.48113135 2.03633847]

データ $\sin(x)$

In [5]:
delta = 1/2
x = np.arange(0, 2*np.pi, delta)
y = np.sin(x) + np.random.normal(0, 0.3, x.shape[0])
plt.plot(x, y, linestyle='None', marker='.', markersize=10)

plt.savefig('plot_out.svg', transparent=True)

最小二乗法 結果 $\sin(x)$

In [6]:
def func(x,a,b,c):
    return a*np.sin(b*x-c)

popt, pcov = curve_fit(func, x, y, p0=[1, 1, 0])
print(popt)

delta=1/10
x_ = np.arange(0, 2*np.pi, delta)
y_ = popt[0]*np.sin(popt[1]*x_-popt[2])
plt.plot(x_, y_)
plt.plot(x, y, linestyle='None', marker='.', markersize=10)

plt.savefig('plot_out.svg', transparent=True)
[0.95609538 1.04887765 0.05291554]