quadra 0.0.1__py3-none-any.whl → 2.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (302) hide show
  1. hydra_plugins/quadra_searchpath_plugin.py +30 -0
  2. quadra/__init__.py +6 -0
  3. quadra/callbacks/__init__.py +0 -0
  4. quadra/callbacks/anomalib.py +289 -0
  5. quadra/callbacks/lightning.py +501 -0
  6. quadra/callbacks/mlflow.py +291 -0
  7. quadra/callbacks/scheduler.py +69 -0
  8. quadra/configs/__init__.py +0 -0
  9. quadra/configs/backbone/caformer_m36.yaml +8 -0
  10. quadra/configs/backbone/caformer_s36.yaml +8 -0
  11. quadra/configs/backbone/convnextv2_base.yaml +8 -0
  12. quadra/configs/backbone/convnextv2_femto.yaml +8 -0
  13. quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
  14. quadra/configs/backbone/dino_vitb8.yaml +12 -0
  15. quadra/configs/backbone/dino_vits8.yaml +12 -0
  16. quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
  17. quadra/configs/backbone/dinov2_vits14.yaml +12 -0
  18. quadra/configs/backbone/efficientnet_b0.yaml +8 -0
  19. quadra/configs/backbone/efficientnet_b1.yaml +8 -0
  20. quadra/configs/backbone/efficientnet_b2.yaml +8 -0
  21. quadra/configs/backbone/efficientnet_b3.yaml +8 -0
  22. quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
  23. quadra/configs/backbone/levit_128s.yaml +8 -0
  24. quadra/configs/backbone/mnasnet0_5.yaml +9 -0
  25. quadra/configs/backbone/resnet101.yaml +8 -0
  26. quadra/configs/backbone/resnet18.yaml +8 -0
  27. quadra/configs/backbone/resnet18_ssl.yaml +8 -0
  28. quadra/configs/backbone/resnet50.yaml +8 -0
  29. quadra/configs/backbone/smp.yaml +9 -0
  30. quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
  31. quadra/configs/backbone/unetr.yaml +15 -0
  32. quadra/configs/backbone/vit16_base.yaml +9 -0
  33. quadra/configs/backbone/vit16_small.yaml +9 -0
  34. quadra/configs/backbone/vit16_tiny.yaml +9 -0
  35. quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
  36. quadra/configs/callbacks/all.yaml +32 -0
  37. quadra/configs/callbacks/default.yaml +37 -0
  38. quadra/configs/callbacks/default_anomalib.yaml +67 -0
  39. quadra/configs/config.yaml +33 -0
  40. quadra/configs/core/default.yaml +11 -0
  41. quadra/configs/datamodule/base/anomaly.yaml +16 -0
  42. quadra/configs/datamodule/base/classification.yaml +21 -0
  43. quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
  44. quadra/configs/datamodule/base/segmentation.yaml +18 -0
  45. quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
  46. quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
  47. quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
  48. quadra/configs/datamodule/base/ssl.yaml +21 -0
  49. quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
  50. quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
  51. quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
  52. quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
  53. quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
  54. quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
  55. quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
  56. quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
  57. quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
  58. quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
  59. quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
  60. quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
  61. quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
  62. quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
  63. quadra/configs/experiment/base/classification/classification.yaml +73 -0
  64. quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
  65. quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
  66. quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
  67. quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
  68. quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
  69. quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
  70. quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
  71. quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
  72. quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
  73. quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
  74. quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
  75. quadra/configs/experiment/base/ssl/byol.yaml +43 -0
  76. quadra/configs/experiment/base/ssl/dino.yaml +46 -0
  77. quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
  78. quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
  79. quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
  80. quadra/configs/experiment/custom/cls.yaml +12 -0
  81. quadra/configs/experiment/default.yaml +15 -0
  82. quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
  83. quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
  84. quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
  85. quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
  86. quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
  87. quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
  88. quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
  89. quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
  90. quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
  91. quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
  92. quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
  93. quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
  94. quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
  95. quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
  96. quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
  97. quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
  98. quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
  99. quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
  100. quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
  101. quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
  102. quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
  103. quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
  104. quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
  105. quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
  106. quadra/configs/export/default.yaml +13 -0
  107. quadra/configs/hydra/anomaly_custom.yaml +15 -0
  108. quadra/configs/hydra/default.yaml +14 -0
  109. quadra/configs/inference/default.yaml +26 -0
  110. quadra/configs/logger/comet.yaml +10 -0
  111. quadra/configs/logger/csv.yaml +5 -0
  112. quadra/configs/logger/mlflow.yaml +12 -0
  113. quadra/configs/logger/tensorboard.yaml +8 -0
  114. quadra/configs/loss/asl.yaml +7 -0
  115. quadra/configs/loss/barlow.yaml +2 -0
  116. quadra/configs/loss/bce.yaml +1 -0
  117. quadra/configs/loss/byol.yaml +1 -0
  118. quadra/configs/loss/cross_entropy.yaml +1 -0
  119. quadra/configs/loss/dino.yaml +8 -0
  120. quadra/configs/loss/simclr.yaml +2 -0
  121. quadra/configs/loss/simsiam.yaml +1 -0
  122. quadra/configs/loss/smp_ce.yaml +3 -0
  123. quadra/configs/loss/smp_dice.yaml +2 -0
  124. quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
  125. quadra/configs/loss/smp_mcc.yaml +2 -0
  126. quadra/configs/loss/vicreg.yaml +5 -0
  127. quadra/configs/model/anomalib/cfa.yaml +35 -0
  128. quadra/configs/model/anomalib/cflow.yaml +30 -0
  129. quadra/configs/model/anomalib/csflow.yaml +34 -0
  130. quadra/configs/model/anomalib/dfm.yaml +19 -0
  131. quadra/configs/model/anomalib/draem.yaml +29 -0
  132. quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
  133. quadra/configs/model/anomalib/fastflow.yaml +32 -0
  134. quadra/configs/model/anomalib/padim.yaml +32 -0
  135. quadra/configs/model/anomalib/patchcore.yaml +36 -0
  136. quadra/configs/model/barlow.yaml +16 -0
  137. quadra/configs/model/byol.yaml +25 -0
  138. quadra/configs/model/classification.yaml +10 -0
  139. quadra/configs/model/dino.yaml +26 -0
  140. quadra/configs/model/logistic_regression.yaml +4 -0
  141. quadra/configs/model/multilabel_classification.yaml +9 -0
  142. quadra/configs/model/simclr.yaml +18 -0
  143. quadra/configs/model/simsiam.yaml +24 -0
  144. quadra/configs/model/smp.yaml +4 -0
  145. quadra/configs/model/smp_multiclass.yaml +4 -0
  146. quadra/configs/model/vicreg.yaml +16 -0
  147. quadra/configs/optimizer/adam.yaml +5 -0
  148. quadra/configs/optimizer/adamw.yaml +3 -0
  149. quadra/configs/optimizer/default.yaml +4 -0
  150. quadra/configs/optimizer/lars.yaml +8 -0
  151. quadra/configs/optimizer/sgd.yaml +4 -0
  152. quadra/configs/scheduler/default.yaml +5 -0
  153. quadra/configs/scheduler/rop.yaml +5 -0
  154. quadra/configs/scheduler/step.yaml +3 -0
  155. quadra/configs/scheduler/warmrestart.yaml +2 -0
  156. quadra/configs/scheduler/warmup.yaml +6 -0
  157. quadra/configs/task/anomalib/cfa.yaml +5 -0
  158. quadra/configs/task/anomalib/cflow.yaml +5 -0
  159. quadra/configs/task/anomalib/csflow.yaml +5 -0
  160. quadra/configs/task/anomalib/draem.yaml +5 -0
  161. quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
  162. quadra/configs/task/anomalib/fastflow.yaml +5 -0
  163. quadra/configs/task/anomalib/inference.yaml +3 -0
  164. quadra/configs/task/anomalib/padim.yaml +5 -0
  165. quadra/configs/task/anomalib/patchcore.yaml +5 -0
  166. quadra/configs/task/classification.yaml +6 -0
  167. quadra/configs/task/classification_evaluation.yaml +6 -0
  168. quadra/configs/task/default.yaml +1 -0
  169. quadra/configs/task/segmentation.yaml +9 -0
  170. quadra/configs/task/segmentation_evaluation.yaml +3 -0
  171. quadra/configs/task/sklearn_classification.yaml +13 -0
  172. quadra/configs/task/sklearn_classification_patch.yaml +11 -0
  173. quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
  174. quadra/configs/task/sklearn_classification_test.yaml +8 -0
  175. quadra/configs/task/ssl.yaml +2 -0
  176. quadra/configs/trainer/lightning_cpu.yaml +36 -0
  177. quadra/configs/trainer/lightning_gpu.yaml +35 -0
  178. quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
  179. quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
  180. quadra/configs/trainer/lightning_multigpu.yaml +37 -0
  181. quadra/configs/trainer/sklearn_classification.yaml +7 -0
  182. quadra/configs/transforms/byol.yaml +47 -0
  183. quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
  184. quadra/configs/transforms/default.yaml +37 -0
  185. quadra/configs/transforms/default_numpy.yaml +24 -0
  186. quadra/configs/transforms/default_resize.yaml +22 -0
  187. quadra/configs/transforms/dino.yaml +63 -0
  188. quadra/configs/transforms/linear_eval.yaml +18 -0
  189. quadra/datamodules/__init__.py +20 -0
  190. quadra/datamodules/anomaly.py +180 -0
  191. quadra/datamodules/base.py +375 -0
  192. quadra/datamodules/classification.py +1003 -0
  193. quadra/datamodules/generic/__init__.py +0 -0
  194. quadra/datamodules/generic/imagenette.py +144 -0
  195. quadra/datamodules/generic/mnist.py +81 -0
  196. quadra/datamodules/generic/mvtec.py +58 -0
  197. quadra/datamodules/generic/oxford_pet.py +163 -0
  198. quadra/datamodules/patch.py +190 -0
  199. quadra/datamodules/segmentation.py +742 -0
  200. quadra/datamodules/ssl.py +140 -0
  201. quadra/datasets/__init__.py +17 -0
  202. quadra/datasets/anomaly.py +287 -0
  203. quadra/datasets/classification.py +241 -0
  204. quadra/datasets/patch.py +138 -0
  205. quadra/datasets/segmentation.py +239 -0
  206. quadra/datasets/ssl.py +110 -0
  207. quadra/losses/__init__.py +0 -0
  208. quadra/losses/classification/__init__.py +6 -0
  209. quadra/losses/classification/asl.py +83 -0
  210. quadra/losses/classification/focal.py +320 -0
  211. quadra/losses/classification/prototypical.py +148 -0
  212. quadra/losses/ssl/__init__.py +17 -0
  213. quadra/losses/ssl/barlowtwins.py +47 -0
  214. quadra/losses/ssl/byol.py +37 -0
  215. quadra/losses/ssl/dino.py +129 -0
  216. quadra/losses/ssl/hyperspherical.py +45 -0
  217. quadra/losses/ssl/idmm.py +50 -0
  218. quadra/losses/ssl/simclr.py +67 -0
  219. quadra/losses/ssl/simsiam.py +30 -0
  220. quadra/losses/ssl/vicreg.py +76 -0
  221. quadra/main.py +46 -0
  222. quadra/metrics/__init__.py +3 -0
  223. quadra/metrics/segmentation.py +251 -0
  224. quadra/models/__init__.py +0 -0
  225. quadra/models/base.py +151 -0
  226. quadra/models/classification/__init__.py +8 -0
  227. quadra/models/classification/backbones.py +149 -0
  228. quadra/models/classification/base.py +92 -0
  229. quadra/models/evaluation.py +322 -0
  230. quadra/modules/__init__.py +0 -0
  231. quadra/modules/backbone.py +30 -0
  232. quadra/modules/base.py +312 -0
  233. quadra/modules/classification/__init__.py +3 -0
  234. quadra/modules/classification/base.py +331 -0
  235. quadra/modules/ssl/__init__.py +17 -0
  236. quadra/modules/ssl/barlowtwins.py +59 -0
  237. quadra/modules/ssl/byol.py +172 -0
  238. quadra/modules/ssl/common.py +285 -0
  239. quadra/modules/ssl/dino.py +186 -0
  240. quadra/modules/ssl/hyperspherical.py +206 -0
  241. quadra/modules/ssl/idmm.py +98 -0
  242. quadra/modules/ssl/simclr.py +73 -0
  243. quadra/modules/ssl/simsiam.py +68 -0
  244. quadra/modules/ssl/vicreg.py +67 -0
  245. quadra/optimizers/__init__.py +4 -0
  246. quadra/optimizers/lars.py +153 -0
  247. quadra/optimizers/sam.py +127 -0
  248. quadra/schedulers/__init__.py +3 -0
  249. quadra/schedulers/base.py +44 -0
  250. quadra/schedulers/warmup.py +127 -0
  251. quadra/tasks/__init__.py +24 -0
  252. quadra/tasks/anomaly.py +582 -0
  253. quadra/tasks/base.py +397 -0
  254. quadra/tasks/classification.py +1264 -0
  255. quadra/tasks/patch.py +492 -0
  256. quadra/tasks/segmentation.py +389 -0
  257. quadra/tasks/ssl.py +560 -0
  258. quadra/trainers/README.md +3 -0
  259. quadra/trainers/__init__.py +0 -0
  260. quadra/trainers/classification.py +179 -0
  261. quadra/utils/__init__.py +0 -0
  262. quadra/utils/anomaly.py +112 -0
  263. quadra/utils/classification.py +618 -0
  264. quadra/utils/deprecation.py +31 -0
  265. quadra/utils/evaluation.py +474 -0
  266. quadra/utils/export.py +579 -0
  267. quadra/utils/imaging.py +32 -0
  268. quadra/utils/logger.py +15 -0
  269. quadra/utils/mlflow.py +98 -0
  270. quadra/utils/model_manager.py +320 -0
  271. quadra/utils/models.py +524 -0
  272. quadra/utils/patch/__init__.py +15 -0
  273. quadra/utils/patch/dataset.py +1433 -0
  274. quadra/utils/patch/metrics.py +449 -0
  275. quadra/utils/patch/model.py +153 -0
  276. quadra/utils/patch/visualization.py +217 -0
  277. quadra/utils/resolver.py +42 -0
  278. quadra/utils/segmentation.py +31 -0
  279. quadra/utils/tests/__init__.py +0 -0
  280. quadra/utils/tests/fixtures/__init__.py +1 -0
  281. quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
  282. quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
  283. quadra/utils/tests/fixtures/dataset/classification.py +406 -0
  284. quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
  285. quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
  286. quadra/utils/tests/fixtures/models/__init__.py +3 -0
  287. quadra/utils/tests/fixtures/models/anomaly.py +89 -0
  288. quadra/utils/tests/fixtures/models/classification.py +45 -0
  289. quadra/utils/tests/fixtures/models/segmentation.py +33 -0
  290. quadra/utils/tests/helpers.py +70 -0
  291. quadra/utils/tests/models.py +27 -0
  292. quadra/utils/utils.py +525 -0
  293. quadra/utils/validator.py +115 -0
  294. quadra/utils/visualization.py +422 -0
  295. quadra/utils/vit_explainability.py +349 -0
  296. quadra-2.1.13.dist-info/LICENSE +201 -0
  297. quadra-2.1.13.dist-info/METADATA +386 -0
  298. quadra-2.1.13.dist-info/RECORD +300 -0
  299. {quadra-0.0.1.dist-info → quadra-2.1.13.dist-info}/WHEEL +1 -1
  300. quadra-2.1.13.dist-info/entry_points.txt +3 -0
  301. quadra-0.0.1.dist-info/METADATA +0 -14
  302. quadra-0.0.1.dist-info/RECORD +0 -4
@@ -0,0 +1,375 @@
1
+ from __future__ import annotations
2
+
3
+ import multiprocessing as mp
4
+ import multiprocessing.pool as mpp
5
+ import os
6
+ import pickle as pkl
7
+ import typing
8
+ from collections.abc import Callable, Iterable, Sequence
9
+ from functools import wraps
10
+ from typing import Any, Literal, Union, cast
11
+
12
+ import albumentations
13
+ import numpy as np
14
+ import pandas as pd
15
+ import torch
16
+ import xxhash
17
+ from pytorch_lightning import LightningDataModule
18
+ from tqdm import tqdm
19
+
20
+ from quadra.utils import utils
21
+
22
+ log = utils.get_logger(__name__)
23
+ TrainDataset = Union[torch.utils.data.Dataset, Sequence[torch.utils.data.Dataset]]
24
+ ValDataset = Union[torch.utils.data.Dataset, Sequence[torch.utils.data.Dataset]]
25
+ TestDataset = torch.utils.data.Dataset
26
+
27
+
28
+ def load_data_from_disk_dec(func):
29
+ """Load data from disk if it exists."""
30
+
31
+ @wraps(func)
32
+ def wrapper(*args, **kwargs):
33
+ """Wrapper function to load data from disk if it exists."""
34
+ self = cast(BaseDataModule, args[0])
35
+ self.restore_checkpoint()
36
+ return func(*args, **kwargs)
37
+
38
+ return wrapper
39
+
40
+
41
+ class DecorateParentMethod(type):
42
+ """Metaclass to decorate methods of subclasses."""
43
+
44
+ def __new__(cls, name, bases, dct):
45
+ """Create new decorator for parent class methods."""
46
+ method_decorator_mapper = {
47
+ "setup": load_data_from_disk_dec,
48
+ }
49
+ for method_name, decorator in method_decorator_mapper.items():
50
+ if method_name in dct:
51
+ dct[method_name] = decorator(dct[method_name])
52
+
53
+ return super().__new__(cls, name, bases, dct)
54
+
55
+
56
+ def compute_file_content_hash(path: str, hash_size: Literal[32, 64, 128] = 64) -> str:
57
+ """Get hash of a file based on its content.
58
+
59
+ Args:
60
+ path: Path to the file.
61
+ hash_size: Size of the hash. Must be one of [32, 64, 128].
62
+
63
+ Returns:
64
+ The hash of the file.
65
+ """
66
+ with open(path, "rb") as f:
67
+ data = f.read()
68
+
69
+ if hash_size == 32:
70
+ file_hash = xxhash.xxh32(data, seed=42).hexdigest()
71
+ elif hash_size == 64:
72
+ file_hash = xxhash.xxh64(data, seed=42).hexdigest()
73
+ elif hash_size == 128:
74
+ file_hash = xxhash.xxh128(data, seed=42).hexdigest()
75
+ else:
76
+ raise ValueError(f"Invalid hash size {hash_size}. Must be one of [32, 64, 128].")
77
+
78
+ return file_hash
79
+
80
+
81
+ def compute_file_size_hash(path: str, hash_size: Literal[32, 64, 128] = 64) -> str:
82
+ """Get hash of a file based on its size.
83
+
84
+ Args:
85
+ path: Path to the file.
86
+ hash_size: Size of the hash. Must be one of [32, 64, 128].
87
+
88
+ Returns:
89
+ The hash of the file.
90
+ """
91
+ data = str(os.path.getsize(path))
92
+
93
+ if hash_size == 32:
94
+ file_hash = xxhash.xxh32(data, seed=42).hexdigest()
95
+ elif hash_size == 64:
96
+ file_hash = xxhash.xxh64(data, seed=42).hexdigest()
97
+ elif hash_size == 128:
98
+ file_hash = xxhash.xxh128(data, seed=42).hexdigest()
99
+ else:
100
+ raise ValueError(f"Invalid hash size {hash_size}. Must be one of [32, 64, 128].")
101
+
102
+ return file_hash
103
+
104
+
105
+ @typing.no_type_check
106
+ def istarmap(self, func: Callable, iterable: Iterable, chunksize: int = 1):
107
+ # pylint: disable=all
108
+ """Starmap-version of imap."""
109
+ self._check_running()
110
+ if chunksize < 1:
111
+ raise ValueError(f"Chunksize must be 1+, not {chunksize:n}")
112
+
113
+ task_batches = mpp.Pool._get_tasks(func, iterable, chunksize)
114
+ result = mpp.IMapIterator(self)
115
+ self._taskqueue.put((self._guarded_task_generation(result._job, mpp.starmapstar, task_batches), result._set_length))
116
+ return (item for chunk in result for item in chunk)
117
+
118
+
119
+ # Patch Pool class to include istarmap
120
+ mpp.Pool.istarmap = istarmap # type: ignore[attr-defined]
121
+
122
+
123
+ class BaseDataModule(LightningDataModule, metaclass=DecorateParentMethod):
124
+ """Base class for all data modules.
125
+
126
+ Args:
127
+ data_path: Path to the data main folder.
128
+ name: The name for the data module. Defaults to "base_datamodule".
129
+ num_workers: Number of workers for dataloaders. Defaults to 16.
130
+ batch_size: Batch size. Defaults to 32.
131
+ seed: Random generator seed. Defaults to 42.
132
+ train_transform: Transformations for train dataset.
133
+ Defaults to None.
134
+ val_transform: Transformations for validation dataset.
135
+ Defaults to None.
136
+ test_transform: Transformations for test dataset.
137
+ Defaults to None.
138
+ enable_hashing: Whether to enable hashing of images. Defaults to True.
139
+ hash_size: Size of the hash. Must be one of [32, 64, 128]. Defaults to 64.
140
+ hash_type: Type of hash to use, if content hash is used, the hash is computed on the file content, otherwise
141
+ the hash is computed on the file size which is faster but less safe. Defaults to "content".
142
+ """
143
+
144
+ def __init__(
145
+ self,
146
+ data_path: str,
147
+ name: str = "base_datamodule",
148
+ num_workers: int = 16,
149
+ batch_size: int = 32,
150
+ seed: int = 42,
151
+ load_aug_images: bool = False,
152
+ aug_name: str | None = None,
153
+ n_aug_to_take: int | None = None,
154
+ replace_str_from: str | None = None,
155
+ replace_str_to: str | None = None,
156
+ train_transform: albumentations.Compose | None = None,
157
+ val_transform: albumentations.Compose | None = None,
158
+ test_transform: albumentations.Compose | None = None,
159
+ enable_hashing: bool = True,
160
+ hash_size: Literal[32, 64, 128] = 64,
161
+ hash_type: Literal["content", "size"] = "content",
162
+ ):
163
+ super().__init__()
164
+ self.num_workers = num_workers
165
+ self.batch_size = batch_size
166
+ self.seed = seed
167
+ self.data_path = data_path
168
+ self.name = name
169
+ self.train_transform = train_transform
170
+ self.val_transform = val_transform
171
+ self.test_transform = test_transform
172
+ self.enable_hashing = enable_hashing
173
+ self.hash_size = hash_size
174
+ self.hash_type = hash_type
175
+
176
+ if self.hash_size not in [32, 64, 128]:
177
+ raise ValueError(f"Invalid hash size {self.hash_size}. Must be one of [32, 64, 128].")
178
+
179
+ self.load_aug_images = load_aug_images
180
+ self.aug_name = aug_name
181
+ self.n_aug_to_take = n_aug_to_take
182
+ self.replace_str_from = replace_str_from
183
+ self.replace_str_to = replace_str_to
184
+ self.extra_args: dict[str, Any] = {}
185
+ self.train_dataset: TrainDataset
186
+ self.val_dataset: ValDataset
187
+ self.test_dataset: TestDataset
188
+ self.data: pd.DataFrame
189
+ self.data_folder = "data"
190
+ os.makedirs(self.data_folder, exist_ok=True)
191
+ self.datamodule_checkpoint_file = os.path.join(self.data_folder, "datamodule.pkl")
192
+ self.dataset_file = os.path.join(self.data_folder, "dataset.csv")
193
+
194
+ @property
195
+ def train_data(self) -> pd.DataFrame:
196
+ """Get train data."""
197
+ if not hasattr(self, "data"):
198
+ raise ValueError("`data` attribute is not set. Cannot load train data.")
199
+ return self.data[self.data["split"] == "train"]
200
+
201
+ @property
202
+ def val_data(self) -> pd.DataFrame:
203
+ """Get validation data."""
204
+ if not hasattr(self, "data"):
205
+ raise ValueError("`data` attribute is not set. Cannot load val data.")
206
+ return self.data[self.data["split"] == "val"]
207
+
208
+ @property
209
+ def test_data(self) -> pd.DataFrame:
210
+ """Get test data."""
211
+ if not hasattr(self, "data"):
212
+ raise ValueError("`data` attribute is not set. Cannot load test data.")
213
+ return self.data[self.data["split"] == "test"]
214
+
215
+ def _dataset_available(self, dataset_name: str) -> bool:
216
+ """Checks if the dataset is available.
217
+
218
+ Args:
219
+ dataset_name : Name of the dataset attribute.
220
+
221
+ Returns:
222
+ True if the dataset is available, False otherwise.
223
+ """
224
+ available = hasattr(self, dataset_name) and getattr(self, dataset_name) is not None
225
+ if available:
226
+ dataset_attr = getattr(self, dataset_name)
227
+ if isinstance(dataset_attr, list):
228
+ available = all(len(d) > 0 for d in dataset_attr)
229
+ else:
230
+ available = len(dataset_attr) > 0
231
+ return available
232
+
233
+ @property
234
+ def train_dataset_available(self) -> bool:
235
+ """Checks if the train dataset is available."""
236
+ return self._dataset_available("train_dataset")
237
+
238
+ @property
239
+ def val_dataset_available(self) -> bool:
240
+ """Checks if the validation dataset is available."""
241
+ return self._dataset_available("val_dataset")
242
+
243
+ @property
244
+ def test_dataset_available(self) -> bool:
245
+ """Checks if the test dataset is available."""
246
+ return self._dataset_available("test_dataset")
247
+
248
+ def _prepare_data(self) -> None:
249
+ """Prepares the data, this should have exactly the same logic as the prepare_data method
250
+ of a LightningModule.
251
+ """
252
+ raise NotImplementedError(
253
+ "This method must be implemented, it should contain all the logic that normally is "
254
+ "contained in the prepare_data method of a LightningModule."
255
+ )
256
+
257
+ def hash_data(self) -> None:
258
+ """Computes the hash of the files inside the datasets."""
259
+ if not self.enable_hashing:
260
+ return
261
+
262
+ # TODO: We need to find a way to annotate the columns of data.
263
+ paths_and_hash_length = zip(self.data["samples"], [self.hash_size] * len(self.data))
264
+
265
+ with mp.Pool(min(8, mp.cpu_count() - 1)) as pool:
266
+ self.data["hash"] = list(
267
+ tqdm(
268
+ pool.istarmap( # type: ignore[attr-defined]
269
+ compute_file_content_hash if self.hash_type == "content" else compute_file_size_hash,
270
+ paths_and_hash_length,
271
+ ),
272
+ total=len(self.data),
273
+ desc="Computing hashes",
274
+ )
275
+ )
276
+
277
+ self.data["hash_type"] = self.hash_type
278
+
279
+ def prepare_data(self) -> None:
280
+ """Prepares the data, should be overridden by subclasses."""
281
+ if hasattr(self, "data"):
282
+ return
283
+
284
+ self._prepare_data()
285
+ self.hash_data()
286
+ self.save_checkpoint()
287
+
288
+ def __getstate__(self) -> dict[str, Any]:
289
+ """This method is called when pickling the object.
290
+ It's useful to remove attributes that shouldn't be pickled.
291
+ """
292
+ state = self.__dict__.copy()
293
+ if "trainer" in state:
294
+ # Lightning injects the trainer in the datamodule, we don't want to pickle it.
295
+ del state["trainer"]
296
+
297
+ return state
298
+
299
+ def save_checkpoint(self) -> None:
300
+ """Saves the datamodule to disk, utility function that is called from prepare_data. We are required to save
301
+ datamodule to disk because we can't assign attributes to the datamodule in prepare_data when working with
302
+ multiple gpus.
303
+ """
304
+ if not os.path.exists(self.datamodule_checkpoint_file) and not os.path.exists(self.dataset_file):
305
+ with open(self.datamodule_checkpoint_file, "wb") as f:
306
+ pkl.dump(self, f)
307
+
308
+ self.data.to_csv(self.dataset_file, index=False)
309
+ log.info("Datamodule checkpoint saved to disk.")
310
+
311
+ if "targets" in self.data:
312
+ if isinstance(self.data["targets"].iloc[0], np.ndarray):
313
+ # If we find a numpy array target it's very likely one hot encoded,
314
+ # in that case we just print the number of train/val/test samples
315
+ grouping = ["split"]
316
+ else:
317
+ grouping = ["split", "targets"]
318
+ log.info("Dataset Info:")
319
+ split_order = {"train": 0, "val": 1, "test": 2}
320
+ log.info(
321
+ "\n%s",
322
+ self.data.groupby(grouping)
323
+ .size()
324
+ .to_frame()
325
+ .reset_index()
326
+ .sort_values(by=["split"], key=lambda x: x.map(split_order))
327
+ .rename(columns={0: "count"})
328
+ .to_string(index=False),
329
+ )
330
+
331
+ def restore_checkpoint(self) -> None:
332
+ """Loads the data from disk, utility function that should be called from setup."""
333
+ if hasattr(self, "data"):
334
+ return
335
+
336
+ if not os.path.isfile(self.datamodule_checkpoint_file):
337
+ raise ValueError(f"Dataset file {self.datamodule_checkpoint_file} does not exist.")
338
+
339
+ with open(self.datamodule_checkpoint_file, "rb") as f:
340
+ checkpoint_datamodule = pkl.load(f)
341
+ for key, value in checkpoint_datamodule.__dict__.items():
342
+ setattr(self, key, value)
343
+
344
+ # TODO: Check if this function can be removed
345
+ def load_augmented_samples(
346
+ self,
347
+ samples: list[str],
348
+ targets: list[Any],
349
+ replace_str_from: str | None = None,
350
+ replace_str_to: str | None = None,
351
+ shuffle: bool = False,
352
+ ) -> tuple[list[str], list[str]]:
353
+ """Loads augmented samples."""
354
+ if self.n_aug_to_take is None:
355
+ raise ValueError("`n_aug_to_take` is not set. Cannot load augmented samples.")
356
+ aug_samples = []
357
+ aug_labels = []
358
+ for sample, label in zip(samples, targets):
359
+ aug_samples.append(sample)
360
+ aug_labels.append(label)
361
+ final_sample = sample
362
+ if replace_str_from is not None and replace_str_to is not None:
363
+ final_sample = final_sample.replace(replace_str_from, replace_str_to)
364
+ base, ext = os.path.splitext(final_sample)
365
+ for k in range(self.n_aug_to_take):
366
+ aug_samples.append(base + "_" + str(k + 1) + ext)
367
+ aug_labels.append(label)
368
+ samples = aug_samples
369
+ targets = aug_labels
370
+ if shuffle:
371
+ idexs = np.arange(len(aug_samples))
372
+ np.random.shuffle(idexs)
373
+ samples = np.array(samples)[idexs].tolist()
374
+ targets = np.array(targets)[idexs].tolist()
375
+ return samples, targets