SGDにおける「順番」の問題
この記事は atma Advent Calendar
の 12/1 分の記事です。大分遅くなってしまいましたがこの記事では Stochastic Gradient Descent における順番が与える影響とそれにまつわる論文をいくつか紹介したいと思います。
Stochastic Gradient Descent とはなにか
Stochastic Gradient Descent (SGD) は連続変数に対する最適化手法の一つで、ベースとなっているのは勾配法 (Gradient Descent) です。ここで以下の説明のためにいくつか用語を定義します。まず考える最適化問題は以下のように関数 $f_n$ の和で表される関数 $F$ の最適化です。
$$ {\rm minimize}_{x \in \mathbb{R}^d} F(x) := \sum_{n=1}^N f_n (x) $$
こういう関数は例えば機械学習・統計分析で頻出です。というのも $\sum$ の部分というのは期待値の計算とみなせるので、期待値を取ったなにかを最大化・最小化する問題もこの形式で記述できるためです。
また尤度関数の最大化問題も対数を取ってあげると上記の形式に帰結することができます。この場合 $N$ はデータの数に相当します。典型的なニューラルネットワークや線形モデルによる学習はすべてこれに該当する問題です。
最適化手法・勾配法では $i$ 番目の iteration において、今の最適化変数 $x_i$ の周りで計算した目的関数全体の勾配 $\nabla F$ を用いて, 全体の勾配の意味で最も下がる方向へ更新します。 $$ \begin{aligned} x_{i + 1} &= x_i - \alpha_i \nabla F(x_i) \\ &= x_i - \alpha_i \sum_{n=1}^N f_n(x_i) \end{aligned} $$
ここで $\alpha_i \in \mathbb{R}$ は $i$ 番目の更新で用いるステップサイズです。
SGD では全体の勾配の代わりに適当なインデックス $k \in \{1, 2, ... N\}$ を選んでそれを用いて更新を行います。すなわち以下のとおりです。
$$ x_{i + 1} = x_i - \alpha_i \nabla f_{k}(x_i) $$
更新時にはただひとつの勾配を計算すれば良いだけなので、GD に比べ一回あたりの更新の計算量が小さいのが特徴です。特に $N$ が膨大な値になる場合には GD では都度 $N$ 回の計算が必要になるのに比べて有利です。一方で、都度ただひとつのデータに対しての勾配しか計算しませんので、目的関数全体に対する正確な勾配との乖離が起こり、収束は遅くなります。
SGD における「順番」の問題
前置きが長くなってしまいましたが、ここからが本題です。先ほど SGD では毎回更新に使うデータの index $k$ を選ぶ、という話をしました。
ではこの $k$ はどのように選ぶのが良いのでしょうか。
まずナイーブに思いつくのは $k$ を $1, 2, \cdots, N, 1, 2, \cdots$ の用に順繰りに選んでいく方法です。この方法を以下では Cyclic と呼びます。
他のやり方は都度 ${1, 2, ..., N}$ を取りうる一様な分布から一つサンプルする方法です。この方法を以下では Random と呼びます。
$n$ ごとに重みをつけて偏りがある分布からサンプルする方法 (Importance Sampling) という方法もありますが以下では確率はいじらずに、一様な分布からサンプリングする方法に絞って議論していきます。
さらに上記の中間のような感じで、都度ランダムにサンプルはせず $N$ 回に1回だけ順番を入れ替え(シャッフル) して、そのあとはその順番通り選んでいく方法もあるでしょう。これを RR: Randomly-ReShuffle と呼びます。
実はやり方で収束速度が違う
この3つの方法、どれも各 $n$ が選ばれる確率は $i \to \infty$ では $1/N$ で同じなのですが、面白いことに収束レートは違うことが経験的に知られています。
Curiously Fast Convergence of some Stochastic Gradient Descent Algorithms 2009 では Logistic Regression に対して、上記3つの方法を用いた場合の収束レートを CCAT Dataset で実験した様子が報告されていて、Random < Cyclic < RR の順番で収束が早くなっています。更に Cyclic / Shuffle の場合には理論研究で得られている収束レート $t^{-1}$ よりも早く収束しています。
SGD に対する理論研究では $k$ の選び方は、前後の関係によらず独立にサンプルされたものである、という立場にたって分析されています。一方で Cyclic / Shuffle は前後に依存関係があるという違いがあり、この違いがレートの違いを生んでいるとして、選び方に対してより分析の余地があるのではないか? とコメントされています。
この問題に $f_n$ が二次関数で定義されるもので、一般の凸関数に対する議論ではないという限定的な課題設定ではあるものの、理論的な証明を与えたものが Why Random Reshuffling Beats Stochastic Gradient Descent です。
証明の気持ち
証明のアイディアは Cyclic や RR では $N$ 回ごとに必ず1度すべての関数 $f_n$ を使って計算が行われる、という点です。この性質を利用して各 $N$ 回の更新ごとの場所を $x_j$ のように定義して $x_j$ と $x_{j+1}$ の位置関係に対して成り立つ収束レートを元に、 Polyak-Rupper Averaging と呼ばれる方法で使われているような「過去 k 回の平均値を使う」ことで RR の方がレートとして早くなることを証明しています。
ライブラリでの実装はどうなっている?
順番に関する実装ですが、pytorch や keras では普通に実装するとデータセットを1度全部見るまでは順番が固定で取り出して、一周するごとに順番を並び替える用になっています。これはまさに Randomly Reshuffle ですね。
まとめ
- SGD には順番をどう選ぶか、という問題が存在しています。
- 特定の関数に対してではあるが RR が Random よりも早いという証明を与えた論文を紹介しました。
最後論文をガッツリ紹介しようと思っていたのですが、内容も詳しく追えておらず(最適化特有の数式 & 僕のパワーのなさ…) アイディアの一欠片ぐらいしかご紹介できませんでした。が、こういう問題意識や、その周辺の議論もあるんだなーというのを知っていただいて、あわよくば最適化おもろいな、と興味を持っていただけたのであれば幸いです。
参考文献
pyspark (mmlspark) で LightGBM 使うときのメモ
Spark 上で Machine Learning を行うためのツールを Microsoft が MMLSpark (Microsoft Machine Learning for Apache Spark) というパッケージで公開しています。
https://github.com/Azure/mmlspark
この中に lightGBM on spark も含まれており python を使って spark 上で lightGBM の学習を回すことができます。 これを使っていて、わからなかったこと・困ったことがいくつかありましたので、調べたメモ書きを共有します。
mmlspark での LightGBM のパラメータ指定
mmlspark の lightGBM と本家本元の lightGBM のパラメータは引数の名前が違います。本家の lightGBM では無限にエイリアスが貼られているので雰囲気で名前を与えても動きますが mmlspark はそうではありません。
パラメータに関するドキュメントが公式にあればよいのですが見当たらないようです (おそらく sコードを読んで追いかけろということ?)。以下 Classifier
の方に関してざっと作ってみました。本家の方では fit
の方で指定するパラメータも __init__
で指定するようになっていることに注意してください。
class LightGBMClassifier: def __init__(self, baggingFraction=1.0, baggingFreq=0, baggingSeed=3, boostFromAverage=True, boostingType='gbdt', categoricalSlotIndexes=None, categoricalSlotNames=None, defaultListenPort=12400, earlyStoppingRound=0, featureFraction=1.0, featuresCol='features', groupCol=None, initScoreCol=None, isProvideTrainingMetric=False, labelCol='label', labelGain=[], lambdaL1=0.0, lambdaL2=0.0, learningRate=0.1, maxBin=255, maxDepth=-1, maxPosition=20, minSumHessianInLeaf=0.001, modelString='', numBatches=0, numIterations=100, numLeaves=31, objective='lambdarank', parallelism='data_parallel', predictionCol='prediction', timeout=1200.0, useBarrierExecutionMode=False, validationIndicatorCol=None, verbosity=1, weightCol=None): """ Args: # tree 成長に関わるパラメータ numIterations: 大事. learning rate との兼ね合いで決まります. rate を小さくすると iterations は多くしましょう. learningRate: 大事. 0.3 ~ 0.1 ぐらいが普通. 小さくすると精度は上がる傾向があるけれど時間もかかります maxDepth: 3 ~ 8 ぐらい. 問題によって最適な値が違うことが多いですが 3=8位を設定しておけば大概はOK。 # 正則化 lambdaL1: lambdaL2: L1/L2 正則化. 大きくすると木の成長がゆっくりになります (極端な予測を持つ葉ができにくくなる) numLeaves: 葉の数. 2 ** max_depth より小さくすることで木の成長をゆっくりにできる # ロバスト性 / 正則化 baggingFraction: bagging (subsamples) の割合 0 ~ 1 featureFraction: columns のサンプリング割合. 0.3 / 0.5 / 0.7 とかを使うことが多いです. baggingFreq: bagging を実行する頻度. たとえば 10 とすると tree を 10 個作るごとに bagging を実行します baggingSeed: bagging でランダムに選ぶときの seed 値 boostingType: `"gbdt"` で問題ないと思います minSumHessianInLeaf: カジュアルに言うと一つの予測集合の最小の大きさみたいなものです。(Objective="mse" のとき一致します) 小さすぎると細かい粒度での分割を許可してしまうのである程度大きい値を指定してあげても良いと思いいます。 (集合サイズなので学習させるデータの大きさに依存して良い値が変わってくるのに注意) # eval set 作るとき validationIndicatorCol: ちょっと指定のしかたが lightGBM (python) と違っていて, evaluation set を作りたい場合, fit にわたすデータに valid data を concat して train or valid を表す 0-1 の flag を `validationIndicatorCol` に指定します. earlyStoppingRound: 指定された回数以上 objective が改善しないときに学習を中断します. 50 とか 100 を指定することが多いです。 """
References
2: evaluation set をつかって early stopping をする
学習させるとなると early stopping を使いたいでしょう。本家 lightGBM では fit
を呼び出す際に eval_set
に validation 用のデータを渡せば early stopping を行ってくれますが mmlspark の場合そうではなく若干クセがあります。これも公式ドキュメントに記載が何もなかったので作ってみました。大きな流れは以下のとおりです
- validation かどうかのフラグをデータにもたせる
- evaluation data を train data と merge (pyspark method でいうと union) する
validationIndicatorCol
に作ったフラグのカラムを指定する
python api の流儀と異なっているので注意してください。
Reference
- https://github.com/Azure/mmlspark/issues/435
- https://github.com/Azure/mmlspark/blob/master/src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMParams.scala#L310
Sample Code
train_df = # (some pyspark dataframe for train) eval_df = # (me too. for test) # 1. `"validation"` column にこのデータが train / eval であることを表すフラグを作る # python api の用に fit で eval_set= とは渡さないことに注意 _train = train_df.withColumn('validation', F.col('target') * 0)[['target', 'feature', 'validation']] _test = eval_df.withColumn('validation', F.col('target') * 0 + 1)[['target', 'feature', 'validation']] input_df = _train.union(_test) from mmlspark.lightgbm import LightGBMClassifier params = { 'numIterations': 200, 'learningRate': .2, 'maxDepth': 7, 'lambdaL2': 1., 'numLeaves': 31, 'baggingFraction': .7, 'featureFraction': .5, 'baggingFreq': 3, 'minSumHessianInLeaf': 100. } # 2. `validationIndicatorCol` を指定する clf = LightGBMClassifier( validationIndicatorCol='validation', earlyStoppingRound=50, labelCol='target', featuresCol='feature', **params ) # 3. call `fit`. (fit のときにはたんに事前に作った data-frame を入れる) clf = clf.fit(input_df)
note
validation data での objective 及び metric 評価は spark 上で分散計算されないようです。そのため training data と同じぐらい大きいデータセットを validation に指定すると速攻で Out-Of-Memory になりますので注意してください。
解釈可能な機械学習モデルを作るライブラリ Interpret を Docker で動かすときのメモ
解釈可能なモデリングを目的としたライブラリ interpret
を使うときに Docker で利用していると動かないという現象があったのでその解決方法です。
環境
python=3.7.7 / conda==4.8.3 で動作する docker を利用しています。interpret自体のversionは interpret==0.2.1 です。
FROM registry.gitlab.com/nyker510/analysis-template/cpu:1.0.4 RUN pip install -U pip && \ pip install \ interpret==0.2.1
現象
docker で起動している jupyter から interpret show を実行しても何も出てこない (本来は分析結果の画面が表示される)
from interpret import show from interpret.data import Marginal import pandas as pd df = pd.DataFrame(np.random.uniform(size=(10, 100))) marginal = Marginal().explain_data(df, df[0], name='train data') show(marginal)
原因
interpret は output として静的なコンテンツを生成しておらず、逐次 interpret がバックグラウンドで起動している API からデータを取得しているようです。試しに F12 の検証から開いてみると http://127.0.0.1:7001/140350020331408/
に対してリクエストを送っていることがわかります。
interpert は docker 内部で実行されているため, 内部での localhost:7001 と jupyter の実行環境の localhost が一致していないためリクエストが通らず、エラーになっています。
解決方法
- docker 起動時に interpret のための port を開放しておく (例えば 7001番など、これをAとします)
- interpret のサーバーを
0.0.0.0:A
に変更する (たとえば0.0.0.0:7001
)
変更は from interpret import set_show_addr
から行えます。例えば 7121
port でつなぐなら以下のような感じ
from interpret import set_show_addr set_show_addr(('0.0.0.0', 7121))
これを実行すると画面が表示されます。やったね。
リモートサーバーの時
適当なリモートサーバーで jupyter を起動している時もあると思います。その場合には set_show_addr
で指定するホストをそのサーバーの名前にしましょう。例えば https://www.example.com
で繋いているのであれば以下のような感じです。
from interpret import set_show_addr set_show_addr((https://www.example.com', 7121))
python: loggingの出力値を文字列として取得したい
python の logging で出力した info とかを文字列として取得したい! という場合の方法についてのメモです。
下準備
- 単純な logger と stream handler (コンソールへの出力のハンドラ) を用意します。
- 詳しくは https://docs.python.org/ja/3/howto/logging.html#logging-advanced-tutorial など参考にしてみてください。
今回は logger / hander 両方に INFO level をつけましたので info よりも重要度が高いものだけ console に output されるようになっています。
from logging import getLogger, StreamHandler, Formatter handler = StreamHandler() handler.setLevel('INFO') logger = getLogger('nyk.510') logger.setLevel('INFO') logger.addHandler(handler)
https://docs.python.org/ja/3/howto/logging.html#loggers
組み込みの深刻度の中では DEBUG が一番低く、 CRITICAL が一番高くなります。たとえば、深刻度が INFO と設定されたロガーは INFO, WARNING, ERROR, CRITICAL のメッセージしか扱わず、 DEBUG メッセージは無視します。
なるほど。というわけで、一旦試してみましょう。
logger.warning('warn') logger.info('foo') logger.debug('debug') # debug はでないよ warn foo
確かに debug は出ないようになっていますね。
loggingの出力値を文字列として取得
さて本題の logging の出力をテキストとして取得する、です。これは要するに上記の例で言うと warn / foo / debug みたいな文字列を取得したい、ということです。 結論をいうと StringIO を stream にもつような handler を作成して logger に付与すればOKです。 テキストとして取得するっていうのはだいたい log をどこかに保存したいとかいう気持ちがあると思いますので、ちょっとおしゃれな formatter にして時間等も取得できるようにしています。
log_capture_io = io.StringIO() stream_handler = StreamHandler(stream=log_capture_io) # オシャに formatting formatter = Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') stream_handler.setFormatter(formatter) stream_handler.setLevel('INFO') logger.addHandler(stream_handler)
この状態で logger に先と同じように log を記録します。
logger.warning('warn') logger.info('foo') logger.debug('debug') # debug は console にでないよ
log の取得
作成した StreamIO から getvalue
すればOKです。
- 通常のコンソールアウトプットは単に文字列だったが,
formatter
をリッチにしているので取得される文字列には何時 log が記録されたかなどの情報も入っている - コンソールの方も普通の
StreamHandler
に formatter を設定すれば時間も表示できる.
s = log_capture_io.getvalue() s.splitlines() ['2020-08-01 07:59:41,132 - nyk.510 - WARNING - warn', '2020-08-01 07:59:41,135 - nyk.510 - INFO - foo']
後片付け
ずっと handler が付いていると記録され続けるので、いらなくなったら消しましょう
- io の close
- handler のひも付けを logger から削除
removeHandler
# 終わったら消しましょう log_capture_io.close() logger.removeHandler(stream_handler) log_capture_io.closed # True
close してしまうと value はもう取れませんので注意
log_capture_io.getvalue() --------------------------------------------------------------------------- ValueError Traceback (most recent call last) <ipython-input-10-b5f4bf9c6e8d> in <module> ----> 1 log_capture_io.getvalue() ValueError: I/O operation on closed file
上記コードは gist にもありますので参考にしてください ;). Logging Recording · GitHub