quadra 0.0.1__py3-none-any.whl → 2.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (302) hide show
  1. hydra_plugins/quadra_searchpath_plugin.py +30 -0
  2. quadra/__init__.py +6 -0
  3. quadra/callbacks/__init__.py +0 -0
  4. quadra/callbacks/anomalib.py +289 -0
  5. quadra/callbacks/lightning.py +501 -0
  6. quadra/callbacks/mlflow.py +291 -0
  7. quadra/callbacks/scheduler.py +69 -0
  8. quadra/configs/__init__.py +0 -0
  9. quadra/configs/backbone/caformer_m36.yaml +8 -0
  10. quadra/configs/backbone/caformer_s36.yaml +8 -0
  11. quadra/configs/backbone/convnextv2_base.yaml +8 -0
  12. quadra/configs/backbone/convnextv2_femto.yaml +8 -0
  13. quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
  14. quadra/configs/backbone/dino_vitb8.yaml +12 -0
  15. quadra/configs/backbone/dino_vits8.yaml +12 -0
  16. quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
  17. quadra/configs/backbone/dinov2_vits14.yaml +12 -0
  18. quadra/configs/backbone/efficientnet_b0.yaml +8 -0
  19. quadra/configs/backbone/efficientnet_b1.yaml +8 -0
  20. quadra/configs/backbone/efficientnet_b2.yaml +8 -0
  21. quadra/configs/backbone/efficientnet_b3.yaml +8 -0
  22. quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
  23. quadra/configs/backbone/levit_128s.yaml +8 -0
  24. quadra/configs/backbone/mnasnet0_5.yaml +9 -0
  25. quadra/configs/backbone/resnet101.yaml +8 -0
  26. quadra/configs/backbone/resnet18.yaml +8 -0
  27. quadra/configs/backbone/resnet18_ssl.yaml +8 -0
  28. quadra/configs/backbone/resnet50.yaml +8 -0
  29. quadra/configs/backbone/smp.yaml +9 -0
  30. quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
  31. quadra/configs/backbone/unetr.yaml +15 -0
  32. quadra/configs/backbone/vit16_base.yaml +9 -0
  33. quadra/configs/backbone/vit16_small.yaml +9 -0
  34. quadra/configs/backbone/vit16_tiny.yaml +9 -0
  35. quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
  36. quadra/configs/callbacks/all.yaml +45 -0
  37. quadra/configs/callbacks/default.yaml +34 -0
  38. quadra/configs/callbacks/default_anomalib.yaml +64 -0
  39. quadra/configs/config.yaml +33 -0
  40. quadra/configs/core/default.yaml +11 -0
  41. quadra/configs/datamodule/base/anomaly.yaml +16 -0
  42. quadra/configs/datamodule/base/classification.yaml +21 -0
  43. quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
  44. quadra/configs/datamodule/base/segmentation.yaml +18 -0
  45. quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
  46. quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
  47. quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
  48. quadra/configs/datamodule/base/ssl.yaml +21 -0
  49. quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
  50. quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
  51. quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
  52. quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
  53. quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
  54. quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
  55. quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
  56. quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
  57. quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
  58. quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
  59. quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
  60. quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
  61. quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
  62. quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
  63. quadra/configs/experiment/base/classification/classification.yaml +73 -0
  64. quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
  65. quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
  66. quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
  67. quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
  68. quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
  69. quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
  70. quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
  71. quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
  72. quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
  73. quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
  74. quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
  75. quadra/configs/experiment/base/ssl/byol.yaml +43 -0
  76. quadra/configs/experiment/base/ssl/dino.yaml +46 -0
  77. quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
  78. quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
  79. quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
  80. quadra/configs/experiment/custom/cls.yaml +12 -0
  81. quadra/configs/experiment/default.yaml +15 -0
  82. quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
  83. quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
  84. quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
  85. quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
  86. quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
  87. quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
  88. quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
  89. quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
  90. quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
  91. quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
  92. quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
  93. quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
  94. quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
  95. quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
  96. quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
  97. quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
  98. quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
  99. quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
  100. quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
  101. quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
  102. quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
  103. quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
  104. quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
  105. quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
  106. quadra/configs/export/default.yaml +13 -0
  107. quadra/configs/hydra/anomaly_custom.yaml +15 -0
  108. quadra/configs/hydra/default.yaml +14 -0
  109. quadra/configs/inference/default.yaml +26 -0
  110. quadra/configs/logger/comet.yaml +10 -0
  111. quadra/configs/logger/csv.yaml +5 -0
  112. quadra/configs/logger/mlflow.yaml +12 -0
  113. quadra/configs/logger/tensorboard.yaml +8 -0
  114. quadra/configs/loss/asl.yaml +7 -0
  115. quadra/configs/loss/barlow.yaml +2 -0
  116. quadra/configs/loss/bce.yaml +1 -0
  117. quadra/configs/loss/byol.yaml +1 -0
  118. quadra/configs/loss/cross_entropy.yaml +1 -0
  119. quadra/configs/loss/dino.yaml +8 -0
  120. quadra/configs/loss/simclr.yaml +2 -0
  121. quadra/configs/loss/simsiam.yaml +1 -0
  122. quadra/configs/loss/smp_ce.yaml +3 -0
  123. quadra/configs/loss/smp_dice.yaml +2 -0
  124. quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
  125. quadra/configs/loss/smp_mcc.yaml +2 -0
  126. quadra/configs/loss/vicreg.yaml +5 -0
  127. quadra/configs/model/anomalib/cfa.yaml +35 -0
  128. quadra/configs/model/anomalib/cflow.yaml +30 -0
  129. quadra/configs/model/anomalib/csflow.yaml +34 -0
  130. quadra/configs/model/anomalib/dfm.yaml +19 -0
  131. quadra/configs/model/anomalib/draem.yaml +29 -0
  132. quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
  133. quadra/configs/model/anomalib/fastflow.yaml +32 -0
  134. quadra/configs/model/anomalib/padim.yaml +32 -0
  135. quadra/configs/model/anomalib/patchcore.yaml +36 -0
  136. quadra/configs/model/barlow.yaml +16 -0
  137. quadra/configs/model/byol.yaml +25 -0
  138. quadra/configs/model/classification.yaml +10 -0
  139. quadra/configs/model/dino.yaml +26 -0
  140. quadra/configs/model/logistic_regression.yaml +4 -0
  141. quadra/configs/model/multilabel_classification.yaml +9 -0
  142. quadra/configs/model/simclr.yaml +18 -0
  143. quadra/configs/model/simsiam.yaml +24 -0
  144. quadra/configs/model/smp.yaml +4 -0
  145. quadra/configs/model/smp_multiclass.yaml +4 -0
  146. quadra/configs/model/vicreg.yaml +16 -0
  147. quadra/configs/optimizer/adam.yaml +5 -0
  148. quadra/configs/optimizer/adamw.yaml +3 -0
  149. quadra/configs/optimizer/default.yaml +4 -0
  150. quadra/configs/optimizer/lars.yaml +8 -0
  151. quadra/configs/optimizer/sgd.yaml +4 -0
  152. quadra/configs/scheduler/default.yaml +5 -0
  153. quadra/configs/scheduler/rop.yaml +5 -0
  154. quadra/configs/scheduler/step.yaml +3 -0
  155. quadra/configs/scheduler/warmrestart.yaml +2 -0
  156. quadra/configs/scheduler/warmup.yaml +6 -0
  157. quadra/configs/task/anomalib/cfa.yaml +5 -0
  158. quadra/configs/task/anomalib/cflow.yaml +5 -0
  159. quadra/configs/task/anomalib/csflow.yaml +5 -0
  160. quadra/configs/task/anomalib/draem.yaml +5 -0
  161. quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
  162. quadra/configs/task/anomalib/fastflow.yaml +5 -0
  163. quadra/configs/task/anomalib/inference.yaml +3 -0
  164. quadra/configs/task/anomalib/padim.yaml +5 -0
  165. quadra/configs/task/anomalib/patchcore.yaml +5 -0
  166. quadra/configs/task/classification.yaml +6 -0
  167. quadra/configs/task/classification_evaluation.yaml +6 -0
  168. quadra/configs/task/default.yaml +1 -0
  169. quadra/configs/task/segmentation.yaml +9 -0
  170. quadra/configs/task/segmentation_evaluation.yaml +3 -0
  171. quadra/configs/task/sklearn_classification.yaml +13 -0
  172. quadra/configs/task/sklearn_classification_patch.yaml +11 -0
  173. quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
  174. quadra/configs/task/sklearn_classification_test.yaml +8 -0
  175. quadra/configs/task/ssl.yaml +2 -0
  176. quadra/configs/trainer/lightning_cpu.yaml +36 -0
  177. quadra/configs/trainer/lightning_gpu.yaml +35 -0
  178. quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
  179. quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
  180. quadra/configs/trainer/lightning_multigpu.yaml +37 -0
  181. quadra/configs/trainer/sklearn_classification.yaml +7 -0
  182. quadra/configs/transforms/byol.yaml +47 -0
  183. quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
  184. quadra/configs/transforms/default.yaml +37 -0
  185. quadra/configs/transforms/default_numpy.yaml +24 -0
  186. quadra/configs/transforms/default_resize.yaml +22 -0
  187. quadra/configs/transforms/dino.yaml +63 -0
  188. quadra/configs/transforms/linear_eval.yaml +18 -0
  189. quadra/datamodules/__init__.py +20 -0
  190. quadra/datamodules/anomaly.py +180 -0
  191. quadra/datamodules/base.py +375 -0
  192. quadra/datamodules/classification.py +1003 -0
  193. quadra/datamodules/generic/__init__.py +0 -0
  194. quadra/datamodules/generic/imagenette.py +144 -0
  195. quadra/datamodules/generic/mnist.py +81 -0
  196. quadra/datamodules/generic/mvtec.py +58 -0
  197. quadra/datamodules/generic/oxford_pet.py +163 -0
  198. quadra/datamodules/patch.py +190 -0
  199. quadra/datamodules/segmentation.py +742 -0
  200. quadra/datamodules/ssl.py +140 -0
  201. quadra/datasets/__init__.py +17 -0
  202. quadra/datasets/anomaly.py +287 -0
  203. quadra/datasets/classification.py +241 -0
  204. quadra/datasets/patch.py +138 -0
  205. quadra/datasets/segmentation.py +239 -0
  206. quadra/datasets/ssl.py +110 -0
  207. quadra/losses/__init__.py +0 -0
  208. quadra/losses/classification/__init__.py +6 -0
  209. quadra/losses/classification/asl.py +83 -0
  210. quadra/losses/classification/focal.py +320 -0
  211. quadra/losses/classification/prototypical.py +148 -0
  212. quadra/losses/ssl/__init__.py +17 -0
  213. quadra/losses/ssl/barlowtwins.py +47 -0
  214. quadra/losses/ssl/byol.py +37 -0
  215. quadra/losses/ssl/dino.py +129 -0
  216. quadra/losses/ssl/hyperspherical.py +45 -0
  217. quadra/losses/ssl/idmm.py +50 -0
  218. quadra/losses/ssl/simclr.py +67 -0
  219. quadra/losses/ssl/simsiam.py +30 -0
  220. quadra/losses/ssl/vicreg.py +76 -0
  221. quadra/main.py +49 -0
  222. quadra/metrics/__init__.py +3 -0
  223. quadra/metrics/segmentation.py +251 -0
  224. quadra/models/__init__.py +0 -0
  225. quadra/models/base.py +151 -0
  226. quadra/models/classification/__init__.py +8 -0
  227. quadra/models/classification/backbones.py +149 -0
  228. quadra/models/classification/base.py +92 -0
  229. quadra/models/evaluation.py +322 -0
  230. quadra/modules/__init__.py +0 -0
  231. quadra/modules/backbone.py +30 -0
  232. quadra/modules/base.py +312 -0
  233. quadra/modules/classification/__init__.py +3 -0
  234. quadra/modules/classification/base.py +327 -0
  235. quadra/modules/ssl/__init__.py +17 -0
  236. quadra/modules/ssl/barlowtwins.py +59 -0
  237. quadra/modules/ssl/byol.py +172 -0
  238. quadra/modules/ssl/common.py +285 -0
  239. quadra/modules/ssl/dino.py +186 -0
  240. quadra/modules/ssl/hyperspherical.py +206 -0
  241. quadra/modules/ssl/idmm.py +98 -0
  242. quadra/modules/ssl/simclr.py +73 -0
  243. quadra/modules/ssl/simsiam.py +68 -0
  244. quadra/modules/ssl/vicreg.py +67 -0
  245. quadra/optimizers/__init__.py +4 -0
  246. quadra/optimizers/lars.py +153 -0
  247. quadra/optimizers/sam.py +127 -0
  248. quadra/schedulers/__init__.py +3 -0
  249. quadra/schedulers/base.py +44 -0
  250. quadra/schedulers/warmup.py +127 -0
  251. quadra/tasks/__init__.py +24 -0
  252. quadra/tasks/anomaly.py +582 -0
  253. quadra/tasks/base.py +397 -0
  254. quadra/tasks/classification.py +1263 -0
  255. quadra/tasks/patch.py +492 -0
  256. quadra/tasks/segmentation.py +389 -0
  257. quadra/tasks/ssl.py +560 -0
  258. quadra/trainers/README.md +3 -0
  259. quadra/trainers/__init__.py +0 -0
  260. quadra/trainers/classification.py +179 -0
  261. quadra/utils/__init__.py +0 -0
  262. quadra/utils/anomaly.py +112 -0
  263. quadra/utils/classification.py +618 -0
  264. quadra/utils/deprecation.py +31 -0
  265. quadra/utils/evaluation.py +474 -0
  266. quadra/utils/export.py +585 -0
  267. quadra/utils/imaging.py +32 -0
  268. quadra/utils/logger.py +15 -0
  269. quadra/utils/mlflow.py +98 -0
  270. quadra/utils/model_manager.py +320 -0
  271. quadra/utils/models.py +523 -0
  272. quadra/utils/patch/__init__.py +15 -0
  273. quadra/utils/patch/dataset.py +1433 -0
  274. quadra/utils/patch/metrics.py +449 -0
  275. quadra/utils/patch/model.py +153 -0
  276. quadra/utils/patch/visualization.py +217 -0
  277. quadra/utils/resolver.py +42 -0
  278. quadra/utils/segmentation.py +31 -0
  279. quadra/utils/tests/__init__.py +0 -0
  280. quadra/utils/tests/fixtures/__init__.py +1 -0
  281. quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
  282. quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
  283. quadra/utils/tests/fixtures/dataset/classification.py +406 -0
  284. quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
  285. quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
  286. quadra/utils/tests/fixtures/models/__init__.py +3 -0
  287. quadra/utils/tests/fixtures/models/anomaly.py +89 -0
  288. quadra/utils/tests/fixtures/models/classification.py +45 -0
  289. quadra/utils/tests/fixtures/models/segmentation.py +33 -0
  290. quadra/utils/tests/helpers.py +70 -0
  291. quadra/utils/tests/models.py +27 -0
  292. quadra/utils/utils.py +525 -0
  293. quadra/utils/validator.py +115 -0
  294. quadra/utils/visualization.py +422 -0
  295. quadra/utils/vit_explainability.py +349 -0
  296. quadra-2.2.7.dist-info/LICENSE +201 -0
  297. quadra-2.2.7.dist-info/METADATA +381 -0
  298. quadra-2.2.7.dist-info/RECORD +300 -0
  299. {quadra-0.0.1.dist-info → quadra-2.2.7.dist-info}/WHEEL +1 -1
  300. quadra-2.2.7.dist-info/entry_points.txt +3 -0
  301. quadra-0.0.1.dist-info/METADATA +0 -14
  302. quadra-0.0.1.dist-info/RECORD +0 -4
@@ -0,0 +1,179 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import cast
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import torch
8
+ from numpy import ndarray
9
+ from pandas import DataFrame
10
+ from sklearn.linear_model import LogisticRegression
11
+ from sklearn.linear_model._base import ClassifierMixin
12
+ from torch.utils.data import DataLoader
13
+
14
+ from quadra.utils import utils
15
+ from quadra.utils.classification import get_results
16
+ from quadra.utils.models import get_feature
17
+
18
+ log = utils.get_logger(__name__)
19
+
20
+
21
+ class SklearnClassificationTrainer:
22
+ """Class to configure and run a classification using torch for feature extraction and sklearn to fit a classifier.
23
+
24
+ Args:
25
+ input_shape: [H, W, C]
26
+ random_state: seed to fix randomness
27
+ classifier: classification model
28
+ iteration_over_training: the number of iteration over training during feature extraction
29
+ backbone: the feature extractor
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ input_shape: list,
35
+ backbone: torch.nn.Module,
36
+ random_state: int = 42,
37
+ classifier: ClassifierMixin = LogisticRegression,
38
+ iteration_over_training: int = 1,
39
+ ) -> None:
40
+ super().__init__()
41
+
42
+ try:
43
+ self.classifier = classifier(max_iter=1e4, random_state=random_state)
44
+ except Exception:
45
+ self.classifier = classifier
46
+
47
+ self.input_shape = input_shape
48
+ self.random_state = random_state
49
+ self.iteration_over_training = iteration_over_training
50
+ self.backbone = backbone
51
+
52
+ def change_backbone(self, backbone: torch.nn.Module):
53
+ """Update feature extractor."""
54
+ self.backbone = backbone
55
+ self.backbone.eval()
56
+
57
+ def change_classifier(self, classifier: ClassifierMixin):
58
+ """Update classifier."""
59
+ self.classifier = classifier
60
+
61
+ def fit(
62
+ self,
63
+ train_dataloader: DataLoader | None = None,
64
+ train_features: ndarray | None = None,
65
+ train_labels: ndarray | None = None,
66
+ ):
67
+ """Fit classifier on training set."""
68
+ # Extract feature
69
+ if self.backbone is None:
70
+ raise AssertionError("You must set a model before running execution")
71
+
72
+ if train_dataloader is not None: # train_features is None or train_labels is None:
73
+ log.info("Extracting features from training set")
74
+ train_features, train_labels, _ = get_feature(
75
+ feature_extractor=self.backbone,
76
+ dl=train_dataloader,
77
+ iteration_over_training=self.iteration_over_training,
78
+ gradcam=False,
79
+ )
80
+ else:
81
+ log.info("Using cached features for training set")
82
+ # With the current implementation cached features are not sorted
83
+ # Even though it doesn't seem to change anything
84
+ if train_features is None or train_labels is None:
85
+ raise AssertionError("Train features and labels must be provided when using cached data")
86
+ permuted_indices = np.random.RandomState(seed=self.random_state).permutation(train_features.shape[0])
87
+ train_features = train_features[permuted_indices]
88
+ train_labels = train_labels[permuted_indices]
89
+
90
+ log.info("Fitting classifier on %d features", len(train_features)) # type: ignore[arg-type]
91
+ self.classifier.fit(train_features, train_labels)
92
+
93
+ def test(
94
+ self,
95
+ test_dataloader: DataLoader,
96
+ test_labels: ndarray | None = None,
97
+ test_features: ndarray | None = None,
98
+ class_to_keep: list[int] | None = None,
99
+ idx_to_class: dict[int, str] | None = None,
100
+ predict_proba: bool = True,
101
+ gradcam: bool = False,
102
+ ) -> (
103
+ tuple[str | dict, DataFrame, float, DataFrame, np.ndarray | None]
104
+ | tuple[None, None, None, DataFrame, np.ndarray | None]
105
+ ):
106
+ """Test classifier on test set.
107
+
108
+ Args:
109
+ test_dataloader: Test dataloader
110
+ test_labels: test labels
111
+ test_features: Optional test features used when cache data is available
112
+ class_to_keep: list of class to keep
113
+ idx_to_class: dictionary mapping class index to class name
114
+ predict_proba: if True, predict also probability for each test image
115
+ gradcam: Whether to compute gradcam
116
+
117
+ Returns:
118
+ cl_rep: Classification report
119
+ pd_cm: Confusion matrix dataframe
120
+ accuracy: Test accuracy
121
+ res: Test results
122
+ cams: Gradcams
123
+ """
124
+ cams = None
125
+ # Extract feature
126
+ if test_features is None:
127
+ log.info("Extracting features from test set")
128
+ test_features, final_test_labels, cams = get_feature(
129
+ feature_extractor=self.backbone,
130
+ dl=test_dataloader,
131
+ gradcam=gradcam,
132
+ classifier=self.classifier,
133
+ input_shape=(self.input_shape[2], self.input_shape[0], self.input_shape[1]),
134
+ )
135
+ else:
136
+ if test_labels is None:
137
+ raise ValueError("Test labels must be provided when using cached data")
138
+ log.info("Using cached features for test set")
139
+ final_test_labels = test_labels
140
+
141
+ # Run classifier
142
+ log.info("Predict classifier on test set")
143
+ test_prediction_label = self.classifier.predict(test_features)
144
+ if predict_proba:
145
+ test_probability = self.classifier.predict_proba(test_features)
146
+ test_probability = test_probability.max(axis=1)
147
+
148
+ if class_to_keep is not None:
149
+ if idx_to_class is None:
150
+ raise ValueError("You must provide `idx_to_class` and `test_labels` when using `class_to_keep`")
151
+ filtered_test_labels = [int(x) if idx_to_class[x] in class_to_keep else -1 for x in final_test_labels]
152
+ else:
153
+ filtered_test_labels = cast(list[int], final_test_labels.tolist())
154
+
155
+ if not hasattr(test_dataloader.dataset, "x"):
156
+ raise ValueError("Current dataset doesn't provide an `x` attribute")
157
+
158
+ res = pd.DataFrame(
159
+ {
160
+ "sample": list(test_dataloader.dataset.x),
161
+ "real_label": final_test_labels,
162
+ "pred_label": test_prediction_label,
163
+ }
164
+ )
165
+
166
+ if not all(t == -1 for t in filtered_test_labels):
167
+ test_real_label_cm = np.array(filtered_test_labels)
168
+ if cams is not None:
169
+ cams = cams[test_real_label_cm != -1] # TODO: Is class_to_keep still used?
170
+ pred_labels_cm = np.array(test_prediction_label)[test_real_label_cm != -1]
171
+ test_real_label_cm = test_real_label_cm[test_real_label_cm != -1].astype(pred_labels_cm.dtype)
172
+ cl_rep, pd_cm, accuracy = get_results(test_real_label_cm, pred_labels_cm, idx_to_class)
173
+
174
+ if predict_proba:
175
+ res["probability"] = test_probability
176
+
177
+ return cl_rep, pd_cm, accuracy, res, cams
178
+
179
+ return None, None, None, res, cams
File without changes
@@ -0,0 +1,112 @@
1
+ """Anomaly Score Normalization Callback that uses min-max normalization."""
2
+
3
+ # Copyright (C) 2022 Intel Corporation
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ from __future__ import annotations
7
+
8
+ try:
9
+ from typing import Any, TypeAlias
10
+ except ImportError:
11
+ from typing import Any
12
+
13
+ from typing_extensions import TypeAlias # noqa
14
+
15
+
16
+ # MyPy wants TypeAlias, but pylint has problems dealing with it
17
+ import numpy as np # pylint: disable=unused-import
18
+ import pytorch_lightning as pl
19
+ import torch # pylint: disable=unused-import
20
+ from anomalib.models.components import AnomalyModule
21
+ from pytorch_lightning import Callback
22
+ from pytorch_lightning.utilities.types import STEP_OUTPUT
23
+
24
+ # https://github.com/python/cpython/issues/90015#issuecomment-1172996118
25
+ MapOrValue: TypeAlias = "float | torch.Tensor | np.ndarray"
26
+
27
+
28
+ def normalize_anomaly_score(raw_score: MapOrValue, threshold: float) -> MapOrValue:
29
+ """Normalize anomaly score value or map based on threshold.
30
+
31
+ Args:
32
+ raw_score: Raw anomaly score valure or map
33
+ threshold: Threshold for anomaly detection
34
+
35
+ Returns:
36
+ Normalized anomaly score value or map clipped between 0 and 1000
37
+ """
38
+ if threshold > 0:
39
+ normalized_score = (raw_score / threshold) * 100.0
40
+ elif threshold == 0:
41
+ # TODO: Is this the best way to handle this case?
42
+ normalized_score = (raw_score + 1) * 100.0
43
+ else:
44
+ normalized_score = 200.0 - ((raw_score / threshold) * 100.0)
45
+
46
+ if isinstance(normalized_score, torch.Tensor):
47
+ return torch.clamp(normalized_score, 0.0, 1000.0)
48
+
49
+ return np.clip(normalized_score, 0.0, 1000.0)
50
+
51
+
52
+ class ThresholdNormalizationCallback(Callback):
53
+ """Callback that normalizes the image-level and pixel-level anomaly scores dividing by the threshold value.
54
+
55
+ Args:
56
+ threshold_type: Threshold used to normalize pixel level anomaly scores, either image or pixel (default)
57
+ """
58
+
59
+ def __init__(self, threshold_type: str = "pixel"):
60
+ super().__init__()
61
+ self.threshold_type = threshold_type
62
+
63
+ def on_test_start(self, trainer: pl.Trainer, pl_module: AnomalyModule) -> None:
64
+ """Called when the test begins."""
65
+ del trainer # `trainer` variable is not used.
66
+
67
+ for metric in (pl_module.image_metrics, pl_module.pixel_metrics):
68
+ if metric is not None:
69
+ metric.set_threshold(100.0)
70
+
71
+ def on_test_batch_end(
72
+ self,
73
+ trainer: pl.Trainer,
74
+ pl_module: AnomalyModule,
75
+ outputs: STEP_OUTPUT | None,
76
+ batch: Any,
77
+ batch_idx: int,
78
+ dataloader_idx: int = 0,
79
+ ) -> None:
80
+ """Called when the test batch ends, normalizes the predicted scores and anomaly maps."""
81
+ del trainer, batch, batch_idx, dataloader_idx # These variables are not used.
82
+
83
+ self._normalize_batch(outputs, pl_module)
84
+
85
+ def on_predict_batch_end(
86
+ self,
87
+ trainer: pl.Trainer,
88
+ pl_module: AnomalyModule,
89
+ outputs: Any,
90
+ batch: Any,
91
+ batch_idx: int,
92
+ dataloader_idx: int = 0,
93
+ ) -> None:
94
+ """Called when the predict batch ends, normalizes the predicted scores and anomaly maps."""
95
+ del trainer, batch, batch_idx, dataloader_idx # These variables are not used.
96
+
97
+ self._normalize_batch(outputs, pl_module)
98
+
99
+ def _normalize_batch(self, outputs, pl_module):
100
+ """Normalize a batch of predictions."""
101
+ image_threshold = pl_module.image_threshold.value.cpu()
102
+ pixel_threshold = pl_module.pixel_threshold.value.cpu()
103
+ outputs["pred_scores"] = normalize_anomaly_score(outputs["pred_scores"], image_threshold.item())
104
+
105
+ threshold = pixel_threshold if self.threshold_type == "pixel" else image_threshold
106
+ threshold = threshold.item()
107
+
108
+ if "anomaly_maps" in outputs:
109
+ outputs["anomaly_maps"] = normalize_anomaly_score(outputs["anomaly_maps"], threshold)
110
+
111
+ if "box_scores" in outputs:
112
+ outputs["box_scores"] = [normalize_anomaly_score(scores, threshold) for scores in outputs["box_scores"]]