AnyTech Engineer Blog

AnyTech Engineer Blogは、AnyTechのエンジニアたちによる調査や成果、Tipsなどを公開するブログです。

ImageNetモデルを用いた異常検知手法の解説【第2回:DN2(Deep Nearest Neighbor Anomaly Detection)】

ImageNetモデルを用いた異常検知手法の解説【第2回:DN2(Deep Nearest Neighbor Anomaly Detection)】

こんにちは、AnyTechの木村と申します。
AnyTechにて、機械学習エンジニアやAIエンジニアといった役割にて、R&Dに日々従事しております。


この記事は、近年流行しているImageNetモデルを応用した異常検知手法について、リサーチをする機会がありましたので、小職の備忘も兼ねまして、その解説をさせて頂くものです。



目次



シリーズ



はじめに

第1回に引き続き、近年流行しているImageNetモデルを応用した異常検知手法について、解説をさせて頂きます。
今回は、DN2(Deep Nearest Neighbor Anomaly Detection)についての解説となります。


私事になりますが、DN2は凄く好きなアルゴリズムです。
先進性を含みながらのシンプルな手法でありながら、CNNの原点や旧来のフィルタリングといった考え方に基づく深みがあります。
思わず、ハッとさせられるような魅力があります。
今回の記事で、アルゴリズムの詳細は勿論、DN2の魅力までを上手く伝えられたらと思います。



先ずはザックリ、アルゴリズム概要

先ずは、ザックリとDN2アルゴリズムについて説明をさせて頂こうと思います。


DN2は、CNNの最終層の特徴ベクトルによるKNNを実施し、それによって得られた距離値を異常スコアとして、その大小より異常/正常の判断を行うアルゴリズムとなります。

比較対象は、任意のクエリ画像と、大量の正常画像です。
つまり、各正常画像から抽出した最終層の特徴ベクトル達の中に、任意のクエリ画像から抽出したそれと近いのものがあれば正常と判断され、近いものがなければ異常と判断されます。
尚、距離値は課題感によって揺れる、非常に相対的なものとなりますので、評価データにて最適な閾値を探索し、それより大きいか小さいかで、異常か正常かを判断する形となります。



次にじっくり、アルゴリズム詳細

さて、ここからはDN2について、じっくり詳細を解説させて頂きます。


先に、この論文が発表された時期について触れさせて欲しいのですが、なんと2020年となります。
これは、個人的にとても驚いたことです。

DN2の概要をお伝えすると、それは先程説明させて頂いたような、K-NearestNeighborによる転移学習に少しアレンジを加えて、異常検知に応用したアルゴリズムです。
また、K-NearestNeighborや転移学習という手法は古くから知られているものです。
それらの手法の組み合わせが、つい最近まで実現されていなかったことが驚きでした。

具体的には、K-NearestNeighborは、wikiによれば、1991年の論文がルーツであるようです。
転移学習は、wikiによれば、1976年の論文がルーツであるようです。
また、転移学習については、2016年のNIPSにて、かの有名なアンドリューさんから、その重要性について改めて強調されていたとのことです。

仮に、2016年から数えたとしても、DN2の論文が発表されるまでには、4年が経過しています。
Deep Learning盛況の昨今において、このシンプルな発想の見過ごしは、かなりの盲点ではないかと思います。

ただし、だからこそ、上手く隙間を突いた秀逸な発明かと思います。
個人的には、こういった間隙論文が大好きです。
ちなみに、論文中にも、「ディープラーニングコミュニティから殆ど注目されてこなかった、異常検知へのKNN適応を新たに提示」との記載があります。
或いは、「全ての新しいworkを、この単純な方法と比較してみてくれ」という記載もあります。
発明者らとしても、間隙感を認識/堪能されている様子です。笑

確かに、CNNについてよくよく考えれば、無限のテンプレートマッチングの実現が難しい為に、目的に合わせて効率良くその代替実現を目指すところから始まっていると思います。
そして、現状でもやはり無限のテンプレートマッチングは難しい中で、効率よく抽象化された特徴によって、それに近しいことを行おうという発想は、正に地に足ついた分析手法であると、強くそう思います。


具体の前に、もう少しだけ、DN2のアルゴリズムの概要に触れさせて下さい。
個人的な見解としまして、この発明が特にユニークな点は、近さではなく遠さに注目をした点ではないかと思います。

転移学習とKNNとの組合せで言うと、先程のプログラム実行でも示した通りで、クエリ画像とのご近所さんをK個ピックアップした上で、そのご近所さん達をクエリ画像と同質であると捉え、性質を推し量るというものでした。
つまり、ご近所さんでない遠い対象は、性質を推し量る上で参考としていません。

一方で、DN2が対象としているタスクは異常検知であり、判定の拠り所を遠さとしています。
DN2は、ImageNetモデルの特徴を転移学習のように用いて、KNNアルゴリズムでご近所さんをK個ピックアップし、そのご近所さん全てとの距離、或いは、その内の幾つかとの距離が遠かった場合に、クエリ画像を異常であると判断するアルゴリズムとなっています。
つまり、最たるご近所さんですら遠いという場合に、クエリが異常であると判断する次第です。

尚、遠さについては、正常/異常を共に含む評価データにて、距離を集計した上で、最適な閾値を策定します。
KNNでピックアップされたクエリ画像とのK個のご近所さんとの距離の平均が、閾値を上回れば正常で、閾値を下回れば異常と判断される形になります。


さて、ここからは、具体的なアルゴリズムの解説へと進みます。
以下は、論文に記載されている説明を、概ねシンプルに和訳したものとなります。


DN2は、一連の入力画像 X_{train} = x_1, x_2, ..., x_N を取り込み、これらを全て正常画像であるとする。
そして、 X_{train} 全件から、ImageNetモデル等の特徴抽出器 F を用いて、 f_i = F(x_i) という形で特徴抽出を行う。
尚、特徴抽出器には、ResNetを使用する。

抽出した一連の特徴達は、 F_{train} = f_1, f_2, ..., f_N と展開される。
また、この一連の特徴 F_{train} を保存しておくことで、学習画像の読込みや、そこからの特徴抽出処理は、省略することが可能となる。

新しいサンプル y が異常かどうかを推測するためには、先ず、そのサンプルから f_y = F(y) という形で、特徴を抽出する。
次に、そのKNN距離を計算し、それを異常スコアとして扱う。

尚、KNN距離算出関数を d とし、距離計算方法にはユークリッド距離を採用する。
つまり、d(y) = \frac{1}{k}\sum_{f∈N_k(f_y)}||f-f_y||^2 という式となる。
尚、ユークリッド距離でなくても、DN2は実現可能となる。

上記式の N_k(f_y) は、学習データから抽出した特徴 F_{train} の内の、 f_y に最も近いK個の特徴を示す。
その為、 d(y) は、K個のご近所さんとの距離の平均となる。
その d(y) に対して閾値を適用し、上回れば異常、下回れば正常との判断をする。


以上が、論文に沿ったアルゴリズムの説明となります。
非常にシンプルなアルゴリズム説明です。
転移学習に近いという感覚も、理解頂けたかと思います。

尚、上記アルゴリズムが破綻するとしたら、距離の平均 d(y) を考えた場合には、K個の内の1つの距離が極端に大きいが為に距離平均が大きくなってしまうという事態かと思われますが、先の結果からImageNetにて学習されたモデルの優秀さが伺える点と、Kを調整すれば回避が可能な点、データの質と量で解決可能な点から、そういった事象は恐らくは発生しづらいものと思われます。


それでは、実際にプログラムでも実践してみましょう。
データは先ず、Cats vs Dogsのデータセットを用います。
猫画像5,000枚を正常画像として X_{train} に、別の猫画像1,000枚と犬画像1,000枚とを評価データにしてみます。
つまり、猫画像を正常、犬画像を異常と識別できれば、精度が高くなる形です。
ImageNetモデルは、先ずは、AlexNetでやってみます。


最初に、モデルをロードします。
そして、特徴抽出器とするための仕掛けとして、最終特徴ベクトルをピックアップするようにhookを仕掛けます。

import torch
import torchvision.models as models
from torchinfo import summary

device = torch.device('cuda:0')  # torch.device('cpu')

model = models.alexnet(weights=models.AlexNet_Weights.DEFAULT)
model.eval()
model.to(device)

print('model =', model)

# set model's intermediate outputs
outputs = []

def hook(module, input, output):
    outputs.append(output)

model.classifier[5].register_forward_hook(hook)

print(summary(model,input_size=(1, 3, 224, 224)))


次に、転移学習の際と同様に、全データのファイル名を収集します。

import os

path = './dogs-vs-cats/train/'
files = os.listdir(path)

files_cat = [os.path.join(path, f) for f in files
             if (os.path.isfile(os.path.join(path, f)) &
                 ('cat.' in f) & ('.jpg' in f))]
files_dog = [os.path.join(path, f) for f in files
             if (os.path.isfile(os.path.join(path, f)) &
                 ('dog.' in f) & ('.jpg' in f))]

files_cat = sorted(files_cat)
files_dog = sorted(files_dog)

print('len(files_cat) =', len(files_cat))
print('files_cat[:10] =\n', files_cat[:10])
print()
print('len(files_dog) =', len(files_dog))
print('files_dog[:10] =\n', files_dog[:10])


ファイル名が取得できたら、正常とする猫画像の枚数を変数 N_train_cat に、評価のための猫画像の枚数を変数 N_val_cat に、評価のための犬画像の枚数を変数 N_dog に、それぞれセットします。

N_train_cat = 5000
N_val_cat = 1000
N_dog = 1000


セットした各枚数に従って、画像の読込、及び、モデルに入力するための前処理を実施したデータを、猫画像/犬画像について収集します。

猫画像からの特徴ベクトル抽出

import cv2
import numpy as np
import random
from tqdm.notebook import tqdm

random.seed(0)

outputs = []

img_cat = []
img_prep_cat = []

files_cat = random.sample(files_cat, (N_train_cat + N_val_cat))
for file_cat in tqdm(files_cat):
    img = cv2.imread(file_cat)[..., ::-1]  # BGR2RGB
    img_prep = cv2.resize(img, (224, 224))

    img_cat.append(img)
    img_prep_cat.append(img_prep)

    x = img_prep
    x = x / 255
    x = x - np.array([[[0.485, 0.456, 0.406]]])
    x = x / np.array([[[0.229, 0.224, 0.225]]])
    x = torch.from_numpy(x.astype(np.float32)).unsqueeze(0).permute(0, 3, 1, 2)
    x = x.to(device)

    with torch.no_grad():
        _ = model(x)

x_cat = torch.vstack(outputs).reshape(len(outputs), -1).detach().cpu().numpy()

print('x_cat.shape =', x_cat.shape)


犬画像からの特徴ベクトル抽出

outputs = []

img_dog = []
img_prep_dog = []

files_dog = random.sample(files_dog, N_dog)
for file_dog in tqdm(files_dog):
    img = cv2.imread(file_dog)[..., ::-1]  # BGR2RGB
    img_prep = cv2.resize(img, (224, 224))

    img_dog.append(img)
    img_prep_dog.append(img_prep)

    x = img_prep
    x = x / 255
    x = x - np.array([[[0.485, 0.456, 0.406]]])
    x = x / np.array([[[0.229, 0.224, 0.225]]])
    x = torch.from_numpy(x.astype(np.float32)).unsqueeze(0).permute(0, 3, 1, 2)
    x = x.to(device)

    with torch.no_grad():
        _ = model(x)

x_dog = torch.vstack(outputs).reshape(len(outputs), -1).detach().cpu().numpy()

print('x_dog.shape =', x_dog.shape)


収集が終わったら、それらを学習データと評価データに分割します。 検証用に、ファイル名称も学習データと評価データに分割します。

x_train = x_cat[:N_train_cat]
x_val = np.concatenate([x_cat[N_train_cat:], x_dog], axis=0)
y_val = np.concatenate([np.zeros([N_val_cat], dtype=np.int16), 
                        np.ones([N_dog], dtype=np.int16)], axis=0)
files_train = np.array(files_cat[:N_train_cat])
files_val = np.array(files_cat[N_train_cat:] + files_dog)
img_train = img_cat[:N_train_cat]
img_val = img_cat[N_train_cat:] + img_dog
img_prep_train = np.array(img_prep_cat[:N_train_cat])
img_prep_val = np.array(img_prep_cat[N_train_cat:] + img_prep_dog)

print('x_train.shape =', x_train.shape)
print('x_val.shape =', x_val.shape)
print('y_val.shape =', y_val.shape)
print('files_train.shape =', files_train.shape)
print('files_val.shape =', files_val.shape)
print('len(img_train) =', len(img_train))
print('len(img_val) =', len(img_val))
print('img_prep_train.shape =', img_prep_train.shape)
print('img_prep_val.shape =', img_prep_val.shape)


ここまでできたら、もうKNNを実践する準備はできました。
もう推論が実施できる状態ですので、いわゆる学習フェーズはもう完了していると言えます。
正常画像から抽出された特徴ベクトルをメモリ上、即ち、変数として保持をしておき、それでもってKNNによる距離測定を行う次第です。

尚、紹介の異常検知4手法の1つであるSPADEでは、この保持された特徴ベクトルのセットのことを、特徴ギャラリーと呼んでいます。
その為、ここでも特徴ギャラリーと呼ぼうと思います。


ちなみに、この特徴ギャラリーとクエリとで行うKNN計算は、結構な計算量となります。
例えば、AlexNetからピックアップされる特徴ベクトルの次元は4,096次元となります。
そして、正常画像として用意した猫画像を、今回は5,000枚としています。
評価画像として用意した猫画像/犬画像は、合わせて2,000枚としています。
そして、KNNの距離計算式は、ユークリッド距離とします。

そうすると、計算量としては、 5,000 × 2,000 × 4,096 回の引き算と、 (5,000 + 2,000) × 4,096 回の掛け算と、 (5,000 + 2,000) × (4,096 - 1) 回の足し算とが発生します。
総当たりの計算ですから、まあまあなかなかです。
シンプルに計算に時間がかかります。

また、メモリ上に大量のデータを保持する関係から、「時間と空間のトレードオフ」という問題にもぶち当たります。
簡単には、以下のようなトレードオフです。

  • for文等でメモリにデータを順次載せつつ計算をすると、メモリパンクの心配はないが、凄く時間がかかってしまう
  • 一気にメモリにデータを載せて、ブロードキャスト計算をすると、高速に計算が行われるものの、メモリ上に載せるデータが多過ぎた場合に、メモリがパンクしてしまう(プログラム実行はメモリエラーで異常終了する)


これらは結構悩ましい問題です。
特にpythonは、for文を使うと著しくパフォーマンスが落ちてしまう言語です。
そちらについては、以下の記事が分かりやすいかと思いますので、参考までに。
python で、 for 文を使ったら負け(笑)|mucun_wuxian


ただし、ここで素晴らしい救いの手があります。
それは、紹介の異常検知4手法の内の1つ、PatchCoreの実装に用いられている、Faissというライブラリになります。
Welcome to Faiss Documentation — Faiss documentation

距離計算系の高速化が、このライブラリを用いることで実現できます。
尚、計算内容にもよりますが、手元の測定ではブロードキャスト計算を行うより高速になりました。
また、かなり大量のデータでも異常終了せず、メモリ逼迫の問題についても大幅に解消されます。
それらの理由としては、以下記事曰く、内部的にBranch-and-BoundLEMPFEXIPRO等のアルゴリズムが適用されている為だそうです。
実装のレベルの高さに脱帽です。
Facebook AI Similarity Search (FAISS), Part 1

ちなみに、PatchCoreの公式実装は、amazonの研究者が行っており、Faissの実装は、facebookの研究者が行っています。
迫力のあるコラボレーション感です。w


そんな訳で、この解説でのDN2実装については、Faissを用いてみようと思います。
使い勝手は非常に軽快である為、この解説を通して、簡単に使いこなせるようになるかと思います。

尚、Faissのインストールについては、conda install -c pytorch faiss-gpu によって実施をしました。
これにて、GPUを用いたFaissの計算を行えるようなります。


それでは、KNNの実践を、Faissで行っていきましょう。
先程、特徴ギャラリーに登録するための正常データ、つまり、猫画像の特徴ベクトルは、変数 x_train にまとめておきました。
それを、特徴ギャラリーに登録してあげようと思います。
以下コードにて、それが実践されます。

import faiss

d = x_train.shape[1]

index = faiss.GpuIndexFlatL2(faiss.StandardGpuResources(), 
                             d, 
                             faiss.GpuIndexFlatConfig())

index.add(x_train)

print('index =', index)
print('index.ntotal =', index.ntotal)


非常にシンプルです。
尚、実装は、以下チュートリアルに沿っています。
Getting started · facebookresearch/faiss Wiki · GitHub

上記実装によって、ユークリッド距離によるKNNを、GPU上で実施してくれる形となります。
尚、上記コードの変数 d には、特徴ベクトルの次元数である4,096という数値が格納されています。


特徴ギャラリーの準備ができたら、後は比較対象のクエリを入力するだけです。
クエリを入力するコードは以下となります。

k = 5  # we want to see 5 nearest neighbors

D_val, I_val = index.search(x_val, k)

print('D_val.shape =', D_val.shape)
print('I_val.shape =', I_val.shape)

import matplotlib.pyplot as plt

plt.figure(figsize=(10, 6), dpi=100)

plt.subplot(1, 2, 1)
plt.imshow(D_val, aspect='auto')
plt.colorbar()
plt.title('KNN distanse')

plt.subplot(1, 2, 2)
plt.imshow(I_val, aspect='auto')
plt.colorbar()
plt.title('KNN index')

plt.show()


クエリとの距離計算実行に関しても、シンプルなコードで実現できます。
また、KNNのK、即ち、ピックアップしたいご近所さんの数を、事前に指定する必要がないことも、Kのチューニングの観点から、非常に有り難い次第です。

尚、KNNの結果としては、K個の距離と、K個のデータインデックスが帰ってきます。
それを図示したものが、上記コードより出力されます。
距離の図示を見ると、異常検知が上手くできていそうなことが、確認できるかと思います。
大まかに、猫画像との距離が小さく、犬画像との距離が大きくなっています。


さて、先程のDN2のアルゴリズム説明に沿って、このK個の距離を平均したものを算出します。
その大小によって、正常/異常を判断する形です。

その実践コードが、以下となります。

D_val = np.mean(D_val, axis=1)

print('D_val.shape =', D_val.shape)

plt.figure(figsize=(10, 8), dpi=100)

plt.subplot(2, 1, 1)
plt.scatter(np.where(y_val == 0)[0], D_val[y_val == 0], alpha=0.5)
plt.scatter(np.where(y_val == 1)[0], D_val[y_val == 1], alpha=0.5)
plt.grid()

plt.subplot(2, 1, 2)
plt.hist(D_val[y_val == 0], alpha=0.5, bins=50)
plt.hist(D_val[y_val == 1], alpha=0.5, bins=50)
plt.grid()

plt.show()


比較的、上手く猫画像と犬画像の識別が、異常検知の要領で実施できているかと思います。
或いは、予測の分布として、2者が比較的重なっているかとも思われます。
この辺りを掘り下げていきましょう。


先ず、精度がどのくらいかを算出してみましょう。
評価データである猫画像1,000枚と犬画像1,000枚とで算出した距離にて、最適な閾値、即ち、最も精度が高くなる閾値を探索してみます。

それを実践するコードが以下となります。

acc_list = []
thresh_list = np.arange(0, np.max(D_val)+1e-10, np.max(D_val)/50)

for thresh in thresh_list:

    acc = np.mean(np.concatenate([(D_val[y_val == 0] < thresh),
                                  (D_val[y_val == 1] > thresh)]))
    acc_list.append(acc)

acc_list = np.array(acc_list)

plt.figure(figsize=(10, 4), dpi=100)
plt.plot(thresh_list, acc_list, alpha=0.5, linewidth=3)
plt.scatter(thresh_list, acc_list, alpha=0.5, s=30)
plt.grid()
plt.title('np.max(acc_list) = %.3f' % np.max(acc_list))
plt.show()


精度は、評価データでの正解率76.3%が最大となりました。
しかし、先程のKNNの転移学習での精度が正解率93.1%だったのに比べると、少し物足りないかもしれません。

ここで、AlexNetではなく、WideResNet50を用いてみましょう。
それによって、精度がどのくらい向上するかを確認します。
先に伝える形になってしまいますが、WideResNet50は紹介の異常検知4手法の内の3つ、SPADE/PaDiM/PatchCoreにて、実際に特徴抽出器として用いられているモデル構造となります。

それを実践するコードは、以下となります。
尚、モデル構築部分だけをカスタマイズすれば、それは実現ができますので、その変更部分だけを記載します。

import torch
import torchvision.models as models
from torchinfo import summary

device = torch.device('cuda:0')  # torch.device('cpu')

model = models.wide_resnet50_2(weights=models.Wide_ResNet50_2_Weights.IMAGENET1K_V1)

model.eval()
model.to(device)

print('model =', model)

# set model's intermediate outputs
outputs = []

def hook(module, input, output):
    outputs.append(output)

model.avgpool.register_forward_hook(hook)

print(summary(model,input_size=(1, 3, 224, 224)))


WideResNet50を特徴抽出器をとした場合の、距離分布と算出精度は以下となります。


距離の分布の傾向が変わっていることが確認できるかと思います。
評価データでの正解率も最大90.3%まで向上しました。


ここで、更なる精度向上への施策として、Cats vs Dogsデータセットのデータを、全て使ってみます。
猫画像10,000枚を正常画像として X_{train} に、別の猫画像2,500枚と犬画像2,500枚とを評価データにしてみます。
引続き、猫画像を正常、犬画像を異常と識別できれば、精度が高くなる形です。

改めて、正常とする猫画像の枚数を変数 N_train_cat に、評価のための猫画像の枚数を変数 N_val_cat に、評価のための犬画像の枚数を変数 N_dog に、それぞれセットします。

N_train_cat = 10000
N_val_cat = 2500
N_dog = 2500


データを増やした上での、WideResNet50を特徴抽出器とした、距離の分布と、算出精度は以下となります。


評価データでの正解率が最大91.1%まで向上しました。
特徴抽出器に採用するモデルの先進性、及び、データ数の増加によって、精度が高くなる傾向が見受けられます。


ちなみに、間違っているデータにはどんなものがあるでしょうか。
大きく間違えた対象のTOP10を見てみます。
以下コードで確認できます。

flg_dog = 0
flg_asc = False

if flg_asc:
    idx_miss = np.argsort(D_val[y_val == flg_dog])
else:
    idx_miss = np.argsort(-D_val[y_val == flg_dog])

for i_miss in idx_miss[:10]:
    img = [img_val[i] for i in np.where(y_val == flg_dog)[0]][i_miss]
    img_prep = img_prep_val[y_val == flg_dog][i_miss]
    
    plt.figure(figsize=(10, 6), dpi=100, facecolor='white')
    plt.subplot(1, 2, 1)
    plt.imshow(img)
    plt.subplot(1, 2, 2)
    plt.imshow(img_prep)
    plt.title('D:%.3f' % D_val[y_val == flg_dog][i_miss])
    plt.show()

  • 猫:算出した距離の大きさTOP1

  • 猫:算出した距離の大きさTOP2

  • 猫:算出した距離の大きさTOP3

  • 猫:算出した距離の大きさTOP4

  • 猫:算出した距離の大きさTOP5

  • 猫:算出した距離の大きさTOP6

  • 猫:算出した距離の大きさTOP7

  • 猫:算出した距離の大きさTOP8

  • 猫:算出した距離の大きさTOP9

  • 猫:算出した距離の大きさTOP10


  • 犬:算出した距離の小ささTOP1

  • 犬:算出した距離の小ささTOP2

  • 犬:算出した距離の小ささTOP3

  • 犬:算出した距離の小ささTOP4

  • 犬:算出した距離の小ささTOP5

  • 犬:算出した距離の小ささTOP6

  • 犬:算出した距離の小ささTOP7

  • 犬:算出した距離の小ささTOP8

  • 犬:算出した距離の小ささTOP9

  • 犬:算出した距離の小ささTOP10



特に猫画像で異常と判断されてしまったTOP10については、かなり納得感が高いかと思います。
逆に言えば、こういったノイズの多い画像の特徴が、特徴ギャラリーに登録されていると、異常検知精度への悪影響がありそうです。

犬画像で正常圏内の距離に収まっていた対象、つまり、特徴ギャラリー特徴とのKNN距離が短かったTOP10については、逆にハッとする示唆が多いかと思います。
耳が立っている、正面向きで顔のアップである、比較的可愛らしい等の傾向を持つ犬画像が、特徴ギャラリーとのKNN距離が短くなる傾向がありそうです。


また、考察のポイントとしましては、異常検知アプローチながら、最大正解率91.1%という精度が引き出せている、というところが重要です。
転移学習でのKNN識別器は、93.5%という正解率が出ていましたが、猫か犬かの2択の識別においての精度であり、仮に馬画像のデータが入力された場合にも、それが猫か犬のどちらに近いかという、ローカルな視野でしか判断できません。
馬画像がどちらに近かろうが、正解/不正解の判断はできません。
つまり、想定外の入力に対応できない、或いは、グローバルな課題感には対応できないというのが、識別器の限界となります。
あくまで、猫画像か犬画像が入力された際に、どちらかを当てることしかできません。

一方で、異常検知というアプローチは、特徴ギャラリーに登録されたデータ達からの外れているか否かを、グローバルに判断しています。
馬画像のデータが入力されれば、それは猫画像とは外れる異常であると、判断することが可能です。
これは、実用を考えると大きな優位性/メリットとなります。


という訳で、次には、Cats vs Dogsデータセットの猫画像特徴を特徴ギャラリーに持つKNNモデルに、様々な動物のテストデータを入力してみて、その効力を確認してみましょう。
テストデータは、Cats vs Dogsデータセットと同様にKaggleが発行して下さっている、animals10というデータセットを用いようと思います。
Animals-10 | Kaggle

ダウンロードすると、 archive/raw-img というフォルダ配下に、10種の動物画像がフォルダ分けされていることを確認できます。
動物の名称は、イタリア語表記のようです。

% tree -d archive
archive
└── raw-img
    ├── cane(犬)
    ├── cavallo(馬)
    ├── elefante(ゾウ)
    ├── farfalla(チョウチョ)
    ├── gallina(ニワトリ)
    ├── gatto(猫)
    ├── mucca(牛)
    ├── pecora(羊)
    ├── ragno(蜘蛛)
    └── scoiattolo(リス)

11 directories

この raw-img フォルダを、animals10 と名称変更して、notebookプログラムが格納されているパスに配置します。
その上で、以下のコードにて、画像ファイル名称を収集します。

path = './animals10/'
files = {'cane':[], 'cavallo':[], 'elefante':[], 'farfalla':[], 'gallina':[],
         'gatto':[], 'mucca':[], 'pecora':[], 'ragno':[], 'scoiattolo':[]}

for animal in files.keys():

    files_tmp = os.listdir(os.path.join(path, animal))
    files[animal] = [os.path.join(path, animal, f) for f in files_tmp
                     if (os.path.isfile(os.path.join(path, animal, f)) &
                         (('.jpg' in f) | ('.jpeg' in f) | ('.png' in f)))]
    files[animal] = sorted(files[animal])

    print('len(files[%s]) =' % animal, len(files[animal]))
    print('files[%s][:5] =\n' % animal, files[animal][:5])
    print()


各動物種フォルダ毎に、約1,400〜4,900枚程のファイルが格納されているようです。
今回は、それぞれから1,000毎ずつの画像をピックアップしてみます。
そして、ピックアップした画像の内、gatto(猫) のみを正常、それ以外を異常と判断できるかどうか、検証をしてみます。

以下が、フォルダ内のファイルをランダムに1,000枚ずつピックアップしつつ、特徴の抽出も実施するコードとなります。

random.seed(0)

N_sample = 1000

x_test = {'cane':[], 'cavallo':[], 'elefante':[], 'farfalla':[], 'gallina':[],
          'gatto':[], 'mucca':[], 'pecora':[], 'ragno':[], 'scoiattolo':[]}

for animal in files.keys():

    outputs = []

    files_tmp = random.sample(files[animal], N_sample)
    for file in tqdm(files_tmp):

        img = cv2.imread(file)[..., ::-1]  # BGR2RGB
        img_prep = cv2.resize(img, (224, 224))

        x = img_prep
        x = x / 255
        x = x - np.array([[[0.485, 0.456, 0.406]]])
        x = x / np.array([[[0.229, 0.224, 0.225]]])
        x = torch.from_numpy(x.astype(np.float32))
        x = x.unsqueeze(0).permute(0, 3, 1, 2)
        x = x.to(device)

        with torch.no_grad():
            _ = model(x)

    x_test[animal] = torch.vstack(outputs).reshape(len(outputs), -1)
    x_test[animal] = x_test[animal].detach().cpu().numpy()

    print('x_test[%s].shape =' % animal, x_test[animal].shape)


閾値については、Cats vs Dogsでの検証で最も正解率の高かった 229.03 という数字を暫定的に用います。

thresh = thresh_list[np.argmax(acc_list)]

print(thresh)


そして、Cats vs Dogsデータにて作成した特徴ギャラリー、即ち、faissのindexを用いて、距離計算を実施します。

k = 5  # we want to see 5 nearest neighbors

D_test = []
y_test = []
y_hat_test = []

plt.figure(figsize=(10, 8), dpi=100)

for i_animal, animal in enumerate(files.keys()):

    D, I = index.search(x_test[animal], k)
    D = np.mean(D, axis=1)
    D_test.append(D)

    if (animal== 'gatto'):
        y_test.append(np.zeros([N_sample]).astype(np.int16))
    else:
        y_test.append(np.ones([N_sample]).astype(np.int16))

    plt.subplot(2, 1, 1)
    plt.scatter((np.arange(N_sample) + (i_animal * N_sample)), D,
                alpha=0.5, label=animal)

    plt.subplot(2, 1, 2)
    plt.hist(D, alpha=0.5, bins=50, label=animal)

D_test = np.hstack(D_test)
y_test = np.hstack(y_test)
y_hat_test = (D_test > thresh).astype(np.int16)

plt.subplot(2, 1, 1)
plt.grid()
plt.legend()

plt.subplot(2, 1, 2)
plt.grid()
plt.legend()
plt.title('acc = %.3f' % np.mean(y_test == y_hat_test))

plt.show()


異常検知精度は95.6%まで発揮されました。
非常に上手く異常検知ができています。

仮に、閾値の最適化を行うと、どのくらいの精度が引き出せるかを、以下コードにて確認してみます。

acc_list = []
thresh_list = np.arange(0, np.max(D_test)+1e-10, np.max(D_test)/50)

for thresh in thresh_list:

    y_hat_test = (D_test > thresh).astype(np.int16)
    acc = np.mean(y_test == y_hat_test)
    acc_list.append(acc)

acc_list = np.array(acc_list)

plt.figure(figsize=(10, 4), dpi=100)
plt.plot(thresh_list, acc_list, alpha=0.5, linewidth=3)
plt.scatter(thresh_list, acc_list, alpha=0.5, s=30)
plt.grid()
plt.title('np.max(acc_list) = %.3f, thresh_list[np.argmax(acc_list) = %.2f' %
          (np.max(acc_list), thresh_list[np.argmax(acc_list)]))
plt.show()


96.8%が最大の正解率となりました。
つまりは、Cats vs Dogsのデータ全件で閾値調整すると、animal10に対しては比較的最適に近い閾値が導出されていることが確認できました。
尚、animal10データは、猫画像以外のデータ数が多い為に、Cats vs Dogsデータにて行った同数データずつでの評価よりも、原理的に精度は出しやすそうです。

尚、accuracy基準で、最適な閾値として導出された 196.30 という数字を採用する場合、不正解率は「100.0% - 96.8%=3.2%」となるのですが、その殆どは gatto(猫) の画像のようでした。
猫画像でないものを異常として捉えることには、概ね成功していると言えます。
また、 gatto(猫) にて異常と判定された対象が、異常と判断されてもしょうがない対象であれば、より上手く異常検知ができていると判断でき、そうでなければ、特徴ギャラリーへの登録が不足していると判断できそうです。


それでは、誤検知の対象を見ていこうと思います。
先ずは、gatto(猫) からです。
gatto(猫) は示唆が深い為、TOP20まで確認しようと思います。
確認をするコードと、その結果は以下となります。

animal = 'gatto'

i = np.where(np.array([k for k in x_test.keys()]) == animal)[0][0]
i_from = i * N_sample
i_to = (i + 1) * N_sample
idx_miss = np.argsort(-np.abs(D_test[i_from:i_to]))

for i_miss in idx_miss[:20]:
    file = files[animal][i_miss]
    img = cv2.imread(file)[..., ::-1]  # BGR2RGB
    img_prep = cv2.resize(img, (224, 224))
    
    plt.figure(figsize=(10, 6), dpi=100, facecolor='white')
    plt.subplot(1, 2, 1)
    plt.imshow(img)
    plt.subplot(1, 2, 2)
    plt.imshow(img_prep)
    plt.title('D:%.3f' % D_test[i_from + i_miss])
    plt.show()
  • gatto(猫):算出した距離の大きさTOP1

  • gatto(猫):算出した距離の大きさTOP2

  • gatto(猫):算出した距離の大きさTOP3

  • gatto(猫):算出した距離の大きさTOP4

  • gatto(猫):算出した距離の大きさTOP5

  • gatto(猫):算出した距離の大きさTOP6

  • gatto(猫):算出した距離の大きさTOP7

  • gatto(猫):算出した距離の大きさTOP8

  • gatto(猫):算出した距離の大きさTOP9

  • gatto(猫):算出した距離の大きさTOP10

  • gatto(猫):算出した距離の大きさTOP11

  • gatto(猫):算出した距離の大きさTOP12

  • gatto(猫):算出した距離の大きさTOP13

  • gatto(猫):算出した距離の大きさTOP14

  • gatto(猫):算出した距離の大きさTOP15

  • gatto(猫):算出した距離の大きさTOP16

  • gatto(猫):算出した距離の大きさTOP17

  • gatto(猫):算出した距離の大きさTOP18

  • gatto(猫):算出した距離の大きさTOP19

  • gatto(猫):算出した距離の大きさTOP20


太っている猫が多いことや、背景の写り込み等が多いもの、2頭写っているもの、字が写り込んでいるもの、元画像のアスペクト比を大きく変えてしまっているもの等々、原因は色々と考えられそうです。
前処理に関連するような内容に関しては、そのロジックを工夫することで、精度向上が図れるのかもしれません。


次に、 gatto(猫) 以外で、特徴ギャラリーとのKNN距離が近かった対象を次々見てみます。
確認コードは cane(犬) の場合で以下で、`animal = 'cane' のコードを所望のanimal種に切り替えることで、その動物種におけるKNN距離が近かった対象TOP5を確認することができます。

尚、画像のタイトル部分に記載されている数字がKNN距離となります。
そして、精度を最大限引き出す閾値が 66.81 となります。
閾値を念頭に確認して頂くと、 gatto(猫) 以外の動物種について、閾値を下回る対象が殆どないことが、確認できるかと思います。

animals = ['cane', 'cavallo', 'elefante', 'farfalla', 'gallina',
           'mucca', 'pecora', 'ragno', 'scoiattolo']

for animal in animals:
    i = np.where(np.array([k for k in x_test.keys()]) == animal)[0][0]
    i_from = i * N_sample
    i_to = (i + 1) * N_sample
    idx_miss = np.argsort(np.abs(D_test[i_from:i_to]))

    for j_miss, i_miss in enumerate(idx_miss[:5]):
        print(animal, j_miss+1)
        file = files[animal][i_miss]
        img = cv2.imread(file)[..., ::-1]  # BGR2RGB
        img_prep = cv2.resize(img, (224, 224))

        plt.figure(figsize=(10, 6), dpi=100, facecolor='white')
        plt.subplot(1, 2, 1)
        plt.imshow(img)
        plt.subplot(1, 2, 2)
        plt.imshow(img_prep)
        plt.title('D:%.3f' % D_test[i_from + i_miss])
        plt.show()

  • cane(犬):算出した距離の小ささTOP1

  • cane(犬):算出した距離の小ささTOP2

  • cane(犬):算出した距離の小ささTOP3

  • cane(犬):算出した距離の小ささTOP4

  • cane(犬):算出した距離の小ささTOP5


  • cavallo(馬):算出した距離の小ささTOP1

  • cavallo(馬):算出した距離の小ささTOP2

  • cavallo(馬):算出した距離の小ささTOP3

  • cavallo(馬):算出した距離の小ささTOP4

  • cavallo(馬):算出した距離の小ささTOP5


  • elefante(ゾウ):算出した距離の小ささTOP1

  • elefante(ゾウ):算出した距離の小ささTOP2

  • elefante(ゾウ):算出した距離の小ささTOP3

  • elefante(ゾウ):算出した距離の小ささTOP4

  • elefante(ゾウ):算出した距離の小ささTOP5


  • farfalla(チョウチョ):算出した距離の小ささTOP1

  • farfalla(チョウチョ):算出した距離の小ささTOP2

  • farfalla(チョウチョ):算出した距離の小ささTOP3

  • farfalla(チョウチョ):算出した距離の小ささTOP4

  • farfalla(チョウチョ):算出した距離の小ささTOP5


  • gallina(ニワトリ):算出した距離の小ささTOP1

  • gallina(ニワトリ):算出した距離の小ささTOP2

  • gallina(ニワトリ):算出した距離の小ささTOP3

  • gallina(ニワトリ):算出した距離の小ささTOP4

  • gallina(ニワトリ):算出した距離の小ささTOP5


  • mucca(牛):算出した距離の小ささTOP1

  • mucca(牛):算出した距離の小ささTOP2

  • mucca(牛):算出した距離の小ささTOP3

  • mucca(牛):算出した距離の小ささTOP4

  • mucca(牛):算出した距離の小ささTOP5


  • pecora(羊):算出した距離の小ささTOP1

  • pecora(羊):算出した距離の小ささTOP2

  • pecora(羊):算出した距離の小ささTOP3

  • pecora(羊):算出した距離の小ささTOP4

  • pecora(羊):算出した距離の小ささTOP5


  • ragno(蜘蛛):算出した距離の小ささTOP1

  • ragno(蜘蛛):算出した距離の小ささTOP2

  • ragno(蜘蛛):算出した距離の小ささTOP3

  • ragno(蜘蛛):算出した距離の小ささTOP4

  • ragno(蜘蛛):算出した距離の小ささTOP5


  • scoiattolo(リス):算出した距離の小ささTOP1

  • scoiattolo(リス):算出した距離の小ささTOP2

  • scoiattolo(リス):算出した距離の小ささTOP3

  • scoiattolo(リス):算出した距離の小ささTOP4

  • scoiattolo(リス):算出した距離の小ささTOP5



見て頂くと、 cane(犬)scoiattolo(リス) など、猫に近い対象が、特にスコアが低いことが確認できます。

テストデータ全体では、動物10種の10,000枚を用意し、その内の3.2%、即ち、320枚しか間違っていないということなので、かなり上手くできていることが伺えます。
DN2が非常に有効な手法であることが、改めて理解頂けたかと思います。


さて、DN2解説の締めとしまして、異常検知のオープンデータとして広く用いられているMVTecデータセットでも検証を行ってみましょう。
MVTecデータセットは、以下のリンクからダウンロードが可能です。
Download Area: MVTec Software


MVTecデータがダウンロードできましたら、以下のコードにて、ファイル名を取得します。
尚、複数種類の異常検知課題を用意してくれているMVTecの中から、今回は bottle にて、精度検証を行ってみようと思います。

path_parent = './mvtec_anomaly_detection/'
type_data = 'bottle'

path_train = os.path.join(path_parent, type_data, 'train/good')
files_train = [os.path.join(path_train, f) for f in os.listdir(path_train)
               if (os.path.isfile(os.path.join(path_train, f)) &
                   ('.png' in f))]
files_train = np.array(sorted(files_train))

print('len(files_train) =', len(files_train))
print('files_train[:5] =\n', files_train[:5])
print()

types_test = os.listdir(os.path.join(path_parent, type_data, 'test'))
types_test = np.array(sorted(types_test))

files_test = {}

for type_test in types_test:

    path_test = os.path.join(path_parent, type_data, 'test', type_test)
    files_test[type_test] = [os.path.join(path_test, f)
                             for f in os.listdir(path_test)
                             if (os.path.isfile(os.path.join(path_test, f)) &
                                 ('.png' in f))]
    files_test[type_test] = np.array(sorted(files_test[type_test]))

    print('len(files_test[%s]) =' % type_test, len(files_test[type_test]))
    print('files_test[%s][:5] =\n' % type_test, files_test[type_test][:5])
    print()

plt.figure(figsize=(12, 12), dpi=100, facecolor='white')

plt.subplot(3, 2, 1)
plt.imshow(cv2.imread(files_train[0])[..., ::-1])
plt.title(files_train[0])

plt.subplot(3, 2, 3)
plt.imshow(cv2.imread(files_test[types_test[0]][0])[..., ::-1])
plt.title(files_test[types_test[0]][0])

plt.subplot(3, 2, 4)
plt.imshow(cv2.imread(files_test[types_test[1]][0])[..., ::-1])
plt.title(files_test[types_test[1]][0])

plt.subplot(3, 2, 5)
plt.imshow(cv2.imread(files_test[types_test[2]][0])[..., ::-1])
plt.title(files_test[types_test[2]][0])

plt.subplot(3, 2, 6)
plt.imshow(cv2.imread(files_test[types_test[3]][0])[..., ::-1])
plt.title(files_test[types_test[3]][0])

plt.show()


尚、MVTecの bottle のデータセットは、以下のようなフォルダ構造になっています。

$ !tree -d ./mvtec_anomaly_detection/bottle/
./mvtec_anomaly_detection/bottle/
├── ground_truth
│   ├── broken_large
│   ├── broken_small
│   └── contamination
├── test
│   ├── broken_large
│   ├── broken_small
│   ├── contamination
│   └── good
└── train
    └── good

11 directories

MCTec全般に 'bottle' 同様のフォルダ構造となっており、 traingood で学習をした上で、 testgood を正常に、 testgood 以外を異常に振り分ける、という形式になっています。
上記のコードは、 type_data の部分を所望のデータタイプに切り替えれば、上手く対応してくれる形となっています。

$ !ls -d ./mvtec_anomaly_detection/*/
./mvtec_anomaly_detection/bottle/     ./mvtec_anomaly_detection/pill/
./mvtec_anomaly_detection/cable/      ./mvtec_anomaly_detection/screw/
./mvtec_anomaly_detection/capsule/    ./mvtec_anomaly_detection/tile/
./mvtec_anomaly_detection/carpet/     ./mvtec_anomaly_detection/toothbrush/
./mvtec_anomaly_detection/grid/       ./mvtec_anomaly_detection/transistor/
./mvtec_anomaly_detection/hazelnut/   ./mvtec_anomaly_detection/wood/
./mvtec_anomaly_detection/leather/    ./mvtec_anomaly_detection/zipper/
./mvtec_anomaly_detection/metal_nut/


ファイル名が取得できたら、次に画像読込と特徴抽出を行っていきます。
先ず、学習データです。

outputs = []

for file in tqdm(files_train):
    img = cv2.imread(file)[..., ::-1]  # BGR2RGB
    img_prep = cv2.resize(img, (256, 256), interpolation=cv2.INTER_AREA)
    img_prep = img_prep[16:(256-16), 16:(256-16)]

    x = img_prep
    x = x / 255
    x = x - np.array([[[0.485, 0.456, 0.406]]])
    x = x / np.array([[[0.229, 0.224, 0.225]]])
    x = torch.from_numpy(x.astype(np.float32)).unsqueeze(0).permute(0, 3, 1, 2)
    x = x.to(device)

    with torch.no_grad():
        _ = model(x)

x_train = torch.vstack(outputs).reshape(len(outputs), -1).detach().cpu().numpy()

print('x_train.shape =', x_train.shape)


次に、テストデータです。
テストデータは、辞書型変数に上手く格納をします。

x_test = {}

for type_test in types_test:

    outputs = []

    for file in tqdm(files_test[type_test]):
        img = cv2.imread(file)[..., ::-1]  # BGR2RGB
        img_prep = cv2.resize(img, (256, 256), interpolation=cv2.INTER_AREA)
        img_prep = img_prep[16:(256-16), 16:(256-16)]

        # plt.figure(figsize=(10, 6), dpi=100, facecolor='white')
        # plt.subplot(1, 2, 1)
        # plt.imshow(cv2.resize(img, (256, 256)))
        # plt.subplot(1, 2, 2)
        # plt.imshow(cv2.resize(img, (256, 256))[16:(256-16), 16:(256-16)])
        # plt.show()

        x = img_prep
        x = x / 255
        x = x - np.array([[[0.485, 0.456, 0.406]]])
        x = x / np.array([[[0.229, 0.224, 0.225]]])
        x = torch.from_numpy(x.astype(np.float32)).unsqueeze(0).permute(0, 3, 1, 2)
        x = x.to(device)

        with torch.no_grad():
            _ = model(x)

    x_test[type_test] = torch.vstack(outputs).reshape(len(outputs), -1).detach().cpu().numpy()

    print('x_test[%s].shape =' % type_test, x_test[type_test].shape)


尚、読み込んだ画像に対して行う前処理ですが、DN2の論文中に「In all instances of DN2, we first resize the input image to 256 × 256, we take the center crop of size 224 × 224」という記載があるので、それに沿って、以下のように行っています。

# In all instances of DN2, we first resize the input image to 256 × 256,
# we take the center crop of size 224 × 224
plt.figure(figsize=(12, 4), dpi=100, facecolor='white')

img = cv2.imread(files_test[list(files_test.keys())[0]][0])[..., ::-1]  # BGR2RGB

plt.subplot(1, 3, 1)
plt.imshow(img)
plt.title(img.shape)

img = cv2.resize(img, (256, 256), interpolation=cv2.INTER_AREA)

plt.subplot(1, 3, 2)
plt.imshow(img)
plt.title(img.shape)
plt.plot([16, 16, (256-16), (256-16), 16],
         [16, (256-16), (256-16), 16, 16], alpha=0.5, linewidth=3, color='r')

img = img[16:(256-16), 16:(256-16)]

plt.subplot(1, 3, 3)
plt.imshow(img)
plt.title(img.shape)

plt.show()


ここで、画像リサイズのinterpolation指定ですが、SPADEのGithubコードにて行われているリサイズが torchvision.transforms.Resize(resize, Image.ANTIALIAS), という指定であり、これと同等のリサイズがされるように、上記のように指定しています。
後述しますが、これはSPADEの論文にて、明確にそう指定されているものとなります。

尚、上記のようにinterpolation指定をした方が、一般的に、キレイに画像縮小が行えるんだそうです。
python - openCV equivalent of a PIL resize ANTIALIAS? - Stack Overflow


ファイル名が収集できたら、次に、特徴ギャラリーを用意します。
特徴ギャラリーには、学習データとして収集した特徴ベクトル全件を登録します。

d = x_train.shape[1]

index = faiss.GpuIndexFlatL2(faiss.StandardGpuResources(), 
                             d, 
                             faiss.GpuIndexFlatConfig())

index.add(x_train)

print('index =', index)
print('index.ntotal =', index.ntotal)


ここまでできたら、後はKNNを実施して、距離を測定するのみです。
以下コードにて、正常/異常種別毎にバッチを分けて、KNNを実施します。

k = 5  # we want to see 5 nearest neighbors

y_test = []
D_test = []
N_test = 0

D, I = index.search(x_test['good'], k)
D = np.mean(D, axis=1)
D_test.append(D)
y_test.append(np.zeros([len(D)]).astype(np.int16))

plt.figure(figsize=(10, 8), dpi=100, facecolor='white')

plt.subplot(2, 1, 1)
plt.scatter((np.arange(len(D)) + N_test), D, alpha=0.5, label=type_test)

plt.subplot(2, 1, 2)
plt.hist(D, alpha=0.5, bins=15, label=type_test)

N_test += len(D)

for type_test in types_test[types_test != 'good']:

    D, I = index.search(x_test[type_test], k)
    D = np.mean(D, axis=1)
    D_test.append(D)
    y_test.append(np.ones([len(D)]).astype(np.int16))

    plt.subplot(2, 1, 1)
    plt.scatter((np.arange(len(D)) + N_test), D, alpha=0.5, label=type_test)

    plt.subplot(2, 1, 2)
    plt.hist(D, alpha=0.5, bins=15, label=type_test)

    N_test += len(D)

y_test = np.hstack(y_test)
D_test = np.hstack(D_test)

plt.subplot(2, 1, 1)
plt.grid()
plt.legend()

plt.subplot(2, 1, 2)
plt.grid()
plt.legend()

plt.show()


good とそれ以外の分布が、比較的キレイに分かれていることが確認できるかと思います。
尚、MVTecは著名なデータセットであり、その構成も秀逸なのですが、精度検証を行う際の注意点として、その件数が少ないことが挙げられます。
1つのデータを間違えるだけでも、精度へ比較的大きな影響を及ぼします。
その為、少々の精度増減に一喜一憂しない方が良かろうかと思われます。


予測の分布が得られたら、最後に精度を測定します。
尚、MVTecでの精度検証では、分かりやすい正解率ではなく、ROCカーブのAUCで評価を行うことが一般的です。
正解率は閾値によって揺れるのですが、ROCカーブのAUCは閾値によって揺れることがなく、手法間の比較としては分かりやすい為です。
つまり、ROCカーブのAUCは、閾値を振った際の精度のキープ度合いの指標であり、閾値からの影響を受けないと言いますか、それが織り込み済みとなっています。

ROC AUCについて詳しく知りたい方は、以下の記事などが分かりやすいので、参考にされると良いかと思います。
機械学習の評価指標 – ROC曲線とAUC | GMOアドパートナーズ TECH BLOG byGMO


正解率と、ROC AUCを計算するコードは以下となります。

from sklearn.metrics import roc_auc_score
from sklearn.metrics import roc_curve

acc_list = []
thresh_list = np.arange(0, np.max(D_test)+1e-10, np.max(D_test)/50)

for thresh in thresh_list:

    acc = np.mean(np.concatenate([(D_test[y_test == 0] < thresh),
                                  (D_test[y_test == 1] >= thresh)]))
    acc_list.append(acc)

acc_list = np.array(acc_list)

# https://github.com/byungjae89/SPADE-pytorch/blob/master/src/main.py#L118
# calculate image-level ROC AUC score
fpr, tpr, _ = roc_curve(y_test, D_test)
roc_auc = roc_auc_score(y_test, D_test)

plt.figure(figsize=(10, 4), dpi=100, facecolor='white')
plt.plot(thresh_list, acc_list, alpha=0.5, linewidth=3)
plt.scatter(thresh_list, acc_list, alpha=0.5, s=30)
plt.grid()
plt.title('np.max(acc_list) = %.3f, thresh_list[np.argmax(acc_list) = %.2f, roc_auc = %.3f' %
          (np.max(acc_list), thresh_list[np.argmax(acc_list)], roc_auc))
plt.show()


最大の正解率が90.4%、ROC AUCが96.6%と出ました。
これはかなり良い数字です。
先程紹介させて頂いた記事の中にもありますが、AUCは100%が最大値で、80%以上で高評価という水準です。
「DN2、恐るべし」と言えるでしょう。


以上で、DN2の解説を終えます。
シンプルな割に、非常に優秀なアルゴリズムであることが伺えたかと思います。

尚、論文にて、「ImageNet画像データと、任意課題の画像データとで、特に類似性というか、密接な関連はなくても良いことを、実験が示していると思う」というような記載がされています。
直感的には、ImageNetに出てくるような画像種への課題感適用がベターが気がしますが、論文にて「そうでもないよ」と言ってくれている次第です。
幾つかのデータセットで実験をされたようです。
しかし、この点が半分正しく、半分正しくない、というような指摘/論旨展開が、後発の論文にてされていくこととなります。

ただ、後発論文の着想はDN2を起点に始まっており、DN2が発明されなければ、これから紹介をする異常検知3手法は存在していなかったものと思われます。
その為、DN2の間隙感は、とても価値の高いものかと思われます。



おわりに

以上、今回の記事では、異常検知4手法の1つ目であるDN2について、詳細にまで踏み込んだ解説をさせて頂きました。
ここまで読んでくださった方には感謝です。🙇
次回の記事では、異常検知4手法の2つ目、SPADEについて説明をさせて頂こうと思います。
そちらもお読み頂けますと、有り難い次第です。🙇

AnyTechでは、流体のためのAI「DeepLiquid」の研究/開発/事業に携わってくれる方を募集しております。
ご興味のある方はこちらから、是非お問い合わせください。