AnyTech Engineer Blog

AnyTech Engineer Blogは、AnyTechのエンジニアたちによる調査や成果、Tipsなどを公開するブログです。

ImageNetモデルを用いた異常検知手法の解説【第5回:PatchCore(Towards Total Recall in Industrial Anomaly Detection)】

ImageNetモデルを用いた異常検知手法の解説【第5回:PatchCore(Towards Total Recall in Industrial Anomaly Detection)】 こんにちは、AnyTechの木村と申します。
AnyTechにて、機械学習エンジニアやAIエンジニアといった役割にて、R&Dに日々従事しております。


この記事は、昨今流行しているImageNetモデルを応用した異常検知手法について、リサーチをする機会がありましたので、小職の備忘も兼ねまして、その解説をさせて頂くものです。



目次



シリーズ



はじめに

今回で最終回となります。
近年流行しているImageNetモデルを応用した異常検知手法について、解説をさせて頂いております。
今回は、PatchCore(Towards Total Recall in Industrial Anomaly Detection)についての解説となります。


PatchCoreは、この記事を書いている2023年3月の時点で、画像レベルの異常検知精度にてTOP1となっているSOTA手法です。
論文の発表は、version1が2021年の6月、version2が2022年の5月とのことでした。
[2106.08265] Towards Total Recall in Industrial Anomaly Detection

また、概ねはSPADEのキレイな上位互換という印象ですが、PaDiMにて導入された工夫も取り入れており、かなり良いところ取りのアルゴリズムとなっています。
また、PatchCoreにて導入された新規アイデアもあります。
1つの集大成と言って良いかと、個人的には思います。
一方で面白いことに、デグレードのような課題感も、実は存在します。
諸々含め、この記事で上手く解説をさせて頂こうと思います。



先ずはザックリ、アルゴリズム概要

PatchCoreは、PaDiMと同様に、pixelレベルの異常位置セグメンテーションを最初に実施します。
そして、クエリ画像におけるpixelレベルの異常スコアの最大値を、画像レベルの異常スコアとして採用します。
その為、PaDiM同様、メインのアルゴリズムは、pixelレベルの異常位置セグメンテーションです。
また、PaDiM同様、最終層の特徴ベクトルは使用しません。
その上、浅い層の中間層の特徴テンソルも使用しません。
使用するのは、中間の深さの層の特徴テンソルと、深めの層の特徴テンソルの2層となります。


また、SPADEと同じく、中間層特徴テンソルの各pixelを用いてKNNを実施し、類似の画像パッチを探索するようなアプローチを取りますが、その探索先が学習データ画像の全パッチ分の特徴となります。
SPADEでは、K個のご近所さん画像の中間層特徴テンソルのみが探索対象でしたが、PatchCoreでは、ご近所さんのみならず全ての画像の中間層特徴テンソルを探索対象とする訳です。
しかし、学習データ画像が多い場合には、コンピューターリソースの問題から、その実現が難しくなる為、コアセットサンプリングを行って、探索対象の中間層特徴テンソルの間引きを行います。
この間引きには、間引きの前後で探索結果が変わらないことを目指し、K-center greedyアルゴリズムを適用します。

論文より引用:toyデータによるコアセットサンプリングのイメージ


尚、中間層特徴テンソルは、SPADEのように層別に取り扱うのではなく、PaDiMのように縦に結合し、中間特徴テンソル各1pixelに対応するそれを1本のみとします。
少しPaDiMと結合の仕方が異なる点として、浅い層と深い層とでチャンネル数を合わせてから結合する方法を取って、各影響力を均しています。
加えて、パラメータによって、受容野の大きさをコントロールできるように、中間特徴テンソル上の注目pixelを中心に、任意の広さpxp pixelの特徴を取得するようにしています。

論文より引用:1x1の特徴取得ではなく、注目pixelを中心とした3x3の特徴を取得するイメージ


上記以外のアルゴリズムについては、SPADEと同じになります。
以下、全体感です。

論文より引用:アルゴリズムオーバービュー、矢印・フローの調整等を少しだけ実施



次にじっくり、アルゴリズム詳細

それでは、ここからじっくりと、PatchCoreの詳細を紹介させて頂きます。


先ずは、アルゴリズムについての論文説明の要約を記載します。


PatchCoreのメソッドは、大きく3つの構成となっている。

  1. メモリバンクに集約するための、局所的な画像パッチ特徴
  2. メモリバンクでのKNN探索効率を高めるためのコアセット削減方法
  3. 画像レベルでの異常検知と、pixelレベルでの異常位置ローカライゼーション


1. Locally aware patch features

DN2、SPADE、PaDiMに続いて、PatchCoreでも、ImageNetで事前にトレーニングされたネットワーク  \phi を使用する。
特定のネットワーク階層の特徴が重要な役割を果たす為、  \phi_{i,j} = \phi_j(x_i) という、データセット X 中の画像 x_i ∈ X を入力とする表現にて、事前学習済みネットワーク \phi の階層レベル j の特徴を示すとする。
特に明記されていない場合は既存文献と同様に、 j は、ResNet50やWideResnet-50等のアーキテクチャにおける特徴マップにインデックスを付け、 j ∈ \lbrace 1, 2, 3, 4 \rbrace は最終的なそれぞれの空間解像度ブロックの出力とする。

用いる特徴の選択肢の1つは、ネットワークの最終層の特徴ベクトルであり、DN2やSPADEは、それを採用している。
しかし、これには次の2つの問題が伴う。

  • (1) PaDiMでも提唱されるように、よりローカライズされた特徴の情報が失われ、かつ、テスト時に遭遇する異常タイプは先天的に分からない為、異常検出パフォーマンスに悪影響を及ぼす
  • (2) ImageNetの事前トレーニング済みネットワークにおける、深い抽象的な特徴は、自然な画像分類のタスクに偏っている為、コールドスタートの産業異常検出タスク、及び、手元にある評価データと、その概念が殆ど重ならない(※DN2論文にて、意外と大丈夫とされていた主張への反論となります)


よって、提供された特徴を利用するためには、中間層か、または、中間層特徴表現を含むパッチレベルのメモリバンク M を使用し、一般的過ぎる特徴や、ImageNet分類に偏った特徴の回避を提案する。
また、理想的には、各パッチ表現は、局所的な空間変動にロバストな、意味のある異常なコンテキストを説明するのに十分な大きさの受容野サイズで動作する。
これは、ストライドプーリングとネットワーク階層を、深い方向に下ることで達成できるが、それによって作成されたパッチ表現は、ImageNet分類に偏ったものとなり、目前の異常検出タスクとの関連性が低くなってしまう。
そんな背景から、各パッチレベルの特徴表現を構成する際には、空間解像度や特徴マップの使いやすさを失うことなく、かつ、小さな空間偏差に対する受容野のサイズとロバスト性を高めるために、ローカル近傍集約が動機付けられる。
つまりは、中間特徴の特定のpixelについて、その周辺の特徴表現も一緒に巻き込んだ形で、1つの特徴を作り上げる形とする。
尚、巻き込んだpixel分だけ、特徴の情報量が増えてしまうが、PatchCoreでは、Adaptive Average Poolingを使用して、特徴ベクトルの集約を行う。


経験的に、或いは、SPADEやPaDiMと同様に、複数の機能階層の集約が何らかの利点を提供することが分かっている。
ただし、使用される特徴の一般性と空間解像度を保持するために、PatchCoreでは2つの中間特徴階層 jj + 1 のみを使用する。
そして、これら中間層の特徴テンソルは、それらの内の最高の解像度で集約することによって、結合が可能となる。
つまり、 \phi_{i,j+1}\phi_{i,j} の内、解像度が低い \phi_{i,j+1} を拡大し、サイズを合わせて結合し、集計を行う。


2. Coreset-reduced patch-feature memory bank

正常画像の量が大きくなると、特徴ギャラリーが大きくなってしまい、テストデータを評価するための推論時間と必要なストレージの両方が増加してしまう。
この問題は、低解像度の特徴マップと高解像度の特徴マップとを両方利用する、SPADEでの異常位置セグメンテーションにおいて、既に指摘されている。
計算上の制限により、SPADEでは、pixelレベルの異常位置セグメンテーションのために、最終層の特徴ベクトルにおける、低解像度の画像類似度を頼りに、特徴ギャラリーへの登録対象の絞り込みを行っている。
この絞り込みにより、最終層の特徴ベクトルという、最低解像度、かつ、ImageNetバイアスな表現を用いることとなり、それが精度への悪影響を及ぼす可能性がある。


その解決に向けては、特徴ギャラリーの情報量を削減した、メモリバンク M を作成する必要がある。
残念ながら、ランダムなサンプリングでは、精度が維持できないことを実験で確認している。
そこで我々は、コアセットのサブサンプリングを使用して M を削減し、精度を維持しながら、推論時間を短縮できることを図り、それに成功した。
概念的には、コアセットの選択は、 A に対する問題の解が、 S に対して計算されたものによって最も厳密に、特により迅速に近似できるように、サブセット S ⊂ A を見つけることを目的としている。
尚、課題感に応じて、対象となるコアセットは異なるが、PatchCoreは最近傍計算を使用する為、引用論文より、「minimax facility location coreset selection」を使用して、 M のコアセット M_C を作成し、それが元のメモリーバンク M とほぼ同様のカバレッジを保証する。

image


M^∗_C の正確な計算は、NP困難である為、引用論文で提案されているような反復貪欲近似を使用する。
また、コアセットの選択時間をさらに短縮するために、別途引用論文に従い、Johnson-Lindenstraussの定理を利用して、ランダムな線形射影  ψ : R^d → R^{d^∗} with d^∗ < d を通じて、要素 m ∈ M の次元を減らす。
メモリ バンクの削減は、以下アルゴリズムの通り。 表記には、$PatchCore-n% を使用して、元のメモリバンクがサブサンプリングされたパーセンテージ n を示す。(例えば、 M の100倍の削減である PatchCore-1%

image


以下図では、ランダム選択と比較した、貪欲なコアセットサブサンプリングの空間カバレッジの視覚的な印象を示す。

image

  • 2Dデータにて実験
  • コアセット(上段) vs ランダムサンプリング(下段)
  • ケース(a):多峰分布、ケース(b): 一様分布
  • 視覚的には、コアセットサブサンプリングは空間サポートをより適切に近似
  • ランダムサブサンプリングはマルチモーダルケースでクラスターを見逃すし、一様分布にて均一性も低くなる


3. Anomaly Detection with PatchCore

公称パッチ画像特徴メモリバンク M を使用する。
テスト画像 x^{test} の画像レベルの異常スコアは、テスト画像の全パッチから得られる異常スコア s ∈ R の最大値 s^∗ による推定とする。
テスト画像の各パッチの異常スコアは、テスト画像のパッチ画像特徴 P(x^{test}) = P_{s,p}(\phi_j(x^{test})) と、メモリバンク M にあるパッチ画像特徴コレクション中の最近傍 m^* との距離とする。

image


s の取得にあたって、我々はスケーリング ws^∗ に適用して、隣接パッチの背景を説明する。
もしも、メモリバンク特徴が異常候補 m^{test,\ast}m^∗ に最も近い場合で、かつ、それらがメモリバンク全体の特徴からも離れている場合、言うなれば、 m^* がメモリバンク全体の特徴内で既にレアなサンプルである場合、その対象の異常スコアが高くなるようにする。

image


N_b(m^∗) は、 M 内における、テスト画像パッチ特徴 m^{test,\ast} との最近傍のTOP b 個(TOP1が、 m^∗ )。
この再重み付けは、最大パッチ距離よりも堅牢であることが分かった。


こうして、 s が与えられると、異常位置セグメンテーションへと直接にシームレスに連結できる。
画像レベルの異常スコアは、arg max-operationによる各パッチの異常スコア計算が必要。
異常位置セグメンテーションマップは、それぞれの空間位置に基づいて計算された、パッチ異常スコアを再調整することにより、PaDiMと同様の手順で計算できる。
元の入力解像度と一致させるために、バイリニア補間によって結果を拡大する。
さらに、カーネル幅 σ = 4 のガウス分布で結果を平滑化したが、このパラメーターは最適化していない。



以上が、論文に記載されているPatchCoreアルゴリズムの説明となります。
補足をさせて頂くと、SPADEに対して、以下4つの発展を加えたものとなっています。

  1. 最終層の特徴ベクトルと、浅い中間層の特徴テンソルは使わずに、やや深めの中間層の特徴テンソルのみを用いる
  2. 中間層の特徴テンソル1pixel毎の受容野の広さについて、パラメータでの調整を試みる機能として、中間層の特徴テンソルを1pixel取得する際に、周辺のpixelも巻き込んで取得するような機能を追加
  3. 最終層特徴ベクトルでのKNNによって得られたご近所さん画像のみから、pixelレベルの異常スコアマップ算出用の中間層特徴テンソルを取得するのではなく、極力全ての学習データ画像からそれを取得するようにする
  4. 必要に応じて、pixelレベルの異常スコアについて、KNNの結果を利用する形で再重み付けをする


そして、要所要所には、PaDiMにて導入された考え方が、PatchCoreに導入されています。
例えば、中間層の特徴テンソルを縦に連結させる考え方や、pixelレベルの異常スコアマップを算出した上でその最大値を画像レベルの異常スコアとする考え方などです。

ImageNetモデルとKNNを用いた異常検知手法の良いとこ取りが、上手く行われている印象です。


さて、それでは、実装例を示していきたいと思います。
先ずは、例によって、特徴抽出用のImageNetモデルを構築します。
ここで、特徴を抽出する層は、深めの中間層2つのみとします。
これがSPADEからの発展の1つ目、「1. 最終層の特徴ベクトルと、浅い中間層の特徴テンソルは使わずに、やや深めの中間層の特徴テンソルのみを用いる」となります。

import torch
import torchvision.models as models
from torchinfo import summary

device = torch.device('cuda:0')  # torch.device('cpu')

model = models.wide_resnet50_2(weights=models.Wide_ResNet50_2_Weights.IMAGENET1K_V1)  # DEFAULT
model.eval()
model.to(device)

print('model =', model)

# set model's intermediate outputs
outputs = []

def hook(module, input, output):
    outputs.append(output)

model.layer2[-1].register_forward_hook(hook)
model.layer3[-1].register_forward_hook(hook)

summary(model, input_size=(1, 3, 224, 224))


次に、こちらも例によって、データのファイル名称取得から、学習データ画像の読込までを実施します。
今回も、Dogs vs Catsのデータセットから試してみます。

ファイル名称取得

import os

path = './dogs-vs-cats/train/'
files = os.listdir(path)

files_cat = [os.path.join(path, f) for f in files
             if (os.path.isfile(os.path.join(path, f)) &
                 ('cat.' in f) & ('.jpg' in f))]
files_dog = [os.path.join(path, f) for f in files
             if (os.path.isfile(os.path.join(path, f)) &
                 ('dog.' in f) & ('.jpg' in f))]

files_cat = sorted(files_cat)
files_dog = sorted(files_dog)

print('len(files_cat) =', len(files_cat))
print('files_cat[:10] =\n', files_cat[:10])
print()
print('len(files_dog) =', len(files_dog))
print('files_dog[:10] =\n', files_dog[:10])


今回は、seedをしっかりfixさせる

import random
import numpy as np

def torch_fix_seed(seed=0):
    # Python random
    random.seed(seed)
    # Numpy
    np.random.seed(seed)
    # Pytorch
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.use_deterministic_algorithms = True

torch_fix_seed()


ファイル名称を学習用と評価用に分割

N_cat_train = 3000
N_cat_val = 1000
N_dog = 1000

files_cat = random.sample(files_cat, (N_cat_train + N_cat_val))
files_dog = random.sample(files_dog, N_dog)

files_train = np.array(files_cat[:N_cat_train])
files_val = np.array(files_cat[N_cat_train:] + files_dog)

y_val = np.concatenate([np.zeros([N_cat_val]), 
                        np.ones([N_dog])], axis=0).astype(np.int16)

print('len(files_train) =', len(files_train))
print('len(files_val) =', len(files_val))

import matplotlib.pyplot as plt

plt.figure(figsize=(10, 4), dpi=100)
plt.plot(y_val, linewidth=5, alpha=0.5)
plt.grid()
plt.show()


学習データ画像の一括読み込み

import cv2
from tqdm.notebook import tqdm

import torch.nn.functional as F

outputs = []

img_train = []
img_prep_train = []

for file in tqdm(files_train):
    img = cv2.imread(file)[..., ::-1]  # BGR2RGB
    img_prep = cv2.resize(img, (224, 224))

    img_train.append(img)
    img_prep_train.append(img_prep)

img_prep_train = np.stack(img_prep_train)

print('len(img_train) =', len(img_train))
print('img_prep_train.shape =', img_prep_train.shape)


さて、次にですが、少し複雑なことを実施します。
例によって、特徴抽出を行うのですが、その際に、各画像パッチ毎の特徴がベクトル状の1x1となるようにではなく、任意のpixel数指定をpとして、テンソル状のpxpとなるように収集をしていきます。
pxpの収集は、1x1の周辺pixelを巻き込むようにして行います。

これがSPADEからの発展の2つ目、「2. 中間層の特徴テンソル1pixel毎の受容野の広さを、後付にてパラメーター調整を試みるような機能を追加」となります。
このpの値が大きいほど、後付にて、各特徴の受容野を大きくすることができるという考え方です。

例を示しながら、説明をさせて頂きます。
猫の画像から、 縦8pixel x 横8pixel の特徴を抽出するとし、その内の1channelに注目するとします。


この時、特徴1pixel毎の受容野に対応する画像パッチを考えると、厳密にはキレイに区切るのは難しいのですが、PatchCoreの論文でもそう割り切られているように、仮に以下のように区切ってしまいます。
座標の数値が対応となっています。


上記が、SPADEやPaDiMで行われているような、特徴の1pixelに対する画像パッチを考える方式です。
一方でPatchCoreは、抽出した特徴からの切り出しを1x1pixelではなく、任意のpxp pixel、例えば、3x3pixelなどで実施をします。
つまり、特徴の縦横上を、3x3のカーネルでスライディングさせながら、特徴値を収集する形となります。
図に示すと、以下のようになります。
特徴上をスライディングさせた結果、対応する画像パッチが左記のように割り切られる、という概念の順序だけ注意して下さい。


各チャンネルで、3x3pixelで収集された特徴情報は、最終的にはベクトル状に並べ替えた上で、AlexNetのお尻の層にもあったAdaptive Average Pooling層の1D版の適用によって、任意指定次元の特徴ベクトルへと圧縮されます。
尚、圧縮の際にはメモリ配置としては、3x3pixelが末端に来る形にて実施される為、1x1pixelから特徴ベクトルを生成する場合にも、3x3pixelから特徴ベクトルを生成する場合にも、比較的似ている特徴が仕上がるかと思いますが、需要野を広げた先にCNNが強く反応するような特徴的な対象が写り込む場合に限っては、似て非なる特徴ベクトルとなることが期待されるかと思います。

尚、これを中間層特徴テンソル2つに対して別々に実施した後、それらを縦に連結して、その後に再び1DのAdaptive Average Pooling層を適用して、任意の次元に圧縮をします。
そうして特徴の次元をコントロールすることで、メモリバンクに登録する情報量の多少、つまり、RAMの占有感/逼迫感をコントロールしているようです。

ちなみに、PatchCoreのofficial実装では、中間特徴に対するAdaptive Average Poolingの圧縮次元指定が1,024で、それらを縦連結した後のAdaptive Average Poolingの圧縮次元指定も1,024となっています。
つまり、各pixel毎について、中間特徴から1,024次元の特徴ベクトルが計2つ生成され、それが縦結合されて2,048次元の特徴ベクトルに集約されて、更にそれを半分の1,024次元に圧縮される形です。

尚、PaDiMでは、各pixel毎について、3つの中間特徴の1x1pixel特徴ベクトルを縦連結して、256+512+1,024=1,792次元の特徴ベクトルを作り出した上で、その1,792次元からランダムに選択した550次元を用いる、というものでした。
それと比較すると、PatchCoreの方が特徴ベクトルの情報量としてはリッチかと思われますが、Adaptive Average Poolingの影響から、ややnon peakyな特徴となっている可能性はあるかと思います。
それらの良し悪しは、適用するデータセットに寄るのかもしれません。


さて、それでは、今説明をしたPatchCore式の特徴ベクトル生成をコードで実装していこうと思います。
実装の方法は、PatchCoreのオフィシャル実装に倣って行っていきます。

本格的に特徴抽出を行う前に、ダミーデータでのfeedforwardを実施して、特徴の縦横サイズを取得します。
本格的に特徴抽出する際に、この数値を使用します。

実施は以下コードです。

# set param
N_batch = 100
patchsize = 3
stride = 1
padding = int((patchsize - 1) / 2)

outputs = []

with torch.no_grad():
    _ = model(torch.randn(N_batch, 3, 224, 224).to(device))

f1 = outputs[0].clone()  # (B, C, H, W)
f2 = outputs[1].clone()  # (B, C, H, W)
feat = [f1, f2]
shapes = [f1.shape, f2.shape]

patch_shapes = []
for i in range(len(feat)):
    number_of_total_patches = []
    for s in shapes[i][-2:]:
        n_patches = (s + 2 * padding - 1 * (patchsize - 1) - 1) / stride + 1
        number_of_total_patches.append(int(n_patches))
    patch_shapes.append(number_of_total_patches)
print('patch_shapes =', patch_shapes)

ref_num_patches = patch_shapes[0]
print('ref_num_patches =', ref_num_patches)


これにて、中間層特徴テンソルの縦横サイズ情報を取得します。
抽出した特徴の内、浅めの層の特徴が 縦28x横28 で、深めの層の特徴が 縦14x横14 となります。
そして、ベクトルを縦連結する際は、解像度の小さい特徴 縦14x横14縦28x横28 にリサイズします。

それらの情報が得られたら、試しにPatchCore仕様の各pixel毎の特徴ベクトルの作り方を、先程抽出したダミー特徴にて実践してみます。
その際に、unfoldをする独立層 unfolder も作成しておきます。

以下コードにて、特徴ベクトル作成は実現されます。

pretrain_embed_dimension = 1024
target_embed_dimension = 1024

print('[before]')
print('------------------------------------------------')
print('len(feat) =', len(feat))
for i in range(len(feat)):
    print('feat[%d].shape =' % i, feat[i].shape)
print('------------------------------------------------')

unfolder = torch.nn.Unfold(
    kernel_size=patchsize, stride=stride, padding=padding, dilation=1
)

# patchify
print('\n\n')
print('[unfold]')
for i in range(len(feat)):
    print('------------------------------------------------')
    print('feat[%d].shape =' % i, feat[i].shape)
    # (B, C, H, W) -> (B, C, H, W, PH, PW)
    with torch.no_grad():
        feat[i] = unfolder(feat[i])
    print('feat[%d].shape =' % i, feat[i].shape)
    # (B, C, H, W, PH, PW) -> (B, C, PH, PW, HW)
    feat[i] = feat[i].reshape(*shapes[i][:2],
                                      patchsize, patchsize, -1)
    print('feat[%d].shape =' % i, feat[i].shape)
    # (B, C, PH, PW, HW) -> (B, HW, C, PW, HW)
    feat[i] = feat[i].permute(0, 4, 1, 2, 3)
    print('feat[%d].shape =' % i, feat[i].shape)
    print('------------------------------------------------')

print('\n\n')
print('[expand small feat]')
for i in range(1, len(feat)):
    print('------------------------------------------------')
    print('feat[%d].shape =' % i, feat[i].shape)

    _feat = feat[i]
    patch_dims = patch_shapes[i]

    print('_feat.shape =', _feat.shape)
    # (B, HW, C, PW, HW) -> (B, H, W, C, PH, PW)
    _feat = _feat.reshape(_feat.shape[0], patch_dims[0],
                                  patch_dims[1], *_feat.shape[2:])
    print('_feat.shape =', _feat.shape)
    # (B, H, W, C, PH, PW) -> (B, C, PH, PW, H, W)
    _feat = _feat.permute(0, -3, -2, -1, 1, 2)
    print('_feat.shape =', _feat.shape)
    perm_base_shape = _feat.shape
    # (B, C, PH, PW, H, W) -> (BCPHPW, H, W)
    _feat = _feat.reshape(-1, *_feat.shape[-2:])
    print('_feat.shape =', _feat.shape)
    # (BCPHPW, H, W) -> (BCPHPW, H_max, W_max)
    _feat = F.interpolate(_feat.unsqueeze(1),
                              size=(ref_num_patches[0], ref_num_patches[1]),
                              mode="bilinear", align_corners=False)
    print('_feat.shape =', _feat.shape)
    _feat = _feat.squeeze(1)
    print('_feat.shape =', _feat.shape)
    # (BCPHPW, H_max, W_max) -> (B, C, PH, PW, H_max, W_max)
    _feat = _feat.reshape(*perm_base_shape[:-2], 
                                  ref_num_patches[0], ref_num_patches[1])
    print('_feat.shape =', _feat.shape)
    # (B, C, PH, PW, H_max, W_max) -> (B, H_max, W_max, C, PH, PW)
    _feat = _feat.permute(0, -2, -1, 1, 2, 3)
    print('_feat.shape =', _feat.shape)
    # (B, H_max, W_max, C, PH, PW) -> (B, H_maxW_max, C, PH, PW)
    _feat = _feat.reshape(len(_feat), -1, *_feat.shape[-3:])
    feat[i] = _feat
    print('feat[%d].shape =' % i, feat[i].shape)
    print('------------------------------------------------')

print('\n\n')
print('[collect by feature vector]')
print('------------------------------------------------')
for i in range(2):
    print('feat[%d].shape =' % i, feat[i].shape)
# (B, H, W, C, PH, PW) -> (BHW, C, PH, PW)
feat = [x.reshape(-1, *x.shape[-3:]) for x in feat]
for i in range(2):
    print('feat[%d].shape =' % i, feat[i].shape)
print('------------------------------------------------')

print('\n\n')
print('[adaptive average pooling for each feature vector]')
for i in range(len(feat)):
    print('------------------------------------------------')
    print('feat[%d].shape =' % i, feat[i].shape)
    _feat = feat[i]
    print('_feat.shape =', _feat.shape)
    # (BHW, C, PH, PW) -> (BHW, 1, CPHPW)
    _feat = _feat.reshape(len(_feat), 1, -1)
    print('_feat.shape =', _feat.shape)
    # (BHW, 1, CPHPW) -> (BHW, D_p)
    _feat = F.adaptive_avg_pool1d(_feat, 
                                      pretrain_embed_dimension).squeeze(1)
    feat[i] = _feat
    print('feat[%d].shape =' % i, feat[i].shape)
    print('------------------------------------------------')

print('\n\n')
print('[concat the two feature vectors and adaptive average pooling]')
print('------------------------------------------------')
for i in range(2):
    print('feat[%d].shape =' % i, feat[i].shape)
# (BHW, D_p) -> (BHW, D_p*2)
feat = torch.stack(feat, dim=1)
print('feat.shape =', feat.shape)
"""Returns reshaped and average pooled feat."""
# batchsize x number_of_layers x input_dim -> batchsize x target_dim
# (BHW, D_p*2) -> (BHW, D_t)
feat = feat.reshape(len(feat), 1, -1)
print('feat.shape =', feat.shape)
feat = F.adaptive_avg_pool1d(feat, target_embed_dimension)
print('feat.shape =', feat.shape)
feat = feat.reshape(len(feat), -1)
print('feat.shape =', feat.shape)
print('------------------------------------------------')


上記コードでは、デバッグを冗長目に行っていますので、それによってベクトルの編集推移が追いやすいかと思います。
後は、ダミーデータで行ったことを、実際の学習データ画像全件に対して行い、その特徴ベクトルを収集すれば、データ全件でのメモリバンクが完成します。

以下が、学習データに画像全件に対する特徴収集コードとなります。

MEAN = torch.from_numpy(np.array([[[0.485, 0.456, 0.406]]]))
MEAN = MEAN.to(torch.float).to(device)
STD = torch.from_numpy(np.array([[[0.229, 0.224, 0.225]]]))
STD = STD.to(torch.float).to(device)

feat_train = []

outputs = []

for i_batch in tqdm(range(0, len(img_prep_train), N_batch)):

    img_batch = img_prep_train[i_batch:(i_batch + N_batch)]
    x = torch.from_numpy(img_batch).to(torch.float).to(device)
    x = x / 255
    x = x - MEAN
    x = x / STD
    x = x.permute(0, 3, 1, 2)

    with torch.no_grad():
        _ = model(x)

    f1 = outputs[0].clone()  # (B, C, H, W)
    f2 = outputs[1].clone()  # (B, C, H, W)
    feat = [f1, f2]
    outputs = []

    # patchify
    for i in range(len(feat)):
        # (B, C, H, W) -> (B, C, H, W, PH, PW)
        with torch.no_grad():
            feat[i] = unfolder(feat[i])
        # (B, C, H, W, PH, PW) -> (B, C, PH, PW, HW)
        feat[i] = feat[i].reshape(*shapes[i][:2],
                                          patchsize, patchsize, -1)
        # (B, C, PH, PW, HW) -> (B, HW, C, PW, HW)
        feat[i] = feat[i].permute(0, 4, 1, 2, 3)

    for i in range(1, len(feat)):
        _feat = feat[i]
        patch_dims = patch_shapes[i]
        # (B, HW, C, PW, HW) -> (B, H, W, C, PH, PW)
        _feat = _feat.reshape(_feat.shape[0], patch_dims[0],
                                      patch_dims[1], *_feat.shape[2:])
        # (B, H, W, C, PH, PW) -> (B, C, PH, PW, H, W)
        _feat = _feat.permute(0, -3, -2, -1, 1, 2)
        perm_base_shape = _feat.shape
        # (B, C, PH, PW, H, W) -> (BCPHPW, H, W)
        _feat = _feat.reshape(-1, *_feat.shape[-2:])
        # (BCPHPW, H, W) -> (BCPHPW, H_max, W_max)
        _feat = F.interpolate(_feat.unsqueeze(1),
                                  size=(ref_num_patches[0], ref_num_patches[1]),
                                  mode="bilinear", align_corners=False)
        _feat = _feat.squeeze(1)
        # (BCPHPW, H_max, W_max) -> (B, C, PH, PW, H_max, W_max)
        _feat = _feat.reshape(*perm_base_shape[:-2], 
                                      ref_num_patches[0], ref_num_patches[1])
        # (B, C, PH, PW, H_max, W_max) -> (B, H_max, W_max, C, PH, PW)
        _feat = _feat.permute(0, -2, -1, 1, 2, 3)
        # (B, H_max, W_max, C, PH, PW) -> (B, H_maxW_max, C, PH, PW)
        _feat = _feat.reshape(len(_feat), -1, *_feat.shape[-3:])
        feat[i] = _feat

    # (B, H, W, C, PH, PW) -> (BHW, C, PH, PW)
    feat = [x.reshape(-1, *x.shape[-3:]) for x in feat]

    for i in range(len(feat)):
        _feat = feat[i]
        # (BHW, C, PH, PW) -> (BHW, 1, CPHPW)
        _feat = _feat.reshape(len(_feat), 1, -1)
        # (BHW, 1, CPHPW) -> (BHW, D_p)
        _feat = F.adaptive_avg_pool1d(_feat, 
                                          pretrain_embed_dimension).squeeze(1)
        feat[i] = _feat

    # (BHW, D_p) -> (BHW, D_p*2)
    feat = torch.stack(feat, dim=1)
    """Returns reshaped and average pooled feat."""
    # batchsize x number_of_layers x input_dim -> batchsize x target_dim
    # (BHW, D_p*2) -> (BHW, D_t)
    feat = feat.reshape(len(feat), 1, -1)
    feat = F.adaptive_avg_pool1d(feat, target_embed_dimension)
    feat = feat.reshape(len(feat), -1)

    feat_train.append(feat.cpu())

feat_train = torch.vstack(feat_train)

print('feat_train.shape =', feat_train.shape)


これにて、メモリバンクが作成されました。
しかし、その容量を見てみると、かなりの量があることが確認できるかと思います。
コードをそのまま実行すると、 画像3,000枚 × 縦224pixel × 横224pixel から、 特徴ベクトル本数2,352,000本 × 特徴ベクトル次元1,024次元 という量になります。
そして、この後の工程としては、メモリバンクとクエリ画像から取れる中間層特徴テンソルの各pixel毎の特徴ベクトルとの、総当たりの距離計算、即ち、KNNの実践が待っています。 それが、SPADEからの発展の3つ目、「3. 最終層特徴ベクトルでのKNNによって得られたご近所さん画像のみから、pixelレベルの異常スコアマップ算出用の中間層特徴テンソルを取得するのではなく、極力全ての学習データ画像からそれを取得するようにする」というポイントになります。

つまりは、理想的には、学習データ画像全件のメモリバンクとのKNNを実施したいものの、データ量が多過ぎる為、それがなかなか難しいという訳です。
例えば、計算に時間がかかり過ぎたり、RAMが逼迫してしまう等です。
また、メモリバンクの容量を減らそうと、特徴の次元数を下げる等してしまうと、今度はKNNの精度が下がることが懸念されます。
何とか特徴表現の豊かさは維持しつつ、メモリバンクの容量を減らしたいところです。

実は、SPADEからの発展の3つ目、「3. 最終層特徴ベクトルでのKNNによって得られたご近所さん画像のみから、pixelレベルの異常スコアマップ算出用の中間層特徴テンソルを取得するのではなく、極力全ての学習データ画像からそれを取得するようにする」という記載の内の、「極力」という言葉が、この展開の伏線となっておりました。
先程の論文の要約でも、抽象的には説明をさせて頂いていました。
全件のメモリバンクでは、やはりKNN実践が難しい為、PatchCoreではメモリバンクからコアセットを切り出すという工程を踏みます。

コアセットとは、データ件数は元のデータ集合より少ないのだけど、理想的には、元のデータセットと同じKNNの結果を返すような、そんな部分的なデータ集合となります。
というのも、メモリバンクにある特徴ベクトルは、特に定点カメラで撮影された画像であったり、照明が一定であったり等で、写り込み変動要素が少ない場合には、ほぼ同様の特徴ベクトルが複数存在しているかと思われます。
即ち、かなり登録データが冗長であることが予測される訳です。
その為、そういった冗長さを省けば、元のメモリバンクが持つ情報量を維持しつつ、データ量の削減が可能となるのではないか、という仮説が立ちます。

その具体的な実践方法として、PatchCoreでは、K-center greedyアルゴリズムを採用しています。
K-means++というアルゴリズムをご存知の方は理解が早いかと思いますが、K-center greedyアルゴリズムとは、以下の手順のアルゴリズムとなります。


(K-center greedyアルゴリズム)

  1. ランダムに任意数のサンプルを選択する
  2. 1で選択したサンプルと全サンプルとの距離を計算する
  3. 2で計算した距離の平均を、任意数で割る形で計算する
  4. 3で計算した平均距離が最も大きいサンプルを選択し、コアセットサンプルとする
  5. 4で選択したサンプルと全サンプルとの距離を計算する
  6. サンプル毎に5で計算した距離と、3で計算した距離の最小値を得る(何れの選択サンプルよりも遠い新サンプルを探すために、何れかのサンプルまでの距離の最小値を得る)
  7. 6で得た距離の最小値が、最大であるサンプルを選択し、コアセットサンプルとする
  8. 7で選択したサンプルと全サンプルとの距離を計算する
  9. サンプル毎にこれまで計算した距離の最小値を得る(再度、何れの選択サンプルよりも遠い新サンプルを探すために、何れかのサンプルまでの距離の最小値を得る)
  10. 以降、指定のサンプル数に至る迄、7〜9を繰り返しながら、コアセットサンプルの数を増やしていく

以上です。
補足として、最初に選ぶ任意数のサンプルは、コアセットサンプルとはならない点に注意して下さい。
これをコアセットサンプルとしても良い気もしますが、PatchCoreのオフィシャル実装においては、コアセットサンプルとしていませんでしたので、今回はそれに倣います。

尚、K-means++の初期位置設定の際には、何れの選択サンプルよりも遠いサンプルを、高い確率で選択するというもので、逆に言えば、近いサンプルでも低い確率ながらも選択される可能性がある、というものとなっています。
K-center greedyアルゴリズムは、その確率的な要素を排除して、最も遠いサンプル一択で選び続けていくアルゴリズムとなっています。

ちなみに、K-center greedyアルゴリズムのイメージについては、同僚の方に教えてもらったものが非常に分かりやすかったですが、これはいわゆる施設配置問題(Facility Location Problem)というもので、その最適化手法の1つがK-center greedyアルゴリズムなのだそうです。
そして、このアルゴリズムに従って、施設配置位置を決めると、何れのポイントからも等距離範囲内に施設を配置することが可能となりそうです。

また、上記のアルゴリズムより、K-center greedyアルゴリズムの際の距離計算においても、データ量の課題感が出てきます。
そこでPatchCoreが取っている手法は、抽出した特徴ベクトルの1,024次元を、ランダムな128次元のベクトルを用いて射影をすることで、圧縮がされた128次元上でコアセットのサンプリングを実施します。
これについては、そのように実施している前例があるそうです。
尚、距離計算式にはユークリッド距離を用いています。


さて、それでは全件のメモリバンクから、実際にコアセットを取り出してみましょう。
K-center greedyアルゴリズムの実践です。
先ずは、以下コードにて、必要なパラメータをセットします。

percentage = 0.01
dimension_to_project_features_to = 128
number_of_starting_points = 10


ここで、 percentage は、全件のメモリバンクからのデータの抽出割合となります。
0.01 という値ですと、 1% のデータ量にまで減らすという意味になります。
もしも、全件のメモリバンクに特徴ベクトルが、1,000,000本登録されていたとしたら、そこから10,000本だけを取り出す形となります。
dimension_to_project_features_to は、コアセットの抽出を行う際に、特徴を圧縮射影する次元数となります。
これが小さい程、コアセットサンプリングの速度が上がり、メモリ逼迫度合いが軽減しますが、コアセットサンプリングの正確さが低下する可能性が増えていきます。
number_of_starting_points は、K-center greedyアルゴリズムを行うに当たって、最初にランダムに選択するサンプルの数となります。


パラメータがセットできましたら、次に、特徴の圧縮射影を行います。
コアセットのサンプリングに関して、この圧縮射影をされた特徴を用いていきます。

mapper = torch.nn.Linear(feat_train.shape[1], dimension_to_project_features_to,
                         bias=False).to(device)

print('mapper =', mapper)

feat_train = feat_train.to(device)

with torch.no_grad():
    feat_train_proj = mapper(feat_train)

print('feat_train.shape =', feat_train.shape)
print('feat_train_proj.shape =', feat_train_proj.shape)


次に、基準となる初期サンプルをランダムに選択します。
繰り返しになりますが、このサンプルはコアセットサンプルとはなりません。

number_of_starting_points = np.clip(number_of_starting_points,
                                    None, len(feat_train_proj))

print('number_of_starting_points =', number_of_starting_points)

start_points = np.random.choice(len(feat_train_proj), number_of_starting_points, 
                                replace=False).tolist()

print('len(start_points) =', len(start_points))
print('start_points =', start_points)


初期サンプルを選びましたら、それらと特徴全件との距離計算を総当りで実施します。

matrix_a = feat_train_proj
matrix_b = feat_train_proj[start_points]

print('matrix_a.shape =', matrix_a.shape)
print('matrix_b.shape =', matrix_b.shape)
print()

print('matrix_a.unsqueeze(1).shape =', matrix_a.unsqueeze(1).shape)
print('matrix_a.unsqueeze(2).shape =', matrix_a.unsqueeze(2).shape)
print('matrix_b.unsqueeze(1).shape =', matrix_b.unsqueeze(1).shape)
print('matrix_b.unsqueeze(2).shape =', matrix_b.unsqueeze(2).shape)
print()

"""Computes batchwise Euclidean distances using PyTorch."""
a_times_a = matrix_a.unsqueeze(1).bmm(matrix_a.unsqueeze(2)).reshape(-1, 1)
b_times_b = matrix_b.unsqueeze(1).bmm(matrix_b.unsqueeze(2)).reshape(1, -1)
a_times_b = matrix_a.mm(matrix_b.T)

print('a_times_a.shape =', a_times_a.shape)
print('b_times_b.shape =', b_times_b.shape)
print('a_times_b.shape =', a_times_b.shape)
print()

approximate_distance_matrix = (-2 * a_times_b + a_times_a + b_times_b).clamp(0, None)  # .sqrt()

print('approximate_distance_matrix.shape =', approximate_distance_matrix.shape)


この際、PatchCoreのオフィシャル実装では、ユークリッド距離の計算について、一般的には \sqrt{(a+b)^2} と計算するところを、あえて \sqrt{a^2+b^2+2ab} と計算しています。
詳細は省きますが、これは、torchのbmmというメソッドを使って、メモリ逼迫のリスクと、データが一定量を超えた際に急に発生する処理速度低下のリスクとの回避のためかと思われます。
私の手元で実験してみた限りで、torchのbmmを用いた方が、処理速度が安定しており、致命的な処理遅延が発生しづらい印象でした。
尚、この距離計算においては、距離の大小判断だけを行いたく、絶対的なスケールには意義が無い為、処理速度向上のために、ルートの計算 .sqrt() は実施を控えます。


初期サンプルとの総当り距離計算が終わりましたら、その平均距離を算出します。
直感的にはこれで、初期サンプルの何れとも遠いサンプルをあぶり出すことができます。
それが、最初の1つ目のコアセットサンプルとなります。

approximate_coreset_anchor_distances = torch.mean(approximate_distance_matrix,
                                                  axis=-1, keepdims=True)

print('approximate_coreset_anchor_distances.shape =', approximate_coreset_anchor_distances.shape)


残る処理は、先程紹介したK-center greedyアルゴリズムに従って、何れのサンプルからも遠いサンプルを、目標数に達するまで繰り返して行く形になります。
それを実践するのが以下のコードとなります。

coreset_indices = []
num_coreset_samples = int(len(feat_train_proj) * percentage)

with torch.no_grad():
    for _ in tqdm(range(num_coreset_samples), desc="Subsampling..."):
        select_idx = torch.argmax(approximate_coreset_anchor_distances).item()
        coreset_indices.append(select_idx)

        matrix_a = feat_train_proj
        matrix_b = feat_train_proj[[select_idx]]
        """Computes batchwise Euclidean distances using PyTorch."""
        a_times_a = matrix_a.unsqueeze(1).bmm(matrix_a.unsqueeze(2)).reshape(-1, 1)
        b_times_b = matrix_b.unsqueeze(1).bmm(matrix_b.unsqueeze(2)).reshape(1, -1)
        a_times_b = matrix_a.mm(matrix_b.T)
        coreset_select_distance = (-2 * a_times_b + a_times_a + b_times_b).clamp(0, None)  # .sqrt()

        approximate_coreset_anchor_distances = torch.cat(
            [approximate_coreset_anchor_distances, coreset_select_distance],
            dim=-1,
        )
        approximate_coreset_anchor_distances = torch.min(
            approximate_coreset_anchor_distances, dim=1
        ).values.reshape(-1, 1)

coreset_indices = np.array(coreset_indices)

print('len(coreset_indices) =', len(coreset_indices))
print('coreset_indices[:10] =', coreset_indices[:10])


こうして、コアセットサンプルのインデックスを得ることができました。
最後は、そのインデックスを用いて、全件のメモリバンクからの絞り込み抽出を行います。

feat_train_coreset = feat_train[coreset_indices]

print('feat_train_coreset.shape =', feat_train_coreset.shape)


これにて、コアセットの抽出が完了した次第です。
また、必要に応じて、全件のメモリバンクはRAM上から削除します。

del feat_train
torch.cuda.empty_cache()


尚、中間層特徴テンソルの1pixelの受容野が元画像の縦横45pixel程だと仮置して、コアセットに選ばれた特徴ベクトルの抽出元となっている画像パッチがどんなものか、一部だけ眺めてみましょう。


コアセットサンプリングの意図通り、似て非なる画像パッチが選ばれていることが確認できるかと思います。
尚、上からコアセットサンプルとして選ばれた順に並べています。
確かに冗長ではないですね。

ちなみに、コアセットサンプリングを行う場合には、以下のように冗長に画像パッチが収集される形になります。


さて、一先ずは、コアセットのサンプリング割合を0.01、即ち、1%として、猫画像と犬画像の選り分けがどのくらい上手くいくかを試してみます。
尚、学習データの猫画像3,000枚から、各画像毎に画像パッチ特徴ベクトルが784本取れる為、全件のメモリバンクにはが3,000×784=2,352,000本の特徴ベクトルが含まれます。
その内の1%をサンプリングするので、コアセットには23,520本の特徴ベクトルが含まれる形になります。
これは、学習データの猫画像30枚分の特徴ベクトル量となり、SPADEにて特徴ベクトル収集対象のご近所さんを、k=30として収集した際の特徴ベクトル本数と同等になります。

先程のサンプルコードにて、コアセットのサンプリングは1%で行っていましたので、そうして抽出したコアセットサンプルをfaissのインデックスに登録をして、L2のKNNを実践してみます。
以下コードでそれは実践されます。

インデックスへの登録

import faiss

feat_train_coreset = feat_train_coreset.cpu().numpy()

search_index = faiss.GpuIndexFlatL2(faiss.StandardGpuResources(),
                                    feat_train_coreset.shape[1],
                                    faiss.GpuIndexFlatConfig())
search_index.add(feat_train_coreset)

評価データの画像読込

outputs = []

img_val = []
img_prep_val = []

for file in tqdm(files_val):
    img = cv2.imread(file)[..., ::-1]  # BGR2RGB
    img_prep = cv2.resize(img, (224, 224))

    img_val.append(img)
    img_prep_val.append(img_prep)

img_prep_val = np.stack(img_prep_val)

print('len(img_val) =', len(img_val))
print('img_prep_val.shape =', img_prep_val.shape)


評価データ画像からImageNetモデル特徴を抽出

# set param
N_batch = 25

feat_val = []

outputs = []

for i_batch in tqdm(range(0, len(img_prep_val), N_batch)):

    img_batch = img_prep_val[i_batch:(i_batch + N_batch)]
    x = torch.from_numpy(img_batch).to(torch.float).to(device)
    x = x / 255
    x = x - MEAN
    x = x / STD
    x = x.permute(0, 3, 1, 2)

    with torch.no_grad():
        _ = model(x)

    f1 = outputs[0].clone()  # (B, C, H, W)
    f2 = outputs[1].clone()  # (B, C, H, W)
    feat = [f1, f2]
    shapes = [f1.shape, f2.shape]

    outputs = []

    # patchify
    for i in range(len(feat)):
        # (B, C, H, W) -> (B, C, H, W, PH, PW)
        with torch.no_grad():
            feat[i] = unfolder(feat[i])
        # (B, C, H, W, PH, PW) -> (B, C, PH, PW, HW)
        feat[i] = feat[i].reshape(*shapes[i][:2], patchsize, patchsize, -1)
        # (B, C, PH, PW, HW) -> (B, HW, C, PW, HW)
        feat[i] = feat[i].permute(0, 4, 1, 2, 3)

    for i in range(1, len(feat)):
        _feat = feat[i]
        patch_dims = patch_shapes[i]
        # (B, HW, C, PW, HW) -> (B, H, W, C, PH, PW)
        _feat = _feat.reshape(_feat.shape[0], patch_dims[0],
                                      patch_dims[1], *_feat.shape[2:])
        # (B, H, W, C, PH, PW) -> (B, C, PH, PW, H, W)
        _feat = _feat.permute(0, -3, -2, -1, 1, 2)
        perm_base_shape = _feat.shape
        # (B, C, PH, PW, H, W) -> (BCPHPW, H, W)
        _feat = _feat.reshape(-1, *_feat.shape[-2:])
        # (BCPHPW, H, W) -> (BCPHPW, H_max, W_max)
        _feat = F.interpolate(_feat.unsqueeze(1),
                                  size=(ref_num_patches[0], ref_num_patches[1]),
                                  mode="bilinear", align_corners=False)
        _feat = _feat.squeeze(1)
        # (BCPHPW, H_max, W_max) -> (B, C, PH, PW, H_max, W_max)
        _feat = _feat.reshape(*perm_base_shape[:-2], 
                                      ref_num_patches[0], ref_num_patches[1])
        # (B, C, PH, PW, H_max, W_max) -> (B, H_max, W_max, C, PH, PW)
        _feat = _feat.permute(0, -2, -1, 1, 2, 3)
        # (B, H_max, W_max, C, PH, PW) -> (B, H_maxW_max, C, PH, PW)
        _feat = _feat.reshape(len(_feat), -1, *_feat.shape[-3:])
        feat[i] = _feat

    # (B, H, W, C, PH, PW) -> (BHW, C, PH, PW)
    feat = [x.reshape(-1, *x.shape[-3:]) for x in feat]

    for i in range(len(feat)):
        _feat = feat[i]
        # (BHW, C, PH, PW) -> (BHW, 1, CPHPW)
        _feat = _feat.reshape(len(_feat), 1, -1)
        # (BHW, 1, CPHPW) -> (BHW, D_p)
        _feat = F.adaptive_avg_pool1d(_feat, 
                                          pretrain_embed_dimension).squeeze(1)
        feat[i] = _feat

    # (BHW, D_p) -> (BHW, D_p*2)
    feat = torch.stack(feat, dim=1)
    """Returns reshaped and average pooled feat."""
    # batchsize x number_of_layers x input_dim -> batchsize x target_dim
    # (BHW, D_p*2) -> (BHW, D_t)
    feat = feat.reshape(len(feat), 1, -1)
    feat = F.adaptive_avg_pool1d(feat, target_embed_dimension)
    feat = feat.reshape(len(feat), -1)

    feat_val.append(feat.cpu())

feat_val = torch.vstack(feat_val).numpy()

print('feat_val.shape =', feat_val.shape)


距離計算、及び、画像毎にそれを集計

k = 1

D_val, _ = search_index.search(feat_val, k)
D_val = D_val.reshape(len(y_val), -1)

print('D_val.shape =', D_val.shape)


後は、この結果を可視化/精度評価してみます。
PaDiMとPatchCoreにおいては、pixelレベルの異常スコアの内の最大値を、画像レベルの異常スコアとするアルゴリズムとなっている為、そのようにして精度評価を行います。
以下コードです。

plt.figure(figsize=(10, 8), dpi=100, facecolor='white')

plt.subplot(2, 1, 1)
plt.scatter(np.where(y_val == 0)[0], np.max(D_val, axis=-1)[y_val == 0],
            alpha=0.5, label='cat')
plt.scatter(np.where(y_val == 1)[0], np.max(D_val, axis=-1)[y_val == 1],
            alpha=0.5, label='dog')
plt.grid()
plt.legend()

plt.subplot(2, 1, 2)
plt.hist(np.max(D_val, axis=1)[y_val == 0], alpha=0.5, bins=50, label='cat')
plt.hist(np.max(D_val, axis=1)[y_val == 1], alpha=0.5, bins=50, label='dog')
plt.grid()
plt.legend()

plt.show()


分布が割れていないことが確認できるかと思います。
この分布に上手いこと閾値を設けた際に、どのくらいの正解率となるかも見てみます。

acc_list = []
thresh_list = np.arange(0, np.max(np.max(D_val, axis=-1))+1e-10,
                           np.max(np.max(D_val, axis=-1))/50)

for thresh in thresh_list:

    acc = np.mean(np.concatenate([(np.max(D_val, axis=-1)[y_val == 0] < thresh),
                                  (np.max(D_val, axis=-1)[y_val == 1] > thresh)]))
    acc_list.append(acc)

acc_list = np.array(acc_list)

plt.figure(figsize=(10, 4), dpi=100, facecolor='white')
plt.plot(thresh_list, acc_list, alpha=0.5, linewidth=3)
plt.scatter(thresh_list, acc_list, alpha=0.5, s=30)
plt.grid()
plt.title('np.max(acc_list) = %.3f, thresh_list[np.argmax(acc_list) = %.2f' %
          (np.max(acc_list), thresh_list[np.argmax(acc_list)]))
plt.show()


画像レベルので異常検知精度が60.8%となりました。
低い精度です。
ここで、コアセットのサンプリング割合を10%に変更してみますと、精度は以下のようになりました。


次に、コアセットサンプリングを行わずに、全件のメモリバンクにて異常検知を実施してみますと、精度は以下のようになりました。


何れのケースも精度が低いです。
一方、概ね同等の精度が出ており、このことからコアセットサンプリングの有効性は伺えそうです。

ここで、画像上の中心から離れた外側の部分には、猫以外の背景等が写り込んでいる可能性が高いかと思われる為、そこを削除して精度測定をしてみます。
PaDiMの時と同様、先程計算した距離値から、以下コードの要領で外側を除去します。

D_val = D_val.reshape(len(y_val), 28, 28)
D_val = D_val[:, 6:-6, 6:-6]
D_val = D_val.reshape(len(y_val), -1)

print('D_val.shape =', D_val.shape)


すると、精度は以下のようになりました。

コアセットサンプリング割合1%:

コアセットサンプリング割合10%:

全件のメモリバンク:


PaDiMの時はほとんど効果がありませんでしたが、PatchCoreの場合は精度改善が見られます。
このことから、クエリ画像からノイズを除けば、精度改善が実現できそうなことが伺えます。
PaDiMと違って、pixel依存が無いことが、この結果に繋がっているものかと思われます。


次に、PaDiMでも行ったように、yolov5にて、 dog cat を検出し、その検出枠でcropした画像でも同様に試してみます。
すると、以下のような精度改善が見られました。

crop画像、かつ、コアセットサンプリング割合1%:

crop画像、かつ、コアセットサンプリング割合1%、かつ、画像の外側除去:


cropによるアライメント改善は、PaDiMと同様に精度が向上しませんでした。
メモリバンクに、背景を除いた猫のパッチ画像が増えていそうなので、精度が改善されると思ったのですが…。
メモリバンクの特徴と対応付けられる画像パッチを確認してみます。


yolov5によってcropする前と後で、幾らか猫のパッチ画像が増えていますが、まだまだ背景のパッチ画像が多いことが確認できます。
なるほど、確かにこれでは精度が上がらなそうです。

ここで、正答している対象と、誤答している対象を確認してみましょう。
また、異常位置のセグメンテーションについては、実は先程算出した score_map によって実施ができていますので、それも合わせて可視化します。

flg_dog = 0
flg_asc = True

score = np.max(D_val, axis=-1)[y_val == flg_dog]
score_max = np.max(D_val)
score_min = np.min(D_val)
if flg_asc:
    idx_sort = np.argsort(score)
else:
    idx_sort = np.argsort(-score)

for i in idx_sort[:10]:
    img = img_prep_val[y_val == flg_dog][i]

    score_map = D_val[y_val == flg_dog][i].reshape(16, 16)
    score_map = (score_map - score_min) / (score_max - score_min)
    score_map = np.pad(score_map, [(6, 6), (6, 6)])
    score_map = cv2.resize(score_map, (224, 224))

    plt.figure(figsize=(11, 2.5), dpi=100, facecolor='white')
    plt.subplot(1, 3, 1)
    plt.imshow(img)
    plt.subplot(1, 3, 2)
    plt.imshow(score_map)
    plt.colorbar()
    plt.title('resularized anomaly score (minimum distance with %d%% coreset of memory bank) = %.3f\n' % (percentage*100, np.max(score_map)))
    plt.subplot(1, 3, 3)
    plt.imshow(overlay_heatmap_on_image(img=img, heatmap=score_map))
    plt.show()

猫画像にて、大きく正答している対象TOP1

猫画像にて、大きく正答している対象TOP2

猫画像にて、大きく正答している対象TOP3

猫画像にて、大きく正答している対象TOP4

猫画像にて、大きく正答している対象TOP5

猫画像にて、大きく正答している対象TOP6

猫画像にて、大きく正答している対象TOP7

猫画像にて、大きく正答している対象TOP8

猫画像にて、大きく正答している対象TOP9

猫画像にて、大きく正答している対象TOP10


猫画像にて、大きく誤答している対象TOP1

猫画像にて、大きく誤答している対象TOP2

猫画像にて、大きく誤答している対象TOP3

猫画像にて、大きく誤答している対象TOP4

猫画像にて、大きく誤答している対象TOP5

猫画像にて、大きく誤答している対象TOP6

猫画像にて、大きく誤答している対象TOP7

猫画像にて、大きく誤答している対象TOP8

猫画像にて、大きく誤答している対象TOP9

猫画像にて、大きく誤答している対象TOP10


犬画像にて、大きく誤答している対象TOP1

犬画像にて、大きく誤答している対象TOP2

犬画像にて、大きく誤答している対象TOP3

犬画像にて、大きく誤答している対象TOP4

犬画像にて、大きく誤答している対象TOP5

犬画像にて、大きく誤答している対象TOP6

犬画像にて、大きく誤答している対象TOP7

犬画像にて、大きく誤答している対象TOP8

犬画像にて、大きく誤答している対象TOP9

犬画像にて、大きく誤答している対象TOP10


犬画像にて、大きく正答している対象TOP1

犬画像にて、大きく正答している対象TOP2

犬画像にて、大きく正答している対象TOP3

犬画像にて、大きく正答している対象TOP4

犬画像にて、大きく正答している対象TOP5

犬画像にて、大きく正答している対象TOP6

犬画像にて、大きく正答している対象TOP7

犬画像にて、大きく正答している対象TOP8

犬画像にて、大きく正答している対象TOP9

犬画像にて、大きく正答している対象TOP10


示唆深い結果となりました。

総じて、Dogs vs Catsのデータセットでは、猫らしさや犬らしさよりも、背景のテクスチャや、首輪などのアクセサリーに対して、強く反応してしまうことが伺えます。
猫らしさや犬らしさよりも、それら異常に対する距離が目立ちすぎてしまっています。
猫らしさや犬らしさに反応してもらうためには、アライメントの調整や、背景等のノイズ除去を行う必要がありそうです。


ここで、PaDiMでの実験と同様に、PatchCoreにおいても、オックスフォード大学が発行して下さっているデータセットを用いてみます。
Visual Geometry Group - University of Oxford


コアセットサンプリング割合1%:

コアセットサンプリング割合1%、かつ、画像の外側除去:

コアセットサンプリング割合10%:

コアセットサンプリング割合10%、かつ、画像の外側除去:

全件のメモリバンク:

全件のメモリバンク、かつ、画像の外側除去:


PaDiMよりも高い精度が出ました。
また、オックスフォードのデータセットは、比較的件数が少ないものとなるのですが、その理由からか、コアセットのサンプリング割合が高い方が精度が高くなる傾向が見て取れます。
このことから、データが少ない場合は、コアセットサンプリングを実施しない方が良いことが、推察できます。
これは、直感的にも納得がいく傾向かと思われます。

尚、コアセットサンプリング割合が0.01、即ち、1%の際のメモリバンク特徴に対応するパッチ画像を確認してみると、以下のようになっていました。


背景等のパッチ画像は少なく、猫の部位を切り出したパッチ画像が多いかと思います。
これを見ると、Dogs vs Catsデータセットよりも精度が出ることが腑に落ちます。


その後、学習データに対して、反転と回転(-10度と+10度)のaugmentationを行い、データを増幅して、精度測定を行ってみたところ、精度は以下のようになりました。
尚、コアセットサンプリングは行わず、全件のメモリバンクを使用しました。

全件のメモリバンク:

全件のメモリバンク、かつ、画像の外側除去(中間層特徴テンソルの上下左右各外側の3pixelずつを除去):

全件のメモリバンク、かつ、画像の外側除去(中間層特徴テンソルの上下左右各外側の6pixelずつを除去):


中間層特徴テンソルの除去についても、これまで上下左右各6pixelずつ実施してきたのに加えて、3pixelずつの除去についても評価してみました。
アライメントによるメモリバンクの質向上と、augmentationによるメモリバンク量の拡充、そして、そのメモリバンクを全件用いることで、徐々に精度が上がってきた次第です。
また、ノイズの写り込みが多いであろう外側のpixelは、やはり除去すると精度が上がることが確認できます。

一方で、メモリバンクが大きくなることで、推論の速度は下がってしまいます。
ここに、精度と速度のトレードオフがあります。
処理速度を上げたい場合には、精度を諦めるか、撮像のアライメントをより厳しく統制してメモリバンクを軽くする必要などがあるかと思います。


仕上げに、猫らしさや犬らしさを捉えられているかどうかを、可視化を行って確認してみましょう。
augmentationを行った上で、中間層特徴テンソル 縦28pixel x 横28pixel から、外側3pixelずつと除いて縦25pixel x 横25pixel としたケースにて、異常位置のセグメンテーションの可視化を実施します。
精度として、最大の正解率で80.2%を発揮していたケースです。


猫画像にて、大きく正答している対象TOP1

猫画像にて、大きく正答している対象TOP2

猫画像にて、大きく正答している対象TOP3

猫画像にて、大きく正答している対象TOP4

猫画像にて、大きく正答している対象TOP5

猫画像にて、大きく正答している対象TOP6

猫画像にて、大きく正答している対象TOP7

猫画像にて、大きく正答している対象TOP8

猫画像にて、大きく正答している対象TOP9

猫画像にて、大きく正答している対象TOP10


猫画像にて、大きく誤答している対象TOP1

猫画像にて、大きく誤答している対象TOP2

猫画像にて、大きく誤答している対象TOP3

猫画像にて、大きく誤答している対象TOP4

猫画像にて、大きく誤答している対象TOP5

猫画像にて、大きく誤答している対象TOP6

猫画像にて、大きく誤答している対象TOP7

猫画像にて、大きく誤答している対象TOP8

猫画像にて、大きく誤答している対象TOP9

猫画像にて、大きく誤答している対象TOP10


犬画像にて、大きく誤答している対象TOP1

犬画像にて、大きく誤答している対象TOP2

犬画像にて、大きく誤答している対象TOP3

犬画像にて、大きく誤答している対象TOP4

犬画像にて、大きく誤答している対象TOP5

犬画像にて、大きく誤答している対象TOP6

犬画像にて、大きく誤答している対象TOP7

犬画像にて、大きく誤答している対象TOP8

犬画像にて、大きく誤答している対象TOP9

犬画像にて、大きく誤答している対象TOP10


犬画像にて、大きく正答している対象TOP1

犬画像にて、大きく正答している対象TOP2

犬画像にて、大きく正答している対象TOP3

犬画像にて、大きく正答している対象TOP4

犬画像にて、大きく正答している対象TOP5

犬画像にて、大きく正答している対象TOP6

犬画像にて、大きく正答している対象TOP7

犬画像にて、大きく正答している対象TOP8

犬画像にて、大きく正答している対象TOP9

犬画像にて、大きく正答している対象TOP10


PaDiMと違って、例えば、縦や横にズレていたり、回転していたり等、アライメントが多少整っていなくても、正答の傾向が強いことが確認できるかと思います。
また、画像の外側除去によって精度が上がるだけあり、猫画像における誤答の原因が外側に多いことも確認できます。
また、上目遣いの目や、舌、大きく開けた口等が、猫における誤答の原因になっていることから、そういった画像を補填すれば、精度向上が図れるであろうことが推察できます。
その意味では、精度が低い要因、即ち、精度を上げるためのトライアル方針が立てやすいことは、PatchCoreの強みと言えます。
PaDiMの場合は、そういう対象があったとしても、統計的にその誤答要因を平均に近付けなければならない為、簡単には精度向上が図れない次第です。

犬にて誤答してしまっている対象、即ち、正常としてしまっている対象については、毛がフサフサしていたり、可愛い見た目の子犬だったり、耳が立っていたりと、猫と近い特徴を持っている場合と言えそうです。
また、犬にて強く異常と判定している対象については、垂れた耳や、大きな鼻、大きく開けた口、舌、首輪、タレ目等々が確認できます。
横を向いている犬の画像も、大きく異常と捉えられていますが、これは目から鼻までがスッと伸びていることや、併せて口が大きいことが反応の理由ではないかと思われます。


以上、犬と猫でのPatchCoreの説明、及び、実験でした。
最後の最後に、参考までのテクニック紹介をさせて頂きたいのですが、この手の課題感の場合は、部品のキズ等のケースと違い、全体感を見て、正常か異常かを判定するケースとなりますので、実は異常スコアマップの最大値ではなく、平均を画像レベルの異常スコアとした方が、精度が上がったりします。
以下が、その方針にて精度測定をした結果となり、実際に上手くスコアが上がりました。
課題感に合わせて、調整できるポイントかと思います。

全件のメモリバンク(平均Ver):

全件のメモリバンク、かつ、画像の外側除去(中間層特徴テンソルの上下左右各外側の3pixelずつを除去)(平均Ver):

全件のメモリバンク、かつ、画像の外側除去(中間層特徴テンソルの上下左右各外側の6pixelずつを除去)(平均Ver):


さて、PatchCoreのイメージが大分掴めたかと思います。
ここで、MVTecに対して適用をしてみます。
先程、Dogs vs Catsに適用したコードを、MVTecに置き換える形で実装をしていきます。

特徴抽出器の準備

import torch
import torchvision.models as models
from torchinfo import summary

device = torch.device('cuda:0')  # torch.device('cpu')

model = models.wide_resnet50_2(weights=models.Wide_ResNet50_2_Weights.IMAGENET1K_V1)  # DEFAULT
model.eval()
model.to(device)

print('model =', model)

# set model's intermediate outputs
outputs = []

def hook(module, input, output):
    outputs.append(output)

model.layer2[-1].register_forward_hook(hook)
model.layer3[-1].register_forward_hook(hook)

summary(model, input_size=(1, 3, 224, 224))


seedをしっかり固定

import random
import numpy as np

def torch_fix_seed(seed=0):
    # Python random
    random.seed(seed)
    # Numpy
    np.random.seed(seed)
    # Pytorch
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.use_deterministic_algorithms = True

torch_fix_seed()


ファイル名称取得

import os
import numpy as np
import cv2
import matplotlib.pyplot as plt

path_parent = './mvtec_anomaly_detection/'
type_data = 'bottle'  # 'screw'

path_train = os.path.join(path_parent, type_data, 'train/good')
files_train = [os.path.join(path_train, f) for f in os.listdir(path_train)
               if (os.path.isfile(os.path.join(path_train, f)) &
                   ('.png' in f))]
files_train = np.array(sorted(files_train))

print('len(files_train) =', len(files_train))
print('files_train[:5] =\n', files_train[:5])
print()

types_test = os.listdir(os.path.join(path_parent, type_data, 'test'))
types_test = np.array(sorted(types_test))

files_test = {}

for type_test in types_test:

    path_test = os.path.join(path_parent, type_data, 'test', type_test)
    files_test[type_test] = [os.path.join(path_test, f)
                             for f in os.listdir(path_test)
                             if (os.path.isfile(os.path.join(path_test, f)) &
                                 ('.png' in f))]
    files_test[type_test] = np.array(sorted(files_test[type_test]))

    print('len(files_test[%s]) =' % type_test, len(files_test[type_test]))
    print('files_test[%s][:5] =\n' % type_test, files_test[type_test][:5])
    print()

plt.figure(figsize=(12, 32), dpi=100, facecolor='white')

plt.subplot(8, 2, 1)
plt.imshow(cv2.imread(files_train[0])[..., ::-1])
plt.title(files_train[0])

for i_type_test, type_test in enumerate(types_test):
    plt.subplot(8, 2, (i_type_test + 3))
    plt.imshow(cv2.imread(files_test[type_test][0])[..., ::-1])
    plt.title(files_test[type_test][0])

plt.show()


学習データ画像読込

import cv2
from tqdm.notebook import tqdm

img_prep_train = []

for file in tqdm(files_train):
    img = cv2.imread(file)[..., ::-1]  # BGR2RGB
    img_prep = cv2.resize(img, (256, 256), interpolation=cv2.INTER_AREA)
    img_prep = img_prep[16:(256-16), 16:(256-16)]

    img_prep_train.append(img_prep)

    if (np.random.rand() < 0.05):
        plt.figure(figsize=(10, 4), dpi=100)
        plt.subplot(1, 2, 1)
        plt.imshow(img)
        plt.title(file)
        plt.subplot(1, 2, 2)
        plt.imshow(img_prep)
        plt.show()

img_prep_train = np.stack(img_prep_train)

print('img_prep_train.shape =', img_prep_train.shape)


学習データ画像から、特徴抽出をする前段の処理

# set param
N_batch = 100
patchsize = 3
stride = 1
padding = int((patchsize - 1) / 2)

unfolder = torch.nn.Unfold(
    kernel_size=patchsize, stride=stride, padding=padding, dilation=1
)

outputs = []

with torch.no_grad():
    _ = model(torch.randn(N_batch, 3, 224, 224).to(device))

f1 = outputs[0].clone()  # (B, C, H, W)
f2 = outputs[1].clone()  # (B, C, H, W)
feat = [f1, f2]
shapes = [f1.shape, f2.shape]

patch_shapes = []
for i in range(len(feat)):
    number_of_total_patches = []
    for s in shapes[i][-2:]:
        n_patches = (s + 2 * padding - 1 * (patchsize - 1) - 1) / stride + 1
        number_of_total_patches.append(int(n_patches))
    patch_shapes.append(number_of_total_patches)
print('patch_shapes =', patch_shapes)

ref_num_patches = patch_shapes[0]
print('ref_num_patches =', ref_num_patches)


学習データ画像からの特徴抽出

import torch.nn.functional as F

pretrain_embed_dimension = 1024
target_embed_dimension = 1024

MEAN = torch.from_numpy(np.array([[[0.485, 0.456, 0.406]]]))
MEAN = MEAN.to(torch.float).to(device)
STD = torch.from_numpy(np.array([[[0.229, 0.224, 0.225]]]))
STD = STD.to(torch.float).to(device)

feat_train = []

outputs = []

for i_batch in tqdm(range(0, len(img_prep_train), N_batch)):

    img_batch = img_prep_train[i_batch:(i_batch + N_batch)]
    x = torch.from_numpy(img_batch).to(torch.float).to(device)
    x = x / 255
    x = x - MEAN
    x = x / STD
    x = x.permute(0, 3, 1, 2)

    with torch.no_grad():
        _ = model(x)

    f1 = outputs[0].clone()  # (B, C, H, W)
    f2 = outputs[1].clone()  # (B, C, H, W)
    feat = [f1, f2]
    shapes = [f1.shape, f2.shape]

    outputs = []

    # patchify
    for i in range(len(feat)):
        # (B, C, H, W) -> (B, C, H, W, PH, PW)
        with torch.no_grad():
            feat[i] = unfolder(feat[i])
        # (B, C, H, W, PH, PW) -> (B, C, PH, PW, HW)
        feat[i] = feat[i].reshape(*shapes[i][:2],
                                          patchsize, patchsize, -1)
        # (B, C, PH, PW, HW) -> (B, HW, C, PW, HW)
        feat[i] = feat[i].permute(0, 4, 1, 2, 3)

    for i in range(1, len(feat)):
        _feat = feat[i]
        patch_dims = patch_shapes[i]
        # (B, HW, C, PW, HW) -> (B, H, W, C, PH, PW)
        _feat = _feat.reshape(_feat.shape[0], patch_dims[0],
                                      patch_dims[1], *_feat.shape[2:])
        # (B, H, W, C, PH, PW) -> (B, C, PH, PW, H, W)
        _feat = _feat.permute(0, -3, -2, -1, 1, 2)
        perm_base_shape = _feat.shape
        # (B, C, PH, PW, H, W) -> (BCPHPW, H, W)
        _feat = _feat.reshape(-1, *_feat.shape[-2:])
        # (BCPHPW, H, W) -> (BCPHPW, H_max, W_max)
        _feat = F.interpolate(_feat.unsqueeze(1),
                                  size=(ref_num_patches[0], ref_num_patches[1]),
                                  mode="bilinear", align_corners=False)
        _feat = _feat.squeeze(1)
        # (BCPHPW, H_max, W_max) -> (B, C, PH, PW, H_max, W_max)
        _feat = _feat.reshape(*perm_base_shape[:-2], 
                                      ref_num_patches[0], ref_num_patches[1])
        # (B, C, PH, PW, H_max, W_max) -> (B, H_max, W_max, C, PH, PW)
        _feat = _feat.permute(0, -2, -1, 1, 2, 3)
        # (B, H_max, W_max, C, PH, PW) -> (B, H_maxW_max, C, PH, PW)
        _feat = _feat.reshape(len(_feat), -1, *_feat.shape[-3:])
        feat[i] = _feat

    # (B, H, W, C, PH, PW) -> (BHW, C, PH, PW)
    feat = [x.reshape(-1, *x.shape[-3:]) for x in feat]

    for i in range(len(feat)):
        _feat = feat[i]
        # (BHW, C, PH, PW) -> (BHW, 1, CPHPW)
        _feat = _feat.reshape(len(_feat), 1, -1)
        # (BHW, 1, CPHPW) -> (BHW, D_p)
        _feat = F.adaptive_avg_pool1d(_feat, 
                                          pretrain_embed_dimension).squeeze(1)
        feat[i] = _feat

    # (BHW, D_p) -> (BHW, D_p*2)
    feat = torch.stack(feat, dim=1)
    """Returns reshaped and average pooled feat."""
    # batchsize x number_of_layers x input_dim -> batchsize x target_dim
    # (BHW, D_p*2) -> (BHW, D_t)
    feat = feat.reshape(len(feat), 1, -1)
    feat = F.adaptive_avg_pool1d(feat, target_embed_dimension)
    feat = feat.reshape(len(feat), -1)

    feat_train.append(feat.cpu())

feat_train = torch.vstack(feat_train)

print('feat_train.shape =', feat_train.shape)


コアセットサンプリング

percentage = 0.1
dimension_to_project_features_to = 128
number_of_starting_points = 10

torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)
mapper = torch.nn.Linear(feat_train.shape[1], dimension_to_project_features_to,
                         bias=False).to(device)

print('mapper =', mapper)

feat_train = feat_train.to(device)

with torch.no_grad():
    feat_train_proj = mapper(feat_train)

print('feat_train.shape =', feat_train.shape)
print('feat_train_proj.shape =', feat_train_proj.shape)

number_of_starting_points = np.clip(number_of_starting_points,
                                    None, len(feat_train_proj))

print('number_of_starting_points =', number_of_starting_points)

np.random.seed(0)
start_points = np.random.choice(len(feat_train_proj), number_of_starting_points, 
                                replace=False).tolist()

print('len(start_points) =', len(start_points))
print('start_points =', start_points)

matrix_a = feat_train_proj
matrix_b = feat_train_proj[start_points]

print('matrix_a.shape =', matrix_a.shape)
print('matrix_b.shape =', matrix_b.shape)
print()

print('matrix_a.unsqueeze(1).shape =', matrix_a.unsqueeze(1).shape)
print('matrix_a.unsqueeze(2).shape =', matrix_a.unsqueeze(2).shape)
print('matrix_b.unsqueeze(1).shape =', matrix_b.unsqueeze(1).shape)
print('matrix_b.unsqueeze(2).shape =', matrix_b.unsqueeze(2).shape)
print()

"""Computes batchwise Euclidean distances using PyTorch."""
a_times_a = matrix_a.unsqueeze(1).bmm(matrix_a.unsqueeze(2)).reshape(-1, 1)
b_times_b = matrix_b.unsqueeze(1).bmm(matrix_b.unsqueeze(2)).reshape(1, -1)
a_times_b = matrix_a.mm(matrix_b.T)

print('a_times_a.shape =', a_times_a.shape)
print('b_times_b.shape =', b_times_b.shape)
print('a_times_b.shape =', a_times_b.shape)
print()

approximate_distance_matrix = (-2 * a_times_b + a_times_a + b_times_b).clamp(0, None)  # .sqrt()

print('approximate_distance_matrix.shape =', approximate_distance_matrix.shape)

approximate_coreset_anchor_distances = torch.mean(approximate_distance_matrix,
                                                  axis=-1, keepdims=True)

print('approximate_coreset_anchor_distances.shape =', approximate_coreset_anchor_distances.shape)

coreset_indices = []
num_coreset_samples = int(len(feat_train_proj) * percentage)

with torch.no_grad():
    for _ in tqdm(range(num_coreset_samples), desc="Subsampling..."):
        select_idx = torch.argmax(approximate_coreset_anchor_distances).item()
        coreset_indices.append(select_idx)

        matrix_a = feat_train_proj
        matrix_b = feat_train_proj[[select_idx]]
        """Computes batchwise Euclidean distances using PyTorch."""
        a_times_a = matrix_a.unsqueeze(1).bmm(matrix_a.unsqueeze(2)).reshape(-1, 1)
        b_times_b = matrix_b.unsqueeze(1).bmm(matrix_b.unsqueeze(2)).reshape(1, -1)
        a_times_b = matrix_a.mm(matrix_b.T)
        coreset_select_distance = (-2 * a_times_b + a_times_a + b_times_b).clamp(0, None)  # .sqrt()

        approximate_coreset_anchor_distances = torch.cat(
            [approximate_coreset_anchor_distances, coreset_select_distance],
            dim=-1,
        )
        approximate_coreset_anchor_distances = torch.min(
            approximate_coreset_anchor_distances, dim=1
        ).values.reshape(-1, 1)

coreset_indices = np.array(coreset_indices)

print('len(coreset_indices) =', len(coreset_indices))
print('coreset_indices[:10] =', coreset_indices[:10])

feat_train_coreset = feat_train[coreset_indices]
feat_train_coreset = feat_train_coreset.cpu().numpy()

del feat_train
torch.cuda.empty_cache()

print('feat_train_coreset.shape =', feat_train_coreset.shape)


KNNインデックスの作成

import faiss

search_index = faiss.GpuIndexFlatL2(faiss.StandardGpuResources(),
                                    feat_train_coreset.shape[1],
                                    faiss.GpuIndexFlatConfig())
search_index.add(feat_train_coreset)


テストデータ画像の読込

img_prep_test = {}
gt_test = {}

for type_test in types_test:

    img_prep_test[type_test] = []
    gt_test[type_test] = []

    for file in tqdm(files_test[type_test]):
        img = cv2.imread(file)[..., ::-1]  # BGR2RGB
        img_prep = cv2.resize(img, (256, 256), interpolation=cv2.INTER_AREA)
        img_prep = img_prep[16:(256-16), 16:(256-16)]
        img_prep_test[type_test].append(img_prep)

        if (type_test == 'good'):
            gt = np.zeros_like(img_prep[..., 0], dtype=np.uint8)
        else:
            file_gt = file.replace('/test/', '/ground_truth/')
            file_gt = file_gt.replace('.png', '_mask.png')
            gt = cv2.imread(file_gt, cv2.IMREAD_GRAYSCALE)
            gt = cv2.resize(gt, (256, 256), interpolation=cv2.INTER_NEAREST)
            gt = gt[16:(256-16), 16:(256-16)]
            gt = (gt / np.max(gt)).astype(np.uint8)
        gt_test[type_test].append(gt)

    img_prep_test[type_test] = np.stack(img_prep_test[type_test])
    gt_test[type_test] = np.stack(gt_test[type_test])

    print('img_prep_test[%s].shape =' % type_test, img_prep_test[type_test].shape)
    print('gt_test[%s].shape =' % type_test, gt_test[type_test].shape)


テストデータ画像からの特徴抽出

# set param
N_batch = 25

feat_test = {}

for type_test in types_test:

    outputs = []
    feat_test[type_test] = []

    for i_batch in tqdm(range(0, len(img_prep_test[type_test]), N_batch)):

        img_batch = img_prep_test[type_test][i_batch:(i_batch + N_batch)]
        x = torch.from_numpy(img_batch).to(torch.float).to(device)
        x = x / 255
        x = x - MEAN
        x = x / STD
        x = x.permute(0, 3, 1, 2)

        with torch.no_grad():
            _ = model(x)

        f1 = outputs[0].clone()  # (B, C, H, W)
        f2 = outputs[1].clone()  # (B, C, H, W)
        feat = [f1, f2]
        shapes = [f1.shape, f2.shape]

        outputs = []

        # patchify
        for i in range(len(feat)):
            # (B, C, H, W) -> (B, C, H, W, PH, PW)
            with torch.no_grad():
                feat[i] = unfolder(feat[i])
            # (B, C, H, W, PH, PW) -> (B, C, PH, PW, HW)
            feat[i] = feat[i].reshape(*shapes[i][:2], patchsize, patchsize, -1)
            # (B, C, PH, PW, HW) -> (B, HW, C, PW, HW)
            feat[i] = feat[i].permute(0, 4, 1, 2, 3)

        for i in range(1, len(feat)):
            _feat = feat[i]
            patch_dims = patch_shapes[i]
            # (B, HW, C, PW, HW) -> (B, H, W, C, PH, PW)
            _feat = _feat.reshape(_feat.shape[0], patch_dims[0],
                                          patch_dims[1], *_feat.shape[2:])
            # (B, H, W, C, PH, PW) -> (B, C, PH, PW, H, W)
            _feat = _feat.permute(0, -3, -2, -1, 1, 2)
            perm_base_shape = _feat.shape
            # (B, C, PH, PW, H, W) -> (BCPHPW, H, W)
            _feat = _feat.reshape(-1, *_feat.shape[-2:])
            # (BCPHPW, H, W) -> (BCPHPW, H_max, W_max)
            _feat = F.interpolate(_feat.unsqueeze(1),
                                      size=(ref_num_patches[0], ref_num_patches[1]),
                                      mode="bilinear", align_corners=False)
            _feat = _feat.squeeze(1)
            # (BCPHPW, H_max, W_max) -> (B, C, PH, PW, H_max, W_max)
            _feat = _feat.reshape(*perm_base_shape[:-2], 
                                          ref_num_patches[0], ref_num_patches[1])
            # (B, C, PH, PW, H_max, W_max) -> (B, H_max, W_max, C, PH, PW)
            _feat = _feat.permute(0, -2, -1, 1, 2, 3)
            # (B, H_max, W_max, C, PH, PW) -> (B, H_maxW_max, C, PH, PW)
            _feat = _feat.reshape(len(_feat), -1, *_feat.shape[-3:])
            feat[i] = _feat

        # (B, H, W, C, PH, PW) -> (BHW, C, PH, PW)
        feat = [x.reshape(-1, *x.shape[-3:]) for x in feat]

        for i in range(len(feat)):
            _feat = feat[i]
            # (BHW, C, PH, PW) -> (BHW, 1, CPHPW)
            _feat = _feat.reshape(len(_feat), 1, -1)
            # (BHW, 1, CPHPW) -> (BHW, D_p)
            _feat = F.adaptive_avg_pool1d(_feat, 
                                          pretrain_embed_dimension).squeeze(1)
            feat[i] = _feat

        # (BHW, D_p) -> (BHW, D_p*2)
        feat = torch.stack(feat, dim=1)
        """Returns reshaped and average pooled feat."""
        # batchsize x number_of_layers x input_dim -> batchsize x target_dim
        # (BHW, D_p*2) -> (BHW, D_t)
        feat = feat.reshape(len(feat), 1, -1)
        feat = F.adaptive_avg_pool1d(feat, target_embed_dimension)
        feat = feat.reshape(len(feat), -1)

        feat_test[type_test].append(feat.cpu())

    feat_test[type_test] = torch.vstack(feat_test[type_test]).numpy()

    print('feat_test.shape[%s] =' % type_test, feat_test[type_test].shape)


テスト画像特徴と、コアセットサンプルとのKNN実施をして、異常スコアマップ取得

k = 1

score_test = {}

for type_test in types_test:

    score_test[type_test], _ = search_index.search(feat_test[type_test], k)
    score_test[type_test] = score_test[type_test].reshape(-1, 28, 28)

    print('score_test[%s].shape =' % type_test, score_test[type_test].shape)
    print('np.mean(score_test[%s]) =' % type_test, np.mean(score_test[type_test]))
    print('np.mean(np.abs(score_test[%s])) =' % type_test, np.mean(np.abs(score_test[type_test])))
    print('np.std(score_test[%s]) =' % type_test, np.std(score_test[type_test]))
    print('np.max(score_test[%s]) =' % type_test, np.max(score_test[type_test]))
    print('np.min(score_test[%s]) =' % type_test, np.min(score_test[type_test]))


画像レベルの予測分布確認

y_hat_list = []
y_list = []
N_test = 0

type_test = 'good'

plt.figure(figsize=(10, 8), dpi=100, facecolor='white')

plt.subplot(2, 1, 1)
plt.scatter((np.arange(len(score_test[type_test])) + N_test),
            np.max(np.max(score_test[type_test], axis=-1), axis=-1),
            alpha=0.5, label=type_test)

plt.subplot(2, 1, 2)
plt.hist(np.max(np.max(score_test[type_test], axis=-1), axis=-1),
         alpha=0.5, bins=10, label=type_test)

y_hat_list.append(np.max(np.max(score_test[type_test], axis=-1), axis=-1))
y_list.append(np.zeros([len(score_test[type_test])], dtype=np.int16))
N_test += len(score_test[type_test])

for type_test in types_test[types_test != 'good']:

    plt.subplot(2, 1, 1)
    plt.scatter((np.arange(len(score_test[type_test])) + N_test),
                np.max(np.max(score_test[type_test], axis=-1), axis=-1),
                alpha=0.5, label=type_test)

    plt.subplot(2, 1, 2)
    plt.hist(np.max(np.max(score_test[type_test], axis=-1), axis=-1),
             alpha=0.5, bins=10, label=type_test)

    y_hat_list.append(np.max(np.max(score_test[type_test], axis=-1), axis=-1))
    y_list.append(np.ones([len(score_test[type_test])], dtype=np.int16))
    N_test += len(score_test[type_test])

y_list = np.hstack(y_list)
y_hat_list = np.hstack(y_hat_list)

plt.subplot(2, 1, 1)
plt.grid()
plt.legend()

plt.subplot(2, 1, 2)
plt.grid()
plt.legend()

plt.show()


画像レベルでのAUC精度評価

from sklearn.metrics import roc_curve
from sklearn.metrics import roc_auc_score

# calculate per-pixel level ROCAUC
fpr, tpr, _ = roc_curve(y_list, y_hat_list)
per_pixel_rocauc = roc_auc_score(y_list, y_hat_list)

plt.figure(figsize=(10, 6), dpi=100)
plt.plot(fpr, tpr, label='%s ROCAUC: %.3f' % (type_data, per_pixel_rocauc))
plt.grid()
plt.legend()
plt.show()


pixelレベルでの異常スコアマップ可視化と精度評価

# https://github.com/gsurma/cnn_explainer/blob/main/utils.py
def overlay_heatmap_on_image(img, heatmap, ratio_img=0.5):
    img = img.astype(np.float32)

    heatmap = 1 - np.clip(heatmap, 0, 1)
    heatmap = cv2.applyColorMap((heatmap * 255).astype(np.uint8), cv2.COLORMAP_JET)
    heatmap = heatmap.astype(np.float32)

    overlay = (img * ratio_img) + (heatmap * (1 - ratio_img))
    overlay = np.clip(overlay, 0, 255)
    overlay = overlay.astype(np.uint8)
    return overlay


gt_flat_list = []
score_flat_list = []

score_max = max([np.max(score_test[type_test]) for type_test in types_test])

for type_test in types_test:
    for i, gt in tqdm(enumerate(gt_test[type_test]),
                      desc=('[verbose mode] visualize localization (case:%s)' %
                            type_test)):
        file = files_test[type_test][i]
        img = cv2.imread(file)[..., ::-1]  # BGR2RGB
        img = cv2.resize(img, (256, 256), interpolation=cv2.INTER_AREA)
        img = img[16:(256-16), 16:(256-16)]
        score = score_test[type_test][i]
        score = cv2.resize(score, (224, 224))

        # calculate per-pixel level ROCAUC
        gt_flat = gt.reshape(-1)
        score_flat = score.reshape(-1)
        if (np.sum(gt_flat > 0) == 0):
            rocauc = np.nan
        else:
            rocauc = roc_auc_score(gt_flat, score_flat)

        # stock for output result
        gt_flat_list.append(gt_flat)
        score_flat_list.append(score_flat)

        plt.figure(figsize=(9, 7.3), dpi=100, facecolor='white')
        plt.rcParams['font.size'] = 8
        plt.subplot2grid((3, 3), (0, 0), rowspan=1, colspan=1)
        plt.imshow(img)
        plt.title('%s : %s' % (file.split('/')[-2], file.split('/')[-1]))
        plt.subplot2grid((3, 3), (0, 1), rowspan=1, colspan=1)
        plt.imshow(gt)
        plt.subplot2grid((3, 3), (0, 2), rowspan=1, colspan=1)
        plt.imshow(score)
        plt.colorbar()
        plt.title('global max score : %.2f' % score_max)
        plt.subplot2grid((3, 4), (1, 0), rowspan=2, colspan=2)
        plt.imshow(overlay_heatmap_on_image(img, (score / score_max)))
        plt.title('per-pixel level ROCAUC: %.3f' % rocauc)
        plt.subplot2grid((3, 4), (1, 2), rowspan=2, colspan=2)
        plt.imshow((img.astype(np.float32) *
                    (score / score_max)[..., None]).astype(np.uint8))
        plt.show()

gt_flat_list = np.array(gt_flat_list).reshape(-1)
score_flat_list = np.array(score_flat_list).reshape(-1)

# calculate per-pixel level ROCAUC
fpr, tpr, _ = roc_curve(gt_flat_list, score_flat_list)
rocauc = roc_auc_score(gt_flat_list, score_flat_list)
print('%s per-pixel level ROCAUC: %.3f' % (type_data, rocauc))

plt.figure(figsize=(10, 6), dpi=100)
plt.plot(fpr, tpr, label='%s ROCAUC: %.3f' % (type_data, rocauc))
plt.grid()
plt.legend()
plt.show()


以上です。
尚、bottle以外のデータ種別についても試してみましたが、以下ページに記載をされている精度と非常に近い精度が、全てのデータ種別について算出されました。
GitHub - hcw-00/PatchCore_anomaly_detection: Unofficial implementation of PatchCore anomaly detection

その内、PaDiMの際に課題感として挙げた、screwについては、ここに結果を記載します。
比較的上手く検知できているかと思います。

画像レベルの異常検知・予測分布


画像レベルの異常検知・精度


pixelレベルの異常位置セグメンテーション・予測の様子


pixelレベルの異常位置セグメンテーション・精度


尚、実装の通りですが、MVTecでのコアセットのサンプリング割合は、10%としています。


という訳で、これにて、PatchCoreの解説を終えます。
非常に優秀なアルゴリズムであることが分かって頂いたかと思います。
また、DN2から始まった流れにて、現状の終着点がPatchCoreという形になっていますが、一方でまだまだ課題感が多いことも、Dogs vs Catsデータセットの例などを通して、感じて頂けたかとも思います。
前処理やハード側設定によって、発揮される精度は大きく変わってくるかと思います。

尚、MVTecにおいては、PatchCoreがSOTAではあります。
しかし、その他のデータによっては、DN2の方が精度が高くなったり、SPADEやPaDiMの方がリーズナブルであったりということはあるかもしれません。
ここで説明させて頂いたアルゴリズムの特性を背景に考察をして頂き、実験をして頂き、最適なアルゴリズムを見つけてもらえたらと思います。



おわりに

以上、今回の記事にて、異常検知4手法のラストであるPatchCoreについて、詳細に踏み込んだ解説をさせて頂きました。
ここまで読んでくださった方には、またまたまたまた感謝です。🙇
本当にありがとうございました。 という訳で、これにて、ImageNetモデルを用いた異常検知手法の紹介、及び、解説を終えさせて頂きます。🙇

AnyTechでは、流体のためのAI「DeepLiquid」の研究/開発/事業に携わってくれる方を募集しております。
ご興味のある方はこちらから、是非お問い合わせください。