quadra 0.0.1__py3-none-any.whl → 2.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (302) hide show
  1. hydra_plugins/quadra_searchpath_plugin.py +30 -0
  2. quadra/__init__.py +6 -0
  3. quadra/callbacks/__init__.py +0 -0
  4. quadra/callbacks/anomalib.py +289 -0
  5. quadra/callbacks/lightning.py +501 -0
  6. quadra/callbacks/mlflow.py +291 -0
  7. quadra/callbacks/scheduler.py +69 -0
  8. quadra/configs/__init__.py +0 -0
  9. quadra/configs/backbone/caformer_m36.yaml +8 -0
  10. quadra/configs/backbone/caformer_s36.yaml +8 -0
  11. quadra/configs/backbone/convnextv2_base.yaml +8 -0
  12. quadra/configs/backbone/convnextv2_femto.yaml +8 -0
  13. quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
  14. quadra/configs/backbone/dino_vitb8.yaml +12 -0
  15. quadra/configs/backbone/dino_vits8.yaml +12 -0
  16. quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
  17. quadra/configs/backbone/dinov2_vits14.yaml +12 -0
  18. quadra/configs/backbone/efficientnet_b0.yaml +8 -0
  19. quadra/configs/backbone/efficientnet_b1.yaml +8 -0
  20. quadra/configs/backbone/efficientnet_b2.yaml +8 -0
  21. quadra/configs/backbone/efficientnet_b3.yaml +8 -0
  22. quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
  23. quadra/configs/backbone/levit_128s.yaml +8 -0
  24. quadra/configs/backbone/mnasnet0_5.yaml +9 -0
  25. quadra/configs/backbone/resnet101.yaml +8 -0
  26. quadra/configs/backbone/resnet18.yaml +8 -0
  27. quadra/configs/backbone/resnet18_ssl.yaml +8 -0
  28. quadra/configs/backbone/resnet50.yaml +8 -0
  29. quadra/configs/backbone/smp.yaml +9 -0
  30. quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
  31. quadra/configs/backbone/unetr.yaml +15 -0
  32. quadra/configs/backbone/vit16_base.yaml +9 -0
  33. quadra/configs/backbone/vit16_small.yaml +9 -0
  34. quadra/configs/backbone/vit16_tiny.yaml +9 -0
  35. quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
  36. quadra/configs/callbacks/all.yaml +45 -0
  37. quadra/configs/callbacks/default.yaml +34 -0
  38. quadra/configs/callbacks/default_anomalib.yaml +64 -0
  39. quadra/configs/config.yaml +33 -0
  40. quadra/configs/core/default.yaml +11 -0
  41. quadra/configs/datamodule/base/anomaly.yaml +16 -0
  42. quadra/configs/datamodule/base/classification.yaml +21 -0
  43. quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
  44. quadra/configs/datamodule/base/segmentation.yaml +18 -0
  45. quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
  46. quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
  47. quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
  48. quadra/configs/datamodule/base/ssl.yaml +21 -0
  49. quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
  50. quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
  51. quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
  52. quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
  53. quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
  54. quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
  55. quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
  56. quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
  57. quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
  58. quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
  59. quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
  60. quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
  61. quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
  62. quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
  63. quadra/configs/experiment/base/classification/classification.yaml +73 -0
  64. quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
  65. quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
  66. quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
  67. quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
  68. quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
  69. quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
  70. quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
  71. quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
  72. quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
  73. quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
  74. quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
  75. quadra/configs/experiment/base/ssl/byol.yaml +43 -0
  76. quadra/configs/experiment/base/ssl/dino.yaml +46 -0
  77. quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
  78. quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
  79. quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
  80. quadra/configs/experiment/custom/cls.yaml +12 -0
  81. quadra/configs/experiment/default.yaml +15 -0
  82. quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
  83. quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
  84. quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
  85. quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
  86. quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
  87. quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
  88. quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
  89. quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
  90. quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
  91. quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
  92. quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
  93. quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
  94. quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
  95. quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
  96. quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
  97. quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
  98. quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
  99. quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
  100. quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
  101. quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
  102. quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
  103. quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
  104. quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
  105. quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
  106. quadra/configs/export/default.yaml +13 -0
  107. quadra/configs/hydra/anomaly_custom.yaml +15 -0
  108. quadra/configs/hydra/default.yaml +14 -0
  109. quadra/configs/inference/default.yaml +26 -0
  110. quadra/configs/logger/comet.yaml +10 -0
  111. quadra/configs/logger/csv.yaml +5 -0
  112. quadra/configs/logger/mlflow.yaml +12 -0
  113. quadra/configs/logger/tensorboard.yaml +8 -0
  114. quadra/configs/loss/asl.yaml +7 -0
  115. quadra/configs/loss/barlow.yaml +2 -0
  116. quadra/configs/loss/bce.yaml +1 -0
  117. quadra/configs/loss/byol.yaml +1 -0
  118. quadra/configs/loss/cross_entropy.yaml +1 -0
  119. quadra/configs/loss/dino.yaml +8 -0
  120. quadra/configs/loss/simclr.yaml +2 -0
  121. quadra/configs/loss/simsiam.yaml +1 -0
  122. quadra/configs/loss/smp_ce.yaml +3 -0
  123. quadra/configs/loss/smp_dice.yaml +2 -0
  124. quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
  125. quadra/configs/loss/smp_mcc.yaml +2 -0
  126. quadra/configs/loss/vicreg.yaml +5 -0
  127. quadra/configs/model/anomalib/cfa.yaml +35 -0
  128. quadra/configs/model/anomalib/cflow.yaml +30 -0
  129. quadra/configs/model/anomalib/csflow.yaml +34 -0
  130. quadra/configs/model/anomalib/dfm.yaml +19 -0
  131. quadra/configs/model/anomalib/draem.yaml +29 -0
  132. quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
  133. quadra/configs/model/anomalib/fastflow.yaml +32 -0
  134. quadra/configs/model/anomalib/padim.yaml +32 -0
  135. quadra/configs/model/anomalib/patchcore.yaml +36 -0
  136. quadra/configs/model/barlow.yaml +16 -0
  137. quadra/configs/model/byol.yaml +25 -0
  138. quadra/configs/model/classification.yaml +10 -0
  139. quadra/configs/model/dino.yaml +26 -0
  140. quadra/configs/model/logistic_regression.yaml +4 -0
  141. quadra/configs/model/multilabel_classification.yaml +9 -0
  142. quadra/configs/model/simclr.yaml +18 -0
  143. quadra/configs/model/simsiam.yaml +24 -0
  144. quadra/configs/model/smp.yaml +4 -0
  145. quadra/configs/model/smp_multiclass.yaml +4 -0
  146. quadra/configs/model/vicreg.yaml +16 -0
  147. quadra/configs/optimizer/adam.yaml +5 -0
  148. quadra/configs/optimizer/adamw.yaml +3 -0
  149. quadra/configs/optimizer/default.yaml +4 -0
  150. quadra/configs/optimizer/lars.yaml +8 -0
  151. quadra/configs/optimizer/sgd.yaml +4 -0
  152. quadra/configs/scheduler/default.yaml +5 -0
  153. quadra/configs/scheduler/rop.yaml +5 -0
  154. quadra/configs/scheduler/step.yaml +3 -0
  155. quadra/configs/scheduler/warmrestart.yaml +2 -0
  156. quadra/configs/scheduler/warmup.yaml +6 -0
  157. quadra/configs/task/anomalib/cfa.yaml +5 -0
  158. quadra/configs/task/anomalib/cflow.yaml +5 -0
  159. quadra/configs/task/anomalib/csflow.yaml +5 -0
  160. quadra/configs/task/anomalib/draem.yaml +5 -0
  161. quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
  162. quadra/configs/task/anomalib/fastflow.yaml +5 -0
  163. quadra/configs/task/anomalib/inference.yaml +3 -0
  164. quadra/configs/task/anomalib/padim.yaml +5 -0
  165. quadra/configs/task/anomalib/patchcore.yaml +5 -0
  166. quadra/configs/task/classification.yaml +6 -0
  167. quadra/configs/task/classification_evaluation.yaml +6 -0
  168. quadra/configs/task/default.yaml +1 -0
  169. quadra/configs/task/segmentation.yaml +9 -0
  170. quadra/configs/task/segmentation_evaluation.yaml +3 -0
  171. quadra/configs/task/sklearn_classification.yaml +13 -0
  172. quadra/configs/task/sklearn_classification_patch.yaml +11 -0
  173. quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
  174. quadra/configs/task/sklearn_classification_test.yaml +8 -0
  175. quadra/configs/task/ssl.yaml +2 -0
  176. quadra/configs/trainer/lightning_cpu.yaml +36 -0
  177. quadra/configs/trainer/lightning_gpu.yaml +35 -0
  178. quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
  179. quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
  180. quadra/configs/trainer/lightning_multigpu.yaml +37 -0
  181. quadra/configs/trainer/sklearn_classification.yaml +7 -0
  182. quadra/configs/transforms/byol.yaml +47 -0
  183. quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
  184. quadra/configs/transforms/default.yaml +37 -0
  185. quadra/configs/transforms/default_numpy.yaml +24 -0
  186. quadra/configs/transforms/default_resize.yaml +22 -0
  187. quadra/configs/transforms/dino.yaml +63 -0
  188. quadra/configs/transforms/linear_eval.yaml +18 -0
  189. quadra/datamodules/__init__.py +20 -0
  190. quadra/datamodules/anomaly.py +180 -0
  191. quadra/datamodules/base.py +375 -0
  192. quadra/datamodules/classification.py +1003 -0
  193. quadra/datamodules/generic/__init__.py +0 -0
  194. quadra/datamodules/generic/imagenette.py +144 -0
  195. quadra/datamodules/generic/mnist.py +81 -0
  196. quadra/datamodules/generic/mvtec.py +58 -0
  197. quadra/datamodules/generic/oxford_pet.py +163 -0
  198. quadra/datamodules/patch.py +190 -0
  199. quadra/datamodules/segmentation.py +742 -0
  200. quadra/datamodules/ssl.py +140 -0
  201. quadra/datasets/__init__.py +17 -0
  202. quadra/datasets/anomaly.py +287 -0
  203. quadra/datasets/classification.py +241 -0
  204. quadra/datasets/patch.py +138 -0
  205. quadra/datasets/segmentation.py +239 -0
  206. quadra/datasets/ssl.py +110 -0
  207. quadra/losses/__init__.py +0 -0
  208. quadra/losses/classification/__init__.py +6 -0
  209. quadra/losses/classification/asl.py +83 -0
  210. quadra/losses/classification/focal.py +320 -0
  211. quadra/losses/classification/prototypical.py +148 -0
  212. quadra/losses/ssl/__init__.py +17 -0
  213. quadra/losses/ssl/barlowtwins.py +47 -0
  214. quadra/losses/ssl/byol.py +37 -0
  215. quadra/losses/ssl/dino.py +129 -0
  216. quadra/losses/ssl/hyperspherical.py +45 -0
  217. quadra/losses/ssl/idmm.py +50 -0
  218. quadra/losses/ssl/simclr.py +67 -0
  219. quadra/losses/ssl/simsiam.py +30 -0
  220. quadra/losses/ssl/vicreg.py +76 -0
  221. quadra/main.py +49 -0
  222. quadra/metrics/__init__.py +3 -0
  223. quadra/metrics/segmentation.py +251 -0
  224. quadra/models/__init__.py +0 -0
  225. quadra/models/base.py +151 -0
  226. quadra/models/classification/__init__.py +8 -0
  227. quadra/models/classification/backbones.py +149 -0
  228. quadra/models/classification/base.py +92 -0
  229. quadra/models/evaluation.py +322 -0
  230. quadra/modules/__init__.py +0 -0
  231. quadra/modules/backbone.py +30 -0
  232. quadra/modules/base.py +312 -0
  233. quadra/modules/classification/__init__.py +3 -0
  234. quadra/modules/classification/base.py +327 -0
  235. quadra/modules/ssl/__init__.py +17 -0
  236. quadra/modules/ssl/barlowtwins.py +59 -0
  237. quadra/modules/ssl/byol.py +172 -0
  238. quadra/modules/ssl/common.py +285 -0
  239. quadra/modules/ssl/dino.py +186 -0
  240. quadra/modules/ssl/hyperspherical.py +206 -0
  241. quadra/modules/ssl/idmm.py +98 -0
  242. quadra/modules/ssl/simclr.py +73 -0
  243. quadra/modules/ssl/simsiam.py +68 -0
  244. quadra/modules/ssl/vicreg.py +67 -0
  245. quadra/optimizers/__init__.py +4 -0
  246. quadra/optimizers/lars.py +153 -0
  247. quadra/optimizers/sam.py +127 -0
  248. quadra/schedulers/__init__.py +3 -0
  249. quadra/schedulers/base.py +44 -0
  250. quadra/schedulers/warmup.py +127 -0
  251. quadra/tasks/__init__.py +24 -0
  252. quadra/tasks/anomaly.py +582 -0
  253. quadra/tasks/base.py +397 -0
  254. quadra/tasks/classification.py +1263 -0
  255. quadra/tasks/patch.py +492 -0
  256. quadra/tasks/segmentation.py +389 -0
  257. quadra/tasks/ssl.py +560 -0
  258. quadra/trainers/README.md +3 -0
  259. quadra/trainers/__init__.py +0 -0
  260. quadra/trainers/classification.py +179 -0
  261. quadra/utils/__init__.py +0 -0
  262. quadra/utils/anomaly.py +112 -0
  263. quadra/utils/classification.py +618 -0
  264. quadra/utils/deprecation.py +31 -0
  265. quadra/utils/evaluation.py +474 -0
  266. quadra/utils/export.py +585 -0
  267. quadra/utils/imaging.py +32 -0
  268. quadra/utils/logger.py +15 -0
  269. quadra/utils/mlflow.py +98 -0
  270. quadra/utils/model_manager.py +320 -0
  271. quadra/utils/models.py +523 -0
  272. quadra/utils/patch/__init__.py +15 -0
  273. quadra/utils/patch/dataset.py +1433 -0
  274. quadra/utils/patch/metrics.py +449 -0
  275. quadra/utils/patch/model.py +153 -0
  276. quadra/utils/patch/visualization.py +217 -0
  277. quadra/utils/resolver.py +42 -0
  278. quadra/utils/segmentation.py +31 -0
  279. quadra/utils/tests/__init__.py +0 -0
  280. quadra/utils/tests/fixtures/__init__.py +1 -0
  281. quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
  282. quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
  283. quadra/utils/tests/fixtures/dataset/classification.py +406 -0
  284. quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
  285. quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
  286. quadra/utils/tests/fixtures/models/__init__.py +3 -0
  287. quadra/utils/tests/fixtures/models/anomaly.py +89 -0
  288. quadra/utils/tests/fixtures/models/classification.py +45 -0
  289. quadra/utils/tests/fixtures/models/segmentation.py +33 -0
  290. quadra/utils/tests/helpers.py +70 -0
  291. quadra/utils/tests/models.py +27 -0
  292. quadra/utils/utils.py +525 -0
  293. quadra/utils/validator.py +115 -0
  294. quadra/utils/visualization.py +422 -0
  295. quadra/utils/vit_explainability.py +349 -0
  296. quadra-2.2.7.dist-info/LICENSE +201 -0
  297. quadra-2.2.7.dist-info/METADATA +381 -0
  298. quadra-2.2.7.dist-info/RECORD +300 -0
  299. {quadra-0.0.1.dist-info → quadra-2.2.7.dist-info}/WHEEL +1 -1
  300. quadra-2.2.7.dist-info/entry_points.txt +3 -0
  301. quadra-0.0.1.dist-info/METADATA +0 -14
  302. quadra-0.0.1.dist-info/RECORD +0 -4
@@ -0,0 +1,474 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from ast import literal_eval
5
+ from collections.abc import Callable
6
+ from functools import wraps
7
+
8
+ import matplotlib.pyplot as plt
9
+ import numpy as np
10
+ import numpy.typing as npt
11
+ import pandas as pd
12
+ import seaborn as sns
13
+ import segmentation_models_pytorch as smp
14
+ import torch
15
+ import yaml
16
+ from segmentation_models_pytorch.losses import DiceLoss
17
+ from segmentation_models_pytorch.losses.constants import BINARY_MODE, MULTICLASS_MODE
18
+ from skimage.measure import label, regionprops # pylint: disable=no-name-in-module
19
+
20
+ from quadra.utils.logger import get_logger
21
+ from quadra.utils.visualization import UnNormalize, create_grid_figure
22
+
23
+ try:
24
+ from onnxruntime.capi.onnxruntime_pybind11_state import RuntimeException # noqa
25
+
26
+ ONNX_AVAILABLE = True
27
+ except ImportError:
28
+ ONNX_AVAILABLE = False
29
+
30
+ log = get_logger(__name__)
31
+
32
+
33
+ def dice(
34
+ input_tensor: torch.Tensor,
35
+ target: torch.Tensor,
36
+ smooth: float = 1.0,
37
+ eps: float = 1e-8,
38
+ reduction: str | None = "mean",
39
+ ) -> torch.Tensor:
40
+ """Dice loss computation function.
41
+
42
+ Args:
43
+ input_tensor: input tensor coming from a model
44
+ target: target tensor to compare with
45
+ smooth: smoothing factor
46
+ eps: epsilon to avoid zero division
47
+ reduction: reduction method, one of "mean", "sum", "none"
48
+
49
+ Returns:
50
+ The computed loss
51
+ """
52
+ bs = input_tensor.size(0)
53
+ iflat = input_tensor.contiguous().view(bs, -1)
54
+ tflat = target.contiguous().view(bs, -1)
55
+ intersection = (iflat * tflat).sum(-1)
56
+ loss = 1 - (2.0 * intersection + smooth) / (iflat.sum(-1) + tflat.sum(-1) + smooth + eps)
57
+
58
+ if reduction == "mean":
59
+ loss = loss.mean()
60
+ return loss
61
+
62
+
63
+ def score_dice(
64
+ y_pred,
65
+ y_true,
66
+ reduction=None,
67
+ ) -> torch.Tensor:
68
+ """Calculate dice score."""
69
+ return 1 - dice(y_pred, y_true, reduction=reduction)
70
+
71
+
72
+ def score_dice_smp(y_pred: torch.Tensor, y_true: torch.Tensor, mode: str = "binary") -> torch.Tensor:
73
+ """Compute dice using smp function. Handle both binary and multiclass scenario.
74
+
75
+ Args:
76
+ y_pred: 1xCxHxW one channel for each class
77
+ y_true: 1x1xHxW true mask with value in [0, ..., n_classes]
78
+ mode: "binary" or "multiclass"
79
+
80
+ Returns:
81
+ dice score
82
+ """
83
+ if mode not in {BINARY_MODE, MULTICLASS_MODE}:
84
+ raise ValueError(f"Mode {mode} not valid.")
85
+
86
+ loss = DiceLoss(mode=mode, from_logits=False)
87
+
88
+ return 1 - loss(y_pred, y_true)
89
+
90
+
91
+ def calculate_mask_based_metrics(
92
+ images: np.ndarray,
93
+ th_masks: torch.Tensor,
94
+ th_preds: torch.Tensor,
95
+ threshold: float = 0.5,
96
+ show_orj_predictions: bool = False,
97
+ metric: Callable = score_dice,
98
+ multilabel: bool = False,
99
+ n_classes: int | None = None,
100
+ ) -> tuple[
101
+ dict[str, float],
102
+ dict[str, list[np.ndarray]],
103
+ dict[str, list[np.ndarray]],
104
+ dict[str, list[str | float]],
105
+ ]:
106
+ """Calculate metrics based on masks and predictions.
107
+
108
+ Args:
109
+ images: Images.
110
+ th_masks: masks are tensors.
111
+ th_preds: predictions are tensors.
112
+ threshold: Threshold to apply. Defaults to 0.5.
113
+ show_orj_predictions: Flag to show original predictions. Defaults to False.
114
+ metric: Metric to use comparison. Defaults to `score_dice`.
115
+ multilabel: True if segmentation is multiclass.
116
+ n_classes: Number of classes. If multilabel is False, this should be None.
117
+
118
+ Returns:
119
+ dict: Dictionary with metrics.
120
+ """
121
+ masks = th_masks.cpu().numpy()
122
+ preds = th_preds.squeeze(0).cpu().numpy()
123
+ th_thresh_preds = (th_preds > threshold).float().cpu()
124
+ thresh_preds = th_thresh_preds.squeeze(0).numpy()
125
+ dice_scores = metric(th_thresh_preds, th_masks, reduction=None).numpy()
126
+ result = {}
127
+ if multilabel:
128
+ if n_classes is None:
129
+ raise ValueError("n_classes arg shouldn't be None when multilabel is True")
130
+ preds_multilabel = (
131
+ torch.nn.functional.one_hot(th_preds.to(torch.int64), num_classes=n_classes).squeeze(1).permute(0, 3, 1, 2)
132
+ )
133
+ masks_multilabel = (
134
+ torch.nn.functional.one_hot(th_masks.to(torch.int64), num_classes=n_classes).squeeze(1).permute(0, 3, 1, 2)
135
+ ).to(preds_multilabel.device)
136
+ # get_stats multiclass, not considering background channel
137
+ tp, fp, fn, tn = smp.metrics.get_stats(
138
+ preds_multilabel[:, 1:, :, :].long(), masks_multilabel[:, 1:, :, :].long(), mode="multilabel"
139
+ )
140
+ else:
141
+ tp, fp, fn, tn = smp.metrics.get_stats(th_thresh_preds.long(), th_masks.long(), mode="binary")
142
+ per_image_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro-imagewise")
143
+ dataset_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")
144
+ result["F1_image"] = round(float(smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro-imagewise").item()), 4)
145
+ result["F1_pixel"] = round(float(smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro").item()), 4)
146
+ result["image_iou"] = round(float(per_image_iou.item()), 4) if not per_image_iou.isnan() else np.nan
147
+ result["dataset_iou"] = round(float(dataset_iou.item()), 4) if not dataset_iou.isnan() else np.nan
148
+ result["TP_pixel"] = tp.sum().item()
149
+ result["FP_pixel"] = fp.sum().item()
150
+ result["FN_pixel"] = fn.sum().item()
151
+ result["TN_pixel"] = tn.sum().item()
152
+ result["TP_image"] = 0
153
+ result["FP_image"] = 0
154
+ result["FN_image"] = 0
155
+ result["TN_image"] = 0
156
+ result["num_good_image"] = 0
157
+ result["num_bad_image"] = 0
158
+ bad_dice, good_dice = [], []
159
+ fg: dict[str, list[np.ndarray]] = {"image": [], "mask": [], "thresh_pred": []}
160
+ fb: dict[str, list[np.ndarray]] = {"image": [], "mask": [], "thresh_pred": []}
161
+ if show_orj_predictions:
162
+ fg["pred"] = []
163
+ fb["pred"] = []
164
+
165
+ area_graph: dict[str, list[str | float]] = {
166
+ "Defect Area Percentage": [],
167
+ "Accuracy": [],
168
+ }
169
+ for idx, (image, pred, mask, thresh_pred, dice_score) in enumerate(
170
+ zip(images, preds, masks, thresh_preds, dice_scores)
171
+ ):
172
+ if np.sum(mask) == 0:
173
+ good_dice.append(dice_score)
174
+ else:
175
+ bad_dice.append(dice_score)
176
+ if mask.sum() > 0:
177
+ result["num_bad_image"] += 1
178
+ if thresh_pred.sum() == 0:
179
+ result["FN_image"] += 1
180
+ fg["image"].append(image)
181
+ fg["mask"].append(mask)
182
+ if show_orj_predictions:
183
+ fg["pred"].append(pred)
184
+ fg["thresh_pred"].append(thresh_pred)
185
+ else:
186
+ result["TP_image"] += 1
187
+ rp = regionprops(label(mask[0]))
188
+ for r in rp:
189
+ mask_partial = th_masks[idx, :, r.bbox[0] : r.bbox[2], r.bbox[1] : r.bbox[3]]
190
+ pred_partial = th_thresh_preds[idx, :, r.bbox[0] : r.bbox[2], r.bbox[1] : r.bbox[3]]
191
+ tp, fp, fn, tn = smp.metrics.get_stats(pred_partial.long(), mask_partial.long(), mode="binary")
192
+ area = tp + fn
193
+ area_percentage = area.sum().item() * 100 / (image.shape[0] * image.shape[1])
194
+ defect_acc = smp.metrics.accuracy(tp, fp, fn, tn, reduction="micro")
195
+ area_graph["Accuracy"].append(defect_acc.item() * 100)
196
+ if area_percentage <= 1:
197
+ area_graph["Defect Area Percentage"].append("Very Small <1%")
198
+ elif area_percentage <= 10:
199
+ area_graph["Defect Area Percentage"].append("Small <10%")
200
+ elif area_percentage <= 25:
201
+ area_graph["Defect Area Percentage"].append("Medium <25%")
202
+ else:
203
+ area_graph["Defect Area Percentage"].append("Large >25%")
204
+
205
+ if mask.sum() == 0:
206
+ result["num_good_image"] += 1
207
+ if thresh_pred.sum() > 0:
208
+ result["FP_image"] += 1
209
+ fb["image"].append(image)
210
+ fb["mask"].append(mask)
211
+ if show_orj_predictions:
212
+ fb["pred"].append(pred)
213
+ fb["thresh_pred"].append(thresh_pred)
214
+ else:
215
+ result["TN_image"] += 1
216
+ result["bad_dice_score_mean"] = np.mean(bad_dice) if len(bad_dice) > 0 else "null"
217
+ result["bad_dice_score_std"] = np.std(bad_dice) if len(bad_dice) > 0 else "null"
218
+ result["good_dice_score_mean"] = np.mean(good_dice) if len(good_dice) > 0 else "null"
219
+ result["good_dice_score_std"] = np.std(good_dice) if len(good_dice) > 0 else "null"
220
+ return result, fg, fb, area_graph
221
+
222
+
223
+ def create_mask_report(
224
+ stage: str,
225
+ output: dict[str, torch.Tensor],
226
+ mean: npt.ArrayLike,
227
+ std: npt.ArrayLike,
228
+ report_path: str,
229
+ nb_samples: int = 6,
230
+ analysis: bool = False,
231
+ apply_sigmoid: bool = True,
232
+ show_all: bool = False,
233
+ threshold: float = 0.5,
234
+ metric: Callable = score_dice,
235
+ show_orj_predictions: bool = False,
236
+ ) -> list[str]:
237
+ """Create report for segmentation experiment
238
+ Args:
239
+ stage: stage name. Train, validation or test
240
+ output: data produced by model
241
+ report_path: experiment path
242
+ mean: mean values
243
+ std: std values
244
+ nb_samples: number of samples
245
+ analysis: if True, analysis will be created
246
+ apply_sigmoid: if True, sigmoid will be applied to predictions
247
+ show_all: if True, all images will be shown
248
+ threshold: threshold for predictions
249
+ metric: metric function
250
+ show_orj_predictions: if True, original predictions will be shown.
251
+
252
+ Returns:
253
+ list of paths to created images.
254
+ """
255
+ if not os.path.exists(report_path):
256
+ os.makedirs(report_path)
257
+
258
+ th_images = output["image"]
259
+ th_masks = output["mask"]
260
+ th_preds = output["mask_pred"]
261
+ th_labels = output["label"]
262
+ n_classes = th_preds.shape[1]
263
+ # TODO: Apply sigmoid is a wrong name now
264
+ if apply_sigmoid:
265
+ if n_classes == 1:
266
+ th_preds = torch.nn.Sigmoid()(th_preds)
267
+ th_thresh_preds = (th_preds > threshold).float()
268
+ else:
269
+ th_preds = torch.nn.Softmax(dim=1)(th_preds)
270
+ th_thresh_preds = torch.argmax(th_preds, dim=1).float().unsqueeze(1)
271
+ # Compute labels from the given masks since by default they are all 0
272
+ th_labels = th_masks.max(dim=2)[0].max(dim=2)[0].squeeze(dim=1)
273
+ show_orj_predictions = False
274
+
275
+ mean = np.asarray(mean)
276
+ std = np.asarray(std)
277
+ unnormalize = UnNormalize(mean, std)
278
+
279
+ images = np.array(
280
+ [(unnormalize(image).cpu().numpy().transpose((1, 2, 0)) * 255).astype(np.uint8) for image in th_images]
281
+ )
282
+ masks = th_masks.cpu().numpy()
283
+ preds = th_preds.squeeze(0).cpu().numpy()
284
+ thresh_preds = th_thresh_preds.squeeze(0).cpu().numpy()
285
+ dice_scores = metric(th_thresh_preds.cpu(), th_masks.cpu(), reduction=None).numpy()
286
+
287
+ labels = th_labels.cpu().numpy()
288
+ binary_labels = labels == 0
289
+
290
+ row_names = ["Input", "Mask", "Pred", f"Pred>{threshold}"]
291
+ bounds = [(0, 255), (0.0, float(n_classes - 1)), (0.0, 1.0), (0.0, float(n_classes - 1))]
292
+ if not show_orj_predictions:
293
+ row_names.pop(2)
294
+ bounds.pop(2)
295
+
296
+ if not show_all:
297
+ sorted_idx = np.argsort(dice_scores)
298
+ else:
299
+ sorted_idx = np.arange(len(dice_scores))
300
+
301
+ binary_labels = binary_labels[sorted_idx]
302
+
303
+ non_zero_score_idx = sorted_idx[~binary_labels]
304
+ zero_score_idx = sorted_idx[binary_labels]
305
+ file_paths = []
306
+ for name, current_score_idx in zip(["good", "bad"], [zero_score_idx, non_zero_score_idx]):
307
+ if len(current_score_idx) == 0:
308
+ continue
309
+
310
+ nb_total_samples = len(current_score_idx)
311
+ nb_selected_samples = nb_total_samples if nb_samples > nb_total_samples else nb_samples
312
+ fig_w = int(nb_selected_samples * 2)
313
+ fig_h = int(len(row_names) * 2)
314
+ if not show_all:
315
+ worst_idx = current_score_idx[:nb_selected_samples].tolist()
316
+ best_idx = current_score_idx[-nb_selected_samples:].tolist()
317
+ random_idx = np.random.choice(current_score_idx, nb_selected_samples, replace=False).tolist()
318
+
319
+ indexes = {"best": best_idx, "worst": worst_idx, "random": random_idx}
320
+ else:
321
+ indexes = {"all": current_score_idx[:nb_selected_samples].tolist()}
322
+ for k, v in indexes.items():
323
+ file_path = os.path.join(report_path, f"{stage}_{name}_{k}_results.png")
324
+ images_to_show = [images[v], masks[v], preds[v], thresh_preds[v]]
325
+ if not show_orj_predictions or n_classes > 1:
326
+ images_to_show.pop(2)
327
+ create_grid_figure(
328
+ images_to_show,
329
+ nrows=len(row_names),
330
+ ncols=nb_selected_samples,
331
+ row_names=row_names,
332
+ file_path=file_path,
333
+ fig_size=(fig_w, fig_h),
334
+ bounds=bounds,
335
+ )
336
+ file_paths.append(file_path)
337
+ if analysis:
338
+ analysis_file_path = os.path.join(report_path, f"{stage}_analysis.yaml")
339
+ result, fg, fb, area_graph = calculate_mask_based_metrics(
340
+ images=images,
341
+ th_masks=th_masks,
342
+ th_preds=th_thresh_preds,
343
+ threshold=threshold,
344
+ show_orj_predictions=show_orj_predictions,
345
+ metric=metric,
346
+ multilabel=bool(n_classes > 1),
347
+ n_classes=n_classes,
348
+ )
349
+
350
+ if len(fg["image"]) > 0:
351
+ if len(fg["image"]) > nb_samples:
352
+ for k, v in fg.items():
353
+ fg[k] = v[:nb_samples]
354
+
355
+ fg_file_path = os.path.join(report_path, f"{stage}_fn_results.png")
356
+ fig_w = int(len(fg["image"]) * 2)
357
+ create_grid_figure(
358
+ [fg for _, fg in fg.items()],
359
+ nrows=len(row_names),
360
+ ncols=len(fg["image"]),
361
+ row_names=row_names,
362
+ file_path=fg_file_path,
363
+ fig_size=(fig_w, fig_h),
364
+ bounds=bounds,
365
+ )
366
+ file_paths.append(fg_file_path)
367
+
368
+ if len(fb["image"]) > 0:
369
+ if len(fb["image"]) > nb_samples:
370
+ for k, v in fb.items():
371
+ fb[k] = v[:nb_samples]
372
+ fb_file_path = os.path.join(report_path, f"{stage}_fp_results.png")
373
+
374
+ fig_w = int(len(fb["image"]) * 2)
375
+ create_grid_figure(
376
+ [fb for _, fb in fb.items()],
377
+ nrows=len(row_names),
378
+ ncols=len(fb["image"]),
379
+ row_names=row_names,
380
+ file_path=fb_file_path,
381
+ fig_size=(fig_w, fig_h),
382
+ bounds=bounds,
383
+ )
384
+ file_paths.append(fb_file_path)
385
+ if len(area_graph["Defect Area Percentage"]) > 0:
386
+ fn_area_path = os.path.join(report_path, f"{stage}_acc_area.png")
387
+ fn_area_df = pd.DataFrame(area_graph)
388
+ ax = sns.boxplot(
389
+ x="Defect Area Percentage",
390
+ y="Accuracy",
391
+ data=fn_area_df,
392
+ order=["Very Small <1%", "Small <10%", "Medium <25%", "Large >25%"],
393
+ )
394
+ ax.set_facecolor("white")
395
+ fig = ax.get_figure()
396
+ fig.savefig(fn_area_path)
397
+ plt.close(fig)
398
+
399
+ file_paths.append(fn_area_path)
400
+ with open(analysis_file_path, "w") as file:
401
+ yaml.dump(literal_eval(str(result)), file, default_flow_style=False)
402
+ file_paths.append(analysis_file_path)
403
+
404
+ return file_paths
405
+
406
+
407
+ def automatic_datamodule_batch_size(batch_size_attribute_name: str = "batch_size"):
408
+ """Automatically scale the datamodule batch size if the given function goes out of memory.
409
+
410
+ Args:
411
+ batch_size_attribute_name: The name of the attribute to modify in the datamodule
412
+ """
413
+
414
+ def decorator(func: Callable):
415
+ """Decorator function."""
416
+
417
+ @wraps(func)
418
+ def wrapper(self, *args, **kwargs):
419
+ """Wrapper function."""
420
+ is_func_finished = False
421
+ starting_batch_size = None
422
+ automatic_batch_size_completed = False
423
+
424
+ if hasattr(self, "automatic_batch_size_completed"):
425
+ automatic_batch_size_completed = self.automatic_batch_size_completed
426
+
427
+ if hasattr(self, "automatic_batch_size"):
428
+ if not hasattr(self.automatic_batch_size, "disable") or not hasattr(
429
+ self.automatic_batch_size, "starting_batch_size"
430
+ ):
431
+ raise ValueError(
432
+ "The automatic_batch_size attribute should have the disable and starting_batch_size attributes"
433
+ )
434
+ starting_batch_size = (
435
+ self.automatic_batch_size.starting_batch_size if not self.automatic_batch_size.disable else None
436
+ )
437
+
438
+ if starting_batch_size is not None and not automatic_batch_size_completed:
439
+ # If we already tried to reduce the batch size, we will start from the last batch size
440
+ log.info("Performing automatic batch size scaling from %d", starting_batch_size)
441
+ setattr(self.datamodule, batch_size_attribute_name, starting_batch_size)
442
+
443
+ while not is_func_finished:
444
+ valid_exceptions = (RuntimeError,)
445
+
446
+ if ONNX_AVAILABLE:
447
+ valid_exceptions += (RuntimeException,)
448
+
449
+ try:
450
+ func(self, *args, **kwargs)
451
+ is_func_finished = True
452
+ self.automatic_batch_size_completed = True
453
+ if torch.cuda.is_available():
454
+ torch.cuda.empty_cache()
455
+ except valid_exceptions as e:
456
+ current_batch_size = getattr(self.datamodule, batch_size_attribute_name)
457
+ setattr(self.datamodule, batch_size_attribute_name, current_batch_size // 2)
458
+ log.warning(
459
+ "The function %s went out of memory, trying to reduce the batch size to %d",
460
+ func.__name__,
461
+ self.datamodule.batch_size,
462
+ )
463
+
464
+ if self.datamodule.batch_size == 0:
465
+ raise RuntimeError(
466
+ f"Unable to run {func.__name__} with batch size 1, the program will exit"
467
+ ) from e
468
+
469
+ if torch.cuda.is_available():
470
+ torch.cuda.empty_cache()
471
+
472
+ return wrapper
473
+
474
+ return decorator