PRML の本読みをしています @section3
最近(というか今日から) 会社でPRML勉強会をやっています。ふつう第1章からやるのが普通ですがPRMLはちょっと重たいので息切れすると良くないよねということでいきなり第3章から始めるという方針をとっています。*1
今回は僕が担当で、主に
- 線形モデルの導入
- 正則化 (なんで Lasso はスパースになるのかの話)
- バイアス・バリアンス
の話をしていました。たぶんこの話の中で一番一般的に使われるのは「バイアスバリアンス」の話だと思っていて、改めてまとめてみて個人的に良かったです。ざっと今日話した内容含めてまとめると以下のような感じかな?と思ってます。
前提
- $D$: データの集合
- $E_D$ データ集合での期待値を取ったもの
- $y$: 予測関数
- $E_D[ y(x;D)]$ データの集合で期待値を取った予測関数。いろんなデータでモデルを作ってアンサンブルしてるイメージ
- $h$: 理想的な予測値
バイアス (モデルの傾向)
- データを沢山取ってアンサンブル平均した $E_D[ y(x;D)]$ が理想的な関数 $h$ にどれだけ近いか
- ロジック自体の傾向が強い(biasが強い)と大きくなる
- [低] 表現力が大きく多様性があるモデル
- [高] 表現力が小さいモデル (いくら平均しても理想に近づけない)
- たとえば常に 1 を返すようなモデル $y(x) = 1$ などはハイバイアスです。
- 癖が強い(データに合わせる気がない)というイメージ
バリアンス (Variance: モデルの分散)
- 今回のデータでの予測値と平均的なデータでの予測値の差分
- 毎回のデータでどれぐらいばらつくか
おまけ
バイアスバリアンスの説明に使われている Figure 3.5 にあたる絵を書く python script を書いてみました。
自分で書くと色々と考えるので楽しいですね。
import numpy as np import matplotlib.pyplot as plt def gaussian_kernel(x, basis=None): if basis is None: basis = np.linspace(-1.2, 1.2, 101) # parameter is my choice >_< phi = np.exp(- (x.reshape(-1, 1) - basis) ** 2 * 250) # add bias basis phi = np.hstack([phi, np.ones_like(phi[:, 0]).reshape(-1, 1)]) return phi def estimate_ml_weight(x, t, lam, xx): basis = np.linspace(0, 1, 24) phi = gaussian_kernel(x, basis=basis) w_ml = np.linalg.inv(phi.T.dot(phi) + lam * np.eye(len(basis) + 1)).dot(phi.T).dot(t) # bias があるので +1 しています xx_phi = gaussian_kernel(xx, basis=basis) pred = xx_phi.dot(w_ml) return pred n_samples = 100 fig, axes = plt.subplots(ncols=2, nrows=3, figsize=(10, 12), sharey=True, sharex=True) for i, l in enumerate([2.6, -.31, -2.4]): ax = axes[i] preds = [] for n in range(n_samples): x = np.random.uniform(0, 1, 40) xx = np.linspace(0, 1, 101) t = np.sin(x * 2 * np.pi) + .2 * np.random.normal(size=len(x)) pred = estimate_ml_weight(x, t, lam=np.exp(l), xx=xx) if n < 20: ax[0].plot(xx, pred, c='black', alpha=.8, linewidth=1) preds.append(pred) ax[1].plot(xx, np.sin(2 * xx * np.pi), c='black', label=f'Lambda = {l}') ax[1].plot(xx, np.mean(preds, axis=0), '--', c='black') ax[1].legend() fig.tight_layout() fig.savefig('bias_variance.png', dpi=120)
*1:若干荒業感がありますが、一人以外は一度は読んだことがあるというのとやはり息切れが怖いのでちょっと先に応用が出てきそうなところから取り組んでいます