AnyTech Engineer Blog

AnyTech Engineer Blogは、AnyTechのエンジニアたちによる調査や成果、Tipsなどを公開するブログです。

LSTMがVision Transformerに戦いを挑んだら?

LSTMがVision Transformerに戦いを挑んだら? こんにちは、AnyTechの立浪です。

このたび、私の主著論文が「NeurIPS 2022」に採択されましたので、今回はこちらの論文について解説させていただきます。私が博士後期の学生として所属している立教大学と弊社でプレスリリースを出しております。立教大学のプレスリリースも是非ご覧ください。

論文は以下のURLから確認可能です。

[2205.01972] Sequencer: Deep LSTM for Image Classification

Sequencer: Deep LSTM for Image Classification

導入

昨今、Computer Visionの業界においては、Vision Transformer[ICLR2020]に始まり、多くのタスクにTransformerが進出しているような状況ですよね。このような流れの中で、MLP-Mixer[NeurIPS2021]CycleMLP[ICLR2022]のようにTransformerのSelf-Attentionの部分を、Self-Attention以外の何かに置き換えるアーキテクチャが提案されています。ConvNeXt[CVPR2022]RepLKNet[CVPR2022]のように、Self-Attentionの部分をカーネルサイズを大きめのDepthwise Convolutionに置き換え、CNNに戻る流れもあります。これらの流れは、いずれもMetaFormerという共通のフレームワークに従っていることが知られています。今回私達が提案したアーキテクチャも、MetaFormerを踏襲し、Self-AttentionをLSTMベースのモジュールに置き換えました。このアーキテクチャをSequencerと名付けています。Sequencerは、画像分類において、Swin Transformer[ICCV2021]やConvNeXtのような強力な手法と張り合える精度が得られます。

初学の読者の中で、Self-AttentionやTransformer、Vision Transformer、MLP-Mixer、MetaFormerについて、初めて聞いたり、詳しくない方もいらっしゃるかもしれません。これらにつきましては、Qiitaなどに詳しい記事ございますのでそちらに委ねたいと思います。

Sequencerのアーキテクチャ

全体は以下のように非常にシンプルで、よくあるアーキテクチャです。原著から拝借しました。

アーキテクチャ


入力画像

Vision Transformer同様のPatch Embeddingを行いますが、Flattenはしません。各パッチの縦横の意味がある状態のままにしておきます。つまり、カーネルサイズとストライドが等しいConvolutionと等価ですね。Sequencerではカーネルサイズ: 7\times7、ストライド (7, 7)を使用しています。

イメージ


Sequencerブロック

Sequencerブロックの処理と次のパッチマージの処理で、1つのステージを構成します。Sequencerは4つのステージから成ります。マルチヘッドAttention以外はVision TransformerのTransformerブロックと同じです。Sequencerブロックは、TransformerブロックのマルチヘッドAttention(左図)をBiLSTM2Dと名付けたレイヤー(右図)で置き換えたブロックです。各ステージごとに、複数個のSequencerブロックがスタックされています。

TransformerブロックSequencerブロック


What is BiLSTM2D?

BiLSTM2Dは横方向の双方向LSTMと縦方向の双方向LSTMが並列に配置されております。これらのLSTMの隠れ状態(上下左右)を連結して、線形層で融合します。 式を使って示します。まず、縦方向用の双方向LSTMを入力します。

 \displaystyle
\mathbf{H}^{\rm ver}_{:, w, :} = {\rm BiLSTM}(\mathbf{X}_{:, w, :}).

ここで、縦方向のシーケンスは、横方向に配置されているパッチの数だけあることに注意してください。各シーケンスに独立の双方向LSTMを用意するのではなく、重み共有をしています。並行して、横方向用の双方向LSTMにもテンソルを入力します。

 \displaystyle
\mathbf{H}^{\rm hor}_{h, :, :} = {\rm BiLSTM}(\mathbf{X}_{h, :, :}).

そうすると、縦方向用の双方向LSTMと横方向用の双方向LSTMから、隠れ状態が得られるので、これらを連結します。

 \displaystyle
\mathbf{H} = {\tt concatenate}(\mathbf{H}^{\rm ver}, \mathbf{H}^{\rm hor}).

最後に、結合した状態をPointWiseな全結合層(つまり  1\times 1 Convolution)に入力します。

 \displaystyle
\hat{\mathbf{X}} = {\rm PointWise FC}(\mathbf{H}).

この操作の意味は、2つの双方向LSTMの出力の融合や、元のテンソル \mathbf{X}に次元を合わせる意味もあります。

これらを図で示しますと、こんな感じです。

BiLSTM2D


パッチマージ

さて、Sequencerは複数のSequencerブロックから構成されるステージが4つあります。最初のステージの最後はパッチマージ、第2、第3ステージの最後はPointWiseな全結合層が加えられています。 パッチマージの処理は、入力画像と同様の処理です。第1ステージから第2ステージに移行する際はカーネルサイズ 2\times2、ストライド (2, 2)のConvolutionで、それ以外はPointWiseな全結合層です。

分類器

Layer Normalizationの後、線形分類器がついています。一般的な画像分類と同様ですので、ここで特筆すべきことはありません。

Sequencerの実験結果

すべてを取り上げるわけにもいけないので、事前学習・ファインチューニング・セマンティックについて取り上げます。

事前学習

Sequencerの事前学習では、ILSVRC-2012 ImageNet、いわゆるImageNet-1kを使用しています。約120万枚程度の自然画像から成る、1000クラスに分類されているデータセットです最近はImageNet-21kを事前学習に使っている論文も多いですよね。ImageNet-21kで満足に訓練できるくらいのお金ください

モデルサイズは3種類用意しました。

  • Sequencer2D-S
  • Sequencer2D-M
  • Sequencer2D-L
分類結果表


以上が結果ですが、同程度のパラメータであれば、Swin TransformerやConvNeXtなどの精度を上回っています。一方で、スループットは悪いです。これはリカレントニューラルネットワークを使っているため仕方ないことでもあります。NeurIPSの査読においてもこの点は多数指摘されました。当初、独自のRNN層を使うことも考えたのですが、Python+PyTorchの実装ではスループットに悪影響を及ぼしてしまうため、今回はスコープ外としました。今回はRNN+MetaFormerについての初回の試みということ、Sequencerが気軽に使いやすいモデルであるために、あえてPyTorchのLSTMだけで構成できるようにしました。

表の中にSequencer2D-L↑というモデルもあります。これは、事前学習されたSequencer2D-Lの重みを使い、より解像度が高い画像でファインチューニングしたモデルの結果です。この場合は、84.6%のTop 1精度が出せています。

アブレーションスタディも豊富に実施しています。この記事ではその詳細を割愛しますが、以下のような疑問に対して、実験的に解を与えています。

  • 双方向LSTMを縦と横に分ける必要があるのか?
  • LSTMが双方向であることに意味があるのか?
  • LSTMの隠れ次元はどのような影響を及ぼすのか?
  • GRUやSimple RNNだったらどうなるのか?

ファインチューニング

画像分類のデータセットに対するファインチューニングは、以下の4種類のデータセットに対して試しています。

  • CIFAR-10
  • CIFAR-100
  • Stanford Cars
  • Oxford Flowers-102
ファインチューニング


セマンティックセグメンテーション

ADE20Kデータセットを使用したセマンティックセグメンテーションの実験も行っています。この実験は査読を受けて追加した実験なので、一度論文を読んだ方でも知らない方がいらっしゃるかもしれません。精度はPoolFormer[CVPR2022]などよりも有意に高く、セマンティックセグメンテーションでは有効な手段と言えると思います。

セマンティックセグメンテーション


一方、物体検出には不向きのようです(原著・付録参照)

分析

分析の一部を紹介したいと思います。

解像度適応性

Seuqnecerは入力画像の解像度を、訓練したときの解像度から半分から2倍に変化させても、精度が落ちにくいという特性があります。以下の図では、水平軸を解像度、垂直軸を相対Top-1精度としています。なお、解像度 224\times 224の画像に対して推論を実施したときの精度を、相対Top-1精度の基準にします。この図からは、Sequencerが多くの解像度の場合、他の手法よりも精度が高いことが解ります。

解像度適応性


推論時のメモリ消費について

解像度を上げたときにメモリ消費がどう変化するのでしょうか?以下はそれを表した図ですが、これによると、解像度をあげていくと、DeiTどころかCNNであるConvNeXtよりもメモリ効率が良いですね。推論が遅いことが許容できる環境で、かつメモリが制限されている環境であれば、Seuqnecerも選択肢の一つになるかと思います。

推論時のメモリ消費


考察やまとめ

RNNで何故うまくいくのでしょうか?これはパッチの距離が近いほどパッチ間の関連性が高い傾向にあること、そして直線的なパッチの連続性がほどよくLSTMの長期記憶に頼れるようになっているからだと推察します。

これまでいろいろ書いてきましたが結局、この論文は、いままで十分に検討されてこなかったRNNを画像分類に使う可能性について検討し、そしてImageNet-1kのような規模のデータセットでその有効性を示したことに意義があります。実は、2015年ごろまではLeNetの畳み込みをLSTMで置き換えたReNetというアーキテクチャが研究されていたようです。しかしながら、この頃の研究は、SVHNなどの小さいデータセットでの研究であり、ImageNet-1kのような大量のデータで検証されてはいませんでしたし、MetaFormerのような残差接続構造も使われていませんでした。今回、Sequencerという温故知新的アーキテクチャを提案し、ImageNet-1kの事前学習を含む多くの実験で、最先端のアーキテクチャに対抗しうることを示しました。SequencerのようなRNNベースのアーキテクチャの高速化の試みが行われることや画像認識の幅広いタスクでSequencerのアイディアが使われていくことを願っています。

蛇足

  • Sequencerは公式実装もありますが、最新版のtimmではSequencerが使えます。
  • Sequencerという名前は、シーケンスモデリングのためのモジュールであるLSTMを使用したことに加え、Synthesizer[ICML 2021]というTransformerの派生アーキテクチャにインスパイアされて名付けました。

さいごに

さて、こちらのブログを始めて、最初の技術記事を投稿しました。この研究は主に大学の方で行った研究なので、AnyTechの売りである流体らしさが薄いかもしれません。ただ、AnyTechにも、Deep Learningのメインの領域に何か影響を与えようとしている人がいるんだと知っていただけたら幸いです。このAnyTech Engineer Blogでは、これからも定期的にコンピュータビジョンを始めとする技術情報を始めとする幅広い技術情報を発信していきます。是非ともフォローをお願いいたします!

また、AnyTechでは、流体のためのAI「DeepLiquid」の研究/開発/事業に携わってくれる方を募集しております。ご興味のある方はこちらからぜひお問い合わせください。