quadra 0.0.1__py3-none-any.whl → 2.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (302) hide show
  1. hydra_plugins/quadra_searchpath_plugin.py +30 -0
  2. quadra/__init__.py +6 -0
  3. quadra/callbacks/__init__.py +0 -0
  4. quadra/callbacks/anomalib.py +289 -0
  5. quadra/callbacks/lightning.py +501 -0
  6. quadra/callbacks/mlflow.py +291 -0
  7. quadra/callbacks/scheduler.py +69 -0
  8. quadra/configs/__init__.py +0 -0
  9. quadra/configs/backbone/caformer_m36.yaml +8 -0
  10. quadra/configs/backbone/caformer_s36.yaml +8 -0
  11. quadra/configs/backbone/convnextv2_base.yaml +8 -0
  12. quadra/configs/backbone/convnextv2_femto.yaml +8 -0
  13. quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
  14. quadra/configs/backbone/dino_vitb8.yaml +12 -0
  15. quadra/configs/backbone/dino_vits8.yaml +12 -0
  16. quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
  17. quadra/configs/backbone/dinov2_vits14.yaml +12 -0
  18. quadra/configs/backbone/efficientnet_b0.yaml +8 -0
  19. quadra/configs/backbone/efficientnet_b1.yaml +8 -0
  20. quadra/configs/backbone/efficientnet_b2.yaml +8 -0
  21. quadra/configs/backbone/efficientnet_b3.yaml +8 -0
  22. quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
  23. quadra/configs/backbone/levit_128s.yaml +8 -0
  24. quadra/configs/backbone/mnasnet0_5.yaml +9 -0
  25. quadra/configs/backbone/resnet101.yaml +8 -0
  26. quadra/configs/backbone/resnet18.yaml +8 -0
  27. quadra/configs/backbone/resnet18_ssl.yaml +8 -0
  28. quadra/configs/backbone/resnet50.yaml +8 -0
  29. quadra/configs/backbone/smp.yaml +9 -0
  30. quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
  31. quadra/configs/backbone/unetr.yaml +15 -0
  32. quadra/configs/backbone/vit16_base.yaml +9 -0
  33. quadra/configs/backbone/vit16_small.yaml +9 -0
  34. quadra/configs/backbone/vit16_tiny.yaml +9 -0
  35. quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
  36. quadra/configs/callbacks/all.yaml +45 -0
  37. quadra/configs/callbacks/default.yaml +34 -0
  38. quadra/configs/callbacks/default_anomalib.yaml +64 -0
  39. quadra/configs/config.yaml +33 -0
  40. quadra/configs/core/default.yaml +11 -0
  41. quadra/configs/datamodule/base/anomaly.yaml +16 -0
  42. quadra/configs/datamodule/base/classification.yaml +21 -0
  43. quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
  44. quadra/configs/datamodule/base/segmentation.yaml +18 -0
  45. quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
  46. quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
  47. quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
  48. quadra/configs/datamodule/base/ssl.yaml +21 -0
  49. quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
  50. quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
  51. quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
  52. quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
  53. quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
  54. quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
  55. quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
  56. quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
  57. quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
  58. quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
  59. quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
  60. quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
  61. quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
  62. quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
  63. quadra/configs/experiment/base/classification/classification.yaml +73 -0
  64. quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
  65. quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
  66. quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
  67. quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
  68. quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
  69. quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
  70. quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
  71. quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
  72. quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
  73. quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
  74. quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
  75. quadra/configs/experiment/base/ssl/byol.yaml +43 -0
  76. quadra/configs/experiment/base/ssl/dino.yaml +46 -0
  77. quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
  78. quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
  79. quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
  80. quadra/configs/experiment/custom/cls.yaml +12 -0
  81. quadra/configs/experiment/default.yaml +15 -0
  82. quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
  83. quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
  84. quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
  85. quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
  86. quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
  87. quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
  88. quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
  89. quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
  90. quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
  91. quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
  92. quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
  93. quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
  94. quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
  95. quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
  96. quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
  97. quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
  98. quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
  99. quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
  100. quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
  101. quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
  102. quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
  103. quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
  104. quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
  105. quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
  106. quadra/configs/export/default.yaml +13 -0
  107. quadra/configs/hydra/anomaly_custom.yaml +15 -0
  108. quadra/configs/hydra/default.yaml +14 -0
  109. quadra/configs/inference/default.yaml +26 -0
  110. quadra/configs/logger/comet.yaml +10 -0
  111. quadra/configs/logger/csv.yaml +5 -0
  112. quadra/configs/logger/mlflow.yaml +12 -0
  113. quadra/configs/logger/tensorboard.yaml +8 -0
  114. quadra/configs/loss/asl.yaml +7 -0
  115. quadra/configs/loss/barlow.yaml +2 -0
  116. quadra/configs/loss/bce.yaml +1 -0
  117. quadra/configs/loss/byol.yaml +1 -0
  118. quadra/configs/loss/cross_entropy.yaml +1 -0
  119. quadra/configs/loss/dino.yaml +8 -0
  120. quadra/configs/loss/simclr.yaml +2 -0
  121. quadra/configs/loss/simsiam.yaml +1 -0
  122. quadra/configs/loss/smp_ce.yaml +3 -0
  123. quadra/configs/loss/smp_dice.yaml +2 -0
  124. quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
  125. quadra/configs/loss/smp_mcc.yaml +2 -0
  126. quadra/configs/loss/vicreg.yaml +5 -0
  127. quadra/configs/model/anomalib/cfa.yaml +35 -0
  128. quadra/configs/model/anomalib/cflow.yaml +30 -0
  129. quadra/configs/model/anomalib/csflow.yaml +34 -0
  130. quadra/configs/model/anomalib/dfm.yaml +19 -0
  131. quadra/configs/model/anomalib/draem.yaml +29 -0
  132. quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
  133. quadra/configs/model/anomalib/fastflow.yaml +32 -0
  134. quadra/configs/model/anomalib/padim.yaml +32 -0
  135. quadra/configs/model/anomalib/patchcore.yaml +36 -0
  136. quadra/configs/model/barlow.yaml +16 -0
  137. quadra/configs/model/byol.yaml +25 -0
  138. quadra/configs/model/classification.yaml +10 -0
  139. quadra/configs/model/dino.yaml +26 -0
  140. quadra/configs/model/logistic_regression.yaml +4 -0
  141. quadra/configs/model/multilabel_classification.yaml +9 -0
  142. quadra/configs/model/simclr.yaml +18 -0
  143. quadra/configs/model/simsiam.yaml +24 -0
  144. quadra/configs/model/smp.yaml +4 -0
  145. quadra/configs/model/smp_multiclass.yaml +4 -0
  146. quadra/configs/model/vicreg.yaml +16 -0
  147. quadra/configs/optimizer/adam.yaml +5 -0
  148. quadra/configs/optimizer/adamw.yaml +3 -0
  149. quadra/configs/optimizer/default.yaml +4 -0
  150. quadra/configs/optimizer/lars.yaml +8 -0
  151. quadra/configs/optimizer/sgd.yaml +4 -0
  152. quadra/configs/scheduler/default.yaml +5 -0
  153. quadra/configs/scheduler/rop.yaml +5 -0
  154. quadra/configs/scheduler/step.yaml +3 -0
  155. quadra/configs/scheduler/warmrestart.yaml +2 -0
  156. quadra/configs/scheduler/warmup.yaml +6 -0
  157. quadra/configs/task/anomalib/cfa.yaml +5 -0
  158. quadra/configs/task/anomalib/cflow.yaml +5 -0
  159. quadra/configs/task/anomalib/csflow.yaml +5 -0
  160. quadra/configs/task/anomalib/draem.yaml +5 -0
  161. quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
  162. quadra/configs/task/anomalib/fastflow.yaml +5 -0
  163. quadra/configs/task/anomalib/inference.yaml +3 -0
  164. quadra/configs/task/anomalib/padim.yaml +5 -0
  165. quadra/configs/task/anomalib/patchcore.yaml +5 -0
  166. quadra/configs/task/classification.yaml +6 -0
  167. quadra/configs/task/classification_evaluation.yaml +6 -0
  168. quadra/configs/task/default.yaml +1 -0
  169. quadra/configs/task/segmentation.yaml +9 -0
  170. quadra/configs/task/segmentation_evaluation.yaml +3 -0
  171. quadra/configs/task/sklearn_classification.yaml +13 -0
  172. quadra/configs/task/sklearn_classification_patch.yaml +11 -0
  173. quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
  174. quadra/configs/task/sklearn_classification_test.yaml +8 -0
  175. quadra/configs/task/ssl.yaml +2 -0
  176. quadra/configs/trainer/lightning_cpu.yaml +36 -0
  177. quadra/configs/trainer/lightning_gpu.yaml +35 -0
  178. quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
  179. quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
  180. quadra/configs/trainer/lightning_multigpu.yaml +37 -0
  181. quadra/configs/trainer/sklearn_classification.yaml +7 -0
  182. quadra/configs/transforms/byol.yaml +47 -0
  183. quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
  184. quadra/configs/transforms/default.yaml +37 -0
  185. quadra/configs/transforms/default_numpy.yaml +24 -0
  186. quadra/configs/transforms/default_resize.yaml +22 -0
  187. quadra/configs/transforms/dino.yaml +63 -0
  188. quadra/configs/transforms/linear_eval.yaml +18 -0
  189. quadra/datamodules/__init__.py +20 -0
  190. quadra/datamodules/anomaly.py +180 -0
  191. quadra/datamodules/base.py +375 -0
  192. quadra/datamodules/classification.py +1003 -0
  193. quadra/datamodules/generic/__init__.py +0 -0
  194. quadra/datamodules/generic/imagenette.py +144 -0
  195. quadra/datamodules/generic/mnist.py +81 -0
  196. quadra/datamodules/generic/mvtec.py +58 -0
  197. quadra/datamodules/generic/oxford_pet.py +163 -0
  198. quadra/datamodules/patch.py +190 -0
  199. quadra/datamodules/segmentation.py +742 -0
  200. quadra/datamodules/ssl.py +140 -0
  201. quadra/datasets/__init__.py +17 -0
  202. quadra/datasets/anomaly.py +287 -0
  203. quadra/datasets/classification.py +241 -0
  204. quadra/datasets/patch.py +138 -0
  205. quadra/datasets/segmentation.py +239 -0
  206. quadra/datasets/ssl.py +110 -0
  207. quadra/losses/__init__.py +0 -0
  208. quadra/losses/classification/__init__.py +6 -0
  209. quadra/losses/classification/asl.py +83 -0
  210. quadra/losses/classification/focal.py +320 -0
  211. quadra/losses/classification/prototypical.py +148 -0
  212. quadra/losses/ssl/__init__.py +17 -0
  213. quadra/losses/ssl/barlowtwins.py +47 -0
  214. quadra/losses/ssl/byol.py +37 -0
  215. quadra/losses/ssl/dino.py +129 -0
  216. quadra/losses/ssl/hyperspherical.py +45 -0
  217. quadra/losses/ssl/idmm.py +50 -0
  218. quadra/losses/ssl/simclr.py +67 -0
  219. quadra/losses/ssl/simsiam.py +30 -0
  220. quadra/losses/ssl/vicreg.py +76 -0
  221. quadra/main.py +49 -0
  222. quadra/metrics/__init__.py +3 -0
  223. quadra/metrics/segmentation.py +251 -0
  224. quadra/models/__init__.py +0 -0
  225. quadra/models/base.py +151 -0
  226. quadra/models/classification/__init__.py +8 -0
  227. quadra/models/classification/backbones.py +149 -0
  228. quadra/models/classification/base.py +92 -0
  229. quadra/models/evaluation.py +322 -0
  230. quadra/modules/__init__.py +0 -0
  231. quadra/modules/backbone.py +30 -0
  232. quadra/modules/base.py +312 -0
  233. quadra/modules/classification/__init__.py +3 -0
  234. quadra/modules/classification/base.py +327 -0
  235. quadra/modules/ssl/__init__.py +17 -0
  236. quadra/modules/ssl/barlowtwins.py +59 -0
  237. quadra/modules/ssl/byol.py +172 -0
  238. quadra/modules/ssl/common.py +285 -0
  239. quadra/modules/ssl/dino.py +186 -0
  240. quadra/modules/ssl/hyperspherical.py +206 -0
  241. quadra/modules/ssl/idmm.py +98 -0
  242. quadra/modules/ssl/simclr.py +73 -0
  243. quadra/modules/ssl/simsiam.py +68 -0
  244. quadra/modules/ssl/vicreg.py +67 -0
  245. quadra/optimizers/__init__.py +4 -0
  246. quadra/optimizers/lars.py +153 -0
  247. quadra/optimizers/sam.py +127 -0
  248. quadra/schedulers/__init__.py +3 -0
  249. quadra/schedulers/base.py +44 -0
  250. quadra/schedulers/warmup.py +127 -0
  251. quadra/tasks/__init__.py +24 -0
  252. quadra/tasks/anomaly.py +582 -0
  253. quadra/tasks/base.py +397 -0
  254. quadra/tasks/classification.py +1263 -0
  255. quadra/tasks/patch.py +492 -0
  256. quadra/tasks/segmentation.py +389 -0
  257. quadra/tasks/ssl.py +560 -0
  258. quadra/trainers/README.md +3 -0
  259. quadra/trainers/__init__.py +0 -0
  260. quadra/trainers/classification.py +179 -0
  261. quadra/utils/__init__.py +0 -0
  262. quadra/utils/anomaly.py +112 -0
  263. quadra/utils/classification.py +618 -0
  264. quadra/utils/deprecation.py +31 -0
  265. quadra/utils/evaluation.py +474 -0
  266. quadra/utils/export.py +585 -0
  267. quadra/utils/imaging.py +32 -0
  268. quadra/utils/logger.py +15 -0
  269. quadra/utils/mlflow.py +98 -0
  270. quadra/utils/model_manager.py +320 -0
  271. quadra/utils/models.py +523 -0
  272. quadra/utils/patch/__init__.py +15 -0
  273. quadra/utils/patch/dataset.py +1433 -0
  274. quadra/utils/patch/metrics.py +449 -0
  275. quadra/utils/patch/model.py +153 -0
  276. quadra/utils/patch/visualization.py +217 -0
  277. quadra/utils/resolver.py +42 -0
  278. quadra/utils/segmentation.py +31 -0
  279. quadra/utils/tests/__init__.py +0 -0
  280. quadra/utils/tests/fixtures/__init__.py +1 -0
  281. quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
  282. quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
  283. quadra/utils/tests/fixtures/dataset/classification.py +406 -0
  284. quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
  285. quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
  286. quadra/utils/tests/fixtures/models/__init__.py +3 -0
  287. quadra/utils/tests/fixtures/models/anomaly.py +89 -0
  288. quadra/utils/tests/fixtures/models/classification.py +45 -0
  289. quadra/utils/tests/fixtures/models/segmentation.py +33 -0
  290. quadra/utils/tests/helpers.py +70 -0
  291. quadra/utils/tests/models.py +27 -0
  292. quadra/utils/utils.py +525 -0
  293. quadra/utils/validator.py +115 -0
  294. quadra/utils/visualization.py +422 -0
  295. quadra/utils/vit_explainability.py +349 -0
  296. quadra-2.2.7.dist-info/LICENSE +201 -0
  297. quadra-2.2.7.dist-info/METADATA +381 -0
  298. quadra-2.2.7.dist-info/RECORD +300 -0
  299. {quadra-0.0.1.dist-info → quadra-2.2.7.dist-info}/WHEEL +1 -1
  300. quadra-2.2.7.dist-info/entry_points.txt +3 -0
  301. quadra-0.0.1.dist-info/METADATA +0 -14
  302. quadra-0.0.1.dist-info/RECORD +0 -4
@@ -0,0 +1,742 @@
1
+ # pylint: disable=unsubscriptable-object,unsupported-assignment-operation,unsupported-membership-test
2
+ from __future__ import annotations
3
+
4
+ import glob
5
+ import os
6
+ import random
7
+ from typing import Any
8
+
9
+ import albumentations
10
+ import cv2
11
+ import numpy as np
12
+ import pandas as pd
13
+ from sklearn.model_selection import train_test_split
14
+ from skmultilearn.model_selection import iterative_train_test_split
15
+ from torch.utils.data import DataLoader
16
+
17
+ from quadra.datamodules.base import BaseDataModule
18
+ from quadra.datasets.segmentation import SegmentationDataset, SegmentationDatasetMulticlass
19
+ from quadra.utils import utils
20
+
21
+ log = utils.get_logger(__name__)
22
+
23
+
24
+ class SegmentationDataModule(BaseDataModule):
25
+ """Base class for segmentation datasets.
26
+
27
+ Args:
28
+ data_path: Path to the data main folder.
29
+ name: The name for the data module. Defaults to "segmentation_datamodule".
30
+ val_size: The validation split. Defaults to 0.2.
31
+ test_size: The test split. Defaults to 0.2.
32
+ seed: Random generator seed. Defaults to 42.
33
+ dataset: Dataset class.
34
+ batch_size: Batch size. Defaults to 32.
35
+ num_workers: Number of workers for dataloaders. Defaults to 16.
36
+ train_transform: Transformations for train dataset.
37
+ Defaults to None.
38
+ val_transform: Transformations for validation dataset.
39
+ Defaults to None.
40
+ test_transform: Transformations for test dataset.
41
+ Defaults to None.
42
+ num_data_class: The number of samples per class. Defaults to None.
43
+ exclude_good: If True, exclude good samples from the dataset. Defaults to False.
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ data_path: str,
49
+ name: str = "segmentation_datamodule",
50
+ test_size: float = 0.3,
51
+ val_size: float = 0.3,
52
+ seed: int = 42,
53
+ dataset: type[SegmentationDataset] = SegmentationDataset,
54
+ batch_size: int = 32,
55
+ num_workers: int = 6,
56
+ train_transform: albumentations.Compose | None = None,
57
+ test_transform: albumentations.Compose | None = None,
58
+ val_transform: albumentations.Compose | None = None,
59
+ train_split_file: str | None = None,
60
+ test_split_file: str | None = None,
61
+ val_split_file: str | None = None,
62
+ num_data_class: int | None = None,
63
+ exclude_good: bool = False,
64
+ **kwargs: Any,
65
+ ):
66
+ super().__init__(
67
+ data_path=data_path,
68
+ name=name,
69
+ seed=seed,
70
+ batch_size=batch_size,
71
+ num_workers=num_workers,
72
+ train_transform=train_transform,
73
+ val_transform=val_transform,
74
+ test_transform=test_transform,
75
+ **kwargs,
76
+ )
77
+ self.test_size = test_size
78
+ self.val_size = val_size
79
+ self.num_data_class = num_data_class
80
+ self.exclude_good = exclude_good
81
+ self.train_split_file = train_split_file
82
+ self.test_split_file = test_split_file
83
+ self.val_split_file = val_split_file
84
+ self.dataset = dataset
85
+ self.train_dataset: SegmentationDataset
86
+ self.val_dataset: SegmentationDataset
87
+ self.test_dataset: SegmentationDataset
88
+
89
+ def _preprocess_mask(self, mask) -> np.ndarray:
90
+ """Binarize mask using 0 as threshold."""
91
+ mask = (mask > 0).astype(np.uint8)
92
+ return mask
93
+
94
+ @staticmethod
95
+ def _resolve_label(path: str) -> int:
96
+ """Resolve label from mask.
97
+
98
+ Args:
99
+ path: Path to the mask.
100
+
101
+ Returns:
102
+ 0 if the mask is empty, 1 otherwise
103
+ """
104
+ if cv2.imread(path).sum() == 0:
105
+ return 0
106
+
107
+ return 1
108
+
109
+ def _read_folder(self, data_path: str) -> tuple[list[str], list[int], list[str]]:
110
+ """Read a folder containing images and masks subfolders.
111
+
112
+ Args:
113
+ data_path: Path to the data folder.
114
+
115
+ Returns:
116
+ List of paths to the images, associated binary targets and list to paths to the masks.
117
+ """
118
+ samples = []
119
+ targets = []
120
+ masks = []
121
+
122
+ for im in glob.glob(os.path.join(data_path, "images", "*")):
123
+ if im[0] == ".":
124
+ continue
125
+
126
+ mask_path = glob.glob(os.path.splitext(im.replace("images", "masks"))[0] + ".*")
127
+
128
+ if len(mask_path) == 0:
129
+ log.debug("Mask not found: %s", os.path.basename(im))
130
+ continue
131
+
132
+ if len(mask_path) > 1:
133
+ raise ValueError(f"Multiple masks found for image: {os.path.basename(im)}, this is not supported")
134
+
135
+ target = self._resolve_label(mask_path[0])
136
+ samples.append(im)
137
+ targets.append(target)
138
+ masks.append(mask_path[0])
139
+
140
+ return samples, targets, masks
141
+
142
+ def _read_split(self, split_file: str) -> tuple[list[str], list[int], list[str]]:
143
+ """Reads split file.
144
+
145
+ Args:
146
+ split_file: Path to the split file.
147
+
148
+ Returns:
149
+ List of paths to images, List of labels.
150
+ """
151
+ samples, targets, masks = [], [], []
152
+ with open(split_file) as f:
153
+ split = f.read().splitlines()
154
+ for sample in split:
155
+ sample_path = os.path.join(self.data_path, sample)
156
+ mask_path = glob.glob(os.path.splitext(sample_path.replace("images", "masks"))[0] + ".*")
157
+
158
+ if len(mask_path) == 0:
159
+ log.debug("Mask not found: %s", os.path.basename(sample_path))
160
+ continue
161
+
162
+ if len(mask_path) > 1:
163
+ raise ValueError(
164
+ f"Multiple masks found for image: {os.path.basename(sample_path)}, this is not supported"
165
+ )
166
+
167
+ target = self._resolve_label(mask_path[0])
168
+ samples.append(sample_path)
169
+ targets.append(target)
170
+ masks.append(mask_path[0])
171
+
172
+ return samples, targets, masks
173
+
174
+ def _prepare_data(self) -> None:
175
+ """Prepare data for training and testing."""
176
+ if not (self.test_split_file and self.train_split_file and self.val_split_file):
177
+ all_samples, all_targets, all_masks = self._read_folder(self.data_path)
178
+ samples_train, samples_test, targets_train, targets_test, masks_train, masks_test = train_test_split(
179
+ all_samples,
180
+ all_targets,
181
+ all_masks,
182
+ test_size=self.test_size,
183
+ random_state=self.seed,
184
+ stratify=all_targets,
185
+ )
186
+ if self.test_split_file:
187
+ samples_test, targets_test, masks_test = self._read_split(self.test_split_file)
188
+ if not self.train_split_file:
189
+ samples_train, targets_train, masks_train = [], [], []
190
+ for sample, target, mask in zip(all_samples, all_targets, all_masks):
191
+ if sample not in samples_test:
192
+ samples_train.append(sample)
193
+ targets_train.append(target)
194
+ masks_train.append(mask)
195
+
196
+ if self.train_split_file:
197
+ samples_train, targets_train, masks_train = self._read_split(self.train_split_file)
198
+ if not self.test_split_file:
199
+ samples_test, targets_test, masks_test = [], [], []
200
+ for sample, target, mask in zip(all_samples, all_targets, all_masks):
201
+ if sample not in samples_train:
202
+ samples_test.append(sample)
203
+ targets_test.append(target)
204
+ masks_test.append(mask)
205
+
206
+ if self.val_split_file:
207
+ if not self.test_split_file or not self.train_split_file:
208
+ raise ValueError("Validation split file is specified but no train or test split file is specified.")
209
+ samples_val, targets_val, masks_val = self._read_split(self.val_split_file)
210
+ else:
211
+ samples_train, samples_val, targets_train, targets_val, masks_train, masks_val = train_test_split(
212
+ samples_train,
213
+ targets_train,
214
+ masks_train,
215
+ test_size=self.val_size,
216
+ random_state=self.seed,
217
+ stratify=targets_train,
218
+ )
219
+
220
+ if self.exclude_good:
221
+ samples_train = list(np.array(samples_train)[np.array(targets_train) != 0])
222
+ masks_train = list(np.array(masks_train)[np.array(targets_train) != 0])
223
+ targets_train = list(np.array(targets_train)[np.array(targets_train) != 0])
224
+
225
+ if self.num_data_class is not None:
226
+ samples_train_topick = []
227
+ targets_train_topick = []
228
+ masks_train_topick = []
229
+
230
+ for cl in np.unique(targets_train):
231
+ idx = np.where(np.array(targets_train) == cl)[0].tolist()
232
+ random.seed(self.seed)
233
+ random.shuffle(idx)
234
+ to_pick = idx[: self.num_data_class]
235
+ for i in to_pick:
236
+ samples_train_topick.append(samples_train[i])
237
+ targets_train_topick.append(cl)
238
+ masks_train_topick.append(masks_train[i])
239
+
240
+ samples_train = samples_train_topick
241
+ targets_train = targets_train_topick
242
+ masks_train = masks_train_topick
243
+
244
+ df_list = []
245
+ for split_name, samples, targets, masks in [
246
+ ("train", samples_train, targets_train, masks_train),
247
+ ("val", samples_val, targets_val, masks_val),
248
+ ("test", samples_test, targets_test, masks_test),
249
+ ]:
250
+ df = pd.DataFrame({"samples": samples, "targets": targets, "masks": masks})
251
+ df["split"] = split_name
252
+ df_list.append(df)
253
+
254
+ self.data = pd.concat(df_list, axis=0)
255
+
256
+ def setup(self, stage=None):
257
+ """Setup data module based on stages of training."""
258
+ if stage in ["fit", "train"]:
259
+ self.train_dataset = self.dataset(
260
+ image_paths=self.data[self.data["split"] == "train"]["samples"].tolist(),
261
+ mask_paths=self.data[self.data["split"] == "train"]["masks"].tolist(),
262
+ mask_preprocess=self._preprocess_mask,
263
+ labels=self.data[self.data["split"] == "train"]["targets"].tolist(),
264
+ object_masks=None,
265
+ transform=self.train_transform,
266
+ batch_size=None,
267
+ defect_transform=None,
268
+ resize=None,
269
+ )
270
+ self.val_dataset = self.dataset(
271
+ image_paths=self.data[self.data["split"] == "val"]["samples"].tolist(),
272
+ mask_paths=self.data[self.data["split"] == "val"]["masks"].tolist(),
273
+ defect_transform=None,
274
+ labels=self.data[self.data["split"] == "val"]["targets"].tolist(),
275
+ object_masks=None,
276
+ batch_size=None,
277
+ mask_preprocess=self._preprocess_mask,
278
+ transform=self.test_transform,
279
+ resize=None,
280
+ )
281
+ elif stage == "test":
282
+ self.test_dataset = self.dataset(
283
+ image_paths=self.data[self.data["split"] == "test"]["samples"].tolist(),
284
+ mask_paths=self.data[self.data["split"] == "test"]["masks"].tolist(),
285
+ labels=self.data[self.data["split"] == "test"]["targets"].tolist(),
286
+ object_masks=None,
287
+ batch_size=None,
288
+ mask_preprocess=self._preprocess_mask,
289
+ transform=self.test_transform,
290
+ resize=None,
291
+ )
292
+ elif stage == "predict":
293
+ pass
294
+ else:
295
+ raise ValueError(f"Unknown stage {stage}")
296
+
297
+ def train_dataloader(self) -> DataLoader:
298
+ """Returns the train dataloader.
299
+
300
+ Raises:
301
+ ValueError: If train dataset is not initialized.
302
+
303
+ Returns:
304
+ Train dataloader.
305
+ """
306
+ if not self.train_dataset_available:
307
+ raise ValueError("Train dataset is not initialized")
308
+
309
+ return DataLoader(
310
+ self.train_dataset,
311
+ batch_size=self.batch_size,
312
+ shuffle=True,
313
+ num_workers=self.num_workers,
314
+ drop_last=False,
315
+ pin_memory=True,
316
+ persistent_workers=self.num_workers > 0,
317
+ )
318
+
319
+ def val_dataloader(self) -> DataLoader:
320
+ """Returns the validation dataloader.
321
+
322
+ Raises:
323
+ ValueError: If validation dataset is not initialized.
324
+
325
+ Returns:
326
+ val dataloader.
327
+ """
328
+ if not self.val_dataset_available:
329
+ raise ValueError("Validation dataset is not initialized")
330
+
331
+ return DataLoader(
332
+ self.val_dataset,
333
+ batch_size=self.batch_size,
334
+ shuffle=False,
335
+ num_workers=self.num_workers,
336
+ drop_last=False,
337
+ pin_memory=True,
338
+ persistent_workers=self.num_workers > 0,
339
+ )
340
+
341
+ def test_dataloader(self) -> DataLoader:
342
+ """Returns the test dataloader.
343
+
344
+ Raises:
345
+ ValueError: If test dataset is not initialized.
346
+
347
+
348
+ Returns:
349
+ test dataloader.
350
+ """
351
+ if not self.test_dataset_available:
352
+ raise ValueError("Test dataset is not initialized")
353
+
354
+ loader = DataLoader(
355
+ self.test_dataset,
356
+ batch_size=self.batch_size,
357
+ shuffle=False,
358
+ num_workers=self.num_workers,
359
+ drop_last=False,
360
+ pin_memory=True,
361
+ persistent_workers=self.num_workers > 0,
362
+ )
363
+ return loader
364
+
365
+ def predict_dataloader(self) -> DataLoader:
366
+ """Returns a dataloader used for predictions."""
367
+ return self.test_dataloader()
368
+
369
+
370
+ class SegmentationMulticlassDataModule(BaseDataModule):
371
+ """Base class for segmentation datasets with multiple classes.
372
+
373
+ Args:
374
+ data_path : Path to the data main folder.
375
+ idx_to_class: dict with corrispondence btw mask index and classes: {1: class_1, 2: class_2, ..., N: class_N}
376
+ except background class which is 0.
377
+ name : The name for the data module. Defaults to "multiclass_segmentation_datamodule".
378
+ dataset: Dataset class.
379
+ batch_size : Batch size. Defaults to 32.
380
+ val_size : The validation split. Defaults to 0.3.
381
+ test_size : The test split. Defaults to 0.3.
382
+ seed : Random generator seed. Defaults to 42.
383
+ num_workers: Number of workers for dataloaders. Defaults to 6.
384
+ train_transform: Transformations for train dataset.
385
+ Defaults to None.
386
+ val_transform : Transformations for validation dataset.
387
+ Defaults to None.
388
+ test_transform : Transformations for test dataset.
389
+ Defaults to None.
390
+ train_split_file: path to txt file with training samples list
391
+ val_split_file: path to txt file with validation samples list
392
+ test_split_file: path to txt file with test samples list
393
+ exclude_good : If True, exclude good samples from the dataset. Defaults to False.
394
+ num_data_train: number of samples to use in the train split (shuffle the samples and pick the
395
+ first num_data_train)
396
+ one_hot_encoding: if True, the labels are one-hot encoded to N channels, where N is the number of classes.
397
+ If False, masks are single channel that contains values as class indexes. Defaults to True.
398
+ """
399
+
400
+ def __init__(
401
+ self,
402
+ data_path: str,
403
+ idx_to_class: dict,
404
+ name: str = "multiclass_segmentation_datamodule",
405
+ dataset: type[SegmentationDatasetMulticlass] = SegmentationDatasetMulticlass,
406
+ batch_size: int = 32,
407
+ test_size: float = 0.3,
408
+ val_size: float = 0.3,
409
+ seed: int = 42,
410
+ num_workers: int = 6,
411
+ train_transform: albumentations.Compose | None = None,
412
+ test_transform: albumentations.Compose | None = None,
413
+ val_transform: albumentations.Compose | None = None,
414
+ train_split_file: str | None = None,
415
+ test_split_file: str | None = None,
416
+ val_split_file: str | None = None,
417
+ exclude_good: bool = False,
418
+ num_data_train: int | None = None,
419
+ one_hot_encoding: bool = False,
420
+ **kwargs: Any,
421
+ ):
422
+ super().__init__(
423
+ data_path=data_path,
424
+ name=name,
425
+ seed=seed,
426
+ batch_size=batch_size,
427
+ num_workers=num_workers,
428
+ train_transform=train_transform,
429
+ val_transform=val_transform,
430
+ test_transform=test_transform,
431
+ **kwargs,
432
+ )
433
+ self.test_size = test_size
434
+ self.val_size = val_size
435
+ self.exclude_good = exclude_good
436
+ self.train_split_file = train_split_file
437
+ self.test_split_file = test_split_file
438
+ self.val_split_file = val_split_file
439
+ self.dataset = dataset
440
+ self.idx_to_class = idx_to_class
441
+ self.num_data_train = num_data_train
442
+ self.one_hot_encoding = one_hot_encoding
443
+ self.train_dataset: SegmentationDataset
444
+ self.val_dataset: SegmentationDataset
445
+ self.test_dataset: SegmentationDataset
446
+
447
+ def _preprocess_mask(self, mask: np.ndarray) -> np.ndarray:
448
+ """Function to preprocess the mask.
449
+
450
+ Args:
451
+ mask: a numpy array of dimension HxW with values in [0] + self.idx_to_class.
452
+
453
+ Output:
454
+ a binary numpy array with dims len(self.idx_to_class+1)xHxW
455
+ """
456
+ # For each class we must have a channel
457
+ multilayer_mask = np.zeros((len(self.idx_to_class) + 1, *mask.shape[:2]))
458
+ for idx in self.idx_to_class:
459
+ multilayer_mask[int(idx)] = (mask == int(idx)).astype(np.uint8)
460
+
461
+ return multilayer_mask
462
+
463
+ def _resolve_label(self, path: str) -> np.ndarray:
464
+ """Return a binary array of 1 + len(self.idx_to_class) with 1 if that class is present in the mask."""
465
+ one_hot = np.zeros([len(self.idx_to_class) + 1], np.uint8) # add class 0
466
+ mask = cv2.imread(path, 0)
467
+ if mask.sum() == 0:
468
+ one_hot[0] = 1
469
+ else:
470
+ indices = np.unique(mask)
471
+ one_hot[indices] = 1
472
+ one_hot[0] = 0
473
+
474
+ return one_hot
475
+
476
+ def _read_folder(self, data_path: str) -> tuple[list[str], list[np.ndarray], list[str]]:
477
+ """Read a folder containing images and masks subfolders.
478
+
479
+ Args:
480
+ data_path: Path to the data folder.
481
+
482
+ Returns:
483
+ List of paths to the images, list of associated one-hot encoded targets and list of mask paths.
484
+ """
485
+ samples = []
486
+ targets = []
487
+ masks = []
488
+
489
+ for im in glob.glob(os.path.join(data_path, "images", "*")):
490
+ if im[0] == ".":
491
+ continue
492
+
493
+ mask_path = glob.glob(os.path.splitext(im.replace("images", "masks"))[0] + ".*")
494
+
495
+ if len(mask_path) == 0:
496
+ log.debug("Mask not found: %s", os.path.basename(im))
497
+ continue
498
+
499
+ if len(mask_path) > 1:
500
+ raise ValueError(f"Multiple masks found for image: {os.path.basename(im)}, this is not supported")
501
+
502
+ target = self._resolve_label(mask_path[0])
503
+ samples.append(im)
504
+ targets.append(target)
505
+ masks.append(mask_path[0])
506
+
507
+ return samples, targets, masks
508
+
509
+ def _read_split(self, split_file: str) -> tuple[list[str], list[np.ndarray], list[str]]:
510
+ """Reads split file.
511
+
512
+ Args:
513
+ split_file: Path to the split file.
514
+
515
+ Returns:
516
+ List of paths to images, labels and mask paths.
517
+ """
518
+ samples, targets, masks = [], [], []
519
+ with open(split_file) as f:
520
+ split = f.read().splitlines()
521
+ for sample in split:
522
+ sample_path = os.path.join(self.data_path, sample)
523
+ mask_path = glob.glob(os.path.splitext(sample_path.replace("images", "masks"))[0] + ".*")
524
+
525
+ if len(mask_path) == 0:
526
+ log.debug("Mask not found: %s", os.path.basename(sample_path))
527
+ continue
528
+
529
+ if len(mask_path) > 1:
530
+ raise ValueError(
531
+ f"Multiple masks found for image: {os.path.basename(sample_path)}, this is not supported"
532
+ )
533
+
534
+ target = self._resolve_label(mask_path[0])
535
+ samples.append(sample_path)
536
+ targets.append(target)
537
+ masks.append(mask_path[0])
538
+
539
+ return samples, targets, masks
540
+
541
+ def _prepare_data(self) -> None:
542
+ """Prepare data for training and testing."""
543
+ if not (self.train_split_file and self.test_split_file and self.val_split_file):
544
+ all_samples, all_targets, all_masks = self._read_folder(self.data_path)
545
+
546
+ (
547
+ samples_and_masks_train,
548
+ targets_train,
549
+ samples_and_masks_test,
550
+ targets_test,
551
+ ) = iterative_train_test_split(
552
+ np.expand_dims(np.array(list(zip(all_samples, all_masks))), 1),
553
+ np.array(all_targets),
554
+ test_size=self.test_size,
555
+ )
556
+
557
+ samples_train, samples_test = samples_and_masks_train[:, 0, 0], samples_and_masks_test[:, 0, 0]
558
+ masks_train, masks_test = samples_and_masks_train[:, 0, 1], samples_and_masks_test[:, 0, 1]
559
+
560
+ if self.test_split_file:
561
+ samples_test, targets_test, masks_test = self._read_split(self.test_split_file)
562
+ if not self.train_split_file:
563
+ samples_train, targets_train, masks_train = [], [], []
564
+ for sample, target, mask in zip(all_samples, all_targets, all_masks):
565
+ if sample not in samples_test:
566
+ samples_train.append(sample)
567
+ targets_train.append(target)
568
+ masks_train.append(mask)
569
+
570
+ if self.train_split_file:
571
+ samples_train, targets_train, masks_train = self._read_split(self.train_split_file)
572
+ if not self.test_split_file:
573
+ samples_test, targets_test, masks_test = [], [], []
574
+ for sample, target, mask in zip(all_samples, all_targets, all_masks):
575
+ if sample not in samples_train:
576
+ samples_test.append(sample)
577
+ targets_test.append(target)
578
+ masks_test.append(mask)
579
+
580
+ if self.val_split_file:
581
+ samples_val, targets_val, masks_val = self._read_split(self.val_split_file)
582
+ if not self.test_split_file or not self.train_split_file:
583
+ raise ValueError("Validation split file is specified but no train or test split file is specified.")
584
+ else:
585
+ samples_and_masks_train, targets_train, samples_and_masks_val, targets_val = iterative_train_test_split(
586
+ np.expand_dims(np.array(list(zip(samples_train, masks_train))), 1),
587
+ np.array(targets_train),
588
+ test_size=self.val_size,
589
+ )
590
+ samples_train = samples_and_masks_train[:, 0, 0]
591
+ samples_val = samples_and_masks_val[:, 0, 0]
592
+ masks_train = samples_and_masks_train[:, 0, 1]
593
+ masks_val = samples_and_masks_val[:, 0, 1]
594
+
595
+ # Pre-ordering train and val samples for determinism
596
+ # They will be shuffled (with a seed) during training
597
+ sorting_indices_train = np.argsort(list(samples_train))
598
+ samples_train = [samples_train[i] for i in sorting_indices_train]
599
+ targets_train = [targets_train[i] for i in sorting_indices_train]
600
+ masks_train = [masks_train[i] for i in sorting_indices_train]
601
+
602
+ sorting_indices_val = np.argsort(samples_val)
603
+ samples_val = [samples_val[i] for i in sorting_indices_val]
604
+ targets_val = [targets_val[i] for i in sorting_indices_val]
605
+ masks_val = [masks_val[i] for i in sorting_indices_val]
606
+
607
+ if self.exclude_good:
608
+ samples_train = list(np.array(samples_train)[np.array(targets_train)[:, 0] == 0])
609
+ masks_train = list(np.array(masks_train)[np.array(targets_train)[:, 0] == 0])
610
+ targets_train = list(np.array(targets_train)[np.array(targets_train)[:, 0] == 0])
611
+
612
+ if self.num_data_train is not None:
613
+ # Generate a random permutation
614
+ random_permutation = list(range(len(samples_train)))
615
+ random.seed(self.seed)
616
+ random.shuffle(random_permutation)
617
+
618
+ # Shuffle samples_train, targets_train, and masks_train using the same permutation
619
+ samples_train = [samples_train[i] for i in random_permutation]
620
+ targets_train = [targets_train[i] for i in random_permutation]
621
+ masks_train = [masks_train[i] for i in random_permutation]
622
+
623
+ samples_train = np.array(samples_train)[: self.num_data_train]
624
+ targets_train = np.array(targets_train)[: self.num_data_train]
625
+ masks_train = np.array(masks_train)[: self.num_data_train]
626
+
627
+ df_list = []
628
+ for split_name, samples, targets, masks in [
629
+ ("train", samples_train, targets_train, masks_train),
630
+ ("val", samples_val, targets_val, masks_val),
631
+ ("test", samples_test, targets_test, masks_test),
632
+ ]:
633
+ df = pd.DataFrame({"samples": samples, "targets": list(targets), "masks": masks})
634
+ df["split"] = split_name
635
+ df_list.append(df)
636
+
637
+ self.data = pd.concat(df_list, axis=0)
638
+
639
+ def setup(self, stage=None):
640
+ """Setup data module based on stages of training."""
641
+ if stage in ["fit", "train"]:
642
+ train_data = self.data[self.data["split"] == "train"]
643
+ val_data = self.data[self.data["split"] == "val"]
644
+
645
+ self.train_dataset = self.dataset(
646
+ image_paths=train_data["samples"].tolist(),
647
+ mask_paths=train_data["masks"].tolist(),
648
+ idx_to_class=self.idx_to_class,
649
+ transform=self.train_transform,
650
+ one_hot=self.one_hot_encoding,
651
+ )
652
+ self.val_dataset = self.dataset(
653
+ image_paths=val_data["samples"].tolist(),
654
+ mask_paths=val_data["masks"].tolist(),
655
+ transform=self.val_transform,
656
+ idx_to_class=self.idx_to_class,
657
+ one_hot=self.one_hot_encoding,
658
+ )
659
+ elif stage == "test":
660
+ self.test_dataset = self.dataset(
661
+ image_paths=self.data[self.data["split"] == "test"]["samples"].tolist(),
662
+ mask_paths=self.data[self.data["split"] == "test"]["masks"].tolist(),
663
+ transform=self.test_transform,
664
+ idx_to_class=self.idx_to_class,
665
+ one_hot=self.one_hot_encoding,
666
+ )
667
+ elif stage == "predict":
668
+ pass
669
+ else:
670
+ raise ValueError(f"Unknown stage {stage}")
671
+
672
+ def train_dataloader(self) -> DataLoader:
673
+ """Returns the train dataloader.
674
+
675
+ Raises:
676
+ ValueError: If train dataset is not initialized.
677
+
678
+ Returns:
679
+ Train dataloader.
680
+ """
681
+ if not self.train_dataset_available:
682
+ raise ValueError("Train dataset is not initialized")
683
+
684
+ return DataLoader(
685
+ self.train_dataset,
686
+ batch_size=self.batch_size,
687
+ shuffle=True,
688
+ num_workers=self.num_workers,
689
+ drop_last=False,
690
+ pin_memory=True,
691
+ persistent_workers=self.num_workers > 0,
692
+ )
693
+
694
+ def val_dataloader(self) -> DataLoader:
695
+ """Returns the validation dataloader.
696
+
697
+ Raises:
698
+ ValueError: If validation dataset is not initialized.
699
+
700
+ Returns:
701
+ val dataloader.
702
+ """
703
+ if not self.val_dataset_available:
704
+ raise ValueError("Validation dataset is not initialized")
705
+
706
+ return DataLoader(
707
+ self.val_dataset,
708
+ batch_size=self.batch_size,
709
+ shuffle=False,
710
+ num_workers=self.num_workers,
711
+ drop_last=False,
712
+ pin_memory=True,
713
+ persistent_workers=self.num_workers > 0,
714
+ )
715
+
716
+ def test_dataloader(self) -> DataLoader:
717
+ """Returns the test dataloader.
718
+
719
+ Raises:
720
+ ValueError: If test dataset is not initialized.
721
+
722
+
723
+ Returns:
724
+ test dataloader.
725
+ """
726
+ if not self.test_dataset_available:
727
+ raise ValueError("Test dataset is not initialized")
728
+
729
+ loader = DataLoader(
730
+ self.test_dataset,
731
+ batch_size=self.batch_size,
732
+ shuffle=False,
733
+ num_workers=self.num_workers,
734
+ drop_last=False,
735
+ pin_memory=True,
736
+ persistent_workers=self.num_workers > 0,
737
+ )
738
+ return loader
739
+
740
+ def predict_dataloader(self) -> DataLoader:
741
+ """Returns a dataloader used for predictions."""
742
+ return self.test_dataloader()