nykergoto’s blog

機械学習とpythonをメインに

画像の超解像度化をするモデル SRCNN を pytorch で実装してみた

画像の超解像度化というタスクがあります。 やることは低解像度(小さい画像)を高解像な画像に拡大するときにできるだけ綺麗に引き延ばす、というタスクです。

https://www.slideshare.net/HHiroto/deep-learning-106529202 ではDeep Learning の登場によって超解像度化技術がどのように進歩していったのかに関してとてもわかり易くまとまっていて良いです。

その中でも Deep Learning の活用が始まりだした 2015 に登場した SRCNN というモデルを実装してみました。

SRCNN の構造

SRCNN のアーキテクチャは3層の Convolution + Relu で構成されています。pytorch だと以下のような感じでとてもシンプルです。

from torch import nn


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=9, padding=4)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=1, padding=0)
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=3, kernel_size=5, padding=2)
        self.activate = nn.ReLU()

    def forward(self, x):
        h = self.activate(self.conv1(x))
        h = self.activate(self.conv2(h))
        return self.conv3(h)

ちょっと注意なのは、このネットワークでは入力される画像サイズと出力される画像サイズは同じだということです(このネットワークでは拡大はしません)。 入力されるのは低解像度画像を拡大した荒い画像で、それを同サイズの綺麗な画像にして返すというふうになっています。

学習時には、綺麗な画像と荒い画像のペアのデータが必要になります。
元の論文では、元の画像に対してガウシアンぼかしを適用して荒い画像を生成していました。 なんですが色々と実装を見ていると、画像を半分など小さく縮小して古典手法で拡大する、という方法をとっているものが大半だったので実装ではこちらを採用しています。

どれだけ綺麗に戻せたかのロス関数にはピクセル間の MSE を用います。

実装

Github にあげています。詳しくは readme に書いていますが docker/docker-compose が入っていれば動かせる、はず

github.com

実験

実際に上記の実装を使って学習を回して、超解像度をやってみました。

データセット

学習につかうデータセットはSRCNN 論文に合わせて Anchored Neighborhood Regression for Fast Example-Based Super-Resolution で使われている 91 枚の画像をつかいました。実際にはこれらの画像を 64x64 の大きさにランダムに切り出しています。

検証には Set5 と呼ばれている画像をつかいました。いろんな論文に登場するので、この分野ではよく使われるデータセットみたいですね。

実験

学習したモデルを使って実際に超解像度化を行います。左が元画像を半分に縮小して人工的に作った低解像度画像、真ん中が SRCNN で綺麗にした画像、右が元の画像です。低解像度画像は x0.5 にリサイズした後 BICUBIC で拡大して作成しました。

f:id:dette:20190518105510p:plain
Bird

f:id:dette:20190518105541p:plain
Baby

これだとちょっとわかりにくいので画像の一部分を拡大してみます。

f:id:dette:20190518105629p:plain
Baby 若干まつげがくっきりしてる?

感想

モデルがとても単純なわりに、しっかり綺麗になるのでやはり CNN の構造はとても良い性質を持っているんだなと改めて思いました。

モデルが小さいので推論だけならば CPU で十分動くので、既存のサーバーとかにも組み込みやすいように思えます。計算量が軽いことを活かして、フロントエンド側で画像を勝手に綺麗にして表示するとかも面白そうです。

あと実装的な観点でいうと torchvision の transform がとても便利で、画像の拡大縮小のコードがとってもすっきりかけて気分が良かったです😄

気になった点としてはネットワークに BatchNormalization がないので若干学習が不安定になりがちな点ですが、これは BatchNormalization が出来たのが 2015 年なので当たり前でした。Deep 界隈の進歩の速さを象徴しているなーとあらためて思いました。

つぎやりたいこと

  • 今僕は 2015 年にはこれたので, もうちょっと未来のアーキテクチャを実装して差分をみてみたい

参考文献