quadra 0.0.1__py3-none-any.whl → 2.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hydra_plugins/quadra_searchpath_plugin.py +30 -0
- quadra/__init__.py +6 -0
- quadra/callbacks/__init__.py +0 -0
- quadra/callbacks/anomalib.py +289 -0
- quadra/callbacks/lightning.py +501 -0
- quadra/callbacks/mlflow.py +291 -0
- quadra/callbacks/scheduler.py +69 -0
- quadra/configs/__init__.py +0 -0
- quadra/configs/backbone/caformer_m36.yaml +8 -0
- quadra/configs/backbone/caformer_s36.yaml +8 -0
- quadra/configs/backbone/convnextv2_base.yaml +8 -0
- quadra/configs/backbone/convnextv2_femto.yaml +8 -0
- quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
- quadra/configs/backbone/dino_vitb8.yaml +12 -0
- quadra/configs/backbone/dino_vits8.yaml +12 -0
- quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
- quadra/configs/backbone/dinov2_vits14.yaml +12 -0
- quadra/configs/backbone/efficientnet_b0.yaml +8 -0
- quadra/configs/backbone/efficientnet_b1.yaml +8 -0
- quadra/configs/backbone/efficientnet_b2.yaml +8 -0
- quadra/configs/backbone/efficientnet_b3.yaml +8 -0
- quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
- quadra/configs/backbone/levit_128s.yaml +8 -0
- quadra/configs/backbone/mnasnet0_5.yaml +9 -0
- quadra/configs/backbone/resnet101.yaml +8 -0
- quadra/configs/backbone/resnet18.yaml +8 -0
- quadra/configs/backbone/resnet18_ssl.yaml +8 -0
- quadra/configs/backbone/resnet50.yaml +8 -0
- quadra/configs/backbone/smp.yaml +9 -0
- quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
- quadra/configs/backbone/unetr.yaml +15 -0
- quadra/configs/backbone/vit16_base.yaml +9 -0
- quadra/configs/backbone/vit16_small.yaml +9 -0
- quadra/configs/backbone/vit16_tiny.yaml +9 -0
- quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
- quadra/configs/callbacks/all.yaml +45 -0
- quadra/configs/callbacks/default.yaml +34 -0
- quadra/configs/callbacks/default_anomalib.yaml +64 -0
- quadra/configs/config.yaml +33 -0
- quadra/configs/core/default.yaml +11 -0
- quadra/configs/datamodule/base/anomaly.yaml +16 -0
- quadra/configs/datamodule/base/classification.yaml +21 -0
- quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
- quadra/configs/datamodule/base/segmentation.yaml +18 -0
- quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
- quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
- quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
- quadra/configs/datamodule/base/ssl.yaml +21 -0
- quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
- quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
- quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
- quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
- quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
- quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
- quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
- quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
- quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
- quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
- quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/base/classification/classification.yaml +73 -0
- quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
- quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
- quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
- quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
- quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
- quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
- quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
- quadra/configs/experiment/base/ssl/byol.yaml +43 -0
- quadra/configs/experiment/base/ssl/dino.yaml +46 -0
- quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
- quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
- quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
- quadra/configs/experiment/custom/cls.yaml +12 -0
- quadra/configs/experiment/default.yaml +15 -0
- quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
- quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
- quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
- quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
- quadra/configs/export/default.yaml +13 -0
- quadra/configs/hydra/anomaly_custom.yaml +15 -0
- quadra/configs/hydra/default.yaml +14 -0
- quadra/configs/inference/default.yaml +26 -0
- quadra/configs/logger/comet.yaml +10 -0
- quadra/configs/logger/csv.yaml +5 -0
- quadra/configs/logger/mlflow.yaml +12 -0
- quadra/configs/logger/tensorboard.yaml +8 -0
- quadra/configs/loss/asl.yaml +7 -0
- quadra/configs/loss/barlow.yaml +2 -0
- quadra/configs/loss/bce.yaml +1 -0
- quadra/configs/loss/byol.yaml +1 -0
- quadra/configs/loss/cross_entropy.yaml +1 -0
- quadra/configs/loss/dino.yaml +8 -0
- quadra/configs/loss/simclr.yaml +2 -0
- quadra/configs/loss/simsiam.yaml +1 -0
- quadra/configs/loss/smp_ce.yaml +3 -0
- quadra/configs/loss/smp_dice.yaml +2 -0
- quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
- quadra/configs/loss/smp_mcc.yaml +2 -0
- quadra/configs/loss/vicreg.yaml +5 -0
- quadra/configs/model/anomalib/cfa.yaml +35 -0
- quadra/configs/model/anomalib/cflow.yaml +30 -0
- quadra/configs/model/anomalib/csflow.yaml +34 -0
- quadra/configs/model/anomalib/dfm.yaml +19 -0
- quadra/configs/model/anomalib/draem.yaml +29 -0
- quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
- quadra/configs/model/anomalib/fastflow.yaml +32 -0
- quadra/configs/model/anomalib/padim.yaml +32 -0
- quadra/configs/model/anomalib/patchcore.yaml +36 -0
- quadra/configs/model/barlow.yaml +16 -0
- quadra/configs/model/byol.yaml +25 -0
- quadra/configs/model/classification.yaml +10 -0
- quadra/configs/model/dino.yaml +26 -0
- quadra/configs/model/logistic_regression.yaml +4 -0
- quadra/configs/model/multilabel_classification.yaml +9 -0
- quadra/configs/model/simclr.yaml +18 -0
- quadra/configs/model/simsiam.yaml +24 -0
- quadra/configs/model/smp.yaml +4 -0
- quadra/configs/model/smp_multiclass.yaml +4 -0
- quadra/configs/model/vicreg.yaml +16 -0
- quadra/configs/optimizer/adam.yaml +5 -0
- quadra/configs/optimizer/adamw.yaml +3 -0
- quadra/configs/optimizer/default.yaml +4 -0
- quadra/configs/optimizer/lars.yaml +8 -0
- quadra/configs/optimizer/sgd.yaml +4 -0
- quadra/configs/scheduler/default.yaml +5 -0
- quadra/configs/scheduler/rop.yaml +5 -0
- quadra/configs/scheduler/step.yaml +3 -0
- quadra/configs/scheduler/warmrestart.yaml +2 -0
- quadra/configs/scheduler/warmup.yaml +6 -0
- quadra/configs/task/anomalib/cfa.yaml +5 -0
- quadra/configs/task/anomalib/cflow.yaml +5 -0
- quadra/configs/task/anomalib/csflow.yaml +5 -0
- quadra/configs/task/anomalib/draem.yaml +5 -0
- quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
- quadra/configs/task/anomalib/fastflow.yaml +5 -0
- quadra/configs/task/anomalib/inference.yaml +3 -0
- quadra/configs/task/anomalib/padim.yaml +5 -0
- quadra/configs/task/anomalib/patchcore.yaml +5 -0
- quadra/configs/task/classification.yaml +6 -0
- quadra/configs/task/classification_evaluation.yaml +6 -0
- quadra/configs/task/default.yaml +1 -0
- quadra/configs/task/segmentation.yaml +9 -0
- quadra/configs/task/segmentation_evaluation.yaml +3 -0
- quadra/configs/task/sklearn_classification.yaml +13 -0
- quadra/configs/task/sklearn_classification_patch.yaml +11 -0
- quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
- quadra/configs/task/sklearn_classification_test.yaml +8 -0
- quadra/configs/task/ssl.yaml +2 -0
- quadra/configs/trainer/lightning_cpu.yaml +36 -0
- quadra/configs/trainer/lightning_gpu.yaml +35 -0
- quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
- quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
- quadra/configs/trainer/lightning_multigpu.yaml +37 -0
- quadra/configs/trainer/sklearn_classification.yaml +7 -0
- quadra/configs/transforms/byol.yaml +47 -0
- quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
- quadra/configs/transforms/default.yaml +37 -0
- quadra/configs/transforms/default_numpy.yaml +24 -0
- quadra/configs/transforms/default_resize.yaml +22 -0
- quadra/configs/transforms/dino.yaml +63 -0
- quadra/configs/transforms/linear_eval.yaml +18 -0
- quadra/datamodules/__init__.py +20 -0
- quadra/datamodules/anomaly.py +180 -0
- quadra/datamodules/base.py +375 -0
- quadra/datamodules/classification.py +1003 -0
- quadra/datamodules/generic/__init__.py +0 -0
- quadra/datamodules/generic/imagenette.py +144 -0
- quadra/datamodules/generic/mnist.py +81 -0
- quadra/datamodules/generic/mvtec.py +58 -0
- quadra/datamodules/generic/oxford_pet.py +163 -0
- quadra/datamodules/patch.py +190 -0
- quadra/datamodules/segmentation.py +742 -0
- quadra/datamodules/ssl.py +140 -0
- quadra/datasets/__init__.py +17 -0
- quadra/datasets/anomaly.py +287 -0
- quadra/datasets/classification.py +241 -0
- quadra/datasets/patch.py +138 -0
- quadra/datasets/segmentation.py +239 -0
- quadra/datasets/ssl.py +110 -0
- quadra/losses/__init__.py +0 -0
- quadra/losses/classification/__init__.py +6 -0
- quadra/losses/classification/asl.py +83 -0
- quadra/losses/classification/focal.py +320 -0
- quadra/losses/classification/prototypical.py +148 -0
- quadra/losses/ssl/__init__.py +17 -0
- quadra/losses/ssl/barlowtwins.py +47 -0
- quadra/losses/ssl/byol.py +37 -0
- quadra/losses/ssl/dino.py +129 -0
- quadra/losses/ssl/hyperspherical.py +45 -0
- quadra/losses/ssl/idmm.py +50 -0
- quadra/losses/ssl/simclr.py +67 -0
- quadra/losses/ssl/simsiam.py +30 -0
- quadra/losses/ssl/vicreg.py +76 -0
- quadra/main.py +49 -0
- quadra/metrics/__init__.py +3 -0
- quadra/metrics/segmentation.py +251 -0
- quadra/models/__init__.py +0 -0
- quadra/models/base.py +151 -0
- quadra/models/classification/__init__.py +8 -0
- quadra/models/classification/backbones.py +149 -0
- quadra/models/classification/base.py +92 -0
- quadra/models/evaluation.py +322 -0
- quadra/modules/__init__.py +0 -0
- quadra/modules/backbone.py +30 -0
- quadra/modules/base.py +312 -0
- quadra/modules/classification/__init__.py +3 -0
- quadra/modules/classification/base.py +327 -0
- quadra/modules/ssl/__init__.py +17 -0
- quadra/modules/ssl/barlowtwins.py +59 -0
- quadra/modules/ssl/byol.py +172 -0
- quadra/modules/ssl/common.py +285 -0
- quadra/modules/ssl/dino.py +186 -0
- quadra/modules/ssl/hyperspherical.py +206 -0
- quadra/modules/ssl/idmm.py +98 -0
- quadra/modules/ssl/simclr.py +73 -0
- quadra/modules/ssl/simsiam.py +68 -0
- quadra/modules/ssl/vicreg.py +67 -0
- quadra/optimizers/__init__.py +4 -0
- quadra/optimizers/lars.py +153 -0
- quadra/optimizers/sam.py +127 -0
- quadra/schedulers/__init__.py +3 -0
- quadra/schedulers/base.py +44 -0
- quadra/schedulers/warmup.py +127 -0
- quadra/tasks/__init__.py +24 -0
- quadra/tasks/anomaly.py +582 -0
- quadra/tasks/base.py +397 -0
- quadra/tasks/classification.py +1263 -0
- quadra/tasks/patch.py +492 -0
- quadra/tasks/segmentation.py +389 -0
- quadra/tasks/ssl.py +560 -0
- quadra/trainers/README.md +3 -0
- quadra/trainers/__init__.py +0 -0
- quadra/trainers/classification.py +179 -0
- quadra/utils/__init__.py +0 -0
- quadra/utils/anomaly.py +112 -0
- quadra/utils/classification.py +618 -0
- quadra/utils/deprecation.py +31 -0
- quadra/utils/evaluation.py +474 -0
- quadra/utils/export.py +585 -0
- quadra/utils/imaging.py +32 -0
- quadra/utils/logger.py +15 -0
- quadra/utils/mlflow.py +98 -0
- quadra/utils/model_manager.py +320 -0
- quadra/utils/models.py +523 -0
- quadra/utils/patch/__init__.py +15 -0
- quadra/utils/patch/dataset.py +1433 -0
- quadra/utils/patch/metrics.py +449 -0
- quadra/utils/patch/model.py +153 -0
- quadra/utils/patch/visualization.py +217 -0
- quadra/utils/resolver.py +42 -0
- quadra/utils/segmentation.py +31 -0
- quadra/utils/tests/__init__.py +0 -0
- quadra/utils/tests/fixtures/__init__.py +1 -0
- quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
- quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
- quadra/utils/tests/fixtures/dataset/classification.py +406 -0
- quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
- quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
- quadra/utils/tests/fixtures/models/__init__.py +3 -0
- quadra/utils/tests/fixtures/models/anomaly.py +89 -0
- quadra/utils/tests/fixtures/models/classification.py +45 -0
- quadra/utils/tests/fixtures/models/segmentation.py +33 -0
- quadra/utils/tests/helpers.py +70 -0
- quadra/utils/tests/models.py +27 -0
- quadra/utils/utils.py +525 -0
- quadra/utils/validator.py +115 -0
- quadra/utils/visualization.py +422 -0
- quadra/utils/vit_explainability.py +349 -0
- quadra-2.2.7.dist-info/LICENSE +201 -0
- quadra-2.2.7.dist-info/METADATA +381 -0
- quadra-2.2.7.dist-info/RECORD +300 -0
- {quadra-0.0.1.dist-info → quadra-2.2.7.dist-info}/WHEEL +1 -1
- quadra-2.2.7.dist-info/entry_points.txt +3 -0
- quadra-0.0.1.dist-info/METADATA +0 -14
- quadra-0.0.1.dist-info/RECORD +0 -4
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
import dotenv
|
|
4
|
+
from hydra.core.config_search_path import ConfigSearchPath
|
|
5
|
+
from hydra.plugins.search_path_plugin import SearchPathPlugin
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class QuadraSearchPathPlugin(SearchPathPlugin):
|
|
9
|
+
"""Generic Search Path Plugin class."""
|
|
10
|
+
|
|
11
|
+
def __init__(self):
|
|
12
|
+
try:
|
|
13
|
+
os.getcwd()
|
|
14
|
+
except FileNotFoundError:
|
|
15
|
+
# This may happen when running tests
|
|
16
|
+
return
|
|
17
|
+
|
|
18
|
+
if os.path.exists(os.path.join(os.getcwd(), ".env")):
|
|
19
|
+
dotenv.load_dotenv(dotenv_path=os.path.join(os.getcwd(), ".env"), override=True)
|
|
20
|
+
|
|
21
|
+
def manipulate_search_path(self, search_path: ConfigSearchPath) -> None:
|
|
22
|
+
"""Plugin used to add custom config to searchpath to be discovered by quadra."""
|
|
23
|
+
# This can be global or taken from the .env
|
|
24
|
+
quadra_search_path = os.environ.get("QUADRA_SEARCH_PATH", None)
|
|
25
|
+
|
|
26
|
+
# Path should be specified as a list of hydra path separated by ";"
|
|
27
|
+
# E.g pkg://package1.configs;file:///path/to/configs
|
|
28
|
+
if quadra_search_path is not None:
|
|
29
|
+
for i, path in enumerate(quadra_search_path.split(";")):
|
|
30
|
+
search_path.append(provider=f"quadra-searchpath-plugin-{i}", path=path)
|
quadra/__init__.py
CHANGED
|
File without changes
|
|
@@ -0,0 +1,289 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import cv2
|
|
7
|
+
import matplotlib
|
|
8
|
+
import matplotlib.pyplot as plt
|
|
9
|
+
import numpy as np
|
|
10
|
+
import pytorch_lightning as pl
|
|
11
|
+
from anomalib.models.components.base import AnomalyModule
|
|
12
|
+
from anomalib.post_processing import (
|
|
13
|
+
add_anomalous_label,
|
|
14
|
+
add_normal_label,
|
|
15
|
+
compute_mask,
|
|
16
|
+
superimpose_anomaly_map,
|
|
17
|
+
)
|
|
18
|
+
from anomalib.pre_processing.transforms import Denormalize
|
|
19
|
+
from anomalib.utils.loggers import AnomalibWandbLogger
|
|
20
|
+
from pytorch_lightning import Callback
|
|
21
|
+
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
|
22
|
+
from skimage.segmentation import mark_boundaries
|
|
23
|
+
from tqdm import tqdm
|
|
24
|
+
|
|
25
|
+
from quadra.utils.anomaly import MapOrValue
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class Visualizer:
|
|
29
|
+
"""Anomaly Visualization.
|
|
30
|
+
|
|
31
|
+
The visualizer object is responsible for collating all the images passed to it into a single image. This can then
|
|
32
|
+
either be logged by accessing the `figure` attribute or can be saved directly by calling `save()` method.
|
|
33
|
+
|
|
34
|
+
Example:
|
|
35
|
+
>>> visualizer = Visualizer()
|
|
36
|
+
>>> visualizer.add_image(image=image, title="Image")
|
|
37
|
+
>>> visualizer.close()
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(self) -> None:
|
|
41
|
+
self.images: list[dict] = []
|
|
42
|
+
|
|
43
|
+
self.figure: matplotlib.figure.Figure
|
|
44
|
+
self.axis: np.ndarray
|
|
45
|
+
|
|
46
|
+
def add_image(self, image: np.ndarray, title: str, color_map: str | None = None):
|
|
47
|
+
"""Add image to figure.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
image: Image which should be added to the figure.
|
|
51
|
+
title: Image title shown on the plot.
|
|
52
|
+
color_map: Name of matplotlib color map used to map scalar data to colours. Defaults to None.
|
|
53
|
+
"""
|
|
54
|
+
image_data = {"image": image, "title": title, "color_map": color_map}
|
|
55
|
+
self.images.append(image_data)
|
|
56
|
+
|
|
57
|
+
def generate(self):
|
|
58
|
+
"""Generate the image."""
|
|
59
|
+
default_plt_backend = plt.get_backend()
|
|
60
|
+
plt.switch_backend("Agg")
|
|
61
|
+
num_cols = len(self.images)
|
|
62
|
+
figure_size = (num_cols * 3, 3)
|
|
63
|
+
self.figure, self.axis = plt.subplots(1, num_cols, figsize=figure_size)
|
|
64
|
+
self.figure.subplots_adjust(right=0.9)
|
|
65
|
+
|
|
66
|
+
axes = self.axis if len(self.images) > 1 else [self.axis]
|
|
67
|
+
for axis, image_dict in zip(axes, self.images):
|
|
68
|
+
axis.axes.xaxis.set_visible(False)
|
|
69
|
+
axis.axes.yaxis.set_visible(False)
|
|
70
|
+
axis.imshow(image_dict["image"], image_dict["color_map"], vmin=0, vmax=255)
|
|
71
|
+
axis.title.set_text(image_dict["title"])
|
|
72
|
+
plt.switch_backend(default_plt_backend)
|
|
73
|
+
|
|
74
|
+
def show(self):
|
|
75
|
+
"""Show image on a matplotlib figure."""
|
|
76
|
+
self.figure.show()
|
|
77
|
+
|
|
78
|
+
def save(self, filename: Path):
|
|
79
|
+
"""Save image.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
filename: Filename to save image
|
|
83
|
+
"""
|
|
84
|
+
filename.parent.mkdir(parents=True, exist_ok=True)
|
|
85
|
+
self.figure.savefig(filename, dpi=100)
|
|
86
|
+
|
|
87
|
+
def close(self):
|
|
88
|
+
"""Close figure."""
|
|
89
|
+
plt.close(self.figure)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
# TODO: This is a lot different from the 0.3.7 anomalib one
|
|
93
|
+
class VisualizerCallback(Callback):
|
|
94
|
+
"""Callback that visualizes the inference results of a model.
|
|
95
|
+
|
|
96
|
+
The callback generates a figure showing the original image, the ground truth segmentation mask,
|
|
97
|
+
the predicted error heat map, and the predicted segmentation mask.
|
|
98
|
+
To save the images to the filesystem, add the 'local' keyword to the `project.log_images_to` parameter in the
|
|
99
|
+
config.yaml file.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
task: either 'segmentation' or 'classification'
|
|
103
|
+
output_path: location where the images will be saved.
|
|
104
|
+
inputs_are_normalized: whether the input images are normalized (like when using MinMax or Treshold callback).
|
|
105
|
+
threshold_type: Either 'pixel' or 'image'. If 'pixel', the threshold is computed on the pixel-level.
|
|
106
|
+
disable: whether to disable the callback.
|
|
107
|
+
plot_only_wrong: whether to plot only the images that are not correctly predicted.
|
|
108
|
+
plot_raw_outputs: Saves the raw images of the segmentation and heatmap output.
|
|
109
|
+
"""
|
|
110
|
+
|
|
111
|
+
def __init__(
|
|
112
|
+
self,
|
|
113
|
+
task: str = "segmentation",
|
|
114
|
+
output_path: str = "anomaly_output",
|
|
115
|
+
inputs_are_normalized: bool = True,
|
|
116
|
+
threshold_type: str = "pixel",
|
|
117
|
+
disable: bool = False,
|
|
118
|
+
plot_only_wrong: bool = False,
|
|
119
|
+
plot_raw_outputs: bool = False,
|
|
120
|
+
) -> None:
|
|
121
|
+
self.inputs_are_normalized = inputs_are_normalized
|
|
122
|
+
self.output_path = output_path
|
|
123
|
+
self.threshold_type = threshold_type
|
|
124
|
+
self.disable = disable
|
|
125
|
+
self.task = task
|
|
126
|
+
self.plot_only_wrong = plot_only_wrong
|
|
127
|
+
self.plot_raw_outputs = plot_raw_outputs
|
|
128
|
+
|
|
129
|
+
def _add_images(self, visualizer: Visualizer, filename: Path, output_label_folder: str):
|
|
130
|
+
"""Save image to logger/local storage.
|
|
131
|
+
|
|
132
|
+
Saves the image in `visualizer.figure` to the respective loggers and local storage if specified in
|
|
133
|
+
`log_images_to` in `config.yaml` of the models.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
visualizer: Visualizer object from which the `figure` is saved/logged.
|
|
137
|
+
filename: Path of the input image. This name is used as name for the generated image.
|
|
138
|
+
output_label_folder: ok if the image is correctly predicted or wrong if it is not
|
|
139
|
+
"""
|
|
140
|
+
visualizer.save(
|
|
141
|
+
Path(self.output_path)
|
|
142
|
+
/ "images"
|
|
143
|
+
/ output_label_folder
|
|
144
|
+
/ filename.parent.name
|
|
145
|
+
/ Path(filename.stem + ".png")
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
def on_test_batch_end(
|
|
149
|
+
self,
|
|
150
|
+
trainer: pl.Trainer,
|
|
151
|
+
pl_module: AnomalyModule,
|
|
152
|
+
outputs: STEP_OUTPUT | None,
|
|
153
|
+
batch: Any,
|
|
154
|
+
batch_idx: int,
|
|
155
|
+
dataloader_idx: int = 0,
|
|
156
|
+
) -> None:
|
|
157
|
+
"""Log images at the end of every batch.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
trainer: Pytorch lightning trainer object (unused).
|
|
161
|
+
pl_module: Lightning modules derived from BaseAnomalyLightning object as
|
|
162
|
+
currently only they support logging images.
|
|
163
|
+
outputs: Outputs of the current test step.
|
|
164
|
+
batch: Input batch of the current test step (unused).
|
|
165
|
+
batch_idx: Index of the current test batch (unused).
|
|
166
|
+
dataloader_idx: Index of the dataloader that yielded the current batch (unused).
|
|
167
|
+
"""
|
|
168
|
+
if self.disable:
|
|
169
|
+
return
|
|
170
|
+
|
|
171
|
+
assert outputs is not None and isinstance(outputs, dict)
|
|
172
|
+
|
|
173
|
+
if any(x not in outputs for x in ["image_path", "image", "mask", "anomaly_maps", "label"]):
|
|
174
|
+
# I'm probably in the classification scenario so I can't use the visualizer
|
|
175
|
+
return
|
|
176
|
+
|
|
177
|
+
if self.threshold_type == "pixel":
|
|
178
|
+
if hasattr(pl_module.pixel_metrics.F1Score, "threshold"):
|
|
179
|
+
threshold = pl_module.pixel_metrics.F1Score.threshold
|
|
180
|
+
else:
|
|
181
|
+
raise AttributeError("Metric has no threshold attribute")
|
|
182
|
+
elif hasattr(pl_module.image_metrics.F1Score, "threshold"):
|
|
183
|
+
threshold = pl_module.image_metrics.F1Score.threshold
|
|
184
|
+
else:
|
|
185
|
+
raise AttributeError("Metric has no threshold attribute")
|
|
186
|
+
|
|
187
|
+
for (
|
|
188
|
+
filename,
|
|
189
|
+
image,
|
|
190
|
+
true_mask,
|
|
191
|
+
anomaly_map,
|
|
192
|
+
gt_label,
|
|
193
|
+
pred_label,
|
|
194
|
+
anomaly_score,
|
|
195
|
+
) in tqdm(
|
|
196
|
+
zip(
|
|
197
|
+
outputs["image_path"],
|
|
198
|
+
outputs["image"],
|
|
199
|
+
outputs["mask"],
|
|
200
|
+
outputs["anomaly_maps"],
|
|
201
|
+
outputs["label"],
|
|
202
|
+
outputs["pred_labels"],
|
|
203
|
+
outputs["pred_scores"],
|
|
204
|
+
)
|
|
205
|
+
):
|
|
206
|
+
denormalized_image = Denormalize()(image.cpu())
|
|
207
|
+
current_true_mask = true_mask.cpu().numpy()
|
|
208
|
+
current_anomaly_map = anomaly_map.cpu().numpy()
|
|
209
|
+
# Normalize the map and rescale it to 0-1 range
|
|
210
|
+
# In this case we are saying that the anomaly map is in the range [normalized_th - 50, normalized_th + 50]
|
|
211
|
+
# This allow to have a stronger color for the anomalies and a lighter one for really normal regions
|
|
212
|
+
# It's also independent from the max or min anomaly score!
|
|
213
|
+
normalized_map: MapOrValue = (current_anomaly_map - (threshold - 50)) / 100
|
|
214
|
+
normalized_map = np.clip(normalized_map, 0, 1)
|
|
215
|
+
|
|
216
|
+
output_label_folder = "ok" if pred_label == gt_label else "wrong"
|
|
217
|
+
|
|
218
|
+
if self.plot_only_wrong and output_label_folder == "ok":
|
|
219
|
+
continue
|
|
220
|
+
|
|
221
|
+
heatmap = superimpose_anomaly_map(
|
|
222
|
+
normalized_map, denormalized_image, normalize=not self.inputs_are_normalized
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
if isinstance(threshold, float):
|
|
226
|
+
pred_mask = compute_mask(current_anomaly_map, threshold)
|
|
227
|
+
else:
|
|
228
|
+
raise TypeError("Threshold should be float")
|
|
229
|
+
vis_img = mark_boundaries(denormalized_image, pred_mask, color=(1, 0, 0), mode="thick")
|
|
230
|
+
visualizer = Visualizer()
|
|
231
|
+
|
|
232
|
+
if self.task == "segmentation":
|
|
233
|
+
visualizer.add_image(image=denormalized_image, title="Image")
|
|
234
|
+
if "mask" in outputs:
|
|
235
|
+
current_true_mask = current_true_mask * 255
|
|
236
|
+
visualizer.add_image(image=current_true_mask, color_map="gray", title="Ground Truth")
|
|
237
|
+
visualizer.add_image(image=heatmap, title="Predicted Heat Map")
|
|
238
|
+
visualizer.add_image(image=pred_mask, color_map="gray", title="Predicted Mask")
|
|
239
|
+
visualizer.add_image(image=vis_img, title="Segmentation Result")
|
|
240
|
+
elif self.task == "classification":
|
|
241
|
+
gt_im = add_anomalous_label(denormalized_image) if gt_label else add_normal_label(denormalized_image)
|
|
242
|
+
visualizer.add_image(gt_im, title="Image/True label")
|
|
243
|
+
if anomaly_score >= threshold:
|
|
244
|
+
image_classified = add_anomalous_label(heatmap, anomaly_score)
|
|
245
|
+
else:
|
|
246
|
+
image_classified = add_normal_label(heatmap, 1 - anomaly_score)
|
|
247
|
+
visualizer.add_image(image=image_classified, title="Prediction")
|
|
248
|
+
|
|
249
|
+
visualizer.generate()
|
|
250
|
+
visualizer.figure.suptitle(
|
|
251
|
+
f"F1 threshold: {threshold}, Mask_max: {current_anomaly_map.max():.3f}, "
|
|
252
|
+
f"Anomaly_score: {anomaly_score:.3f}"
|
|
253
|
+
)
|
|
254
|
+
path_filename = Path(filename)
|
|
255
|
+
self._add_images(visualizer, path_filename, output_label_folder)
|
|
256
|
+
visualizer.close()
|
|
257
|
+
|
|
258
|
+
if self.plot_raw_outputs:
|
|
259
|
+
for raw_output, raw_name in zip([heatmap, vis_img], ["heatmap", "segmentation"]):
|
|
260
|
+
current_raw_output = raw_output
|
|
261
|
+
if raw_name == "segmentation":
|
|
262
|
+
current_raw_output = (raw_output * 255).astype(np.uint8)
|
|
263
|
+
current_raw_output = cv2.cvtColor(current_raw_output, cv2.COLOR_RGB2BGR)
|
|
264
|
+
raw_filename = (
|
|
265
|
+
Path(self.output_path)
|
|
266
|
+
/ "images"
|
|
267
|
+
/ output_label_folder
|
|
268
|
+
/ path_filename.parent.name
|
|
269
|
+
/ "raw_outputs"
|
|
270
|
+
/ Path(path_filename.stem + f"_{raw_name}.png")
|
|
271
|
+
)
|
|
272
|
+
raw_filename.parent.mkdir(parents=True, exist_ok=True)
|
|
273
|
+
cv2.imwrite(str(raw_filename), current_raw_output)
|
|
274
|
+
|
|
275
|
+
def on_test_end(self, _trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
|
|
276
|
+
"""Sync logs.
|
|
277
|
+
|
|
278
|
+
Currently only ``AnomalibWandbLogger`` is called from this method. This is because logging as a single batch
|
|
279
|
+
ensures that all images appear as part of the same step.
|
|
280
|
+
|
|
281
|
+
Args:
|
|
282
|
+
_trainer: Pytorch Lightning trainer (unused)
|
|
283
|
+
pl_module: Anomaly module
|
|
284
|
+
"""
|
|
285
|
+
if self.disable:
|
|
286
|
+
return
|
|
287
|
+
|
|
288
|
+
if pl_module.logger is not None and isinstance(pl_module.logger, AnomalibWandbLogger):
|
|
289
|
+
pl_module.logger.save()
|