nykergoto’s blog

機械学習とpythonをメインに

scikit-learn の grid-search を sample_weight と同時に使用する場合の問題点

以下の記事によると scikit-learn の BaseSearchCV の実装には問題があり意図しない動作をしている可能性がある、とのことが報告されています。いつも scikit learn を使う身としては気になる話題なので、少し詳しく見ていきます。

deaktator.github.io

BaseSearchCVと sample weight の問題点

記事中で指摘されている問題点をまとめると以下のようになります。

  • BaseSearchCV を継承した CV class において, 各 CV ごとの学習自体は sample_weight が適用される
  • しかし validation set に対する score の計算では sample_weight が適用されない

検証

試しにさっくり手元の環境でもやってみました。使用するバージョンは最新の scikit-learn==0.24.0 です。

cv = np.array([0, 0, 1, 1])
X = np.ones(shape=(4, 1))
y = np.array([1, 0, 1, 0])

fold = np.array([
    [[0, 1], [2, 3]],
    [[2, 3], [0, 1]],
])

# negative に対して weight を 999 / 1000 で与える
sample_weight_for_zeros = [
    1, 999, 1, 999
]

# grid search で学習. 入力が const の linear model なので positive の割合が weight になるはず.
grid = GridSearchCV(
    estimator=LogisticRegression(), 
    param_grid={ 'random_state': [42] }, 
    scoring='accuracy', 
    cv=fold, 
    return_train_score=True
)

grid.fit(X, y, sample_weight=sample_weight_for_zeros)

上記の設定だと sample_weight は negative に多く設定されています。したがって重み付きのスコアは 0.999 になっていてほしいです。 scikit-learn に実装されている accuracy では sample_weight arg が用意されていて、手動で計算すると以下のようになります.

from sklearn.metrics import accuracy_score

# 本当は 0.999 になってほしい
accuracy_score(y, grid.predict(X), sample_weight=sample_weight_for_zeros)  # 0.999

しかし実際には 0.5 になっています

# 実際には weight がないときと同じスコアになる
grid.best_score_ #  0.5

sample_weight が適用されていないのは validation でのスコア計算の部分で, 学習自体は意図した動作です.

# 予測値は negative が 0.999 が出力される (学習自体には sample_weight が適用されているため.)
grid.predict_proba(X) 

array([[0.99900002, 0.00099998],
       [0.99900002, 0.00099998],
       [0.99900002, 0.00099998],
       [0.99900002, 0.00099998]])

まとめ

  • 学習自体は sample_weight つきで期待通り実行されるが, 算出される score は sample_weight を無視して計算されている.
  • GridSearchCV や RandomSaerchCV など, BaseSearchCV を継承したパラメータサーチでは score の意味で最も良いパラメータが選ばれる.
  • したがって, sample_weight を考慮して最も良い parameter が知りたい場合でも, sample_weight を無視したスコアの意味でもっとも良いモデルが選ばれるため、問題.

sample_weight を使っていてかつ sciki-learn の枠組み上でパラメータ最適化を行う場合には注意したほうが良いかもしれません。

以下使用したコードです。

gist.github.com