quadra 0.0.1__py3-none-any.whl → 2.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (302) hide show
  1. hydra_plugins/quadra_searchpath_plugin.py +30 -0
  2. quadra/__init__.py +6 -0
  3. quadra/callbacks/__init__.py +0 -0
  4. quadra/callbacks/anomalib.py +289 -0
  5. quadra/callbacks/lightning.py +501 -0
  6. quadra/callbacks/mlflow.py +291 -0
  7. quadra/callbacks/scheduler.py +69 -0
  8. quadra/configs/__init__.py +0 -0
  9. quadra/configs/backbone/caformer_m36.yaml +8 -0
  10. quadra/configs/backbone/caformer_s36.yaml +8 -0
  11. quadra/configs/backbone/convnextv2_base.yaml +8 -0
  12. quadra/configs/backbone/convnextv2_femto.yaml +8 -0
  13. quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
  14. quadra/configs/backbone/dino_vitb8.yaml +12 -0
  15. quadra/configs/backbone/dino_vits8.yaml +12 -0
  16. quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
  17. quadra/configs/backbone/dinov2_vits14.yaml +12 -0
  18. quadra/configs/backbone/efficientnet_b0.yaml +8 -0
  19. quadra/configs/backbone/efficientnet_b1.yaml +8 -0
  20. quadra/configs/backbone/efficientnet_b2.yaml +8 -0
  21. quadra/configs/backbone/efficientnet_b3.yaml +8 -0
  22. quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
  23. quadra/configs/backbone/levit_128s.yaml +8 -0
  24. quadra/configs/backbone/mnasnet0_5.yaml +9 -0
  25. quadra/configs/backbone/resnet101.yaml +8 -0
  26. quadra/configs/backbone/resnet18.yaml +8 -0
  27. quadra/configs/backbone/resnet18_ssl.yaml +8 -0
  28. quadra/configs/backbone/resnet50.yaml +8 -0
  29. quadra/configs/backbone/smp.yaml +9 -0
  30. quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
  31. quadra/configs/backbone/unetr.yaml +15 -0
  32. quadra/configs/backbone/vit16_base.yaml +9 -0
  33. quadra/configs/backbone/vit16_small.yaml +9 -0
  34. quadra/configs/backbone/vit16_tiny.yaml +9 -0
  35. quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
  36. quadra/configs/callbacks/all.yaml +45 -0
  37. quadra/configs/callbacks/default.yaml +34 -0
  38. quadra/configs/callbacks/default_anomalib.yaml +64 -0
  39. quadra/configs/config.yaml +33 -0
  40. quadra/configs/core/default.yaml +11 -0
  41. quadra/configs/datamodule/base/anomaly.yaml +16 -0
  42. quadra/configs/datamodule/base/classification.yaml +21 -0
  43. quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
  44. quadra/configs/datamodule/base/segmentation.yaml +18 -0
  45. quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
  46. quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
  47. quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
  48. quadra/configs/datamodule/base/ssl.yaml +21 -0
  49. quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
  50. quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
  51. quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
  52. quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
  53. quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
  54. quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
  55. quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
  56. quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
  57. quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
  58. quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
  59. quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
  60. quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
  61. quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
  62. quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
  63. quadra/configs/experiment/base/classification/classification.yaml +73 -0
  64. quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
  65. quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
  66. quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
  67. quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
  68. quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
  69. quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
  70. quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
  71. quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
  72. quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
  73. quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
  74. quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
  75. quadra/configs/experiment/base/ssl/byol.yaml +43 -0
  76. quadra/configs/experiment/base/ssl/dino.yaml +46 -0
  77. quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
  78. quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
  79. quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
  80. quadra/configs/experiment/custom/cls.yaml +12 -0
  81. quadra/configs/experiment/default.yaml +15 -0
  82. quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
  83. quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
  84. quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
  85. quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
  86. quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
  87. quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
  88. quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
  89. quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
  90. quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
  91. quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
  92. quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
  93. quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
  94. quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
  95. quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
  96. quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
  97. quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
  98. quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
  99. quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
  100. quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
  101. quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
  102. quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
  103. quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
  104. quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
  105. quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
  106. quadra/configs/export/default.yaml +13 -0
  107. quadra/configs/hydra/anomaly_custom.yaml +15 -0
  108. quadra/configs/hydra/default.yaml +14 -0
  109. quadra/configs/inference/default.yaml +26 -0
  110. quadra/configs/logger/comet.yaml +10 -0
  111. quadra/configs/logger/csv.yaml +5 -0
  112. quadra/configs/logger/mlflow.yaml +12 -0
  113. quadra/configs/logger/tensorboard.yaml +8 -0
  114. quadra/configs/loss/asl.yaml +7 -0
  115. quadra/configs/loss/barlow.yaml +2 -0
  116. quadra/configs/loss/bce.yaml +1 -0
  117. quadra/configs/loss/byol.yaml +1 -0
  118. quadra/configs/loss/cross_entropy.yaml +1 -0
  119. quadra/configs/loss/dino.yaml +8 -0
  120. quadra/configs/loss/simclr.yaml +2 -0
  121. quadra/configs/loss/simsiam.yaml +1 -0
  122. quadra/configs/loss/smp_ce.yaml +3 -0
  123. quadra/configs/loss/smp_dice.yaml +2 -0
  124. quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
  125. quadra/configs/loss/smp_mcc.yaml +2 -0
  126. quadra/configs/loss/vicreg.yaml +5 -0
  127. quadra/configs/model/anomalib/cfa.yaml +35 -0
  128. quadra/configs/model/anomalib/cflow.yaml +30 -0
  129. quadra/configs/model/anomalib/csflow.yaml +34 -0
  130. quadra/configs/model/anomalib/dfm.yaml +19 -0
  131. quadra/configs/model/anomalib/draem.yaml +29 -0
  132. quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
  133. quadra/configs/model/anomalib/fastflow.yaml +32 -0
  134. quadra/configs/model/anomalib/padim.yaml +32 -0
  135. quadra/configs/model/anomalib/patchcore.yaml +36 -0
  136. quadra/configs/model/barlow.yaml +16 -0
  137. quadra/configs/model/byol.yaml +25 -0
  138. quadra/configs/model/classification.yaml +10 -0
  139. quadra/configs/model/dino.yaml +26 -0
  140. quadra/configs/model/logistic_regression.yaml +4 -0
  141. quadra/configs/model/multilabel_classification.yaml +9 -0
  142. quadra/configs/model/simclr.yaml +18 -0
  143. quadra/configs/model/simsiam.yaml +24 -0
  144. quadra/configs/model/smp.yaml +4 -0
  145. quadra/configs/model/smp_multiclass.yaml +4 -0
  146. quadra/configs/model/vicreg.yaml +16 -0
  147. quadra/configs/optimizer/adam.yaml +5 -0
  148. quadra/configs/optimizer/adamw.yaml +3 -0
  149. quadra/configs/optimizer/default.yaml +4 -0
  150. quadra/configs/optimizer/lars.yaml +8 -0
  151. quadra/configs/optimizer/sgd.yaml +4 -0
  152. quadra/configs/scheduler/default.yaml +5 -0
  153. quadra/configs/scheduler/rop.yaml +5 -0
  154. quadra/configs/scheduler/step.yaml +3 -0
  155. quadra/configs/scheduler/warmrestart.yaml +2 -0
  156. quadra/configs/scheduler/warmup.yaml +6 -0
  157. quadra/configs/task/anomalib/cfa.yaml +5 -0
  158. quadra/configs/task/anomalib/cflow.yaml +5 -0
  159. quadra/configs/task/anomalib/csflow.yaml +5 -0
  160. quadra/configs/task/anomalib/draem.yaml +5 -0
  161. quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
  162. quadra/configs/task/anomalib/fastflow.yaml +5 -0
  163. quadra/configs/task/anomalib/inference.yaml +3 -0
  164. quadra/configs/task/anomalib/padim.yaml +5 -0
  165. quadra/configs/task/anomalib/patchcore.yaml +5 -0
  166. quadra/configs/task/classification.yaml +6 -0
  167. quadra/configs/task/classification_evaluation.yaml +6 -0
  168. quadra/configs/task/default.yaml +1 -0
  169. quadra/configs/task/segmentation.yaml +9 -0
  170. quadra/configs/task/segmentation_evaluation.yaml +3 -0
  171. quadra/configs/task/sklearn_classification.yaml +13 -0
  172. quadra/configs/task/sklearn_classification_patch.yaml +11 -0
  173. quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
  174. quadra/configs/task/sklearn_classification_test.yaml +8 -0
  175. quadra/configs/task/ssl.yaml +2 -0
  176. quadra/configs/trainer/lightning_cpu.yaml +36 -0
  177. quadra/configs/trainer/lightning_gpu.yaml +35 -0
  178. quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
  179. quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
  180. quadra/configs/trainer/lightning_multigpu.yaml +37 -0
  181. quadra/configs/trainer/sklearn_classification.yaml +7 -0
  182. quadra/configs/transforms/byol.yaml +47 -0
  183. quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
  184. quadra/configs/transforms/default.yaml +37 -0
  185. quadra/configs/transforms/default_numpy.yaml +24 -0
  186. quadra/configs/transforms/default_resize.yaml +22 -0
  187. quadra/configs/transforms/dino.yaml +63 -0
  188. quadra/configs/transforms/linear_eval.yaml +18 -0
  189. quadra/datamodules/__init__.py +20 -0
  190. quadra/datamodules/anomaly.py +180 -0
  191. quadra/datamodules/base.py +375 -0
  192. quadra/datamodules/classification.py +1003 -0
  193. quadra/datamodules/generic/__init__.py +0 -0
  194. quadra/datamodules/generic/imagenette.py +144 -0
  195. quadra/datamodules/generic/mnist.py +81 -0
  196. quadra/datamodules/generic/mvtec.py +58 -0
  197. quadra/datamodules/generic/oxford_pet.py +163 -0
  198. quadra/datamodules/patch.py +190 -0
  199. quadra/datamodules/segmentation.py +742 -0
  200. quadra/datamodules/ssl.py +140 -0
  201. quadra/datasets/__init__.py +17 -0
  202. quadra/datasets/anomaly.py +287 -0
  203. quadra/datasets/classification.py +241 -0
  204. quadra/datasets/patch.py +138 -0
  205. quadra/datasets/segmentation.py +239 -0
  206. quadra/datasets/ssl.py +110 -0
  207. quadra/losses/__init__.py +0 -0
  208. quadra/losses/classification/__init__.py +6 -0
  209. quadra/losses/classification/asl.py +83 -0
  210. quadra/losses/classification/focal.py +320 -0
  211. quadra/losses/classification/prototypical.py +148 -0
  212. quadra/losses/ssl/__init__.py +17 -0
  213. quadra/losses/ssl/barlowtwins.py +47 -0
  214. quadra/losses/ssl/byol.py +37 -0
  215. quadra/losses/ssl/dino.py +129 -0
  216. quadra/losses/ssl/hyperspherical.py +45 -0
  217. quadra/losses/ssl/idmm.py +50 -0
  218. quadra/losses/ssl/simclr.py +67 -0
  219. quadra/losses/ssl/simsiam.py +30 -0
  220. quadra/losses/ssl/vicreg.py +76 -0
  221. quadra/main.py +49 -0
  222. quadra/metrics/__init__.py +3 -0
  223. quadra/metrics/segmentation.py +251 -0
  224. quadra/models/__init__.py +0 -0
  225. quadra/models/base.py +151 -0
  226. quadra/models/classification/__init__.py +8 -0
  227. quadra/models/classification/backbones.py +149 -0
  228. quadra/models/classification/base.py +92 -0
  229. quadra/models/evaluation.py +322 -0
  230. quadra/modules/__init__.py +0 -0
  231. quadra/modules/backbone.py +30 -0
  232. quadra/modules/base.py +312 -0
  233. quadra/modules/classification/__init__.py +3 -0
  234. quadra/modules/classification/base.py +327 -0
  235. quadra/modules/ssl/__init__.py +17 -0
  236. quadra/modules/ssl/barlowtwins.py +59 -0
  237. quadra/modules/ssl/byol.py +172 -0
  238. quadra/modules/ssl/common.py +285 -0
  239. quadra/modules/ssl/dino.py +186 -0
  240. quadra/modules/ssl/hyperspherical.py +206 -0
  241. quadra/modules/ssl/idmm.py +98 -0
  242. quadra/modules/ssl/simclr.py +73 -0
  243. quadra/modules/ssl/simsiam.py +68 -0
  244. quadra/modules/ssl/vicreg.py +67 -0
  245. quadra/optimizers/__init__.py +4 -0
  246. quadra/optimizers/lars.py +153 -0
  247. quadra/optimizers/sam.py +127 -0
  248. quadra/schedulers/__init__.py +3 -0
  249. quadra/schedulers/base.py +44 -0
  250. quadra/schedulers/warmup.py +127 -0
  251. quadra/tasks/__init__.py +24 -0
  252. quadra/tasks/anomaly.py +582 -0
  253. quadra/tasks/base.py +397 -0
  254. quadra/tasks/classification.py +1263 -0
  255. quadra/tasks/patch.py +492 -0
  256. quadra/tasks/segmentation.py +389 -0
  257. quadra/tasks/ssl.py +560 -0
  258. quadra/trainers/README.md +3 -0
  259. quadra/trainers/__init__.py +0 -0
  260. quadra/trainers/classification.py +179 -0
  261. quadra/utils/__init__.py +0 -0
  262. quadra/utils/anomaly.py +112 -0
  263. quadra/utils/classification.py +618 -0
  264. quadra/utils/deprecation.py +31 -0
  265. quadra/utils/evaluation.py +474 -0
  266. quadra/utils/export.py +585 -0
  267. quadra/utils/imaging.py +32 -0
  268. quadra/utils/logger.py +15 -0
  269. quadra/utils/mlflow.py +98 -0
  270. quadra/utils/model_manager.py +320 -0
  271. quadra/utils/models.py +523 -0
  272. quadra/utils/patch/__init__.py +15 -0
  273. quadra/utils/patch/dataset.py +1433 -0
  274. quadra/utils/patch/metrics.py +449 -0
  275. quadra/utils/patch/model.py +153 -0
  276. quadra/utils/patch/visualization.py +217 -0
  277. quadra/utils/resolver.py +42 -0
  278. quadra/utils/segmentation.py +31 -0
  279. quadra/utils/tests/__init__.py +0 -0
  280. quadra/utils/tests/fixtures/__init__.py +1 -0
  281. quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
  282. quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
  283. quadra/utils/tests/fixtures/dataset/classification.py +406 -0
  284. quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
  285. quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
  286. quadra/utils/tests/fixtures/models/__init__.py +3 -0
  287. quadra/utils/tests/fixtures/models/anomaly.py +89 -0
  288. quadra/utils/tests/fixtures/models/classification.py +45 -0
  289. quadra/utils/tests/fixtures/models/segmentation.py +33 -0
  290. quadra/utils/tests/helpers.py +70 -0
  291. quadra/utils/tests/models.py +27 -0
  292. quadra/utils/utils.py +525 -0
  293. quadra/utils/validator.py +115 -0
  294. quadra/utils/visualization.py +422 -0
  295. quadra/utils/vit_explainability.py +349 -0
  296. quadra-2.2.7.dist-info/LICENSE +201 -0
  297. quadra-2.2.7.dist-info/METADATA +381 -0
  298. quadra-2.2.7.dist-info/RECORD +300 -0
  299. {quadra-0.0.1.dist-info → quadra-2.2.7.dist-info}/WHEEL +1 -1
  300. quadra-2.2.7.dist-info/entry_points.txt +3 -0
  301. quadra-0.0.1.dist-info/METADATA +0 -14
  302. quadra-0.0.1.dist-info/RECORD +0 -4
@@ -0,0 +1,582 @@
1
+ from __future__ import annotations
2
+
3
+ import csv
4
+ import glob
5
+ import json
6
+ import os
7
+ from collections import Counter
8
+ from typing import Any, Generic, Literal, TypeVar, cast
9
+
10
+ import cv2
11
+ import hydra
12
+ import numpy as np
13
+ import torch
14
+ from anomalib.models.components.base import AnomalyModule
15
+ from anomalib.post_processing import anomaly_map_to_color_map
16
+ from anomalib.utils import plot_cumulative_histogram
17
+ from anomalib.utils.callbacks.min_max_normalization import MinMaxNormalizationCallback
18
+ from anomalib.utils.metrics.optimal_f1 import OptimalF1
19
+ from matplotlib import pyplot as plt
20
+ from omegaconf import DictConfig
21
+ from sklearn.metrics import ConfusionMatrixDisplay, f1_score
22
+ from tqdm import tqdm
23
+
24
+ from quadra.callbacks.mlflow import get_mlflow_logger
25
+ from quadra.datamodules import AnomalyDataModule
26
+ from quadra.modules.base import ModelSignatureWrapper
27
+ from quadra.tasks.base import Evaluation, LightningTask
28
+ from quadra.utils import utils
29
+ from quadra.utils.anomaly import MapOrValue, ThresholdNormalizationCallback, normalize_anomaly_score
30
+ from quadra.utils.classification import get_results
31
+ from quadra.utils.evaluation import automatic_datamodule_batch_size
32
+ from quadra.utils.export import export_model
33
+
34
+ log = utils.get_logger(__name__)
35
+
36
+ AnomalyDataModuleT = TypeVar("AnomalyDataModuleT", bound=AnomalyDataModule)
37
+
38
+
39
+ class AnomalibDetection(Generic[AnomalyDataModuleT], LightningTask[AnomalyDataModuleT]):
40
+ """Anomaly Detection Task.
41
+
42
+ Args:
43
+ config: The experiment configuration
44
+ module_function: The function that instantiates the module and model
45
+ checkpoint_path: The path to the checkpoint to load the model from.
46
+ Defaults to None.
47
+ run_test: Whether to run the test after training. Defaults to False.
48
+ report: Whether to report the results. Defaults to False.
49
+ """
50
+
51
+ def __init__(
52
+ self,
53
+ config: DictConfig,
54
+ module_function: DictConfig,
55
+ checkpoint_path: str | None = None,
56
+ run_test: bool = True,
57
+ report: bool = True,
58
+ ):
59
+ super().__init__(
60
+ config=config,
61
+ checkpoint_path=checkpoint_path,
62
+ run_test=run_test,
63
+ report=report,
64
+ )
65
+ self._module: AnomalyModule
66
+ self.module_function = module_function
67
+ self.export_folder = "deployment_model"
68
+ self.report_path = ""
69
+ self.test_results: list[dict] | None = None
70
+
71
+ @property
72
+ def module(self) -> AnomalyModule:
73
+ """Get the module."""
74
+ return self._module
75
+
76
+ @module.setter
77
+ def module(self, module_config):
78
+ """Set the module."""
79
+ if hasattr(self.config.model.model, "input_size"):
80
+ transform_height = self.config.transforms.input_height
81
+ transform_width = self.config.transforms.input_width
82
+ original_model_height, original_model_width = self.config.model.model.input_size
83
+
84
+ if transform_height != original_model_height or transform_width != original_model_width:
85
+ log.warning(
86
+ "Model input size %dx%d "
87
+ "does not match the transform size %dx%d. "
88
+ "The model input size will be updated to match the transform size.",
89
+ original_model_height,
90
+ original_model_width,
91
+ transform_height,
92
+ transform_width,
93
+ )
94
+ self.config.model.model.input_size = [transform_height, transform_width]
95
+
96
+ _module = cast(
97
+ AnomalyModule,
98
+ hydra.utils.instantiate(
99
+ self.module_function,
100
+ module_config,
101
+ ),
102
+ )
103
+
104
+ self._module = _module
105
+
106
+ def prepare(self) -> None:
107
+ """Prepare the task."""
108
+ super().prepare()
109
+ self.module = self.config.model
110
+ self.module.model = ModelSignatureWrapper(self.module.model)
111
+
112
+ def export(self) -> None:
113
+ """Export model for production."""
114
+ if self.config.trainer.get("fast_dev_run"):
115
+ log.warning("Skipping export since fast_dev_run is enabled")
116
+ return
117
+
118
+ model = self.module.model
119
+
120
+ input_shapes = self.config.export.input_shapes
121
+
122
+ half_precision = "16" in self.trainer.precision
123
+
124
+ model_json, export_paths = export_model(
125
+ config=self.config,
126
+ model=model,
127
+ export_folder=self.export_folder,
128
+ half_precision=half_precision,
129
+ input_shapes=input_shapes,
130
+ idx_to_class={0: "good", 1: "defect"},
131
+ )
132
+
133
+ if len(export_paths) == 0:
134
+ return
135
+
136
+ model_json["image_threshold"] = np.round(self.module.image_threshold.value.item(), 3)
137
+ model_json["pixel_threshold"] = np.round(self.module.pixel_threshold.value.item(), 3)
138
+ model_json["anomaly_method"] = self.config.model.model.name
139
+
140
+ with open(os.path.join(self.export_folder, "model.json"), "w") as f:
141
+ json.dump(model_json, f, cls=utils.HydraEncoder)
142
+
143
+ def test(self) -> Any:
144
+ """Lightning test."""
145
+ self.test_results = super().test()
146
+ return self.test_results
147
+
148
+ def _generate_report(self) -> None:
149
+ """Generate a report for the task."""
150
+ if len(self.report_path) > 0:
151
+ os.makedirs(self.report_path, exist_ok=True)
152
+
153
+ # Save json with test results
154
+ if self.test_results is not None:
155
+ with open(os.path.join(self.report_path, "test_results.json"), "w") as f:
156
+ json.dump(self.test_results[0], f)
157
+
158
+ all_output = cast(
159
+ list[dict], self.trainer.predict(model=self.module, dataloaders=self.datamodule.test_dataloader())
160
+ )
161
+ all_output_flatten: dict[str, torch.Tensor | list] = {}
162
+
163
+ for key in all_output[0]:
164
+ if type(all_output[0][key]) == torch.Tensor:
165
+ tensor_gatherer = torch.cat([x[key] for x in all_output])
166
+ all_output_flatten[key] = tensor_gatherer
167
+ else:
168
+ list_gatherer = []
169
+ for x in all_output:
170
+ list_gatherer.extend(x[key])
171
+ all_output_flatten[key] = list_gatherer
172
+
173
+ image_paths = all_output_flatten["image_path"]
174
+ named_labels = [x.split("/")[-2] for x in all_output_flatten["image_path"]]
175
+
176
+ class_to_idx = {"good": 0}
177
+ idx = 1
178
+ for cls in set(named_labels):
179
+ if cls == "good":
180
+ continue
181
+
182
+ class_to_idx[cls] = idx
183
+ idx += 1
184
+
185
+ class_to_idx["false_defect"] = idx
186
+ idx_to_class = {v: k for k, v in class_to_idx.items()}
187
+
188
+ gt_labels = [class_to_idx[x] for x in named_labels]
189
+ pred_labels = []
190
+ for i, _ in enumerate(named_labels):
191
+ pred_label = all_output_flatten["pred_labels"][i].item()
192
+
193
+ if pred_label == 0:
194
+ pred_labels.append(0)
195
+ elif pred_label == 1 and gt_labels[i] == 0:
196
+ if idx > 2:
197
+ pred_labels.append(class_to_idx["false_defect"])
198
+ else:
199
+ pred_labels.append(1)
200
+ else:
201
+ pred_labels.append(class_to_idx[named_labels[i]])
202
+
203
+ if class_to_idx["false_defect"] not in pred_labels:
204
+ # If there are no false defects remove the label from the confusion matrix
205
+ class_to_idx.pop("false_defect")
206
+
207
+ anomaly_scores = all_output_flatten["pred_scores"]
208
+ if isinstance(anomaly_scores, torch.Tensor):
209
+ exportable_anomaly_scores = anomaly_scores.cpu().numpy()
210
+ else:
211
+ exportable_anomaly_scores = anomaly_scores
212
+
213
+ # Zip the lists together to create rows for the CSV file
214
+ rows = zip(image_paths, pred_labels, gt_labels, exportable_anomaly_scores)
215
+ # Specify the CSV file name
216
+ csv_file = "test_predictions.csv"
217
+ # Write the data to the CSV file
218
+ with open(csv_file, mode="w", newline="") as file:
219
+ writer = csv.writer(file)
220
+ # Write the header if needed
221
+ writer.writerow(["image_path", "predicted_label", "ground_truth_label", "predicted_score"])
222
+ # Write the rows
223
+ writer.writerows(rows)
224
+
225
+ log.info("CSV file %s has been created.", csv_file)
226
+
227
+ if not isinstance(anomaly_scores, torch.Tensor):
228
+ raise ValueError("Anomaly scores must be a tensor")
229
+
230
+ good_scores = anomaly_scores[np.where(all_output_flatten["label"] == 0)]
231
+ defect_scores = anomaly_scores[np.where(all_output_flatten["label"] == 1)]
232
+
233
+ # Lightning has a callback attribute but is not inside the __init__ so mypy complains
234
+ if any(
235
+ isinstance(x, MinMaxNormalizationCallback)
236
+ for x in self.trainer.callbacks # type: ignore[attr-defined]
237
+ ):
238
+ threshold = torch.tensor(0.5)
239
+ elif any(
240
+ isinstance(x, ThresholdNormalizationCallback)
241
+ for x in self.trainer.callbacks # type: ignore[attr-defined]
242
+ ):
243
+ threshold = torch.tensor(100.0)
244
+ else:
245
+ threshold = self.module.image_metrics.F1Score.threshold
246
+
247
+ # The output of the prediction is a normalized score so the cumulative histogram is displayed with the
248
+ # normalized scores
249
+ plot_cumulative_histogram(
250
+ good_scores.cpu().numpy(), defect_scores.cpu().numpy(), threshold.item(), self.report_path
251
+ )
252
+
253
+ _, pd_cm, _ = get_results(np.array(gt_labels), np.array(pred_labels), idx_to_class)
254
+ np_cm = np.array(pd_cm)
255
+ disp = ConfusionMatrixDisplay(
256
+ confusion_matrix=np_cm,
257
+ display_labels=class_to_idx.keys(),
258
+ )
259
+ disp.plot(include_values=True, cmap=plt.cm.Greens, ax=None, colorbar=False, xticks_rotation=90)
260
+ plt.title("Confusion Matrix")
261
+ plt.savefig(
262
+ os.path.join(self.report_path, "test_confusion_matrix.png"), bbox_inches="tight", pad_inches=0, dpi=300
263
+ )
264
+ plt.close()
265
+
266
+ avg_score_dict = {k: 0.0 for k in set(named_labels)}
267
+
268
+ for i, item in enumerate(named_labels):
269
+ avg_score_dict[item] += all_output_flatten["pred_scores"][i].item()
270
+
271
+ counter = Counter(named_labels)
272
+ avg_score_dict = {k: v / counter[k] for k, v in avg_score_dict.items()}
273
+ avg_score_dict = dict(sorted(avg_score_dict.items(), key=lambda q: q[1]))
274
+
275
+ with open(os.path.join(self.report_path, "avg_score_by_label.csv"), "w") as f:
276
+ f.write("label,avg_anomaly_score\n")
277
+ for k, v in avg_score_dict.items():
278
+ f.write(f"{k},{v:.3f}\n")
279
+
280
+ def generate_report(self):
281
+ """Generate a report for the task and try to upload artifacts."""
282
+ self._generate_report()
283
+ self._upload_artifacts()
284
+
285
+ def _upload_artifacts(self):
286
+ """If MLflow is available upload artifacts to the artifact repository."""
287
+ mflow_logger = get_mlflow_logger(trainer=self.trainer)
288
+ tensorboard_logger = utils.get_tensorboard_logger(trainer=self.trainer)
289
+
290
+ if mflow_logger is not None and self.config.core.get("upload_artifacts"):
291
+ mflow_logger.experiment.log_artifact(run_id=mflow_logger.run_id, local_path="test_confusion_matrix.png")
292
+ mflow_logger.experiment.log_artifact(run_id=mflow_logger.run_id, local_path="avg_score_by_label.csv")
293
+
294
+ if "visualizer" in self.config.callbacks:
295
+ artifacts = glob.glob(os.path.join(self.config.callbacks.visualizer.output_path, "**", "*"))
296
+ for a in artifacts:
297
+ mflow_logger.experiment.log_artifact(
298
+ run_id=mflow_logger.run_id, local_path=a, artifact_path="anomaly_output"
299
+ )
300
+
301
+ if tensorboard_logger is not None and self.config.core.get("upload_artifacts"):
302
+ artifacts = []
303
+ artifacts.append("test_confusion_matrix.png")
304
+ artifacts.append("avg_score_by_label.csv")
305
+
306
+ if "visualizer" in self.config.callbacks:
307
+ artifacts.extend(
308
+ glob.glob(os.path.join(self.config.callbacks.visualizer.output_path, "**/*"), recursive=True)
309
+ )
310
+
311
+ for a in artifacts:
312
+ if os.path.isdir(a):
313
+ continue
314
+
315
+ ext = os.path.splitext(a)[1].lower()
316
+
317
+ if ext in [".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".tif", ".gif"]:
318
+ try:
319
+ img = cv2.imread(a)
320
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
321
+ except cv2.error:
322
+ log.info("Could not upload artifact image %s", a)
323
+ continue
324
+ output_path = os.path.sep.join(a.split(os.path.sep)[-2:])
325
+ tensorboard_logger.experiment.add_image(output_path, img, 0, dataformats="HWC")
326
+ else:
327
+ utils.upload_file_tensorboard(a, tensorboard_logger)
328
+
329
+
330
+ class AnomalibEvaluation(Evaluation[AnomalyDataModule]):
331
+ """Evaluation task for Anomalib.
332
+
333
+ Args:
334
+ config: Task configuration
335
+ model_path: Path to the model folder that contains an exported model
336
+ use_training_threshold: Whether to use the training threshold for the evaluation or use the one that
337
+ maximizes the F1 score on the test set.
338
+ device: Device to use for evaluation. If None, the device is automatically determined.
339
+ """
340
+
341
+ def __init__(
342
+ self,
343
+ config: DictConfig,
344
+ model_path: str,
345
+ use_training_threshold: bool = False,
346
+ device: str | None = None,
347
+ training_threshold_type: Literal["image", "pixel"] | None = None,
348
+ ):
349
+ super().__init__(config=config, model_path=model_path, device=device)
350
+
351
+ self.use_training_threshold = use_training_threshold
352
+
353
+ if training_threshold_type is not None and training_threshold_type not in ["image", "pixel"]:
354
+ raise ValueError("Training threshold type must be either image or pixel")
355
+
356
+ if training_threshold_type is None and use_training_threshold:
357
+ log.warning("Using training threshold but no training threshold type is provided, defaulting to image")
358
+ training_threshold_type = "image"
359
+
360
+ self.training_threshold_type = training_threshold_type
361
+
362
+ def prepare(self) -> None:
363
+ """Prepare the evaluation."""
364
+ super().prepare()
365
+ self.datamodule = self.config.datamodule
366
+ # prepare_data() must be explicitly called because there is no lightning training
367
+ self.datamodule.prepare_data()
368
+ self.datamodule.setup(stage="test")
369
+
370
+ @automatic_datamodule_batch_size(batch_size_attribute_name="test_batch_size")
371
+ def test(self) -> None:
372
+ """Perform test."""
373
+ log.info("Running test")
374
+ test_dataloader = self.datamodule.test_dataloader()
375
+
376
+ optimal_f1 = OptimalF1(num_classes=None, pos_label=1) # type: ignore[arg-type]
377
+
378
+ anomaly_scores = []
379
+ anomaly_maps = []
380
+ image_labels = []
381
+ image_paths = []
382
+
383
+ with torch.no_grad():
384
+ for batch_item in tqdm(test_dataloader):
385
+ batch_images = batch_item["image"]
386
+ batch_labels = batch_item["label"]
387
+ image_labels.extend(batch_labels.tolist())
388
+ image_paths.extend(batch_item["image_path"])
389
+ batch_images = batch_images.to(device=self.device, dtype=self.deployment_model.model_dtype)
390
+ if self.model_data.get("anomaly_method") == "efficientad":
391
+ model_output = self.deployment_model(batch_images, None)
392
+ else:
393
+ model_output = self.deployment_model(batch_images)
394
+ anomaly_map, anomaly_score = model_output[0], model_output[1]
395
+ anomaly_map = anomaly_map.cpu()
396
+ anomaly_score = anomaly_score.cpu()
397
+ known_labels = torch.where(batch_labels != -1)[0]
398
+ if len(known_labels) > 0:
399
+ # Skip computing F1 score for images without gt
400
+ optimal_f1.update(anomaly_score[known_labels], batch_labels[known_labels])
401
+ anomaly_scores.append(anomaly_score)
402
+ anomaly_maps.append(anomaly_map)
403
+
404
+ anomaly_scores = torch.cat(anomaly_scores)
405
+ anomaly_maps = torch.cat(anomaly_maps)
406
+
407
+ if any(x != -1 for x in image_labels):
408
+ if self.use_training_threshold:
409
+ _image_labels = torch.tensor(image_labels)
410
+ threshold = torch.tensor(float(self.model_data[f"{self.training_threshold_type}_threshold"]))
411
+ known_labels = torch.where(_image_labels != -1)[0]
412
+
413
+ _image_labels = _image_labels[known_labels]
414
+ _anomaly_scores = anomaly_scores[known_labels]
415
+
416
+ pred_labels = (_anomaly_scores >= threshold).long()
417
+
418
+ optimal_f1_score = torch.tensor(f1_score(_image_labels, pred_labels))
419
+ else:
420
+ optimal_f1_score = optimal_f1.compute()
421
+ threshold = optimal_f1.threshold
422
+ else:
423
+ log.warning("No ground truth available during evaluation, use training image threshold for reporting")
424
+ optimal_f1_score = torch.tensor(0)
425
+ threshold = torch.tensor(float(self.model_data["image_threshold"]))
426
+
427
+ log.info("Computed F1 score: %s", optimal_f1_score.item())
428
+ self.metadata["anomaly_scores"] = anomaly_scores
429
+ self.metadata["anomaly_maps"] = anomaly_maps
430
+ self.metadata["image_labels"] = image_labels
431
+ self.metadata["image_paths"] = image_paths
432
+ self.metadata["threshold"] = threshold.item()
433
+ self.metadata["optimal_f1"] = optimal_f1_score.item()
434
+
435
+ def generate_report(self) -> None:
436
+ """Generate report."""
437
+ log.info("Generating report")
438
+ if len(self.report_path) > 0:
439
+ os.makedirs(self.report_path, exist_ok=True)
440
+
441
+ # TODO: We currently don't use anomaly for segmentation, so the pixel threshold handling is not properly
442
+ # implemented and we produce as output only a single threshold.
443
+ training_threshold = self.model_data[f"{self.training_threshold_type}_threshold"]
444
+ optimal_threshold = self.metadata["threshold"]
445
+
446
+ normalized_optimal_threshold = cast(float, normalize_anomaly_score(optimal_threshold, training_threshold))
447
+
448
+ os.makedirs(os.path.join(self.report_path, "predictions"), exist_ok=True)
449
+ os.makedirs(os.path.join(self.report_path, "heatmaps"), exist_ok=True)
450
+
451
+ anomaly_scores = self.metadata["anomaly_scores"].cpu().numpy()
452
+ anomaly_scores = normalize_anomaly_score(anomaly_scores, training_threshold)
453
+
454
+ if not isinstance(anomaly_scores, np.ndarray):
455
+ raise ValueError("Anomaly scores must be a numpy array")
456
+
457
+ good_scores = anomaly_scores[np.where(np.array(self.metadata["image_labels"]) == 0)]
458
+ defect_scores = anomaly_scores[np.where(np.array(self.metadata["image_labels"]) == 1)]
459
+
460
+ count_overlapping_scores = 0
461
+
462
+ if len(good_scores) != 0 and len(defect_scores) != 0 and defect_scores.min() <= good_scores.max():
463
+ count_overlapping_scores = len(
464
+ np.where((anomaly_scores >= defect_scores.min()) & (anomaly_scores <= good_scores.max()))[0]
465
+ )
466
+
467
+ plot_cumulative_histogram(good_scores, defect_scores, normalized_optimal_threshold, self.report_path)
468
+
469
+ json_output = {
470
+ "observations": [],
471
+ "threshold": np.round(normalized_optimal_threshold, 3),
472
+ "unnormalized_threshold": np.round(optimal_threshold, 3),
473
+ "f1_score": np.round(self.metadata["optimal_f1"], 3),
474
+ "metrics": {
475
+ "overlapping_scores": count_overlapping_scores,
476
+ },
477
+ }
478
+
479
+ tg, fb, fg, tb = 0, 0, 0, 0
480
+
481
+ mask_area = None
482
+ crop_area = None
483
+
484
+ if hasattr(self.datamodule, "valid_area_mask") and self.datamodule.valid_area_mask is not None:
485
+ mask_area = cv2.imread(self.datamodule.valid_area_mask, 0)
486
+ mask_area = (mask_area > 0).astype(np.uint8) # type: ignore[operator]
487
+
488
+ if hasattr(self.datamodule, "crop_area") and self.datamodule.crop_area is not None:
489
+ crop_area = self.datamodule.crop_area
490
+
491
+ anomaly_maps = normalize_anomaly_score(self.metadata["anomaly_maps"], training_threshold)
492
+
493
+ if not isinstance(anomaly_maps, torch.Tensor):
494
+ raise ValueError("Anomaly maps must be a tensor")
495
+
496
+ for img_path, gt_label, anomaly_score, anomaly_map in tqdm(
497
+ zip(
498
+ self.metadata["image_paths"],
499
+ self.metadata["image_labels"],
500
+ anomaly_scores,
501
+ anomaly_maps,
502
+ ),
503
+ total=len(self.metadata["image_paths"]),
504
+ ):
505
+ img = cv2.imread(img_path, 0)
506
+ if mask_area is not None:
507
+ img = img * mask_area # type: ignore[operator]
508
+
509
+ if crop_area is not None:
510
+ img = img[crop_area[1] : crop_area[3], crop_area[0] : crop_area[2]]
511
+
512
+ output_mask = (anomaly_map >= normalized_optimal_threshold).cpu().numpy().squeeze().astype(np.uint8)
513
+ output_mask_label = os.path.basename(os.path.dirname(img_path))
514
+ output_mask_name = os.path.splitext(os.path.basename(img_path))[0] + ".png"
515
+ pred_label = int(anomaly_score >= normalized_optimal_threshold)
516
+
517
+ json_output["observations"].append(
518
+ {
519
+ "image_path": os.path.dirname(img_path),
520
+ "file_name": os.path.basename(img_path),
521
+ "expectation": gt_label if gt_label != -1 else "",
522
+ "prediction": pred_label,
523
+ "prediction_mask": os.path.join("predictions", output_mask_label, output_mask_name),
524
+ "prediction_heatmap": os.path.join("heatmaps", output_mask_label, output_mask_name),
525
+ "is_correct": pred_label == gt_label if gt_label != -1 else True,
526
+ "anomaly_score": f"{anomaly_score.item():.3f}",
527
+ }
528
+ )
529
+
530
+ if gt_label == 0 and pred_label == 0:
531
+ tg += 1
532
+ elif gt_label == 0 and pred_label == 1:
533
+ fb += 1
534
+ elif gt_label == 1 and pred_label == 0:
535
+ fg += 1
536
+ elif gt_label == 1 and pred_label == 1:
537
+ tb += 1
538
+
539
+ output_mask = output_mask * 255
540
+ output_mask = cv2.resize(output_mask, (img.shape[1], img.shape[0]))
541
+ output_prediction_folder = os.path.join(self.report_path, "predictions", output_mask_label)
542
+ os.makedirs(output_prediction_folder, exist_ok=True)
543
+ cv2.imwrite(os.path.join(output_prediction_folder, output_mask_name), output_mask)
544
+
545
+ # Normalize the map and rescale it to 0-1 range
546
+ # In this case we are saying that the anomaly map is in the range [normalized_th - 50, normalized_th + 50]
547
+ # This allow to have a stronger color for the anomalies and a lighter one for really normal regions
548
+ # It's also independent from the max or min anomaly score!
549
+ normalized_map: MapOrValue = (anomaly_map - (normalized_optimal_threshold - 50)) / 100
550
+
551
+ if isinstance(normalized_map, torch.Tensor):
552
+ normalized_map = normalized_map.cpu().numpy().squeeze()
553
+
554
+ normalized_map = np.clip(normalized_map, 0, 1)
555
+ output_heatmap = anomaly_map_to_color_map(normalized_map, normalize=False)
556
+ output_heatmap = cv2.resize(output_heatmap, (img.shape[1], img.shape[0]))
557
+
558
+ output_heatmap_folder = os.path.join(self.report_path, "heatmaps", output_mask_label)
559
+ os.makedirs(output_heatmap_folder, exist_ok=True)
560
+
561
+ cv2.imwrite(
562
+ os.path.join(output_heatmap_folder, output_mask_name),
563
+ cv2.cvtColor(output_heatmap, cv2.COLOR_RGB2BGR),
564
+ )
565
+
566
+ json_output["metrics"]["confusion_matrix"] = {
567
+ "class_labels": ["normal", "anomaly"],
568
+ "matrix": [
569
+ [tg, fb],
570
+ [fg, tb],
571
+ ],
572
+ }
573
+
574
+ with open(os.path.join(self.report_path, "anomaly_test_output.json"), "w") as f:
575
+ json.dump(json_output, f)
576
+
577
+ def execute(self) -> None:
578
+ """Execute the evaluation."""
579
+ self.prepare()
580
+ self.test()
581
+ self.generate_report()
582
+ self.finalize()