quadra 0.0.1__py3-none-any.whl → 2.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (302) hide show
  1. hydra_plugins/quadra_searchpath_plugin.py +30 -0
  2. quadra/__init__.py +6 -0
  3. quadra/callbacks/__init__.py +0 -0
  4. quadra/callbacks/anomalib.py +289 -0
  5. quadra/callbacks/lightning.py +501 -0
  6. quadra/callbacks/mlflow.py +291 -0
  7. quadra/callbacks/scheduler.py +69 -0
  8. quadra/configs/__init__.py +0 -0
  9. quadra/configs/backbone/caformer_m36.yaml +8 -0
  10. quadra/configs/backbone/caformer_s36.yaml +8 -0
  11. quadra/configs/backbone/convnextv2_base.yaml +8 -0
  12. quadra/configs/backbone/convnextv2_femto.yaml +8 -0
  13. quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
  14. quadra/configs/backbone/dino_vitb8.yaml +12 -0
  15. quadra/configs/backbone/dino_vits8.yaml +12 -0
  16. quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
  17. quadra/configs/backbone/dinov2_vits14.yaml +12 -0
  18. quadra/configs/backbone/efficientnet_b0.yaml +8 -0
  19. quadra/configs/backbone/efficientnet_b1.yaml +8 -0
  20. quadra/configs/backbone/efficientnet_b2.yaml +8 -0
  21. quadra/configs/backbone/efficientnet_b3.yaml +8 -0
  22. quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
  23. quadra/configs/backbone/levit_128s.yaml +8 -0
  24. quadra/configs/backbone/mnasnet0_5.yaml +9 -0
  25. quadra/configs/backbone/resnet101.yaml +8 -0
  26. quadra/configs/backbone/resnet18.yaml +8 -0
  27. quadra/configs/backbone/resnet18_ssl.yaml +8 -0
  28. quadra/configs/backbone/resnet50.yaml +8 -0
  29. quadra/configs/backbone/smp.yaml +9 -0
  30. quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
  31. quadra/configs/backbone/unetr.yaml +15 -0
  32. quadra/configs/backbone/vit16_base.yaml +9 -0
  33. quadra/configs/backbone/vit16_small.yaml +9 -0
  34. quadra/configs/backbone/vit16_tiny.yaml +9 -0
  35. quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
  36. quadra/configs/callbacks/all.yaml +45 -0
  37. quadra/configs/callbacks/default.yaml +34 -0
  38. quadra/configs/callbacks/default_anomalib.yaml +64 -0
  39. quadra/configs/config.yaml +33 -0
  40. quadra/configs/core/default.yaml +11 -0
  41. quadra/configs/datamodule/base/anomaly.yaml +16 -0
  42. quadra/configs/datamodule/base/classification.yaml +21 -0
  43. quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
  44. quadra/configs/datamodule/base/segmentation.yaml +18 -0
  45. quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
  46. quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
  47. quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
  48. quadra/configs/datamodule/base/ssl.yaml +21 -0
  49. quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
  50. quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
  51. quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
  52. quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
  53. quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
  54. quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
  55. quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
  56. quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
  57. quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
  58. quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
  59. quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
  60. quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
  61. quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
  62. quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
  63. quadra/configs/experiment/base/classification/classification.yaml +73 -0
  64. quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
  65. quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
  66. quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
  67. quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
  68. quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
  69. quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
  70. quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
  71. quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
  72. quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
  73. quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
  74. quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
  75. quadra/configs/experiment/base/ssl/byol.yaml +43 -0
  76. quadra/configs/experiment/base/ssl/dino.yaml +46 -0
  77. quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
  78. quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
  79. quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
  80. quadra/configs/experiment/custom/cls.yaml +12 -0
  81. quadra/configs/experiment/default.yaml +15 -0
  82. quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
  83. quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
  84. quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
  85. quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
  86. quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
  87. quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
  88. quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
  89. quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
  90. quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
  91. quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
  92. quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
  93. quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
  94. quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
  95. quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
  96. quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
  97. quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
  98. quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
  99. quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
  100. quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
  101. quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
  102. quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
  103. quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
  104. quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
  105. quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
  106. quadra/configs/export/default.yaml +13 -0
  107. quadra/configs/hydra/anomaly_custom.yaml +15 -0
  108. quadra/configs/hydra/default.yaml +14 -0
  109. quadra/configs/inference/default.yaml +26 -0
  110. quadra/configs/logger/comet.yaml +10 -0
  111. quadra/configs/logger/csv.yaml +5 -0
  112. quadra/configs/logger/mlflow.yaml +12 -0
  113. quadra/configs/logger/tensorboard.yaml +8 -0
  114. quadra/configs/loss/asl.yaml +7 -0
  115. quadra/configs/loss/barlow.yaml +2 -0
  116. quadra/configs/loss/bce.yaml +1 -0
  117. quadra/configs/loss/byol.yaml +1 -0
  118. quadra/configs/loss/cross_entropy.yaml +1 -0
  119. quadra/configs/loss/dino.yaml +8 -0
  120. quadra/configs/loss/simclr.yaml +2 -0
  121. quadra/configs/loss/simsiam.yaml +1 -0
  122. quadra/configs/loss/smp_ce.yaml +3 -0
  123. quadra/configs/loss/smp_dice.yaml +2 -0
  124. quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
  125. quadra/configs/loss/smp_mcc.yaml +2 -0
  126. quadra/configs/loss/vicreg.yaml +5 -0
  127. quadra/configs/model/anomalib/cfa.yaml +35 -0
  128. quadra/configs/model/anomalib/cflow.yaml +30 -0
  129. quadra/configs/model/anomalib/csflow.yaml +34 -0
  130. quadra/configs/model/anomalib/dfm.yaml +19 -0
  131. quadra/configs/model/anomalib/draem.yaml +29 -0
  132. quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
  133. quadra/configs/model/anomalib/fastflow.yaml +32 -0
  134. quadra/configs/model/anomalib/padim.yaml +32 -0
  135. quadra/configs/model/anomalib/patchcore.yaml +36 -0
  136. quadra/configs/model/barlow.yaml +16 -0
  137. quadra/configs/model/byol.yaml +25 -0
  138. quadra/configs/model/classification.yaml +10 -0
  139. quadra/configs/model/dino.yaml +26 -0
  140. quadra/configs/model/logistic_regression.yaml +4 -0
  141. quadra/configs/model/multilabel_classification.yaml +9 -0
  142. quadra/configs/model/simclr.yaml +18 -0
  143. quadra/configs/model/simsiam.yaml +24 -0
  144. quadra/configs/model/smp.yaml +4 -0
  145. quadra/configs/model/smp_multiclass.yaml +4 -0
  146. quadra/configs/model/vicreg.yaml +16 -0
  147. quadra/configs/optimizer/adam.yaml +5 -0
  148. quadra/configs/optimizer/adamw.yaml +3 -0
  149. quadra/configs/optimizer/default.yaml +4 -0
  150. quadra/configs/optimizer/lars.yaml +8 -0
  151. quadra/configs/optimizer/sgd.yaml +4 -0
  152. quadra/configs/scheduler/default.yaml +5 -0
  153. quadra/configs/scheduler/rop.yaml +5 -0
  154. quadra/configs/scheduler/step.yaml +3 -0
  155. quadra/configs/scheduler/warmrestart.yaml +2 -0
  156. quadra/configs/scheduler/warmup.yaml +6 -0
  157. quadra/configs/task/anomalib/cfa.yaml +5 -0
  158. quadra/configs/task/anomalib/cflow.yaml +5 -0
  159. quadra/configs/task/anomalib/csflow.yaml +5 -0
  160. quadra/configs/task/anomalib/draem.yaml +5 -0
  161. quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
  162. quadra/configs/task/anomalib/fastflow.yaml +5 -0
  163. quadra/configs/task/anomalib/inference.yaml +3 -0
  164. quadra/configs/task/anomalib/padim.yaml +5 -0
  165. quadra/configs/task/anomalib/patchcore.yaml +5 -0
  166. quadra/configs/task/classification.yaml +6 -0
  167. quadra/configs/task/classification_evaluation.yaml +6 -0
  168. quadra/configs/task/default.yaml +1 -0
  169. quadra/configs/task/segmentation.yaml +9 -0
  170. quadra/configs/task/segmentation_evaluation.yaml +3 -0
  171. quadra/configs/task/sklearn_classification.yaml +13 -0
  172. quadra/configs/task/sklearn_classification_patch.yaml +11 -0
  173. quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
  174. quadra/configs/task/sklearn_classification_test.yaml +8 -0
  175. quadra/configs/task/ssl.yaml +2 -0
  176. quadra/configs/trainer/lightning_cpu.yaml +36 -0
  177. quadra/configs/trainer/lightning_gpu.yaml +35 -0
  178. quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
  179. quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
  180. quadra/configs/trainer/lightning_multigpu.yaml +37 -0
  181. quadra/configs/trainer/sklearn_classification.yaml +7 -0
  182. quadra/configs/transforms/byol.yaml +47 -0
  183. quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
  184. quadra/configs/transforms/default.yaml +37 -0
  185. quadra/configs/transforms/default_numpy.yaml +24 -0
  186. quadra/configs/transforms/default_resize.yaml +22 -0
  187. quadra/configs/transforms/dino.yaml +63 -0
  188. quadra/configs/transforms/linear_eval.yaml +18 -0
  189. quadra/datamodules/__init__.py +20 -0
  190. quadra/datamodules/anomaly.py +180 -0
  191. quadra/datamodules/base.py +375 -0
  192. quadra/datamodules/classification.py +1003 -0
  193. quadra/datamodules/generic/__init__.py +0 -0
  194. quadra/datamodules/generic/imagenette.py +144 -0
  195. quadra/datamodules/generic/mnist.py +81 -0
  196. quadra/datamodules/generic/mvtec.py +58 -0
  197. quadra/datamodules/generic/oxford_pet.py +163 -0
  198. quadra/datamodules/patch.py +190 -0
  199. quadra/datamodules/segmentation.py +742 -0
  200. quadra/datamodules/ssl.py +140 -0
  201. quadra/datasets/__init__.py +17 -0
  202. quadra/datasets/anomaly.py +287 -0
  203. quadra/datasets/classification.py +241 -0
  204. quadra/datasets/patch.py +138 -0
  205. quadra/datasets/segmentation.py +239 -0
  206. quadra/datasets/ssl.py +110 -0
  207. quadra/losses/__init__.py +0 -0
  208. quadra/losses/classification/__init__.py +6 -0
  209. quadra/losses/classification/asl.py +83 -0
  210. quadra/losses/classification/focal.py +320 -0
  211. quadra/losses/classification/prototypical.py +148 -0
  212. quadra/losses/ssl/__init__.py +17 -0
  213. quadra/losses/ssl/barlowtwins.py +47 -0
  214. quadra/losses/ssl/byol.py +37 -0
  215. quadra/losses/ssl/dino.py +129 -0
  216. quadra/losses/ssl/hyperspherical.py +45 -0
  217. quadra/losses/ssl/idmm.py +50 -0
  218. quadra/losses/ssl/simclr.py +67 -0
  219. quadra/losses/ssl/simsiam.py +30 -0
  220. quadra/losses/ssl/vicreg.py +76 -0
  221. quadra/main.py +49 -0
  222. quadra/metrics/__init__.py +3 -0
  223. quadra/metrics/segmentation.py +251 -0
  224. quadra/models/__init__.py +0 -0
  225. quadra/models/base.py +151 -0
  226. quadra/models/classification/__init__.py +8 -0
  227. quadra/models/classification/backbones.py +149 -0
  228. quadra/models/classification/base.py +92 -0
  229. quadra/models/evaluation.py +322 -0
  230. quadra/modules/__init__.py +0 -0
  231. quadra/modules/backbone.py +30 -0
  232. quadra/modules/base.py +312 -0
  233. quadra/modules/classification/__init__.py +3 -0
  234. quadra/modules/classification/base.py +327 -0
  235. quadra/modules/ssl/__init__.py +17 -0
  236. quadra/modules/ssl/barlowtwins.py +59 -0
  237. quadra/modules/ssl/byol.py +172 -0
  238. quadra/modules/ssl/common.py +285 -0
  239. quadra/modules/ssl/dino.py +186 -0
  240. quadra/modules/ssl/hyperspherical.py +206 -0
  241. quadra/modules/ssl/idmm.py +98 -0
  242. quadra/modules/ssl/simclr.py +73 -0
  243. quadra/modules/ssl/simsiam.py +68 -0
  244. quadra/modules/ssl/vicreg.py +67 -0
  245. quadra/optimizers/__init__.py +4 -0
  246. quadra/optimizers/lars.py +153 -0
  247. quadra/optimizers/sam.py +127 -0
  248. quadra/schedulers/__init__.py +3 -0
  249. quadra/schedulers/base.py +44 -0
  250. quadra/schedulers/warmup.py +127 -0
  251. quadra/tasks/__init__.py +24 -0
  252. quadra/tasks/anomaly.py +582 -0
  253. quadra/tasks/base.py +397 -0
  254. quadra/tasks/classification.py +1263 -0
  255. quadra/tasks/patch.py +492 -0
  256. quadra/tasks/segmentation.py +389 -0
  257. quadra/tasks/ssl.py +560 -0
  258. quadra/trainers/README.md +3 -0
  259. quadra/trainers/__init__.py +0 -0
  260. quadra/trainers/classification.py +179 -0
  261. quadra/utils/__init__.py +0 -0
  262. quadra/utils/anomaly.py +112 -0
  263. quadra/utils/classification.py +618 -0
  264. quadra/utils/deprecation.py +31 -0
  265. quadra/utils/evaluation.py +474 -0
  266. quadra/utils/export.py +585 -0
  267. quadra/utils/imaging.py +32 -0
  268. quadra/utils/logger.py +15 -0
  269. quadra/utils/mlflow.py +98 -0
  270. quadra/utils/model_manager.py +320 -0
  271. quadra/utils/models.py +523 -0
  272. quadra/utils/patch/__init__.py +15 -0
  273. quadra/utils/patch/dataset.py +1433 -0
  274. quadra/utils/patch/metrics.py +449 -0
  275. quadra/utils/patch/model.py +153 -0
  276. quadra/utils/patch/visualization.py +217 -0
  277. quadra/utils/resolver.py +42 -0
  278. quadra/utils/segmentation.py +31 -0
  279. quadra/utils/tests/__init__.py +0 -0
  280. quadra/utils/tests/fixtures/__init__.py +1 -0
  281. quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
  282. quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
  283. quadra/utils/tests/fixtures/dataset/classification.py +406 -0
  284. quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
  285. quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
  286. quadra/utils/tests/fixtures/models/__init__.py +3 -0
  287. quadra/utils/tests/fixtures/models/anomaly.py +89 -0
  288. quadra/utils/tests/fixtures/models/classification.py +45 -0
  289. quadra/utils/tests/fixtures/models/segmentation.py +33 -0
  290. quadra/utils/tests/helpers.py +70 -0
  291. quadra/utils/tests/models.py +27 -0
  292. quadra/utils/utils.py +525 -0
  293. quadra/utils/validator.py +115 -0
  294. quadra/utils/visualization.py +422 -0
  295. quadra/utils/vit_explainability.py +349 -0
  296. quadra-2.2.7.dist-info/LICENSE +201 -0
  297. quadra-2.2.7.dist-info/METADATA +381 -0
  298. quadra-2.2.7.dist-info/RECORD +300 -0
  299. {quadra-0.0.1.dist-info → quadra-2.2.7.dist-info}/WHEEL +1 -1
  300. quadra-2.2.7.dist-info/entry_points.txt +3 -0
  301. quadra-0.0.1.dist-info/METADATA +0 -14
  302. quadra-0.0.1.dist-info/RECORD +0 -4
@@ -0,0 +1,138 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import random
5
+ from collections.abc import Callable
6
+
7
+ import cv2
8
+ import h5py
9
+ import numpy as np
10
+ from torch.utils.data import Dataset
11
+
12
+ from quadra.utils import utils
13
+ from quadra.utils.imaging import keep_aspect_ratio_resize
14
+ from quadra.utils.patch.dataset import compute_safe_patch_range, trisample
15
+
16
+ log = utils.get_logger(__name__)
17
+
18
+
19
+ class PatchSklearnClassificationTrainDataset(Dataset):
20
+ """Dataset used for patch sampling, it expects samples to be paths to h5 files containing all the required
21
+ information for patch sampling from images.
22
+
23
+ Args:
24
+ data_path: base path to the dataset
25
+ samples: Paths to h5 files
26
+ targets: Labels associated with each sample
27
+ class_to_idx: Mapping between class and corresponding index
28
+ resize: Whether to perform an aspect ratio resize of the patch before the transformations
29
+ transform: Optional function applied to the image
30
+ rgb: if False, image will be converted in grayscale
31
+ channel: 1 or 3. If rgb is True, then channel will be set at 3.
32
+ balance_classes: if True, the dataset will be balanced by duplicating samples of the minority class
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ data_path: str,
38
+ samples: list[str],
39
+ targets: list[str | int],
40
+ class_to_idx: dict | None = None,
41
+ resize: int | None = None,
42
+ transform: Callable | None = None,
43
+ rgb: bool = True,
44
+ channel: int = 3,
45
+ balance_classes: bool = False,
46
+ ):
47
+ super().__init__()
48
+
49
+ # Keep-Aspect-Ratio resize
50
+ self.resize = resize
51
+ self.data_path = data_path
52
+
53
+ if balance_classes:
54
+ samples_array = np.array(samples)
55
+ targets_array = np.array(targets)
56
+ samples_to_use: list[str] = []
57
+ targets_to_use: list[str | int] = []
58
+
59
+ cls, counts = np.unique(targets_array, return_counts=True)
60
+ max_count = np.max(counts)
61
+ for cl, count in zip(cls, counts):
62
+ idx_to_pick = list(np.where(targets_array == cl)[0])
63
+
64
+ if count < max_count:
65
+ idx_to_pick += random.choices(idx_to_pick, k=max_count - count)
66
+
67
+ samples_to_use.extend(samples_array[idx_to_pick])
68
+ targets_to_use.extend(targets_array[idx_to_pick])
69
+ else:
70
+ samples_to_use = samples
71
+ targets_to_use = targets
72
+
73
+ # Data
74
+ self.x = np.array(samples_to_use)
75
+ self.y = np.array(targets_to_use)
76
+
77
+ if class_to_idx is None:
78
+ unique_targets = np.unique(targets_to_use)
79
+ class_to_idx = {c: i for i, c in enumerate(unique_targets)}
80
+
81
+ self.class_to_idx = class_to_idx
82
+ self.idx_to_class = {v: k for k, v in class_to_idx.items()}
83
+
84
+ self.samples = [
85
+ (path, self.class_to_idx[self.y[i]] if self.y[i] is not None else None) for i, path in enumerate(self.x)
86
+ ]
87
+
88
+ self.rgb = rgb
89
+ self.channel = 3 if rgb else channel
90
+
91
+ self.transform = transform
92
+
93
+ def __getitem__(self, idx) -> tuple[np.ndarray, np.ndarray]:
94
+ path, y = self.samples[idx]
95
+
96
+ h5_file = h5py.File(path)
97
+ x = cv2.imread(os.path.join(self.data_path, h5_file["img_path"][()].decode("utf-8")))
98
+
99
+ weights = h5_file["triangles_weights"][()]
100
+ patch_size = h5_file["patch_size"][()]
101
+
102
+ if weights.shape[0] == 0: # pylint: disable=no-member
103
+ # If the image is completely good sample a point anywhere
104
+ patch_y = np.random.randint(0, x.shape[0] + 1)
105
+ patch_x = np.random.randint(0, x.shape[1] + 1)
106
+ else:
107
+ random_triangle = np.random.choice(weights.shape[0], p=weights)
108
+ [patch_y, patch_x] = trisample(h5_file["triangles"][random_triangle])
109
+
110
+ h5_file.close()
111
+
112
+ # If the patch is outside the image reduce the exceeding area by taking more patch from the inner area
113
+ [y_left, y_right] = compute_safe_patch_range(patch_y, patch_size[0], x.shape[0])
114
+ [x_left, x_right] = compute_safe_patch_range(patch_x, patch_size[1], x.shape[1])
115
+
116
+ x = x[patch_y - y_left : patch_y + y_right, patch_x - x_left : patch_x + x_right]
117
+
118
+ if self.rgb:
119
+ x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB)
120
+ else:
121
+ x = cv2.cvtColor(x, cv2.COLOR_BGR2GRAY)
122
+ x = cv2.cvtColor(x, cv2.COLOR_GRAY2RGB)
123
+
124
+ if self.channel == 1:
125
+ x = x[:, :, 0]
126
+
127
+ # Resize keeping aspect ratio
128
+ if self.resize:
129
+ x = keep_aspect_ratio_resize(x, self.resize)
130
+
131
+ if self.transform:
132
+ aug = self.transform(image=x)
133
+ x = aug["image"]
134
+
135
+ return x, y
136
+
137
+ def __len__(self):
138
+ return len(self.samples)
@@ -0,0 +1,239 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from collections.abc import Callable
5
+ from typing import Any
6
+
7
+ import albumentations
8
+ import cv2
9
+ import numpy as np
10
+ import torch
11
+
12
+ from quadra.utils.deprecation import deprecated
13
+ from quadra.utils.imaging import keep_aspect_ratio_resize
14
+ from quadra.utils.segmentation import smooth_mask
15
+
16
+
17
+ # DEPRECATED -> we can use SegmentationDatasetMulticlass also for one class segmentation
18
+ @deprecated("Use SegmentationDatasetMulticlass instead")
19
+ class SegmentationDataset(torch.utils.data.Dataset):
20
+ """Custom SegmentationDataset class for loading images and masks.
21
+
22
+ Args:
23
+ image_paths: List of paths to images.
24
+ mask_paths: List of paths to masks.
25
+ batch_size: Batch size.
26
+ object_masks: List of paths to object masks.
27
+ resize: Resize image to this size.
28
+ mask_preprocess: Preprocess mask.
29
+ labels: List of labels.
30
+ transform: Transformations to apply to images and masks.
31
+ mask_smoothing: Smooth mask.
32
+ defect_transform: Transformations to apply to images and masks for defects.
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ image_paths: list[str],
38
+ mask_paths: list[str],
39
+ batch_size: int | None = None,
40
+ object_masks: list[np.ndarray | Any] | None = None,
41
+ resize: int = 224,
42
+ mask_preprocess: Callable | None = None,
43
+ labels: list[str] | None = None,
44
+ transform: albumentations.Compose | None = None,
45
+ mask_smoothing: bool = False,
46
+ defect_transform: albumentations.Compose | None = None,
47
+ ):
48
+ self.transform = transform
49
+ self.defect_transform = defect_transform
50
+ self.image_paths = image_paths
51
+ self.mask_paths = mask_paths
52
+ self.labels = labels
53
+ self.mask_preprocess = mask_preprocess
54
+ self.resize = resize
55
+ self.object_masks = object_masks
56
+ self.data_len = len(self.image_paths)
57
+ self.batch_size = None if batch_size is None else max(batch_size, self.data_len)
58
+ self.smooth_mask = mask_smoothing
59
+
60
+ def __getitem__(self, index):
61
+ # This is required to avoid infinite loop when running the dataset outside of a dataloader
62
+ if self.batch_size is not None and self.batch_size == index:
63
+ raise StopIteration
64
+
65
+ if self.batch_size is None and self.data_len == index:
66
+ raise StopIteration
67
+
68
+ index = index % self.data_len
69
+ image_path = self.image_paths[index]
70
+
71
+ image = cv2.imread(str(image_path))
72
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
73
+ object_mask_path = self.object_masks[index] if self.object_masks is not None else None
74
+ if object_mask_path is not None:
75
+ object_mask = cv2.imread(str(object_mask_path), 0) if os.path.isfile(object_mask_path) else None
76
+ else:
77
+ object_mask = None
78
+ label = self.labels[index] if self.labels is not None else None
79
+ if (
80
+ self.mask_paths[index] is np.nan
81
+ or self.mask_paths[index] is None
82
+ or not os.path.isfile(self.mask_paths[index])
83
+ ):
84
+ mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
85
+ else:
86
+ mask_path = self.mask_paths[index]
87
+ mask = cv2.imread(str(mask_path), 0)
88
+ if self.defect_transform is not None and label == 1 and np.sum(mask) == 0:
89
+ if object_mask is not None:
90
+ object_mask *= 255
91
+ aug = self.defect_transform(image=image, mask=mask, object_mask=object_mask, label=label)
92
+ image = aug["image"]
93
+ mask = aug["mask"]
94
+ label = aug["label"]
95
+ if self.mask_preprocess:
96
+ mask = self.mask_preprocess(mask)
97
+ if object_mask is not None:
98
+ object_mask = self.mask_preprocess(object_mask)
99
+ if self.resize:
100
+ image = keep_aspect_ratio_resize(image, self.resize)
101
+ mask = keep_aspect_ratio_resize(mask, self.resize)
102
+ if object_mask is not None:
103
+ object_mask = keep_aspect_ratio_resize(object_mask, self.resize)
104
+
105
+ if self.transform is not None:
106
+ aug = self.transform(image=image, mask=mask)
107
+ image = aug["image"]
108
+ mask = aug["mask"]
109
+ if isinstance(mask, np.ndarray):
110
+ mask_sum = np.sum(mask)
111
+ elif isinstance(mask, torch.Tensor):
112
+ mask_sum = torch.sum(mask)
113
+ else:
114
+ raise ValueError("Unsupported type for mask")
115
+ if mask_sum > 0 and (label is None or label == 0):
116
+ label = 1
117
+ if mask_sum == 0:
118
+ label = 0
119
+
120
+ if isinstance(image, np.ndarray):
121
+ mask = (mask > 0).astype(np.uint8)
122
+
123
+ if self.smooth_mask:
124
+ mask = smooth_mask(mask)
125
+ mask = np.expand_dims(mask, axis=0)
126
+ else:
127
+ mask = (mask > 0).int()
128
+ if self.smooth_mask:
129
+ mask = torch.from_numpy(smooth_mask(mask.numpy()))
130
+ mask = mask.unsqueeze(0)
131
+
132
+ return image, mask, label
133
+
134
+ def __len__(self):
135
+ if self.batch_size is None:
136
+ return self.data_len
137
+
138
+ return max(self.data_len, self.batch_size)
139
+
140
+
141
+ class SegmentationDatasetMulticlass(torch.utils.data.Dataset):
142
+ """Custom SegmentationDataset class for loading images and multilabel masks.
143
+
144
+ Args:
145
+ image_paths: List of paths to images.
146
+ mask_paths: List of paths to masks.
147
+ idx_to_class: dict with corrispondence btw mask index and classes: {1: class_1, 2: class_2, ..., N: class_N}
148
+ batch_size: Batch size.
149
+ transform: Transformations to apply to images and masks.
150
+ one_hot: if True return a binary mask (n_classxHxW), otherwise the labelled mask HxW. SMP loss requires the
151
+ second format.
152
+ """
153
+
154
+ def __init__(
155
+ self,
156
+ image_paths: list[str],
157
+ mask_paths: list[str],
158
+ idx_to_class: dict,
159
+ batch_size: int | None = None,
160
+ transform: albumentations.Compose | None = None,
161
+ one_hot: bool = False,
162
+ ):
163
+ self.transform = transform
164
+ self.image_paths = image_paths
165
+ self.mask_paths = mask_paths
166
+ self.idx_to_class = idx_to_class
167
+ self.data_len = len(self.image_paths)
168
+ self.batch_size = None if batch_size is None else max(batch_size, self.data_len)
169
+ self.one_hot = one_hot
170
+
171
+ def _preprocess_mask(self, mask: np.ndarray):
172
+ """Function to preprocess the mask -> needed for albumentations
173
+ Args:
174
+ mask: a numpy array of dimension HxW with values in [0] + self.idx_to_class.
175
+
176
+ Output:
177
+ a binary numpy array with dims len(self.idx_to_class) + 1 x H x W
178
+ """
179
+ multilayer_mask = np.zeros((len(self.idx_to_class) + 1, *mask.shape[:2]))
180
+ # provide background information for completeness
181
+ # single channel mask does not use it anyway.
182
+ multilayer_mask[0] = (mask == 0).astype(np.uint8)
183
+ for idx in self.idx_to_class:
184
+ multilayer_mask[int(idx)] = (mask == int(idx)).astype(np.uint8)
185
+
186
+ return multilayer_mask
187
+
188
+ def __getitem__(self, index):
189
+ """Get image and mask."""
190
+ # This is required to avoid infinite loop when running the dataset outside of a dataloader
191
+ if self.batch_size is not None and self.batch_size == index:
192
+ raise StopIteration
193
+ if self.batch_size is None and self.data_len == index:
194
+ raise StopIteration
195
+
196
+ index = index % self.data_len
197
+ image_path = self.image_paths[index]
198
+
199
+ image = cv2.imread(str(image_path))
200
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
201
+
202
+ if (
203
+ self.mask_paths[index] is np.nan
204
+ or self.mask_paths[index] is None
205
+ or not os.path.isfile(self.mask_paths[index])
206
+ ):
207
+ mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
208
+ else:
209
+ mask_path = self.mask_paths[index]
210
+ mask = cv2.imread(str(mask_path), 0)
211
+
212
+ # we go back to binary masks avoid transformation errors
213
+ mask = self._preprocess_mask(mask)
214
+
215
+ if self.transform is not None:
216
+ masks = list(mask)
217
+ aug = self.transform(image=image, masks=masks)
218
+ image = aug["image"]
219
+ mask = np.stack(aug["masks"]) # C x H x W
220
+
221
+ # we compute single channel mask again
222
+ # zero is the background
223
+ if not self.one_hot: # one hot is done by smp dice loss
224
+ mask_out = np.zeros(mask.shape[1:])
225
+ for i in range(1, mask.shape[0]):
226
+ mask_out[mask[i] == 1] = i
227
+ # mask_out shape -> HxW
228
+ else:
229
+ mask_out = mask
230
+ # mask_out shape -> CxHxW where C is number of classes (included the background)
231
+
232
+ return image, mask_out.astype(int), 0
233
+
234
+ def __len__(self):
235
+ """Returns the dataset lenght."""
236
+ if self.batch_size is None:
237
+ return self.data_len
238
+
239
+ return max(self.data_len, self.batch_size)
quadra/datasets/ssl.py ADDED
@@ -0,0 +1,110 @@
1
+ from __future__ import annotations
2
+
3
+ import random
4
+ from collections.abc import Iterable
5
+ from enum import Enum
6
+
7
+ import albumentations as A
8
+ import numpy as np
9
+ from torch.utils.data import Dataset
10
+
11
+
12
+ class AugmentationStrategy(Enum):
13
+ """Augmentation Strategy for TwoAugmentationDataset."""
14
+
15
+ SAME_IMAGE = 1
16
+ SAME_CLASS = 2
17
+
18
+
19
+ class TwoAugmentationDataset(Dataset):
20
+ """Two Image Augmentation Dataset for using in self-supervised learning.
21
+
22
+ Args:
23
+ dataset: A torch Dataset object
24
+ transform: albumentation transformations for each image.
25
+ If you use single transformation, it will be applied to both images.
26
+ If you use tuple, it will be applied to first image and second image separately.
27
+ strategy: Defaults to AugmentationStrategy.SAME_IMAGE.
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ dataset: Dataset,
33
+ transform: A.Compose | tuple[A.Compose, A.Compose],
34
+ strategy: AugmentationStrategy = AugmentationStrategy.SAME_IMAGE,
35
+ ):
36
+ self.dataset = dataset
37
+ self.transform = transform
38
+ self.stategy = strategy
39
+ if isinstance(transform, Iterable) and not isinstance(transform, str) and len(set(transform)) != 2:
40
+ raise ValueError("transform must be an Iterable of length 2")
41
+
42
+ def __getitem__(self, index):
43
+ image1, target = self.dataset[index]
44
+
45
+ if self.stategy == AugmentationStrategy.SAME_IMAGE:
46
+ image2 = image1
47
+ elif self.stategy == AugmentationStrategy.SAME_CLASS:
48
+ positive_pair_idx = random.choice(np.where(self.dataset.y == target)[0])
49
+ image2, _ = self.dataset[positive_pair_idx]
50
+ else:
51
+ raise ValueError("Unknown strategy")
52
+
53
+ if isinstance(self.transform, Iterable):
54
+ image1 = self.transform[0](image=image1)["image"]
55
+ image2 = self.transform[1](image=image2)["image"]
56
+ else:
57
+ image1 = self.transform(image=image1)["image"]
58
+ image2 = self.transform(image=image2)["image"]
59
+
60
+ return [image1, image2], target
61
+
62
+ def __len__(self):
63
+ return len(self.dataset)
64
+
65
+
66
+ class TwoSetAugmentationDataset(Dataset):
67
+ """Two Set Augmentation Dataset for using in self-supervised learning (DINO).
68
+
69
+ Args:
70
+ dataset: Base dataset
71
+ global_transforms: Global transformations for each image.
72
+ local_transform: Local transformations for each image.
73
+ num_local_transforms: Number of local transformations to apply. In total you will have
74
+ two + num_local_transforms transformations for each image. First element of the array will always
75
+ return the original image.
76
+
77
+ Example:
78
+ >>> images[0] = global_transform[0](original_image)
79
+ >>> images[1] = global_transform[1](original_image)
80
+ >>> images[2:] = local_transform(s)(original_image)
81
+ """
82
+
83
+ def __init__(
84
+ self,
85
+ dataset: Dataset,
86
+ global_transforms: tuple[A.Compose, A.Compose],
87
+ local_transform: A.Compose,
88
+ num_local_transforms: int,
89
+ ):
90
+ self.dataset = dataset
91
+ self.global_transforms = global_transforms
92
+ self.local_transform = local_transform
93
+ self.num_local_transforms = num_local_transforms
94
+
95
+ if num_local_transforms < 1:
96
+ raise ValueError("num_local_transforms must be greater than 0")
97
+
98
+ def __getitem__(self, index):
99
+ original_image, target = self.dataset[index]
100
+ global_outputs = []
101
+ local_outputs = []
102
+ for global_transform in self.global_transforms:
103
+ global_outputs.append(global_transform(image=original_image)["image"])
104
+ for _ in range(self.num_local_transforms):
105
+ local_outputs.append(self.local_transform(image=original_image)["image"])
106
+ all_outputs = global_outputs + local_outputs
107
+ return all_outputs, target
108
+
109
+ def __len__(self):
110
+ return len(self.dataset)
File without changes
@@ -0,0 +1,6 @@
1
+ from .asl import AsymmetricLoss
2
+ from .focal import FocalLoss
3
+
4
+ # TODO: Implement prototypical loss as a module
5
+
6
+ __all__ = ["AsymmetricLoss", "FocalLoss"]
@@ -0,0 +1,83 @@
1
+ import torch
2
+
3
+
4
+ class AsymmetricLoss(torch.nn.Module):
5
+ """Notice - optimized version, minimizes memory allocation and gpu uploading,
6
+ favors inplace operations.
7
+
8
+ Args:
9
+ gamma_neg: gamma for negative samples
10
+ gamma_pos: gamma for positive samples
11
+ m: bias value added to negative samples
12
+ eps: epsilon to avoid division by zero
13
+ disable_torch_grad_focal_loss: if True, disables torch grad for focal loss
14
+ apply_sigmoid: if True, applies sigmoid to input before computing loss
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ gamma_neg: float = 4,
20
+ gamma_pos: float = 0,
21
+ m: float = 0.05,
22
+ eps: float = 1e-8,
23
+ disable_torch_grad_focal_loss: bool = False,
24
+ apply_sigmoid: bool = True,
25
+ ):
26
+ super().__init__()
27
+
28
+ self.gamma_neg = gamma_neg
29
+ self.gamma_pos = gamma_pos
30
+ self.m = m
31
+ self.eps = eps
32
+ self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
33
+ self.apply_sigmoid = apply_sigmoid
34
+
35
+ # prevent memory allocation and gpu uploading every iteration, and encourages inplace operations
36
+ self.targets: torch.Tensor
37
+ self.anti_targets: torch.Tensor
38
+ self.xs_pos: torch.Tensor
39
+ self.xs_neg: torch.Tensor
40
+ self.asymmetric_w: torch.Tensor
41
+ self.loss: torch.Tensor
42
+
43
+ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
44
+ """Compute the asymmetric loss.
45
+
46
+ Args:
47
+ x: input logits (after sigmoid)
48
+ y: targets (multi-label binarized vector)
49
+
50
+ Returns:
51
+ asymettric loss
52
+ """
53
+ self.targets = y
54
+ self.anti_targets = 1 - y
55
+
56
+ # Calculating Probabilities
57
+ self.xs_pos = x
58
+ if self.apply_sigmoid:
59
+ self.xs_pos = torch.sigmoid(self.xs_pos)
60
+ self.xs_neg = 1.0 - self.xs_pos
61
+
62
+ # Asymmetric clipping
63
+ if self.m is not None and self.m > 0:
64
+ self.xs_neg.add_(self.m).clamp_(max=1)
65
+
66
+ # Basic CE calculation
67
+ self.loss = self.targets * torch.log(self.xs_pos.clamp(min=self.eps))
68
+ self.loss.add_(self.anti_targets * torch.log(self.xs_neg.clamp(min=self.eps)))
69
+
70
+ # Asymmetric Focusing
71
+ if self.gamma_neg > 0 or self.gamma_pos > 0:
72
+ if self.disable_torch_grad_focal_loss:
73
+ torch.set_grad_enabled(False)
74
+ self.xs_pos = self.xs_pos * self.targets
75
+ self.xs_neg = self.xs_neg * self.anti_targets
76
+ self.asymmetric_w = torch.pow(
77
+ 1 - self.xs_pos - self.xs_neg, self.gamma_pos * self.targets + self.gamma_neg * self.anti_targets
78
+ )
79
+ if self.disable_torch_grad_focal_loss:
80
+ torch.set_grad_enabled(True)
81
+ self.loss *= self.asymmetric_w
82
+
83
+ return -self.loss.sum()