データコンペサイトを作る: フロントエンド編
この記事はオンサイトデータコンペ atmaCup のシステムぐるぐる https://www.guruguru.ml/ のフロントエンドの話です。
atmaCupとは?
オンサイトデータコンペatmaCupとは実際に会場に集まり、準備されたデータをテーマに沿って分析・予測を行い、その精度を競うイベントです。 データコンペで有名なのはKaggleですが、みんなで実際に集まり、かつ時間もその日のうち8時間など短いのが特徴で、 参加者のスキルがオンラインのデータコンペより強く結果に表れます。
直近 atmaCup#4 はリテールAI研究会様 との共同開催でした。期間が 3/1 ~ 8 の1周間、すべての分析はAzure上で行うというものでした。 今までのatmaCupでは1日でかつ環境は自由というものだったので、期間・環境ともにはじめての試みでした。 環境構築が大変だったかたや途中トラブルなどもありましたが、無事に終わることができ主催としてはホッとしているところです。
https://www.guruguru.ml/competitions/9/leaderboard より最終結果をみることができます。
このatmaCupですが「実際に集まる」と言ってもスコア計算用のファイル(以下では kaggle に合わせて submission file と呼びます)をUSBか何かでもらうわけには行かないので、submission file を upload してスコア計算し、チームごとにランキングが出るようなウェブアプリケーションぐるぐる https://www.guruguru.ml/ を別途作っています。この記事はぐるぐるのフロントエンド周りの話です。
バックエンドについては以前記事にしています のでそちらもどうぞ。
全体構成
まずは全体構成ですが、バックエンドのDRFに対して、SPA(SinglePageApplication)で実装されたフロントエンドがいて、適宜バックエンドに必要な情報をとっている、という感じです。
https://cloudcraft.co/view/45c7af86-888a-49cd-a7ff-6ca026e9e6e0?key=opGf3vj8B6oy9f0Z2zJwyQ
バックエンド、フロントともに ECS でデプロイされていて適宜スケーリングするような設定になっています。
フロントは Nuxt.js (Vue.jsをSPAとして適用しやすくするためのフレームワーク) で実装されていて、master merge のタイミングで ECS へのデプロイを gitlab CI で動かすような感じです。
フロントの構成は大きく
- Vue.js
- Nuxt.js
- Vuetify.js
に依存していますので、まずはその話から。
Vue.js
ぐるぐるフロントではメインのjsフレームワークとしてVue.jsを採用しています。Vueはウェブページに表示されるテンプレート部とデータ部分の管理をするためのライブラリで、データに合わせて動的に表示を変えたりすることが簡単に出来るのが特徴です。
普通のjsで画面を操作しようと思うと、データが変わったあとに自分で画面を書き換える操作が必要です(jQueryなどはそう)が、Vueはその整合性をVue側で担保してくれます。そのため僕を含めた実装者はデータの操作部分の整合性に注力するだけで済む用になります。
Nuxt.js
Nuxt.js とは 「Vue.js のゆにばーさるあぷりけーしょん」を作れるフレームワークです。ゆにばーさるあぷりけーしょんってなんやねんということなんですが要するにVueで動くようなウェブページを便利に作れるぐらいな認識で良いと思います。これを使うことで自分でhtmlを書いてVueをレンダリングする処理をまったく書かずにVue Compnentだけ記述してウェブページを作成することができます。
Server Side Rendering (略してSSR. サーバー側で js を読み込んでページを作ってからクライアントに返すためクライアント側での展開が不要になる。) にも対応しているモダンな子です。
Nuxt のいいところを列挙すると以下のような感じです。
- 基本的に使うライブラリが初期で入っている
- 強い制約がある
- 使い回しを考えてコードを書きやすい
基本的に使うライブラリが初期で入っている
これはたとえばSPAアプリケーション全体での変数を保持する仕組みを提供してくれる vuex
や vue-router
等を指します。 Nuxt.js を入れるだけでこれらのライブラリは一通り揃いますし、これらのライブラリを使って実装することを想定してライフサイクルが定義されているので、自分でごちゃごちゃと config をいじることがほとんどありません。
なのでNuxt.jsを入れればあとはアプリケーションの実装に集中できてとても便利です。
強い制約がある
もうひとつ大きないいところとして「強い制約がある」という点があります。たとえば /hoge
にアクセスした時に表示するページを定義したい場合、 pages
というディレクトリに hoge.vue
というファイルを作ることを強制されています。これは最初始めるときに学習コストがかかるというデメリットはありますが、だれが書いても同じところに書かざるを得ないという素晴らしいメリットがあります。
この辺は Django / DRF にも通じるところがありますが、フレームワークの本質は強い制約を与えることで同じ書き方しかできなくすることだと思います。これによって他のプロジェクトに持っていく・持ち込むことが容易になりますし、他人のコードがとても読みやすく、コラボレーションのコストを下げてくれるのが良いところです。
使い回しを考えてコードを書きやすい
これもほぼ前項と一緒なのですが制約故に特定の機能を作りたい時、特定の場所に書くことが義務づけられているので、他のNuxtプロジェクトでの使い回しを考えつつコードを書けるのはメリットかなと思っています。
Vuetify
Vuetifyとは
Vuetifyは多数のコンポーネント(たとえばボタン要素やリストでの表示部分など)やグリッドレイアウトの仕組みが整ったVue.jsのためのフレームワークです。デザインにGoogleがかくあるべしを定義しているマテリアルデザインが採用されていてわりと最近人気です。そもそも僕がマテリアルデザインが好きというのも有るのですが、それ以外にもとてもいい点があります。
vuetify使っていていいなーと思うところをまとめると、以下のようになります。
実装済みコンポーネントの種類が圧倒的に多い
一番最初に Vuetify を使い始めたのはこれが一番大きかったと思います。本当に種類が多くてボタンとかフォーム見たいな基本的な物は勿論のこと、画像の遅延ローディング(最初サムネイルを読み込んだあと綺麗な画像を読み込むやつ)とかTimelineの表示、ステップでの入力などなど多彩です。あまりに種類が多いので、ちょっと前までゴリゴリコンポーネントが増えていた時には、開発途中で知らない便利コンポーネントが生えていて「??」となることも多々ありました。最近は出揃った感もあって安定している気がします。
https://vuetifyjs.com/en/components/calendars/ このあたりから見て見てるとワクワク感が伝わるかなと思います。リンク先はカレンダーですが、パット見Googleカレンダーにしか見えない…。
種類が多いとやはり初速が出るので、ざっくりページを作る時非常に助かります。またデザインのカスタムもある程度考慮して定義されている(一部フォーム系は癖がありますが)ので、最初プロトタイプを作ってデザインを当て込んで精緻にしていく、という様なMVPに沿った開発ができるため重宝しています。
ありえんぐらいのコミット数
本当にありえないぐらい活発にコミットされています。僕のgithubアカウントの通知はほとんどVuetifyで埋まっています。
最近V2になったのですが、その前後などは「こういうコンポーネントあったら便利だよね」というのが次々実装されていっていて、勢いを感じました。(今現在OpenしているIssueが1000弱あり、PullRequestは66個あります。) それだけ発展途上という見方もできますが、次々にバグや使い勝手の良い機能が充実していくのでユーザーとして非常にありがたいです。
一緒についてくるCSSとレイアウトコンポーネントが優秀
Vuetifyがいくらコンポーネントがあると言っても、自分が使うときには少しデザインを直したいことが多いです。たとえばボタンと他の要素との距離を取りたいとか、画面の中央に浮かせたいとかですね。
そういう時 style で書いたりcssでクラス定義してやることが多いと思います。最近はBEMに従ってクラス名を書いて、Vueのscoped styleで外部に影響を及ぼさないようにして、というふうに書くわけですがどうしてもクラス間でのスタイルのバッティングや、paddingのpxを埋め込んでしまって後で直すのが大変だったりします。sassを使って変数定義しても良いのですが僕は面倒くさがりなので余りやりたくないのです…
Vuetifyではコンポーネントデザイン用のCSSと別に、スタイル調整用CSSがついています。このCSSがfunctional(名前どおりにスタイルが適用される)に実装されていて、これがとても使い勝手が良いのです。
たとえば class="pt-2"
と書くと padding-top が 4 * 2 = 8px
だけ付く、見たいな感じです。これを使うことでその場限りのデザイン用のスタイルをささっと書くことができます。また Vuetify のコンポーネントのスタイルよりもこのCSSのほうが強く付く様に定義されているため、自分が書いたclassが反映されない!!というよくある悲しい事件がないため、非常に使い勝手が良いです。
こだわりポイント
つらつらとフレームワークのいいとこを書きましたので、次にぐるぐるに入れている僕のこだわりポイントを紹介したいなと思います。
submission でスコアが上がった時の通知
通称?「緑の画面」と呼ばれているものです。これはsubmissionしたときのスコアが過去最高だったとき、画面全体を使っておめでとうコメントとスコアをカウントアップして出してくれる機能です。いいスコアがでたら喜びたいよね?というので作りました。結構反響があってわりと嬉しいです:D
使っているちょっと込み入った?こととして数値のカウントアップ部分があります。これは Tween を使って自作しています。
<template> <span>{{ text | toFixed(fixed) }}<template v-if="showChange"> {{ change | toFixed(fixed) }} </template> </span> </template> <script> import TWEEN from '@tweenjs/tween.js' function animate(time) { requestAnimationFrame(animate) TWEEN.update(time) } requestAnimationFrame(animate) export default { props: { value: { type: [Number, String], default: null }, from: { type: [Number, String], default: null }, duration: { type: [Number, String], default: 1000 }, delay: { type: [Number, String], default: 0 }, fixed: { type: [Number], default: 4 }, showChange: { type: Boolean, default: false } }, data() { return { current: null, change: null } }, computed: { text() { return (+this.current).toFixed(this.fixed) } }, watch: { value(newVal, oldValue) { if (!!this.from) this.countUp(+this.from, newVal) if (oldValue) this.countUp(+oldValue || 0, newVal) } }, mounted() { this.current = this.value }, methods: { run() { this.current = this.from this.countUp(this.from, this.value) }, countUp(start, end) { this.current = +start this.change = +end - this.current const updateObject = { value: +start } new TWEEN.Tween(updateObject) .delay(+this.delay) .to({ value: +end }, +this.duration) .onUpdate(() => { this.current = updateObject.value }) .start() } } } </script>
実際の画面は参加してsubmitしてからのお楽しみということで、ここには貼らないでおきます😌
emoji で可愛く
基本的にコンペって殺伐としているじゃないですか。*1なのでちょっとでも可愛い要素を入れたいなと思い絵文字を導入しています。
これは twitter の絵文字をレンダリングする twemoji というライブラリを使って実装していてカスタムディレクティブとして global に定義して
import Vue from 'vue' import twemoji from 'twemoji' Vue.directive('emoji', { inserted(el) { el.innerHTML = twemoji.parse(el.innerHTML) } })
こんな感じでつかいます
<template> <div v-emoji>🚀</div> </template>
こうすると v-emoji
で囲った部分に含まれる絵文字を twitter 社が出している絵文字の画像 dom に変換してくれるので、どの環境でみても同じような表示になってとてもハッピーです。
たとえばですがパスワードリセットページ https://www.guruguru.ml/auth/reset などに使っています。
LeaderBoard Chart Race [NEW!]
棒グラフが時系列でどのように変化するかをアニメーションで表示したのが barchart race と呼ばれるものです。最近 twitter とかでも良く見かけていてとてもわかり易いし見ていて楽しいので「リーダーボードにも適用したら面白んじゃないか?」と思って前回の atmaCup#3 のときにウェブサービスを使って作成したのが好評だったため、自作してみました。
今回からランキングの推移を見れるページを作ってみました! (summaryページ https://t.co/TstfYk1GlR から見れます。)添付は public の推移ですが private も見ることできます。いつの時点では、どんな感じのランキングだったか見れますので振り返りに使っていただければと思います😆 #atmaCup pic.twitter.com/J1mE3NNMEU
— ニューヨーカーGOTO (@nyker_goto) March 8, 2020
見ていただくのが早いので動画を貼りましたが、やっていることは単純で
- コンペ開始から終了までを一定間隔で区切って
- それぞれの時間までの一番良いsubmissionの値をチームごとに保存
- スライダーで設定された時間のスコアを並び替えて表示
という感じです。この画面ですがVue.jsのデータバインディングの効果が実に遺憾なく発揮されていて、僕は上記のチームごとのスコアの配列を作ることだけに集中できたので、結構さっくり作ることができました。むしろアニメーションや画面上の配置、色などのデザイン部分のほうが時間がかかっていた気がします。
とはいえ、アニメーションも Vue.js のリストトランジション機能を使っているので実質3行ぐらいで実装できてしまうのでVue.jsおそろしやという感じです。
まとめ
ぐるぐるのフロントエンドについて書きました。データ分析をしていてフロントをやっている人はあまりいないようなイメージがありますが(やっていてもAPI開発のバックエンドの方が多いような肌感)、データを見やすいようにグラフに落とし込んでいく作業と、ユーザーが使って楽しい・使いやすいサイトにするにはどう配置をしたら良いかを考えつつ実装していくフロントエンドの親和性はかなり高いなと感じています。
この記事をみて少しでもフロントエンド面白いな・このシステムが動いているatmaCup出てみたいなと思っていただければ幸いです。😆
atmaCup#5も開催できるように色々と練っているところですので、参加いただけると大変うれしいです。
参考リンク
- atma株式会社 https://www.atma.co.jp/
- atmaCupを運営している会社です。
- ぐるぐる: https://www.guruguru.ml/
- 上記で紹介したコンペサイトです。Nuxt.js + Vuetify で動作しています。
- リテールAI研究会: https://retail-ai.or.jp/
- atmaCup#4を共同で開催しました。リテール分野のAI活用に関して積極的に活動されています。
- https://www.guruguru.ml/competitions/9/leaderboard より結果を見ることができます。優勝は前々回優勝の pao さんでした。
*1:諸説あります
データコンペサイトを作る DjangoRestFramework編
この記事は atma Advent Calendar 2019 - Qiita 2019/12/21 の記事です。
今年自社のサービスとして オンサイトのデータコンペティション atmaCup をはじめました。 オンサイトデータコンペとは実際に会場に集まり、準備されたデータをテーマに沿って分析・予測を行い、その精度を競うイベントです。 データコンペで有名なのはKaggleですが、みんなで実際に集まり、かつ時間もその日の8時間と短いのが特徴で、 参加者のスキルがオンラインのデータコンペより強く結果に表れます。
このatmaCupですが当然やろうと思うとコンペ用のシステムも必要です。というわけで裏側のシステム 「ぐるぐる」 を僕が作っています。
この時記事ではそのバックエンド部分を担っている DjangoRestFramework
についてその便利さとどういう機能を使ってぐるぐるを作っているか、を少し紹介したいと思います。
おことわり
この記事ではコードを一部書いていますがプロジェクト全体の構成などは記述していません。ごめんなさい(力尽きました)
DjangoRestFrameworkでやってみたい!と思った方は Django REST Frameworkを使って爆速でAPIを実装する - Qiita こちらの記事などを参考に是非チャレンジしてみて下さい。損はしないです。
どういう構成か
先ずぐるぐるの全体のざっくりとした構成から紹介します。
インフラ
ぐるぐるではAWSを使っていて ECS (Fargate) によってコンテナとしてデプロイしています。 更新は gitlab の master merge のタイミングで gitlab-CI によって ECR に push & 新しいイメージを使って ECS を更新という CI を組んでいます。
Static File はすべて CloudFront 経由で S3 につなげていて、データベースは private subnet 上の RDS に接続というよくある構成です。
開発環境
docker / docker-compose でアプリケーションごとにイメージを作っています。
アプリケーション
フロントエンドとバックエンドを切り離したRESTFULLな構成です。特にコンペはスコア計算時にかなり重たい処理が入りますので、フロントと分離するのは自然かなと思っています。 フロントエンドはNuxt.jsでバックエンドはDjangoRestFrameworkを採用しています。
DjangoRestFramework
DjangoRestFramework (略して DRFということもあります) は python の web framework の django を rest api 用に拡張したライブラリです。 python で書かれていることもあり、機械学習でいつも python を使っている僕にとってはかなり馴染みやすいフレームワークでした。
個人的に押しの子なのですが Qiita でも余り人気がなかったり日本でははやっていないようで少し残念です…少しでもユーザーを増やしたいな、という思いもあってこれを書いていたりします。
DjangoRestFramework の立ち位置
DjangoRestFramework は python の web framework の中ではかなりカバー範囲の広いフレームワークです。カバー範囲が広いのでこれだけですべて完結できる便利さがある一方その分重たいという欠点はあります。 python の web framework には他にも flask や falcon などより軽量な物があり、機械学習用途ではこちらが使われていることが多い印象です。
とはいえそれを上回る実装の速度感が魅力で、この記事ではその一部でもお伝えできればなと思っています。
DRF のいいところ
- DBのことを気にしなくても良い設計
- Model に関連した機能が豊富
- Adminsite がほとんどなにもしなくても出来る
- Document ページがマジで何もしなくても出来る
の4点にあると思っています。
DB のことを気にしなくていい
DRFを使っている時、DBとのやり取りを気にすることはほぼありません。たとえばですが各コンペ用のテーブルが作りたいな、と思ったとします。 Submission に必要そうな情報というと
- コンペのタイトル
- 説明文
- 締切日
など要りそうですね。この場合であればモデルとして以下を定義して
from django.db import models class Competition(models.Model): title = models.CharField(max_length=128, help_text='foo') description = models.TextField(max_length=10000, help_text='説明文') finished_at = models.DateTimeField(null=True, blank=True, auto_created=True, help_text='コンペ終了の日時')
migration をするだけで簡単に table 作成をやってくれます。
# migration ファイルの作成 [django@700e14ec6a91 django]$ python manage.py makemigrations competition Migrations for 'competition': diary/competition/migrations/0001_initial.py - Create model Competition # DBへの反映 [django@700e14ec6a91 django]$ python manage.py migrate Operations to perform: Apply all migrations: admin, auth, competition, contenttypes, sessions, v1 Running migrations: Applying competition.0001_initial... OK
あとは makemigration 時に出来たファイルを git に乗せて、本番用サーバーで python manage.py migrate
だけ動くようにしておけば model で定義した最新の状態にまで DB の状態を変えてくれるのです。楽ちんですね。
なおこの migration はかなり賢いのでたとえば field の名前を間違えたので直したい、とか新しく field を定義したい、みたいなことぐらいであれば自動的に検知して DB の column 名を変更するような SQL を発行してくれます。(今回は特に時間がなかったので基本的に何も考えずにとにかくmigrationしているので competition だけで40ぐらいファイルができていますが不具合は一度もありませんでした)。
Model に関連した機能が豊富
上記の Model と migration などは Django の機能なんですが (import も django からなのがわかると思います), ここから API のエンドポイント作成までが DRF が担う部分です。 このとき Model に関連した view を作るとき DRF の爆速感が本領を発揮します。
まずDRFでは大きく2つの概念があります。
この serializer がやっている処理は地道で面倒ですが必要です、というのもユーザーは良からぬリクエストをしてくる可能性がありますから int にしたいときは int へのキャスト処理を書いて〜とやる必要があります。 これを自分で書いていると案外面倒です。
もっというとこの変換はデータベース(すなわち Django の Model) に関連したものが多いです。 たとえばチーム作成をする時を考えてみてください。作成用のエンドポイントでは、チームのmodelに関連した情報をユーザーは送ってくるので、チームのmodelで定義したフィールドのそれぞれに対して json-> python object への変換を書くことになるはずです。 要するに Model の定義と Serialzier の変換は対応関係にある場合が多いのですね。
そこで DRF だと Model の定義に合わせて自動的に変換する便利クラス ModelSerializer
があります。これを使うと先の Competition は
from rest_framework import serializers from .models import Competition class CompetitionSerializer(serializers.ModelSerializer): class Meta: model = Competition fields = '__all__'
みたくかけます。これだけで Model に定義された _all__
field, ようするに title, description, finished_at の3つのフィールドそれぞれをいい感じに変換してくれます。
(場合によっては全部必要ない場合もありますのでその場合は exclude
などで除外フィールドを書くことも可能です)
あとは view を作るだけですが、こちらもモデルから作成してくれる ModelViewSet
を指定して先の serializer と model objects を登録した viewset を作って
from rest_framework.viewsets import ModelViewSet from .models import Competition from .serializers import CompetitionSerializer class CompetitionViewSet(ModelViewSet): serializer_class = CompetitionSerializer queryset = Competition.objects.all()
これを urls に加えます.
from rest_framework.routers import DefaultRouter from . import views router = DefaultRouter() router.register('competitions', views.CompetitionViewSet, basename='competitions') urlpatterns = [ *router.urls ]
基本的には, これだけで competition に対する CRUD のエンドポイントが出来てしまいます。簡単!
あとは承認周り(どんなユーザーがエンドポイントにアクセスできるか)やValidation(どういう状態のリクエストは許可するか)なども DRF 側である程度指定されていてるので、それに沿って書くことで何がどこにあるのかよくわからない状態を防げますので、複数人開発や規模が大きいプロジェクトの場合にも強いのかなと思っています。
管理者用の画面
データベースに直接見に行くのでもいいですがwebから見れる管理者側の画面があるとなにか変なことがあったときやノンエンジニアの人に頼んだりする時とても便利です。 今回のプロジェクトは時間は無いのでフロントを作らずすべて Django Admin で実装しています。
実装と言ってもめっちゃ簡単で、最低 admin site に表示するだけなら下記の4行あれば作れてしまいます。 これで competition の CRUD は admin 側から行えるようになります。簡単。
from django.contrib import admin from . import models @admin.register(models.Competition) class CompetitionAdmin(admin.ModelAdmin): pass
これだとたとえば「いま終わっているコンペだけ表示したい!」とかの要望に答えられないのでちょっと手を入れて実際の master branch は以下のようになっています。 手を入れると言っても表示する項目書いているだけなんですけどね…
@admin.register(models.Competition) class CompetitionAdmin(admin.ModelAdmin): list_display = ( 'id', 'published', 'private_rank_is_confirmed', 'title', 'finished_at', 'applying_teams', 'total_submissions', 'is_finished', 'calculate_status') ordering = ('-finished_at',) search_fields = ('description',)
Documentページ
実際にこのAPIをデプロイしたとしましょう。そうすると次はフロントエンドとの繋ぎこみをする必要があります。 この時結構厄介なのがAPIの仕様をFront作業者にどう伝えるのか、という問題です。
よくある方法とその問題点
一般には仕様書的なものを書いて、それを共有したりするんだと思いますがそれだとどうしても仕様と実装の二重管理になるため
- 仕様と同じようにAPIが作られているかがわからない(APIがまちがっている)
- そもそも仕様が古くて今のAPIとずれている(仕様書がまちがっている)
- 今はまだそのエンドポイントができていない(未着手だった)
といった理由で上手くAPIを使えないことがあります。
DRF のドキュメント生成機能
DRFは今の実装状態から document を自動生成して、それもエンドポイントに付け加える機能があります。 要するに今の実装で使えるエンドポイント一覧をいいかんじに表示するアプリケーションとしても機能するということです。これは見てもらったほうが早いと思うので実際の表示画面をお見せしたいと思います。 このドキュメント機能を使う場合は urls に一行追加すれば OK です。簡単ですね。
from django.contrib import admin from django.urls import path, include from rest_framework.documentation import include_docs_urls urlpatterns = [ path('admin/', admin.site.urls), # 中略 path('docs', include_docs_urls(title='documentation')) # これが document 用の url ]
この状態で /docs
にアクセスすると実装されている API 一覧が見れます。
この場からエンドポイントを叩くことも可能です。たとえば nyk510
が join しているチーム一覧などは以下のような感じ。
またちっちゃいですが左下にログインをするタブがあり、そこから特定のユーザーとしてログインした状態でエンドポイントを叩くことも可能です。 これがあるので、最近は Postman みたいな API 叩くようクライアントを使う機会がとても減りました。
また表示されるのは「実装されている」エンドポイントであることも重要です。これによって 常に実装を正としてAPIを参照することが可能です。 ドキュメントを作っている時に比べて、フロントエンドとのコミュニケーションがかなりスムーズに行なるようになってかなりいいなーと思っています。
作った機能
というわけで便利なDRFで色々と作りました。 基本的な機能は Kaggle を参考にチームマージ機能やディスカッション、それに紐づくコメントといいね、通知などを実装しています。
そのなかでもポイントを挙げると以下のような感じでしょうか。
- Publicは即座に・PrivateScoreはコンペが終わってから
- Privateスコア計算をadminsiteから実行できるように
- discussionをnotebookから作成
- 通知機能
- LB周りのテスト
Publicは即座に・PrivateScoreはコンペが終わってから
これは完全に僕のこだわりなのですが、なんとなく主催者だけが最終スコアがわかっているのって卑怯な感じが無いですか? という理由からコンペが終わるまで submission の private score は計算されないような仕組みになっています。 (そのせいでコンペ終了後のスコア計算が上手く回るかどうか毎回めちゃくちゃドキドキしています。これとても心臓に悪いので、この仕様は将来的には変えるかも知れません…)
反対に PublicScore はできるだけ素早く返せるよう、あえて他のインスタンスで計算する構成にしていません。これは正直コンペ開催中ぐらいであれば ECS のクラスタ数を増やすなどで対応できるだろうという読みと、なにより僕のユーザー体験としてスコアがなかなか計算されない辛さを体感しているのが大きいです。実際見ていても数百万行の AUC とかならまだしも 1万行程度の RMSE ならそんなに計算コストもかからなさそうなので、当分はこの方針で行きたいと思っています。
Privateスコア計算を adminsite から実行できるように
PrivateScoreは上記に書いたように計算冴えていませんので、どこかのタイミングで計算を実行させる必要があります。 はじめは terminal からやるようにしていたのですが atmaCup#2 から admin site からボタンひとつで実行できるようにしました。 といってもやることはかなり簡単で、 adminsite 用のテンプレートをちょっと編集して
{% extends 'admin/change_form.html' %} {% block submit_buttons_bottom %} {{ block.super }} <h2>Jobs</h2> <div class=""> <div> <input type="submit" value="Run Calculate Score" name="_calculate_score"> <div style="color: #7b7b7b; padding: 8px">コンペに提出されたサブミットすべてに対して, private/public のスコア計算を実行します。<br> [NOTE] コンペが終了していない時実行できません.</div> </div> </div> {% endblock %}
あとはこのファイルを admin site の change_form_template
に加えるだけです。
@admin.register(models.Competition) class CompetitionAdmin(admin.ModelAdmin): # 中略 change_form_template = 'admin/competition/change_form.html' def response_change(self, request, obj): if '_calculate_score' not in request.POST: return super(CompetitionAdmin, self).response_change(request, obj) if not obj.is_finished: self.message_user(request, message='コンペが終了していません. 計算を行えるのはコンペ終了後です. ', level='warning') return HttpResponseRedirect('.') self.message_user(request, message=f'{obj.title} の submission に対するスコア計算を開始しました. (合計{obj.total_submissions}件)') t = threading.Thread(target=run_score_job, args=(obj,)) t.setDaemon(True) t.start() return self._response_post_save(request, obj)
これでボタンが押されたタイミングで admin class に指定された method (今の場合は _calculate_score
) が動くようになります。
self.message_user
を使うとオシャレに popup でエラーなどのメッセージも出せてとても便利でした。
discussion を notebookから作成
これは途中で discussion 作った後に「notebookから直接作れたら便利じゃない?」という声をうけて作ったものです。
めんどくさそうですが nbconvert
を使うと案外楽です(画像以外は)。
今回の要件として最終的な format が markdown だったので MarkdownExporter
を使っていますが他のフォーマットでも同様に作れると思います。
若干面倒なのは markdown 生成後の画像ファイルがもとのままだとテキストとして代入されているところでしょうか。
今回のDRFではstaticfileは別途保存するようにしていたので ContentFile
で画像をobject化して別のモデル UploadFile
を使って s3 などに upload するという作業を入れています。
from django.core.files.base import ContentFile from nbconvert.exporters import MarkdownExporter from rest_framework.decorators import api_view, parser_classes, permission_classes from rest_framework.parsers import FormParser, MultiPartParser from rest_framework.permissions import IsAuthenticated from rest_framework.views import Response, status from main.discussions.models import UploadFile @api_view(['POST']) @parser_classes([MultiPartParser, FormParser]) @permission_classes([IsAuthenticated]) def notebook_to_other_format(request): """ Jupyter Notebook 形式のファイルから markwon を作成します。 このエンドポイントにはログインしているユーザーでないとアクセスすることはできません。 ### Parameter * `file`: 変換したい File Object です. 必ず拡張子は `.ipynb` である必要があります """ try: file_obj = request.data.get('file', None) if file_obj is None: raise ValueError('file is required') name, ext = file_obj.name.split('.') if ext != 'ipynb': raise ValueError('File Extension Must Be `ipynb`') txt, metadata = MarkdownExporter().from_file(file_obj) outputs = metadata['outputs'] metadata['filename'] = name for key, img in outputs.items(): upload_file = UploadFile(name=key) content = ContentFile(img) upload_file.file.save(key, content) upload_file.save() txt = txt.replace(key, upload_file.file.url) return Response({ 'body': txt, 'metadata': metadata }) except Exception as e: return Response(data=dict(error=str(e)), status=status.HTTP_400_BAD_REQUEST)
通知機能
自分が参加しているコンペに新しい discussion ができたら通知してほしいなーって思ったのでつけた機能です。
主に django-notification
の機能を使っています。基本的に通知をしたいイベントをトリガにして notify.send
を呼び出すだけでOKです。
以下の場合だと Discussion (ユーザーがいろんな議論をするディスカッションを管理するためのモデル) が作成された時に、コンペに参加しているユーザー全員に対して通知を送るようにしています。 複数ユーザーへの対応も notification 側でやってくれるのでかなり便利です。
from notifications.signals import notify @receiver(post_save, sender=Discussion) def notify_discussion_create_handler(sender, instance: Discussion, created, **kwargs): """ ディスカッションが作成されたことを通知する handler """ if not created or not instance.competition: return # ディスカッション作成者以外で, コンペに参加しているユーザーに対して通知する target_users = User.objects.filter(Q(teams_as_owner__competition=instance.competition) | Q(teams_as_member__competition=instance.competition) ).exclude(pk=instance.created_by.pk).distinct() notify.send(instance.created_by, recipient=target_users, verb='create', description=f'新しい discussion {instance.short_title} が追加されました', target=instance)
あとは notification の model に対してアクセスする endpoint と既読や削除をするような endpoint を作ればOKです。
from notifications.models import Notification from rest_framework.decorators import action from rest_framework.pagination import PageNumberPagination from rest_framework.permissions import IsAuthenticated from rest_framework.response import Response from rest_framework.schemas import ManualSchema from rest_framework.viewsets import ReadOnlyModelViewSet from .serializers import NotificationSerializer class NotificationViewSet(ReadOnlyModelViewSet): serializer_class = NotificationSerializer queryset = Notification.objects.order_by('-timestamp') \ .select_related('actor_content_type', 'target_content_type', 'action_object_content_type').all() permission_classes = (IsAuthenticated,) def get_queryset(self): return self.queryset.filter(recipient=self.request.user) @action(methods=['PUT'], detail=False, url_path='mark-all-as-read', schema=ManualSchema(fields=[], description='ユーザーに紐づくすべての通知を既読に更新します')) def all_as_read(self, *args, **kwargs): qs = self.get_queryset() not_read = qs.filter(unread=True).count() qs.mark_all_as_read() return Response(data={'count': not_read}) @action(methods=['PUT'], detail=False, url_path='mark-all-as-delete', schema=ManualSchema(fields=[], description='ユーザーに紐づくすべての通知を既読に更新します')) def all_as_delete(self, *args, **kwargs): qs = self.get_queryset() n_active = qs.filter(deleted=False).count() qs.mark_all_as_deleted() return Response(data={'count': n_active})
フロントとつなげるとこんな感じに通知を出すことが出来ます。ちゃんと object に「どこから通知がきたのか」が保存されていますのでフロント側ですこし処理を入れればいいねされたディスカッションのページに移動することが出来ます。
これ以外にも以下のようなイベントで通知を行うようにしています。
- 自分が作ったディスカッションにコメントがついた時
- 自分が作ったコメントにコメントがついた or いいねがついた時
- チームマージのリクエストが来た時
これで少しでも見逃しに気づいてくれることがふえたらいいなーと思っています。(コンペやってる時は時間に余裕がないので、あえて煩いぐらいに出そうと思っています)
LB周りのテスト
ちょっとめんどくさいのが submission や Leader Board が絡むロジック部分のテストです。 たとえば RMSE でスコアリングするようなコンペがあって、このコンペでの順位計算がちゃんとできていることをテストしたい! と思ったとします。
この時やらないといけないのはざっとあげても以下の要件があります
- private/public のスコア計算がちゃんと出来るか
- private/public のランキングが正しいか
- 小さい順にちゃんとなっているか
- 1回もsubmitしていない人(private scoreがNullの人)が順位がつかないようになっているか
- private スコアの計算対象が選択した submit になっているか
- 同じスコアの人が居た場合 submit が早い人が順位が高くなるか
- team merge した時に submit のひも付けもちゃんと merge してその時の public score も過去最高のものが選ばれるか
- マージしたのにスコアがマージされないと困る
- LateSubmission がランキングに影響しないか
- あとで submission した値で更新されると困る
そして順序などは RMSE のように小さければ良いものと AUC のように大きければ良い物の2つがあるため、これらは別々のテストケースとしてテストする必要があります。 これを実現するために
- 特定の metric で特定の public/private スコアになるような submission を作成する Fixture クラスを作成
- そのクラスに対して metric の値通りに計算ができているかどうかのテストケースを個別に作成
- 上記の Fixture クラスを使ってランキング計算等のエンドポイントのテストを作成
という2段階のテストをすることにしました。正直これをやっている時が一番大変だったかもしれません、がミスっていると一番困るところでもあるので致し方なし…
それ以外にも
- 自分の submission 以外にはアクセスできないこと
- マージリクエストを許可できるのはオーナー権限だけ
- privateLB はコンペ終了かつ管理者がOKしていないとアクセスできないこと
みたいな基本的な権限周りのテストも入れています。
使っているライブラリ
主に pytest
と parameterized
というテスト時にパラメータを指定できるライブラリを使って実現しています。
また late submission のように時間がからむテストでは未来の状態を作る必要があります(たとえば今日の18:00終了のコンペに対して明日の16:00にアクセスしたらOKみたいなこと)ので、時間をそのタイミングだけ変更することが出来るライブラリ freezegun
を使っています。
# freezegan を使っているテストの一例. with freeze_time のあたり. def test_not_allow_change_selected_submit_after_competition_is_finished(self): """コンペ終了後に submit file の選択を変えられないこと""" # 10回 submit する submit_ids = [] for _ in range(10): response = self._submit_to_my_team() sub_id = response.json()['id'] submit_ids.append(sub_id) # 一番最後の submit を選択できる res = self.client.patch(self.url_from_pk(submit_ids[-1]), data=dict(selected=True)) self.assertEqual(res.status_code, 200) self.assertIs(res.json()['selected'], True) # コンペ終了のちょっと前 (1時間前) なら選択を変えられる with freeze_time(self.competition.finished_at - timedelta(hours=1)): res = self.client.patch(self.url_from_pk(submit_ids[-2]), data=dict(selected=True)) self.assertEqual(res.status_code, 200) self.assertIs(res.json()['selected'], True) # コンペ終了後には選択 submit を変えられない with freeze_time(self.competition.finished_at + timedelta(hours=2)): res = self.client.patch(self.url_from_pk(submit_ids[-3]), data=dict(selected=True)) self.assertEqual(res.status_code, 400, res.json())
まとめ
早足気味ですがぐるぐるの裏側で使っているDRFの良い所と、実際に実装したときのお気持ちやポイント的なものを書いてみました。 データコンペサイトは普通の web-application に比べて考えることが結構あるので大変ですが、頭を使うのでとても楽しかったなというのが振り返ってみての印象です。やはりいつも使っているサービスを実装するぞ、っていう意気込みもあるのかも知れないですね。
これを見て DRF 面白そう!使ってみようかなー!と思っていただければ幸いです。:D
Kaggle Days Tokyo のオンサイトコンペに参加しました! #kaggledaystokyo
Kaggle Days Tokyo で開催されたオンサイトコンペに参加してきました!! 結果としては全体 88 チーム中 private で 56 位という悔しさの残る結果になりました。が同時に反省点と学びもとても多い素晴らしいコンペだったので、感想兼反省文を書いていこうと思います。
Kaggle Days Tokyo は2日間ありましたが、僕は体調不良のため2日目しか参加できませんでしたので1日目のことに関しては他の方の記事を参照していただければと思います。
どんなコンペだったか
日経電子版のログデータをもとにして、閲覧しているユーザーの年齢を当てるというタスクです。メインのテーブルには
- ユーザーid
- 見ているデバイス情報
- 記事のid
などがありこの記事 id が記事データと紐づくような構造になっています。記事データには
- タイトル (1/2/3)
- キーワード
- 本文
- 記事の発行された日時
などがありました。
基本的な戦略
僕は弊社でインターンをしてくれている もーぐりくん と一緒にチームで参加しました。 僕はオレオレのフレームワーク vivid があり初速がある程度出せるだろうという見積もりがあったので、僕が基本的なモデリングを行い feature importance など特徴量の重要な話などは共有しつつ、最後にマージしましょうという作戦を取りました。
NYK510 TimeLine
ML Bearさん が当日の timeline を書かれていて、後で反省するのに良いと思ったので僕も真似します。 正直かなり切羽詰まっていたのでほぼ覚えていないのですが git の message とともに振り返っていこうと思います。
- 10:35 initial commit
- atma-cup #2 のリポジトリをコピペしてスタートしました。
- 11:22
[update] version-1 feature
- データの構造と submission すべき情報を見て submission までの雛形を作成していました。このときにつかっているのはメインのテーブルだけでした。
- 11:50 [update] first submit
- 最初のサブミット. 単体モデルでは
Objective="poisson"
の LightGBM が一番良かったためそれで submit しました. この日で一番余裕があったタイミングだったと思います。 - 一旦 10 モデルぐらい作成して ridge 回帰も作っていたのですがそれぞれのモデルのチューニングが適当すぎたのか CV と LB の差が激しく撃沈
- 最初のサブミット. 単体モデルでは
- 12:01: FastText のモデル作成完了
- 13:10: [update] add pseudo labeling
- pseudo labeling をやってみたかったので実装してみていました。盛大にバグって精度が出ませんでした。このあたりから焦り始めます。
- ~ 14:30 [fix] bug
- NN を無理やり入れようと苦戦して bug と戦う羽目になりました。また optuna での tuning を回しながら特徴を作ろう、であったり記事のテキストを整形しよう、みたいな欲張りをしてそちらでも bug と戦っていました。この間の進捗は虚無でした。
- このあたりで一回もーぐりくんとMTGをして正気を若干取り戻し、記事情報を少し入れました。またもーぐりくんアイディアの特徴も入れてスコアは良くなりました。
- 15:17: [update] swem enbedding
- FastText で作った特徴量で記事情報の SWEM をしました。この時点でかなり上位陣と差をつけられていて差分がわからず混乱に陥っていました。(後で判明しましたがこのときロジックのミスでちゃんと本文情報を使えていなかったみたいです)
- 16:47 [update] genre
- 記事のジャンル情報を入れてみて若干のスコア改善。もすぐに抜かされまくるという状況。
- ~ 18:30 気づくとコンペはおわっていた
ふりかえり
悪かったこと
チームのメリットを活かせなかった
最初はまだしも、途中からは自分のことでいっぱいになり、全くチームとして機能していませんでした。
基本的に今回僕は独創的な特徴量は考えられませんでしたが、すこしの議論しかできませんでした。 一方でもーぐりくんはユーザーがどういう気持で閲覧するかなどを考慮した集計方法などを提案してくれていて、非常に頼もしかったです。 そのアイディアのパワーを生かしきれなかったのは反省です。
実装の正確性
今回は家のPCに ssh 接続して分析を行っていたのですが noetbook のバグ? かなにかで jupyter 上のブラウザからささっとファイル閲覧ができなくなり、最終的な output の特徴量をエクセルで確認するような作業をふっとばしていました。 これによって、よくある「概ねはあっているが、1行違うために生成される特徴量がおかしい」というパターンに陥って、性能がでないという自体になっていました。
次の日コードを見て気づいて修正すると Private で 11.70366
→ 11.41330
(13位相当) になりました。悔しい。
知っている内容を使えなかったこと
今回の solution ではほとんどすべてのチームが TargetEncoding を使ってカテゴリの埋め込みを行っていて、どのチームもとても良く効いたとのことでした。 僕はリークを恐れて最後まで「TargetEncodingを使う」ことが頭の中の選択の一つにも上がっておらず、ここで大きな差をつけられてしまった感があります。
ではこの情報に僕が到達できなかったかというと否で、なんなら atmaCup#2 でも target encoding の使い方が肝でしたし分析コンペLT会のハクビシンさんの発表も、なんなら前日の Jack さんの発表でも TargetEncoding 大事だよーということは言われていて twitter などでも耳にしていましたから「TargetEncodingの実力を過小評価していた」自分の落ち度です。
hakubishin さんのスライド: Target Encoding はなぜ有効なのか
Jack さんのスライド: How to encode categorical features for GBDT
ちなみにメインテーブルのカテゴリ変数に対して TargetEncoding を投入すると Private で 11.70366
→ 11.25799
(3位相当) になりました。悔しいなあ。
良かったこと
もーぐりくんがチームメイトだったこと
多分僕が二人のチームだったらもっと悲惨でした。ありがとう。
vivid で完走できたこと
今回のコンペで与えられていたテーブルは予測する際に一度ユーザーで集計をする必要があるデータでした。 そのようなデータは vivid を作ったときには全く想定していなかったのですが特徴変換の部分とそれを pipeline 的に行う部分を分離して実装していたおかげで対応できたので構造化がある程度できていたのかなと思っています。 またある程度の初速を出せたのは(あとでめちゃ抜かされてはいますが)うりではあるのでそれも良かった点でしょうか。
使っているうちに出てきたダメポイントについても、体を張ったテストだと思って、全部 issue にしてより強くしていきたいと思います。
たくさんの Kaggler と一緒にオンサイトコンペに参加できたこと
これは何者にも変えられない体験でした。特に夕方のある程度形勢が定まってきたような段階でも皆黙々と作業に打ち込んでいる様子は流石だなあと思っていました。 同じ場に参加できたこと、大変嬉しく思います。
またコンペのタスク・データについても非常に噛みごたえのある面白いデータでした。提供いただいた日経さん、また設計を主に担当されたであろう u++ さんに感謝です。
学び
精神的なこと
いつもは開催者側にいるのでなんとなーく大変なことはわかっていましたが、時間制限付きのコンペの大変さは僕の想像を遥かに上回っていました。*1
もとから知っていはいましたが僕がとても予想外・いつもと違う状況に弱いことを改めて知ることができました(悲しいことにほとんど覚えていないのですが夕方からコンペ終了までの僕はかなり狼狽していたと思います)。 そもそも、そういう状況になることを思って行動していない & 腹をくくって思考を切り替えられないのは分析コンペだけではなく必要な能力だと思います。正直ちょっとどうやって鍛えたら良いのかがわかっていませんが、精進します。
追試験
勉強のため上記で触れたような WordEmbedding・Pseudo Labeling*2 まわりなどのバグの修正 & TargetEncoding の追加等の修正など行い LateSubmission をしてみました。
マシンパワーの関係で LightGBM の SingleModel しか試せていませんが 一番良かったのは Poisson の LightGBM で Private 11.17052
(2位相当) となりました(SeedAveraging や Stacking など行えばまだ上がるかもしれないです)。
五月雨ですがやったことで効いた・効かなかったはざっくり以下のような感じになりました。
- 効いたこと
- target encoding
- pseudo labeling
- やるごとにじわっと良くなる (private で 0.02程度?)イメージでした。3回ぐらいまでは有効でした。(それ以上はサチって止まってしまう感じ)
- SWEM (simple word embedding) での記事タイトルや本文、タイトル、キーワードの埋め込み
- Objective=RMSE 以外で解く
- 自分の範囲内では Poisson が一番良かったです
- いま Log 変換を試していないことに気が付きました。後でやる。
- ユーザーと記事との関係から記事の embedding の作成
- seed averaging
- キーワード全体での count で置き換えて集計 (mean/sum/std)
- ユーザーごとのアクセス時間のヒストグラム特徴量 (もーぐりくん案)
- (自分の範囲内では)効かなかったこと
- one-hot 化して多クラス分類として解いた後に期待値に直す
- 画像からユーザーの年齢を当てるタスクの論文で上記の方法のほうが regression よりも有効だった旨が書いてあったのを思ってやってみましたが余り効果ありませんでした。*3
- そもそも学習させたモデルが LightGBM だったので駄目だったのかも知れません
- one-hot 化して多クラス分類として解いた後に期待値に直す
- 効果が微妙だったこと
- カテゴリ変数の one-hot encoding
- 水準数がとても多いカテゴリも存在していたので Target Encoding のほうが効率もよくあえて one-hot する必要はなかったかなと思っています。
- カテゴリ変数の one-hot encoding
コードは以下においてありますのでもしよろしければ。vivid 使ってます 😌
やっていて、当日の8時間の中でこれをやってしまう上位のチームのパワーに圧倒されるばかりでした。しかも優勝チームの pocket さんに話を伺ったとき、全部スクラッチで書いてーという話をされていて、圧倒的な自力の差を感じました。
最後に
このような機会を頂いたKaggle Days Tokyoの運営の方々, データ提供頂いた日経様, スポンサード頂いている会社様, ありがとうございました。 次回以降も機会あれば是非参加したいです!
参考文献
*1:あらためてatmaCupもそうですがオンサイトで勝つ方のタフさを感じました
*2:限定的ですが Pseudo Labeling は効果がありました
*3:ちょっと読んだのが昔なのでうろ覚えですがたぶん DAGER: Deep Age, Gender and Emotion Recognition Using Convolutional Neural Network この論文
RMSE を Fold ごとに取ると全体の値より小さくなる証明
この記事を書く前に twitter でお話をしている流れで、まますさんに的確な証明を頂くことができました! 証明にはこちら RMSE.pdf - Google ドライブ からアクセスできます。(まますさんありがとうございましたmm)
そもそも
この記事のお題は RMSE を Fold ごとに取ると全体の値より小さくなる証明をやります
ということです。
これをやろうと思ったきっかけは #かぐるーど での kaggle本の本読みです。 前回は第5章だったのですが、その5.2.2で次のような記述があります。
クロスバリデーションでモデルの汎化性能を評価する際は、通常は各foldにおけるスコアを平均して行いますが、それぞれのfoldの目的変数と予測値を集めてデータ全体で計算する方法もあります。なお、評価指標によっては各foldのスコアの平均と、データ全体で目的変数と予測値から計算したスコアが一致しません。例えば、MAEやloglossではそれらが一致しますが、RMSEでは各foldのスコアの平均はデータ全体で計 算するより低くなります。
要するに K-Fold それぞれの rmse 平均値と、データセット全体での rmse の値だとデータセット全体のほうが大きい (K Fold のほうが良いように見積もられてしまう) という話です。たしかに Root をとる操作を毎回やるのと、全体で合わせた後やるのだと前者のほうが小さい値になような感じはしますよね。
これ一般的に示せるかなーという議論があり、僕が「関数の凸性とイェンゼンの不等式でいけますよ」と言ったところじゃあやってほしい!と言われたのが当エントリの経緯になります。
せっかくなので簡単にですが凸関数とイェンゼンの不等式にも触れつつ、お話できればと思っています。
NOTE: 若干細かい定義域についてやイェンゼンの不等式の導出についてなどは省略していますので、それらは別途文献など見ていただけば幸いです。
凸関数とは
下準備として、凸関数というのを定義します。凸関数というのは色々な定義がありますが、以下を満たすような関数 $f: \mathbb{R} \to (-\infty,+\infty]$ のことです
$$ f(t x_1 + (1 - t)x_2) \leq t f(x_1) + (1-t) f(x_2) $$
ただし $x_1, x_2$ は任意の実数 $\mathbb{R}$ の点で $t$ には 0以上1以下の制約がついています。
要するに $x_1$ と $x_2$ の内分点での $f$ の値と最初に $f$ で計算してしまってから $f(x_1)$ と $f(x_2)$ の内分を取るのとだと、後者のほうが大きいような関数、って言うことです。
また $f$ が微分可能な場合 $f$ の二階微分 $f'' \geq 0$ であることと上記の等式は同値になります。
イェンゼンの不等式
これをちょっと発展させて内分点の部分を2つ以上の点に拡張したのがイェンゼンの不等式です。
イェンゼンの不等式は上記の式と同じく特定の関数 $f$ が凸関数である必要十分条件を表した式で, $f$が凸ならば任意の自然数 $n$ と$\sum_{i=1}^n p_i = 1, p_i \geq 0$ を満たすような $p_i$ に対して次の式
$$ f(p_1 x_1 + p_2 x_2 \cdots + p_n x_n) \leq p_1 f(x_1) + p_2 f(x_2) \cdots + p_n f(x_n) $$
がなりたつ、という定理です。
たとえば$n=2$の時を考えてもらうと先ほどの凸関数の定義そのままであることはすぐわかると思いますので、凸関数の定義を変数 $n$ 個の場合に拡張したようなイメージです。
RMSE を考える
RMSE とは入力とラベルの誤差の2乗和を $M$ とした時に
$$ {\rm RMSE}(M) = M^{\frac{1}{2}} $$
で計算される値です。これの二階微分を考えると
$$ {\rm RMSE}''(x) = - \frac{1}{4} M^{- \frac{3}{2}} < 0 $$
です。すなわちRMSEの二階微分は常に負の値となります。これは凸関数と全く正反対の性質で一般に凹関数 (concave) と呼ばれ先のイェンゼンの不等式とちょうど不等号が反対の不等式が成立します。
K-Foldしたときの RMSE
今Fold を$ K$ 個に分割して、それぞれが $n_k$ 個のデータを持っているとします。(データセット全体では $N$ 個とします。) この時各 Fold での MSE (Mean Squared Error) を $M_k$ とすると Fold ごとのデータの数で重みづけた ${\rm RMSE}_{\rm fold}$ は
$$ {\rm RMSE}_{\rm fold} = \sum_{k=1}^K \frac{n_k}{N} \sqrt{M_k} $$
となります。一方で通常の RMSE に関しては
$$ {\rm RMSE} = \sqrt{\frac{1}{N} \sum_{k=1}^K n_k M_k} = \sqrt{\sum_{k=1}^K \frac{n_k}{N} M_k} $$
となります。ここで $M_k$ に $n_k$ をかけているのは $M_k$ が既に Mean Squared Error なので要素の数を掛けて和にになおして全体の $N$ で割算をするためです。
ここで $p_k = n_k / N$, $f(x) = \sqrt{x}$ と考えると $f$ は凹関数でかつ $\sum p_k = 1$ ですのでイェンゼンの不等式が用いることが出来て
$$ {\rm RMSE} = \sqrt{\sum_{k=1}^K \frac{n_k}{N} M_k} \geq \sum_{k=1}^K \frac{n_k}{N} \sqrt{M_k} = {\rm RMSE}_{\rm fold} $$
が成立します。即ち fold ごとで RMSE を計算して重み付きの平均を取った値のほうが、データセット全体での RMSE の値より小さくなることがわかりました。
RMSE 以外でも…
上記の証明を追っていただくと分かるようにこの証明はロス関数の値がデータごとに計算できること、及びそれをデータセット全体の平均したあとに凹関数に代入する、という構造が保たれている限り同様の議論をすることが可能です。
ですので Log を取ってから Root を取る RMSLE (Root Mean Squared Log Error) なども同様の議論が可能です。
参考文献
以下は本記事を書くにあたって使用した凸関数に関する話題や凸最適化に関する日本語の参考文献です。
- wikipedia: 凸関数: 凸関数 - Wikipedia
- 非線形計画法: 山下 信雄
- 非線形最適化の基礎: 福島 雅夫
- 工学基礎 最適化とその応用: 矢部 博
この記事は kaggle その2 advent calendar 2019 の記事です。