quadra 0.0.1__py3-none-any.whl → 2.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hydra_plugins/quadra_searchpath_plugin.py +30 -0
- quadra/__init__.py +6 -0
- quadra/callbacks/__init__.py +0 -0
- quadra/callbacks/anomalib.py +289 -0
- quadra/callbacks/lightning.py +501 -0
- quadra/callbacks/mlflow.py +291 -0
- quadra/callbacks/scheduler.py +69 -0
- quadra/configs/__init__.py +0 -0
- quadra/configs/backbone/caformer_m36.yaml +8 -0
- quadra/configs/backbone/caformer_s36.yaml +8 -0
- quadra/configs/backbone/convnextv2_base.yaml +8 -0
- quadra/configs/backbone/convnextv2_femto.yaml +8 -0
- quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
- quadra/configs/backbone/dino_vitb8.yaml +12 -0
- quadra/configs/backbone/dino_vits8.yaml +12 -0
- quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
- quadra/configs/backbone/dinov2_vits14.yaml +12 -0
- quadra/configs/backbone/efficientnet_b0.yaml +8 -0
- quadra/configs/backbone/efficientnet_b1.yaml +8 -0
- quadra/configs/backbone/efficientnet_b2.yaml +8 -0
- quadra/configs/backbone/efficientnet_b3.yaml +8 -0
- quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
- quadra/configs/backbone/levit_128s.yaml +8 -0
- quadra/configs/backbone/mnasnet0_5.yaml +9 -0
- quadra/configs/backbone/resnet101.yaml +8 -0
- quadra/configs/backbone/resnet18.yaml +8 -0
- quadra/configs/backbone/resnet18_ssl.yaml +8 -0
- quadra/configs/backbone/resnet50.yaml +8 -0
- quadra/configs/backbone/smp.yaml +9 -0
- quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
- quadra/configs/backbone/unetr.yaml +15 -0
- quadra/configs/backbone/vit16_base.yaml +9 -0
- quadra/configs/backbone/vit16_small.yaml +9 -0
- quadra/configs/backbone/vit16_tiny.yaml +9 -0
- quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
- quadra/configs/callbacks/all.yaml +45 -0
- quadra/configs/callbacks/default.yaml +34 -0
- quadra/configs/callbacks/default_anomalib.yaml +64 -0
- quadra/configs/config.yaml +33 -0
- quadra/configs/core/default.yaml +11 -0
- quadra/configs/datamodule/base/anomaly.yaml +16 -0
- quadra/configs/datamodule/base/classification.yaml +21 -0
- quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
- quadra/configs/datamodule/base/segmentation.yaml +18 -0
- quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
- quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
- quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
- quadra/configs/datamodule/base/ssl.yaml +21 -0
- quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
- quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
- quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
- quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
- quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
- quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
- quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
- quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
- quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
- quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
- quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/base/classification/classification.yaml +73 -0
- quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
- quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
- quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
- quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
- quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
- quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
- quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
- quadra/configs/experiment/base/ssl/byol.yaml +43 -0
- quadra/configs/experiment/base/ssl/dino.yaml +46 -0
- quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
- quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
- quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
- quadra/configs/experiment/custom/cls.yaml +12 -0
- quadra/configs/experiment/default.yaml +15 -0
- quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
- quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
- quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
- quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
- quadra/configs/export/default.yaml +13 -0
- quadra/configs/hydra/anomaly_custom.yaml +15 -0
- quadra/configs/hydra/default.yaml +14 -0
- quadra/configs/inference/default.yaml +26 -0
- quadra/configs/logger/comet.yaml +10 -0
- quadra/configs/logger/csv.yaml +5 -0
- quadra/configs/logger/mlflow.yaml +12 -0
- quadra/configs/logger/tensorboard.yaml +8 -0
- quadra/configs/loss/asl.yaml +7 -0
- quadra/configs/loss/barlow.yaml +2 -0
- quadra/configs/loss/bce.yaml +1 -0
- quadra/configs/loss/byol.yaml +1 -0
- quadra/configs/loss/cross_entropy.yaml +1 -0
- quadra/configs/loss/dino.yaml +8 -0
- quadra/configs/loss/simclr.yaml +2 -0
- quadra/configs/loss/simsiam.yaml +1 -0
- quadra/configs/loss/smp_ce.yaml +3 -0
- quadra/configs/loss/smp_dice.yaml +2 -0
- quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
- quadra/configs/loss/smp_mcc.yaml +2 -0
- quadra/configs/loss/vicreg.yaml +5 -0
- quadra/configs/model/anomalib/cfa.yaml +35 -0
- quadra/configs/model/anomalib/cflow.yaml +30 -0
- quadra/configs/model/anomalib/csflow.yaml +34 -0
- quadra/configs/model/anomalib/dfm.yaml +19 -0
- quadra/configs/model/anomalib/draem.yaml +29 -0
- quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
- quadra/configs/model/anomalib/fastflow.yaml +32 -0
- quadra/configs/model/anomalib/padim.yaml +32 -0
- quadra/configs/model/anomalib/patchcore.yaml +36 -0
- quadra/configs/model/barlow.yaml +16 -0
- quadra/configs/model/byol.yaml +25 -0
- quadra/configs/model/classification.yaml +10 -0
- quadra/configs/model/dino.yaml +26 -0
- quadra/configs/model/logistic_regression.yaml +4 -0
- quadra/configs/model/multilabel_classification.yaml +9 -0
- quadra/configs/model/simclr.yaml +18 -0
- quadra/configs/model/simsiam.yaml +24 -0
- quadra/configs/model/smp.yaml +4 -0
- quadra/configs/model/smp_multiclass.yaml +4 -0
- quadra/configs/model/vicreg.yaml +16 -0
- quadra/configs/optimizer/adam.yaml +5 -0
- quadra/configs/optimizer/adamw.yaml +3 -0
- quadra/configs/optimizer/default.yaml +4 -0
- quadra/configs/optimizer/lars.yaml +8 -0
- quadra/configs/optimizer/sgd.yaml +4 -0
- quadra/configs/scheduler/default.yaml +5 -0
- quadra/configs/scheduler/rop.yaml +5 -0
- quadra/configs/scheduler/step.yaml +3 -0
- quadra/configs/scheduler/warmrestart.yaml +2 -0
- quadra/configs/scheduler/warmup.yaml +6 -0
- quadra/configs/task/anomalib/cfa.yaml +5 -0
- quadra/configs/task/anomalib/cflow.yaml +5 -0
- quadra/configs/task/anomalib/csflow.yaml +5 -0
- quadra/configs/task/anomalib/draem.yaml +5 -0
- quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
- quadra/configs/task/anomalib/fastflow.yaml +5 -0
- quadra/configs/task/anomalib/inference.yaml +3 -0
- quadra/configs/task/anomalib/padim.yaml +5 -0
- quadra/configs/task/anomalib/patchcore.yaml +5 -0
- quadra/configs/task/classification.yaml +6 -0
- quadra/configs/task/classification_evaluation.yaml +6 -0
- quadra/configs/task/default.yaml +1 -0
- quadra/configs/task/segmentation.yaml +9 -0
- quadra/configs/task/segmentation_evaluation.yaml +3 -0
- quadra/configs/task/sklearn_classification.yaml +13 -0
- quadra/configs/task/sklearn_classification_patch.yaml +11 -0
- quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
- quadra/configs/task/sklearn_classification_test.yaml +8 -0
- quadra/configs/task/ssl.yaml +2 -0
- quadra/configs/trainer/lightning_cpu.yaml +36 -0
- quadra/configs/trainer/lightning_gpu.yaml +35 -0
- quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
- quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
- quadra/configs/trainer/lightning_multigpu.yaml +37 -0
- quadra/configs/trainer/sklearn_classification.yaml +7 -0
- quadra/configs/transforms/byol.yaml +47 -0
- quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
- quadra/configs/transforms/default.yaml +37 -0
- quadra/configs/transforms/default_numpy.yaml +24 -0
- quadra/configs/transforms/default_resize.yaml +22 -0
- quadra/configs/transforms/dino.yaml +63 -0
- quadra/configs/transforms/linear_eval.yaml +18 -0
- quadra/datamodules/__init__.py +20 -0
- quadra/datamodules/anomaly.py +180 -0
- quadra/datamodules/base.py +375 -0
- quadra/datamodules/classification.py +1003 -0
- quadra/datamodules/generic/__init__.py +0 -0
- quadra/datamodules/generic/imagenette.py +144 -0
- quadra/datamodules/generic/mnist.py +81 -0
- quadra/datamodules/generic/mvtec.py +58 -0
- quadra/datamodules/generic/oxford_pet.py +163 -0
- quadra/datamodules/patch.py +190 -0
- quadra/datamodules/segmentation.py +742 -0
- quadra/datamodules/ssl.py +140 -0
- quadra/datasets/__init__.py +17 -0
- quadra/datasets/anomaly.py +287 -0
- quadra/datasets/classification.py +241 -0
- quadra/datasets/patch.py +138 -0
- quadra/datasets/segmentation.py +239 -0
- quadra/datasets/ssl.py +110 -0
- quadra/losses/__init__.py +0 -0
- quadra/losses/classification/__init__.py +6 -0
- quadra/losses/classification/asl.py +83 -0
- quadra/losses/classification/focal.py +320 -0
- quadra/losses/classification/prototypical.py +148 -0
- quadra/losses/ssl/__init__.py +17 -0
- quadra/losses/ssl/barlowtwins.py +47 -0
- quadra/losses/ssl/byol.py +37 -0
- quadra/losses/ssl/dino.py +129 -0
- quadra/losses/ssl/hyperspherical.py +45 -0
- quadra/losses/ssl/idmm.py +50 -0
- quadra/losses/ssl/simclr.py +67 -0
- quadra/losses/ssl/simsiam.py +30 -0
- quadra/losses/ssl/vicreg.py +76 -0
- quadra/main.py +49 -0
- quadra/metrics/__init__.py +3 -0
- quadra/metrics/segmentation.py +251 -0
- quadra/models/__init__.py +0 -0
- quadra/models/base.py +151 -0
- quadra/models/classification/__init__.py +8 -0
- quadra/models/classification/backbones.py +149 -0
- quadra/models/classification/base.py +92 -0
- quadra/models/evaluation.py +322 -0
- quadra/modules/__init__.py +0 -0
- quadra/modules/backbone.py +30 -0
- quadra/modules/base.py +312 -0
- quadra/modules/classification/__init__.py +3 -0
- quadra/modules/classification/base.py +327 -0
- quadra/modules/ssl/__init__.py +17 -0
- quadra/modules/ssl/barlowtwins.py +59 -0
- quadra/modules/ssl/byol.py +172 -0
- quadra/modules/ssl/common.py +285 -0
- quadra/modules/ssl/dino.py +186 -0
- quadra/modules/ssl/hyperspherical.py +206 -0
- quadra/modules/ssl/idmm.py +98 -0
- quadra/modules/ssl/simclr.py +73 -0
- quadra/modules/ssl/simsiam.py +68 -0
- quadra/modules/ssl/vicreg.py +67 -0
- quadra/optimizers/__init__.py +4 -0
- quadra/optimizers/lars.py +153 -0
- quadra/optimizers/sam.py +127 -0
- quadra/schedulers/__init__.py +3 -0
- quadra/schedulers/base.py +44 -0
- quadra/schedulers/warmup.py +127 -0
- quadra/tasks/__init__.py +24 -0
- quadra/tasks/anomaly.py +582 -0
- quadra/tasks/base.py +397 -0
- quadra/tasks/classification.py +1263 -0
- quadra/tasks/patch.py +492 -0
- quadra/tasks/segmentation.py +389 -0
- quadra/tasks/ssl.py +560 -0
- quadra/trainers/README.md +3 -0
- quadra/trainers/__init__.py +0 -0
- quadra/trainers/classification.py +179 -0
- quadra/utils/__init__.py +0 -0
- quadra/utils/anomaly.py +112 -0
- quadra/utils/classification.py +618 -0
- quadra/utils/deprecation.py +31 -0
- quadra/utils/evaluation.py +474 -0
- quadra/utils/export.py +585 -0
- quadra/utils/imaging.py +32 -0
- quadra/utils/logger.py +15 -0
- quadra/utils/mlflow.py +98 -0
- quadra/utils/model_manager.py +320 -0
- quadra/utils/models.py +523 -0
- quadra/utils/patch/__init__.py +15 -0
- quadra/utils/patch/dataset.py +1433 -0
- quadra/utils/patch/metrics.py +449 -0
- quadra/utils/patch/model.py +153 -0
- quadra/utils/patch/visualization.py +217 -0
- quadra/utils/resolver.py +42 -0
- quadra/utils/segmentation.py +31 -0
- quadra/utils/tests/__init__.py +0 -0
- quadra/utils/tests/fixtures/__init__.py +1 -0
- quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
- quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
- quadra/utils/tests/fixtures/dataset/classification.py +406 -0
- quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
- quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
- quadra/utils/tests/fixtures/models/__init__.py +3 -0
- quadra/utils/tests/fixtures/models/anomaly.py +89 -0
- quadra/utils/tests/fixtures/models/classification.py +45 -0
- quadra/utils/tests/fixtures/models/segmentation.py +33 -0
- quadra/utils/tests/helpers.py +70 -0
- quadra/utils/tests/models.py +27 -0
- quadra/utils/utils.py +525 -0
- quadra/utils/validator.py +115 -0
- quadra/utils/visualization.py +422 -0
- quadra/utils/vit_explainability.py +349 -0
- quadra-2.2.7.dist-info/LICENSE +201 -0
- quadra-2.2.7.dist-info/METADATA +381 -0
- quadra-2.2.7.dist-info/RECORD +300 -0
- {quadra-0.0.1.dist-info → quadra-2.2.7.dist-info}/WHEEL +1 -1
- quadra-2.2.7.dist-info/entry_points.txt +3 -0
- quadra-0.0.1.dist-info/METADATA +0 -14
- quadra-0.0.1.dist-info/RECORD +0 -4
quadra/tasks/anomaly.py
ADDED
|
@@ -0,0 +1,582 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import csv
|
|
4
|
+
import glob
|
|
5
|
+
import json
|
|
6
|
+
import os
|
|
7
|
+
from collections import Counter
|
|
8
|
+
from typing import Any, Generic, Literal, TypeVar, cast
|
|
9
|
+
|
|
10
|
+
import cv2
|
|
11
|
+
import hydra
|
|
12
|
+
import numpy as np
|
|
13
|
+
import torch
|
|
14
|
+
from anomalib.models.components.base import AnomalyModule
|
|
15
|
+
from anomalib.post_processing import anomaly_map_to_color_map
|
|
16
|
+
from anomalib.utils import plot_cumulative_histogram
|
|
17
|
+
from anomalib.utils.callbacks.min_max_normalization import MinMaxNormalizationCallback
|
|
18
|
+
from anomalib.utils.metrics.optimal_f1 import OptimalF1
|
|
19
|
+
from matplotlib import pyplot as plt
|
|
20
|
+
from omegaconf import DictConfig
|
|
21
|
+
from sklearn.metrics import ConfusionMatrixDisplay, f1_score
|
|
22
|
+
from tqdm import tqdm
|
|
23
|
+
|
|
24
|
+
from quadra.callbacks.mlflow import get_mlflow_logger
|
|
25
|
+
from quadra.datamodules import AnomalyDataModule
|
|
26
|
+
from quadra.modules.base import ModelSignatureWrapper
|
|
27
|
+
from quadra.tasks.base import Evaluation, LightningTask
|
|
28
|
+
from quadra.utils import utils
|
|
29
|
+
from quadra.utils.anomaly import MapOrValue, ThresholdNormalizationCallback, normalize_anomaly_score
|
|
30
|
+
from quadra.utils.classification import get_results
|
|
31
|
+
from quadra.utils.evaluation import automatic_datamodule_batch_size
|
|
32
|
+
from quadra.utils.export import export_model
|
|
33
|
+
|
|
34
|
+
log = utils.get_logger(__name__)
|
|
35
|
+
|
|
36
|
+
AnomalyDataModuleT = TypeVar("AnomalyDataModuleT", bound=AnomalyDataModule)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class AnomalibDetection(Generic[AnomalyDataModuleT], LightningTask[AnomalyDataModuleT]):
|
|
40
|
+
"""Anomaly Detection Task.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
config: The experiment configuration
|
|
44
|
+
module_function: The function that instantiates the module and model
|
|
45
|
+
checkpoint_path: The path to the checkpoint to load the model from.
|
|
46
|
+
Defaults to None.
|
|
47
|
+
run_test: Whether to run the test after training. Defaults to False.
|
|
48
|
+
report: Whether to report the results. Defaults to False.
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
def __init__(
|
|
52
|
+
self,
|
|
53
|
+
config: DictConfig,
|
|
54
|
+
module_function: DictConfig,
|
|
55
|
+
checkpoint_path: str | None = None,
|
|
56
|
+
run_test: bool = True,
|
|
57
|
+
report: bool = True,
|
|
58
|
+
):
|
|
59
|
+
super().__init__(
|
|
60
|
+
config=config,
|
|
61
|
+
checkpoint_path=checkpoint_path,
|
|
62
|
+
run_test=run_test,
|
|
63
|
+
report=report,
|
|
64
|
+
)
|
|
65
|
+
self._module: AnomalyModule
|
|
66
|
+
self.module_function = module_function
|
|
67
|
+
self.export_folder = "deployment_model"
|
|
68
|
+
self.report_path = ""
|
|
69
|
+
self.test_results: list[dict] | None = None
|
|
70
|
+
|
|
71
|
+
@property
|
|
72
|
+
def module(self) -> AnomalyModule:
|
|
73
|
+
"""Get the module."""
|
|
74
|
+
return self._module
|
|
75
|
+
|
|
76
|
+
@module.setter
|
|
77
|
+
def module(self, module_config):
|
|
78
|
+
"""Set the module."""
|
|
79
|
+
if hasattr(self.config.model.model, "input_size"):
|
|
80
|
+
transform_height = self.config.transforms.input_height
|
|
81
|
+
transform_width = self.config.transforms.input_width
|
|
82
|
+
original_model_height, original_model_width = self.config.model.model.input_size
|
|
83
|
+
|
|
84
|
+
if transform_height != original_model_height or transform_width != original_model_width:
|
|
85
|
+
log.warning(
|
|
86
|
+
"Model input size %dx%d "
|
|
87
|
+
"does not match the transform size %dx%d. "
|
|
88
|
+
"The model input size will be updated to match the transform size.",
|
|
89
|
+
original_model_height,
|
|
90
|
+
original_model_width,
|
|
91
|
+
transform_height,
|
|
92
|
+
transform_width,
|
|
93
|
+
)
|
|
94
|
+
self.config.model.model.input_size = [transform_height, transform_width]
|
|
95
|
+
|
|
96
|
+
_module = cast(
|
|
97
|
+
AnomalyModule,
|
|
98
|
+
hydra.utils.instantiate(
|
|
99
|
+
self.module_function,
|
|
100
|
+
module_config,
|
|
101
|
+
),
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
self._module = _module
|
|
105
|
+
|
|
106
|
+
def prepare(self) -> None:
|
|
107
|
+
"""Prepare the task."""
|
|
108
|
+
super().prepare()
|
|
109
|
+
self.module = self.config.model
|
|
110
|
+
self.module.model = ModelSignatureWrapper(self.module.model)
|
|
111
|
+
|
|
112
|
+
def export(self) -> None:
|
|
113
|
+
"""Export model for production."""
|
|
114
|
+
if self.config.trainer.get("fast_dev_run"):
|
|
115
|
+
log.warning("Skipping export since fast_dev_run is enabled")
|
|
116
|
+
return
|
|
117
|
+
|
|
118
|
+
model = self.module.model
|
|
119
|
+
|
|
120
|
+
input_shapes = self.config.export.input_shapes
|
|
121
|
+
|
|
122
|
+
half_precision = "16" in self.trainer.precision
|
|
123
|
+
|
|
124
|
+
model_json, export_paths = export_model(
|
|
125
|
+
config=self.config,
|
|
126
|
+
model=model,
|
|
127
|
+
export_folder=self.export_folder,
|
|
128
|
+
half_precision=half_precision,
|
|
129
|
+
input_shapes=input_shapes,
|
|
130
|
+
idx_to_class={0: "good", 1: "defect"},
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
if len(export_paths) == 0:
|
|
134
|
+
return
|
|
135
|
+
|
|
136
|
+
model_json["image_threshold"] = np.round(self.module.image_threshold.value.item(), 3)
|
|
137
|
+
model_json["pixel_threshold"] = np.round(self.module.pixel_threshold.value.item(), 3)
|
|
138
|
+
model_json["anomaly_method"] = self.config.model.model.name
|
|
139
|
+
|
|
140
|
+
with open(os.path.join(self.export_folder, "model.json"), "w") as f:
|
|
141
|
+
json.dump(model_json, f, cls=utils.HydraEncoder)
|
|
142
|
+
|
|
143
|
+
def test(self) -> Any:
|
|
144
|
+
"""Lightning test."""
|
|
145
|
+
self.test_results = super().test()
|
|
146
|
+
return self.test_results
|
|
147
|
+
|
|
148
|
+
def _generate_report(self) -> None:
|
|
149
|
+
"""Generate a report for the task."""
|
|
150
|
+
if len(self.report_path) > 0:
|
|
151
|
+
os.makedirs(self.report_path, exist_ok=True)
|
|
152
|
+
|
|
153
|
+
# Save json with test results
|
|
154
|
+
if self.test_results is not None:
|
|
155
|
+
with open(os.path.join(self.report_path, "test_results.json"), "w") as f:
|
|
156
|
+
json.dump(self.test_results[0], f)
|
|
157
|
+
|
|
158
|
+
all_output = cast(
|
|
159
|
+
list[dict], self.trainer.predict(model=self.module, dataloaders=self.datamodule.test_dataloader())
|
|
160
|
+
)
|
|
161
|
+
all_output_flatten: dict[str, torch.Tensor | list] = {}
|
|
162
|
+
|
|
163
|
+
for key in all_output[0]:
|
|
164
|
+
if type(all_output[0][key]) == torch.Tensor:
|
|
165
|
+
tensor_gatherer = torch.cat([x[key] for x in all_output])
|
|
166
|
+
all_output_flatten[key] = tensor_gatherer
|
|
167
|
+
else:
|
|
168
|
+
list_gatherer = []
|
|
169
|
+
for x in all_output:
|
|
170
|
+
list_gatherer.extend(x[key])
|
|
171
|
+
all_output_flatten[key] = list_gatherer
|
|
172
|
+
|
|
173
|
+
image_paths = all_output_flatten["image_path"]
|
|
174
|
+
named_labels = [x.split("/")[-2] for x in all_output_flatten["image_path"]]
|
|
175
|
+
|
|
176
|
+
class_to_idx = {"good": 0}
|
|
177
|
+
idx = 1
|
|
178
|
+
for cls in set(named_labels):
|
|
179
|
+
if cls == "good":
|
|
180
|
+
continue
|
|
181
|
+
|
|
182
|
+
class_to_idx[cls] = idx
|
|
183
|
+
idx += 1
|
|
184
|
+
|
|
185
|
+
class_to_idx["false_defect"] = idx
|
|
186
|
+
idx_to_class = {v: k for k, v in class_to_idx.items()}
|
|
187
|
+
|
|
188
|
+
gt_labels = [class_to_idx[x] for x in named_labels]
|
|
189
|
+
pred_labels = []
|
|
190
|
+
for i, _ in enumerate(named_labels):
|
|
191
|
+
pred_label = all_output_flatten["pred_labels"][i].item()
|
|
192
|
+
|
|
193
|
+
if pred_label == 0:
|
|
194
|
+
pred_labels.append(0)
|
|
195
|
+
elif pred_label == 1 and gt_labels[i] == 0:
|
|
196
|
+
if idx > 2:
|
|
197
|
+
pred_labels.append(class_to_idx["false_defect"])
|
|
198
|
+
else:
|
|
199
|
+
pred_labels.append(1)
|
|
200
|
+
else:
|
|
201
|
+
pred_labels.append(class_to_idx[named_labels[i]])
|
|
202
|
+
|
|
203
|
+
if class_to_idx["false_defect"] not in pred_labels:
|
|
204
|
+
# If there are no false defects remove the label from the confusion matrix
|
|
205
|
+
class_to_idx.pop("false_defect")
|
|
206
|
+
|
|
207
|
+
anomaly_scores = all_output_flatten["pred_scores"]
|
|
208
|
+
if isinstance(anomaly_scores, torch.Tensor):
|
|
209
|
+
exportable_anomaly_scores = anomaly_scores.cpu().numpy()
|
|
210
|
+
else:
|
|
211
|
+
exportable_anomaly_scores = anomaly_scores
|
|
212
|
+
|
|
213
|
+
# Zip the lists together to create rows for the CSV file
|
|
214
|
+
rows = zip(image_paths, pred_labels, gt_labels, exportable_anomaly_scores)
|
|
215
|
+
# Specify the CSV file name
|
|
216
|
+
csv_file = "test_predictions.csv"
|
|
217
|
+
# Write the data to the CSV file
|
|
218
|
+
with open(csv_file, mode="w", newline="") as file:
|
|
219
|
+
writer = csv.writer(file)
|
|
220
|
+
# Write the header if needed
|
|
221
|
+
writer.writerow(["image_path", "predicted_label", "ground_truth_label", "predicted_score"])
|
|
222
|
+
# Write the rows
|
|
223
|
+
writer.writerows(rows)
|
|
224
|
+
|
|
225
|
+
log.info("CSV file %s has been created.", csv_file)
|
|
226
|
+
|
|
227
|
+
if not isinstance(anomaly_scores, torch.Tensor):
|
|
228
|
+
raise ValueError("Anomaly scores must be a tensor")
|
|
229
|
+
|
|
230
|
+
good_scores = anomaly_scores[np.where(all_output_flatten["label"] == 0)]
|
|
231
|
+
defect_scores = anomaly_scores[np.where(all_output_flatten["label"] == 1)]
|
|
232
|
+
|
|
233
|
+
# Lightning has a callback attribute but is not inside the __init__ so mypy complains
|
|
234
|
+
if any(
|
|
235
|
+
isinstance(x, MinMaxNormalizationCallback)
|
|
236
|
+
for x in self.trainer.callbacks # type: ignore[attr-defined]
|
|
237
|
+
):
|
|
238
|
+
threshold = torch.tensor(0.5)
|
|
239
|
+
elif any(
|
|
240
|
+
isinstance(x, ThresholdNormalizationCallback)
|
|
241
|
+
for x in self.trainer.callbacks # type: ignore[attr-defined]
|
|
242
|
+
):
|
|
243
|
+
threshold = torch.tensor(100.0)
|
|
244
|
+
else:
|
|
245
|
+
threshold = self.module.image_metrics.F1Score.threshold
|
|
246
|
+
|
|
247
|
+
# The output of the prediction is a normalized score so the cumulative histogram is displayed with the
|
|
248
|
+
# normalized scores
|
|
249
|
+
plot_cumulative_histogram(
|
|
250
|
+
good_scores.cpu().numpy(), defect_scores.cpu().numpy(), threshold.item(), self.report_path
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
_, pd_cm, _ = get_results(np.array(gt_labels), np.array(pred_labels), idx_to_class)
|
|
254
|
+
np_cm = np.array(pd_cm)
|
|
255
|
+
disp = ConfusionMatrixDisplay(
|
|
256
|
+
confusion_matrix=np_cm,
|
|
257
|
+
display_labels=class_to_idx.keys(),
|
|
258
|
+
)
|
|
259
|
+
disp.plot(include_values=True, cmap=plt.cm.Greens, ax=None, colorbar=False, xticks_rotation=90)
|
|
260
|
+
plt.title("Confusion Matrix")
|
|
261
|
+
plt.savefig(
|
|
262
|
+
os.path.join(self.report_path, "test_confusion_matrix.png"), bbox_inches="tight", pad_inches=0, dpi=300
|
|
263
|
+
)
|
|
264
|
+
plt.close()
|
|
265
|
+
|
|
266
|
+
avg_score_dict = {k: 0.0 for k in set(named_labels)}
|
|
267
|
+
|
|
268
|
+
for i, item in enumerate(named_labels):
|
|
269
|
+
avg_score_dict[item] += all_output_flatten["pred_scores"][i].item()
|
|
270
|
+
|
|
271
|
+
counter = Counter(named_labels)
|
|
272
|
+
avg_score_dict = {k: v / counter[k] for k, v in avg_score_dict.items()}
|
|
273
|
+
avg_score_dict = dict(sorted(avg_score_dict.items(), key=lambda q: q[1]))
|
|
274
|
+
|
|
275
|
+
with open(os.path.join(self.report_path, "avg_score_by_label.csv"), "w") as f:
|
|
276
|
+
f.write("label,avg_anomaly_score\n")
|
|
277
|
+
for k, v in avg_score_dict.items():
|
|
278
|
+
f.write(f"{k},{v:.3f}\n")
|
|
279
|
+
|
|
280
|
+
def generate_report(self):
|
|
281
|
+
"""Generate a report for the task and try to upload artifacts."""
|
|
282
|
+
self._generate_report()
|
|
283
|
+
self._upload_artifacts()
|
|
284
|
+
|
|
285
|
+
def _upload_artifacts(self):
|
|
286
|
+
"""If MLflow is available upload artifacts to the artifact repository."""
|
|
287
|
+
mflow_logger = get_mlflow_logger(trainer=self.trainer)
|
|
288
|
+
tensorboard_logger = utils.get_tensorboard_logger(trainer=self.trainer)
|
|
289
|
+
|
|
290
|
+
if mflow_logger is not None and self.config.core.get("upload_artifacts"):
|
|
291
|
+
mflow_logger.experiment.log_artifact(run_id=mflow_logger.run_id, local_path="test_confusion_matrix.png")
|
|
292
|
+
mflow_logger.experiment.log_artifact(run_id=mflow_logger.run_id, local_path="avg_score_by_label.csv")
|
|
293
|
+
|
|
294
|
+
if "visualizer" in self.config.callbacks:
|
|
295
|
+
artifacts = glob.glob(os.path.join(self.config.callbacks.visualizer.output_path, "**", "*"))
|
|
296
|
+
for a in artifacts:
|
|
297
|
+
mflow_logger.experiment.log_artifact(
|
|
298
|
+
run_id=mflow_logger.run_id, local_path=a, artifact_path="anomaly_output"
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
if tensorboard_logger is not None and self.config.core.get("upload_artifacts"):
|
|
302
|
+
artifacts = []
|
|
303
|
+
artifacts.append("test_confusion_matrix.png")
|
|
304
|
+
artifacts.append("avg_score_by_label.csv")
|
|
305
|
+
|
|
306
|
+
if "visualizer" in self.config.callbacks:
|
|
307
|
+
artifacts.extend(
|
|
308
|
+
glob.glob(os.path.join(self.config.callbacks.visualizer.output_path, "**/*"), recursive=True)
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
for a in artifacts:
|
|
312
|
+
if os.path.isdir(a):
|
|
313
|
+
continue
|
|
314
|
+
|
|
315
|
+
ext = os.path.splitext(a)[1].lower()
|
|
316
|
+
|
|
317
|
+
if ext in [".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".tif", ".gif"]:
|
|
318
|
+
try:
|
|
319
|
+
img = cv2.imread(a)
|
|
320
|
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
321
|
+
except cv2.error:
|
|
322
|
+
log.info("Could not upload artifact image %s", a)
|
|
323
|
+
continue
|
|
324
|
+
output_path = os.path.sep.join(a.split(os.path.sep)[-2:])
|
|
325
|
+
tensorboard_logger.experiment.add_image(output_path, img, 0, dataformats="HWC")
|
|
326
|
+
else:
|
|
327
|
+
utils.upload_file_tensorboard(a, tensorboard_logger)
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
class AnomalibEvaluation(Evaluation[AnomalyDataModule]):
|
|
331
|
+
"""Evaluation task for Anomalib.
|
|
332
|
+
|
|
333
|
+
Args:
|
|
334
|
+
config: Task configuration
|
|
335
|
+
model_path: Path to the model folder that contains an exported model
|
|
336
|
+
use_training_threshold: Whether to use the training threshold for the evaluation or use the one that
|
|
337
|
+
maximizes the F1 score on the test set.
|
|
338
|
+
device: Device to use for evaluation. If None, the device is automatically determined.
|
|
339
|
+
"""
|
|
340
|
+
|
|
341
|
+
def __init__(
|
|
342
|
+
self,
|
|
343
|
+
config: DictConfig,
|
|
344
|
+
model_path: str,
|
|
345
|
+
use_training_threshold: bool = False,
|
|
346
|
+
device: str | None = None,
|
|
347
|
+
training_threshold_type: Literal["image", "pixel"] | None = None,
|
|
348
|
+
):
|
|
349
|
+
super().__init__(config=config, model_path=model_path, device=device)
|
|
350
|
+
|
|
351
|
+
self.use_training_threshold = use_training_threshold
|
|
352
|
+
|
|
353
|
+
if training_threshold_type is not None and training_threshold_type not in ["image", "pixel"]:
|
|
354
|
+
raise ValueError("Training threshold type must be either image or pixel")
|
|
355
|
+
|
|
356
|
+
if training_threshold_type is None and use_training_threshold:
|
|
357
|
+
log.warning("Using training threshold but no training threshold type is provided, defaulting to image")
|
|
358
|
+
training_threshold_type = "image"
|
|
359
|
+
|
|
360
|
+
self.training_threshold_type = training_threshold_type
|
|
361
|
+
|
|
362
|
+
def prepare(self) -> None:
|
|
363
|
+
"""Prepare the evaluation."""
|
|
364
|
+
super().prepare()
|
|
365
|
+
self.datamodule = self.config.datamodule
|
|
366
|
+
# prepare_data() must be explicitly called because there is no lightning training
|
|
367
|
+
self.datamodule.prepare_data()
|
|
368
|
+
self.datamodule.setup(stage="test")
|
|
369
|
+
|
|
370
|
+
@automatic_datamodule_batch_size(batch_size_attribute_name="test_batch_size")
|
|
371
|
+
def test(self) -> None:
|
|
372
|
+
"""Perform test."""
|
|
373
|
+
log.info("Running test")
|
|
374
|
+
test_dataloader = self.datamodule.test_dataloader()
|
|
375
|
+
|
|
376
|
+
optimal_f1 = OptimalF1(num_classes=None, pos_label=1) # type: ignore[arg-type]
|
|
377
|
+
|
|
378
|
+
anomaly_scores = []
|
|
379
|
+
anomaly_maps = []
|
|
380
|
+
image_labels = []
|
|
381
|
+
image_paths = []
|
|
382
|
+
|
|
383
|
+
with torch.no_grad():
|
|
384
|
+
for batch_item in tqdm(test_dataloader):
|
|
385
|
+
batch_images = batch_item["image"]
|
|
386
|
+
batch_labels = batch_item["label"]
|
|
387
|
+
image_labels.extend(batch_labels.tolist())
|
|
388
|
+
image_paths.extend(batch_item["image_path"])
|
|
389
|
+
batch_images = batch_images.to(device=self.device, dtype=self.deployment_model.model_dtype)
|
|
390
|
+
if self.model_data.get("anomaly_method") == "efficientad":
|
|
391
|
+
model_output = self.deployment_model(batch_images, None)
|
|
392
|
+
else:
|
|
393
|
+
model_output = self.deployment_model(batch_images)
|
|
394
|
+
anomaly_map, anomaly_score = model_output[0], model_output[1]
|
|
395
|
+
anomaly_map = anomaly_map.cpu()
|
|
396
|
+
anomaly_score = anomaly_score.cpu()
|
|
397
|
+
known_labels = torch.where(batch_labels != -1)[0]
|
|
398
|
+
if len(known_labels) > 0:
|
|
399
|
+
# Skip computing F1 score for images without gt
|
|
400
|
+
optimal_f1.update(anomaly_score[known_labels], batch_labels[known_labels])
|
|
401
|
+
anomaly_scores.append(anomaly_score)
|
|
402
|
+
anomaly_maps.append(anomaly_map)
|
|
403
|
+
|
|
404
|
+
anomaly_scores = torch.cat(anomaly_scores)
|
|
405
|
+
anomaly_maps = torch.cat(anomaly_maps)
|
|
406
|
+
|
|
407
|
+
if any(x != -1 for x in image_labels):
|
|
408
|
+
if self.use_training_threshold:
|
|
409
|
+
_image_labels = torch.tensor(image_labels)
|
|
410
|
+
threshold = torch.tensor(float(self.model_data[f"{self.training_threshold_type}_threshold"]))
|
|
411
|
+
known_labels = torch.where(_image_labels != -1)[0]
|
|
412
|
+
|
|
413
|
+
_image_labels = _image_labels[known_labels]
|
|
414
|
+
_anomaly_scores = anomaly_scores[known_labels]
|
|
415
|
+
|
|
416
|
+
pred_labels = (_anomaly_scores >= threshold).long()
|
|
417
|
+
|
|
418
|
+
optimal_f1_score = torch.tensor(f1_score(_image_labels, pred_labels))
|
|
419
|
+
else:
|
|
420
|
+
optimal_f1_score = optimal_f1.compute()
|
|
421
|
+
threshold = optimal_f1.threshold
|
|
422
|
+
else:
|
|
423
|
+
log.warning("No ground truth available during evaluation, use training image threshold for reporting")
|
|
424
|
+
optimal_f1_score = torch.tensor(0)
|
|
425
|
+
threshold = torch.tensor(float(self.model_data["image_threshold"]))
|
|
426
|
+
|
|
427
|
+
log.info("Computed F1 score: %s", optimal_f1_score.item())
|
|
428
|
+
self.metadata["anomaly_scores"] = anomaly_scores
|
|
429
|
+
self.metadata["anomaly_maps"] = anomaly_maps
|
|
430
|
+
self.metadata["image_labels"] = image_labels
|
|
431
|
+
self.metadata["image_paths"] = image_paths
|
|
432
|
+
self.metadata["threshold"] = threshold.item()
|
|
433
|
+
self.metadata["optimal_f1"] = optimal_f1_score.item()
|
|
434
|
+
|
|
435
|
+
def generate_report(self) -> None:
|
|
436
|
+
"""Generate report."""
|
|
437
|
+
log.info("Generating report")
|
|
438
|
+
if len(self.report_path) > 0:
|
|
439
|
+
os.makedirs(self.report_path, exist_ok=True)
|
|
440
|
+
|
|
441
|
+
# TODO: We currently don't use anomaly for segmentation, so the pixel threshold handling is not properly
|
|
442
|
+
# implemented and we produce as output only a single threshold.
|
|
443
|
+
training_threshold = self.model_data[f"{self.training_threshold_type}_threshold"]
|
|
444
|
+
optimal_threshold = self.metadata["threshold"]
|
|
445
|
+
|
|
446
|
+
normalized_optimal_threshold = cast(float, normalize_anomaly_score(optimal_threshold, training_threshold))
|
|
447
|
+
|
|
448
|
+
os.makedirs(os.path.join(self.report_path, "predictions"), exist_ok=True)
|
|
449
|
+
os.makedirs(os.path.join(self.report_path, "heatmaps"), exist_ok=True)
|
|
450
|
+
|
|
451
|
+
anomaly_scores = self.metadata["anomaly_scores"].cpu().numpy()
|
|
452
|
+
anomaly_scores = normalize_anomaly_score(anomaly_scores, training_threshold)
|
|
453
|
+
|
|
454
|
+
if not isinstance(anomaly_scores, np.ndarray):
|
|
455
|
+
raise ValueError("Anomaly scores must be a numpy array")
|
|
456
|
+
|
|
457
|
+
good_scores = anomaly_scores[np.where(np.array(self.metadata["image_labels"]) == 0)]
|
|
458
|
+
defect_scores = anomaly_scores[np.where(np.array(self.metadata["image_labels"]) == 1)]
|
|
459
|
+
|
|
460
|
+
count_overlapping_scores = 0
|
|
461
|
+
|
|
462
|
+
if len(good_scores) != 0 and len(defect_scores) != 0 and defect_scores.min() <= good_scores.max():
|
|
463
|
+
count_overlapping_scores = len(
|
|
464
|
+
np.where((anomaly_scores >= defect_scores.min()) & (anomaly_scores <= good_scores.max()))[0]
|
|
465
|
+
)
|
|
466
|
+
|
|
467
|
+
plot_cumulative_histogram(good_scores, defect_scores, normalized_optimal_threshold, self.report_path)
|
|
468
|
+
|
|
469
|
+
json_output = {
|
|
470
|
+
"observations": [],
|
|
471
|
+
"threshold": np.round(normalized_optimal_threshold, 3),
|
|
472
|
+
"unnormalized_threshold": np.round(optimal_threshold, 3),
|
|
473
|
+
"f1_score": np.round(self.metadata["optimal_f1"], 3),
|
|
474
|
+
"metrics": {
|
|
475
|
+
"overlapping_scores": count_overlapping_scores,
|
|
476
|
+
},
|
|
477
|
+
}
|
|
478
|
+
|
|
479
|
+
tg, fb, fg, tb = 0, 0, 0, 0
|
|
480
|
+
|
|
481
|
+
mask_area = None
|
|
482
|
+
crop_area = None
|
|
483
|
+
|
|
484
|
+
if hasattr(self.datamodule, "valid_area_mask") and self.datamodule.valid_area_mask is not None:
|
|
485
|
+
mask_area = cv2.imread(self.datamodule.valid_area_mask, 0)
|
|
486
|
+
mask_area = (mask_area > 0).astype(np.uint8) # type: ignore[operator]
|
|
487
|
+
|
|
488
|
+
if hasattr(self.datamodule, "crop_area") and self.datamodule.crop_area is not None:
|
|
489
|
+
crop_area = self.datamodule.crop_area
|
|
490
|
+
|
|
491
|
+
anomaly_maps = normalize_anomaly_score(self.metadata["anomaly_maps"], training_threshold)
|
|
492
|
+
|
|
493
|
+
if not isinstance(anomaly_maps, torch.Tensor):
|
|
494
|
+
raise ValueError("Anomaly maps must be a tensor")
|
|
495
|
+
|
|
496
|
+
for img_path, gt_label, anomaly_score, anomaly_map in tqdm(
|
|
497
|
+
zip(
|
|
498
|
+
self.metadata["image_paths"],
|
|
499
|
+
self.metadata["image_labels"],
|
|
500
|
+
anomaly_scores,
|
|
501
|
+
anomaly_maps,
|
|
502
|
+
),
|
|
503
|
+
total=len(self.metadata["image_paths"]),
|
|
504
|
+
):
|
|
505
|
+
img = cv2.imread(img_path, 0)
|
|
506
|
+
if mask_area is not None:
|
|
507
|
+
img = img * mask_area # type: ignore[operator]
|
|
508
|
+
|
|
509
|
+
if crop_area is not None:
|
|
510
|
+
img = img[crop_area[1] : crop_area[3], crop_area[0] : crop_area[2]]
|
|
511
|
+
|
|
512
|
+
output_mask = (anomaly_map >= normalized_optimal_threshold).cpu().numpy().squeeze().astype(np.uint8)
|
|
513
|
+
output_mask_label = os.path.basename(os.path.dirname(img_path))
|
|
514
|
+
output_mask_name = os.path.splitext(os.path.basename(img_path))[0] + ".png"
|
|
515
|
+
pred_label = int(anomaly_score >= normalized_optimal_threshold)
|
|
516
|
+
|
|
517
|
+
json_output["observations"].append(
|
|
518
|
+
{
|
|
519
|
+
"image_path": os.path.dirname(img_path),
|
|
520
|
+
"file_name": os.path.basename(img_path),
|
|
521
|
+
"expectation": gt_label if gt_label != -1 else "",
|
|
522
|
+
"prediction": pred_label,
|
|
523
|
+
"prediction_mask": os.path.join("predictions", output_mask_label, output_mask_name),
|
|
524
|
+
"prediction_heatmap": os.path.join("heatmaps", output_mask_label, output_mask_name),
|
|
525
|
+
"is_correct": pred_label == gt_label if gt_label != -1 else True,
|
|
526
|
+
"anomaly_score": f"{anomaly_score.item():.3f}",
|
|
527
|
+
}
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
if gt_label == 0 and pred_label == 0:
|
|
531
|
+
tg += 1
|
|
532
|
+
elif gt_label == 0 and pred_label == 1:
|
|
533
|
+
fb += 1
|
|
534
|
+
elif gt_label == 1 and pred_label == 0:
|
|
535
|
+
fg += 1
|
|
536
|
+
elif gt_label == 1 and pred_label == 1:
|
|
537
|
+
tb += 1
|
|
538
|
+
|
|
539
|
+
output_mask = output_mask * 255
|
|
540
|
+
output_mask = cv2.resize(output_mask, (img.shape[1], img.shape[0]))
|
|
541
|
+
output_prediction_folder = os.path.join(self.report_path, "predictions", output_mask_label)
|
|
542
|
+
os.makedirs(output_prediction_folder, exist_ok=True)
|
|
543
|
+
cv2.imwrite(os.path.join(output_prediction_folder, output_mask_name), output_mask)
|
|
544
|
+
|
|
545
|
+
# Normalize the map and rescale it to 0-1 range
|
|
546
|
+
# In this case we are saying that the anomaly map is in the range [normalized_th - 50, normalized_th + 50]
|
|
547
|
+
# This allow to have a stronger color for the anomalies and a lighter one for really normal regions
|
|
548
|
+
# It's also independent from the max or min anomaly score!
|
|
549
|
+
normalized_map: MapOrValue = (anomaly_map - (normalized_optimal_threshold - 50)) / 100
|
|
550
|
+
|
|
551
|
+
if isinstance(normalized_map, torch.Tensor):
|
|
552
|
+
normalized_map = normalized_map.cpu().numpy().squeeze()
|
|
553
|
+
|
|
554
|
+
normalized_map = np.clip(normalized_map, 0, 1)
|
|
555
|
+
output_heatmap = anomaly_map_to_color_map(normalized_map, normalize=False)
|
|
556
|
+
output_heatmap = cv2.resize(output_heatmap, (img.shape[1], img.shape[0]))
|
|
557
|
+
|
|
558
|
+
output_heatmap_folder = os.path.join(self.report_path, "heatmaps", output_mask_label)
|
|
559
|
+
os.makedirs(output_heatmap_folder, exist_ok=True)
|
|
560
|
+
|
|
561
|
+
cv2.imwrite(
|
|
562
|
+
os.path.join(output_heatmap_folder, output_mask_name),
|
|
563
|
+
cv2.cvtColor(output_heatmap, cv2.COLOR_RGB2BGR),
|
|
564
|
+
)
|
|
565
|
+
|
|
566
|
+
json_output["metrics"]["confusion_matrix"] = {
|
|
567
|
+
"class_labels": ["normal", "anomaly"],
|
|
568
|
+
"matrix": [
|
|
569
|
+
[tg, fb],
|
|
570
|
+
[fg, tb],
|
|
571
|
+
],
|
|
572
|
+
}
|
|
573
|
+
|
|
574
|
+
with open(os.path.join(self.report_path, "anomaly_test_output.json"), "w") as f:
|
|
575
|
+
json.dump(json_output, f)
|
|
576
|
+
|
|
577
|
+
def execute(self) -> None:
|
|
578
|
+
"""Execute the evaluation."""
|
|
579
|
+
self.prepare()
|
|
580
|
+
self.test()
|
|
581
|
+
self.generate_report()
|
|
582
|
+
self.finalize()
|