quadra 0.0.1__py3-none-any.whl → 2.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (302) hide show
  1. hydra_plugins/quadra_searchpath_plugin.py +30 -0
  2. quadra/__init__.py +6 -0
  3. quadra/callbacks/__init__.py +0 -0
  4. quadra/callbacks/anomalib.py +289 -0
  5. quadra/callbacks/lightning.py +501 -0
  6. quadra/callbacks/mlflow.py +291 -0
  7. quadra/callbacks/scheduler.py +69 -0
  8. quadra/configs/__init__.py +0 -0
  9. quadra/configs/backbone/caformer_m36.yaml +8 -0
  10. quadra/configs/backbone/caformer_s36.yaml +8 -0
  11. quadra/configs/backbone/convnextv2_base.yaml +8 -0
  12. quadra/configs/backbone/convnextv2_femto.yaml +8 -0
  13. quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
  14. quadra/configs/backbone/dino_vitb8.yaml +12 -0
  15. quadra/configs/backbone/dino_vits8.yaml +12 -0
  16. quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
  17. quadra/configs/backbone/dinov2_vits14.yaml +12 -0
  18. quadra/configs/backbone/efficientnet_b0.yaml +8 -0
  19. quadra/configs/backbone/efficientnet_b1.yaml +8 -0
  20. quadra/configs/backbone/efficientnet_b2.yaml +8 -0
  21. quadra/configs/backbone/efficientnet_b3.yaml +8 -0
  22. quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
  23. quadra/configs/backbone/levit_128s.yaml +8 -0
  24. quadra/configs/backbone/mnasnet0_5.yaml +9 -0
  25. quadra/configs/backbone/resnet101.yaml +8 -0
  26. quadra/configs/backbone/resnet18.yaml +8 -0
  27. quadra/configs/backbone/resnet18_ssl.yaml +8 -0
  28. quadra/configs/backbone/resnet50.yaml +8 -0
  29. quadra/configs/backbone/smp.yaml +9 -0
  30. quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
  31. quadra/configs/backbone/unetr.yaml +15 -0
  32. quadra/configs/backbone/vit16_base.yaml +9 -0
  33. quadra/configs/backbone/vit16_small.yaml +9 -0
  34. quadra/configs/backbone/vit16_tiny.yaml +9 -0
  35. quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
  36. quadra/configs/callbacks/all.yaml +32 -0
  37. quadra/configs/callbacks/default.yaml +37 -0
  38. quadra/configs/callbacks/default_anomalib.yaml +67 -0
  39. quadra/configs/config.yaml +33 -0
  40. quadra/configs/core/default.yaml +11 -0
  41. quadra/configs/datamodule/base/anomaly.yaml +16 -0
  42. quadra/configs/datamodule/base/classification.yaml +21 -0
  43. quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
  44. quadra/configs/datamodule/base/segmentation.yaml +18 -0
  45. quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
  46. quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
  47. quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
  48. quadra/configs/datamodule/base/ssl.yaml +21 -0
  49. quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
  50. quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
  51. quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
  52. quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
  53. quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
  54. quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
  55. quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
  56. quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
  57. quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
  58. quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
  59. quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
  60. quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
  61. quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
  62. quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
  63. quadra/configs/experiment/base/classification/classification.yaml +73 -0
  64. quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
  65. quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
  66. quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
  67. quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
  68. quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
  69. quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
  70. quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
  71. quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
  72. quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
  73. quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
  74. quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
  75. quadra/configs/experiment/base/ssl/byol.yaml +43 -0
  76. quadra/configs/experiment/base/ssl/dino.yaml +46 -0
  77. quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
  78. quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
  79. quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
  80. quadra/configs/experiment/custom/cls.yaml +12 -0
  81. quadra/configs/experiment/default.yaml +15 -0
  82. quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
  83. quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
  84. quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
  85. quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
  86. quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
  87. quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
  88. quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
  89. quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
  90. quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
  91. quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
  92. quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
  93. quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
  94. quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
  95. quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
  96. quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
  97. quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
  98. quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
  99. quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
  100. quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
  101. quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
  102. quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
  103. quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
  104. quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
  105. quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
  106. quadra/configs/export/default.yaml +13 -0
  107. quadra/configs/hydra/anomaly_custom.yaml +15 -0
  108. quadra/configs/hydra/default.yaml +14 -0
  109. quadra/configs/inference/default.yaml +26 -0
  110. quadra/configs/logger/comet.yaml +10 -0
  111. quadra/configs/logger/csv.yaml +5 -0
  112. quadra/configs/logger/mlflow.yaml +12 -0
  113. quadra/configs/logger/tensorboard.yaml +8 -0
  114. quadra/configs/loss/asl.yaml +7 -0
  115. quadra/configs/loss/barlow.yaml +2 -0
  116. quadra/configs/loss/bce.yaml +1 -0
  117. quadra/configs/loss/byol.yaml +1 -0
  118. quadra/configs/loss/cross_entropy.yaml +1 -0
  119. quadra/configs/loss/dino.yaml +8 -0
  120. quadra/configs/loss/simclr.yaml +2 -0
  121. quadra/configs/loss/simsiam.yaml +1 -0
  122. quadra/configs/loss/smp_ce.yaml +3 -0
  123. quadra/configs/loss/smp_dice.yaml +2 -0
  124. quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
  125. quadra/configs/loss/smp_mcc.yaml +2 -0
  126. quadra/configs/loss/vicreg.yaml +5 -0
  127. quadra/configs/model/anomalib/cfa.yaml +35 -0
  128. quadra/configs/model/anomalib/cflow.yaml +30 -0
  129. quadra/configs/model/anomalib/csflow.yaml +34 -0
  130. quadra/configs/model/anomalib/dfm.yaml +19 -0
  131. quadra/configs/model/anomalib/draem.yaml +29 -0
  132. quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
  133. quadra/configs/model/anomalib/fastflow.yaml +32 -0
  134. quadra/configs/model/anomalib/padim.yaml +32 -0
  135. quadra/configs/model/anomalib/patchcore.yaml +36 -0
  136. quadra/configs/model/barlow.yaml +16 -0
  137. quadra/configs/model/byol.yaml +25 -0
  138. quadra/configs/model/classification.yaml +10 -0
  139. quadra/configs/model/dino.yaml +26 -0
  140. quadra/configs/model/logistic_regression.yaml +4 -0
  141. quadra/configs/model/multilabel_classification.yaml +9 -0
  142. quadra/configs/model/simclr.yaml +18 -0
  143. quadra/configs/model/simsiam.yaml +24 -0
  144. quadra/configs/model/smp.yaml +4 -0
  145. quadra/configs/model/smp_multiclass.yaml +4 -0
  146. quadra/configs/model/vicreg.yaml +16 -0
  147. quadra/configs/optimizer/adam.yaml +5 -0
  148. quadra/configs/optimizer/adamw.yaml +3 -0
  149. quadra/configs/optimizer/default.yaml +4 -0
  150. quadra/configs/optimizer/lars.yaml +8 -0
  151. quadra/configs/optimizer/sgd.yaml +4 -0
  152. quadra/configs/scheduler/default.yaml +5 -0
  153. quadra/configs/scheduler/rop.yaml +5 -0
  154. quadra/configs/scheduler/step.yaml +3 -0
  155. quadra/configs/scheduler/warmrestart.yaml +2 -0
  156. quadra/configs/scheduler/warmup.yaml +6 -0
  157. quadra/configs/task/anomalib/cfa.yaml +5 -0
  158. quadra/configs/task/anomalib/cflow.yaml +5 -0
  159. quadra/configs/task/anomalib/csflow.yaml +5 -0
  160. quadra/configs/task/anomalib/draem.yaml +5 -0
  161. quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
  162. quadra/configs/task/anomalib/fastflow.yaml +5 -0
  163. quadra/configs/task/anomalib/inference.yaml +3 -0
  164. quadra/configs/task/anomalib/padim.yaml +5 -0
  165. quadra/configs/task/anomalib/patchcore.yaml +5 -0
  166. quadra/configs/task/classification.yaml +6 -0
  167. quadra/configs/task/classification_evaluation.yaml +6 -0
  168. quadra/configs/task/default.yaml +1 -0
  169. quadra/configs/task/segmentation.yaml +9 -0
  170. quadra/configs/task/segmentation_evaluation.yaml +3 -0
  171. quadra/configs/task/sklearn_classification.yaml +13 -0
  172. quadra/configs/task/sklearn_classification_patch.yaml +11 -0
  173. quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
  174. quadra/configs/task/sklearn_classification_test.yaml +8 -0
  175. quadra/configs/task/ssl.yaml +2 -0
  176. quadra/configs/trainer/lightning_cpu.yaml +36 -0
  177. quadra/configs/trainer/lightning_gpu.yaml +35 -0
  178. quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
  179. quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
  180. quadra/configs/trainer/lightning_multigpu.yaml +37 -0
  181. quadra/configs/trainer/sklearn_classification.yaml +7 -0
  182. quadra/configs/transforms/byol.yaml +47 -0
  183. quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
  184. quadra/configs/transforms/default.yaml +37 -0
  185. quadra/configs/transforms/default_numpy.yaml +24 -0
  186. quadra/configs/transforms/default_resize.yaml +22 -0
  187. quadra/configs/transforms/dino.yaml +63 -0
  188. quadra/configs/transforms/linear_eval.yaml +18 -0
  189. quadra/datamodules/__init__.py +20 -0
  190. quadra/datamodules/anomaly.py +180 -0
  191. quadra/datamodules/base.py +375 -0
  192. quadra/datamodules/classification.py +1003 -0
  193. quadra/datamodules/generic/__init__.py +0 -0
  194. quadra/datamodules/generic/imagenette.py +144 -0
  195. quadra/datamodules/generic/mnist.py +81 -0
  196. quadra/datamodules/generic/mvtec.py +58 -0
  197. quadra/datamodules/generic/oxford_pet.py +163 -0
  198. quadra/datamodules/patch.py +190 -0
  199. quadra/datamodules/segmentation.py +742 -0
  200. quadra/datamodules/ssl.py +140 -0
  201. quadra/datasets/__init__.py +17 -0
  202. quadra/datasets/anomaly.py +287 -0
  203. quadra/datasets/classification.py +241 -0
  204. quadra/datasets/patch.py +138 -0
  205. quadra/datasets/segmentation.py +239 -0
  206. quadra/datasets/ssl.py +110 -0
  207. quadra/losses/__init__.py +0 -0
  208. quadra/losses/classification/__init__.py +6 -0
  209. quadra/losses/classification/asl.py +83 -0
  210. quadra/losses/classification/focal.py +320 -0
  211. quadra/losses/classification/prototypical.py +148 -0
  212. quadra/losses/ssl/__init__.py +17 -0
  213. quadra/losses/ssl/barlowtwins.py +47 -0
  214. quadra/losses/ssl/byol.py +37 -0
  215. quadra/losses/ssl/dino.py +129 -0
  216. quadra/losses/ssl/hyperspherical.py +45 -0
  217. quadra/losses/ssl/idmm.py +50 -0
  218. quadra/losses/ssl/simclr.py +67 -0
  219. quadra/losses/ssl/simsiam.py +30 -0
  220. quadra/losses/ssl/vicreg.py +76 -0
  221. quadra/main.py +46 -0
  222. quadra/metrics/__init__.py +3 -0
  223. quadra/metrics/segmentation.py +251 -0
  224. quadra/models/__init__.py +0 -0
  225. quadra/models/base.py +151 -0
  226. quadra/models/classification/__init__.py +8 -0
  227. quadra/models/classification/backbones.py +149 -0
  228. quadra/models/classification/base.py +92 -0
  229. quadra/models/evaluation.py +322 -0
  230. quadra/modules/__init__.py +0 -0
  231. quadra/modules/backbone.py +30 -0
  232. quadra/modules/base.py +312 -0
  233. quadra/modules/classification/__init__.py +3 -0
  234. quadra/modules/classification/base.py +331 -0
  235. quadra/modules/ssl/__init__.py +17 -0
  236. quadra/modules/ssl/barlowtwins.py +59 -0
  237. quadra/modules/ssl/byol.py +172 -0
  238. quadra/modules/ssl/common.py +285 -0
  239. quadra/modules/ssl/dino.py +186 -0
  240. quadra/modules/ssl/hyperspherical.py +206 -0
  241. quadra/modules/ssl/idmm.py +98 -0
  242. quadra/modules/ssl/simclr.py +73 -0
  243. quadra/modules/ssl/simsiam.py +68 -0
  244. quadra/modules/ssl/vicreg.py +67 -0
  245. quadra/optimizers/__init__.py +4 -0
  246. quadra/optimizers/lars.py +153 -0
  247. quadra/optimizers/sam.py +127 -0
  248. quadra/schedulers/__init__.py +3 -0
  249. quadra/schedulers/base.py +44 -0
  250. quadra/schedulers/warmup.py +127 -0
  251. quadra/tasks/__init__.py +24 -0
  252. quadra/tasks/anomaly.py +582 -0
  253. quadra/tasks/base.py +397 -0
  254. quadra/tasks/classification.py +1264 -0
  255. quadra/tasks/patch.py +492 -0
  256. quadra/tasks/segmentation.py +389 -0
  257. quadra/tasks/ssl.py +560 -0
  258. quadra/trainers/README.md +3 -0
  259. quadra/trainers/__init__.py +0 -0
  260. quadra/trainers/classification.py +179 -0
  261. quadra/utils/__init__.py +0 -0
  262. quadra/utils/anomaly.py +112 -0
  263. quadra/utils/classification.py +618 -0
  264. quadra/utils/deprecation.py +31 -0
  265. quadra/utils/evaluation.py +474 -0
  266. quadra/utils/export.py +579 -0
  267. quadra/utils/imaging.py +32 -0
  268. quadra/utils/logger.py +15 -0
  269. quadra/utils/mlflow.py +98 -0
  270. quadra/utils/model_manager.py +320 -0
  271. quadra/utils/models.py +524 -0
  272. quadra/utils/patch/__init__.py +15 -0
  273. quadra/utils/patch/dataset.py +1433 -0
  274. quadra/utils/patch/metrics.py +449 -0
  275. quadra/utils/patch/model.py +153 -0
  276. quadra/utils/patch/visualization.py +217 -0
  277. quadra/utils/resolver.py +42 -0
  278. quadra/utils/segmentation.py +31 -0
  279. quadra/utils/tests/__init__.py +0 -0
  280. quadra/utils/tests/fixtures/__init__.py +1 -0
  281. quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
  282. quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
  283. quadra/utils/tests/fixtures/dataset/classification.py +406 -0
  284. quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
  285. quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
  286. quadra/utils/tests/fixtures/models/__init__.py +3 -0
  287. quadra/utils/tests/fixtures/models/anomaly.py +89 -0
  288. quadra/utils/tests/fixtures/models/classification.py +45 -0
  289. quadra/utils/tests/fixtures/models/segmentation.py +33 -0
  290. quadra/utils/tests/helpers.py +70 -0
  291. quadra/utils/tests/models.py +27 -0
  292. quadra/utils/utils.py +525 -0
  293. quadra/utils/validator.py +115 -0
  294. quadra/utils/visualization.py +422 -0
  295. quadra/utils/vit_explainability.py +349 -0
  296. quadra-2.1.13.dist-info/LICENSE +201 -0
  297. quadra-2.1.13.dist-info/METADATA +386 -0
  298. quadra-2.1.13.dist-info/RECORD +300 -0
  299. {quadra-0.0.1.dist-info → quadra-2.1.13.dist-info}/WHEEL +1 -1
  300. quadra-2.1.13.dist-info/entry_points.txt +3 -0
  301. quadra-0.0.1.dist-info/METADATA +0 -14
  302. quadra-0.0.1.dist-info/RECORD +0 -4
@@ -0,0 +1,217 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+
5
+ import cv2
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ from matplotlib.cm import get_cmap
9
+ from matplotlib.colors import Colormap
10
+ from matplotlib.lines import Line2D
11
+ from matplotlib.pyplot import Figure
12
+
13
+ from quadra.utils import utils
14
+
15
+ log = utils.get_logger(__name__)
16
+
17
+
18
+ def plot_patch_reconstruction(
19
+ reconstruction: dict,
20
+ idx_to_class: dict[int, str],
21
+ class_to_idx: dict[str, int],
22
+ ignore_classes: list[int] | None = None,
23
+ is_polygon: bool = True,
24
+ ) -> Figure:
25
+ """Helper function for plotting the patch reconstruction.
26
+
27
+ Args:
28
+ reconstruction: Dict following this structure
29
+ {
30
+ "file_path": str,
31
+ "mask_path": str,
32
+ "prediction": {
33
+ "label": str,
34
+ "points": [{"x": int, "y": int}]
35
+ }
36
+ } if is_polygon else
37
+ {
38
+ "file_path": str,
39
+ "mask_path": str,
40
+ "prediction": np.ndarray
41
+ }
42
+ idx_to_class: Dictionary mapping indices to label names
43
+ class_to_idx: Dictionary mapping class names to indices
44
+ ignore_classes: Eventually the classes to not plot
45
+ is_polygon: Boolean indicating if the prediction is a polygon or a mask.
46
+
47
+ Returns:
48
+ Matplotlib plot showing predicted patch regions and eventually gt
49
+
50
+ """
51
+ cmap_name = "tab10"
52
+
53
+ # 10 classes + good
54
+ if len(idx_to_class.values()) > 11:
55
+ cmap_name = "tab20"
56
+
57
+ cmap = get_cmap(cmap_name)
58
+ test_img = cv2.imread(reconstruction["image_path"])
59
+ test_img = cv2.cvtColor(test_img, cv2.COLOR_BGR2RGB)
60
+ gt_img = None
61
+
62
+ if reconstruction["mask_path"] is not None and os.path.isfile(reconstruction["mask_path"]):
63
+ gt_img = cv2.imread(reconstruction["mask_path"], 0)
64
+
65
+ out = np.zeros((test_img.shape[0], test_img.shape[1]), dtype=np.uint8)
66
+
67
+ if is_polygon:
68
+ for _, region in enumerate(reconstruction["prediction"]):
69
+ points = [[item["x"], item["y"]] for item in region["points"]]
70
+ c_label = region["label"]
71
+
72
+ out = cv2.drawContours(
73
+ out,
74
+ np.array([points], np.int32),
75
+ -1,
76
+ class_to_idx[c_label],
77
+ thickness=cv2.FILLED,
78
+ ) # type: ignore[call-overload]
79
+ else:
80
+ out = reconstruction["prediction"]
81
+
82
+ fig = plot_patch_results(
83
+ image=test_img,
84
+ prediction_image=out,
85
+ ground_truth_image=gt_img,
86
+ plot_original=True,
87
+ ignore_classes=ignore_classes,
88
+ save_path=None,
89
+ class_to_idx=class_to_idx,
90
+ cmap=cmap,
91
+ )
92
+
93
+ return fig
94
+
95
+
96
+ def show_mask_on_image(image: np.ndarray, mask: np.ndarray):
97
+ """Plot mask on top of the original image."""
98
+ image = image.astype(np.float32) / 255
99
+ mask = mask.astype(np.float32) / 255
100
+ out = mask + image.astype(np.float32)
101
+ out = out / np.max(out)
102
+ return np.uint8(255 * out)
103
+
104
+
105
+ def create_rgb_mask(
106
+ mask: np.ndarray,
107
+ color_map: dict,
108
+ ignore_classes: list[int] | None = None,
109
+ ground_truth_mask: np.ndarray | None = None,
110
+ ):
111
+ """Convert index mask to RGB mask."""
112
+ output_mask = np.zeros([mask.shape[0], mask.shape[1], 3])
113
+ for c in np.unique(mask):
114
+ if ignore_classes is not None and c in ignore_classes:
115
+ continue
116
+
117
+ output_mask[mask == c] = color_map[str(c)]
118
+ if ignore_classes is not None and ground_truth_mask is not None:
119
+ output_mask[np.isin(ground_truth_mask, ignore_classes)] = [0, 0, 0]
120
+
121
+ return output_mask
122
+
123
+
124
+ def plot_patch_results(
125
+ image: np.ndarray,
126
+ prediction_image: np.ndarray,
127
+ ground_truth_image: np.ndarray | None,
128
+ class_to_idx: dict[str, int],
129
+ plot_original: bool = True,
130
+ ignore_classes: list[int] | None = None,
131
+ image_height: int = 10,
132
+ save_path: str | None = None,
133
+ cmap: Colormap | None = None,
134
+ ) -> Figure:
135
+ """Function used to plot the image predicted.
136
+
137
+ Args:
138
+ prediction_image: The prediction image
139
+ image: The original image to plot
140
+ ground_truth_image: The ground truth image
141
+ class_to_idx: Dictionary mapping class names to indices
142
+ plot_original: Boolean to plot the original image
143
+ ignore_classes: The classes to ignore, default is 0
144
+ image_height: The height of the output figure
145
+ save_path: The path to save the figure
146
+ cmap: The colormap to use. If None, tab20 is used
147
+
148
+ Returns:
149
+ The matplotlib figure
150
+ """
151
+ if ignore_classes is None:
152
+ ignore_classes = [0]
153
+
154
+ if cmap is None:
155
+ cmap = get_cmap("tab20")
156
+
157
+ image = image[0 : prediction_image.shape[0], 0 : prediction_image.shape[1], :]
158
+ idx_to_class = {v: k for k, v in class_to_idx.items()}
159
+
160
+ if ignore_classes is not None:
161
+ class_to_idx = {k: v for k, v in class_to_idx.items() if v not in ignore_classes}
162
+
163
+ class_idxs = list(class_to_idx.values())
164
+
165
+ cmap = {str(c): tuple(int(i * 255) for i in cmap(c / len(class_idxs))[:-1]) for c in class_idxs}
166
+ output_images = []
167
+ titles = []
168
+
169
+ if plot_original:
170
+ output_images.append(image)
171
+ titles.append("Original Image")
172
+
173
+ if ground_truth_image is not None:
174
+ ground_truth_image = ground_truth_image[0 : prediction_image.shape[0], 0 : prediction_image.shape[1]]
175
+ ground_truth_mask = create_rgb_mask(ground_truth_image, cmap, ignore_classes=ignore_classes)
176
+ output_images.append(ground_truth_mask)
177
+ titles.append("Ground Truth Mask")
178
+
179
+ prediction_mask = create_rgb_mask(
180
+ prediction_image,
181
+ cmap,
182
+ ignore_classes=ignore_classes,
183
+ )
184
+
185
+ output_images.append(prediction_mask)
186
+ titles.append("Prediction Mask")
187
+ if ignore_classes is not None and ground_truth_image is not None:
188
+ prediction_mask = create_rgb_mask(
189
+ prediction_image, cmap, ignore_classes=ignore_classes, ground_truth_mask=ground_truth_image
190
+ )
191
+
192
+ ignored_classes_str = [idx_to_class[c] for c in ignore_classes]
193
+ prediction_title = f"Prediction Mask \n (Ignoring Ground Truth Class: {ignored_classes_str})"
194
+ output_images.append(prediction_mask)
195
+ titles.append(prediction_title)
196
+
197
+ fig, axs = plt.subplots(
198
+ ncols=len(output_images),
199
+ nrows=1,
200
+ figsize=(len(output_images) * image_height, image_height),
201
+ squeeze=False,
202
+ facecolor="white",
203
+ )
204
+
205
+ for i, output_image in enumerate(output_images):
206
+ axs[0, i].imshow(show_mask_on_image(image, output_image))
207
+ axs[0, i].set_title(titles[i])
208
+ axs[0, i].axis("off")
209
+
210
+ custom_lines = [Line2D([0], [0], color=tuple(i / 255.0 for i in cmap[str(c)]), lw=4) for c in class_idxs]
211
+ custom_labels = list(class_to_idx.keys())
212
+ axs[0, -1].legend(custom_lines, custom_labels, loc="center left", bbox_to_anchor=(1.01, 0.81), borderaxespad=0)
213
+ if save_path is not None:
214
+ plt.savefig(save_path, bbox_inches="tight")
215
+ plt.close()
216
+
217
+ return fig
@@ -0,0 +1,42 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+ from hydra.core.hydra_config import HydraConfig
6
+ from omegaconf import OmegaConf
7
+
8
+
9
+ def multirun_subdir_beautify(subdir: str) -> str:
10
+ """Change the subdir name to be more readable and usable, this function will replace / with | to avoid creating
11
+ undesired subdirectories and remove the left part of the equals sign to avoid having too long names.
12
+
13
+ Args:
14
+ subdir: The subdir name.
15
+
16
+ Returns:
17
+ The beautified subdir name.
18
+
19
+ Examples:
20
+ >>> multirun_subdir_beautify("experiment=pippo/anomaly/padim,trainer.batch_size=32")
21
+ "pippo|anomaly|padim,32"
22
+ """
23
+ hydra_cfg = HydraConfig.get()
24
+ if hydra_cfg.mode is None or hydra_cfg.mode.name == "RUN":
25
+ return subdir
26
+ # Remove slashes to avoid creating multiple subdirs
27
+ # TODO: if right side of the equals sign has `,` this will not work.
28
+ subdir_list = subdir.replace("/", "|").split(",")
29
+ subdir = ",".join([x.split("=")[1].replace(" ", "") for x in subdir_list])
30
+
31
+ return subdir
32
+
33
+
34
+ def as_tuple(*args: Any) -> tuple[Any, ...]:
35
+ """Resolves a list of arguments to a tuple."""
36
+ return tuple(args)
37
+
38
+
39
+ def register_resolvers() -> None:
40
+ """Register custom resolver."""
41
+ OmegaConf.register_new_resolver("multirun_subdir_beautify", multirun_subdir_beautify)
42
+ OmegaConf.register_new_resolver("as_tuple", as_tuple)
@@ -0,0 +1,31 @@
1
+ import numpy as np
2
+ import skimage
3
+ from skimage.morphology import medial_axis
4
+
5
+ from quadra.utils import utils
6
+
7
+ log = utils.get_logger(__name__)
8
+
9
+
10
+ def smooth_mask(mask: np.ndarray) -> np.ndarray:
11
+ """Smooths for segmentation.
12
+
13
+ Args:
14
+ mask: Input mask
15
+
16
+ Returns:
17
+ Smoothed mask
18
+ """
19
+ labeled_mask = skimage.measure.label(mask)
20
+ labels = np.arange(0, np.max(labeled_mask) + 1)
21
+ output_mask = np.zeros_like(mask).astype(np.float32)
22
+ for l in labels:
23
+ component_mask = labeled_mask == l
24
+ _, distance = medial_axis(component_mask, return_distance=True)
25
+ component_mask_norm = distance ** (1 / 2.2)
26
+ component_mask_norm = (component_mask_norm - np.min(component_mask_norm)) / (
27
+ np.max(component_mask_norm) - np.min(component_mask_norm)
28
+ )
29
+ output_mask += component_mask_norm
30
+ output_mask = output_mask * mask
31
+ return output_mask
File without changes
@@ -0,0 +1 @@
1
+ from .dataset import * # noqa: F403
@@ -0,0 +1,39 @@
1
+ from .anomaly import AnomalyDatasetArguments, anomaly_dataset, base_anomaly_dataset
2
+ from .classification import (
3
+ ClassificationDatasetArguments,
4
+ ClassificationMultilabelDatasetArguments,
5
+ ClassificationPatchDatasetArguments,
6
+ base_classification_dataset,
7
+ base_multilabel_classification_dataset,
8
+ base_patch_classification_dataset,
9
+ classification_dataset,
10
+ classification_patch_dataset,
11
+ multilabel_classification_dataset,
12
+ )
13
+ from .imagenette import imagenette_dataset
14
+ from .segmentation import (
15
+ SegmentationDatasetArguments,
16
+ base_binary_segmentation_dataset,
17
+ base_multiclass_segmentation_dataset,
18
+ segmentation_dataset,
19
+ )
20
+
21
+ __all__ = [
22
+ "anomaly_dataset",
23
+ "classification_dataset",
24
+ "AnomalyDatasetArguments",
25
+ "ClassificationDatasetArguments",
26
+ "ClassificationPatchDatasetArguments",
27
+ "classification_patch_dataset",
28
+ "segmentation_dataset",
29
+ "SegmentationDatasetArguments",
30
+ "multilabel_classification_dataset",
31
+ "ClassificationMultilabelDatasetArguments",
32
+ "base_anomaly_dataset",
33
+ "imagenette_dataset",
34
+ "base_classification_dataset",
35
+ "base_patch_classification_dataset",
36
+ "base_binary_segmentation_dataset",
37
+ "base_multiclass_segmentation_dataset",
38
+ "base_multilabel_classification_dataset",
39
+ ]
@@ -0,0 +1,124 @@
1
+ from __future__ import annotations
2
+
3
+ import shutil
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+ import cv2
9
+ import pytest
10
+
11
+ from quadra.utils.tests.helpers import _random_image
12
+
13
+
14
+ @dataclass
15
+ class AnomalyDatasetArguments:
16
+ """Anomaly dataset arguments.
17
+
18
+ Args:
19
+ train_samples: number of train samples
20
+ val_samples: number of validation samples (good, bad)
21
+ test_samples: number of test samples (good, bad)
22
+ """
23
+
24
+ train_samples: int
25
+ val_samples: tuple[int, int]
26
+ test_samples: tuple[int, int]
27
+
28
+
29
+ def _build_anomaly_dataset(
30
+ tmp_path: Path, dataset_arguments: AnomalyDatasetArguments
31
+ ) -> tuple[str, AnomalyDatasetArguments]:
32
+ """Generate anomaly dataset in the standard mvtec format.
33
+
34
+ Args:
35
+ tmp_path: path to temporary directory
36
+ dataset_arguments: dataset arguments
37
+
38
+ Returns:
39
+ path to anomaly dataset
40
+ """
41
+ train_samples = dataset_arguments.train_samples
42
+ val_samples = dataset_arguments.val_samples
43
+ test_samples = dataset_arguments.test_samples
44
+
45
+ anomaly_dataset_path = tmp_path / "anomaly_dataset"
46
+ anomaly_dataset_path.mkdir()
47
+ train_good_path = anomaly_dataset_path / "train" / "good"
48
+ val_good_path = anomaly_dataset_path / "val" / "good"
49
+ val_bad_path = anomaly_dataset_path / "val" / "bad"
50
+ test_good_path = anomaly_dataset_path / "test" / "good"
51
+ test_bad_path = anomaly_dataset_path / "test" / "bad"
52
+
53
+ train_good_path.mkdir(parents=True)
54
+ val_good_path.mkdir(parents=True)
55
+ val_bad_path.mkdir(parents=True)
56
+ test_good_path.mkdir(parents=True)
57
+ test_bad_path.mkdir(parents=True)
58
+
59
+ # Generate train good images
60
+ for i in range(train_samples):
61
+ image = _random_image()
62
+ image_path = train_good_path / f"train_{i}.png"
63
+ cv2.imwrite(str(image_path), image)
64
+
65
+ # Generate val good images
66
+ for i in range(val_samples[0]):
67
+ image = _random_image()
68
+ image_path = val_good_path / f"val_{i}.png"
69
+ cv2.imwrite(str(image_path), image)
70
+ # Generate val bad images
71
+ for i in range(val_samples[1]):
72
+ image = _random_image()
73
+ image_path = val_bad_path / f"val_{i}.png"
74
+ cv2.imwrite(str(image_path), image)
75
+
76
+ # Generate test good images
77
+ for i in range(test_samples[0]):
78
+ image = _random_image()
79
+ image_path = test_good_path / f"test_{i}.png"
80
+ cv2.imwrite(str(image_path), image)
81
+ # Generate test bad images
82
+ for i in range(test_samples[1]):
83
+ image = _random_image()
84
+ image_path = test_bad_path / f"test_{i}.png"
85
+ cv2.imwrite(str(image_path), image)
86
+
87
+ return str(anomaly_dataset_path), dataset_arguments
88
+
89
+
90
+ @pytest.fixture
91
+ def anomaly_dataset(tmp_path: Path, dataset_arguments: AnomalyDatasetArguments) -> tuple[str, AnomalyDatasetArguments]:
92
+ """Fixture used to dinamically generate anomaly dataset. By default images are random grayscales with size 10x10.
93
+
94
+ Args:
95
+ tmp_path: path to temporary directory
96
+ dataset_arguments: dataset arguments
97
+
98
+ Returns:
99
+ path to anomaly dataset
100
+ """
101
+ yield _build_anomaly_dataset(tmp_path, dataset_arguments)
102
+ if tmp_path.exists():
103
+ shutil.rmtree(tmp_path)
104
+
105
+
106
+ @pytest.fixture(
107
+ params=[AnomalyDatasetArguments(**{"train_samples": 10, "val_samples": (1, 1), "test_samples": (1, 1)})]
108
+ )
109
+ def base_anomaly_dataset(tmp_path: Path, request: Any) -> tuple[str, AnomalyDatasetArguments]:
110
+ """Generate base anomaly dataset with the following parameters:
111
+ - train_samples: 10
112
+ - val_samples: (10, 10)
113
+ - test_samples: (10, 10).
114
+
115
+ Args:
116
+ tmp_path: Path to temporary directory
117
+ request: Pytest SubRequest object
118
+
119
+ Yields:
120
+ Path to anomaly dataset and dataset arguments
121
+ """
122
+ yield _build_anomaly_dataset(tmp_path, request.param)
123
+ if tmp_path.exists():
124
+ shutil.rmtree(tmp_path)