quadra 0.0.1__py3-none-any.whl → 2.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (302) hide show
  1. hydra_plugins/quadra_searchpath_plugin.py +30 -0
  2. quadra/__init__.py +6 -0
  3. quadra/callbacks/__init__.py +0 -0
  4. quadra/callbacks/anomalib.py +289 -0
  5. quadra/callbacks/lightning.py +501 -0
  6. quadra/callbacks/mlflow.py +291 -0
  7. quadra/callbacks/scheduler.py +69 -0
  8. quadra/configs/__init__.py +0 -0
  9. quadra/configs/backbone/caformer_m36.yaml +8 -0
  10. quadra/configs/backbone/caformer_s36.yaml +8 -0
  11. quadra/configs/backbone/convnextv2_base.yaml +8 -0
  12. quadra/configs/backbone/convnextv2_femto.yaml +8 -0
  13. quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
  14. quadra/configs/backbone/dino_vitb8.yaml +12 -0
  15. quadra/configs/backbone/dino_vits8.yaml +12 -0
  16. quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
  17. quadra/configs/backbone/dinov2_vits14.yaml +12 -0
  18. quadra/configs/backbone/efficientnet_b0.yaml +8 -0
  19. quadra/configs/backbone/efficientnet_b1.yaml +8 -0
  20. quadra/configs/backbone/efficientnet_b2.yaml +8 -0
  21. quadra/configs/backbone/efficientnet_b3.yaml +8 -0
  22. quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
  23. quadra/configs/backbone/levit_128s.yaml +8 -0
  24. quadra/configs/backbone/mnasnet0_5.yaml +9 -0
  25. quadra/configs/backbone/resnet101.yaml +8 -0
  26. quadra/configs/backbone/resnet18.yaml +8 -0
  27. quadra/configs/backbone/resnet18_ssl.yaml +8 -0
  28. quadra/configs/backbone/resnet50.yaml +8 -0
  29. quadra/configs/backbone/smp.yaml +9 -0
  30. quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
  31. quadra/configs/backbone/unetr.yaml +15 -0
  32. quadra/configs/backbone/vit16_base.yaml +9 -0
  33. quadra/configs/backbone/vit16_small.yaml +9 -0
  34. quadra/configs/backbone/vit16_tiny.yaml +9 -0
  35. quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
  36. quadra/configs/callbacks/all.yaml +32 -0
  37. quadra/configs/callbacks/default.yaml +37 -0
  38. quadra/configs/callbacks/default_anomalib.yaml +67 -0
  39. quadra/configs/config.yaml +33 -0
  40. quadra/configs/core/default.yaml +11 -0
  41. quadra/configs/datamodule/base/anomaly.yaml +16 -0
  42. quadra/configs/datamodule/base/classification.yaml +21 -0
  43. quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
  44. quadra/configs/datamodule/base/segmentation.yaml +18 -0
  45. quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
  46. quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
  47. quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
  48. quadra/configs/datamodule/base/ssl.yaml +21 -0
  49. quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
  50. quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
  51. quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
  52. quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
  53. quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
  54. quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
  55. quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
  56. quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
  57. quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
  58. quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
  59. quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
  60. quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
  61. quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
  62. quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
  63. quadra/configs/experiment/base/classification/classification.yaml +73 -0
  64. quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
  65. quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
  66. quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
  67. quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
  68. quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
  69. quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
  70. quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
  71. quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
  72. quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
  73. quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
  74. quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
  75. quadra/configs/experiment/base/ssl/byol.yaml +43 -0
  76. quadra/configs/experiment/base/ssl/dino.yaml +46 -0
  77. quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
  78. quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
  79. quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
  80. quadra/configs/experiment/custom/cls.yaml +12 -0
  81. quadra/configs/experiment/default.yaml +15 -0
  82. quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
  83. quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
  84. quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
  85. quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
  86. quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
  87. quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
  88. quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
  89. quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
  90. quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
  91. quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
  92. quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
  93. quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
  94. quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
  95. quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
  96. quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
  97. quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
  98. quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
  99. quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
  100. quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
  101. quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
  102. quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
  103. quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
  104. quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
  105. quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
  106. quadra/configs/export/default.yaml +13 -0
  107. quadra/configs/hydra/anomaly_custom.yaml +15 -0
  108. quadra/configs/hydra/default.yaml +14 -0
  109. quadra/configs/inference/default.yaml +26 -0
  110. quadra/configs/logger/comet.yaml +10 -0
  111. quadra/configs/logger/csv.yaml +5 -0
  112. quadra/configs/logger/mlflow.yaml +12 -0
  113. quadra/configs/logger/tensorboard.yaml +8 -0
  114. quadra/configs/loss/asl.yaml +7 -0
  115. quadra/configs/loss/barlow.yaml +2 -0
  116. quadra/configs/loss/bce.yaml +1 -0
  117. quadra/configs/loss/byol.yaml +1 -0
  118. quadra/configs/loss/cross_entropy.yaml +1 -0
  119. quadra/configs/loss/dino.yaml +8 -0
  120. quadra/configs/loss/simclr.yaml +2 -0
  121. quadra/configs/loss/simsiam.yaml +1 -0
  122. quadra/configs/loss/smp_ce.yaml +3 -0
  123. quadra/configs/loss/smp_dice.yaml +2 -0
  124. quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
  125. quadra/configs/loss/smp_mcc.yaml +2 -0
  126. quadra/configs/loss/vicreg.yaml +5 -0
  127. quadra/configs/model/anomalib/cfa.yaml +35 -0
  128. quadra/configs/model/anomalib/cflow.yaml +30 -0
  129. quadra/configs/model/anomalib/csflow.yaml +34 -0
  130. quadra/configs/model/anomalib/dfm.yaml +19 -0
  131. quadra/configs/model/anomalib/draem.yaml +29 -0
  132. quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
  133. quadra/configs/model/anomalib/fastflow.yaml +32 -0
  134. quadra/configs/model/anomalib/padim.yaml +32 -0
  135. quadra/configs/model/anomalib/patchcore.yaml +36 -0
  136. quadra/configs/model/barlow.yaml +16 -0
  137. quadra/configs/model/byol.yaml +25 -0
  138. quadra/configs/model/classification.yaml +10 -0
  139. quadra/configs/model/dino.yaml +26 -0
  140. quadra/configs/model/logistic_regression.yaml +4 -0
  141. quadra/configs/model/multilabel_classification.yaml +9 -0
  142. quadra/configs/model/simclr.yaml +18 -0
  143. quadra/configs/model/simsiam.yaml +24 -0
  144. quadra/configs/model/smp.yaml +4 -0
  145. quadra/configs/model/smp_multiclass.yaml +4 -0
  146. quadra/configs/model/vicreg.yaml +16 -0
  147. quadra/configs/optimizer/adam.yaml +5 -0
  148. quadra/configs/optimizer/adamw.yaml +3 -0
  149. quadra/configs/optimizer/default.yaml +4 -0
  150. quadra/configs/optimizer/lars.yaml +8 -0
  151. quadra/configs/optimizer/sgd.yaml +4 -0
  152. quadra/configs/scheduler/default.yaml +5 -0
  153. quadra/configs/scheduler/rop.yaml +5 -0
  154. quadra/configs/scheduler/step.yaml +3 -0
  155. quadra/configs/scheduler/warmrestart.yaml +2 -0
  156. quadra/configs/scheduler/warmup.yaml +6 -0
  157. quadra/configs/task/anomalib/cfa.yaml +5 -0
  158. quadra/configs/task/anomalib/cflow.yaml +5 -0
  159. quadra/configs/task/anomalib/csflow.yaml +5 -0
  160. quadra/configs/task/anomalib/draem.yaml +5 -0
  161. quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
  162. quadra/configs/task/anomalib/fastflow.yaml +5 -0
  163. quadra/configs/task/anomalib/inference.yaml +3 -0
  164. quadra/configs/task/anomalib/padim.yaml +5 -0
  165. quadra/configs/task/anomalib/patchcore.yaml +5 -0
  166. quadra/configs/task/classification.yaml +6 -0
  167. quadra/configs/task/classification_evaluation.yaml +6 -0
  168. quadra/configs/task/default.yaml +1 -0
  169. quadra/configs/task/segmentation.yaml +9 -0
  170. quadra/configs/task/segmentation_evaluation.yaml +3 -0
  171. quadra/configs/task/sklearn_classification.yaml +13 -0
  172. quadra/configs/task/sklearn_classification_patch.yaml +11 -0
  173. quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
  174. quadra/configs/task/sklearn_classification_test.yaml +8 -0
  175. quadra/configs/task/ssl.yaml +2 -0
  176. quadra/configs/trainer/lightning_cpu.yaml +36 -0
  177. quadra/configs/trainer/lightning_gpu.yaml +35 -0
  178. quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
  179. quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
  180. quadra/configs/trainer/lightning_multigpu.yaml +37 -0
  181. quadra/configs/trainer/sklearn_classification.yaml +7 -0
  182. quadra/configs/transforms/byol.yaml +47 -0
  183. quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
  184. quadra/configs/transforms/default.yaml +37 -0
  185. quadra/configs/transforms/default_numpy.yaml +24 -0
  186. quadra/configs/transforms/default_resize.yaml +22 -0
  187. quadra/configs/transforms/dino.yaml +63 -0
  188. quadra/configs/transforms/linear_eval.yaml +18 -0
  189. quadra/datamodules/__init__.py +20 -0
  190. quadra/datamodules/anomaly.py +180 -0
  191. quadra/datamodules/base.py +375 -0
  192. quadra/datamodules/classification.py +1003 -0
  193. quadra/datamodules/generic/__init__.py +0 -0
  194. quadra/datamodules/generic/imagenette.py +144 -0
  195. quadra/datamodules/generic/mnist.py +81 -0
  196. quadra/datamodules/generic/mvtec.py +58 -0
  197. quadra/datamodules/generic/oxford_pet.py +163 -0
  198. quadra/datamodules/patch.py +190 -0
  199. quadra/datamodules/segmentation.py +742 -0
  200. quadra/datamodules/ssl.py +140 -0
  201. quadra/datasets/__init__.py +17 -0
  202. quadra/datasets/anomaly.py +287 -0
  203. quadra/datasets/classification.py +241 -0
  204. quadra/datasets/patch.py +138 -0
  205. quadra/datasets/segmentation.py +239 -0
  206. quadra/datasets/ssl.py +110 -0
  207. quadra/losses/__init__.py +0 -0
  208. quadra/losses/classification/__init__.py +6 -0
  209. quadra/losses/classification/asl.py +83 -0
  210. quadra/losses/classification/focal.py +320 -0
  211. quadra/losses/classification/prototypical.py +148 -0
  212. quadra/losses/ssl/__init__.py +17 -0
  213. quadra/losses/ssl/barlowtwins.py +47 -0
  214. quadra/losses/ssl/byol.py +37 -0
  215. quadra/losses/ssl/dino.py +129 -0
  216. quadra/losses/ssl/hyperspherical.py +45 -0
  217. quadra/losses/ssl/idmm.py +50 -0
  218. quadra/losses/ssl/simclr.py +67 -0
  219. quadra/losses/ssl/simsiam.py +30 -0
  220. quadra/losses/ssl/vicreg.py +76 -0
  221. quadra/main.py +46 -0
  222. quadra/metrics/__init__.py +3 -0
  223. quadra/metrics/segmentation.py +251 -0
  224. quadra/models/__init__.py +0 -0
  225. quadra/models/base.py +151 -0
  226. quadra/models/classification/__init__.py +8 -0
  227. quadra/models/classification/backbones.py +149 -0
  228. quadra/models/classification/base.py +92 -0
  229. quadra/models/evaluation.py +322 -0
  230. quadra/modules/__init__.py +0 -0
  231. quadra/modules/backbone.py +30 -0
  232. quadra/modules/base.py +312 -0
  233. quadra/modules/classification/__init__.py +3 -0
  234. quadra/modules/classification/base.py +331 -0
  235. quadra/modules/ssl/__init__.py +17 -0
  236. quadra/modules/ssl/barlowtwins.py +59 -0
  237. quadra/modules/ssl/byol.py +172 -0
  238. quadra/modules/ssl/common.py +285 -0
  239. quadra/modules/ssl/dino.py +186 -0
  240. quadra/modules/ssl/hyperspherical.py +206 -0
  241. quadra/modules/ssl/idmm.py +98 -0
  242. quadra/modules/ssl/simclr.py +73 -0
  243. quadra/modules/ssl/simsiam.py +68 -0
  244. quadra/modules/ssl/vicreg.py +67 -0
  245. quadra/optimizers/__init__.py +4 -0
  246. quadra/optimizers/lars.py +153 -0
  247. quadra/optimizers/sam.py +127 -0
  248. quadra/schedulers/__init__.py +3 -0
  249. quadra/schedulers/base.py +44 -0
  250. quadra/schedulers/warmup.py +127 -0
  251. quadra/tasks/__init__.py +24 -0
  252. quadra/tasks/anomaly.py +582 -0
  253. quadra/tasks/base.py +397 -0
  254. quadra/tasks/classification.py +1264 -0
  255. quadra/tasks/patch.py +492 -0
  256. quadra/tasks/segmentation.py +389 -0
  257. quadra/tasks/ssl.py +560 -0
  258. quadra/trainers/README.md +3 -0
  259. quadra/trainers/__init__.py +0 -0
  260. quadra/trainers/classification.py +179 -0
  261. quadra/utils/__init__.py +0 -0
  262. quadra/utils/anomaly.py +112 -0
  263. quadra/utils/classification.py +618 -0
  264. quadra/utils/deprecation.py +31 -0
  265. quadra/utils/evaluation.py +474 -0
  266. quadra/utils/export.py +579 -0
  267. quadra/utils/imaging.py +32 -0
  268. quadra/utils/logger.py +15 -0
  269. quadra/utils/mlflow.py +98 -0
  270. quadra/utils/model_manager.py +320 -0
  271. quadra/utils/models.py +524 -0
  272. quadra/utils/patch/__init__.py +15 -0
  273. quadra/utils/patch/dataset.py +1433 -0
  274. quadra/utils/patch/metrics.py +449 -0
  275. quadra/utils/patch/model.py +153 -0
  276. quadra/utils/patch/visualization.py +217 -0
  277. quadra/utils/resolver.py +42 -0
  278. quadra/utils/segmentation.py +31 -0
  279. quadra/utils/tests/__init__.py +0 -0
  280. quadra/utils/tests/fixtures/__init__.py +1 -0
  281. quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
  282. quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
  283. quadra/utils/tests/fixtures/dataset/classification.py +406 -0
  284. quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
  285. quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
  286. quadra/utils/tests/fixtures/models/__init__.py +3 -0
  287. quadra/utils/tests/fixtures/models/anomaly.py +89 -0
  288. quadra/utils/tests/fixtures/models/classification.py +45 -0
  289. quadra/utils/tests/fixtures/models/segmentation.py +33 -0
  290. quadra/utils/tests/helpers.py +70 -0
  291. quadra/utils/tests/models.py +27 -0
  292. quadra/utils/utils.py +525 -0
  293. quadra/utils/validator.py +115 -0
  294. quadra/utils/visualization.py +422 -0
  295. quadra/utils/vit_explainability.py +349 -0
  296. quadra-2.1.13.dist-info/LICENSE +201 -0
  297. quadra-2.1.13.dist-info/METADATA +386 -0
  298. quadra-2.1.13.dist-info/RECORD +300 -0
  299. {quadra-0.0.1.dist-info → quadra-2.1.13.dist-info}/WHEEL +1 -1
  300. quadra-2.1.13.dist-info/entry_points.txt +3 -0
  301. quadra-0.0.1.dist-info/METADATA +0 -14
  302. quadra-0.0.1.dist-info/RECORD +0 -4
@@ -0,0 +1,1003 @@
1
+ # pylint: disable=unsupported-assignment-operation,unsubscriptable-object
2
+ from __future__ import annotations
3
+
4
+ import os
5
+ import random
6
+ from collections.abc import Callable
7
+ from typing import Any
8
+
9
+ import albumentations
10
+ import numpy as np
11
+ import pandas as pd
12
+ import torch
13
+ from sklearn.model_selection import train_test_split
14
+ from skmultilearn.model_selection import iterative_train_test_split
15
+ from timm.data.readers.reader_image_folder import find_images_and_targets
16
+ from torch.utils.data import DataLoader
17
+
18
+ from quadra.datamodules.base import BaseDataModule
19
+ from quadra.datasets import ImageClassificationListDataset
20
+ from quadra.datasets.classification import MultilabelClassificationDataset
21
+ from quadra.utils import utils
22
+ from quadra.utils.classification import find_test_image, get_split, group_labels, natural_key
23
+
24
+ log = utils.get_logger(__name__)
25
+
26
+
27
+ class ClassificationDataModule(BaseDataModule):
28
+ """Base class single folder based classification datamodules. If there is no nested folders, use this class.
29
+
30
+ Args:
31
+ data_path: Path to the data main folder.
32
+ name: The name for the data module. Defaults to "classification_datamodule".
33
+ num_workers: Number of workers for dataloaders. Defaults to 16.
34
+ batch_size: Batch size. Defaults to 32.
35
+ seed: Random generator seed. Defaults to 42.
36
+ dataset: Dataset class.
37
+ val_size: The validation split. Defaults to 0.2.
38
+ test_size: The test split. Defaults to 0.2.
39
+ exclude_filter: The filter for excluding folders. Defaults to None.
40
+ include_filter: The filter for including folders. Defaults to None.
41
+ label_map: The mapping for labels. Defaults to None.
42
+ num_data_class: The number of samples per class. Defaults to None.
43
+ train_transform: Transformations for train dataset.
44
+ Defaults to None.
45
+ val_transform: Transformations for validation dataset.
46
+ Defaults to None.
47
+ test_transform: Transformations for test dataset.
48
+ Defaults to None.
49
+ train_split_file: The file with train split. Defaults to None.
50
+ val_split_file: The file with validation split. Defaults to None.
51
+ test_split_file: The file with test split. Defaults to None.
52
+ class_to_idx: The mapping from class name to index. Defaults to None.
53
+ **kwargs: Additional arguments for BaseDataModule.
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ data_path: str,
59
+ dataset: type[ImageClassificationListDataset] = ImageClassificationListDataset,
60
+ name: str = "classification_datamodule",
61
+ num_workers: int = 8,
62
+ batch_size: int = 32,
63
+ seed: int = 42,
64
+ val_size: float | None = 0.2,
65
+ test_size: float = 0.2,
66
+ num_data_class: int | None = None,
67
+ exclude_filter: list[str] | None = None,
68
+ include_filter: list[str] | None = None,
69
+ label_map: dict[str, Any] | None = None,
70
+ load_aug_images: bool = False,
71
+ aug_name: str | None = None,
72
+ n_aug_to_take: int | None = 4,
73
+ replace_str_from: str | None = None,
74
+ replace_str_to: str | None = None,
75
+ train_transform: albumentations.Compose | None = None,
76
+ val_transform: albumentations.Compose | None = None,
77
+ test_transform: albumentations.Compose | None = None,
78
+ train_split_file: str | None = None,
79
+ test_split_file: str | None = None,
80
+ val_split_file: str | None = None,
81
+ class_to_idx: dict[str, int] | None = None,
82
+ **kwargs: Any,
83
+ ):
84
+ super().__init__(
85
+ data_path=data_path,
86
+ name=name,
87
+ seed=seed,
88
+ batch_size=batch_size,
89
+ num_workers=num_workers,
90
+ train_transform=train_transform,
91
+ val_transform=val_transform,
92
+ test_transform=test_transform,
93
+ load_aug_images=load_aug_images,
94
+ aug_name=aug_name,
95
+ n_aug_to_take=n_aug_to_take,
96
+ replace_str_from=replace_str_from,
97
+ replace_str_to=replace_str_to,
98
+ **kwargs,
99
+ )
100
+ self.replace_str = None
101
+ self.exclude_filter = exclude_filter
102
+ self.include_filter = include_filter
103
+ self.val_size = val_size
104
+ self.test_size = test_size
105
+ self.label_map = label_map
106
+ self.num_data_class = num_data_class
107
+ self.dataset = dataset
108
+ self.train_split_file = train_split_file
109
+ self.test_split_file = test_split_file
110
+ self.val_split_file = val_split_file
111
+ self.class_to_idx: dict[str, int] | None
112
+
113
+ if class_to_idx is not None:
114
+ self.class_to_idx = class_to_idx
115
+ self.num_classes = len(self.class_to_idx)
116
+ else:
117
+ self.class_to_idx = self._find_classes_from_data_path(self.data_path)
118
+ if self.class_to_idx is None:
119
+ log.warning("Could not build a class_to_idx from the data_path subdirectories")
120
+ self.num_classes = 0
121
+ else:
122
+ self.num_classes = len(self.class_to_idx)
123
+
124
+ def _read_split(self, split_file: str) -> tuple[list[str], list[str]]:
125
+ """Reads split file.
126
+
127
+ Args:
128
+ split_file: Path to the split file.
129
+
130
+ Returns:
131
+ List of paths to images.
132
+ """
133
+ samples, targets = [], []
134
+ with open(split_file) as f:
135
+ split = f.readlines()
136
+ for row in split:
137
+ csv_values = row.split(",")
138
+ sample = str(",".join(csv_values[:-1])).strip()
139
+ target = csv_values[-1].strip()
140
+ sample_path = os.path.join(self.data_path, sample)
141
+ if os.path.exists(sample_path):
142
+ samples.append(sample_path)
143
+ targets.append(target)
144
+ else:
145
+ continue
146
+ # log.warning(f"{sample_path} does not exist")
147
+ return samples, targets
148
+
149
+ def _find_classes_from_data_path(self, data_path: str) -> dict[str, int] | None:
150
+ """Given a data_path, build a random class_to_idx from the subdirectories.
151
+
152
+ Args:
153
+ data_path: Path to the data main folder.
154
+
155
+ Returns:
156
+ class_to_idx dictionary.
157
+ """
158
+ subdirectories = []
159
+
160
+ # Check if the directory exists
161
+ if os.path.exists(data_path) and os.path.isdir(data_path):
162
+ # Iterate through the items in the directory
163
+ for item in os.listdir(data_path):
164
+ item_path = os.path.join(data_path, item)
165
+
166
+ # Check if it's a directory and not starting with "."
167
+ if (
168
+ os.path.isdir(item_path)
169
+ and not item.startswith(".")
170
+ # Check if there's at least one image file in the subdirectory
171
+ and any(
172
+ os.path.splitext(file)[1].lower().endswith(tuple(utils.IMAGE_EXTENSIONS))
173
+ for file in os.listdir(item_path)
174
+ )
175
+ ):
176
+ subdirectories.append(item)
177
+
178
+ if len(subdirectories) > 0:
179
+ return {cl: idx for idx, cl in enumerate(sorted(subdirectories))}
180
+ return None
181
+
182
+ return None
183
+
184
+ @staticmethod
185
+ def _find_images_and_targets(
186
+ root_folder: str, class_to_idx: dict[str, int] | None = None
187
+ ) -> tuple[list[tuple[str, int]], dict[str, int]]:
188
+ """Collects the samples from item folders."""
189
+ images_and_targets, class_to_idx = find_images_and_targets(
190
+ folder=root_folder, types=utils.IMAGE_EXTENSIONS, class_to_idx=class_to_idx
191
+ )
192
+ return images_and_targets, class_to_idx
193
+
194
+ def _filter_images_and_targets(
195
+ self, images_and_targets: list[tuple[str, int]], class_to_idx: dict[str, int]
196
+ ) -> tuple[list[str], list[str]]:
197
+ """Filters the images and targets."""
198
+ samples: list[str] = []
199
+ targets: list[str] = []
200
+ idx_to_class = {v: k for k, v in class_to_idx.items()}
201
+ images_and_targets = [(str(image_path), target) for image_path, target in images_and_targets]
202
+ for image_path, target in images_and_targets:
203
+ target_class = idx_to_class[target]
204
+ if self.exclude_filter is not None and any(
205
+ exclude_filter in image_path for exclude_filter in self.exclude_filter
206
+ ):
207
+ continue
208
+ if self.include_filter is not None:
209
+ if any(include_filter in image_path for include_filter in self.include_filter):
210
+ samples.append(str(image_path))
211
+ targets.append(target_class)
212
+ else:
213
+ samples.append(str(image_path))
214
+ targets.append(target_class)
215
+ return (
216
+ samples,
217
+ targets,
218
+ )
219
+
220
+ def _prepare_data(self) -> None:
221
+ """Prepares Classification data for the data module."""
222
+ images_and_targets, class_to_idx = self._find_images_and_targets(self.data_path, self.class_to_idx)
223
+ all_samples, all_targets = self._filter_images_and_targets(images_and_targets, class_to_idx)
224
+ if self.label_map is not None:
225
+ all_targets, _ = group_labels(all_targets, self.label_map)
226
+
227
+ samples_train: list[str] = []
228
+ targets_train: list[str] = []
229
+ samples_test: list[str] = []
230
+ targets_test: list[str] = []
231
+ samples_val: list[str] = []
232
+ targets_val: list[str] = []
233
+
234
+ if self.test_size < 1.0:
235
+ samples_train, samples_test, targets_train, targets_test = train_test_split(
236
+ all_samples,
237
+ all_targets,
238
+ test_size=self.test_size,
239
+ random_state=self.seed,
240
+ stratify=all_targets,
241
+ )
242
+ if self.test_split_file:
243
+ samples_test, targets_test = self._read_split(self.test_split_file)
244
+ if not self.train_split_file:
245
+ samples_train, targets_train = [], []
246
+ for sample, target in zip(all_samples, all_targets):
247
+ if sample not in samples_test:
248
+ samples_train.append(sample)
249
+ targets_train.append(target)
250
+ if self.train_split_file:
251
+ samples_train, targets_train = self._read_split(self.train_split_file)
252
+ if not self.test_split_file:
253
+ samples_test, targets_test = [], []
254
+ for sample, target in zip(all_samples, all_targets):
255
+ if sample not in samples_train:
256
+ samples_test.append(sample)
257
+ targets_test.append(target)
258
+ if self.val_split_file:
259
+ samples_val, targets_val = self._read_split(self.val_split_file)
260
+ if not self.test_split_file or not self.train_split_file:
261
+ raise ValueError("Validation split file is specified but no train or test split file is specified.")
262
+ else:
263
+ samples_train, samples_val, targets_train, targets_val = train_test_split(
264
+ samples_train,
265
+ targets_train,
266
+ test_size=self.val_size,
267
+ random_state=self.seed,
268
+ stratify=targets_train,
269
+ )
270
+
271
+ if self.num_data_class is not None:
272
+ samples_train_topick = []
273
+ targets_train_topick = []
274
+ for cl in np.unique(targets_train):
275
+ idx = np.where(np.array(targets_train) == cl)[0]
276
+ random.seed(self.seed)
277
+ random.shuffle(idx) # type: ignore[arg-type]
278
+ to_pick = idx[: self.num_data_class]
279
+ for i in to_pick:
280
+ samples_train_topick.append(samples_train[i])
281
+ targets_train_topick.append(cl)
282
+
283
+ samples_train = samples_train_topick
284
+ targets_train = targets_train_topick
285
+ else:
286
+ log.info("Test size is set to 1.0: all samples will be put in test-set")
287
+ samples_test = all_samples
288
+ targets_test = all_targets
289
+ train_df = pd.DataFrame({"samples": samples_train, "targets": targets_train})
290
+ train_df["split"] = "train"
291
+ val_df = pd.DataFrame({"samples": samples_val, "targets": targets_val})
292
+ val_df["split"] = "val"
293
+ test_df = pd.DataFrame({"samples": samples_test, "targets": targets_test})
294
+ test_df["split"] = "test"
295
+ self.data = pd.concat([train_df, val_df, test_df], axis=0)
296
+
297
+ # if self.load_aug_images:
298
+ # samples_train, targets_train = self.load_augmented_samples(
299
+ # samples_train, targets_train, self.replace_str, shuffle=True
300
+ # )
301
+ # samples_val, targets_val = self.load_augmented_samples(
302
+ # samples_val, targets_val , self.replace_str, shuffle=True
303
+ # )
304
+ unique_targets = [str(t) for t in np.unique(targets_train)]
305
+ if self.class_to_idx is None:
306
+ sorted_targets = sorted(unique_targets, key=natural_key)
307
+ class_to_idx = {c: idx for idx, c in enumerate(sorted_targets)}
308
+ self.class_to_idx = class_to_idx
309
+ log.info("Class_to_idx not provided in config, building it from targets: %s", class_to_idx)
310
+
311
+ if len(unique_targets) == 0:
312
+ log.warning("Unique_targets length is 0, training set is empty")
313
+ else:
314
+ if len(self.class_to_idx.keys()) != len(unique_targets):
315
+ raise ValueError(
316
+ "The number of classes in the class_to_idx dictionary does not match the number of unique targets."
317
+ f" `class_to_idx`: {self.class_to_idx}, `unique_targets`: {unique_targets}"
318
+ )
319
+ if not all(c in unique_targets for c in self.class_to_idx):
320
+ raise ValueError(
321
+ "The classes in the class_to_idx dictionary do not match the available unique targets in the"
322
+ " datasset. `class_to_idx`: {self.class_to_idx}, `unique_targets`: {unique_targets}"
323
+ )
324
+
325
+ def setup(self, stage: str | None = None) -> None:
326
+ """Setup data module based on stages of training."""
327
+ if stage in ["train", "fit"]:
328
+ self.train_dataset = self.dataset(
329
+ samples=self.data[self.data["split"] == "train"]["samples"].tolist(),
330
+ targets=self.data[self.data["split"] == "train"]["targets"].tolist(),
331
+ transform=self.train_transform,
332
+ class_to_idx=self.class_to_idx,
333
+ )
334
+ self.val_dataset = self.dataset(
335
+ samples=self.data[self.data["split"] == "val"]["samples"].tolist(),
336
+ targets=self.data[self.data["split"] == "val"]["targets"].tolist(),
337
+ transform=self.val_transform,
338
+ class_to_idx=self.class_to_idx,
339
+ )
340
+ if stage in ["test", "predict"]:
341
+ self.test_dataset = self.dataset(
342
+ samples=self.data[self.data["split"] == "test"]["samples"].tolist(),
343
+ targets=self.data[self.data["split"] == "test"]["targets"].tolist(),
344
+ transform=self.test_transform,
345
+ class_to_idx=self.class_to_idx,
346
+ )
347
+
348
+ def train_dataloader(self) -> DataLoader:
349
+ """Returns the train dataloader.
350
+
351
+ Raises:
352
+ ValueError: If train dataset is not initialized.
353
+
354
+ Returns:
355
+ Train dataloader.
356
+ """
357
+ if not self.train_dataset_available:
358
+ raise ValueError("Train dataset is not initialized")
359
+ if not isinstance(self.train_dataset, torch.utils.data.Dataset):
360
+ raise ValueError("Train dataset has to be single `torch.utils.data.Dataset` instance.")
361
+ return DataLoader(
362
+ self.train_dataset,
363
+ batch_size=self.batch_size,
364
+ shuffle=True,
365
+ num_workers=self.num_workers,
366
+ drop_last=False,
367
+ pin_memory=True,
368
+ persistent_workers=self.num_workers > 0,
369
+ )
370
+
371
+ def val_dataloader(self) -> DataLoader:
372
+ """Returns the validation dataloader.
373
+
374
+ Raises:
375
+ ValueError: If validation dataset is not initialized.
376
+
377
+ Returns:
378
+ val dataloader.
379
+ """
380
+ if not self.val_dataset_available:
381
+ raise ValueError("Validation dataset is not initialized")
382
+ if not isinstance(self.val_dataset, torch.utils.data.Dataset):
383
+ raise ValueError("Validation dataset has to be single `torch.utils.data.Dataset` instance.")
384
+ return DataLoader(
385
+ self.val_dataset,
386
+ batch_size=self.batch_size,
387
+ shuffle=False,
388
+ num_workers=self.num_workers,
389
+ drop_last=False,
390
+ pin_memory=True,
391
+ persistent_workers=self.num_workers > 0,
392
+ )
393
+
394
+ def test_dataloader(self) -> DataLoader:
395
+ """Returns the test dataloader.
396
+
397
+ Raises:
398
+ ValueError: If test dataset is not initialized.
399
+
400
+
401
+ Returns:
402
+ test dataloader.
403
+ """
404
+ if not self.test_dataset_available:
405
+ raise ValueError("Test dataset is not initialized")
406
+
407
+ loader = DataLoader(
408
+ self.test_dataset,
409
+ batch_size=self.batch_size,
410
+ shuffle=False,
411
+ num_workers=self.num_workers,
412
+ drop_last=False,
413
+ pin_memory=True,
414
+ persistent_workers=self.num_workers > 0,
415
+ )
416
+ return loader
417
+
418
+ def predict_dataloader(self) -> DataLoader:
419
+ """Returns a dataloader used for predictions."""
420
+ return self.test_dataloader()
421
+
422
+
423
+ class SklearnClassificationDataModule(BaseDataModule):
424
+ """A generic Data Module for classification with frozen torch backbone and sklearn classifier.
425
+
426
+ It can also handle k-fold cross validation.
427
+
428
+ Args:
429
+ name: The name for the data module. Defaults to "sklearn_classification_datamodule".
430
+ data_path: Path to images main folder
431
+ exclude_filter: List of string filter to be used to exclude images. If None no filter will be applied.
432
+ include_filter: List of string filter to be used to include images. Only images that satisfied at list one of
433
+ the filter will be included.
434
+ val_size: The validation split. Defaults to 0.2.
435
+ class_to_idx: Dictionary of conversion btw folder name and index. Only file whose label is in dictionary key
436
+ list will be considered. If None all files will be considered and a custom conversion is created.
437
+ seed: Fixed seed for random operations
438
+ batch_size: Dimension of batches for dataloader
439
+ num_workers: Number of workers for dataloader
440
+ train_transform: Albumentation transformations for training set
441
+ val_transform: Albumentation transformations for validation set
442
+ test_transform: Albumentation transformations for test set
443
+ roi: Optional cropping region
444
+ n_splits: Number of dataset subdivision (default 1 -> train/test). Use a value >= 2 for cross validation.
445
+ phase: Either train or test
446
+ cache: If true disable shuffling in all dataloader to enable feature caching
447
+ limit_training_data: if defined, each class will be donwsampled to this number. It must be >= 2 to allow
448
+ splitting
449
+ label_map: Dictionary of conversion btw folder name and label.
450
+ train_split_file: Optional path to a csv file containing the train split samples.
451
+ test_split_file: Optional path to a csv file containing the test split samples.
452
+ **kwargs: Additional arguments for BaseDataModule
453
+ """
454
+
455
+ def __init__(
456
+ self,
457
+ data_path: str,
458
+ exclude_filter: list[str] | None = None,
459
+ include_filter: list[str] | None = None,
460
+ val_size: float = 0.2,
461
+ class_to_idx: dict[str, int] | None = None,
462
+ label_map: dict[str, Any] | None = None,
463
+ seed: int = 42,
464
+ batch_size: int = 32,
465
+ num_workers: int = 6,
466
+ train_transform: albumentations.Compose | None = None,
467
+ val_transform: albumentations.Compose | None = None,
468
+ test_transform: albumentations.Compose | None = None,
469
+ roi: tuple[int, int, int, int] | None = None,
470
+ n_splits: int = 1,
471
+ phase: str = "train",
472
+ cache: bool = False,
473
+ limit_training_data: int | None = None,
474
+ train_split_file: str | None = None,
475
+ test_split_file: str | None = None,
476
+ name: str = "sklearn_classification_datamodule",
477
+ dataset: type[ImageClassificationListDataset] = ImageClassificationListDataset,
478
+ **kwargs: Any,
479
+ ):
480
+ super().__init__(
481
+ data_path=data_path,
482
+ name=name,
483
+ seed=seed,
484
+ batch_size=batch_size,
485
+ num_workers=num_workers,
486
+ train_transform=train_transform,
487
+ val_transform=val_transform,
488
+ test_transform=test_transform,
489
+ **kwargs,
490
+ )
491
+
492
+ self.class_to_idx = class_to_idx
493
+ self.roi = roi
494
+ self.cache = cache
495
+ self.limit_training_data = limit_training_data
496
+
497
+ self.dataset = dataset
498
+ self.phase = phase
499
+ self.n_splits = n_splits
500
+ self.train_split_file = train_split_file
501
+ self.test_split_file = test_split_file
502
+ self.exclude_filter = exclude_filter
503
+ self.include_filter = include_filter
504
+ self.val_size = val_size
505
+ self.label_map = label_map
506
+ self.full_dataset: ImageClassificationListDataset
507
+ self.train_dataset: list[ImageClassificationListDataset]
508
+ self.val_dataset: list[ImageClassificationListDataset]
509
+
510
+ def _prepare_data(self) -> None:
511
+ """Prepares the data for the data module."""
512
+ assert os.path.isdir(self.data_path), f"Folder {self.data_path} does not exist."
513
+
514
+ list_df = []
515
+ if self.phase == "train":
516
+ samples, targets, split_generator, self.class_to_idx = get_split(
517
+ image_dir=self.data_path,
518
+ exclude_filter=self.exclude_filter,
519
+ include_filter=self.include_filter,
520
+ test_size=self.val_size,
521
+ random_state=self.seed,
522
+ class_to_idx=self.class_to_idx,
523
+ n_splits=self.n_splits,
524
+ limit_training_data=self.limit_training_data,
525
+ train_split_file=self.train_split_file,
526
+ label_map=self.label_map,
527
+ )
528
+
529
+ for cv_idx, split in enumerate(split_generator):
530
+ train_idx, val_idx = split
531
+ train_val_df = pd.DataFrame({"samples": samples, "targets": targets})
532
+ train_val_df["cv"] = 0
533
+ train_val_df["split"] = "train"
534
+ train_val_df.loc[val_idx, "split"] = "val"
535
+ train_val_df.loc[train_idx, "cv"] = cv_idx
536
+ train_val_df.loc[val_idx, "cv"] = cv_idx
537
+ list_df.append(train_val_df)
538
+
539
+ test_samples, test_targets = find_test_image(
540
+ folder=self.data_path,
541
+ exclude_filter=self.exclude_filter,
542
+ include_filter=self.include_filter,
543
+ test_split_file=self.test_split_file,
544
+ )
545
+ if self.label_map is not None:
546
+ test_targets, _ = group_labels(test_targets, self.label_map)
547
+ test_df = pd.DataFrame({"samples": test_samples, "targets": test_targets})
548
+ test_df["split"] = "test"
549
+ test_df["cv"] = np.nan
550
+
551
+ list_df.append(test_df)
552
+ self.data = pd.concat(list_df, axis=0)
553
+
554
+ def setup(self, stage: str) -> None:
555
+ """Setup data module based on stages of training."""
556
+ if stage == "fit":
557
+ self.train_dataset = []
558
+ self.val_dataset = []
559
+
560
+ for cv_idx in range(self.n_splits):
561
+ cv_df = self.data[self.data["cv"] == cv_idx]
562
+ train_samples = cv_df[cv_df["split"] == "train"]["samples"].tolist()
563
+ train_targets = cv_df[cv_df["split"] == "train"]["targets"].tolist()
564
+ val_samples = cv_df[cv_df["split"] == "val"]["samples"].tolist()
565
+ val_targets = cv_df[cv_df["split"] == "val"]["targets"].tolist()
566
+ self.train_dataset.append(
567
+ self.dataset(
568
+ class_to_idx=self.class_to_idx,
569
+ samples=train_samples,
570
+ targets=train_targets,
571
+ transform=self.train_transform,
572
+ roi=self.roi,
573
+ )
574
+ )
575
+ self.val_dataset.append(
576
+ self.dataset(
577
+ class_to_idx=self.class_to_idx,
578
+ samples=val_samples,
579
+ targets=val_targets,
580
+ transform=self.val_transform,
581
+ roi=self.roi,
582
+ )
583
+ )
584
+ all_samples = self.data[self.data["cv"] == 0]["samples"].tolist()
585
+ all_targets = self.data[self.data["cv"] == 0]["targets"].tolist()
586
+ self.full_dataset = self.dataset(
587
+ class_to_idx=self.class_to_idx,
588
+ samples=all_samples,
589
+ targets=all_targets,
590
+ transform=self.train_transform,
591
+ roi=self.roi,
592
+ )
593
+ if stage == "test":
594
+ test_samples = self.data[self.data["split"] == "test"]["samples"].tolist()
595
+ test_targets = self.data[self.data["split"] == "test"]["targets"]
596
+ self.test_dataset = self.dataset(
597
+ class_to_idx=self.class_to_idx,
598
+ samples=test_samples,
599
+ targets=test_targets.tolist(),
600
+ transform=self.test_transform,
601
+ roi=self.roi,
602
+ allow_missing_label=True,
603
+ )
604
+
605
+ def predict_dataloader(self) -> DataLoader:
606
+ """Returns a dataloader used for predictions."""
607
+ return self.test_dataloader()
608
+
609
+ def train_dataloader(self) -> list[DataLoader]:
610
+ """Returns a list of train dataloader.
611
+
612
+ Raises:
613
+ ValueError: If train dataset is not initialized.
614
+
615
+ Returns:
616
+ list of train dataloader.
617
+ """
618
+ if not self.train_dataset_available:
619
+ raise ValueError("Train dataset is not initialized")
620
+
621
+ loader = []
622
+ for dataset in self.train_dataset:
623
+ loader.append(
624
+ DataLoader(
625
+ dataset,
626
+ batch_size=self.batch_size,
627
+ shuffle=not self.cache,
628
+ num_workers=self.num_workers,
629
+ drop_last=False,
630
+ pin_memory=True,
631
+ )
632
+ )
633
+ return loader
634
+
635
+ def val_dataloader(self) -> list[DataLoader]:
636
+ """Returns a list of validation dataloader.
637
+
638
+ Raises:
639
+ ValueError: If validation dataset is not initialized.
640
+
641
+ Returns:
642
+ List of validation dataloader.
643
+ """
644
+ if not self.val_dataset_available:
645
+ raise ValueError("Validation dataset is not initialized")
646
+
647
+ loader = []
648
+ for dataset in self.val_dataset:
649
+ loader.append(
650
+ DataLoader(
651
+ dataset,
652
+ batch_size=self.batch_size,
653
+ shuffle=False,
654
+ num_workers=self.num_workers,
655
+ drop_last=False,
656
+ pin_memory=True,
657
+ )
658
+ )
659
+
660
+ return loader
661
+
662
+ def test_dataloader(self) -> DataLoader:
663
+ """Returns the test dataloader.
664
+
665
+ Raises:
666
+ ValueError: If test dataset is not initialized.
667
+
668
+
669
+ Returns:
670
+ test dataloader.
671
+ """
672
+ if not self.test_dataset_available:
673
+ raise ValueError("Test dataset is not initialized")
674
+
675
+ loader = DataLoader(
676
+ self.test_dataset,
677
+ batch_size=self.batch_size,
678
+ shuffle=False,
679
+ num_workers=self.num_workers,
680
+ drop_last=False,
681
+ pin_memory=True,
682
+ persistent_workers=self.num_workers > 0,
683
+ )
684
+ return loader
685
+
686
+ def full_dataloader(self) -> DataLoader:
687
+ """Return a dataloader to perform training on the entire dataset.
688
+
689
+ Returns:
690
+ dataloader to perform training on the entire dataset after evaluation. This is useful
691
+ to perform a final training on the entire dataset after the evaluation phase.
692
+
693
+ """
694
+ if self.full_dataset is None:
695
+ raise ValueError("Full dataset is not initialized")
696
+
697
+ return DataLoader(
698
+ self.full_dataset,
699
+ batch_size=self.batch_size,
700
+ shuffle=not self.cache,
701
+ num_workers=self.num_workers,
702
+ drop_last=False,
703
+ pin_memory=True,
704
+ )
705
+
706
+
707
+ class MultilabelClassificationDataModule(BaseDataModule):
708
+ """Base class for all multi-label modules.
709
+
710
+ Args:
711
+ data_path: Path to the data main folder.
712
+ images_and_labels_file: a path to a txt file containing the relative (to `data_path`) path
713
+ of images with their relative labels, in a comma-separated way.
714
+ E.g.:
715
+
716
+ * path1,l1,l2,l3
717
+ * path2,l4,l5
718
+ * ...
719
+
720
+ One of `images_and_label` and both `train_split_file` and `test_split_file` must be set.
721
+ Defaults to None.
722
+ name: The name for the data module. Defaults to "multilabel_datamodule".
723
+ dataset: a callable returning a torch.utils.data.Dataset class.
724
+ num_classes: the number of classes in the dataset. This is used to create one-hot encoded
725
+ targets. Defaults to None.
726
+ num_workers: Number of workers for dataloaders. Defaults to 16.
727
+ batch_size: Training batch size. Defaults to 64.
728
+ test_batch_size: Testing batch size. Defaults to 64.
729
+ seed: Random generator seed. Defaults to SegmentationEvalua2.
730
+ val_size: The validation split. Defaults to 0.2.
731
+ test_size: The test split. Defaults to 0.2.
732
+ train_transform: Transformations for train dataset.
733
+ Defaults to None.
734
+ val_transform: Transformations for validation dataset.
735
+ Defaults to None.
736
+ test_transform: Transformations for test dataset.
737
+ Defaults to None.
738
+ train_split_file: The file with train split. Defaults to None.
739
+ val_split_file: The file with validation split. Defaults to None.
740
+ test_split_file: The file with test split. Defaults to None.
741
+ class_to_idx: a clss to idx dictionary. Defaults to None.
742
+ """
743
+
744
+ def __init__(
745
+ self,
746
+ data_path: str,
747
+ images_and_labels_file: str | None = None,
748
+ train_split_file: str | None = None,
749
+ test_split_file: str | None = None,
750
+ val_split_file: str | None = None,
751
+ name: str = "multilabel_datamodule",
752
+ dataset: Callable = MultilabelClassificationDataset,
753
+ num_classes: int | None = None,
754
+ num_workers: int = 16,
755
+ batch_size: int = 64,
756
+ test_batch_size: int = 64,
757
+ seed: int = 42,
758
+ val_size: float | None = 0.2,
759
+ test_size: float | None = 0.2,
760
+ train_transform: albumentations.Compose | None = None,
761
+ val_transform: albumentations.Compose | None = None,
762
+ test_transform: albumentations.Compose | None = None,
763
+ class_to_idx: dict[str, int] | None = None,
764
+ **kwargs,
765
+ ):
766
+ super().__init__(
767
+ data_path=data_path,
768
+ name=name,
769
+ num_workers=num_workers,
770
+ batch_size=batch_size,
771
+ seed=seed,
772
+ train_transform=train_transform,
773
+ val_transform=val_transform,
774
+ test_transform=test_transform,
775
+ **kwargs,
776
+ )
777
+ if not (images_and_labels_file is not None or (train_split_file is not None and test_split_file is not None)):
778
+ raise ValueError(
779
+ "Either `images_and_labels_file` or both `train_split_file` and `test_split_file` must be set"
780
+ )
781
+ self.images_and_labels_file = images_and_labels_file
782
+ self.dataset = dataset
783
+ self.num_classes = num_classes
784
+ self.train_batch_size = batch_size
785
+ self.test_batch_size = test_batch_size
786
+ self.val_size = val_size
787
+ self.test_size = test_size
788
+ self.train_split_file = train_split_file
789
+ self.test_split_file = test_split_file
790
+ self.val_split_file = val_split_file
791
+ self.class_to_idx = class_to_idx
792
+ self.train_dataset: MultilabelClassificationDataset
793
+ self.val_dataset: MultilabelClassificationDataset
794
+ self.test_dataset: MultilabelClassificationDataset
795
+
796
+ def _read_split(self, split_file: str) -> tuple[list[str], list[list[str]]]:
797
+ """Reads split file.
798
+
799
+ Args:
800
+ split_file: Path to the split file.
801
+
802
+ Returns:
803
+ Tuple containing list of paths to images and list of labels.
804
+ """
805
+ all_samples, all_targets = [], []
806
+ with open(split_file) as f:
807
+ for line in f.readlines():
808
+ split_line = line.split(",")
809
+ sample = os.path.join(self.data_path, split_line[0])
810
+ targets = [t.strip() for t in split_line[1:]]
811
+ if len(targets) == 0:
812
+ continue
813
+ all_samples.append(sample)
814
+ all_targets.append(targets)
815
+ return all_samples, all_targets
816
+
817
+ def _prepare_data(self) -> None:
818
+ """Prepares the data for the data module."""
819
+ if self.images_and_labels_file is not None:
820
+ # Read all images and targets
821
+ all_samples, all_targets = self._read_split(self.images_and_labels_file)
822
+ all_samples = np.array(all_samples).reshape(-1, 1)
823
+
824
+ # Targets to idx
825
+ unique_targets = set(utils.flatten_list(all_targets))
826
+ if self.class_to_idx is None:
827
+ self.class_to_idx = {c: i for i, c in enumerate(unique_targets)}
828
+
829
+ all_targets = [[self.class_to_idx[t] for t in targets] for targets in all_targets]
830
+
831
+ # Transform targets to one-hot
832
+ if self.num_classes is None:
833
+ self.num_classes = len(unique_targets)
834
+ all_targets = np.array([[i in targets for i in range(self.num_classes)] for targets in all_targets]).astype(
835
+ int
836
+ )
837
+
838
+ # Create splits
839
+ samples_train, targets_train, samples_test, targets_test = iterative_train_test_split(
840
+ all_samples, all_targets, test_size=self.test_size
841
+ )
842
+ elif self.train_split_file is not None and self.test_split_file is not None:
843
+ # Both train_split_file and test_split_file are set
844
+ samples_train, targets_train = self._read_split(self.train_split_file)
845
+ samples_test, targets_test = self._read_split(self.test_split_file)
846
+
847
+ # Create class_to_idx from all targets
848
+ unique_targets = set(utils.flatten_list(targets_test + targets_train))
849
+ if self.class_to_idx is None:
850
+ self.class_to_idx = {c: i for i, c in enumerate(unique_targets)}
851
+
852
+ # Transform targets to one-hot
853
+ if self.num_classes is None:
854
+ self.num_classes = len(unique_targets)
855
+ targets_test = [[self.class_to_idx[t] for t in targets] for targets in targets_test]
856
+ targets_test = np.array(
857
+ [[i in targets for i in range(self.num_classes)] for targets in targets_test]
858
+ ).astype(int)
859
+ targets_train = [[self.class_to_idx[t] for t in targets] for targets in targets_train]
860
+ targets_train = np.array(
861
+ [[i in targets for i in range(self.num_classes)] for targets in targets_train]
862
+ ).astype(int)
863
+ else:
864
+ raise ValueError(
865
+ "Either `images_and_labels_file` or both `train_split_file` and `test_split_file` must be set"
866
+ )
867
+
868
+ if self.val_split_file:
869
+ if not self.test_split_file or not self.train_split_file:
870
+ raise ValueError("Validation split file is specified but no train or test split file is specified.")
871
+ samples_val, targets_val = self._read_split(self.val_split_file)
872
+ targets_val = [[self.class_to_idx[t] for t in targets] for targets in targets_val]
873
+ targets_val = np.array([[i in targets for i in range(self.num_classes)] for targets in targets_val]).astype(
874
+ int
875
+ )
876
+ else:
877
+ samples_train = np.array(samples_train).reshape(-1, 1)
878
+ targets_train = np.array(targets_train).reshape(-1, self.num_classes)
879
+ samples_train, targets_train, samples_val, targets_val = iterative_train_test_split(
880
+ samples_train, targets_train, test_size=self.val_size
881
+ )
882
+
883
+ if isinstance(samples_train, np.ndarray):
884
+ samples_train = samples_train.flatten().tolist()
885
+ if isinstance(samples_val, np.ndarray):
886
+ samples_val = samples_val.flatten().tolist()
887
+ if isinstance(samples_test, np.ndarray):
888
+ samples_test = samples_test.flatten().tolist()
889
+
890
+ if isinstance(targets_train, np.ndarray):
891
+ targets_train = list(targets_train)
892
+ if isinstance(targets_val, np.ndarray):
893
+ targets_val = list(targets_val) # type: ignore[assignment]
894
+ if isinstance(targets_test, np.ndarray):
895
+ targets_test = list(targets_test)
896
+
897
+ # Create data
898
+ train_df = pd.DataFrame({"samples": samples_train, "targets": targets_train})
899
+ train_df["split"] = "train"
900
+ val_df = pd.DataFrame({"samples": samples_val, "targets": targets_val})
901
+ val_df["split"] = "val"
902
+ test_df = pd.DataFrame({"samples": samples_test, "targets": targets_test})
903
+ test_df["split"] = "test"
904
+ self.data = pd.concat([train_df, val_df, test_df], axis=0)
905
+
906
+ def setup(self, stage: str | None = None) -> None:
907
+ """Setup data module based on stages of training."""
908
+ if stage in ["train", "fit"]:
909
+ train_samples = self.data[self.data["split"] == "train"]["samples"].tolist()
910
+ train_targets = self.data[self.data["split"] == "train"]["targets"].tolist()
911
+ val_samples = self.data[self.data["split"] == "val"]["samples"].tolist()
912
+ val_targets = self.data[self.data["split"] == "val"]["targets"].tolist()
913
+ self.train_dataset = self.dataset(
914
+ samples=train_samples,
915
+ targets=train_targets,
916
+ transform=self.train_transform,
917
+ class_to_idx=self.class_to_idx,
918
+ )
919
+ self.val_dataset = self.dataset(
920
+ samples=val_samples,
921
+ targets=val_targets,
922
+ transform=self.val_transform,
923
+ class_to_idx=self.class_to_idx,
924
+ )
925
+ if stage == "test":
926
+ test_samples = self.data[self.data["split"] == "test"]["samples"].tolist()
927
+ test_targets = self.data[self.data["split"] == "test"]["targets"].tolist()
928
+ self.test_dataset = self.dataset(
929
+ samples=test_samples,
930
+ targets=test_targets,
931
+ transform=self.test_transform,
932
+ class_to_idx=self.class_to_idx,
933
+ )
934
+
935
+ def train_dataloader(self) -> DataLoader:
936
+ """Returns the train dataloader.
937
+
938
+ Raises:
939
+ ValueError: If train dataset is not initialized.
940
+
941
+ Returns:
942
+ Train dataloader.
943
+ """
944
+ if not self.train_dataset_available:
945
+ raise ValueError("Train dataset is not initialized")
946
+ return DataLoader(
947
+ self.train_dataset,
948
+ batch_size=self.batch_size,
949
+ shuffle=True,
950
+ num_workers=self.num_workers,
951
+ drop_last=False,
952
+ pin_memory=True,
953
+ persistent_workers=self.num_workers > 0,
954
+ )
955
+
956
+ def val_dataloader(self) -> DataLoader:
957
+ """Returns the validation dataloader.
958
+
959
+ Raises:
960
+ ValueError: If validation dataset is not initialized.
961
+
962
+ Returns:
963
+ val dataloader.
964
+ """
965
+ if not self.val_dataset_available:
966
+ raise ValueError("Validation dataset is not initialized")
967
+ return DataLoader(
968
+ self.val_dataset,
969
+ batch_size=self.batch_size,
970
+ shuffle=False,
971
+ num_workers=self.num_workers,
972
+ drop_last=False,
973
+ pin_memory=True,
974
+ persistent_workers=self.num_workers > 0,
975
+ )
976
+
977
+ def test_dataloader(self) -> DataLoader:
978
+ """Returns the test dataloader.
979
+
980
+ Raises:
981
+ ValueError: If test dataset is not initialized.
982
+
983
+
984
+ Returns:
985
+ test dataloader.
986
+ """
987
+ if not self.test_dataset_available:
988
+ raise ValueError("Test dataset is not initialized")
989
+
990
+ loader = DataLoader(
991
+ self.test_dataset,
992
+ batch_size=self.batch_size,
993
+ shuffle=False,
994
+ num_workers=self.num_workers,
995
+ drop_last=False,
996
+ pin_memory=True,
997
+ persistent_workers=self.num_workers > 0,
998
+ )
999
+ return loader
1000
+
1001
+ def predict_dataloader(self) -> DataLoader:
1002
+ """Returns a dataloader used for predictions."""
1003
+ return self.test_dataloader()