スチールウールの活動記録

初心者に向けて(私もだけど)電子工作やプログラミングや関連することをいろいろ。

多項式曲線フィッティング

黄色い本として有名な パターン認識機械学習(通称:PRML) を勉強しているので、できたらとりあえずブログにメモしていきたいと思います。細かくかかないと僕はわすれてしまうので無駄に書いていきたいと思います。

多項式曲線

1.1の例としてでてくる多項式曲線フィッティングをやったので書いておきます。
これは一般的には最小二乗法と呼ばれるものです。
N個の観測値 {x} からそれぞれに対応する {t} をもとに {w} のパラメータを求めて、できるだけデータ集合に当てはまる曲線を作ろうという話です。
まず以下の多項式を使ってデータのフィッティングを行います。

$$ y(x,\mathbf{w})=w_0+w_1x+w_2x^{2}+...+w_Mx^{M} = \sum_{j=0}^{M}w_jx^{j} $$

{M}多項式の次数で、係数 {w_0,...,w_M} をまとめて {\mathbf{w}} というベクトルとします。

{\mathbf{w}} は、関数 {y(x,\mathbf{w})}と訓練データの誤差関数を最小化すれば求められます。

\begin{equation} E(\mathbf{w})=\frac{1}{2}\sum_{n=1}^N{y(x_n,\mathbf{w})-t_n}^2 \end{equation}

これを最小化するには {\mathbf{w}} をそれぞれ偏微分して、0にすればできます。

例えば、M=3の場合、 {\mathbf{w}}は {{w_0,w_1,w_2,w_3}}の集合になります。
そして、(1)の式は

$$ y(x,\mathbf{w})=w_0+w_1x+w_2x^{2}+w_3x^{3} $$

になり,(2)の式を展開すると、

$$ \frac{1}{2}\sum_{n=1}^{N}{(w_0+w_1x_n+w_2x_n^{2}+w_3x_n^{3})-t_n}^{2} $$ となります。

これを {\mathbf{w}} のそれぞれで偏微分します。

\begin{align} \frac{\partial}{\partial w_0}\left(\frac{1}{2}\sum_{n=1}^{N}{(w_0+w_1x_n+w_2x_n^{2}+w_3x_n^{3})-t_n}^{2}\right) \\ = \sum_{n=1}^{N}{(w_0+w_1x_n+w_2x_n^{2}+w_3x_n^{3})-t_n} \\ \frac{\partial}{\partial w_1}\left(\frac{1}{2}\sum_{n=1}^{N}{(w_0+w_1x_n+w_2x_n^{2}+w_3x_n^{3})-t_n}^{2}\right) \\ =\sum_{n=1}^{N}x_n{(w_0+w_1x_n+w_2x_n^{2}+w_3x_n^{3})-t_n} \\ \frac{\partial}{\partial w_2}\left(\frac{1}{2}\sum_{n=1}^{N}{(w_0+w_1x_n+w_2x_n^{2}+w_3x_n^{3})-t_n}^{2}\right) \\ =\sum_{n=1}^{N}x_n^{2}{(w_0+w_1x_n+w_2x_n^{2}+w_3x_n^{3})-t_n} \\ \frac{\partial}{\partial w_3}\left(\frac{1}{2}\sum_{n=1}^{N}{(w_0+w_1x_n+w_2x_n^{2}+w_3x_n^{3})-t_n}^{2}\right) \\ =\sum_{n=1}^{N}x_n^{3}{(w_0+w_1x_n+w_2x_n^{2}+w_3x_n^{3})-t_n} \\ \end{align}

式を変形すると、

\begin{align} \sum_{n=1}^{N}(w_0+w_1x_n+w_2x_n^{2}+w_3x_n^{3}) &= \sum_{n=1}^{N}t_n \\ \sum_{n=1}^{N}x_n(w_0+w_1x_n+w_2x_n^{2}+w_3x_n^{3})&=\sum_{n=1}^{N}{x_nt_n} \\ \sum_{n=1}^{N}x_n^{2}(w_0+w_1x_n+w_2x_n^{2}+w_3x_n^{3})&=\sum_{n=1}^{N}{x_n^{2}t_n} \\ \sum_{n=1}^{N}x_n^{3}(w_0+w_1x_n+w_2x_n^{2}+w_3x_n^{3})&=\sum_{n=1}^{N}{x_n^{3}t_n} \\ \end{align}

となります。

さらにまとめると、

\begin{align} \sum_{j=0}^M\left(\left(\sum_{n=1}^Nx_n^j\right) w_j\right)&=\sum_{n=1}^N{t_n} \\ \sum_{j=0}^M\left(\left(\sum_{n=1}^Nx_nx_n^j\right) w_j\right)&=\sum_{n=1}^N{x_nt_n} \\ \sum_{j=0}^M\left(\left(\sum_{n=1}^Nx_n^2x_n^j\right) w_j\right)&=\sum_{n=1}^N{x_n^2t_n} \\ \sum_{j=0}^M\left(\left(\sum_{n=1}^Nx_n^3x_n^j\right) w_j\right)&=\sum_{n=1}^N{x_n^3t_n} \\ \end{align}

ここで

\begin{align} A_i &= \sum_{n=1}^N x_n^{(i+j)} \\ T_i &= \sum_{n=1}^N x_n^it_n \end{align}

とすると、{w_i}において

\begin{align} \sum_{j=0}^MA_{ij}w_j &= T_i \end{align}

となります。

M=3のとき

\begin{align} \sum_{j=0}^MA_{0j}w_j &= T_0 \\ \sum_{j=0}^MA_{1j}w_j &= T_1 \\ \sum_{j=0}^MA_{2j}w_j &= T_2 \\ \sum_{j=0}^MA_{3j}w_j &= T_3 \\ \end{align}

なので、

\begin{align} A_{00}w_0+A_{01}w_1+A_{02}w_2+A_{03}w_3 &= T_0 \\ A_{10}w_0+A_{11}w_1+A_{12}w_2+A_{13}w_3 &= T_1 \\ A_{20}w_0+A_{21}w_1+A_{22}w_2+A_{23}w_3 &= T_2 \\ A_{30}w_0+A_{31}w_1+A_{32}w_2+A_{33}w_3 &= T_3 \\ \end{align}

行列表現にすると、

\begin{align} \left[ \begin{array}{rrrr} A_{00} & A_{01} & A_{02} & A_{03} \\ A_{10} & A_{11} & A_{12} & A_{13} \\ A_{20} & A_{21} & A_{22} & A_{23} \\ A_{30} & A_{31} & A_{32} & A_{33} \\ \end{array} \right] \left[ \begin{array}{r} w_0 \\ w_1 \\ w_2 \\ w_3 \\ \end{array} \right] &= \left[ \begin{array}{r} T_0 \\ T_1 \\ T_2 \\ T_3 \\ \end{array} \right] \end{align}

となります。wを求めるには、

\begin{align} \left[ \begin{array}{r} w_0 \\ w_1 \\ w_2 \\ w_3 \\ \end{array} \right] &= \left[ \begin{array}{rrrr} A_{00} & A_{01} & A_{02} & A_{03} \\ A_{10} & A_{11} & A_{12} & A_{13} \\ A_{20} & A_{21} & A_{22} & A_{23} \\ A_{30} & A_{31} & A_{32} & A_{33} \\ \end{array} \right]^{-1} \left[ \begin{array}{r} T_0 \\ T_1 \\ T_2 \\ T_3 \\ \end{array} \right] \end{align}

つまり、AとTで方程式を解けば求めることができます。

実装

行列などが簡単に計算できて、グラフも出力できるPythonを使います。

# coding:utf-8
import numpy as np
import matplotlib.pyplot as plt

def calc_y(x, w):
    return np.array([np.sum([w[i] * (x[n] ** i) for i in xrange(w.size)]) for n in xrange(x.size)])

# wを計算
def estimate(x, t, M):
    A = np.array([[(x ** (i + j)).sum() for j in xrange(M + 1)] for i in xrange(M + 1)])
    T = np.array([(x ** (i) * t).sum() for i in xrange(M + 1)])
    return np.linalg.solve(A, T)


if __name__ == '__main__':
    # 多項式の次数
    M = 3
    
    # データ点を10個作る
    x_data = np.linspace(0, 1, 10)
    # 分散0.2のノイズを与える
    t_data = np.sin(2 * np.pi * x_data) + np.random.normal(0, 0.2, x_data.size)

    x_true = np.linspace(0, 1, 500)
    y_true = np.sin(2 * np.pi * x_true)

    w = estimate(x_data, t_data, M)
    y_data = calc_y(x_true, w)

    plt.plot(x_data, t_data, 'o')
    plt.plot(x_true, y_true, 'g')
    plt.plot(x_true, y_data, 'r')

    plt.show()

f:id:bhjkkk:20160207062654p:plain

緑が正しいsin波形、赤がデータ点によってフィッティングしたものです。

うまくフィッティングができていると思います。

罰金項

罰金項を付け加えることで、係数が大きな値になることをある程度防ぐことができるようになります。
罰金項で最も単純な物は、係数を2乗して和をとったもので、誤差関数が

$$ \tilde{E}(\mathbf{w})=\frac{1}{2}\sum_{n=1}^N{y(x_n,\mathbf{w})-t_n}^2+\frac{\lambda}{2}|\mathbf{w}|^2 $$

となります。

{|\mathbf{w}|^2\equiv\mathbf{w}^{\mathbf{T}}\mathbf{w}=w_0^2+w_1^2+...+w_M^2} で、{\lambda}正則化項と二乗誤差の和の項との相対的な重要度の調節をします。
これを加えて{\mathbf{w}}を計算するには、同じように{\tilde{E}(\mathbf{w})}偏微分して0になるようにします。

まず、{\displaystyle \frac{\lambda}{2}|\mathbf{w}|^2}に注目します。
M=3のときを考えると、

\begin{align} \frac{\lambda}{2}(w_0^2+w_1^2+w_2^2+w_3^2) \end{align}

です。これを偏微分すると

\begin{align} \frac{\partial}{\partial w_0}&\left(\frac{\lambda}{2}(w_0^2+w_1^2+w_2^2+w_3^2)\right) \\ &=\lambda w_0 \\ \frac{\partial}{\partial w_1}&\left(\frac{\lambda}{2}(w_0^2+w_1^2+w_2^2+w_3^2)\right) \\ &=\lambda w_1 \\ \frac{\partial}{\partial w_2}&\left(\frac{\lambda}{2}(w_0^2+w_1^2+w_2^2+w_3^2)\right) \\ &=\lambda w_2 \\ \frac{\partial}{\partial w_3}&\left(\frac{\lambda}{2}(w_0^2+w_1^2+w_2^2+w_3^2)\right) \\ &=\lambda w_3 \\ \end{align}

になります。つまり、上記の {A_{ij}}{i=j} のときに罰金項の偏微分の値が足されるので、

$$ A_{ij} = \begin{cases} \displaystyle \lambda+\sum_{n=1}^N x_n^{(i+j)} & (i=j) \\ \displaystyle \sum_{n=1}^N x_n^{(i+j)} & (i!=j) \end{cases} $$

となります。

Pythonのコードに以下の関数を追加します。
PRMLには {\ln{\lambda}=-18} とあるので、{\lambda=\mathrm{e}^{-18}} にします。

# 罰金項ありでwを計算
def estimate_la(x, t, M):
    la = np.exp(-18)
    A = np.array([[(x ** (i + j)).sum() if i != j else la + (x ** (i + j)).sum() for j in xrange(M + 1)] for i in
                  xrange(M + 1)])
    T = np.array([(x ** (i) * t).sum() for i in xrange(M + 1)])
    return np.linalg.solve(A, T)

f:id:bhjkkk:20160207090155p:plain

緑が正しい線、赤がただの二乗誤差関数による線、青が罰金項を付け加えた線です。

M=9でやっていますが、過学習が抑えられていい感じになってますね。