quadra 0.0.1__py3-none-any.whl → 2.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (302) hide show
  1. hydra_plugins/quadra_searchpath_plugin.py +30 -0
  2. quadra/__init__.py +6 -0
  3. quadra/callbacks/__init__.py +0 -0
  4. quadra/callbacks/anomalib.py +289 -0
  5. quadra/callbacks/lightning.py +501 -0
  6. quadra/callbacks/mlflow.py +291 -0
  7. quadra/callbacks/scheduler.py +69 -0
  8. quadra/configs/__init__.py +0 -0
  9. quadra/configs/backbone/caformer_m36.yaml +8 -0
  10. quadra/configs/backbone/caformer_s36.yaml +8 -0
  11. quadra/configs/backbone/convnextv2_base.yaml +8 -0
  12. quadra/configs/backbone/convnextv2_femto.yaml +8 -0
  13. quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
  14. quadra/configs/backbone/dino_vitb8.yaml +12 -0
  15. quadra/configs/backbone/dino_vits8.yaml +12 -0
  16. quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
  17. quadra/configs/backbone/dinov2_vits14.yaml +12 -0
  18. quadra/configs/backbone/efficientnet_b0.yaml +8 -0
  19. quadra/configs/backbone/efficientnet_b1.yaml +8 -0
  20. quadra/configs/backbone/efficientnet_b2.yaml +8 -0
  21. quadra/configs/backbone/efficientnet_b3.yaml +8 -0
  22. quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
  23. quadra/configs/backbone/levit_128s.yaml +8 -0
  24. quadra/configs/backbone/mnasnet0_5.yaml +9 -0
  25. quadra/configs/backbone/resnet101.yaml +8 -0
  26. quadra/configs/backbone/resnet18.yaml +8 -0
  27. quadra/configs/backbone/resnet18_ssl.yaml +8 -0
  28. quadra/configs/backbone/resnet50.yaml +8 -0
  29. quadra/configs/backbone/smp.yaml +9 -0
  30. quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
  31. quadra/configs/backbone/unetr.yaml +15 -0
  32. quadra/configs/backbone/vit16_base.yaml +9 -0
  33. quadra/configs/backbone/vit16_small.yaml +9 -0
  34. quadra/configs/backbone/vit16_tiny.yaml +9 -0
  35. quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
  36. quadra/configs/callbacks/all.yaml +45 -0
  37. quadra/configs/callbacks/default.yaml +34 -0
  38. quadra/configs/callbacks/default_anomalib.yaml +64 -0
  39. quadra/configs/config.yaml +33 -0
  40. quadra/configs/core/default.yaml +11 -0
  41. quadra/configs/datamodule/base/anomaly.yaml +16 -0
  42. quadra/configs/datamodule/base/classification.yaml +21 -0
  43. quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
  44. quadra/configs/datamodule/base/segmentation.yaml +18 -0
  45. quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
  46. quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
  47. quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
  48. quadra/configs/datamodule/base/ssl.yaml +21 -0
  49. quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
  50. quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
  51. quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
  52. quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
  53. quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
  54. quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
  55. quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
  56. quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
  57. quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
  58. quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
  59. quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
  60. quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
  61. quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
  62. quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
  63. quadra/configs/experiment/base/classification/classification.yaml +73 -0
  64. quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
  65. quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
  66. quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
  67. quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
  68. quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
  69. quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
  70. quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
  71. quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
  72. quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
  73. quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
  74. quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
  75. quadra/configs/experiment/base/ssl/byol.yaml +43 -0
  76. quadra/configs/experiment/base/ssl/dino.yaml +46 -0
  77. quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
  78. quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
  79. quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
  80. quadra/configs/experiment/custom/cls.yaml +12 -0
  81. quadra/configs/experiment/default.yaml +15 -0
  82. quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
  83. quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
  84. quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
  85. quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
  86. quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
  87. quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
  88. quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
  89. quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
  90. quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
  91. quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
  92. quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
  93. quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
  94. quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
  95. quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
  96. quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
  97. quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
  98. quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
  99. quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
  100. quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
  101. quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
  102. quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
  103. quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
  104. quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
  105. quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
  106. quadra/configs/export/default.yaml +13 -0
  107. quadra/configs/hydra/anomaly_custom.yaml +15 -0
  108. quadra/configs/hydra/default.yaml +14 -0
  109. quadra/configs/inference/default.yaml +26 -0
  110. quadra/configs/logger/comet.yaml +10 -0
  111. quadra/configs/logger/csv.yaml +5 -0
  112. quadra/configs/logger/mlflow.yaml +12 -0
  113. quadra/configs/logger/tensorboard.yaml +8 -0
  114. quadra/configs/loss/asl.yaml +7 -0
  115. quadra/configs/loss/barlow.yaml +2 -0
  116. quadra/configs/loss/bce.yaml +1 -0
  117. quadra/configs/loss/byol.yaml +1 -0
  118. quadra/configs/loss/cross_entropy.yaml +1 -0
  119. quadra/configs/loss/dino.yaml +8 -0
  120. quadra/configs/loss/simclr.yaml +2 -0
  121. quadra/configs/loss/simsiam.yaml +1 -0
  122. quadra/configs/loss/smp_ce.yaml +3 -0
  123. quadra/configs/loss/smp_dice.yaml +2 -0
  124. quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
  125. quadra/configs/loss/smp_mcc.yaml +2 -0
  126. quadra/configs/loss/vicreg.yaml +5 -0
  127. quadra/configs/model/anomalib/cfa.yaml +35 -0
  128. quadra/configs/model/anomalib/cflow.yaml +30 -0
  129. quadra/configs/model/anomalib/csflow.yaml +34 -0
  130. quadra/configs/model/anomalib/dfm.yaml +19 -0
  131. quadra/configs/model/anomalib/draem.yaml +29 -0
  132. quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
  133. quadra/configs/model/anomalib/fastflow.yaml +32 -0
  134. quadra/configs/model/anomalib/padim.yaml +32 -0
  135. quadra/configs/model/anomalib/patchcore.yaml +36 -0
  136. quadra/configs/model/barlow.yaml +16 -0
  137. quadra/configs/model/byol.yaml +25 -0
  138. quadra/configs/model/classification.yaml +10 -0
  139. quadra/configs/model/dino.yaml +26 -0
  140. quadra/configs/model/logistic_regression.yaml +4 -0
  141. quadra/configs/model/multilabel_classification.yaml +9 -0
  142. quadra/configs/model/simclr.yaml +18 -0
  143. quadra/configs/model/simsiam.yaml +24 -0
  144. quadra/configs/model/smp.yaml +4 -0
  145. quadra/configs/model/smp_multiclass.yaml +4 -0
  146. quadra/configs/model/vicreg.yaml +16 -0
  147. quadra/configs/optimizer/adam.yaml +5 -0
  148. quadra/configs/optimizer/adamw.yaml +3 -0
  149. quadra/configs/optimizer/default.yaml +4 -0
  150. quadra/configs/optimizer/lars.yaml +8 -0
  151. quadra/configs/optimizer/sgd.yaml +4 -0
  152. quadra/configs/scheduler/default.yaml +5 -0
  153. quadra/configs/scheduler/rop.yaml +5 -0
  154. quadra/configs/scheduler/step.yaml +3 -0
  155. quadra/configs/scheduler/warmrestart.yaml +2 -0
  156. quadra/configs/scheduler/warmup.yaml +6 -0
  157. quadra/configs/task/anomalib/cfa.yaml +5 -0
  158. quadra/configs/task/anomalib/cflow.yaml +5 -0
  159. quadra/configs/task/anomalib/csflow.yaml +5 -0
  160. quadra/configs/task/anomalib/draem.yaml +5 -0
  161. quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
  162. quadra/configs/task/anomalib/fastflow.yaml +5 -0
  163. quadra/configs/task/anomalib/inference.yaml +3 -0
  164. quadra/configs/task/anomalib/padim.yaml +5 -0
  165. quadra/configs/task/anomalib/patchcore.yaml +5 -0
  166. quadra/configs/task/classification.yaml +6 -0
  167. quadra/configs/task/classification_evaluation.yaml +6 -0
  168. quadra/configs/task/default.yaml +1 -0
  169. quadra/configs/task/segmentation.yaml +9 -0
  170. quadra/configs/task/segmentation_evaluation.yaml +3 -0
  171. quadra/configs/task/sklearn_classification.yaml +13 -0
  172. quadra/configs/task/sklearn_classification_patch.yaml +11 -0
  173. quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
  174. quadra/configs/task/sklearn_classification_test.yaml +8 -0
  175. quadra/configs/task/ssl.yaml +2 -0
  176. quadra/configs/trainer/lightning_cpu.yaml +36 -0
  177. quadra/configs/trainer/lightning_gpu.yaml +35 -0
  178. quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
  179. quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
  180. quadra/configs/trainer/lightning_multigpu.yaml +37 -0
  181. quadra/configs/trainer/sklearn_classification.yaml +7 -0
  182. quadra/configs/transforms/byol.yaml +47 -0
  183. quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
  184. quadra/configs/transforms/default.yaml +37 -0
  185. quadra/configs/transforms/default_numpy.yaml +24 -0
  186. quadra/configs/transforms/default_resize.yaml +22 -0
  187. quadra/configs/transforms/dino.yaml +63 -0
  188. quadra/configs/transforms/linear_eval.yaml +18 -0
  189. quadra/datamodules/__init__.py +20 -0
  190. quadra/datamodules/anomaly.py +180 -0
  191. quadra/datamodules/base.py +375 -0
  192. quadra/datamodules/classification.py +1003 -0
  193. quadra/datamodules/generic/__init__.py +0 -0
  194. quadra/datamodules/generic/imagenette.py +144 -0
  195. quadra/datamodules/generic/mnist.py +81 -0
  196. quadra/datamodules/generic/mvtec.py +58 -0
  197. quadra/datamodules/generic/oxford_pet.py +163 -0
  198. quadra/datamodules/patch.py +190 -0
  199. quadra/datamodules/segmentation.py +742 -0
  200. quadra/datamodules/ssl.py +140 -0
  201. quadra/datasets/__init__.py +17 -0
  202. quadra/datasets/anomaly.py +287 -0
  203. quadra/datasets/classification.py +241 -0
  204. quadra/datasets/patch.py +138 -0
  205. quadra/datasets/segmentation.py +239 -0
  206. quadra/datasets/ssl.py +110 -0
  207. quadra/losses/__init__.py +0 -0
  208. quadra/losses/classification/__init__.py +6 -0
  209. quadra/losses/classification/asl.py +83 -0
  210. quadra/losses/classification/focal.py +320 -0
  211. quadra/losses/classification/prototypical.py +148 -0
  212. quadra/losses/ssl/__init__.py +17 -0
  213. quadra/losses/ssl/barlowtwins.py +47 -0
  214. quadra/losses/ssl/byol.py +37 -0
  215. quadra/losses/ssl/dino.py +129 -0
  216. quadra/losses/ssl/hyperspherical.py +45 -0
  217. quadra/losses/ssl/idmm.py +50 -0
  218. quadra/losses/ssl/simclr.py +67 -0
  219. quadra/losses/ssl/simsiam.py +30 -0
  220. quadra/losses/ssl/vicreg.py +76 -0
  221. quadra/main.py +49 -0
  222. quadra/metrics/__init__.py +3 -0
  223. quadra/metrics/segmentation.py +251 -0
  224. quadra/models/__init__.py +0 -0
  225. quadra/models/base.py +151 -0
  226. quadra/models/classification/__init__.py +8 -0
  227. quadra/models/classification/backbones.py +149 -0
  228. quadra/models/classification/base.py +92 -0
  229. quadra/models/evaluation.py +322 -0
  230. quadra/modules/__init__.py +0 -0
  231. quadra/modules/backbone.py +30 -0
  232. quadra/modules/base.py +312 -0
  233. quadra/modules/classification/__init__.py +3 -0
  234. quadra/modules/classification/base.py +327 -0
  235. quadra/modules/ssl/__init__.py +17 -0
  236. quadra/modules/ssl/barlowtwins.py +59 -0
  237. quadra/modules/ssl/byol.py +172 -0
  238. quadra/modules/ssl/common.py +285 -0
  239. quadra/modules/ssl/dino.py +186 -0
  240. quadra/modules/ssl/hyperspherical.py +206 -0
  241. quadra/modules/ssl/idmm.py +98 -0
  242. quadra/modules/ssl/simclr.py +73 -0
  243. quadra/modules/ssl/simsiam.py +68 -0
  244. quadra/modules/ssl/vicreg.py +67 -0
  245. quadra/optimizers/__init__.py +4 -0
  246. quadra/optimizers/lars.py +153 -0
  247. quadra/optimizers/sam.py +127 -0
  248. quadra/schedulers/__init__.py +3 -0
  249. quadra/schedulers/base.py +44 -0
  250. quadra/schedulers/warmup.py +127 -0
  251. quadra/tasks/__init__.py +24 -0
  252. quadra/tasks/anomaly.py +582 -0
  253. quadra/tasks/base.py +397 -0
  254. quadra/tasks/classification.py +1263 -0
  255. quadra/tasks/patch.py +492 -0
  256. quadra/tasks/segmentation.py +389 -0
  257. quadra/tasks/ssl.py +560 -0
  258. quadra/trainers/README.md +3 -0
  259. quadra/trainers/__init__.py +0 -0
  260. quadra/trainers/classification.py +179 -0
  261. quadra/utils/__init__.py +0 -0
  262. quadra/utils/anomaly.py +112 -0
  263. quadra/utils/classification.py +618 -0
  264. quadra/utils/deprecation.py +31 -0
  265. quadra/utils/evaluation.py +474 -0
  266. quadra/utils/export.py +585 -0
  267. quadra/utils/imaging.py +32 -0
  268. quadra/utils/logger.py +15 -0
  269. quadra/utils/mlflow.py +98 -0
  270. quadra/utils/model_manager.py +320 -0
  271. quadra/utils/models.py +523 -0
  272. quadra/utils/patch/__init__.py +15 -0
  273. quadra/utils/patch/dataset.py +1433 -0
  274. quadra/utils/patch/metrics.py +449 -0
  275. quadra/utils/patch/model.py +153 -0
  276. quadra/utils/patch/visualization.py +217 -0
  277. quadra/utils/resolver.py +42 -0
  278. quadra/utils/segmentation.py +31 -0
  279. quadra/utils/tests/__init__.py +0 -0
  280. quadra/utils/tests/fixtures/__init__.py +1 -0
  281. quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
  282. quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
  283. quadra/utils/tests/fixtures/dataset/classification.py +406 -0
  284. quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
  285. quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
  286. quadra/utils/tests/fixtures/models/__init__.py +3 -0
  287. quadra/utils/tests/fixtures/models/anomaly.py +89 -0
  288. quadra/utils/tests/fixtures/models/classification.py +45 -0
  289. quadra/utils/tests/fixtures/models/segmentation.py +33 -0
  290. quadra/utils/tests/helpers.py +70 -0
  291. quadra/utils/tests/models.py +27 -0
  292. quadra/utils/utils.py +525 -0
  293. quadra/utils/validator.py +115 -0
  294. quadra/utils/visualization.py +422 -0
  295. quadra/utils/vit_explainability.py +349 -0
  296. quadra-2.2.7.dist-info/LICENSE +201 -0
  297. quadra-2.2.7.dist-info/METADATA +381 -0
  298. quadra-2.2.7.dist-info/RECORD +300 -0
  299. {quadra-0.0.1.dist-info → quadra-2.2.7.dist-info}/WHEEL +1 -1
  300. quadra-2.2.7.dist-info/entry_points.txt +3 -0
  301. quadra-0.0.1.dist-info/METADATA +0 -14
  302. quadra-0.0.1.dist-info/RECORD +0 -4
@@ -0,0 +1,1433 @@
1
+ from __future__ import annotations
2
+
3
+ import glob
4
+ import itertools
5
+ import json
6
+ import math
7
+ import os
8
+ import random
9
+ import shutil
10
+ import warnings
11
+ from collections.abc import Callable
12
+ from copy import deepcopy
13
+ from dataclasses import dataclass
14
+ from functools import partial
15
+ from multiprocessing import Pool
16
+ from typing import Any
17
+
18
+ import cv2
19
+ import h5py
20
+ import numpy as np
21
+ from scipy import ndimage
22
+ from skimage.measure import label, regionprops # pylint: disable=no-name-in-module
23
+ from skimage.util import view_as_windows
24
+ from skmultilearn.model_selection import iterative_train_test_split
25
+ from tqdm import tqdm
26
+ from tripy import earclip
27
+
28
+ from quadra.utils import utils
29
+
30
+ log = utils.get_logger(__name__)
31
+
32
+
33
+ @dataclass
34
+ class PatchDatasetFileFormat:
35
+ """Model representing the content of the patch dataset split_files field in the info.json file."""
36
+
37
+ image_path: str
38
+ mask_path: str | None = None
39
+
40
+
41
+ @dataclass
42
+ class PatchDatasetInfo:
43
+ """Model representing the content of the patch dataset info.json file."""
44
+
45
+ patch_size: tuple[int, int] | None
46
+ patch_number: tuple[int, int] | None
47
+ annotated_good: list[int] | None
48
+ overlap: float
49
+ train_files: list[PatchDatasetFileFormat]
50
+ val_files: list[PatchDatasetFileFormat]
51
+ test_files: list[PatchDatasetFileFormat]
52
+
53
+ @staticmethod
54
+ def _map_files(files: list[Any]):
55
+ """Convert a list of dict to a list of PatchDatasetFileFormat."""
56
+ mapped_files = []
57
+ for file in files:
58
+ current_file = file
59
+ if isinstance(file, dict):
60
+ current_file = PatchDatasetFileFormat(**current_file)
61
+ mapped_files.append(current_file)
62
+
63
+ return mapped_files
64
+
65
+ def __post_init__(self):
66
+ self.train_files = self._map_files(self.train_files)
67
+ self.val_files = self._map_files(self.val_files)
68
+ self.test_files = self._map_files(self.test_files)
69
+
70
+
71
+ def get_image_mask_association(
72
+ data_folder: str,
73
+ mask_folder: str | None = None,
74
+ mask_extension: str = "",
75
+ warning_on_missing_mask: bool = True,
76
+ ) -> list[dict]:
77
+ """Function used to match images and mask from a folder or sub-folders.
78
+
79
+ Args:
80
+ data_folder: root data folder containing images or images and masks
81
+ mask_folder: Optional root directory used to search only the masks
82
+ mask_extension: extension used to identify the mask file, it's mandatory if mask_folder is not specified
83
+ warning_on_missing_mask: if set to True a warning will be raised if a mask is missing, disable if you know
84
+ that many images do not have a mask.
85
+ warning_on_missing_mask: if set to True a warning will be raised if a mask is missing, disable if you know
86
+
87
+ Returns:
88
+ List of dict like:
89
+ [
90
+ {
91
+ 'base_name': '161927.tiff',
92
+ 'path': 'test_dataset_patch/images/161927.tiff',
93
+ 'mask': 'test_dataset_patch/masks/161927_mask.tiff'
94
+ }, ...
95
+ ]
96
+ """
97
+ # get all the images from the data folder
98
+ data_images = glob.glob(os.path.join(data_folder, "**", "*"), recursive=True)
99
+
100
+ basenames = [os.path.splitext(os.path.basename(image))[0] for image in data_images]
101
+
102
+ if len(set(basenames)) != len(basenames):
103
+ raise ValueError("Found multiple images with the same name and different extension, this is not supported.")
104
+
105
+ log.info("Found: %d images in %s", len(data_images), data_folder)
106
+ # divide images and mask if in the same folder
107
+ # if mask folder is specified search mask in that folder
108
+ if mask_folder:
109
+ masks_images = []
110
+ for basename in basenames:
111
+ mask_path = os.path.join(mask_folder, f"{basename}{mask_extension}.*")
112
+ mask_path_list = glob.glob(mask_path)
113
+
114
+ if len(mask_path_list) == 1:
115
+ masks_images.append(mask_path_list[0])
116
+ elif warning_on_missing_mask:
117
+ log.warning("Mask for %s not found", basename)
118
+ else:
119
+ if mask_extension == "":
120
+ raise ValueError("If no mask folder is provided, mask extension is mandatory it cannot be empty.")
121
+
122
+ masks_images = [image for image in data_images if mask_extension in image]
123
+ data_images = [image for image in data_images if mask_extension not in image]
124
+
125
+ # build support dictionary
126
+ unique_images = [{"base_name": os.path.basename(image), "path": image, "mask": None} for image in data_images]
127
+
128
+ images_stem = [os.path.splitext(str(image["base_name"]))[0] + mask_extension for image in unique_images]
129
+ masks_stem = [os.path.splitext(os.path.basename(mask))[0] for mask in masks_images]
130
+
131
+ # search corrispondency between file or folders
132
+ for i, image_stem in enumerate(images_stem):
133
+ if image_stem in masks_stem:
134
+ unique_images[i]["mask"] = masks_images[masks_stem.index(image_stem)]
135
+
136
+ log.info("Unique images with mask: %d", len([uni for uni in unique_images if uni.get("mask") is not None]))
137
+ log.info("Unique images with no mask: %d", len([uni for uni in unique_images if uni.get("mask") is None]))
138
+
139
+ return unique_images
140
+
141
+
142
+ def compute_patch_info(
143
+ img_h: int,
144
+ img_w: int,
145
+ patch_num_h: int,
146
+ patch_num_w: int,
147
+ overlap: float = 0.0,
148
+ ) -> tuple[tuple[int, int], tuple[int, int]]:
149
+ """Compute the patch size and step size given the number of patches and the overlap.
150
+
151
+ Args:
152
+ img_h: height of the image
153
+ img_w: width of the image
154
+ patch_num_h: number of vertical patches
155
+ patch_num_w: number of horizontal patches
156
+ overlap: percentage of overlap between patches.
157
+
158
+ Returns:
159
+ Tuple containing:
160
+ patch_size: [size_y, size_x] Dimension of the patch
161
+ step_size: [step_y, step_x] Step size
162
+ """
163
+ patch_size_h = np.ceil(img_h / (1 + (patch_num_h - 1) - (patch_num_h - 1) * overlap)).astype(int)
164
+ step_h = patch_size_h - np.ceil(overlap * patch_size_h).astype(int)
165
+
166
+ patch_size_w = np.ceil(img_w / (1 + (patch_num_w - 1) - (patch_num_w - 1) * overlap)).astype(int)
167
+ step_w = patch_size_w - np.ceil(overlap * patch_size_w).astype(int)
168
+
169
+ # We want a combination of patch size and step that if the image is not divisible by the number of patches
170
+ # will try to fit the maximum number of patches in the image + ONLY 1 extra patch that will be taken from the end
171
+ # of the image.
172
+
173
+ counter = 0
174
+ original_patch_size_h = patch_size_h
175
+ original_patch_size_w = patch_size_w
176
+ original_step_h = step_h
177
+ original_step_w = step_w
178
+
179
+ while (patch_num_h - 1) * step_h + patch_size_h < img_h or (patch_num_h - 2) * step_h + patch_size_h > img_h:
180
+ counter += 1
181
+ if (patch_num_h - 1) * (step_h + 1) + patch_size_h < img_h:
182
+ step_h += 1
183
+ else:
184
+ patch_size_h += 1
185
+
186
+ if counter == 100:
187
+ # We probably entered an infinite loop, restart with smaller step size
188
+ step_h = original_step_h - 1
189
+ patch_size_h = original_patch_size_h
190
+ counter = 0
191
+
192
+ counter = 0
193
+ while (patch_num_w - 1) * step_w + patch_size_w < img_w or (patch_num_w - 2) * step_w + patch_size_w > img_w:
194
+ counter += 1
195
+ if (patch_num_w - 1) * (step_w + 1) + patch_size_w < img_w:
196
+ step_w += 1
197
+ else:
198
+ patch_size_w += 1
199
+
200
+ if counter == 100:
201
+ # We probably entered an infinite loop, restart with smaller step size
202
+ step_w = original_step_w - 1
203
+ patch_size_w = original_patch_size_w
204
+ counter = 0
205
+
206
+ return (patch_size_h, patch_size_w), (step_h, step_w)
207
+
208
+
209
+ def compute_patch_info_from_patch_dim(
210
+ img_h: int,
211
+ img_w: int,
212
+ patch_height: int,
213
+ patch_width: int,
214
+ overlap: float = 0.0,
215
+ ) -> tuple[tuple[int, int], tuple[int, int]]:
216
+ """Compute patch info given the patch dimension
217
+ Args:
218
+ img_h: height of the image
219
+ img_w: width of the image
220
+ patch_height: patch height
221
+ patch_width: patch width
222
+ overlap: overlap percentage [0, 1].
223
+
224
+ Returns:
225
+ Tuple of number of patches, step
226
+
227
+ """
228
+ assert 1 >= overlap >= 0, f"Invalid overlap. Must be between [0, 1], received {overlap}"
229
+ step_h = patch_height - int(overlap * patch_height)
230
+ step_w = patch_width - int(overlap * patch_width)
231
+
232
+ patch_num_h = np.ceil(((img_h - patch_height) / step_h) + 1).astype(int)
233
+ patch_num_w = np.ceil(((img_w - patch_width) / step_w) + 1).astype(int)
234
+
235
+ # Handle the case where the last patch does not cover the full image, I need to do this rather than np.ceil
236
+ # because I don't want to add a new patch if the last one exceeds already the image!
237
+ if ((patch_num_h - 1) * step_h) + patch_height < img_h:
238
+ patch_num_h += 1
239
+ if ((patch_num_w - 1) * step_w) + patch_width < img_w:
240
+ patch_num_w += 1
241
+
242
+ return (patch_num_h, patch_num_w), (step_h, step_w)
243
+
244
+
245
+ def from_rgb_to_idx(img: np.ndarray, class_to_color: dict, class_to_idx: dict) -> np.ndarray:
246
+ """Args:
247
+ img: Rgb mask in which each different color is associated with a class
248
+ class_to_color: Dict "key": [R, G, B]
249
+ class_to_idx: Dict "key": class_idx.
250
+
251
+ Returns:
252
+ Grayscale mask in which each class is associated with a specific index
253
+ """
254
+ img = img.astype(int)
255
+ # Use negative values to avoid strange behaviour in the remote eventuality
256
+ # of someone using a color like [1, 255, 255]
257
+ for classe, color in class_to_color.items():
258
+ img[np.all(img == color, axis=-1).astype(bool), 0] = -class_to_idx[classe]
259
+
260
+ img = np.abs(img[:, :, 0])
261
+
262
+ return img.astype(np.uint8)
263
+
264
+
265
+ def __save_patch_dataset(
266
+ image_patches: np.ndarray,
267
+ labelled_patches: np.ndarray | None = None,
268
+ mask_patches: np.ndarray | None = None,
269
+ labelled_mask: np.ndarray | None = None,
270
+ output_folder: str = "extraction_data",
271
+ image_name: str = "example",
272
+ area_threshold: float = 0.45,
273
+ area_defect_threshold: float = 0.2,
274
+ mask_extension: str = "_mask",
275
+ save_mask: bool = False,
276
+ mask_output_folder: str | None = None,
277
+ class_to_idx: dict | None = None,
278
+ ) -> None:
279
+ """Given a view_as_window computed patches, masks and labelled mask, save all the images in subdirectory
280
+ divided by name and position in the grid, ambiguous patches i.e. the one that contains defects but with not enough
281
+ to go above defined thresholds are marked as #DISCARD# and should be discarded in training.
282
+ Patches of images without ground truth are saved inside the None folder.
283
+
284
+ Args:
285
+ image_patches: [n, m, patch_w, patch_h, channel] numpy array of the image patches
286
+ mask_patches: [n, m, patch_w, patch_h] numpy array of mask patches
287
+ labelled_patches: [n, m, patch_w, patch_h] numpy array of labelled mask patch
288
+ labelled_mask: numpy array in which each defect in the image is labelled using connected components
289
+ class_to_idx: Dictionary with the mapping {"class" -> class in mask}, it must cover all indices
290
+ contained in the masks
291
+ save_mask: flag to save or ignore mask
292
+ output_folder: folder where to save data
293
+ mask_extension: postfix of the saved mask based on the image name
294
+ mask_output_folder: Optional folder in which to save the masks
295
+ image_name: name to use in order to save the data
296
+ area_threshold: minimum percentage of defected patch area present in the mask to classify the patch as defect
297
+ area_defect_threshold: minimum percentage of single defect present in the patch to classify the patch as defect
298
+
299
+ Returns:
300
+ None
301
+ """
302
+ if class_to_idx is not None:
303
+ log.debug("Classes from dict: %s", class_to_idx)
304
+ index_to_class = {v: k for k, v in class_to_idx.items()}
305
+ log.debug("Inverse class: %s", index_to_class)
306
+ reference_classes = index_to_class
307
+
308
+ if mask_patches is not None:
309
+ classes_in_mask = set(np.unique(mask_patches))
310
+ missing_classes = set(classes_in_mask).difference(class_to_idx.values())
311
+
312
+ assert len(missing_classes) == 0, f"Found index in mask that has no corresponding class {missing_classes}"
313
+ elif mask_patches is not None:
314
+ reference_classes = {k: str(v) for k, v in enumerate(list(np.unique(mask_patches)))}
315
+ else:
316
+ raise ValueError("If no `class_to_idx` is provided, `mask_patches` must be provided")
317
+
318
+ log.debug("Classes from mask: %s", reference_classes)
319
+ class_to_idx = {v: k for k, v in reference_classes.items()}
320
+ log.debug("Final reference classes: %s", reference_classes)
321
+
322
+ # create subdirectory for the saving data
323
+ for cl in reference_classes.values():
324
+ os.makedirs(os.path.join(output_folder, str(cl)), exist_ok=True)
325
+
326
+ if mask_output_folder is not None:
327
+ os.makedirs(os.path.join(output_folder, mask_output_folder, str(cl)), exist_ok=True)
328
+
329
+ if mask_output_folder is None:
330
+ mask_output_folder = output_folder
331
+ else:
332
+ mask_output_folder = os.path.join(output_folder, mask_output_folder)
333
+
334
+ log.debug("Mask out: %s", mask_output_folder)
335
+
336
+ if mask_patches is None:
337
+ os.makedirs(os.path.join(output_folder, str(None)), exist_ok=True)
338
+ # for [i, j] in patches location
339
+ for row_index in range(image_patches.shape[0]):
340
+ for col_index in range(image_patches.shape[1]):
341
+ # default class it's the one in index 0
342
+ output_class = reference_classes.get(0)
343
+ image = image_patches[row_index, col_index]
344
+
345
+ discard_in_training = True
346
+ if mask_patches is not None and labelled_patches is not None:
347
+ discard_in_training = False
348
+ max_defected_area = 0
349
+ mask = mask_patches[row_index, col_index]
350
+ patch_area_th = mask.shape[0] * mask.shape[1] * area_threshold
351
+
352
+ if mask.sum() > 0:
353
+ discard_in_training = True
354
+ for k, v in class_to_idx.items():
355
+ if v == 0:
356
+ continue
357
+
358
+ mask_patch = mask == int(v)
359
+ defected_area = mask_patch.sum()
360
+
361
+ if defected_area > 0:
362
+ # If enough defected area is inside the patch
363
+ if defected_area > patch_area_th:
364
+ if defected_area > max_defected_area:
365
+ output_class = k
366
+ max_defected_area = defected_area
367
+ discard_in_training = False
368
+ else:
369
+ all_defects_in_patch = mask_patch * labelled_patches[row_index, col_index]
370
+
371
+ # For each different defect inside the area check
372
+ # if enough part of it is contained in the patch
373
+ for defect_id in np.unique(all_defects_in_patch):
374
+ if defect_id == 0:
375
+ continue
376
+
377
+ defect_area_in_patch = (all_defects_in_patch == defect_id).sum()
378
+ defect_area_th = (labelled_mask == defect_id).sum() * area_defect_threshold
379
+
380
+ if defect_area_in_patch > defect_area_th:
381
+ output_class = k
382
+ if defect_area_in_patch > max_defected_area:
383
+ max_defected_area = defect_area_in_patch
384
+ discard_in_training = False
385
+ else:
386
+ discard_in_training = False
387
+
388
+ if save_mask:
389
+ mask_name = f"{image_name}_{row_index * image_patches.shape[1] + col_index}{mask_extension}.png"
390
+
391
+ if discard_in_training:
392
+ mask_name = "#DISCARD#" + mask_name
393
+ cv2.imwrite(
394
+ os.path.join(
395
+ mask_output_folder,
396
+ output_class, # type: ignore[arg-type]
397
+ mask_name,
398
+ ),
399
+ mask.astype(np.uint8),
400
+ )
401
+ else:
402
+ output_class = "None"
403
+
404
+ patch_name = f"{image_name}_{row_index * image_patches.shape[1] + col_index}.png"
405
+ if discard_in_training:
406
+ patch_name = "#DISCARD#" + patch_name
407
+
408
+ cv2.imwrite(
409
+ os.path.join(
410
+ output_folder,
411
+ output_class, # type: ignore[arg-type]
412
+ patch_name,
413
+ ),
414
+ image,
415
+ )
416
+
417
+
418
+ def generate_patch_dataset(
419
+ data_dictionary: list[dict],
420
+ class_to_idx: dict,
421
+ val_size: float = 0.3,
422
+ test_size: float = 0.0,
423
+ seed: int = 42,
424
+ patch_number: tuple[int, int] | None = None,
425
+ patch_size: tuple[int, int] | None = None,
426
+ overlap: float = 0.0,
427
+ output_folder: str = "extraction_data",
428
+ save_original_images_and_masks: bool = True,
429
+ area_threshold: float = 0.45,
430
+ area_defect_threshold: float = 0.2,
431
+ mask_extension: str = "_mask",
432
+ mask_output_folder: str | None = None,
433
+ save_mask: bool = False,
434
+ clear_output_folder: bool = False,
435
+ mask_preprocessing: Callable | None = None,
436
+ train_filename: str = "dataset.txt",
437
+ repeat_good_images: int = 1,
438
+ balance_defects: bool = True,
439
+ annotated_good: list[str] | None = None,
440
+ num_workers: int = 1,
441
+ ) -> dict | None:
442
+ """Giving a data_dictionary as:
443
+ >>> {
444
+ >>> 'base_name': '163931_1_5.jpg',
445
+ >>> 'path': 'extraction_data/1/163931_1_5.jpg',
446
+ >>> 'mask': 'extraction_data/1/163931_1_5_mask.jpg'
447
+ >>>}
448
+ This function will generate patches datasets based on the defined split number, one for training, one for validation
449
+ and one for testing respectively under output_folder/train, output_folder/val and output_folder/test, the training
450
+ dataset will contain h5 files and a txt file resulting from a call to the
451
+ generate_classification_patch_train_dataset, while the test dataset will contain patches saved on disk divided
452
+ in subfolders per class, patch extraction is done in a sliding window fashion.
453
+ Original images and masks (preprocessed if mask_preprocessing is present) will also be saved under
454
+ output_folder/original/images and output_folder/original/masks.
455
+ If patch number is specified the patch size will be calculated accordingly, if the image is not divisible by the
456
+ patch number two possible behaviours can occur:
457
+ - if the patch reconstruction is smaller than the original image a new patch will be generated containing the
458
+ pixels from the edge of the image (E.g the new patch will contain the last patch_size pixels of the original
459
+ image)
460
+ - if the patch reconstruction is bigger than the original image the last patch will contain the pixels from the
461
+ edge of the image same as above, but without adding a new patch to the count.
462
+
463
+ Args:
464
+ data_dictionary: Dictionary as above
465
+ val_size: percentage of the dictionary entries to be used for validation
466
+ test_size: percentage of the dictionary entries to be used for testing
467
+ seed: seed for rng based operations
468
+ clear_output_folder: flag used to delete all the data in subfolder
469
+ class_to_idx: Dictionary {"defect": value in mask.. }
470
+ output_folder: root_folder where to extract the data
471
+ save_original_images_and_masks: If True, images and masks will be copied inside output_folder/original/
472
+ area_threshold: Minimum percentage of defected patch area present in the mask to classify the patch as defect
473
+ area_defect_threshold: Minimum percentage of single defect present in the patch to classify the patch as defect
474
+ mask_extension: Extension used to assign image to mask
475
+ mask_output_folder: Optional folder in which to save the masks
476
+ save_mask: Flag to save the mask
477
+ patch_number: Optional number of patches for each side, required if patch_size is None
478
+ patch_size: Optional dimension of the patch, required if patch_number is None
479
+ overlap: Overlap of the patches [0, 1]
480
+ mask_preprocessing: Optional function applied to masks, this can be useful for example to convert an image in
481
+ range [0-255] to the required [0-1]
482
+ train_filename: Name of the file containing mapping between h5 files and labels for training
483
+ repeat_good_images: Number of repetition for images with emtpy or None mask
484
+ balance_defects: If true add one good entry for each defect extracted
485
+ annotated_good: List of labels that are annotated but considered as good
486
+ num_workers: Number of workers used for the h5 creation
487
+
488
+ Returns:
489
+ None if data_dictionary is empty, otherwise return a dictionary containing informations about the dataset
490
+
491
+ """
492
+ if len(data_dictionary) == 0:
493
+ warnings.warn("Input data dictionary is empty!", UserWarning, stacklevel=2)
494
+ return None
495
+
496
+ if val_size < 0 or test_size < 0 or (val_size + test_size) > 1:
497
+ raise ValueError("Validation and Test size must be greater or equal than zero and sum up to maximum 1")
498
+ if clear_output_folder and os.path.exists(output_folder):
499
+ shutil.rmtree(output_folder)
500
+ os.makedirs(output_folder, exist_ok=True)
501
+ os.makedirs(os.path.join(output_folder, "original"), exist_ok=True)
502
+ if save_original_images_and_masks:
503
+ log.info("Moving original images and masks to dataset folder...")
504
+ os.makedirs(os.path.join(output_folder, "original", "images"), exist_ok=True)
505
+ os.makedirs(os.path.join(output_folder, "original", "masks"), exist_ok=True)
506
+
507
+ for i, item in enumerate(data_dictionary):
508
+ img_new_path = os.path.join("original", "images", item["base_name"])
509
+ shutil.copy(item["path"], os.path.join(output_folder, img_new_path))
510
+ data_dictionary[i]["path"] = img_new_path
511
+
512
+ if item["mask"] is not None:
513
+ mask = cv2.imread(item["mask"])
514
+ if mask_preprocessing is not None:
515
+ mask = mask_preprocessing(mask).astype(np.uint8)
516
+ mask_new_path = os.path.join("original", "masks", os.path.splitext(item["base_name"])[0] + ".png")
517
+ cv2.imwrite(os.path.join(output_folder, mask_new_path), mask)
518
+ data_dictionary[i]["mask"] = mask_new_path
519
+
520
+ shuffled_indices = np.random.default_rng(seed).permutation(len(data_dictionary))
521
+ data_dictionary = [data_dictionary[i] for i in shuffled_indices]
522
+ log.info("Performing multilabel stratification...")
523
+ train_data_dictionary, val_data_dictionary, test_data_dictionary = multilabel_stratification(
524
+ output_folder=output_folder,
525
+ data_dictionary=data_dictionary,
526
+ num_classes=len(class_to_idx.values()),
527
+ val_size=val_size,
528
+ test_size=test_size,
529
+ )
530
+
531
+ log.info("Train set size: %d", len(train_data_dictionary))
532
+ log.info("Validation set size: %d", len(val_data_dictionary))
533
+ log.info("Test set size: %d", len(test_data_dictionary))
534
+
535
+ idx_to_class = {v: k for (k, v) in class_to_idx.items()}
536
+
537
+ os.makedirs(output_folder, exist_ok=True)
538
+
539
+ dataset_info = {
540
+ "patch_size": patch_size,
541
+ "patch_number": patch_number,
542
+ "overlap": overlap,
543
+ "annotated_good": annotated_good,
544
+ "train_files": [{"image_path": x["path"], "mask_path": x["mask"]} for x in train_data_dictionary],
545
+ "val_files": [{"image_path": x["path"], "mask_path": x["mask"]} for x in val_data_dictionary],
546
+ "test_files": [{"image_path": x["path"], "mask_path": x["mask"]} for x in test_data_dictionary],
547
+ }
548
+
549
+ with open(os.path.join(output_folder, "info.json"), "w") as f:
550
+ json.dump(dataset_info, f)
551
+
552
+ if len(train_data_dictionary) > 0:
553
+ log.info("Generating train set")
554
+ generate_patch_sampling_dataset(
555
+ data_dictionary=train_data_dictionary,
556
+ patch_number=patch_number,
557
+ patch_size=patch_size,
558
+ overlap=overlap,
559
+ idx_to_class=idx_to_class,
560
+ balance_defects=balance_defects,
561
+ repeat_good_images=repeat_good_images,
562
+ output_folder=output_folder,
563
+ subfolder_name="train",
564
+ train_filename=train_filename,
565
+ annotated_good=annotated_good if annotated_good is None else [class_to_idx[x] for x in annotated_good],
566
+ num_workers=num_workers,
567
+ )
568
+
569
+ for phase, split_dict in zip(["val", "test"], [val_data_dictionary, test_data_dictionary]):
570
+ if len(split_dict) > 0:
571
+ log.info("Generating %s set", phase)
572
+ generate_patch_sliding_window_dataset(
573
+ data_dictionary=split_dict,
574
+ patch_number=patch_number,
575
+ patch_size=patch_size,
576
+ overlap=overlap,
577
+ output_folder=output_folder,
578
+ subfolder_name=phase,
579
+ area_threshold=area_threshold,
580
+ area_defect_threshold=area_defect_threshold,
581
+ mask_extension=mask_extension,
582
+ mask_output_folder=mask_output_folder,
583
+ save_mask=save_mask,
584
+ class_to_idx=class_to_idx,
585
+ )
586
+
587
+ log.info("All done! Datasets saved to %s", output_folder)
588
+
589
+ return dataset_info
590
+
591
+
592
+ def multilabel_stratification(
593
+ output_folder: str,
594
+ data_dictionary: list[dict],
595
+ num_classes: int,
596
+ val_size: float,
597
+ test_size: float,
598
+ ) -> tuple[list[dict], list[dict], list[dict]]:
599
+ """Split data dictionary using multilabel based stratification, place every sample with None
600
+ mask inside the test set,for all the others read the labels contained in the masks
601
+ to create one-hot encoded labels.
602
+
603
+ Args:
604
+ output_folder: root folder of the dataset
605
+ data_dictionary: Data dictionary as described in generate patch dataset
606
+ num_classes: Number of classes contained in the dataset, required for one hot encoding
607
+ val_size: Percentage of data to be used for validation
608
+ test_size: Percentage of data to be used for test
609
+ Returns:
610
+ Three data dictionaries, one for training, one for validation and one for test
611
+
612
+ """
613
+ if val_size + test_size == 0:
614
+ return data_dictionary, [], []
615
+ if val_size == 1:
616
+ return [], data_dictionary, []
617
+ if test_size == 1:
618
+ return [], [], data_dictionary
619
+
620
+ test_data_dictionary = list(filter(lambda q: q["mask"] is None, data_dictionary))
621
+ log.info("Number of images with no mask inserted in test_data_dictionary: %d", len(test_data_dictionary))
622
+ empty_test_size = len(test_data_dictionary) / len(data_dictionary)
623
+ data_dictionary = list(filter(lambda q: q["mask"] is not None, data_dictionary))
624
+
625
+ if len(data_dictionary) == 0:
626
+ # All the item in the data dictionary have None mask, put everything in test
627
+ warnings.warn(
628
+ "All the images have None mask and the test size is not equal to 1! Put everything in test",
629
+ UserWarning,
630
+ stacklevel=2,
631
+ )
632
+ return [], [], test_data_dictionary
633
+
634
+ x = []
635
+ y = None
636
+ for item in data_dictionary:
637
+ one_hot = np.zeros([1, num_classes], dtype=np.int16)
638
+ if item["mask"] is None:
639
+ continue
640
+ # this works even if item["mask"] is already an absolute path
641
+ mask = cv2.imread(os.path.join(output_folder, item["mask"]), 0)
642
+
643
+ labels = np.unique(mask)
644
+
645
+ one_hot[:, labels] = 1
646
+ x.append(item["base_name"])
647
+ if y is None:
648
+ y = one_hot
649
+ else:
650
+ y = np.concatenate([y, one_hot])
651
+
652
+ x_test: list[Any] | np.ndarray
653
+
654
+ if empty_test_size > test_size:
655
+ warnings.warn(
656
+ (
657
+ "The percentage of images with None label is greater than the test_size, the newest test_size is"
658
+ f" {empty_test_size}!"
659
+ ),
660
+ UserWarning,
661
+ stacklevel=2,
662
+ )
663
+ x_train, _, x_val, _ = iterative_train_test_split(np.expand_dims(np.array(x), 1), y, val_size)
664
+ x_test = [q["base_name"] for q in test_data_dictionary]
665
+ else:
666
+ test_size -= empty_test_size
667
+ x_train, _, x_remaining, y_remaining = iterative_train_test_split(
668
+ np.expand_dims(np.array(x), 1), y, val_size + test_size
669
+ )
670
+
671
+ if x_remaining.shape[0] == 1:
672
+ if test_size == 0:
673
+ x_val = x_remaining
674
+ x_test = np.array([])
675
+ elif val_size == 0:
676
+ x_test = x_remaining
677
+ x_val = np.array([])
678
+ else:
679
+ log.warning("Not enough data to create the test split, only a validation set of size 1 will be created")
680
+ x_val = x_remaining
681
+ x_test = np.array([])
682
+ else:
683
+ x_val, _, x_test, _ = iterative_train_test_split(
684
+ x_remaining, y_remaining, test_size / (val_size + test_size)
685
+ )
686
+ # Here x_test should be always a numpy array, but mypy does not recognize it
687
+ x_test = [q[0] for q in x_test.tolist()] # type: ignore[union-attr]
688
+ x_test.extend([q["base_name"] for q in test_data_dictionary])
689
+
690
+ train_data_dictionary = list(filter(lambda q: q["base_name"] in x_train, data_dictionary))
691
+ val_data_dictionary = list(filter(lambda q: q["base_name"] in x_val, data_dictionary))
692
+ test_data_dictionary = list(filter(lambda q: q["base_name"] in x_test, data_dictionary + test_data_dictionary))
693
+
694
+ return train_data_dictionary, val_data_dictionary, test_data_dictionary
695
+
696
+
697
+ def generate_patch_sliding_window_dataset(
698
+ data_dictionary: list[dict],
699
+ subfolder_name: str,
700
+ patch_number: tuple[int, int] | None = None,
701
+ patch_size: tuple[int, int] | None = None,
702
+ overlap: float = 0.0,
703
+ output_folder: str = "extraction_data",
704
+ area_threshold: float = 0.45,
705
+ area_defect_threshold: float = 0.2,
706
+ mask_extension: str = "_mask",
707
+ mask_output_folder: str | None = None,
708
+ save_mask: bool = False,
709
+ class_to_idx: dict | None = None,
710
+ ) -> None:
711
+ """Giving a data_dictionary as:
712
+ >>> {
713
+ >>> 'base_name': '163931_1_5.jpg',
714
+ >>> 'path': 'extraction_data/1/163931_1_5.jpg',
715
+ >>> 'mask': 'extraction_data/1/163931_1_5_mask.jpg'
716
+ >>>}
717
+ This function will extract the patches and save the file and the mask in subdirectory
718
+ Args:
719
+ data_dictionary: Dictionary as above
720
+ subfolder_name: Name of the subfolder where to save the extracted patches (output_folder/subfolder_name)
721
+ class_to_idx: Dictionary {"defect": value in mask.. }
722
+ output_folder: root_folder where to extract the data
723
+ area_threshold: minimum percentage of defected patch area present in the mask to classify the patch as defect
724
+ area_defect_threshold: minimum percentage of single defect present in the patch to classify the patch as defect
725
+ mask_extension: extension used to assign image to mask
726
+ mask_output_folder: Optional folder in which to save the masks
727
+ save_mask: flag to save the mask
728
+ patch_number: Optional number of patches for each side, required if patch_size is None
729
+ patch_size: Optional dimension of the patch, required if patch_number is None
730
+ overlap: overlap of the patches [0, 1].
731
+
732
+ Returns:
733
+ None.
734
+
735
+ """
736
+ if save_mask and len(mask_extension) == 0 and mask_output_folder is None:
737
+ raise InvalidParameterCombinationException(
738
+ "If mask output folder is not set you must specify a mask extension in order to save masks!"
739
+ )
740
+
741
+ if patch_number is None and patch_size is None:
742
+ raise InvalidParameterCombinationException("One between patch number or patch size must be specified!")
743
+
744
+ for data in tqdm(data_dictionary):
745
+ base_id = data.get("base_name")
746
+ base_path = data.get("path")
747
+ base_mask = data.get("mask")
748
+
749
+ assert base_id is not None, "Cannot find base id in data_dictionary"
750
+ assert base_path is not None, "Cannot find image in data_dictionary"
751
+
752
+ image = cv2.imread(os.path.join(output_folder, base_path))
753
+ h = image.shape[0]
754
+ w = image.shape[1]
755
+
756
+ log.debug("Processing %s with shape %s", base_id, image.shape)
757
+ mask = mask_patches = None
758
+ labelled_mask = labelled_patches = None
759
+
760
+ if base_mask is not None:
761
+ mask = cv2.imread(os.path.join(output_folder, base_mask), 0)
762
+ labelled_mask = label(mask)
763
+
764
+ if patch_size is not None:
765
+ [patch_height, patch_width] = patch_size
766
+ [patch_num_h, patch_num_w], step = compute_patch_info_from_patch_dim(
767
+ h, w, patch_height, patch_width, overlap
768
+ )
769
+ elif patch_number is not None:
770
+ [patch_height, patch_width], step = compute_patch_info(h, w, patch_number[0], patch_number[1], overlap)
771
+ [patch_num_h, patch_num_w] = patch_number
772
+ else:
773
+ # mypy does not recognize that this is unreachable
774
+ raise InvalidParameterCombinationException("One between patch number or patch size must be specified!")
775
+
776
+ log.debug(
777
+ "Extracting %s patches with size %s, step %s", [patch_num_h, patch_num_w], [patch_height, patch_width], step
778
+ )
779
+ image_patches = extract_patches(image, (patch_num_h, patch_num_w), (patch_height, patch_width), step, overlap)
780
+
781
+ if mask is not None:
782
+ if labelled_mask is None:
783
+ raise ValueError("Labelled mask cannot be None!")
784
+ mask_patches = extract_patches(mask, (patch_num_h, patch_num_w), (patch_height, patch_width), step, overlap)
785
+ labelled_patches = extract_patches(
786
+ labelled_mask, (patch_num_h, patch_num_w), (patch_height, patch_width), step, overlap
787
+ )
788
+ assert image_patches.shape[:-1] == mask_patches.shape, "Image patches and mask patches mismatch!"
789
+
790
+ log.debug("Image patches shape: %s", image_patches.shape)
791
+ __save_patch_dataset(
792
+ image_patches=image_patches,
793
+ mask_patches=mask_patches,
794
+ labelled_patches=labelled_patches,
795
+ labelled_mask=labelled_mask,
796
+ image_name=os.path.splitext(base_id)[0],
797
+ output_folder=os.path.join(output_folder, subfolder_name),
798
+ area_threshold=area_threshold,
799
+ area_defect_threshold=area_defect_threshold,
800
+ mask_extension=mask_extension,
801
+ save_mask=save_mask,
802
+ mask_output_folder=mask_output_folder,
803
+ class_to_idx=class_to_idx,
804
+ )
805
+
806
+
807
+ def extract_patches(
808
+ image: np.ndarray,
809
+ patch_number: tuple[int, ...],
810
+ patch_size: tuple[int, ...],
811
+ step: tuple[int, ...],
812
+ overlap: float,
813
+ ) -> np.ndarray:
814
+ """From an image extract N x M Patch[h, w] if the image is not perfectly divided by the number of patches of given
815
+ dimension the last patch will contain a replica of the original image taken in range [-img_h:, :] or [:, -img_w:].
816
+
817
+ Args:
818
+ image: Numpy array of the image
819
+ patch_number: number of patches to be extracted
820
+ patch_size: dimension of the patch
821
+ step: step of the patch extraction
822
+ overlap: horizontal and vertical patch overlapping in range [0, 1]
823
+
824
+ Returns:
825
+ Patches [N, M, 1, image_w, image_h, image_c]
826
+
827
+ """
828
+ assert 1.0 >= overlap >= 0.0, f"Overlap must be between 0 and 1. Received {overlap}"
829
+ (patch_num_h, patch_num_w) = patch_number
830
+ (patch_height, patch_width) = patch_size
831
+
832
+ pad_h = (patch_num_h - 1) * step[0] + patch_size[0] - image.shape[0]
833
+ pad_w = (patch_num_w - 1) * step[1] + patch_size[1] - image.shape[1]
834
+ # if the image has 3 channel change dimension
835
+ if len(image.shape) == 3:
836
+ patch_size = (patch_size[0], patch_size[1], image.shape[2])
837
+ step = (step[0], step[1], image.shape[2])
838
+
839
+ # If this is not true there's some strange case I didn't take into account
840
+ if pad_h < 0 or pad_w < 0:
841
+ raise ValueError("Something went wrong with the patch extraction, expected positive padding values")
842
+
843
+ if pad_h > 0 or pad_w > 0:
844
+ # We work with copies as view_as_windows returns a view of the original image
845
+ crop_img = deepcopy(image)
846
+
847
+ if pad_h:
848
+ crop_img = crop_img[0 : (patch_num_h - 2) * step[0] + patch_height, :]
849
+
850
+ if pad_w:
851
+ crop_img = crop_img[:, 0 : (patch_num_w - 2) * step[1] + patch_width]
852
+
853
+ # Extract safe patches inside the image
854
+ patches = view_as_windows(crop_img, patch_size, step=step)
855
+ else:
856
+ patches = view_as_windows(image, patch_size, step=step)
857
+
858
+ extra_patches_h = None
859
+ extra_patches_w = None
860
+
861
+ if pad_h > 0:
862
+ # Append extra patches taken from the edge of the image
863
+ extra_patches_h = view_as_windows(image[-patch_height:, :], patch_size, step=step)
864
+
865
+ if pad_w > 0:
866
+ extra_patches_w = view_as_windows(image[:, -patch_width:], patch_size, step=step)
867
+
868
+ if extra_patches_h is not None:
869
+ # Add an extra column and set is content to the bottom right patch area of the original image if both
870
+ # dimension requires extra patches
871
+ if extra_patches_h.ndim == 6:
872
+ # RGB
873
+ extra_patches_h = np.concatenate(
874
+ [
875
+ extra_patches_h,
876
+ (np.zeros([1, 1, 1, patch_size[0], patch_size[1], extra_patches_h.shape[-1]], dtype=np.uint8)),
877
+ ],
878
+ axis=1,
879
+ )
880
+ else:
881
+ extra_patches_h = np.concatenate(
882
+ [extra_patches_h, (np.zeros([1, 1, patch_size[0], patch_size[1]], dtype=np.uint8))], axis=1
883
+ )
884
+
885
+ if extra_patches_h is None:
886
+ # Required by mypy as it cannot infer that extra_patch_h cannot be None
887
+ raise ValueError("Extra patch h cannot be None!")
888
+
889
+ extra_patches_h[:, -1, :] = image[-patch_height:, -patch_width:]
890
+
891
+ if patches.ndim == 6:
892
+ # With RGB images there's an extra dimension, axis 2 is important don't use plain squeeze or it breaks if
893
+ # the number of patches is set to 1!
894
+ patches = patches.squeeze(axis=2)
895
+
896
+ if extra_patches_w is not None:
897
+ if extra_patches_w.ndim == 6:
898
+ # RGB
899
+ patches = np.concatenate([patches, extra_patches_w.squeeze(2)], axis=1)
900
+ else:
901
+ patches = np.concatenate([patches, extra_patches_w], axis=1)
902
+
903
+ if extra_patches_h is not None:
904
+ if extra_patches_h.ndim == 6:
905
+ # RGB
906
+ patches = np.concatenate([patches, extra_patches_h.squeeze(2)], axis=0)
907
+ else:
908
+ patches = np.concatenate([patches, extra_patches_h], axis=0)
909
+
910
+ # If this is not true there's some strange case I didn't take into account
911
+ assert (
912
+ patches.shape[0] == patch_num_h and patches.shape[1] == patch_num_w
913
+ ), f"Patch shape {patches.shape} does not match the expected shape {patch_number}"
914
+
915
+ return patches
916
+
917
+
918
+ def generate_patch_sampling_dataset(
919
+ data_dictionary: list[dict[Any, Any]],
920
+ output_folder: str,
921
+ idx_to_class: dict,
922
+ overlap: float,
923
+ repeat_good_images: int = 1,
924
+ balance_defects: bool = True,
925
+ patch_number: tuple[int, int] | None = None,
926
+ patch_size: tuple[int, int] | None = None,
927
+ subfolder_name: str = "train",
928
+ train_filename: str = "dataset.txt",
929
+ annotated_good: list[int] | None = None,
930
+ num_workers: int = 1,
931
+ ) -> None:
932
+ """Generate a dataset of patches.
933
+
934
+ Args:
935
+ data_dictionary: Dictionary containing image and mask mapping
936
+ output_folder: root folder
937
+ idx_to_class: Dict mapping an index to the corresponding class name
938
+ repeat_good_images: Number of repetition for images with emtpy or None mask
939
+ balance_defects: If true add one good entry for each defect extracted
940
+ patch_number: Optional number of patches for each side, required if patch_size is None
941
+ patch_size: Optional dimension of the patch, required if patch_number is None
942
+ overlap: Percentage of overlap between patches
943
+ subfolder_name: name of the subfolder where to store h5 files for defected images and dataset txt
944
+ train_filename: Name of the file in which to store the mappings between h5 files and labels
945
+ annotated_good: List of class indices that are considered good other than the background
946
+ num_workers: Number of processes used to create h5 files.
947
+
948
+ Returns:
949
+ Create a txt file containing tuples path,label where path is a pointer to the generated h5 file and label is the
950
+ corresponding label
951
+
952
+ Each generated h5 file contains five fields:
953
+ img_path: Pointer to the location of the original image
954
+ mask_path: Optional pointer to the mask file, is missing if the mask is completely empty or is
955
+ not present
956
+ patch_size: dimension of the patches on the interested image
957
+ triangles: List of triangles that covers the defect
958
+ triangles_weights: Which weight should be given to each triangle for sampling
959
+
960
+ """
961
+ if patch_number is None and patch_size is None:
962
+ raise InvalidParameterCombinationException("One between patch number or patch size must be specified!")
963
+
964
+ sampling_dataset_folder = os.path.join(output_folder, subfolder_name)
965
+
966
+ os.makedirs(sampling_dataset_folder, exist_ok=True)
967
+ labelled_masks_path = os.path.join(output_folder, "original", "labelled_masks")
968
+ os.makedirs(labelled_masks_path, exist_ok=True)
969
+
970
+ with open(os.path.join(sampling_dataset_folder, train_filename), "w") as output_file:
971
+ if num_workers < 1:
972
+ raise InvalidNumWorkersNumberException("Workers must be >= 1")
973
+
974
+ if num_workers > 1:
975
+ log.info("Executing generate_patch_sampling_dataset w/ more than 1 worker!")
976
+
977
+ split_data_dictionary = np.array_split(np.asarray(data_dictionary), num_workers)
978
+
979
+ with Pool(num_workers) as pool:
980
+ res_list = pool.map(
981
+ partial(
982
+ create_h5,
983
+ patch_size=patch_size,
984
+ patch_number=patch_number,
985
+ idx_to_class=idx_to_class,
986
+ overlap=overlap,
987
+ repeat_good_images=repeat_good_images,
988
+ balance_defects=balance_defects,
989
+ annotated_good=annotated_good,
990
+ output_folder=output_folder,
991
+ labelled_masks_path=labelled_masks_path,
992
+ sampling_dataset_folder=sampling_dataset_folder,
993
+ ),
994
+ split_data_dictionary,
995
+ )
996
+
997
+ res = list(itertools.chain(*res_list))
998
+ else:
999
+ res = create_h5(
1000
+ data_dictionary=data_dictionary,
1001
+ patch_size=patch_size,
1002
+ patch_number=patch_number,
1003
+ idx_to_class=idx_to_class,
1004
+ overlap=overlap,
1005
+ repeat_good_images=repeat_good_images,
1006
+ balance_defects=balance_defects,
1007
+ annotated_good=annotated_good,
1008
+ output_folder=output_folder,
1009
+ labelled_masks_path=labelled_masks_path,
1010
+ sampling_dataset_folder=sampling_dataset_folder,
1011
+ )
1012
+
1013
+ for line in res:
1014
+ output_file.write(line)
1015
+
1016
+
1017
+ def create_h5(
1018
+ data_dictionary: list[dict[Any, Any]],
1019
+ idx_to_class: dict,
1020
+ overlap: float,
1021
+ repeat_good_images: int,
1022
+ balance_defects: bool,
1023
+ output_folder: str,
1024
+ labelled_masks_path: str,
1025
+ sampling_dataset_folder: str,
1026
+ annotated_good: list[int] | None = None,
1027
+ patch_size: tuple[int, int] | None = None,
1028
+ patch_number: tuple[int, int] | None = None,
1029
+ ) -> list[str]:
1030
+ """Create h5 files for each image in the dataset.
1031
+
1032
+ Args:
1033
+ data_dictionary: Dictionary containing image and mask mapping
1034
+ idx_to_class: Dict mapping an index to the corresponding class name
1035
+ overlap: Percentage of overlap between patches
1036
+ repeat_good_images: Number of repetition for images with emtpy or None mask
1037
+ balance_defects: If true add one good entry for each defect extracted
1038
+ output_folder: root folder
1039
+ overlap: Percentage of overlap between patches
1040
+ annotated_good: List of class indices that are considered good other than the background
1041
+ labelled_masks_path: paths of labelled masks
1042
+ sampling_dataset_folder: folder of the dataset
1043
+ patch_size: Dimension of the patch, required if patch_number is None
1044
+ patch_number: Number of patches for each side, required if patch_size is None.
1045
+
1046
+ Returns:
1047
+ output_list: List of h5 files' names
1048
+
1049
+ """
1050
+ if patch_number is None and patch_size is None:
1051
+ raise InvalidParameterCombinationException("One between patch number or patch size must be specified!")
1052
+
1053
+ output_list = []
1054
+ for item in tqdm(data_dictionary):
1055
+ log.debug("Processing %s", item["base_name"])
1056
+ # this works even if item["path"] is already an absolute path
1057
+ img = cv2.imread(os.path.join(output_folder, item["path"]))
1058
+
1059
+ h = img.shape[0]
1060
+ w = img.shape[1]
1061
+
1062
+ if item["mask"] is None:
1063
+ mask = np.zeros([h, w])
1064
+ else:
1065
+ # this works even if item["mask"] is already an absolute path
1066
+ mask = cv2.imread(os.path.join(output_folder, item["mask"]), 0) # type: ignore[assignment]
1067
+
1068
+ if patch_size is not None:
1069
+ patch_height = patch_size[1]
1070
+ patch_width = patch_size[0]
1071
+ else:
1072
+ # Mypy complains because patch_number is Optional, but we already checked that it is not None.
1073
+ [patch_height, patch_width], _ = compute_patch_info(
1074
+ h,
1075
+ w,
1076
+ patch_number[0], # type: ignore[index]
1077
+ patch_number[1], # type: ignore[index]
1078
+ overlap,
1079
+ )
1080
+
1081
+ h5_file_name_good = os.path.join(sampling_dataset_folder, f"{os.path.splitext(item['base_name'])[0]}_good.h5")
1082
+
1083
+ disable_good = False
1084
+
1085
+ with h5py.File(h5_file_name_good, "w") as f:
1086
+ f.create_dataset("img_path", data=item["path"])
1087
+ f.create_dataset("patch_size", data=np.array([patch_height, patch_width]))
1088
+
1089
+ target = idx_to_class[0]
1090
+
1091
+ if mask.sum() == 0:
1092
+ f.create_dataset("triangles", data=np.array([], dtype=np.uint8), dtype=np.uint8)
1093
+ f.create_dataset("triangles_weights", data=np.array([], dtype=np.uint8), dtype=np.uint8)
1094
+
1095
+ for _ in range(repeat_good_images):
1096
+ output_list.append(f"{os.path.basename(h5_file_name_good)},{target}\n")
1097
+
1098
+ continue
1099
+
1100
+ binary_mask = (mask > 0).astype(np.uint8)
1101
+
1102
+ # Dilate the defects and take the background
1103
+ binary_mask = np.logical_not(cv2.dilate(binary_mask, np.ones([patch_height, patch_width]))).astype(np.uint8)
1104
+
1105
+ temp_binary_mask = deepcopy(binary_mask)
1106
+ # Remove the edges of the image as they are unsafe for sampling without padding
1107
+ temp_binary_mask[0 : patch_height // 2, :] = 0
1108
+ temp_binary_mask[:, 0 : patch_width // 2] = 0
1109
+ temp_binary_mask[-patch_height // 2 :, :] = 0
1110
+ temp_binary_mask[:, -patch_width // 2 :] = 0
1111
+
1112
+ if temp_binary_mask.sum() != 0:
1113
+ # If the mask without the edges is not empty use it, otherwise use the original mask as it is not
1114
+ # possible to sample a patch that will not exceed the edges, this must be taken care by the patch
1115
+ # sampler used during training
1116
+ binary_mask = temp_binary_mask
1117
+
1118
+ # In the case of hx1 or 1xw number of patches we must make sure that the sampling row or the sampling
1119
+ # column is empty, if it isn't remove it from the possible sampling area
1120
+ if patch_height == img.shape[0]:
1121
+ must_clear_indices = np.where(binary_mask.sum(axis=0) != img.shape[0])[0]
1122
+ binary_mask[:, must_clear_indices] = 0
1123
+
1124
+ if patch_width == img.shape[1]:
1125
+ must_clear_indices = np.where(binary_mask.sum(axis=1) != img.shape[1])[0]
1126
+ binary_mask[must_clear_indices, :] = 0
1127
+
1128
+ # If there's no space for sampling good patches skip it
1129
+ if binary_mask.sum() == 0:
1130
+ disable_good = True
1131
+ else:
1132
+ triangles, weights = triangulate_region(binary_mask)
1133
+ if triangles is None:
1134
+ disable_good = True
1135
+ else:
1136
+ log.debug(
1137
+ "Saving %s triangles for %s with label %s",
1138
+ triangles.shape[0],
1139
+ os.path.basename(h5_file_name_good),
1140
+ target,
1141
+ )
1142
+
1143
+ f.create_dataset("mask_path", data=item["mask"])
1144
+ # Points from extracted triangles should be sufficiently far from all the defects allowing to sample
1145
+ # good patches almost all the time
1146
+ f.create_dataset("triangles", data=triangles, dtype=np.int32)
1147
+ f.create_dataset("triangles_weights", data=weights, dtype=np.float64)
1148
+
1149
+ # Avoid saving the good h5 file here because otherwise I'll have one more good compared to the
1150
+ # number of defects
1151
+ if not balance_defects:
1152
+ output_list.append(f"{os.path.basename(h5_file_name_good)},{target}\n")
1153
+
1154
+ if disable_good:
1155
+ os.remove(h5_file_name_good)
1156
+
1157
+ labelled_mask = label(mask)
1158
+ cv2.imwrite(os.path.join(labelled_masks_path, f"{os.path.splitext(item['base_name'])[0]}.png"), labelled_mask)
1159
+
1160
+ real_defects_mask = None
1161
+
1162
+ if annotated_good is not None:
1163
+ # Remove true defected area from the good labeled mask
1164
+ # If we want this to be even more restrictive we could also include the background as we don't know for sure
1165
+ # it will not contain any defects
1166
+ real_defects_mask = (~np.isin(mask, [0] + annotated_good)).astype(np.uint8)
1167
+ real_defects_mask = cv2.dilate(real_defects_mask, np.ones([patch_height, patch_width])).astype(bool)
1168
+
1169
+ for i in np.unique(labelled_mask):
1170
+ if i == 0:
1171
+ continue
1172
+
1173
+ current_mask = (labelled_mask == i).astype(np.uint8)
1174
+ target_idx = (mask * current_mask).max()
1175
+
1176
+ # When we have good annotations we want to avoid sampling patches containing true defects, to do so we
1177
+ # reduce the extraction area based on the area covered by the other defects
1178
+ if annotated_good is not None and real_defects_mask is not None and target_idx in annotated_good:
1179
+ # a - b = a & ~b
1180
+ # pylint: disable=invalid-unary-operand-type
1181
+ current_mask = np.bitwise_and(current_mask.astype(bool), ~real_defects_mask).astype(np.uint8)
1182
+ else:
1183
+ # When dealing with small defects the number of points that will be sampled will be limited and patches
1184
+ # will mostly be centered around the defect, to overcome this issue enlarge defect bounding box by 50%
1185
+ # of the difference between the patch_size and the defect bb size, we don't do this on good labels to
1186
+ # avoid invalidating the reduction applied before.
1187
+ props = regionprops(current_mask)[0]
1188
+ bbox_size = [props.bbox[2] - props.bbox[0], props.bbox[3] - props.bbox[1]]
1189
+ diff_bbox = np.array([max(0, patch_height - bbox_size[0]), max(0, patch_width - bbox_size[1])])
1190
+
1191
+ if diff_bbox[0] != 0:
1192
+ current_mask = cv2.dilate(current_mask, np.ones([diff_bbox[0] // 2, 1]))
1193
+ if diff_bbox[1] != 0:
1194
+ current_mask = cv2.dilate(current_mask, np.ones([1, diff_bbox[1] // 2]))
1195
+
1196
+ if current_mask.sum() == 0:
1197
+ # If it's not possible to sample a labelled good patch basically
1198
+ continue
1199
+
1200
+ temp_current_mask = deepcopy(current_mask)
1201
+ # Remove the edges of the image as they are unsafe for sampling without padding
1202
+ temp_current_mask[0 : patch_height // 2, :] = 0
1203
+ temp_current_mask[:, 0 : patch_width // 2] = 0
1204
+ temp_current_mask[-patch_height // 2 :, :] = 0
1205
+ temp_current_mask[:, -patch_width // 2 :] = 0
1206
+
1207
+ if temp_current_mask.sum() != 0:
1208
+ # If the mask without the edges is not empty use it, otherwise use the original mask as it is not
1209
+ # possible to sample a patch that will not exceed the edges, this must be taken care by the patch
1210
+ # sampler used during training
1211
+ current_mask = temp_current_mask
1212
+
1213
+ triangles, weights = triangulate_region(current_mask)
1214
+
1215
+ if triangles is not None:
1216
+ h5_file_name = os.path.join(sampling_dataset_folder, f"{os.path.splitext(item['base_name'])[0]}_{i}.h5")
1217
+
1218
+ target = idx_to_class[target_idx]
1219
+
1220
+ log.debug(
1221
+ "Saving %s triangles for %s with label %s",
1222
+ triangles.shape[0],
1223
+ os.path.basename(h5_file_name),
1224
+ target,
1225
+ )
1226
+
1227
+ with h5py.File(h5_file_name, "w") as f:
1228
+ f.create_dataset("img_path", data=item["path"])
1229
+ f.create_dataset("mask_path", data=item["mask"])
1230
+ f.create_dataset("patch_size", data=np.array([patch_height, patch_width]))
1231
+ f.create_dataset("triangles", data=triangles, dtype=np.int32)
1232
+ f.create_dataset("triangles_weights", data=weights, dtype=np.float64)
1233
+ f.create_dataset("labelled_index", data=i, dtype=np.int32)
1234
+
1235
+ if annotated_good is not None and target_idx in annotated_good:
1236
+ # I treat annotate good images exactly the same as I would treat background
1237
+ for _ in range(repeat_good_images):
1238
+ output_list.append(f"{os.path.basename(h5_file_name)},{target}\n")
1239
+ else:
1240
+ output_list.append(f"{os.path.basename(h5_file_name)},{target}\n")
1241
+
1242
+ if balance_defects:
1243
+ if not disable_good:
1244
+ output_list.append(f"{os.path.basename(h5_file_name_good)},{idx_to_class[0]}\n")
1245
+ else:
1246
+ log.debug(
1247
+ "Unable to add a good defect for %s, since there's no way to sample good patches",
1248
+ h5_file_name,
1249
+ )
1250
+ return output_list
1251
+
1252
+
1253
+ def triangle_area(triangle: np.ndarray) -> float:
1254
+ """Compute the area of a triangle defined by 3 points.
1255
+
1256
+ Args:
1257
+ triangle: Array of shape 3x2 containing the coordinates of a triangle.
1258
+
1259
+ Returns:
1260
+ The area of the triangle
1261
+
1262
+ """
1263
+ [y1, x1], [y2, x2], [y3, x3] = triangle
1264
+ return abs(0.5 * (((x2 - x1) * (y3 - y1)) - ((x3 - x1) * (y2 - y1))))
1265
+
1266
+
1267
+ def triangulate_region(mask: ndimage) -> tuple[np.ndarray | None, np.ndarray | None]:
1268
+ """Extract from a binary image containing a single roi (with or without holes) a list of triangles
1269
+ (and their normalized area) that completely subdivide an approximated polygon defined around mask contours,
1270
+ the output can be used to easily sample uniformly points that are almost guarantee to lie inside the roi.
1271
+
1272
+ Args:
1273
+ mask: Binary image defining a region of interest
1274
+
1275
+ Returns:
1276
+ Tuple containing:
1277
+ triangles: a numpy array containing a list of list of vertices (y, x) of the triangles defined over a
1278
+ polygon that contains the entire region
1279
+ weights: areas of each triangle rescaled (area_i / sum(areas))
1280
+
1281
+ """
1282
+ polygon_points, hier = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_TC89_L1)
1283
+
1284
+ if not np.all(hier[:, :, 3] == -1): # there are holes
1285
+ holes = ndimage.binary_fill_holes(mask).astype(np.uint8)
1286
+ holes -= mask
1287
+ holes = (holes > 0).astype(np.uint8)
1288
+ if holes.sum() > 0: # there are holes
1289
+ for hole in regionprops(label(holes)):
1290
+ y_hole_center = int(hole.centroid[0])
1291
+ mask[y_hole_center] = 0
1292
+
1293
+ polygon_points, _ = cv2.findContours(mask, cv2.RETR_LIST, cv2.CHAIN_APPROX_TC89_L1)
1294
+
1295
+ final_approx = []
1296
+
1297
+ # Extract a simpler approximation of the contour
1298
+ for cnt in polygon_points:
1299
+ epsilon = 0.01 * cv2.arcLength(cnt, True)
1300
+ approx = cv2.approxPolyDP(cnt, epsilon, True)
1301
+ final_approx.append(approx)
1302
+
1303
+ triangles = None
1304
+
1305
+ for approx in final_approx:
1306
+ contours_tripy = [x[0] for x in approx]
1307
+ current_triangles = earclip(contours_tripy)
1308
+
1309
+ if len(current_triangles) == 0:
1310
+ # This can only happen is a defect is like one pixel wide...
1311
+ continue
1312
+
1313
+ current_triangles = np.array([list(x) for x in current_triangles])
1314
+
1315
+ triangles = current_triangles if triangles is None else np.concatenate([triangles, current_triangles])
1316
+
1317
+ if triangles is None:
1318
+ return None, None
1319
+
1320
+ # Swap x and y to match cv2
1321
+ triangles = triangles[..., ::-1]
1322
+
1323
+ weights = np.array([triangle_area(x) for x in triangles])
1324
+ weights = weights / weights.sum()
1325
+
1326
+ return triangles, weights
1327
+
1328
+
1329
+ class InvalidParameterCombinationException(Exception):
1330
+ """Exception raised when an invalid combination of parameters is passed to a function."""
1331
+
1332
+
1333
+ class InvalidNumWorkersNumberException(Exception):
1334
+ """Exception raised when an invalid number of workers is passed to a function."""
1335
+
1336
+
1337
+ def load_train_file(
1338
+ train_file_path: str,
1339
+ include_filter: list[str] | None = None,
1340
+ exclude_filter: list[str] | None = None,
1341
+ class_to_skip: list | None = None,
1342
+ ) -> tuple[list[str], list[str]]:
1343
+ """Load a train file and return a list of samples and a list of targets. It is expected that train files will be in
1344
+ the same location as the train_file_path.
1345
+
1346
+ Args:
1347
+ train_file_path: Training file location
1348
+ include_filter: Include only samples that contain one of the element of this list
1349
+ exclude_filter: Exclude all samples that contain one of the element of this list
1350
+ class_to_skip: if not None, exlude all the samples with labels present in this list.
1351
+
1352
+ Returns:
1353
+ List of samples and list of targets
1354
+
1355
+ """
1356
+ samples = []
1357
+ targets = []
1358
+
1359
+ with open(train_file_path) as f:
1360
+ lines = f.read().splitlines()
1361
+ for line in lines:
1362
+ sample, target = line.split(",")
1363
+ if class_to_skip is not None and target in class_to_skip:
1364
+ continue
1365
+ samples.append(sample)
1366
+ targets.append(target)
1367
+
1368
+ include_filter = [] if include_filter is None else include_filter
1369
+ exclude_filter = [] if exclude_filter is None else exclude_filter
1370
+
1371
+ valid_samples_indices = [
1372
+ i
1373
+ for (i, x) in enumerate(samples)
1374
+ if (len(include_filter) == 0 or any(f in x for f in include_filter))
1375
+ and (len(exclude_filter) == 0 or not any(f in x for f in exclude_filter))
1376
+ ]
1377
+
1378
+ samples = [samples[i] for i in valid_samples_indices]
1379
+ targets = [targets[i] for i in valid_samples_indices]
1380
+
1381
+ train_folder = os.path.dirname(train_file_path)
1382
+ samples = [os.path.join(train_folder, x) for x in samples]
1383
+
1384
+ return samples, targets
1385
+
1386
+
1387
+ def compute_safe_patch_range(sampled_point: int, patch_size: int, image_size: int) -> tuple[int, int]:
1388
+ """Computes the safe patch size for the given image size.
1389
+
1390
+ Args:
1391
+ sampled_point: the sampled point
1392
+ patch_size: the size of the patch
1393
+ image_size: the size of the image.
1394
+
1395
+ Returns:
1396
+ Tuple containing the safe patch range [left, right] such that
1397
+ [sampled_point - left : sampled_point + right] will be within the image size.
1398
+ """
1399
+ left = patch_size // 2
1400
+ right = patch_size // 2
1401
+
1402
+ if sampled_point + right > image_size:
1403
+ right = image_size - sampled_point
1404
+ left = patch_size - right
1405
+
1406
+ if sampled_point - left < 0:
1407
+ left = sampled_point
1408
+ right = patch_size - left
1409
+
1410
+ return left, right
1411
+
1412
+
1413
+ def trisample(triangle: np.ndarray) -> tuple[int, int]:
1414
+ """Sample a point uniformly in a triangle.
1415
+
1416
+ Args:
1417
+ triangle: Array of shape 3x2 containing the coordinates of a triangle.
1418
+
1419
+ Returns:
1420
+ Sample point uniformly in the triangle
1421
+
1422
+ """
1423
+ [y1, x1], [y2, x2], [y3, x3] = triangle
1424
+
1425
+ r1 = random.random()
1426
+ r2 = random.random()
1427
+
1428
+ s1 = math.sqrt(r1)
1429
+
1430
+ x = x1 * (1.0 - s1) + x2 * (1.0 - r2) * s1 + x3 * r2 * s1
1431
+ y = y1 * (1.0 - s1) + y2 * (1.0 - r2) * s1 + y3 * r2 * s1
1432
+
1433
+ return int(y), int(x)