nykergoto’s blog

機械学習とpythonをメインに

PRML の本読みをしています @section3

最近(というか今日から) 会社でPRML勉強会をやっています。ふつう第1章からやるのが普通ですがPRMLはちょっと重たいので息切れすると良くないよねということでいきなり第3章から始めるという方針をとっています。*1

今回は僕が担当で、主に

  • 線形モデルの導入
  • 正則化 (なんで Lasso はスパースになるのかの話)
  • バイアス・バリアンス

の話をしていました。たぶんこの話の中で一番一般的に使われるのは「バイアスバリアンス」の話だと思っていて、改めてまとめてみて個人的に良かったです。ざっと今日話した内容含めてまとめると以下のような感じかな?と思ってます。

f:id:dette:20191112005503p:plain

前提

  • $D$: データの集合
  • $E_D$ データ集合での期待値を取ったもの
  • $y$: 予測関数
  • $E_D[ y(x;D)]$ データの集合で期待値を取った予測関数。いろんなデータでモデルを作ってアンサンブルしてるイメージ
  • $h$: 理想的な予測値

バイアス (モデルの傾向)

  • データを沢山取ってアンサンブル平均した $E_D[ y(x;D)]$ が理想的な関数 $h$ にどれだけ近いか
  • ロジック自体の傾向が強い(biasが強い)と大きくなる
    • [低] 表現力が大きく多様性があるモデル
    • [高] 表現力が小さいモデル (いくら平均しても理想に近づけない)
      • たとえば常に 1 を返すようなモデル $y(x) = 1$ などはハイバイアスです。
      • 癖が強い(データに合わせる気がない)というイメージ

バリアンス (Variance: モデルの分散)

  • 今回のデータでの予測値と平均的なデータでの予測値の差分
  • 毎回のデータでどれぐらいばらつくか
    • [低] データによって予測が変わらないもの
      • たとえば学習を全くしない y=1 を常に返すアルゴリズムは Variance=0 です。反対にバイアスはめっちゃ大きい。
    • [高] データに依存して予測値が良く変わるモデル
      • たとえばパラメータ過多の最尤推定のようにデータに過剰に fitting してしまうアルゴリズムはハイバリアンスです

おまけ

バイアスバリアンスの説明に使われている Figure 3.5 にあたる絵を書く python script を書いてみました。

自分で書くと色々と考えるので楽しいですね。

import numpy as np

import matplotlib.pyplot as plt

def gaussian_kernel(x, basis=None):
    if basis is None:
        basis = np.linspace(-1.2, 1.2, 101)
    
    # parameter is my choice >_<
    phi = np.exp(- (x.reshape(-1, 1) - basis) ** 2 * 250)
    
    # add bias basis
    phi = np.hstack([phi, np.ones_like(phi[:, 0]).reshape(-1, 1)])    
    return phi

def estimate_ml_weight(x, t, lam, xx):
    basis = np.linspace(0, 1, 24)
    phi = gaussian_kernel(x, basis=basis)
    w_ml = np.linalg.inv(phi.T.dot(phi) + lam * np.eye(len(basis) + 1)).dot(phi.T).dot(t) # bias があるので +1 しています
    xx_phi = gaussian_kernel(xx, basis=basis)
    pred = xx_phi.dot(w_ml)
    return pred

n_samples = 100

fig, axes = plt.subplots(ncols=2, nrows=3, figsize=(10, 12), sharey=True, sharex=True)

for i, l in enumerate([2.6, -.31, -2.4]):
    ax = axes[i]
    preds = []
    for n in range(n_samples):
        x = np.random.uniform(0, 1, 40)
        xx = np.linspace(0, 1, 101)
        t = np.sin(x * 2 * np.pi) + .2 * np.random.normal(size=len(x))
        pred = estimate_ml_weight(x, t, lam=np.exp(l), xx=xx)
        
        if n < 20:
            ax[0].plot(xx, pred, c='black', alpha=.8, linewidth=1)

        preds.append(pred)

    ax[1].plot(xx, np.sin(2 * xx * np.pi), c='black', label=f'Lambda = {l}')
    ax[1].plot(xx, np.mean(preds, axis=0), '--', c='black')
    ax[1].legend()
    
fig.tight_layout()
fig.savefig('bias_variance.png', dpi=120)

f:id:dette:20191112002247p:plain
バイアス・バリアンスのやつ

*1:若干荒業感がありますが、一人以外は一度は読んだことがあるというのとやはり息切れが怖いのでちょっと先に応用が出てきそうなところから取り組んでいます