quadra 0.0.1__py3-none-any.whl → 2.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (302) hide show
  1. hydra_plugins/quadra_searchpath_plugin.py +30 -0
  2. quadra/__init__.py +6 -0
  3. quadra/callbacks/__init__.py +0 -0
  4. quadra/callbacks/anomalib.py +289 -0
  5. quadra/callbacks/lightning.py +501 -0
  6. quadra/callbacks/mlflow.py +291 -0
  7. quadra/callbacks/scheduler.py +69 -0
  8. quadra/configs/__init__.py +0 -0
  9. quadra/configs/backbone/caformer_m36.yaml +8 -0
  10. quadra/configs/backbone/caformer_s36.yaml +8 -0
  11. quadra/configs/backbone/convnextv2_base.yaml +8 -0
  12. quadra/configs/backbone/convnextv2_femto.yaml +8 -0
  13. quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
  14. quadra/configs/backbone/dino_vitb8.yaml +12 -0
  15. quadra/configs/backbone/dino_vits8.yaml +12 -0
  16. quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
  17. quadra/configs/backbone/dinov2_vits14.yaml +12 -0
  18. quadra/configs/backbone/efficientnet_b0.yaml +8 -0
  19. quadra/configs/backbone/efficientnet_b1.yaml +8 -0
  20. quadra/configs/backbone/efficientnet_b2.yaml +8 -0
  21. quadra/configs/backbone/efficientnet_b3.yaml +8 -0
  22. quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
  23. quadra/configs/backbone/levit_128s.yaml +8 -0
  24. quadra/configs/backbone/mnasnet0_5.yaml +9 -0
  25. quadra/configs/backbone/resnet101.yaml +8 -0
  26. quadra/configs/backbone/resnet18.yaml +8 -0
  27. quadra/configs/backbone/resnet18_ssl.yaml +8 -0
  28. quadra/configs/backbone/resnet50.yaml +8 -0
  29. quadra/configs/backbone/smp.yaml +9 -0
  30. quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
  31. quadra/configs/backbone/unetr.yaml +15 -0
  32. quadra/configs/backbone/vit16_base.yaml +9 -0
  33. quadra/configs/backbone/vit16_small.yaml +9 -0
  34. quadra/configs/backbone/vit16_tiny.yaml +9 -0
  35. quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
  36. quadra/configs/callbacks/all.yaml +45 -0
  37. quadra/configs/callbacks/default.yaml +34 -0
  38. quadra/configs/callbacks/default_anomalib.yaml +64 -0
  39. quadra/configs/config.yaml +33 -0
  40. quadra/configs/core/default.yaml +11 -0
  41. quadra/configs/datamodule/base/anomaly.yaml +16 -0
  42. quadra/configs/datamodule/base/classification.yaml +21 -0
  43. quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
  44. quadra/configs/datamodule/base/segmentation.yaml +18 -0
  45. quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
  46. quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
  47. quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
  48. quadra/configs/datamodule/base/ssl.yaml +21 -0
  49. quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
  50. quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
  51. quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
  52. quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
  53. quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
  54. quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
  55. quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
  56. quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
  57. quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
  58. quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
  59. quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
  60. quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
  61. quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
  62. quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
  63. quadra/configs/experiment/base/classification/classification.yaml +73 -0
  64. quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
  65. quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
  66. quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
  67. quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
  68. quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
  69. quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
  70. quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
  71. quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
  72. quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
  73. quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
  74. quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
  75. quadra/configs/experiment/base/ssl/byol.yaml +43 -0
  76. quadra/configs/experiment/base/ssl/dino.yaml +46 -0
  77. quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
  78. quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
  79. quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
  80. quadra/configs/experiment/custom/cls.yaml +12 -0
  81. quadra/configs/experiment/default.yaml +15 -0
  82. quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
  83. quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
  84. quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
  85. quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
  86. quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
  87. quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
  88. quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
  89. quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
  90. quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
  91. quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
  92. quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
  93. quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
  94. quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
  95. quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
  96. quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
  97. quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
  98. quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
  99. quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
  100. quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
  101. quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
  102. quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
  103. quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
  104. quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
  105. quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
  106. quadra/configs/export/default.yaml +13 -0
  107. quadra/configs/hydra/anomaly_custom.yaml +15 -0
  108. quadra/configs/hydra/default.yaml +14 -0
  109. quadra/configs/inference/default.yaml +26 -0
  110. quadra/configs/logger/comet.yaml +10 -0
  111. quadra/configs/logger/csv.yaml +5 -0
  112. quadra/configs/logger/mlflow.yaml +12 -0
  113. quadra/configs/logger/tensorboard.yaml +8 -0
  114. quadra/configs/loss/asl.yaml +7 -0
  115. quadra/configs/loss/barlow.yaml +2 -0
  116. quadra/configs/loss/bce.yaml +1 -0
  117. quadra/configs/loss/byol.yaml +1 -0
  118. quadra/configs/loss/cross_entropy.yaml +1 -0
  119. quadra/configs/loss/dino.yaml +8 -0
  120. quadra/configs/loss/simclr.yaml +2 -0
  121. quadra/configs/loss/simsiam.yaml +1 -0
  122. quadra/configs/loss/smp_ce.yaml +3 -0
  123. quadra/configs/loss/smp_dice.yaml +2 -0
  124. quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
  125. quadra/configs/loss/smp_mcc.yaml +2 -0
  126. quadra/configs/loss/vicreg.yaml +5 -0
  127. quadra/configs/model/anomalib/cfa.yaml +35 -0
  128. quadra/configs/model/anomalib/cflow.yaml +30 -0
  129. quadra/configs/model/anomalib/csflow.yaml +34 -0
  130. quadra/configs/model/anomalib/dfm.yaml +19 -0
  131. quadra/configs/model/anomalib/draem.yaml +29 -0
  132. quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
  133. quadra/configs/model/anomalib/fastflow.yaml +32 -0
  134. quadra/configs/model/anomalib/padim.yaml +32 -0
  135. quadra/configs/model/anomalib/patchcore.yaml +36 -0
  136. quadra/configs/model/barlow.yaml +16 -0
  137. quadra/configs/model/byol.yaml +25 -0
  138. quadra/configs/model/classification.yaml +10 -0
  139. quadra/configs/model/dino.yaml +26 -0
  140. quadra/configs/model/logistic_regression.yaml +4 -0
  141. quadra/configs/model/multilabel_classification.yaml +9 -0
  142. quadra/configs/model/simclr.yaml +18 -0
  143. quadra/configs/model/simsiam.yaml +24 -0
  144. quadra/configs/model/smp.yaml +4 -0
  145. quadra/configs/model/smp_multiclass.yaml +4 -0
  146. quadra/configs/model/vicreg.yaml +16 -0
  147. quadra/configs/optimizer/adam.yaml +5 -0
  148. quadra/configs/optimizer/adamw.yaml +3 -0
  149. quadra/configs/optimizer/default.yaml +4 -0
  150. quadra/configs/optimizer/lars.yaml +8 -0
  151. quadra/configs/optimizer/sgd.yaml +4 -0
  152. quadra/configs/scheduler/default.yaml +5 -0
  153. quadra/configs/scheduler/rop.yaml +5 -0
  154. quadra/configs/scheduler/step.yaml +3 -0
  155. quadra/configs/scheduler/warmrestart.yaml +2 -0
  156. quadra/configs/scheduler/warmup.yaml +6 -0
  157. quadra/configs/task/anomalib/cfa.yaml +5 -0
  158. quadra/configs/task/anomalib/cflow.yaml +5 -0
  159. quadra/configs/task/anomalib/csflow.yaml +5 -0
  160. quadra/configs/task/anomalib/draem.yaml +5 -0
  161. quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
  162. quadra/configs/task/anomalib/fastflow.yaml +5 -0
  163. quadra/configs/task/anomalib/inference.yaml +3 -0
  164. quadra/configs/task/anomalib/padim.yaml +5 -0
  165. quadra/configs/task/anomalib/patchcore.yaml +5 -0
  166. quadra/configs/task/classification.yaml +6 -0
  167. quadra/configs/task/classification_evaluation.yaml +6 -0
  168. quadra/configs/task/default.yaml +1 -0
  169. quadra/configs/task/segmentation.yaml +9 -0
  170. quadra/configs/task/segmentation_evaluation.yaml +3 -0
  171. quadra/configs/task/sklearn_classification.yaml +13 -0
  172. quadra/configs/task/sklearn_classification_patch.yaml +11 -0
  173. quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
  174. quadra/configs/task/sklearn_classification_test.yaml +8 -0
  175. quadra/configs/task/ssl.yaml +2 -0
  176. quadra/configs/trainer/lightning_cpu.yaml +36 -0
  177. quadra/configs/trainer/lightning_gpu.yaml +35 -0
  178. quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
  179. quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
  180. quadra/configs/trainer/lightning_multigpu.yaml +37 -0
  181. quadra/configs/trainer/sklearn_classification.yaml +7 -0
  182. quadra/configs/transforms/byol.yaml +47 -0
  183. quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
  184. quadra/configs/transforms/default.yaml +37 -0
  185. quadra/configs/transforms/default_numpy.yaml +24 -0
  186. quadra/configs/transforms/default_resize.yaml +22 -0
  187. quadra/configs/transforms/dino.yaml +63 -0
  188. quadra/configs/transforms/linear_eval.yaml +18 -0
  189. quadra/datamodules/__init__.py +20 -0
  190. quadra/datamodules/anomaly.py +180 -0
  191. quadra/datamodules/base.py +375 -0
  192. quadra/datamodules/classification.py +1003 -0
  193. quadra/datamodules/generic/__init__.py +0 -0
  194. quadra/datamodules/generic/imagenette.py +144 -0
  195. quadra/datamodules/generic/mnist.py +81 -0
  196. quadra/datamodules/generic/mvtec.py +58 -0
  197. quadra/datamodules/generic/oxford_pet.py +163 -0
  198. quadra/datamodules/patch.py +190 -0
  199. quadra/datamodules/segmentation.py +742 -0
  200. quadra/datamodules/ssl.py +140 -0
  201. quadra/datasets/__init__.py +17 -0
  202. quadra/datasets/anomaly.py +287 -0
  203. quadra/datasets/classification.py +241 -0
  204. quadra/datasets/patch.py +138 -0
  205. quadra/datasets/segmentation.py +239 -0
  206. quadra/datasets/ssl.py +110 -0
  207. quadra/losses/__init__.py +0 -0
  208. quadra/losses/classification/__init__.py +6 -0
  209. quadra/losses/classification/asl.py +83 -0
  210. quadra/losses/classification/focal.py +320 -0
  211. quadra/losses/classification/prototypical.py +148 -0
  212. quadra/losses/ssl/__init__.py +17 -0
  213. quadra/losses/ssl/barlowtwins.py +47 -0
  214. quadra/losses/ssl/byol.py +37 -0
  215. quadra/losses/ssl/dino.py +129 -0
  216. quadra/losses/ssl/hyperspherical.py +45 -0
  217. quadra/losses/ssl/idmm.py +50 -0
  218. quadra/losses/ssl/simclr.py +67 -0
  219. quadra/losses/ssl/simsiam.py +30 -0
  220. quadra/losses/ssl/vicreg.py +76 -0
  221. quadra/main.py +49 -0
  222. quadra/metrics/__init__.py +3 -0
  223. quadra/metrics/segmentation.py +251 -0
  224. quadra/models/__init__.py +0 -0
  225. quadra/models/base.py +151 -0
  226. quadra/models/classification/__init__.py +8 -0
  227. quadra/models/classification/backbones.py +149 -0
  228. quadra/models/classification/base.py +92 -0
  229. quadra/models/evaluation.py +322 -0
  230. quadra/modules/__init__.py +0 -0
  231. quadra/modules/backbone.py +30 -0
  232. quadra/modules/base.py +312 -0
  233. quadra/modules/classification/__init__.py +3 -0
  234. quadra/modules/classification/base.py +327 -0
  235. quadra/modules/ssl/__init__.py +17 -0
  236. quadra/modules/ssl/barlowtwins.py +59 -0
  237. quadra/modules/ssl/byol.py +172 -0
  238. quadra/modules/ssl/common.py +285 -0
  239. quadra/modules/ssl/dino.py +186 -0
  240. quadra/modules/ssl/hyperspherical.py +206 -0
  241. quadra/modules/ssl/idmm.py +98 -0
  242. quadra/modules/ssl/simclr.py +73 -0
  243. quadra/modules/ssl/simsiam.py +68 -0
  244. quadra/modules/ssl/vicreg.py +67 -0
  245. quadra/optimizers/__init__.py +4 -0
  246. quadra/optimizers/lars.py +153 -0
  247. quadra/optimizers/sam.py +127 -0
  248. quadra/schedulers/__init__.py +3 -0
  249. quadra/schedulers/base.py +44 -0
  250. quadra/schedulers/warmup.py +127 -0
  251. quadra/tasks/__init__.py +24 -0
  252. quadra/tasks/anomaly.py +582 -0
  253. quadra/tasks/base.py +397 -0
  254. quadra/tasks/classification.py +1263 -0
  255. quadra/tasks/patch.py +492 -0
  256. quadra/tasks/segmentation.py +389 -0
  257. quadra/tasks/ssl.py +560 -0
  258. quadra/trainers/README.md +3 -0
  259. quadra/trainers/__init__.py +0 -0
  260. quadra/trainers/classification.py +179 -0
  261. quadra/utils/__init__.py +0 -0
  262. quadra/utils/anomaly.py +112 -0
  263. quadra/utils/classification.py +618 -0
  264. quadra/utils/deprecation.py +31 -0
  265. quadra/utils/evaluation.py +474 -0
  266. quadra/utils/export.py +585 -0
  267. quadra/utils/imaging.py +32 -0
  268. quadra/utils/logger.py +15 -0
  269. quadra/utils/mlflow.py +98 -0
  270. quadra/utils/model_manager.py +320 -0
  271. quadra/utils/models.py +523 -0
  272. quadra/utils/patch/__init__.py +15 -0
  273. quadra/utils/patch/dataset.py +1433 -0
  274. quadra/utils/patch/metrics.py +449 -0
  275. quadra/utils/patch/model.py +153 -0
  276. quadra/utils/patch/visualization.py +217 -0
  277. quadra/utils/resolver.py +42 -0
  278. quadra/utils/segmentation.py +31 -0
  279. quadra/utils/tests/__init__.py +0 -0
  280. quadra/utils/tests/fixtures/__init__.py +1 -0
  281. quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
  282. quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
  283. quadra/utils/tests/fixtures/dataset/classification.py +406 -0
  284. quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
  285. quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
  286. quadra/utils/tests/fixtures/models/__init__.py +3 -0
  287. quadra/utils/tests/fixtures/models/anomaly.py +89 -0
  288. quadra/utils/tests/fixtures/models/classification.py +45 -0
  289. quadra/utils/tests/fixtures/models/segmentation.py +33 -0
  290. quadra/utils/tests/helpers.py +70 -0
  291. quadra/utils/tests/models.py +27 -0
  292. quadra/utils/utils.py +525 -0
  293. quadra/utils/validator.py +115 -0
  294. quadra/utils/visualization.py +422 -0
  295. quadra/utils/vit_explainability.py +349 -0
  296. quadra-2.2.7.dist-info/LICENSE +201 -0
  297. quadra-2.2.7.dist-info/METADATA +381 -0
  298. quadra-2.2.7.dist-info/RECORD +300 -0
  299. {quadra-0.0.1.dist-info → quadra-2.2.7.dist-info}/WHEEL +1 -1
  300. quadra-2.2.7.dist-info/entry_points.txt +3 -0
  301. quadra-0.0.1.dist-info/METADATA +0 -14
  302. quadra-0.0.1.dist-info/RECORD +0 -4
@@ -0,0 +1,449 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import warnings
5
+
6
+ import cv2
7
+ import numpy as np
8
+ import pandas as pd
9
+ from scipy import ndimage
10
+ from skimage.measure import label, regionprops # pylint: disable=no-name-in-module
11
+ from tqdm import tqdm
12
+
13
+ from quadra.utils import utils
14
+ from quadra.utils.patch.dataset import PatchDatasetFileFormat, compute_patch_info, compute_patch_info_from_patch_dim
15
+
16
+ log = utils.get_logger(__name__)
17
+
18
+
19
+ def get_sorted_patches_by_image(test_results: pd.DataFrame, img_name: str) -> pd.DataFrame:
20
+ """Gets the patches of a given image sorted by patch number.
21
+
22
+ Args:
23
+ test_results: Pandas dataframe containing test results like the one produced by SklearnClassificationTrainer
24
+ img_name: name of the image used to filter the results.
25
+
26
+ Returns:
27
+ test results filtered by image name and sorted by patch number
28
+ """
29
+ img_patches = test_results[test_results["filename"] == os.path.splitext(img_name)[0]]
30
+ patches_idx = np.array(
31
+ [int(os.path.basename(x).split("_")[-1].replace(".png", "")) for x in img_patches["sample"].tolist()]
32
+ )
33
+ patches_idx = np.argsort(patches_idx).tolist()
34
+ img_patches = img_patches.iloc[patches_idx]
35
+
36
+ return img_patches
37
+
38
+
39
+ def compute_patch_metrics(
40
+ test_img_info: list[PatchDatasetFileFormat],
41
+ test_results: pd.DataFrame,
42
+ overlap: float,
43
+ idx_to_class: dict,
44
+ patch_num_h: int | None = None,
45
+ patch_num_w: int | None = None,
46
+ patch_w: int | None = None,
47
+ patch_h: int | None = None,
48
+ return_polygon: bool = False,
49
+ patch_reconstruction_method: str = "priority",
50
+ annotated_good: list[int] | None = None,
51
+ ) -> tuple[int, int, int, list[dict]]:
52
+ """Compute the metrics of a patch dataset.
53
+
54
+ Args:
55
+ test_img_info: List of observation paths and mask paths
56
+ test_results: Pandas dataframe containing the results of an SklearnClassificationTrainer utility
57
+ patch_num_h: Number of vertical patches (required if patch_w and patch_h are None)
58
+ patch_num_w: Number of horizontal patches (required if patch_w and patch_h are None)
59
+ patch_h: Patch height (required if patch_num_h and patch_num_w are None)
60
+ patch_w: Patch width (required if patch_num_h and patch_num_w are None)
61
+ overlap: Percentage of overlap between the patches
62
+ idx_to_class: Dict mapping an index to the corresponding class name
63
+ return_polygon: if set to true convert the reconstructed mask into polygons, otherwise return the mask
64
+ patch_reconstruction_method: How to compute the label of overlapping patches, can either be:
65
+ priority: Assign the top priority label (i.e the one with greater index) to overlapping regions
66
+ major_voting: Assign the most present label among the patches label overlapping a pixel
67
+ annotated_good: List of indices of annotations to be treated as good.
68
+
69
+ Returns:
70
+ Tuple containing:
71
+ false_region_bad: Number of false bad regions detected in the dataset
72
+ false_region_good: Number of missed defects
73
+ true_region_bad: Number of correctly identified defects
74
+ reconstructions: If polygon is true this is a List of dict containing
75
+ {
76
+ "file_path": image_path,
77
+ "mask_path": mask_path,
78
+ "file_name": observation_name,
79
+ "prediction": [{
80
+ "label": predicted_label,
81
+ "points": List of dict coordinates "x" and "y" representing the points of a polygon that
82
+ surrounds an image area covered by patches of label = predicted_label
83
+ }]
84
+ }
85
+ else its a list of dict containing
86
+ {
87
+ "file_path": image_path,
88
+ "mask_path": mask_path,
89
+ "file_name": observation_name,
90
+ "prediction": numpy array containing the reconstructed mask
91
+ }
92
+ """
93
+ assert patch_reconstruction_method in [
94
+ "priority",
95
+ "major_voting",
96
+ ], "Patch reconstruction method not recognized, valid values are priority, major_voting"
97
+
98
+ if (patch_h is not None and patch_w is not None) and (patch_num_h is not None and patch_num_w is not None):
99
+ raise ValueError("Either number of patches or patch size is required for reconstruction")
100
+
101
+ assert (patch_h is not None and patch_w is not None) or (
102
+ patch_num_h is not None and patch_num_w is not None
103
+ ), "Either number of patches or patch size is required for reconstruction"
104
+
105
+ if patch_h is not None and patch_w is not None and patch_num_h is not None and patch_num_w is not None:
106
+ warnings.warn(
107
+ "Both number of patches and patch dimension are specified, using number of patches by default",
108
+ UserWarning,
109
+ stacklevel=2,
110
+ )
111
+
112
+ log.info("Computing patch metrics!")
113
+
114
+ false_region_bad = 0
115
+ false_region_good = 0
116
+ true_region_bad = 0
117
+ reconstructions = []
118
+ test_results["filename"] = test_results["sample"].apply(
119
+ lambda x: "_".join(os.path.basename(x).replace("#DISCARD#", "").split("_")[0:-1])
120
+ )
121
+
122
+ for info in tqdm(test_img_info):
123
+ img_path = info.image_path
124
+ mask_path = info.mask_path
125
+
126
+ img_json_entry = {
127
+ "image_path": img_path,
128
+ "mask_path": mask_path,
129
+ "file_name": os.path.basename(img_path),
130
+ "prediction": None,
131
+ }
132
+
133
+ test_img = cv2.imread(img_path)
134
+
135
+ img_name = os.path.basename(img_path)
136
+
137
+ h = test_img.shape[0]
138
+ w = test_img.shape[1]
139
+
140
+ gt_img = None
141
+
142
+ if mask_path is not None and os.path.exists(mask_path):
143
+ gt_img = cv2.imread(mask_path, 0)
144
+ if test_img.shape[0:2] != gt_img.shape:
145
+ # Ensure that the mask has the same size as the image by padding it with zeros
146
+ log.warning("Found mask with different size than the image, padding it with zeros!")
147
+ gt_img = np.pad(
148
+ gt_img, ((0, test_img.shape[0] - gt_img.shape[0]), (0, test_img.shape[1] - gt_img.shape[1]))
149
+ )
150
+ if patch_num_h is not None and patch_num_w is not None:
151
+ patch_size, step = compute_patch_info(h, w, patch_num_h, patch_num_w, overlap)
152
+ elif patch_h is not None and patch_w is not None:
153
+ [patch_num_h, patch_num_w], step = compute_patch_info_from_patch_dim(h, w, patch_h, patch_w, overlap)
154
+ patch_size = (patch_h, patch_w)
155
+ else:
156
+ raise ValueError(
157
+ "Either number of patches or patch size is required for reconstruction, this should not happen"
158
+ " at this stage"
159
+ )
160
+
161
+ img_patches = get_sorted_patches_by_image(test_results, img_name)
162
+ pred = img_patches["pred_label"].to_numpy().reshape(patch_num_h, patch_num_w)
163
+
164
+ # Treat annotated good predictions as background, this is an optimistic assumption that assumes that the
165
+ # remaining background is good, but it is not always true so maybe on non annotated areas we are missing
166
+ # defects and it would be necessary to handle this in a different way.
167
+ if annotated_good is not None:
168
+ pred[np.isin(pred, annotated_good)] = 0
169
+ if patch_num_h is not None and patch_num_w is not None:
170
+ output_mask, predicted_defect = reconstruct_patch(
171
+ input_img_shape=test_img.shape,
172
+ patch_size=patch_size,
173
+ pred=pred,
174
+ patch_num_h=patch_num_h,
175
+ patch_num_w=patch_num_w,
176
+ idx_to_class=idx_to_class,
177
+ step=step,
178
+ return_polygon=return_polygon,
179
+ method=patch_reconstruction_method,
180
+ )
181
+ else:
182
+ raise ValueError("`patch_num_h` and `patch_num_w` cannot be None at this point")
183
+
184
+ if return_polygon:
185
+ img_json_entry["prediction"] = predicted_defect
186
+ else:
187
+ img_json_entry["prediction"] = output_mask
188
+
189
+ reconstructions.append(img_json_entry)
190
+ if gt_img is not None:
191
+ if annotated_good is not None:
192
+ gt_img[np.isin(gt_img, annotated_good)] = 0
193
+
194
+ gt_img_binary = (gt_img > 0).astype(bool) # type: ignore[operator]
195
+ regions_pred = label(output_mask).astype(np.uint8)
196
+
197
+ for k in range(1, regions_pred.max() + 1):
198
+ region = (regions_pred == k).astype(bool)
199
+ # If there's no overlap with the gt
200
+ if np.sum(np.bitwise_and(region, gt_img_binary)) == 0:
201
+ false_region_bad += 1
202
+
203
+ output_mask = (output_mask > 0).astype(np.uint8)
204
+ gt_img = label(gt_img)
205
+
206
+ for i in range(1, gt_img.max() + 1): # type: ignore[union-attr]
207
+ region = (gt_img == i).astype(bool) # type: ignore[union-attr]
208
+ if np.sum(np.bitwise_and(region, output_mask)) == 0:
209
+ false_region_good += 1
210
+ else:
211
+ true_region_bad += 1
212
+
213
+ return false_region_bad, false_region_good, true_region_bad, reconstructions
214
+
215
+
216
+ def reconstruct_patch(
217
+ input_img_shape: tuple[int, ...],
218
+ patch_size: tuple[int, int],
219
+ pred: np.ndarray,
220
+ patch_num_h: int,
221
+ patch_num_w: int,
222
+ idx_to_class: dict,
223
+ step: tuple[int, int],
224
+ return_polygon: bool = True,
225
+ method: str = "priority",
226
+ ) -> tuple[np.ndarray, list[dict]]:
227
+ """Reconstructs the prediction image from the patches.
228
+
229
+ Args:
230
+ input_img_shape: The size of the reconstructed image
231
+ patch_size: Array defining the patch size
232
+ pred: Numpy array containing reconstructed prediction (patch_num_h x patch_num_w)
233
+ patch_num_h: Number of vertical patches
234
+ patch_num_w: Number of horizontal patches
235
+ idx_to_class: Dictionary mapping indices to labels
236
+ step: Array defining the step size to be used for reconstruction
237
+ return_polygon: If true compute predicted polygons. Defaults to True.
238
+ method: Reconstruction method to be used. Currently supported: "priority" and "major_voting"
239
+
240
+ Returns:
241
+ (reconstructed_prediction_image, predictions) where predictions is an array of objects
242
+ [{
243
+ "label": Predicted_label,
244
+ "points": List of dict coordinates "x" and "y" representing the points of a polygon that
245
+ surrounds an image area covered by patches of label = predicted_label
246
+ }]
247
+ """
248
+ if method == "priority":
249
+ return _reconstruct_patch_priority(
250
+ input_img_shape,
251
+ patch_size,
252
+ pred,
253
+ patch_num_h,
254
+ patch_num_w,
255
+ idx_to_class,
256
+ step,
257
+ return_polygon,
258
+ )
259
+ if method == "major_voting":
260
+ return _reconstruct_patch_major_voting(
261
+ input_img_shape,
262
+ patch_size,
263
+ pred,
264
+ patch_num_h,
265
+ patch_num_w,
266
+ idx_to_class,
267
+ step,
268
+ return_polygon,
269
+ )
270
+
271
+ raise ValueError(f"Invalid reconstruction method {method}")
272
+
273
+
274
+ def _reconstruct_patch_priority(
275
+ input_img_shape: tuple[int, ...],
276
+ patch_size: tuple[int, int],
277
+ pred: np.ndarray,
278
+ patch_num_h: int,
279
+ patch_num_w: int,
280
+ idx_to_class: dict,
281
+ step: tuple[int, int],
282
+ return_polygon: bool = True,
283
+ ) -> tuple[np.ndarray, list[dict]]:
284
+ """Reconstruct patch polygons using the priority method."""
285
+ final_mask = np.zeros([input_img_shape[0], input_img_shape[1]], dtype=np.uint8)
286
+ predicted_defect = []
287
+
288
+ for i in range(1, pred.max() + 1):
289
+ white_patch = np.full((patch_size[0], patch_size[1]), i, dtype=np.uint8)
290
+ masked_pred = (pred == i).astype(np.uint8)
291
+
292
+ if masked_pred.sum() == 0:
293
+ continue
294
+
295
+ mask_img = np.zeros([input_img_shape[0], input_img_shape[1]], dtype=np.uint8)
296
+
297
+ for h in range(patch_num_h):
298
+ for w in range(patch_num_w):
299
+ if masked_pred[h, w] == 1:
300
+ patch_location_h = step[0] * h
301
+ patch_location_w = step[1] * w
302
+
303
+ # Move replicated patches prediction in the correct position of the original image if needed
304
+ if patch_location_h + patch_size[0] > mask_img.shape[0]:
305
+ patch_location_h = mask_img.shape[0] - patch_size[0]
306
+
307
+ if patch_location_w + patch_size[1] > mask_img.shape[1]:
308
+ patch_location_w = mask_img.shape[1] - patch_size[1]
309
+
310
+ mask_img[
311
+ patch_location_h : patch_location_h + patch_size[0],
312
+ patch_location_w : patch_location_w + patch_size[1],
313
+ ] = white_patch
314
+
315
+ mask_img = mask_img[0 : input_img_shape[0], 0 : input_img_shape[1]]
316
+
317
+ # Priority is given by the index of the class, the larger, the more important
318
+ final_mask = np.maximum(mask_img, final_mask)
319
+
320
+ if final_mask.sum() != 0 and return_polygon:
321
+ for lab in np.unique(final_mask):
322
+ if lab == 0:
323
+ continue
324
+
325
+ polygon = from_mask_to_polygon((final_mask == lab).astype(np.uint8))
326
+
327
+ for pol in polygon:
328
+ class_entry = {
329
+ "label": idx_to_class.get(lab),
330
+ "points": pol,
331
+ }
332
+
333
+ predicted_defect.append(class_entry)
334
+
335
+ return final_mask, predicted_defect
336
+
337
+
338
+ def _reconstruct_patch_major_voting(
339
+ input_img_shape: tuple[int, ...],
340
+ patch_size: tuple[int, int],
341
+ pred: np.ndarray,
342
+ patch_num_h: int,
343
+ patch_num_w: int,
344
+ idx_to_class: dict,
345
+ step: tuple[int, int],
346
+ return_polygon: bool = True,
347
+ ):
348
+ """Reconstruct patch polygons using the major voting method."""
349
+ predicted_defect = []
350
+
351
+ final_mask = np.zeros([input_img_shape[0], input_img_shape[1], np.max(pred) + 1], dtype=np.uint8)
352
+ white_patch = np.ones((patch_size[0], patch_size[1]), dtype=np.uint8)
353
+
354
+ for i in range(1, pred.max() + 1):
355
+ masked_pred = (pred == i).astype(np.uint8)
356
+
357
+ if masked_pred.sum() == 0:
358
+ continue
359
+
360
+ mask_img = np.zeros([input_img_shape[0], input_img_shape[1]], dtype=np.uint8)
361
+
362
+ for h in range(patch_num_h):
363
+ for w in range(patch_num_w):
364
+ if masked_pred[h, w] == 1:
365
+ patch_location_h = step[0] * h
366
+ patch_location_w = step[1] * w
367
+
368
+ # Move replicated patches prediction in the correct position of the original image if needed
369
+ if patch_location_h + patch_size[0] > mask_img.shape[0]:
370
+ patch_location_h = mask_img.shape[0] - patch_size[0]
371
+
372
+ if patch_location_w + patch_size[1] > mask_img.shape[1]:
373
+ patch_location_w = mask_img.shape[1] - patch_size[1]
374
+
375
+ mask_img[
376
+ patch_location_h : patch_location_h + patch_size[0],
377
+ patch_location_w : patch_location_w + patch_size[1],
378
+ ] += white_patch
379
+
380
+ mask_img = mask_img[0 : input_img_shape[0], 0 : input_img_shape[1]]
381
+ final_mask[:, :, i] = mask_img
382
+
383
+ # Since argmax returns first element on ties and the priority is defined from 0 to n_classes,
384
+ # I needed a way to get the last element on ties, this code achieves that
385
+ final_mask = ((final_mask.shape[-1] - 1) - np.argmax(final_mask[..., ::-1], axis=-1)) * np.invert(
386
+ np.all(final_mask == 0, axis=-1)
387
+ )
388
+
389
+ if final_mask.sum() != 0 and return_polygon:
390
+ for lab in np.unique(final_mask):
391
+ if lab == 0:
392
+ continue
393
+
394
+ polygon = from_mask_to_polygon((final_mask == lab).astype(np.uint8))
395
+
396
+ for pol in polygon:
397
+ class_entry = {
398
+ "label": idx_to_class.get(lab),
399
+ "points": pol,
400
+ }
401
+
402
+ predicted_defect.append(class_entry)
403
+
404
+ return final_mask, predicted_defect
405
+
406
+
407
+ def from_mask_to_polygon(mask_img: np.ndarray) -> list:
408
+ """Convert a mask of pattern to a list of polygon vertices.
409
+
410
+ Args:
411
+ mask_img: masked patch reconstruction image
412
+ Returns:
413
+ a list of lists containing the coordinates of the polygons containing each region of the mask:
414
+ [
415
+ [
416
+ {
417
+ "x": 1.1,
418
+ "y": 2.2
419
+ },
420
+ {
421
+ "x": 2.1,
422
+ "y": 3.2
423
+ }
424
+ ], ...
425
+ ].
426
+ """
427
+ points_dict = []
428
+ # find vertices of polygon: points -> list of array of dim n_vertex, 1, 2(x,y)
429
+ polygon_points, hier = cv2.findContours(mask_img, cv2.RETR_TREE, cv2.CHAIN_APPROX_TC89_L1)
430
+
431
+ if not hier[:, :, 2:].all(-1).all(): # there are holes
432
+ holes = ndimage.binary_fill_holes(mask_img).astype(int)
433
+ holes -= mask_img
434
+ holes = (holes > 0).astype(np.uint8)
435
+ if holes.sum() > 0: # there are holes
436
+ for hole in regionprops(label(holes)):
437
+ a, _, _, _d = hole.bbox
438
+ mask_img[a] = 0
439
+
440
+ polygon_points, hier = cv2.findContours(mask_img, cv2.RETR_LIST, cv2.CHAIN_APPROX_TC89_L1)
441
+
442
+ for pol in polygon_points:
443
+ # pol: n_vertex, 1, 2
444
+ current_poly = []
445
+ for point in pol:
446
+ current_poly.append({"x": int(point[0, 0]), "y": int(point[0, 1])})
447
+ points_dict.append(current_poly)
448
+
449
+ return points_dict
@@ -0,0 +1,153 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import os
5
+ from typing import Any
6
+
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ import pandas as pd
10
+ from label_studio_converter.brush import mask2rle
11
+ from omegaconf import DictConfig
12
+ from sklearn.metrics import ConfusionMatrixDisplay
13
+ from torch.utils.data import DataLoader
14
+
15
+ from quadra.utils import utils
16
+ from quadra.utils.patch.visualization import plot_patch_reconstruction
17
+ from quadra.utils.visualization import UnNormalize, plot_classification_results
18
+
19
+ log = utils.get_logger(__name__)
20
+
21
+
22
+ def save_classification_result(
23
+ results: pd.DataFrame,
24
+ output_folder: str,
25
+ confusion_matrix: pd.DataFrame | None,
26
+ accuracy: float,
27
+ test_dataloader: DataLoader,
28
+ reconstructions: list[dict],
29
+ config: DictConfig,
30
+ output: DictConfig,
31
+ ignore_classes: list[int] | None = None,
32
+ ):
33
+ """Save classification results.
34
+
35
+ Args:
36
+ results: Dataframe containing the classification results
37
+ output_folder: Folder where to save the results
38
+ confusion_matrix: Confusion matrix
39
+ accuracy: Accuracy of the model
40
+ test_dataloader: Dataloader used for testing
41
+ reconstructions: List of dictionaries containing polygons or masks
42
+ config: Experiment configuration
43
+ output: Output configuration
44
+ ignore_classes: Eventual classes to ignore during reconstruction plot. Defaults to None.
45
+ """
46
+ # Save csv
47
+ results.to_csv(os.path.join(output_folder, "test_results.csv"), index_label="index")
48
+
49
+ if confusion_matrix is not None:
50
+ # Save confusion matrix
51
+ disp = ConfusionMatrixDisplay(
52
+ confusion_matrix=np.array(confusion_matrix),
53
+ display_labels=[x.replace("pred:", "") for x in confusion_matrix.columns.to_list()],
54
+ )
55
+ disp.plot(include_values=True, cmap=plt.cm.Greens, ax=None, colorbar=False, xticks_rotation=90)
56
+ plt.title(f"Confusion Matrix (Accuracy: {(accuracy * 100):.2f}%)")
57
+ plt.savefig(
58
+ os.path.join(output_folder, "test_confusion_matrix.png"),
59
+ bbox_inches="tight",
60
+ pad_inches=0,
61
+ dpi=300,
62
+ )
63
+ plt.close()
64
+
65
+ if output.example:
66
+ if not hasattr(test_dataloader.dataset, "idx_to_class"):
67
+ raise ValueError("The provided dataset does not have an attribute 'idx_to_class")
68
+
69
+ idx_to_class = test_dataloader.dataset.idx_to_class
70
+
71
+ # Get misclassified samples
72
+ example_folder = os.path.join(output_folder, "example")
73
+ if not os.path.isdir(example_folder):
74
+ os.makedirs(example_folder)
75
+
76
+ # Skip if no no ground truth is available
77
+ if not all(results["real_label"] == -1):
78
+ for v in np.unique([results["real_label"], results["pred_label"]]):
79
+ if v == -1:
80
+ continue
81
+
82
+ k = idx_to_class[v]
83
+
84
+ if ignore_classes is not None and v in ignore_classes:
85
+ continue
86
+
87
+ plot_classification_results(
88
+ test_dataloader.dataset,
89
+ unorm=UnNormalize(mean=config.transforms.mean, std=config.transforms.std),
90
+ pred_labels=results["pred_label"].to_numpy(),
91
+ test_labels=results["real_label"].to_numpy(),
92
+ class_name=k,
93
+ original_folder=example_folder,
94
+ idx_to_class=idx_to_class,
95
+ pred_class_to_plot=v,
96
+ what="con",
97
+ rows=output.get("rows", 3),
98
+ cols=output.get("cols", 2),
99
+ figsize=output.get("figsize", (20, 20)),
100
+ )
101
+
102
+ plot_classification_results(
103
+ test_dataloader.dataset,
104
+ unorm=UnNormalize(mean=config.transforms.mean, std=config.transforms.std),
105
+ pred_labels=results["pred_label"].to_numpy(),
106
+ test_labels=results["real_label"].to_numpy(),
107
+ class_name=k,
108
+ original_folder=example_folder,
109
+ idx_to_class=idx_to_class,
110
+ pred_class_to_plot=v,
111
+ what="dis",
112
+ rows=output.get("rows", 3),
113
+ cols=output.get("cols", 2),
114
+ figsize=output.get("figsize", (20, 20)),
115
+ )
116
+
117
+ for counter, reconstruction in enumerate(reconstructions):
118
+ is_polygon = True
119
+ if isinstance(reconstruction["prediction"], np.ndarray):
120
+ is_polygon = False
121
+
122
+ if is_polygon:
123
+ if len(reconstruction["prediction"]) == 0:
124
+ continue
125
+ elif reconstruction["prediction"].sum() == 0:
126
+ continue
127
+
128
+ if counter > 5:
129
+ break
130
+
131
+ to_plot = plot_patch_reconstruction(
132
+ reconstruction,
133
+ idx_to_class,
134
+ class_to_idx=test_dataloader.dataset.class_to_idx, # type: ignore[attr-defined]
135
+ ignore_classes=ignore_classes,
136
+ is_polygon=is_polygon,
137
+ )
138
+
139
+ if to_plot:
140
+ output_name = f"reconstruction_{os.path.splitext(os.path.basename(reconstruction['file_name']))[0]}.png"
141
+ plt.savefig(os.path.join(example_folder, output_name), bbox_inches="tight", pad_inches=0)
142
+
143
+ plt.close()
144
+
145
+
146
+ class RleEncoder(json.JSONEncoder):
147
+ """Custom encoder to convert numpy arrays to RLE."""
148
+
149
+ def default(self, o: Any):
150
+ """Customize standard encoder behaviour to convert numpy arrays to RLE."""
151
+ if isinstance(o, np.ndarray):
152
+ return mask2rle(o)
153
+ return json.JSONEncoder.default(self, o)