AnyTech Engineer Blog

AnyTech Engineer Blogは、AnyTechのエンジニアたちによる調査や成果、Tipsなどを公開するブログです。

ImageNetモデルを用いた異常検知手法の解説【第4回:PaDiM(a Patch Distribution Modeling Framework for Anomaly Detection and Localization)】

ImageNetモデルを用いた異常検知手法の解説【第4回:PaDiM(a Patch Distribution Modeling Framework for Anomaly Detection and Localization)】 こんにちは、AnyTechの木村と申します。
AnyTechにて、機械学習エンジニアやAIエンジニアといった役割にて、R&Dに日々従事しております。


この記事は、昨今流行しているImageNetモデルを応用した異常検知手法について、リサーチをする機会がありましたので、小職の備忘も兼ねまして、その解説をさせて頂くものです。



目次



シリーズ



はじめに

第3回に引き続き、近年流行しているImageNetモデルを応用した異常検知手法について、解説をさせて頂きます。
今回は、PaDiM(a Patch Distribution Modeling Framework for Anomaly Detection and Localization)についての解説となります。


PaDiMは個人的に、関連のアルゴリズムの中で最初に触れたものでして、思い入れが深いです。
当時の衝撃はなかなかのもので、「異常検知で、学習をしないの?」と、教えてくれた同僚に何度も聞いてしまいました。笑
「マハラノビス距離って…、何だっけ…?」とも聞いてしまいました。笑
また、個人的にそれまで敬遠気味であった異常検知というアプローチに対して、初めて手応えを感じたのも、PaDiMを通してでした。
後々詳しく知っていくと、ある種の割り切りも強いことに気付くのですが…、その辺りも含めて、この記事で解説をさせて頂こうと思います。



先ずはザックリ、アルゴリズム概要

PaDiMは、SPADEとは逆で、pixelレベルの異常位置セグメンテーションを前段に行います。
後段として、画像レベルの異常検知を行いますが、それはクエリ画像における異常位置セグメンテーション結果の最大スコアを採用して、完了する次第です。
その為、メインのアルゴリズムは、pixelレベルの異常位置セグメンテーションとなります。
また、最終層の特徴ベクトルは使用せず、中間層の特徴テンソルのみを使用します。

PaDiMのpixelレベルの異常位置セグメンテーションは、以下アルゴリズムにて実施されます。


  • 学習データ全件から、空間上の縦横の概念を持った中間層の特徴テンソルの抽出を行う
  • 中間層特徴テンソルの縦横サイズを (H, W) に合わせて、図のように縦に連結する
  • 縦に連結した次元を、ランダムに2/3程を間引き、1/3程の次元数に削減する
  • 中間層特徴テンソルの各pixel (i, j) 毎に、次元毎の平均(ベクトル)  \mu_{ij} と、分散共分散行列  \sum_{ij} とを計算する
  • クエリ画像から同様に「抽出/結合/次元削減」した中間層テンソルと、  \mu_{ij} と、  \sum_{ij} とを用いて、マハラノビス距離を計算する
  • 求めたマハラノビス距離のマップを、ガウシアンフィルタで滑らかにする

以上。

ここで、各異常スコアは相対的なものですので、評価データにて最適な閾値を探索し、それより大きいか小さいかで、pixelレベルで異常か正常かを判断する形となります。



次にじっくり、アルゴリズム詳細

ここからはじっくりと、PaDiMの詳細を紹介させて頂こうと思います。


PaDiMは、SPADEの後続になります。
或いは、Modeling the Distribution of Normal Data in Pre-Trained Deep Features for Anomaly Detectionの後続とも言えるでしょう。

Modeling the Distribution of Normal Data in Pre-Trained Deep Features for Anomaly Detectionをザックリ説明すると、ImageNetモデルの特徴と、異常検知に古くから用いられているマハラノビス距離という指標を用いて、画像レベルの異常検知を行うアルゴリズムです。
以下の記事が分かりやすかったので、参考までに紹介させて頂きます。
スクラッチ学習よりも高精度!?既存の画像分類モデルを利用して、異常検知モデルを作る方法を提案した論文を読み、今後のAI開発について考えました。 | 株式会社オープンストリーム

PaDiMは、上記2つのアルゴリズムを参考に発展されたアルゴリズムとなっていますが、一方で、ある種の割切りとも思える、アルゴリズムの単純化も施されていると思います。
その意味で、個人的には、SPADEのスピンオフというような印象です。


尚、大まかには、PaDiMはSPADEと比べて、以下の違いがあるかと思います。

  1. メモリ逼迫の課題については、学習時には依然として残るが、推論時にはキレイにクリアされている
  2. SPADEにはなかったpixel依存の問題があり、得意とする問題と苦手とする問題が、ハッキリと分かれる
  3. PaDiMでは、イメージレベルの異常検知と、pixelレベルの異常位置セグメンテーションとが、シームレスに行われる為、それら結果に一貫性がある


誤解を恐れずに、もう少し踏み込んだ個人的な見解を伝えさせて頂くと、PaDiMがMVTecで発揮している高い精度は、MVTec程に整備されたデータだからこそ、発揮されたものだと思われます。
後程、そう思う根拠となる実験結果を示しますが、左右上下のズレや、アングルの違い等がありますと、高い精度が発揮されません。
その為、PaDiMは、固定カメラ等、画角が一定である画像などに向いており、論文中でも産業検査を課題感として挙げています。

また、論文では、ShanghaiTech Campus (STC)データセットでの検証結果も載っていますが、こちらも固定カメラにて撮影されたものとなります。
Shanghaitech Vision and Intelligent Perception(SVIP) LAB

実際に運用を行う際にも、そういったシチュエーションの場合に、PaDiMを用いることをオススメします。


ちなみに、論文に記載されているMVTecでの精度と、paper with codeで紹介されているgithubリポジトリとの間には、乖離があります。

Image-level anomaly detection accuracy (ROCAUC)

Category PaDiM-Anomaly-Detection-Localization-master anomalib ind_knn_ad paper(※ind_knn_adの記載を引用)
Carpet 0.999 0.995 0.933 0.967
Grid 0.957 0.942 0.982 0.973
Leather 1.000 1.000 0.999 0.992
Tile 0.974 0.974 0.986 0.941
Wood 0.988 0.993 0.989 0.947
Bottle 0.998 0.999 0.998 0.983
Cable 0.922 0.878 0.933 0.967
Capsule 0.915 0.927 0.883 0.985
Hazelnut 0.933 0.964 0.837 0.982
Metal nut 0.992 0.989 0.994 0.972
Pill 0.944 0.939 0.890 0.957
Screw 0.844 0.845 0.830 0.985
Toothbrush 0.972 0.942 0.972 0.988
Transistor 0.978 0.976 0.968 0.975
Zipper 0.909 0.882 0.895 0.985


Pixel-level anomaly detection accuracy (ROCAUC)

Category PaDiM-Anomaly-Detection-Localization-master anomalib ind_knn_ad paper(※ind_knn_adの記載を引用)
Carpet 0.990 0.991 0.986 0.962
Grid 0.965 0.970 0.972 0.946
Leather 0.989 0.993 0.987 0.978
Tile 0.939 0.955 0.948 0.860
Wood 0.941 0.957 0.911 0.936
Bottle 0.982 0.985 0.978 0.948
Cable 0.968 0.970 0.961 0.888
Capsule 0.986 0.988 0.983 0.935
Hazelnut 0.979 0.985 0.975 0.926
Metal nut 0.971 0.982 0.965 0.856
Pill 0.961 0.966 0.932 0.927
Screw 0.983 0.988 0.978 0.944
Toothbrush 0.987 0.991 0.983 0.931
Transistor 0.975 0.976 0.972 0.845
Zipper 0.984 0.986 0.974 0.959


尚、論文中に、「ランダムローテーション(-10〜+10)と、ランダムクロップを適用する」というaugmentationに関する記載があり、それによって、pixel単位の異常位置セグメンテーション精度を犠牲に、画像単位の異常検知精度を引き上げているのではないかと思われます。
各リポジトリ間での精度差分は、コーディング上の細かい配慮の違いではないかと思われます。


さて、ここからは、PaDiMの具体的なアルゴリズム詳細を説明させて頂きます。
先ずは、アルゴリズムについての要約を記載します。


A. Embedding extraction

先ず、pre-trainedのCNNを使用し、パッチ埋め込みベクトルを取得する。
SPADEのそれと似ており、以下図のイメージ。

CNN特徴マップの位置 (i, j) に対する各画像パッチについて、PaDiMは学習用に、 N 個の異なる画像から、3層からの特徴ベクトルセット  X_{ij} = \lbrace x^k_{ij} , k ∈ \lbrack 1, N \rbrack \rbrace を取得し、そこからガウスパラメーター ( \mu_{ij} , \sum_{ij}) を導出する。


学習フェーズにおいて、正常画像の各パッチは、pre-trainedのCNNの中間層の特徴テンソル、或いは、アクティベーションマップの中の、空間的に対応する一本の特徴ベクトルに対応付けられる。
そして、その中間層特徴テンソル上の1pixel毎の特徴ベクトルの受容野が、画像パッチの広さに対応している。

また、PaDiMでは、CNNの層の深さ別に抽出される特徴ベクトルを縦に連結して、それを多様なセマンティックレベルと解像度からの情報を集約した特徴ベクトルとすることで、きめ細かなグローバルコンテキストを得る。
尚、アクティベーションマップは、入力画像よりも解像度が低い為、元の画像サイズにおける多くのピクセルが同じ埋め込みを持ち、元の画像解像度にオーバーラップすることのない画像パッチをする。
よって、入力画像は (i, j) ∈ \lbrack 1, W \rbrack × \lbrack 1, H \rbrack 位置のグリッドに分割ができる。
ここで、 WxH は、埋め込みの生成に使用される最大活性化マップの解像度(縦横大)とした。
そうして、このグリッド内の各パッチ位置 (i, j) は、前述のように計算された埋め込みベクトル X_{ij} に関連付けられる。

生成されたパッチ埋め込みベクトルは、チャンネル方向について、冗長な情報を運ぶ可能性がある為、そのサイズを縮小する可能性を実験で試す。
例えば、wide_resnet50_2では、縦連結した特徴ベクトルのチャンネル数、及び、1pixel毎におけるその長さが、1,792次元となっているが、CNNはこの1,792次元の内に似た特徴を作ってしまう性質がある為、似ているそれらは幾つかを捨てても、情報量は減らないだろうという形。
実験では、縦連結したチャンネル数の内、その次元をランダムに選択/削減することは、従来の主成分分析(PCA)アルゴリズムよりも効率的であることに気付いた。
この単純でランダムな次元削減により、最新のパフォーマンスを維持しながら、学習時間と推論時間との両方が軽減される。


B. Learning of the normality

先程の図に示すように、N 個の学習用正常画像について、位置 (i, j) での正常画像特性を学習するには、最初に (i, j) での画像パッチ特徴ベクトルのセット  X_{ij} = \lbrace x^k_{ij} , k ∈ \lbrack 1, N \rbrack \rbrace を計算する。

このセットによって運ばれる情報を要約するために、 X_{ij} が多変量ガウス分布 N( \mu_{ij},  \sum_{ij}) によって生成されると仮定する。
ここで、  \mu_{ij}X_{ij} の標本平均であり、標本共分散 \sum_{ij} は次のように推定されます。

ここで、正則化項 I は、学習用正常画像の画像パッチ特徴ベクトルセットの共分散行列  \sum_{ij} をフルランクで可逆にする。
また、可能な各パッチ位置は、先程の図に示すように、ガウスパラメーターの行列によって多変量ガウス分布に関連付けられ、かつ、その特徴ベクトルは、さまざまなセマンティックレベルから情報を集約している。
したがって、各推定多変量ガウス分布 N( \mu_{ij}, \sum_{ij}) も異なるレベルからの情報を取得し、 \sum_{ij} にはレベル間の相関が含まれる。
そして、pre-trainedのCNNにおける、異なるセマンティックレベル間のこれらの関係をモデル化することで、異常位置セグメンテーションのパフォーマンスが向上することが、実験的に示された。


C. Inference : computation of the anomaly map

先の研究における、異常マップの計算に着想を得て、マハラノビス距離 M(x_{ij}) を使用して、テスト画像の位置 (i, j) にあるパッチに異常スコアを与える。

M(x_{ij}) は、 x_{ij} が得られるテスト画像のパッチと、学習済み分布 N( \mu_{ij}, \sum_{ij}) の間の距離と解釈できる。
ここで、 M(x_{ij}) は、次のように計算される。

これを全pixelに展開すれば、異常スコアマップを形成するための、マハラノビス距離の行列 M =  ( M ( x_{ij} ) )_{1<i<W, 1<j<H} を計算できる。
このマップにて高いスコアを持つ座標は、異常な領域となる。
マハラノビス距離の適用によって、スケーラビリティの問題は大幅に解消する。

そして、画像全体の最終的な異常スコアは、異常マップ M の最大値とする。



以上が、論文に記載されているアルゴリズムの説明です。
新たに、マハラノビス距離という概念が入ってきている為、その点がDN2やSPADEと違い、難しく感じるかもしれません。
しかし、マハラノビス距離もまた、KNNとは別種の古典的な手法となります。
ImageNetモデルと、古典的な手法の組み合わせという意味では、DN2/SPADEに似たアルゴリズムの実現感ではないでしょうか。

マハラノビス距離については、例えば、メーカーさん勤務で品質管理などをされている方などは、親しみが深いのかもしれません。
ザックリとは、事象の発生といいますか、数値の大小の分布に対して、正規分布を仮定した上で、その分布の中心から逸脱度合いを測定する距離測定方法となります。
分布の中心から大きく逸脱している場合に、異常とみなされる次第です。
より詳細には、以下の記事が分かりやすかったので、よければ参考にして頂けたらと思います。
様々な手法で大活躍、マハラノビス距離の考え方 | シグマアイ-仕事で使える統計を-


さて、アルゴリズムを把握したところで、実際にコードを書き起こしていきましょう。
先ず、ImageNetモデルを用いる点と、中間特徴を引っ張ってくるところは、これまでと変わりませんので、その仕掛けをこれまでと同様に構築します。

import torch
import torchvision.models as models
from torchinfo import summary

device = torch.device('cuda:0')  # torch.device('cpu')

model = models.wide_resnet50_2(weights=models.Wide_ResNet50_2_Weights.IMAGENET1K_V1)  # DEFAULT
model.eval()
model.to(device)

print('model =', model)

# https://github.com/xiahaifeng1995/PaDiM-Anomaly-Detection-Localization-master/blob/main/main.py#L62
# set model's intermediate outputs
outputs = []

def hook(module, input, output):
    outputs.append(output.clone().detach().cpu().numpy())

model.layer1[-1].register_forward_hook(hook)
model.layer2[-1].register_forward_hook(hook)
model.layer3[-1].register_forward_hook(hook)

summary(model, input_size=(1, 3, 224, 224))


次に、データの読込を実施します。
読み込むデータは、DN2やSPADEと同様に、先ずは、Dogs vs Catsデータセットを用いてみます。

尚、途中で不安を覚えるかもしれませんので、Dogs vs CatsデータセットにPaDiMを適用するとどうなるかを、先に共有させて頂きますが、実は全然精度が出ません。
Dogs vs Catsデータセットの写り込みが多様である為、pixel依存を効かせているPaDiMにとっては、苦手な部類である様子です。
その事実を前もって認識した上で、実験を進めていきましょう。


以下、ファイル名の取得から、中間層の特徴テンソル取得までを、一気に実施します。

ファイル名取得

import os

path = './dogs-vs-cats/train/'
files = os.listdir(path)

files_cat = [os.path.join(path, f) for f in files
             if (os.path.isfile(os.path.join(path, f)) &
                 ('cat.' in f) & ('.jpg' in f))]
files_dog = [os.path.join(path, f) for f in files
             if (os.path.isfile(os.path.join(path, f)) &
                 ('dog.' in f) & ('.jpg' in f))]

files_cat = sorted(files_cat)
files_dog = sorted(files_dog)

print('len(files_cat) =', len(files_cat))
print('files_cat[:10] =\n', files_cat[:10])
print()
print('len(files_dog) =', len(files_dog))
print('files_dog[:10] =\n', files_dog[:10])


学習データと評価データの振分け

import random
import numpy as np

N_cat_train = 3000
N_cat_val = 1000
N_dog = 1000

random.seed(0)

files_cat = random.sample(files_cat, (N_cat_train + N_cat_val))
files_dog = random.sample(files_dog, N_dog)

files_train = np.array(files_cat[:N_cat_train])
files_val = np.array(files_cat[N_cat_train:] + files_dog)

y_val = np.concatenate([np.zeros([N_cat_val]), 
                        np.ones([N_dog])], axis=0).astype(np.int16)

print('len(files_train) =', len(files_train))
print('len(files_val) =', len(files_val))

import matplotlib.pyplot as plt
import japanize_matplotlib

plt.figure(figsize=(10, 4), dpi=100)
plt.plot(y_val, linewidth=5, alpha=0.5)
plt.grid()
plt.show()


学習用データの中間層特徴テンソルの取得
(※最終層の特徴ベクトルは不要なので取得なし)

import cv2
from tqdm.notebook import tqdm

N_batch = 200

outputs = []
x_batch = []

img_train = []
img_prep_train = []

for file in tqdm(files_train):
    img = cv2.imread(file)[..., ::-1]  # BGR2RGB
    img_prep = cv2.resize(img, (224, 224))

    x = img_prep
    x = x / 255
    x = x - np.array([[[0.485, 0.456, 0.406]]])
    x = x / np.array([[[0.229, 0.224, 0.225]]])
    x = torch.from_numpy(x.astype(np.float32)).unsqueeze(0).permute(0, 3, 1, 2)
    x = x.to(device)
    
    x_batch.append(x)

    if (len(x_batch) == N_batch) | (file == files_train[-1]):
        with torch.no_grad():
            _ = model(torch.vstack(x_batch))
            x_batch = []

    img_train.append(img)
    img_prep_train.append(img_prep)

img_prep_train = np.stack(img_prep_train)
f1_train = np.vstack(outputs[0::3])
f2_train = np.vstack(outputs[1::3])
f3_train = np.vstack(outputs[2::3])

print('len(img_train) =', len(img_train))
print('img_prep_train.shape =', img_prep_train.shape)
print('f1_train.shape =', f1_train.shape)
print('f2_train.shape =', f2_train.shape)
print('f3_train.shape =', f3_train.shape)


評価用データの中間層特徴テンソルの取得
(※最終層の特徴ベクトルは不要なので取得なし)

outputs = []
x_batch = []

img_val = []
img_prep_val = []

for file in tqdm(files_val):
    img = cv2.imread(file)[..., ::-1]  # BGR2RGB
    img_prep = cv2.resize(img, (224, 224))

    x = img_prep
    x = x / 255
    x = x - np.array([[[0.485, 0.456, 0.406]]])
    x = x / np.array([[[0.229, 0.224, 0.225]]])
    x = torch.from_numpy(x.astype(np.float32)).unsqueeze(0).permute(0, 3, 1, 2)
    x = x.to(device)
    
    x_batch.append(x)

    if (len(x_batch) == N_batch) | (file == files_val[-1]):
        with torch.no_grad():
            _ = model(torch.vstack(x_batch))
            x_batch = []

    img_val.append(img)
    img_prep_val.append(img_prep)

img_prep_val = np.stack(img_prep_val)
f1_val = np.vstack(outputs[0::3])
f2_val = np.vstack(outputs[1::3])
f3_val = np.vstack(outputs[2::3])

print('len(img_val) =', len(img_val))
print('img_prep_val.shape =', img_prep_val.shape)
print('f1_val.shape =', f1_val.shape)
print('f2_val.shape =', f2_val.shape)
print('f3_val.shape =', f3_val.shape)


さて、中間層の特徴テンソルが得られたところで、次に、その中間層の特徴テンソルの次元削減を行います。
PaDiMでは、中間層の特徴がチャンネル次元方向に冗長だろうという仮説の下、チャンネル次元の次元削減を、ランダムチョイスによって実現しています。
主成分分析による次元削減も試したが、ランダムチョイスの方が精度が高かったという、彼らの実験結果だったそうです。
また、この実施によって、メモリ逼迫問題を軽減することができます。
直接的な恩恵で言えば、学習データを増やすことができます。

尚、論文によれば、WideResNetにおいては、図に示されるような縦連結の特徴ベクトル 256 + 512 + 1024 = 1792 次元のものを、 550 次元にまで削減したとのことです。
ここでも、論文で行われた実験に沿って、 1792550 という次元削減を行おうと思います。

ランダムに、次元削減を行うコードが以下です。
ちなみに、縦連結をする前に、次元を削減した方がメモリパンクの心配が減る為、ここでの説明実装ではそのようにしています。

np.random.seed(0)
idx_tmp = np.sort(np.random.permutation(np.arange(256+512+1024))[:550])

f1_train = f1_train[:, idx_tmp[idx_tmp < 256]]
f2_train = f2_train[:, (idx_tmp[(256 <= idx_tmp) & (idx_tmp < (256 + 512))] - 256)]
f3_train = f3_train[:, (idx_tmp[(256 + 512) <= idx_tmp] - (256 + 512))]

f1_val = f1_val[:, idx_tmp[idx_tmp < 256]]
f2_val = f2_val[:, (idx_tmp[(256 <= idx_tmp) & (idx_tmp < (256 + 512))] - 256)]
f3_val = f3_val[:, (idx_tmp[(256 + 512) <= idx_tmp] - (256 + 512))]

print('f1_train.shape =', f1_train.shape)
print('f2_train.shape =', f2_train.shape)
print('f3_train.shape =', f3_train.shape)
print('f1_val.shape =', f1_val.shape)
print('f2_val.shape =', f2_val.shape)
print('f3_val.shape =', f3_val.shape)


そして、次には、それらのベクトルを縦連結します。
イメージとしては、中間層の特徴テンソル、或いは、論文表現におけるアクティベーションマップの、空間方向の縦横の長さを合わせて、それを縦に連結することで、 (学習データ数, 550, アクティベーションマップの高さ, アクティベーションマップの幅) という形状のテンソルを作ります。
これが、アクティベーションマップ上の1pixel単位で見た際には、縦連結の特徴ベクトルとなっていて、それが束になってアクティベーションマップを形成する次第です。

尚、レイヤー別のアクティベーションマップを連結させるためには、3つあるそれの内、いずれかの縦横の大きさに合わせる必要がります。
論文では、先程の図に示された通り、最も空間方向に大きいアクティベーションマップに合わせて、つまり、浅い層のアクティベーションマップに合わせて、深さが中間の層のそれと、深い層のそれとを、拡大することとしています。
私が実験した限りですと、ここのチューニングが結構大事な印象で、どちらかというと最も空間方向に小さいアクティベーションマップに合わせた方が、精度も計算利得も高いかと思われました。
その為、ここでの実装においては、最も空間方向に小さいアクティベーションマップに合わせる方針にて、縦連結の中間層の特徴テンソルを作成しようと思います。

連結を実施するコードは以下となります。
モードの数値を変更することで、連結後のアクティベーションマップの空間方向縦横の大きさを変更することができます。
また、この記事の実装では、最も縦横のサイズが小さいものに合わせて、中間層の特徴テンソルを結合しています。
私の手元の実験ではそれが一番精度が高かったのと、メモリ効率が良い為です。

import torch.nn.functional as F

mode = 2

if (mode == 0):
    f2_train = F.interpolate(torch.from_numpy(f2_train), size=56,
                             mode='bilinear', align_corners=False).numpy()
    f3_train = F.interpolate(torch.from_numpy(f3_train), size=56,
                             mode='bilinear', align_corners=False).numpy()
elif (mode == 1):
    f1_train = F.interpolate(torch.from_numpy(f1_train), size=28,
                             mode='bilinear', align_corners=False).numpy()
    f3_train = F.interpolate(torch.from_numpy(f3_train), size=28,
                             mode='bilinear', align_corners=False).numpy()
elif (mode == 2):
    f1_train = F.interpolate(torch.from_numpy(f1_train), size=14,
                             mode='bilinear', align_corners=False).numpy()
    f2_train = F.interpolate(torch.from_numpy(f2_train), size=14,
                             mode='bilinear', align_corners=False).numpy()

print('f1_train.shape =', f1_train.shape)
print('f2_train.shape =', f2_train.shape)
print('f3_train.shape =', f3_train.shape)

f123_train = np.concatenate([f1_train, f2_train, f3_train], axis=1)

print('f123_train.shape =', f123_train.shape)


ここで、メモリ逼迫が結構深刻になる可能性が高いですので、連結前のアクティベーションマップをメモリ上から削除します。
これによって、メモリパンクのリスクが軽減されるかと思われます。

del f1_train
del f2_train
del f3_train


次に、評価データよりのアクティベーションマップを縦連結します。
併せて、連結前のアクティベーションマップを削除します。

if (mode == 0):
    f2_val = F.interpolate(torch.from_numpy(f2_val), size=56,
                           mode='bilinear', align_corners=False).numpy()
    f3_val = F.interpolate(torch.from_numpy(f3_val), size=56,
                           mode='bilinear', align_corners=False).numpy()
elif (mode == 1):
    f1_val = F.interpolate(torch.from_numpy(f1_val), size=28,
                           mode='bilinear', align_corners=False).numpy()
    f3_val = F.interpolate(torch.from_numpy(f3_val), size=28,
                           mode='bilinear', align_corners=False).numpy()
elif (mode == 2):
    f1_val = F.interpolate(torch.from_numpy(f1_val), size=14,
                           mode='bilinear', align_corners=False).numpy()
    f2_val = F.interpolate(torch.from_numpy(f2_val), size=14,
                           mode='bilinear', align_corners=False).numpy()

print('f1_val.shape =', f1_val.shape)
print('f2_val.shape =', f2_val.shape)
print('f3_val.shape =', f3_val.shape)

f123_val = np.concatenate([f1_val, f2_val, f3_val], axis=1)

print('f123_val.shape =', f123_val.shape)

del f1_val
del f2_val
del f3_val


ここまでで、学習データと評価データにおける、縦連結のアクティベーションマップが作成できました。
もし、試して頂いた方で、メモリ逼迫によってプログラムが落ちてしまった人などがいらっしゃいましたら、読み込む件数を減らすなどして、対処頂けたらと思います。

次には、いよいよ学習処理を実施します。
学習処理では、マハラノビス距離を求めるための標本平均 \mu_{ij} と、標本共分散 \sum_{ij} を算出します。
全ての (i, j) について、算出します。

以下コードにて、それが実践されます。

from tqdm.notebook import tqdm

cov_inv = np.zeros([f123_train.shape[2], f123_train.shape[3], 550, 550])
mean = np.zeros([f123_train.shape[2], f123_train.shape[3], 550])

for i_h in tqdm(range(f123_train.shape[2])):
    for i_w in range(f123_train.shape[3]):
        f = f123_train[:, :, i_h, i_w].copy()
        mean[i_h, i_w] = np.mean(f, axis=0)

        f = f - mean[i_h, i_w][None]
        f = ((f.T @ f) / (len(f) - 1)) + (0.01 * np.eye(f.shape[1]))

        cov_inv[i_h, i_w] = np.linalg.inv(f)


これによって、マハラノビス距離を求める準備ができました。

尚、標本共分散 Σ_{ij} を算出するに当たっては、正則化項 I を用いていますが、それがコード中の (0.01 * np.eye(f.shape[1])) となっています。
0.01 を変更することで、正則化の強弱を調整します。
私が実験した限りでは、この正則化項が無いと、標本共分散の算出が失敗してしまうようでした。

次に、評価データについて、実際にマハラノビス距離を算出してみましょう。

from scipy.spatial.distance import mahalanobis

score_val = np.zeros([len(f123_val), f123_val.shape[2], f123_val.shape[3]])

for i_h in tqdm(range(f123_val.shape[2])):
    for i_w in range(f123_val.shape[3]):
        f = f123_val[:, :, i_h, i_w]
        score_tmp = [mahalanobis(sample, mean[i_h, i_w], cov_inv[i_h, i_w])
                     for sample in f]
        score_val[:, i_h, i_w] = np.array(score_tmp)

print('score_val.shape =', score_val.shape)
print('np.mean(score_val) =', np.mean(score_val))
print('np.std(score_val) =', np.std(score_val))
print('np.max(score_val) =', np.max(score_val))
print('np.min(score_val) =', np.min(score_val))


これで、異常スコアマップの算出ができました。
異常スコアマップに対して、閾値を適用すれば、異常位置のセグメンテーションが実現されます。

ちなみに、先程もお伝えしたように、Dogs vs Catsデータセットでは精度が出ません。
その為、pixel単位の異常位置セグメンテーションの可視化は、一旦行わずに、先に画像単位の異常検知へと進もうと思います。

画像単位の異常検知は、pixel単位の異常スコアマップにおける最大値を、画像レベルの異常スコアとして用います。
その為、以下のようなシンプルなコードで実現できます。

score_val_ = np.max(np.max(score_val, axis=-1), axis=-1)

print('score_val_.shape =', score_val_.shape)
print('np.mean(score_val_) =', np.mean(score_val_))
print('np.std(score_val_) =', np.std(score_val_))
print('np.max(score_val_) =', np.max(score_val_))
print('np.min(score_val_) =', np.min(score_val_))


次に、この画像レベルの異常スコアの分布を可視化と、閾値別の正解率を求めてみます。
以下コードになります。

plt.figure(figsize=(10, 8), dpi=100, facecolor='white')

plt.subplot(2, 1, 1)
plt.scatter(np.where(y_val == 0)[0], 
            score_val_[y_val == 0], alpha=0.5, label='猫画像')
plt.scatter(np.where(y_val == 1)[0], 
            score_val_[y_val == 1], alpha=0.5, label='犬画像')
plt.grid()
plt.legend()

plt.subplot(2, 1, 2)
plt.hist(score_val_[y_val == 0], alpha=0.5, bins=50, label='猫画像')
plt.hist(score_val_[y_val == 1], alpha=0.5, 
         bins=int(50*(N_dog/N_cat_val)), label='犬画像')
plt.grid()
plt.legend()

plt.show()

acc_list = []
thresh_list = np.arange(0, np.max(score_val_)+1e-10, np.max(score_val_)/50)

for thresh in thresh_list:

    acc = np.mean(np.concatenate([(score_val_[y_val == 0] < thresh),
                                  (score_val_[y_val == 1] > thresh)]))
    acc_list.append(acc)

acc_list = np.array(acc_list)

plt.figure(figsize=(10, 4), dpi=100, facecolor='white')
plt.plot(thresh_list, acc_list, alpha=0.5, linewidth=3)
plt.scatter(thresh_list, acc_list, alpha=0.5, s=30)
plt.grid()
plt.title('np.max(acc_list) = %.3f, thresh_list[np.argmax(acc_list) = %.2f' %
          (np.max(acc_list), thresh_list[np.argmax(acc_list)]))
plt.show()


正常データと異常データとの異常スコアの分布が、かなり重なってしまっていることが確認できるかと思います。
また、閾値調整を行った際の最高正解率が65.5%であり、お世辞にも高いとは言えません。
その理由としては、やはり、Dogs vs Catsデータセットの撮影アングルが安定していない為かと思われます。
試しに、幾つかの画像をピックアップしてみて、その平均を取ってみます。

やはり、アングルが安定していませんね…。


ここで、画像上の中心から離れた外側の部分には、猫以外の背景等が写り込んでいる可能性が高いかと思われる為、そこを削除して、異常検知を実施してみます。
尚、ここで実施するのは、評価データからのみと言いますか、算出後の異常スコアマップから除くのみです。
先程計算した異常スコアを、以下コードの要領で編集して、精度を測定してみます。

score_val__ = np.max(np.max(score_val[:, 6:-6, 6:-6], axis=-1), axis=-1)

plt.figure(figsize=(10, 8), dpi=100, facecolor='white')

plt.subplot(2, 1, 1)
plt.scatter(np.where(y_val == 0)[0], 
            score_val__[y_val == 0], alpha=0.5, label='猫画像')
plt.scatter(np.where(y_val == 1)[0], 
            score_val__[y_val == 1], alpha=0.5, label='犬画像')
plt.grid()
plt.legend()

plt.subplot(2, 1, 2)
plt.hist(score_val__[y_val == 0], alpha=0.5, bins=50, label='猫画像')
plt.hist(score_val__[y_val == 1], alpha=0.5, 
         bins=int(50*(N_dog/N_cat_val)), label='犬画像')
plt.grid()
plt.legend()

plt.show()

acc_list = []
thresh_list = np.arange(0, np.max(score_val__)+1e-10, np.max(score_val__)/50)

for thresh in thresh_list:

    acc = np.mean(np.concatenate([(score_val__[y_val == 0] < thresh),
                                  (score_val__[y_val == 1] > thresh)]))
    acc_list.append(acc)

acc_list = np.array(acc_list)

plt.figure(figsize=(10, 4), dpi=100, facecolor='white')
plt.plot(thresh_list, acc_list, alpha=0.5, linewidth=3)
plt.scatter(thresh_list, acc_list, alpha=0.5, s=30)
plt.grid()
plt.title('np.max(acc_list) = %.3f, thresh_list[np.argmax(acc_list) = %.2f' %
          (np.max(acc_list), thresh_list[np.argmax(acc_list)]))
plt.show()


結果は、ほとんど変わりませんでした。
画像の外側にて、誤検知が多数されているという訳でもないようです。


更なる試しとして、yolov5にて、 dog cat を検出し、その検出枠でcropした画像でも同様に試してみたのですが、精度向上はそこまで見られなかった次第です。
もっと丁寧にアライメントを行う必要がありそうです。

アライメントがされた様子

精度

精度(先程と同様に、スコアマップの外側を除いた場合)


悔しいので、アライメント改善をもう少し試みてみます。
以下のオックスフォード大学が発行して下さっているデータセットを用いてみます。
Visual Geometry Group - University of Oxford

アノテーション例が以下です。


このデータセットから、背景をマスクした上で、猫と犬の顔部分だけをクロップして、それを用いてみます。
アライメントは、結構キレイにされている印象です。

猫画像の例

犬画像の例


Dogs vs Catsデータセットにて行ったことと同様のことを、上記データにも実施してみます。
その結果、以下のように精度が向上しました。

精度

精度(先程と同様に、スコアマップの外側を除いた場合)


このことから、少々のアライメント改善ですと、あまり精度への影響は無いですが、大幅にアライメント改善を行えば、精度が改善することが伺えます。

ここで、正答している対象と、誤答している対象を確認してみましょう。
また、異常位置のセグメンテーションについては、実は先程算出した score_map によって実施ができています。
それも合わせて可視化してみましょう。


実施コードは以下です。
冒頭のフラグ値を調整すれば、犬/猫、スコアの昇順/降順が調整できます。

flg_dog = 0
flg_asc = True

score = np.max(np.max(score_val, axis=-1), axis=-1)[y_val == flg_dog]
score_max = np.max(score_val)
score_min = np.min(score_val)
if flg_asc:
    idx_sort = np.argsort(score)
else:
    idx_sort = np.argsort(-score)

for i in idx_sort[:10]:
    img = img_prep_val[y_val == flg_dog][i]

    score_map = score_val[y_val == flg_dog][i]
    score_map = (score_map - score_min) / (score_max - score_min)
    score_map = cv2.resize(score_map, (224, 224))

    plt.figure(figsize=(11, 2.5), dpi=100, facecolor='white')
    plt.subplot(1, 3, 1)
    plt.imshow(img)
    plt.subplot(1, 3, 2)
    plt.imshow(score_map)
    plt.colorbar()
    plt.subplot(1, 3, 3)
    plt.imshow(overlay_heatmap_on_image(img=img, heatmap=score_map))
    plt.title('regularized anomaly score = %.2f' % np.max(score_map))
    plt.show()

猫画像にて、大きく正答している対象TOP1

猫画像にて、大きく正答している対象TOP2

猫画像にて、大きく正答している対象TOP3

猫画像にて、大きく正答している対象TOP4

猫画像にて、大きく正答している対象TOP5

猫画像にて、大きく正答している対象TOP6

猫画像にて、大きく正答している対象TOP7

猫画像にて、大きく正答している対象TOP8

猫画像にて、大きく正答している対象TOP9

猫画像にて、大きく正答している対象TOP10


猫画像にて、大きく誤答している対象TOP1

猫画像にて、大きく誤答している対象TOP2

猫画像にて、大きく誤答している対象TOP3

猫画像にて、大きく誤答している対象TOP4

猫画像にて、大きく誤答している対象TOP5

猫画像にて、大きく誤答している対象TOP6

猫画像にて、大きく誤答している対象TOP7

猫画像にて、大きく誤答している対象TOP8

猫画像にて、大きく誤答している対象TOP9

猫画像にて、大きく誤答している対象TOP10


犬画像にて、大きく誤答している対象TOP1

犬画像にて、大きく誤答している対象TOP2

犬画像にて、大きく誤答している対象TOP3

犬画像にて、大きく誤答している対象TOP4

犬画像にて、大きく誤答している対象TOP5

犬画像にて、大きく誤答している対象TOP6

犬画像にて、大きく誤答している対象TOP7

犬画像にて、大きく誤答している対象TOP8

犬画像にて、大きく誤答している対象TOP9

犬画像にて、大きく誤答している対象TOP10


犬画像にて、大きく正答している対象TOP1

犬画像にて、大きく正答している対象TOP2

犬画像にて、大きく正答している対象TOP3

犬画像にて、大きく正答している対象TOP4

犬画像にて、大きく正答している対象TOP5

犬画像にて、大きく正答している対象TOP6

犬画像にて、大きく正答している対象TOP7

犬画像にて、大きく正答している対象TOP8

犬画像にて、大きく正答している対象TOP9

犬画像にて、大きく正答している対象TOP10


なかなか示唆深い結果となりました。


特に、犬画像にて正答している対象、即ち、異常反応を示せている対象については、猫にはない垂れた耳や、大きな鼻などに反応できていると思います。

全体を通しては、しっかりと正面を向いているとスコアが低くなる傾向があったり、その逆で横を向いているとスコアが高くなる傾向があることから、PaDiMのpixel依存の傾向が伺えるかと思います。
SPADEと違い、PaDiMは画像パッチ単位別にそれぞれ独立して異常検知を行っている為、多くの場合に真ん中に写り込んでいるような対象が、左右にズレて写り込んでいたりすると、異常を示しやすくなっているものと思われます。
この部分は、恐らくはMVTec向けに、ある程度割り切ってしまっているところなのかと思われます。


さて、ここで、いよいよMVTecデータセットにて、PaDiM適用を試してみようと思います。
尚、PaDiMはMVTecデータセットですと、非常に高い精度が出ます。
MVTecが、非常にアライメントが整っているデータである為かと思われます。

以下コードにて、Dogs vs Catsデータセットにて行ったことを、MVTecにて順次行っていきます。

ファイル名収集

path_parent = './mvtec_anomaly_detection/'
type_data = 'bottle'

path_train = os.path.join(path_parent, type_data, 'train/good')
files_train = [os.path.join(path_train, f) for f in os.listdir(path_train)
               if (os.path.isfile(os.path.join(path_train, f)) &
                   ('.png' in f))]
files_train = np.array(sorted(files_train))

print('len(files_train) =', len(files_train))
print('files_train[:5] =\n', files_train[:5])
print()

types_test = os.listdir(os.path.join(path_parent, type_data, 'test'))
types_test = np.array(sorted(types_test))

files_test = {}

for type_test in types_test:

    path_test = os.path.join(path_parent, type_data, 'test', type_test)
    files_test[type_test] = [os.path.join(path_test, f)
                             for f in os.listdir(path_test)
                             if (os.path.isfile(os.path.join(path_test, f)) &
                                 ('.png' in f))]
    files_test[type_test] = np.array(sorted(files_test[type_test]))

    print('len(files_test[%s]) =' % type_test, len(files_test[type_test]))
    print('files_test[%s][:5] =\n' % type_test, files_test[type_test][:5])
    print()

plt.figure(figsize=(12, 32), dpi=100, facecolor='white')

plt.subplot(8, 2, 1)
plt.imshow(cv2.imread(files_train[0])[..., ::-1])
plt.title(files_train[0])

for i_type_test, type_test in enumerate(types_test):
    plt.subplot(8, 2, (i_type_test + 3))
    plt.imshow(cv2.imread(files_test[type_test][0])[..., ::-1])
    plt.title(files_test[type_test][0])

plt.show()


学習データからの特徴収集

outputs = []

img_train = []
img_prep_train = []

for file in tqdm(files_train):
    img = cv2.imread(file)[..., ::-1]  # BGR2RGB
    img_prep = cv2.resize(img, (256, 256), interpolation=cv2.INTER_AREA)
    img_prep = img_prep[16:(256-16), 16:(256-16)]

    x = img_prep
    x = x / 255
    x = x - np.array([[[0.485, 0.456, 0.406]]])
    x = x / np.array([[[0.229, 0.224, 0.225]]])
    x = torch.from_numpy(x.astype(np.float32)).unsqueeze(0).permute(0, 3, 1, 2)
    x = x.to(device)

    with torch.no_grad():
        _ = model(x)

    img_train.append(img)
    img_prep_train.append(img_prep)

img_prep_train = np.stack(img_prep_train)
f1_train = np.vstack(outputs[0::3])
f2_train = np.vstack(outputs[1::3])
f3_train = np.vstack(outputs[2::3])

print('len(img_train) =', len(img_train))
print('img_prep_train.shape =', img_prep_train.shape)
print('f1_train.shape =', f1_train.shape)
print('f2_train.shape =', f2_train.shape)
print('f3_train.shape =', f3_train.shape)


テストデータからの特徴収集

img_test = {}
img_prep_test = {}
gt_test = {}

f1_test = {}
f2_test = {}
f3_test = {}
fl_test = {}

for type_test in types_test:

    outputs = []

    img_test[type_test] = []
    img_prep_test[type_test] = []
    gt_test[type_test] = []

    for file in tqdm(files_test[type_test]):
        img = cv2.imread(file)[..., ::-1]  # BGR2RGB
        img_prep = cv2.resize(img, (256, 256), interpolation=cv2.INTER_AREA)
        img_prep = img_prep[16:(256-16), 16:(256-16)]

        if (type_test == 'good'):
            gt = np.zeros_like(img_prep[..., 0], dtype=np.uint8)
        else:
            file_gt = file.replace('/test/', '/ground_truth/')
            file_gt = file_gt.replace('.png', '_mask.png')
            gt = cv2.imread(file_gt, cv2.IMREAD_GRAYSCALE)
            gt = cv2.resize(gt, (256, 256), interpolation=cv2.INTER_NEAREST)
            gt = gt[16:(256-16), 16:(256-16)]
            gt = (gt / np.max(gt)).astype(np.uint8)
        gt_test[type_test].append(gt)

        x = img_prep
        x = x / 255
        x = x - np.array([[[0.485, 0.456, 0.406]]])
        x = x / np.array([[[0.229, 0.224, 0.225]]])
        x = torch.from_numpy(x.astype(np.float32)).unsqueeze(0).permute(0, 3, 1, 2)
        x = x.to(device)

        with torch.no_grad():
            _ = model(x)

        img_test[type_test].append(img)
        img_prep_test[type_test].append(img_prep)

    img_prep_test[type_test] = np.stack(img_prep_test[type_test])
    gt_test[type_test] = np.stack(gt_test[type_test])
    f1_test[type_test] = np.vstack(outputs[0::3])
    f2_test[type_test] = np.vstack(outputs[1::3])
    f3_test[type_test] = np.vstack(outputs[2::3])

    print('len(img_test[%s]) =' % type_test, len(img_test[type_test]))
    print('img_prep_test[%s].shape =' % type_test, img_prep_test[type_test].shape)
    print('gt_test[%s].shape =' % type_test, gt_test[type_test].shape)
    print('f1_test[%s].shape =' % type_test, f1_test[type_test].shape)
    print('f2_test[%s].shape =' % type_test, f2_test[type_test].shape)
    print('f3_test[%s].shape =' % type_test, f3_test[type_test].shape)


ランダムチョイスによる特徴次元削減

np.random.seed(0)
idx_tmp = np.sort(np.random.permutation(np.arange(256+512+1024))[:550])

f1_train = f1_train[:, idx_tmp[idx_tmp < 256]]
f2_train = f2_train[:, (idx_tmp[(256 <= idx_tmp) & (idx_tmp < (256 + 512))] - 256)]
f3_train = f3_train[:, (idx_tmp[(256 + 512) <= idx_tmp] - (256 + 512))]

print('f1_train.shape =', f1_train.shape)
print('f2_train.shape =', f2_train.shape)
print('f3_train.shape =', f3_train.shape)

for type_test in types_test:

    f1_test[type_test] = f1_test[type_test][:, idx_tmp[idx_tmp < 256]]
    f2_test[type_test] = f2_test[type_test][:, (idx_tmp[(256 <= idx_tmp) & (idx_tmp < (256 + 512))] - 256)]
    f3_test[type_test] = f3_test[type_test][:, (idx_tmp[(256 + 512) <= idx_tmp] - (256 + 512))]

    print('f1_test[%s].shape =' % type_test, f1_test[type_test].shape)
    print('f2_test[%s].shape =' % type_test, f2_test[type_test].shape)
    print('f3_test[%s].shape =' % type_test, f3_test[type_test].shape)


各層のアクティベーションマップを縦連結
(※MVTecはデータ数が少ない為、変数削除は割愛)

import torch.nn.functional as F

mode = 2

if (mode == 0):
    f2_train = F.interpolate(torch.from_numpy(f2_train), size=56,
                             mode='bilinear', align_corners=False).numpy()
    f3_train = F.interpolate(torch.from_numpy(f3_train), size=56,
                             mode='bilinear', align_corners=False).numpy()
elif (mode == 1):
    f1_train = F.interpolate(torch.from_numpy(f1_train), size=28,
                             mode='bilinear', align_corners=False).numpy()
    f3_train = F.interpolate(torch.from_numpy(f3_train), size=28,
                             mode='bilinear', align_corners=False).numpy()
elif (mode == 2):
    f1_train = F.interpolate(torch.from_numpy(f1_train), size=14,
                             mode='bilinear', align_corners=False).numpy()
    f2_train = F.interpolate(torch.from_numpy(f2_train), size=14,
                             mode='bilinear', align_corners=False).numpy()

print('f1_train.shape =', f1_train.shape)
print('f2_train.shape =', f2_train.shape)
print('f3_train.shape =', f3_train.shape)

f123_train = np.concatenate([f1_train, f2_train, f3_train], axis=1)

print('f123_train.shape =', f123_train.shape)

f123_test = {}

for type_test in types_test:

    if (mode == 0):
        f2_test[type_test] = F.interpolate(torch.from_numpy(f2_test[type_test]),
                                           size=56, mode='bilinear',
                                           align_corners=False).numpy()
        f3_test[type_test] = F.interpolate(torch.from_numpy(f3_test[type_test]),
                                           size=56, mode='bilinear',
                                           align_corners=False).numpy()
    elif (mode == 1):
        f1_test[type_test] = F.interpolate(torch.from_numpy(f1_test[type_test]),
                                           size=28, mode='bilinear',
                                           align_corners=False).numpy()
        f3_test[type_test] = F.interpolate(torch.from_numpy(f3_test[type_test]),
                                           size=28, mode='bilinear',
                                           align_corners=False).numpy()
    elif (mode == 2):
        f1_test[type_test] = F.interpolate(torch.from_numpy(f1_test[type_test]),
                                           size=14, mode='bilinear',
                                           align_corners=False).numpy()
        f2_test[type_test] = F.interpolate(torch.from_numpy(f2_test[type_test]),
                                           size=14, mode='bilinear',
                                           align_corners=False).numpy()

    print('f2_test[%s].shape =' % type_test, f2_test[type_test].shape)
    print('f3_test[%s].shape =' % type_test, f3_test[type_test].shape)

    f123_test[type_test] = np.concatenate([f1_test[type_test], f2_test[type_test], 
                                           f3_test[type_test]], axis=1)

    print('f123_test[%s].shape =' % type_test, f123_test[type_test].shape)


マハラノビス距離算出用の標本平均  \mu_{ij} と、標本共分散  \sum_{ij} を算出

from tqdm.notebook import tqdm

cov_inv = np.zeros([f123_train.shape[2], f123_train.shape[3], 550, 550])
mean = np.zeros([f123_train.shape[2], f123_train.shape[3], 550])

for i_h in tqdm(range(f123_train.shape[2])):
    for i_w in range(f123_train.shape[3]):
        f = f123_train[:, :, i_h, i_w].copy()
        mean[i_h, i_w] = np.mean(f, axis=0)

        f = f - mean[i_h, i_w][None]
        f = ((f.T @ f) / (len(f) - 1)) + (0.01 * np.eye(f.shape[1]))

        cov_inv[i_h, i_w] = np.linalg.inv(f)


テストデータについて、マハラノビス距離を算出

score_test = {}

for type_test in types_test:

    score_test[type_test] = np.zeros([len(f123_test[type_test]),
                                      f123_train.shape[2], f123_train.shape[3]])

    for i_h in tqdm(range(f123_train.shape[2])):
        for i_w in range(f123_train.shape[3]):
            f = f123_test[type_test][:, :, i_h, i_w]
            score_tmp = [mahalanobis(sample, mean[i_h, i_w], cov_inv[i_h, i_w])
                         for sample in f]
            score_test[type_test][:, i_h, i_w] = np.array(score_tmp)

    print('score_test[%s].shape =' % type_test, score_test[type_test].shape)
    print('np.mean(score_test[%s]) =' % type_test, np.mean(score_test[type_test]))
    print('np.mean(np.abs(score_test[%s])) =' % type_test, np.mean(np.abs(score_test[type_test])))
    print('np.std(score_test[%s]) =' % type_test, np.std(score_test[type_test]))
    print('np.max(score_test[%s]) =' % type_test, np.max(score_test[type_test]))
    print('np.min(score_test[%s]) =' % type_test, np.min(score_test[type_test]))


テストデータについて、画像単位異常検知の予測分布可視化と精度算出を実施

y_hat_list = []
y_list = []
N_test = 0

type_test = 'good'

plt.figure(figsize=(10, 8), dpi=100, facecolor='white')

plt.subplot(2, 1, 1)
plt.scatter((np.arange(len(score_test[type_test])) + N_test),
            np.max(np.max(score_test[type_test], axis=-1), axis=-1),
            alpha=0.5, label=type_test)

plt.subplot(2, 1, 2)
plt.hist(np.max(np.max(score_test[type_test], axis=-1), axis=-1),
         alpha=0.5, bins=10, label=type_test)

y_hat_list.append(np.max(np.max(score_test[type_test], axis=-1), axis=-1))
y_list.append(np.zeros([len(score_test[type_test])], dtype=np.int16))
N_test += len(score_test[type_test])

for type_test in types_test[types_test != 'good']:

    plt.subplot(2, 1, 1)
    plt.scatter((np.arange(len(score_test[type_test])) + N_test),
                np.max(np.max(score_test[type_test], axis=-1), axis=-1),
                alpha=0.5, label=type_test)

    plt.subplot(2, 1, 2)
    plt.hist(np.max(np.max(score_test[type_test], axis=-1), axis=-1),
             alpha=0.5, bins=10, label=type_test)

    y_hat_list.append(np.max(np.max(score_test[type_test], axis=-1), axis=-1))
    y_list.append(np.ones([len(score_test[type_test])], dtype=np.int16))
    N_test += len(score_test[type_test])

y_list = np.hstack(y_list)
y_hat_list = np.hstack(y_hat_list)

plt.subplot(2, 1, 1)
plt.grid()
plt.legend()

plt.subplot(2, 1, 2)
plt.grid()
plt.legend()

plt.show()

from sklearn.metrics import roc_curve
from sklearn.metrics import roc_auc_score

# calculate per-image level ROCAUC
fpr, tpr, _ = roc_curve(y_list, y_hat_list)
per_image_rocauc = roc_auc_score(y_list, y_hat_list)

plt.figure(figsize=(10, 6), dpi=100)
plt.plot(fpr, tpr, label='%s ROCAUC: %.3f' % (type_data, per_image_rocauc))
plt.grid()
plt.legend()
plt.show()


なんと、閾値次第では全問正解という形で、ROCカーブのAUCが1.0となりました。
Dogs vs Catsデータセットでは全く歯が立たなかったPaDiMが、MVTecの bottle では最高精度を引き出した次第です。
また、試して頂くと分かるのですが、 bottle 以外のデータ種別においても、高い精度が出ます。

しかしながら、ネジのデータ種別 screw については、ROC-AUCが0.884に留まりました。
MVTecのデータ種別の中で、最も精度が低くなってしまうのが、 screw になります。
数字としては、決して悪くない精度ではあるのですが、MVTecにおける他のデータ種別での発揮精度に比べると、見劣りしてしまいます。
その様子ですが、以下となります。


元データの様子を見て頂くと、ネジがクルクルと回っていることが確認できるかと思います。
他にもクルクルと回ってしまう対象はあるのですが、それによる見えの変化が特に大きいのが screw かと思われます。

これは、Pixel依存に傾倒しているPaDiMアルゴリズムにとっては、やや苦手とする課題感です。
ただし、Dogs vs Catsデータセットに対するように、全く歯が立たない訳ではありません。
この辺りから、PaDiMがどのくらいまでの撮像揺れに耐えられるかが、伺い知れるところです。
或いは、PaDiMの適用を検討している場合には、撮影条件について、この特性を加味する必要があります。


さて、画像単位での異常検知精度は確認ができました。
次には、pixel単位での異常位置セグメンテーションの様子を確認してみましょう。
先程算出した異常スコアマップを用いて、可視化を実現することができます。

その実施コードが以下です。

# https://github.com/gsurma/cnn_explainer/blob/main/utils.py#L16
def overlay_heatmap_on_image(img, heatmap, ratio_img=0.5):
    img = img.astype(np.float32)

    heatmap = 1 - np.clip(heatmap, 0, 1)
    heatmap = cv2.applyColorMap((heatmap * 255).astype(np.uint8), cv2.COLORMAP_JET)
    heatmap = heatmap.astype(np.float32)

    overlay = (img * ratio_img) + (heatmap * (1 - ratio_img))
    overlay = np.clip(overlay, 0, 255)
    overlay = overlay.astype(np.uint8)
    return overlay


gt_flat_list = []
score_flat_list = []

score_max = max([np.max(score_test[type_test]) for type_test in types_test])

for type_test in types_test:
    for i, gt in tqdm(enumerate(gt_test[type_test]),
                      desc=('[verbose mode] visualize localization (case:%s)' %
                            type_test)):
        file = files_test[type_test][i]
        img = cv2.imread(file)[..., ::-1]  # BGR2RGB
        img = cv2.resize(img, (256, 256), interpolation=cv2.INTER_AREA)
        img = img[16:(256-16), 16:(256-16)]
        score = score_test[type_test][i]
        score = cv2.resize(score, (224, 224))

        # calculate per-pixel level ROCAUC
        gt_flat = gt.reshape(-1)
        score_flat = score.reshape(-1)
        if (np.sum(gt_flat > 0) == 0):
            rocauc = np.nan
        else:
            rocauc = roc_auc_score(gt_flat, score_flat)

        # stock for output result
        gt_flat_list.append(gt_flat)
        score_flat_list.append(score_flat)

        plt.figure(figsize=(9, 7.3), dpi=100, facecolor='white')
        plt.rcParams['font.size'] = 8
        plt.subplot2grid((3, 3), (0, 0), rowspan=1, colspan=1)
        plt.imshow(img)
        plt.title('%s : %s' % (file.split('/')[-2], file.split('/')[-1]))
        plt.subplot2grid((3, 3), (0, 1), rowspan=1, colspan=1)
        plt.imshow(gt)
        plt.subplot2grid((3, 3), (0, 2), rowspan=1, colspan=1)
        plt.imshow(score)
        plt.colorbar()
        plt.title('global max score : %.2f' % score_max)
        plt.subplot2grid((3, 4), (1, 0), rowspan=2, colspan=2)
        plt.imshow(overlay_heatmap_on_image(img, (score / score_max)))
        plt.title('per-pixel level ROCAUC: %.3f' % rocauc)
        plt.subplot2grid((3, 4), (1, 2), rowspan=2, colspan=2)
        plt.imshow((img.astype(np.float32) *
                    (score / score_max)[..., None]).astype(np.uint8))
        plt.show()

gt_flat_list = np.array(gt_flat_list).reshape(-1)
score_flat_list = np.array(score_flat_list).reshape(-1)

# calculate per-pixel level ROCAUC
fpr, tpr, _ = roc_curve(gt_flat_list, score_flat_list)
rocauc = roc_auc_score(gt_flat_list, score_flat_list)
print('%s per-pixel level ROCAUC: %.3f' % (type_data, rocauc))

plt.figure(figsize=(10, 6), dpi=100)
plt.plot(fpr, tpr, label='%s ROCAUC: %.3f' % (type_data, rocauc))
plt.grid()
plt.legend()
plt.show()


上手く異常位置のセグメンテーションが行えていることが確認できるかと思います。
bottle データ種別全体のpixel単位のROC-AUCも0.982と非常に高い数値です。
非常に優秀ですね。


ここで、最後に screw での異常位置セグメンテーションの結果も確認してみましょう。
実は、全pixel単位のROC-AUCとしては、0.983という高い数値が出ます。
しかし、よくよく結果を確認すると、「なぜこのピクセルに反応するのか?」という誤検知が散見されます。
アルゴリズムとしては優秀なものの、やはりpixel依存の課題感が垣間見える、というところです。
しかし、僅かなpixel数で間違いを重ねる分には、全pixel単位のROC-AUCには軽微な影響しか与えないために、精度が高くなっている次第です。
一方で、画像単位での検知精度となると、最大の異常スコアマップ値を画像単位の異常スコアとしている為に、精度への影響が強く表れることとなります。

以下、異常位置セグメンテーションの結果を、少し多めに添付します。


また、アルゴリズムを把握したところで、最後の捕捉として、論文にて行われていたaugmentation関して触れたいと思います。
個人的には微妙に議論の余地があるかと思っています。
実際に試してもみたのですが、効果が高いとは言えず、直感的に取り扱いづらい印象を受けました。

というのも、先ず1つには、それは敢えてpixel依存を崩しに行く行為であるからです。
つまり、結果から見ても分かるように、pixelレベルでの異常位置セグメンテーションの精度はどうしても下がってしまいます。
ただし、その精度低下が正則化の効果をもたらし、画像レベルでの異常検知の精度上昇は起こり得そうな気はします。
しかし、ちょっと実現したいアルゴリズム感との衝突を感じずにはいられません。

もう1つには、ImageNetを識別するモデルの学習中に、augmentationは行われているであろう点です。
ImageNetモデル学習での、augmentationの目的感は「多少の見えの違いであれば、同じものとして対応して欲しい」という気持ちである筈です。
これは、もう少し技術寄りに解釈を深めると、「多少見えの異なる対象からは、同様の特徴を抽出して欲しい」という気持ちに変わるかと思います。
以前の転移学習の回で実際に、同じ概念の対象から類似の特徴が抽出されている様子を確認して頂いたかと思います。
つまり、ImageNetの識別モデルとしては、多少のaugmentation実施をしても、最終的には冗長な同様の特徴を抽出することが是であると思います。
一方で、PaDiMがaugmentationによって発生させようとしているのは、多少の見えの違いからの特徴の揺れになりますが、ImageNetモデル学習時に行ったaugmentationの元々の目的感からすると、それは発生しづらいと考えるのが妥当ではないかと思います。
ただし、それは層の深さによっても傾向が異なってくるところではあり、かつ、PaDiMは浅い層と深い層の両方から特徴を取っている為、augmentationが精度向上を上手くもたらしてくれることも考えられるかと思います。
しかし、個人的には、調整すべきはアライメントの方、つまり、データ全般の位置合わせを改善する方が、PaDiMの本当の精度を引き出すためには、肝要なのではないかと考えます。


という訳で、以上でPaDiMアルゴリズムの解説を終えます。

pixel依存という課題感を若干強調してしまいましたが、基本的には非常に優秀なアルゴリズムかと思われます。
或いは、pixel依存をしても構わないほどに、アライメントがしっかりしているデータが対象であれば、積極的に使っても良いアルゴリズムかと思われます。
ハード側でアライメントの解決が可能な課題感などであれば、すぐにでも実用検討が可能なアルゴリズムかもしれません。

尚、実運用の際は、マハラノビス距離を求めるための標本平均  \mu_{ij} と、標本共分散  \sum_{ij} をロードすれば、すぐに異常検知推論を実施することができます。
RAMを逼迫する特徴ギャラリーをロードしなくても良い為、非常にリーズナブルです。

一方で、実際の課題感を考えると、pixel依存はやはり気になってしまう課題感かと思われます。
アライメントが難しい課題に取り組んでいる方からすると、やはりSPADE等が1stオプションになってしまうことでしょう。
その意味では、PaDiMはある種の割り切りによって、リーズナブルに精度を発揮することを検証した、1つの切り口と捉えると良いのかと思います。
個人的には、繰り返しになりますが、スピンオフという印象です。
pixel依存という割り切りによって高い精度が出せることを証明した、素晴らしい結果かと思います。


尚、この後に控える幾ばくかの哀愁を感じてしまう展開が、次なる異常検知手法のPatchCoreとなります。
PatchCoreは、SPADEの良さを追求することで、MVTecにおけるPaDiMの精度を追い抜いた、SOTAアルゴリズムです。
SPADE同様にpixel依存のないアルゴリズムながら、PaDiMの精度を凌ぐ結果を残した次第です。
ただ、MVTecの結果が全てではなく、データセットにおいてはPaDiMが精度を凌ぐこともあろうかと思いますので、シックスマンとして懐に持っておくことに意義があるかと思います。



おわりに

以上、今回の記事では、異常検知4手法の3つ目であるPaDiMについて、詳細に踏み込んだ解説をさせて頂きました。
ここまで読んでくださった方には、またまたまた感謝です。🙇
次回の記事では、異常検知4手法のラスト、PatchCoreについて説明をさせて頂こうと思います。
そちらもお読み頂けますと、有り難い次第です。🙇

AnyTechでは、流体のためのAI「DeepLiquid」の研究/開発/事業に携わってくれる方を募集しております。
ご興味のある方はこちらから、是非お問い合わせください。