quadra 0.0.1__py3-none-any.whl → 2.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (302) hide show
  1. hydra_plugins/quadra_searchpath_plugin.py +30 -0
  2. quadra/__init__.py +6 -0
  3. quadra/callbacks/__init__.py +0 -0
  4. quadra/callbacks/anomalib.py +289 -0
  5. quadra/callbacks/lightning.py +501 -0
  6. quadra/callbacks/mlflow.py +291 -0
  7. quadra/callbacks/scheduler.py +69 -0
  8. quadra/configs/__init__.py +0 -0
  9. quadra/configs/backbone/caformer_m36.yaml +8 -0
  10. quadra/configs/backbone/caformer_s36.yaml +8 -0
  11. quadra/configs/backbone/convnextv2_base.yaml +8 -0
  12. quadra/configs/backbone/convnextv2_femto.yaml +8 -0
  13. quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
  14. quadra/configs/backbone/dino_vitb8.yaml +12 -0
  15. quadra/configs/backbone/dino_vits8.yaml +12 -0
  16. quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
  17. quadra/configs/backbone/dinov2_vits14.yaml +12 -0
  18. quadra/configs/backbone/efficientnet_b0.yaml +8 -0
  19. quadra/configs/backbone/efficientnet_b1.yaml +8 -0
  20. quadra/configs/backbone/efficientnet_b2.yaml +8 -0
  21. quadra/configs/backbone/efficientnet_b3.yaml +8 -0
  22. quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
  23. quadra/configs/backbone/levit_128s.yaml +8 -0
  24. quadra/configs/backbone/mnasnet0_5.yaml +9 -0
  25. quadra/configs/backbone/resnet101.yaml +8 -0
  26. quadra/configs/backbone/resnet18.yaml +8 -0
  27. quadra/configs/backbone/resnet18_ssl.yaml +8 -0
  28. quadra/configs/backbone/resnet50.yaml +8 -0
  29. quadra/configs/backbone/smp.yaml +9 -0
  30. quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
  31. quadra/configs/backbone/unetr.yaml +15 -0
  32. quadra/configs/backbone/vit16_base.yaml +9 -0
  33. quadra/configs/backbone/vit16_small.yaml +9 -0
  34. quadra/configs/backbone/vit16_tiny.yaml +9 -0
  35. quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
  36. quadra/configs/callbacks/all.yaml +32 -0
  37. quadra/configs/callbacks/default.yaml +37 -0
  38. quadra/configs/callbacks/default_anomalib.yaml +67 -0
  39. quadra/configs/config.yaml +33 -0
  40. quadra/configs/core/default.yaml +11 -0
  41. quadra/configs/datamodule/base/anomaly.yaml +16 -0
  42. quadra/configs/datamodule/base/classification.yaml +21 -0
  43. quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
  44. quadra/configs/datamodule/base/segmentation.yaml +18 -0
  45. quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
  46. quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
  47. quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
  48. quadra/configs/datamodule/base/ssl.yaml +21 -0
  49. quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
  50. quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
  51. quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
  52. quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
  53. quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
  54. quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
  55. quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
  56. quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
  57. quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
  58. quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
  59. quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
  60. quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
  61. quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
  62. quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
  63. quadra/configs/experiment/base/classification/classification.yaml +73 -0
  64. quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
  65. quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
  66. quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
  67. quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
  68. quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
  69. quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
  70. quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
  71. quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
  72. quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
  73. quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
  74. quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
  75. quadra/configs/experiment/base/ssl/byol.yaml +43 -0
  76. quadra/configs/experiment/base/ssl/dino.yaml +46 -0
  77. quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
  78. quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
  79. quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
  80. quadra/configs/experiment/custom/cls.yaml +12 -0
  81. quadra/configs/experiment/default.yaml +15 -0
  82. quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
  83. quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
  84. quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
  85. quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
  86. quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
  87. quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
  88. quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
  89. quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
  90. quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
  91. quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
  92. quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
  93. quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
  94. quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
  95. quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
  96. quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
  97. quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
  98. quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
  99. quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
  100. quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
  101. quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
  102. quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
  103. quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
  104. quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
  105. quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
  106. quadra/configs/export/default.yaml +13 -0
  107. quadra/configs/hydra/anomaly_custom.yaml +15 -0
  108. quadra/configs/hydra/default.yaml +14 -0
  109. quadra/configs/inference/default.yaml +26 -0
  110. quadra/configs/logger/comet.yaml +10 -0
  111. quadra/configs/logger/csv.yaml +5 -0
  112. quadra/configs/logger/mlflow.yaml +12 -0
  113. quadra/configs/logger/tensorboard.yaml +8 -0
  114. quadra/configs/loss/asl.yaml +7 -0
  115. quadra/configs/loss/barlow.yaml +2 -0
  116. quadra/configs/loss/bce.yaml +1 -0
  117. quadra/configs/loss/byol.yaml +1 -0
  118. quadra/configs/loss/cross_entropy.yaml +1 -0
  119. quadra/configs/loss/dino.yaml +8 -0
  120. quadra/configs/loss/simclr.yaml +2 -0
  121. quadra/configs/loss/simsiam.yaml +1 -0
  122. quadra/configs/loss/smp_ce.yaml +3 -0
  123. quadra/configs/loss/smp_dice.yaml +2 -0
  124. quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
  125. quadra/configs/loss/smp_mcc.yaml +2 -0
  126. quadra/configs/loss/vicreg.yaml +5 -0
  127. quadra/configs/model/anomalib/cfa.yaml +35 -0
  128. quadra/configs/model/anomalib/cflow.yaml +30 -0
  129. quadra/configs/model/anomalib/csflow.yaml +34 -0
  130. quadra/configs/model/anomalib/dfm.yaml +19 -0
  131. quadra/configs/model/anomalib/draem.yaml +29 -0
  132. quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
  133. quadra/configs/model/anomalib/fastflow.yaml +32 -0
  134. quadra/configs/model/anomalib/padim.yaml +32 -0
  135. quadra/configs/model/anomalib/patchcore.yaml +36 -0
  136. quadra/configs/model/barlow.yaml +16 -0
  137. quadra/configs/model/byol.yaml +25 -0
  138. quadra/configs/model/classification.yaml +10 -0
  139. quadra/configs/model/dino.yaml +26 -0
  140. quadra/configs/model/logistic_regression.yaml +4 -0
  141. quadra/configs/model/multilabel_classification.yaml +9 -0
  142. quadra/configs/model/simclr.yaml +18 -0
  143. quadra/configs/model/simsiam.yaml +24 -0
  144. quadra/configs/model/smp.yaml +4 -0
  145. quadra/configs/model/smp_multiclass.yaml +4 -0
  146. quadra/configs/model/vicreg.yaml +16 -0
  147. quadra/configs/optimizer/adam.yaml +5 -0
  148. quadra/configs/optimizer/adamw.yaml +3 -0
  149. quadra/configs/optimizer/default.yaml +4 -0
  150. quadra/configs/optimizer/lars.yaml +8 -0
  151. quadra/configs/optimizer/sgd.yaml +4 -0
  152. quadra/configs/scheduler/default.yaml +5 -0
  153. quadra/configs/scheduler/rop.yaml +5 -0
  154. quadra/configs/scheduler/step.yaml +3 -0
  155. quadra/configs/scheduler/warmrestart.yaml +2 -0
  156. quadra/configs/scheduler/warmup.yaml +6 -0
  157. quadra/configs/task/anomalib/cfa.yaml +5 -0
  158. quadra/configs/task/anomalib/cflow.yaml +5 -0
  159. quadra/configs/task/anomalib/csflow.yaml +5 -0
  160. quadra/configs/task/anomalib/draem.yaml +5 -0
  161. quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
  162. quadra/configs/task/anomalib/fastflow.yaml +5 -0
  163. quadra/configs/task/anomalib/inference.yaml +3 -0
  164. quadra/configs/task/anomalib/padim.yaml +5 -0
  165. quadra/configs/task/anomalib/patchcore.yaml +5 -0
  166. quadra/configs/task/classification.yaml +6 -0
  167. quadra/configs/task/classification_evaluation.yaml +6 -0
  168. quadra/configs/task/default.yaml +1 -0
  169. quadra/configs/task/segmentation.yaml +9 -0
  170. quadra/configs/task/segmentation_evaluation.yaml +3 -0
  171. quadra/configs/task/sklearn_classification.yaml +13 -0
  172. quadra/configs/task/sklearn_classification_patch.yaml +11 -0
  173. quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
  174. quadra/configs/task/sklearn_classification_test.yaml +8 -0
  175. quadra/configs/task/ssl.yaml +2 -0
  176. quadra/configs/trainer/lightning_cpu.yaml +36 -0
  177. quadra/configs/trainer/lightning_gpu.yaml +35 -0
  178. quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
  179. quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
  180. quadra/configs/trainer/lightning_multigpu.yaml +37 -0
  181. quadra/configs/trainer/sklearn_classification.yaml +7 -0
  182. quadra/configs/transforms/byol.yaml +47 -0
  183. quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
  184. quadra/configs/transforms/default.yaml +37 -0
  185. quadra/configs/transforms/default_numpy.yaml +24 -0
  186. quadra/configs/transforms/default_resize.yaml +22 -0
  187. quadra/configs/transforms/dino.yaml +63 -0
  188. quadra/configs/transforms/linear_eval.yaml +18 -0
  189. quadra/datamodules/__init__.py +20 -0
  190. quadra/datamodules/anomaly.py +180 -0
  191. quadra/datamodules/base.py +375 -0
  192. quadra/datamodules/classification.py +1003 -0
  193. quadra/datamodules/generic/__init__.py +0 -0
  194. quadra/datamodules/generic/imagenette.py +144 -0
  195. quadra/datamodules/generic/mnist.py +81 -0
  196. quadra/datamodules/generic/mvtec.py +58 -0
  197. quadra/datamodules/generic/oxford_pet.py +163 -0
  198. quadra/datamodules/patch.py +190 -0
  199. quadra/datamodules/segmentation.py +742 -0
  200. quadra/datamodules/ssl.py +140 -0
  201. quadra/datasets/__init__.py +17 -0
  202. quadra/datasets/anomaly.py +287 -0
  203. quadra/datasets/classification.py +241 -0
  204. quadra/datasets/patch.py +138 -0
  205. quadra/datasets/segmentation.py +239 -0
  206. quadra/datasets/ssl.py +110 -0
  207. quadra/losses/__init__.py +0 -0
  208. quadra/losses/classification/__init__.py +6 -0
  209. quadra/losses/classification/asl.py +83 -0
  210. quadra/losses/classification/focal.py +320 -0
  211. quadra/losses/classification/prototypical.py +148 -0
  212. quadra/losses/ssl/__init__.py +17 -0
  213. quadra/losses/ssl/barlowtwins.py +47 -0
  214. quadra/losses/ssl/byol.py +37 -0
  215. quadra/losses/ssl/dino.py +129 -0
  216. quadra/losses/ssl/hyperspherical.py +45 -0
  217. quadra/losses/ssl/idmm.py +50 -0
  218. quadra/losses/ssl/simclr.py +67 -0
  219. quadra/losses/ssl/simsiam.py +30 -0
  220. quadra/losses/ssl/vicreg.py +76 -0
  221. quadra/main.py +46 -0
  222. quadra/metrics/__init__.py +3 -0
  223. quadra/metrics/segmentation.py +251 -0
  224. quadra/models/__init__.py +0 -0
  225. quadra/models/base.py +151 -0
  226. quadra/models/classification/__init__.py +8 -0
  227. quadra/models/classification/backbones.py +149 -0
  228. quadra/models/classification/base.py +92 -0
  229. quadra/models/evaluation.py +322 -0
  230. quadra/modules/__init__.py +0 -0
  231. quadra/modules/backbone.py +30 -0
  232. quadra/modules/base.py +312 -0
  233. quadra/modules/classification/__init__.py +3 -0
  234. quadra/modules/classification/base.py +331 -0
  235. quadra/modules/ssl/__init__.py +17 -0
  236. quadra/modules/ssl/barlowtwins.py +59 -0
  237. quadra/modules/ssl/byol.py +172 -0
  238. quadra/modules/ssl/common.py +285 -0
  239. quadra/modules/ssl/dino.py +186 -0
  240. quadra/modules/ssl/hyperspherical.py +206 -0
  241. quadra/modules/ssl/idmm.py +98 -0
  242. quadra/modules/ssl/simclr.py +73 -0
  243. quadra/modules/ssl/simsiam.py +68 -0
  244. quadra/modules/ssl/vicreg.py +67 -0
  245. quadra/optimizers/__init__.py +4 -0
  246. quadra/optimizers/lars.py +153 -0
  247. quadra/optimizers/sam.py +127 -0
  248. quadra/schedulers/__init__.py +3 -0
  249. quadra/schedulers/base.py +44 -0
  250. quadra/schedulers/warmup.py +127 -0
  251. quadra/tasks/__init__.py +24 -0
  252. quadra/tasks/anomaly.py +582 -0
  253. quadra/tasks/base.py +397 -0
  254. quadra/tasks/classification.py +1264 -0
  255. quadra/tasks/patch.py +492 -0
  256. quadra/tasks/segmentation.py +389 -0
  257. quadra/tasks/ssl.py +560 -0
  258. quadra/trainers/README.md +3 -0
  259. quadra/trainers/__init__.py +0 -0
  260. quadra/trainers/classification.py +179 -0
  261. quadra/utils/__init__.py +0 -0
  262. quadra/utils/anomaly.py +112 -0
  263. quadra/utils/classification.py +618 -0
  264. quadra/utils/deprecation.py +31 -0
  265. quadra/utils/evaluation.py +474 -0
  266. quadra/utils/export.py +579 -0
  267. quadra/utils/imaging.py +32 -0
  268. quadra/utils/logger.py +15 -0
  269. quadra/utils/mlflow.py +98 -0
  270. quadra/utils/model_manager.py +320 -0
  271. quadra/utils/models.py +524 -0
  272. quadra/utils/patch/__init__.py +15 -0
  273. quadra/utils/patch/dataset.py +1433 -0
  274. quadra/utils/patch/metrics.py +449 -0
  275. quadra/utils/patch/model.py +153 -0
  276. quadra/utils/patch/visualization.py +217 -0
  277. quadra/utils/resolver.py +42 -0
  278. quadra/utils/segmentation.py +31 -0
  279. quadra/utils/tests/__init__.py +0 -0
  280. quadra/utils/tests/fixtures/__init__.py +1 -0
  281. quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
  282. quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
  283. quadra/utils/tests/fixtures/dataset/classification.py +406 -0
  284. quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
  285. quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
  286. quadra/utils/tests/fixtures/models/__init__.py +3 -0
  287. quadra/utils/tests/fixtures/models/anomaly.py +89 -0
  288. quadra/utils/tests/fixtures/models/classification.py +45 -0
  289. quadra/utils/tests/fixtures/models/segmentation.py +33 -0
  290. quadra/utils/tests/helpers.py +70 -0
  291. quadra/utils/tests/models.py +27 -0
  292. quadra/utils/utils.py +525 -0
  293. quadra/utils/validator.py +115 -0
  294. quadra/utils/visualization.py +422 -0
  295. quadra/utils/vit_explainability.py +349 -0
  296. quadra-2.1.13.dist-info/LICENSE +201 -0
  297. quadra-2.1.13.dist-info/METADATA +386 -0
  298. quadra-2.1.13.dist-info/RECORD +300 -0
  299. {quadra-0.0.1.dist-info → quadra-2.1.13.dist-info}/WHEEL +1 -1
  300. quadra-2.1.13.dist-info/entry_points.txt +3 -0
  301. quadra-0.0.1.dist-info/METADATA +0 -14
  302. quadra-0.0.1.dist-info/RECORD +0 -4
@@ -0,0 +1,30 @@
1
+ import os
2
+
3
+ import dotenv
4
+ from hydra.core.config_search_path import ConfigSearchPath
5
+ from hydra.plugins.search_path_plugin import SearchPathPlugin
6
+
7
+
8
+ class QuadraSearchPathPlugin(SearchPathPlugin):
9
+ """Generic Search Path Plugin class."""
10
+
11
+ def __init__(self):
12
+ try:
13
+ os.getcwd()
14
+ except FileNotFoundError:
15
+ # This may happen when running tests
16
+ return
17
+
18
+ if os.path.exists(os.path.join(os.getcwd(), ".env")):
19
+ dotenv.load_dotenv(dotenv_path=os.path.join(os.getcwd(), ".env"), override=True)
20
+
21
+ def manipulate_search_path(self, search_path: ConfigSearchPath) -> None:
22
+ """Plugin used to add custom config to searchpath to be discovered by quadra."""
23
+ # This can be global or taken from the .env
24
+ quadra_search_path = os.environ.get("QUADRA_SEARCH_PATH", None)
25
+
26
+ # Path should be specified as a list of hydra path separated by ";"
27
+ # E.g pkg://package1.configs;file:///path/to/configs
28
+ if quadra_search_path is not None:
29
+ for i, path in enumerate(quadra_search_path.split(";")):
30
+ search_path.append(provider=f"quadra-searchpath-plugin-{i}", path=path)
quadra/__init__.py CHANGED
@@ -0,0 +1,6 @@
1
+ __version__ = "2.1.13"
2
+
3
+
4
+ def get_version():
5
+ """Returns the version of the package."""
6
+ return __version__
File without changes
@@ -0,0 +1,289 @@
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import Any
5
+
6
+ import cv2
7
+ import matplotlib
8
+ import matplotlib.pyplot as plt
9
+ import numpy as np
10
+ import pytorch_lightning as pl
11
+ from anomalib.models.components.base import AnomalyModule
12
+ from anomalib.post_processing import (
13
+ add_anomalous_label,
14
+ add_normal_label,
15
+ compute_mask,
16
+ superimpose_anomaly_map,
17
+ )
18
+ from anomalib.pre_processing.transforms import Denormalize
19
+ from anomalib.utils.loggers import AnomalibWandbLogger
20
+ from pytorch_lightning import Callback
21
+ from pytorch_lightning.utilities.types import STEP_OUTPUT
22
+ from skimage.segmentation import mark_boundaries
23
+ from tqdm import tqdm
24
+
25
+ from quadra.utils.anomaly import MapOrValue
26
+
27
+
28
+ class Visualizer:
29
+ """Anomaly Visualization.
30
+
31
+ The visualizer object is responsible for collating all the images passed to it into a single image. This can then
32
+ either be logged by accessing the `figure` attribute or can be saved directly by calling `save()` method.
33
+
34
+ Example:
35
+ >>> visualizer = Visualizer()
36
+ >>> visualizer.add_image(image=image, title="Image")
37
+ >>> visualizer.close()
38
+ """
39
+
40
+ def __init__(self) -> None:
41
+ self.images: list[dict] = []
42
+
43
+ self.figure: matplotlib.figure.Figure
44
+ self.axis: np.ndarray
45
+
46
+ def add_image(self, image: np.ndarray, title: str, color_map: str | None = None):
47
+ """Add image to figure.
48
+
49
+ Args:
50
+ image: Image which should be added to the figure.
51
+ title: Image title shown on the plot.
52
+ color_map: Name of matplotlib color map used to map scalar data to colours. Defaults to None.
53
+ """
54
+ image_data = {"image": image, "title": title, "color_map": color_map}
55
+ self.images.append(image_data)
56
+
57
+ def generate(self):
58
+ """Generate the image."""
59
+ default_plt_backend = plt.get_backend()
60
+ plt.switch_backend("Agg")
61
+ num_cols = len(self.images)
62
+ figure_size = (num_cols * 3, 3)
63
+ self.figure, self.axis = plt.subplots(1, num_cols, figsize=figure_size)
64
+ self.figure.subplots_adjust(right=0.9)
65
+
66
+ axes = self.axis if len(self.images) > 1 else [self.axis]
67
+ for axis, image_dict in zip(axes, self.images):
68
+ axis.axes.xaxis.set_visible(False)
69
+ axis.axes.yaxis.set_visible(False)
70
+ axis.imshow(image_dict["image"], image_dict["color_map"], vmin=0, vmax=255)
71
+ axis.title.set_text(image_dict["title"])
72
+ plt.switch_backend(default_plt_backend)
73
+
74
+ def show(self):
75
+ """Show image on a matplotlib figure."""
76
+ self.figure.show()
77
+
78
+ def save(self, filename: Path):
79
+ """Save image.
80
+
81
+ Args:
82
+ filename: Filename to save image
83
+ """
84
+ filename.parent.mkdir(parents=True, exist_ok=True)
85
+ self.figure.savefig(filename, dpi=100)
86
+
87
+ def close(self):
88
+ """Close figure."""
89
+ plt.close(self.figure)
90
+
91
+
92
+ # TODO: This is a lot different from the 0.3.7 anomalib one
93
+ class VisualizerCallback(Callback):
94
+ """Callback that visualizes the inference results of a model.
95
+
96
+ The callback generates a figure showing the original image, the ground truth segmentation mask,
97
+ the predicted error heat map, and the predicted segmentation mask.
98
+ To save the images to the filesystem, add the 'local' keyword to the `project.log_images_to` parameter in the
99
+ config.yaml file.
100
+
101
+ Args:
102
+ task: either 'segmentation' or 'classification'
103
+ output_path: location where the images will be saved.
104
+ inputs_are_normalized: whether the input images are normalized (like when using MinMax or Treshold callback).
105
+ threshold_type: Either 'pixel' or 'image'. If 'pixel', the threshold is computed on the pixel-level.
106
+ disable: whether to disable the callback.
107
+ plot_only_wrong: whether to plot only the images that are not correctly predicted.
108
+ plot_raw_outputs: Saves the raw images of the segmentation and heatmap output.
109
+ """
110
+
111
+ def __init__(
112
+ self,
113
+ task: str = "segmentation",
114
+ output_path: str = "anomaly_output",
115
+ inputs_are_normalized: bool = True,
116
+ threshold_type: str = "pixel",
117
+ disable: bool = False,
118
+ plot_only_wrong: bool = False,
119
+ plot_raw_outputs: bool = False,
120
+ ) -> None:
121
+ self.inputs_are_normalized = inputs_are_normalized
122
+ self.output_path = output_path
123
+ self.threshold_type = threshold_type
124
+ self.disable = disable
125
+ self.task = task
126
+ self.plot_only_wrong = plot_only_wrong
127
+ self.plot_raw_outputs = plot_raw_outputs
128
+
129
+ def _add_images(self, visualizer: Visualizer, filename: Path, output_label_folder: str):
130
+ """Save image to logger/local storage.
131
+
132
+ Saves the image in `visualizer.figure` to the respective loggers and local storage if specified in
133
+ `log_images_to` in `config.yaml` of the models.
134
+
135
+ Args:
136
+ visualizer: Visualizer object from which the `figure` is saved/logged.
137
+ filename: Path of the input image. This name is used as name for the generated image.
138
+ output_label_folder: ok if the image is correctly predicted or wrong if it is not
139
+ """
140
+ visualizer.save(
141
+ Path(self.output_path)
142
+ / "images"
143
+ / output_label_folder
144
+ / filename.parent.name
145
+ / Path(filename.stem + ".png")
146
+ )
147
+
148
+ def on_test_batch_end(
149
+ self,
150
+ trainer: pl.Trainer,
151
+ pl_module: AnomalyModule,
152
+ outputs: STEP_OUTPUT | None,
153
+ batch: Any,
154
+ batch_idx: int,
155
+ dataloader_idx: int = 0,
156
+ ) -> None:
157
+ """Log images at the end of every batch.
158
+
159
+ Args:
160
+ trainer: Pytorch lightning trainer object (unused).
161
+ pl_module: Lightning modules derived from BaseAnomalyLightning object as
162
+ currently only they support logging images.
163
+ outputs: Outputs of the current test step.
164
+ batch: Input batch of the current test step (unused).
165
+ batch_idx: Index of the current test batch (unused).
166
+ dataloader_idx: Index of the dataloader that yielded the current batch (unused).
167
+ """
168
+ if self.disable:
169
+ return
170
+
171
+ assert outputs is not None and isinstance(outputs, dict)
172
+
173
+ if any(x not in outputs for x in ["image_path", "image", "mask", "anomaly_maps", "label"]):
174
+ # I'm probably in the classification scenario so I can't use the visualizer
175
+ return
176
+
177
+ if self.threshold_type == "pixel":
178
+ if hasattr(pl_module.pixel_metrics.F1Score, "threshold"):
179
+ threshold = pl_module.pixel_metrics.F1Score.threshold
180
+ else:
181
+ raise AttributeError("Metric has no threshold attribute")
182
+ elif hasattr(pl_module.image_metrics.F1Score, "threshold"):
183
+ threshold = pl_module.image_metrics.F1Score.threshold
184
+ else:
185
+ raise AttributeError("Metric has no threshold attribute")
186
+
187
+ for (
188
+ filename,
189
+ image,
190
+ true_mask,
191
+ anomaly_map,
192
+ gt_label,
193
+ pred_label,
194
+ anomaly_score,
195
+ ) in tqdm(
196
+ zip(
197
+ outputs["image_path"],
198
+ outputs["image"],
199
+ outputs["mask"],
200
+ outputs["anomaly_maps"],
201
+ outputs["label"],
202
+ outputs["pred_labels"],
203
+ outputs["pred_scores"],
204
+ )
205
+ ):
206
+ denormalized_image = Denormalize()(image.cpu())
207
+ current_true_mask = true_mask.cpu().numpy()
208
+ current_anomaly_map = anomaly_map.cpu().numpy()
209
+ # Normalize the map and rescale it to 0-1 range
210
+ # In this case we are saying that the anomaly map is in the range [normalized_th - 50, normalized_th + 50]
211
+ # This allow to have a stronger color for the anomalies and a lighter one for really normal regions
212
+ # It's also independent from the max or min anomaly score!
213
+ normalized_map: MapOrValue = (current_anomaly_map - (threshold - 50)) / 100
214
+ normalized_map = np.clip(normalized_map, 0, 1)
215
+
216
+ output_label_folder = "ok" if pred_label == gt_label else "wrong"
217
+
218
+ if self.plot_only_wrong and output_label_folder == "ok":
219
+ continue
220
+
221
+ heatmap = superimpose_anomaly_map(
222
+ normalized_map, denormalized_image, normalize=not self.inputs_are_normalized
223
+ )
224
+
225
+ if isinstance(threshold, float):
226
+ pred_mask = compute_mask(current_anomaly_map, threshold)
227
+ else:
228
+ raise TypeError("Threshold should be float")
229
+ vis_img = mark_boundaries(denormalized_image, pred_mask, color=(1, 0, 0), mode="thick")
230
+ visualizer = Visualizer()
231
+
232
+ if self.task == "segmentation":
233
+ visualizer.add_image(image=denormalized_image, title="Image")
234
+ if "mask" in outputs:
235
+ current_true_mask = current_true_mask * 255
236
+ visualizer.add_image(image=current_true_mask, color_map="gray", title="Ground Truth")
237
+ visualizer.add_image(image=heatmap, title="Predicted Heat Map")
238
+ visualizer.add_image(image=pred_mask, color_map="gray", title="Predicted Mask")
239
+ visualizer.add_image(image=vis_img, title="Segmentation Result")
240
+ elif self.task == "classification":
241
+ gt_im = add_anomalous_label(denormalized_image) if gt_label else add_normal_label(denormalized_image)
242
+ visualizer.add_image(gt_im, title="Image/True label")
243
+ if anomaly_score >= threshold:
244
+ image_classified = add_anomalous_label(heatmap, anomaly_score)
245
+ else:
246
+ image_classified = add_normal_label(heatmap, 1 - anomaly_score)
247
+ visualizer.add_image(image=image_classified, title="Prediction")
248
+
249
+ visualizer.generate()
250
+ visualizer.figure.suptitle(
251
+ f"F1 threshold: {threshold}, Mask_max: {current_anomaly_map.max():.3f}, "
252
+ f"Anomaly_score: {anomaly_score:.3f}"
253
+ )
254
+ path_filename = Path(filename)
255
+ self._add_images(visualizer, path_filename, output_label_folder)
256
+ visualizer.close()
257
+
258
+ if self.plot_raw_outputs:
259
+ for raw_output, raw_name in zip([heatmap, vis_img], ["heatmap", "segmentation"]):
260
+ current_raw_output = raw_output
261
+ if raw_name == "segmentation":
262
+ current_raw_output = (raw_output * 255).astype(np.uint8)
263
+ current_raw_output = cv2.cvtColor(current_raw_output, cv2.COLOR_RGB2BGR)
264
+ raw_filename = (
265
+ Path(self.output_path)
266
+ / "images"
267
+ / output_label_folder
268
+ / path_filename.parent.name
269
+ / "raw_outputs"
270
+ / Path(path_filename.stem + f"_{raw_name}.png")
271
+ )
272
+ raw_filename.parent.mkdir(parents=True, exist_ok=True)
273
+ cv2.imwrite(str(raw_filename), current_raw_output)
274
+
275
+ def on_test_end(self, _trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
276
+ """Sync logs.
277
+
278
+ Currently only ``AnomalibWandbLogger`` is called from this method. This is because logging as a single batch
279
+ ensures that all images appear as part of the same step.
280
+
281
+ Args:
282
+ _trainer: Pytorch Lightning trainer (unused)
283
+ pl_module: Anomaly module
284
+ """
285
+ if self.disable:
286
+ return
287
+
288
+ if pl_module.logger is not None and isinstance(pl_module.logger, AnomalibWandbLogger):
289
+ pl_module.logger.save()