nykergoto’s blog

機械学習とpythonをメインに

ニューラルネットは何を見ているか ~ ブラックボックスモデルの解釈可能性

ディープラーニングによる予測は特に画像分野において顕著な性能を示していることはご案内のとおりです。これは ResNet や BatchNormalization といった技術の開発により多数のレイヤが重なった大きなモデルに対しても学習を行うことが可能になったことが理由の一つです。

一方で、ネットワーク構造が複雑になってしまったがゆえに別の問題も生じています。 それはなぜディープラーニングが性能が良いのか、どうやって判別を行っているのかということについてよくわからない、という問題です。 最近では人間が見ると元の画像とはほとんど変わらないのに出力結果が大きく変化してしまう画像(敵対的サンプル)を生成できてしまうことを示した論文なども登場しています。

というわけで今回は、内部でどうなっているかよくわからないモデルに対して、その解釈性を与えることに関する論文 Interpretable Explanations of Black Boxes by Meaningful Perturbation を紹介します。

解釈性とは何か

さて僕はさらっと解釈性、という単語を用いましたが、一体解釈性とは何でしょうか。ここで中身がどうなっているかわからないブラックボックスな予測モデル $f$ を以下のように定義します。

$$ f: X \to Y $$

ここで $X$ は入力値の取りうる空間, $Y$ は出力の空間です. 例えば画像の二値分類であれば入力は rgb の画像 $x : {\cal A} \times \mathbb{R}^3$ となります。ここで ${\cal A} = \left\{1, \ldots, H \right\} \times \left\{1, \ldots, W \right\}$ で $H, W \in \mathbb{N}$ は画像の縦横のピクセル数です. 一方出力 $Y$ は $\left\{-1, +1\right\}$ のラベルとなります.

解釈性とは特定の入力画像にたいしての出力が予測するルールのこと、と表現することができます。 例えば以下のような条件を満たす $Q_1$ というルールを得ることができた!とします。

$$ Q_1(x ;f) = \left\{ x \in X_c \iff f(x) = +1 \right\} $$

ここで $X_c \subset X$ は入力画像のうち $+1$ を取る画像集合であり, $X$ の部分集合です. このルール $Q_1$ の要素が $1$ の値を取る画像であれば $f$ も 1 を返し, 逆も成り立ちます。

以下の期待値を考えること、このルールがどの程度正しいのかを知ることも可能です.

$$ {\cal L_1} = \mathbb{E}_{x \sim p(x)} [1 - \delta_{Q_1}] $$

ここで $\delta_Q$ は $Q$ に$x$ が含まれているときに 1 その他のとき 0 を返す標示関数です.

${\cal L_1}$ は $Q_1$ が成立しない確率です。すなわち $f$ が $+1$ の画像に対して正しく分類ができてるの? という解釈性になっている、というわけです。そしてなぜ $f$ に仮定をおかず解釈性の議論をすることが出来るかというと条件の左側 $x \in X_c$ という関係が明瞭であるからなのです。

同じような形式で、出力同士を比べることでも、解釈性を考えることは可能です。 例として, $f$ がある回転角度 $\theta$ に対して不変な出力を持つのかということを確かめたいとします. この時のルール $Q_2$ は以下のようになります.

$$ Q_2(x, x'; f) = \left\{ x \sim_{\theta} x' \Rightarrow f(x) = f(x') \right\} $$

先と同様に期待値を取ればルールの正しさを知ることができます。

$$ {\cal L_2} = \mathbb{E}[1 - \delta_{Q_2}] $$

この ${\cal L_2}$ は $f$ が $\theta$ の回転に対して出力が変わらないのか? という解釈がどの程度成立するか、ということの表現になっています。

ではこれを拡張して, 画像の特定の領域を削ってしまうような関数 $h: X \to X$ があるとします。 そしてこの $h$ が中身がよくわからない $f$ の出力に関与しているのか? ということを知りたいとしましょう。

先程までの議論同様に考えると, 以下のようなルール $Q_3$ を考えれば良いことがわかります.

$$ Q_3(x; f, h) = \left\{ x' = h(x) \Rightarrow f(x) \neq f(x') \right\} $$

このルールが成立する確率が高くなるような $h$ が削っている領域が $f$ の出力をより決定している領域です。 ざっくりと言ってしまうと $f$ がよく見ている部分 とでも言うことが出来るでしょうか。

ルールが決まったのであとはロスを考えれば良いのですが、$h$ が極端な値を取られると困ります。例えば画像を 1px 間隔で削るような関数があったとします。自然に考えて人間はそのように画像を解釈していませんから、解としてはあまりよろしくないですよね。このような不自然な削り方に対してペナルティを与えるため一般的な機械学習同様に $h$ に対して正則化の効果を与えるような項を付け加えます。

結局以下の最小化問題を解けば良いことになります。

$$ \min_{h \in {\cal H}} {\cal L}_3 := \mathbb{E}_{x \sim p}[1 - \delta_{Q_3(x; h, f)}] + {\cal R} (h) $$

ここで ${\cal H}$ は関数 $h$ が取りうる範囲を決める関数集合, ${\cal R}: {\cal H} \to \mathbb{R}$ はペナルティを計算する正則化を計算する関数です。

論文での目的関数

前置きが長くなりましたが論文の話に戻ります。

論文中では $h$ をどの領域をどのぐらい削るかで表現するために画像の各ピクセルに対して 0 から 1 の値を取るような $m: {\cal A} \to [0, 1]$ を定義し以下のような関数 $\Phi: X \to X$ を用いて画像情報を削除しています。

$$ [\Phi(x_0; m)](u) = m(u) x_0(u) + (1- m(u))\mu_0 $$

ここで $u \in {\cal A}$ は画像の各要素を表しており, $\mu_0$ は平均の色の値を表します。 これを用いたロス関数は以下のようになっています。 論文ではこれの他にガウスノイズを乗せたものや ぼかしフィルタによる定式化も提案されていますが、ノリとしてはほぼ一緒です。

これを用いて目的関数を記述すると以下のようになります。

$$ {\cal L}(x) = \lambda_1 \| 1 - m\|_1 + \lambda_2 \sum_{u \in X} \| \nabla m(u) \| + \mathbb{E}_{\tau \sim [0, 4]} \left[ f(\Phi(x_0(- \bullet \tau): m)) \right] $$

第一項は必要な部分だけを 0 にしているか (すべての要素を 0 にすれば入力の情報がゼロになり出力が変わるのは自明) という意味です。

第二項は $m$ の変化が自然かどうか (画像方向に対して勾配がきびしくないか) という正則化項です。

第三項は $\Phi$ による入力画像の加工に加えて, $[0, 4)$ の一様分布に従う確率変数 $\tau$ によって入力画像の信号が弱まる様になっています。これは入力画像の一部分に対する過学習を起こさないために付け加えられています。

この目的関数を用いて勾配法をつかって最適化を行います。

実験

github 上に pytorch 実装が上がっていたので, fork していろいろいじってみました。

https://github.com/nyk510/pytorch-explain-black-box

pytorch が動く環境であれば以下のコマンドで実験を行います。第一引数が対象となる画像です. サンプルに入っている画像は下の猿の画像です。吠えてます。

f:id:dette:20180207170713j:plain

これを使って学習をしてみます。

python main.py examples/macaque.jpg
run with gpu
------------------------------
Category with highest probability: 373 - 96.07 %
Optimizing..
finish!
------------------------------
after masked probability: 1.26 %
save to output/macaqueperturbated.png
save to output/macaqueheat_map.png
save to output/macaquemask.png
save to output/macaquecam.png

はじめ 96% で分類できていた画像が加工後には 1,26% にまで落ち込んでいます。

学習されたマスクをかけた画像は以下のようになっています。

f:id:dette:20180207170827p:plain

人間の目にはほとんど変化が無いように見えます。敵対的サンプルの画像と似たような画像を生成していることがわかります。 マスクを掛けている領域はこんな感じ。猿の頭部分を見て判定しているのかな?といった印象ですね。

f:id:dette:20180207171220p:plain

感想

  • 正則化の与え方が論文を読んだ感じかなりヒューリスティックなので改良の余地はあるかも
  • 最適化方法を二次の情報を用いた手法 (たとえば L-BFGS など)にすればより早く収束しそう