quadra 0.0.1__py3-none-any.whl → 2.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (302) hide show
  1. hydra_plugins/quadra_searchpath_plugin.py +30 -0
  2. quadra/__init__.py +6 -0
  3. quadra/callbacks/__init__.py +0 -0
  4. quadra/callbacks/anomalib.py +289 -0
  5. quadra/callbacks/lightning.py +501 -0
  6. quadra/callbacks/mlflow.py +291 -0
  7. quadra/callbacks/scheduler.py +69 -0
  8. quadra/configs/__init__.py +0 -0
  9. quadra/configs/backbone/caformer_m36.yaml +8 -0
  10. quadra/configs/backbone/caformer_s36.yaml +8 -0
  11. quadra/configs/backbone/convnextv2_base.yaml +8 -0
  12. quadra/configs/backbone/convnextv2_femto.yaml +8 -0
  13. quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
  14. quadra/configs/backbone/dino_vitb8.yaml +12 -0
  15. quadra/configs/backbone/dino_vits8.yaml +12 -0
  16. quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
  17. quadra/configs/backbone/dinov2_vits14.yaml +12 -0
  18. quadra/configs/backbone/efficientnet_b0.yaml +8 -0
  19. quadra/configs/backbone/efficientnet_b1.yaml +8 -0
  20. quadra/configs/backbone/efficientnet_b2.yaml +8 -0
  21. quadra/configs/backbone/efficientnet_b3.yaml +8 -0
  22. quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
  23. quadra/configs/backbone/levit_128s.yaml +8 -0
  24. quadra/configs/backbone/mnasnet0_5.yaml +9 -0
  25. quadra/configs/backbone/resnet101.yaml +8 -0
  26. quadra/configs/backbone/resnet18.yaml +8 -0
  27. quadra/configs/backbone/resnet18_ssl.yaml +8 -0
  28. quadra/configs/backbone/resnet50.yaml +8 -0
  29. quadra/configs/backbone/smp.yaml +9 -0
  30. quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
  31. quadra/configs/backbone/unetr.yaml +15 -0
  32. quadra/configs/backbone/vit16_base.yaml +9 -0
  33. quadra/configs/backbone/vit16_small.yaml +9 -0
  34. quadra/configs/backbone/vit16_tiny.yaml +9 -0
  35. quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
  36. quadra/configs/callbacks/all.yaml +32 -0
  37. quadra/configs/callbacks/default.yaml +37 -0
  38. quadra/configs/callbacks/default_anomalib.yaml +67 -0
  39. quadra/configs/config.yaml +33 -0
  40. quadra/configs/core/default.yaml +11 -0
  41. quadra/configs/datamodule/base/anomaly.yaml +16 -0
  42. quadra/configs/datamodule/base/classification.yaml +21 -0
  43. quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
  44. quadra/configs/datamodule/base/segmentation.yaml +18 -0
  45. quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
  46. quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
  47. quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
  48. quadra/configs/datamodule/base/ssl.yaml +21 -0
  49. quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
  50. quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
  51. quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
  52. quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
  53. quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
  54. quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
  55. quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
  56. quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
  57. quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
  58. quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
  59. quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
  60. quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
  61. quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
  62. quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
  63. quadra/configs/experiment/base/classification/classification.yaml +73 -0
  64. quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
  65. quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
  66. quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
  67. quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
  68. quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
  69. quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
  70. quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
  71. quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
  72. quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
  73. quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
  74. quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
  75. quadra/configs/experiment/base/ssl/byol.yaml +43 -0
  76. quadra/configs/experiment/base/ssl/dino.yaml +46 -0
  77. quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
  78. quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
  79. quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
  80. quadra/configs/experiment/custom/cls.yaml +12 -0
  81. quadra/configs/experiment/default.yaml +15 -0
  82. quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
  83. quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
  84. quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
  85. quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
  86. quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
  87. quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
  88. quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
  89. quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
  90. quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
  91. quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
  92. quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
  93. quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
  94. quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
  95. quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
  96. quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
  97. quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
  98. quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
  99. quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
  100. quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
  101. quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
  102. quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
  103. quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
  104. quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
  105. quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
  106. quadra/configs/export/default.yaml +13 -0
  107. quadra/configs/hydra/anomaly_custom.yaml +15 -0
  108. quadra/configs/hydra/default.yaml +14 -0
  109. quadra/configs/inference/default.yaml +26 -0
  110. quadra/configs/logger/comet.yaml +10 -0
  111. quadra/configs/logger/csv.yaml +5 -0
  112. quadra/configs/logger/mlflow.yaml +12 -0
  113. quadra/configs/logger/tensorboard.yaml +8 -0
  114. quadra/configs/loss/asl.yaml +7 -0
  115. quadra/configs/loss/barlow.yaml +2 -0
  116. quadra/configs/loss/bce.yaml +1 -0
  117. quadra/configs/loss/byol.yaml +1 -0
  118. quadra/configs/loss/cross_entropy.yaml +1 -0
  119. quadra/configs/loss/dino.yaml +8 -0
  120. quadra/configs/loss/simclr.yaml +2 -0
  121. quadra/configs/loss/simsiam.yaml +1 -0
  122. quadra/configs/loss/smp_ce.yaml +3 -0
  123. quadra/configs/loss/smp_dice.yaml +2 -0
  124. quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
  125. quadra/configs/loss/smp_mcc.yaml +2 -0
  126. quadra/configs/loss/vicreg.yaml +5 -0
  127. quadra/configs/model/anomalib/cfa.yaml +35 -0
  128. quadra/configs/model/anomalib/cflow.yaml +30 -0
  129. quadra/configs/model/anomalib/csflow.yaml +34 -0
  130. quadra/configs/model/anomalib/dfm.yaml +19 -0
  131. quadra/configs/model/anomalib/draem.yaml +29 -0
  132. quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
  133. quadra/configs/model/anomalib/fastflow.yaml +32 -0
  134. quadra/configs/model/anomalib/padim.yaml +32 -0
  135. quadra/configs/model/anomalib/patchcore.yaml +36 -0
  136. quadra/configs/model/barlow.yaml +16 -0
  137. quadra/configs/model/byol.yaml +25 -0
  138. quadra/configs/model/classification.yaml +10 -0
  139. quadra/configs/model/dino.yaml +26 -0
  140. quadra/configs/model/logistic_regression.yaml +4 -0
  141. quadra/configs/model/multilabel_classification.yaml +9 -0
  142. quadra/configs/model/simclr.yaml +18 -0
  143. quadra/configs/model/simsiam.yaml +24 -0
  144. quadra/configs/model/smp.yaml +4 -0
  145. quadra/configs/model/smp_multiclass.yaml +4 -0
  146. quadra/configs/model/vicreg.yaml +16 -0
  147. quadra/configs/optimizer/adam.yaml +5 -0
  148. quadra/configs/optimizer/adamw.yaml +3 -0
  149. quadra/configs/optimizer/default.yaml +4 -0
  150. quadra/configs/optimizer/lars.yaml +8 -0
  151. quadra/configs/optimizer/sgd.yaml +4 -0
  152. quadra/configs/scheduler/default.yaml +5 -0
  153. quadra/configs/scheduler/rop.yaml +5 -0
  154. quadra/configs/scheduler/step.yaml +3 -0
  155. quadra/configs/scheduler/warmrestart.yaml +2 -0
  156. quadra/configs/scheduler/warmup.yaml +6 -0
  157. quadra/configs/task/anomalib/cfa.yaml +5 -0
  158. quadra/configs/task/anomalib/cflow.yaml +5 -0
  159. quadra/configs/task/anomalib/csflow.yaml +5 -0
  160. quadra/configs/task/anomalib/draem.yaml +5 -0
  161. quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
  162. quadra/configs/task/anomalib/fastflow.yaml +5 -0
  163. quadra/configs/task/anomalib/inference.yaml +3 -0
  164. quadra/configs/task/anomalib/padim.yaml +5 -0
  165. quadra/configs/task/anomalib/patchcore.yaml +5 -0
  166. quadra/configs/task/classification.yaml +6 -0
  167. quadra/configs/task/classification_evaluation.yaml +6 -0
  168. quadra/configs/task/default.yaml +1 -0
  169. quadra/configs/task/segmentation.yaml +9 -0
  170. quadra/configs/task/segmentation_evaluation.yaml +3 -0
  171. quadra/configs/task/sklearn_classification.yaml +13 -0
  172. quadra/configs/task/sklearn_classification_patch.yaml +11 -0
  173. quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
  174. quadra/configs/task/sklearn_classification_test.yaml +8 -0
  175. quadra/configs/task/ssl.yaml +2 -0
  176. quadra/configs/trainer/lightning_cpu.yaml +36 -0
  177. quadra/configs/trainer/lightning_gpu.yaml +35 -0
  178. quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
  179. quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
  180. quadra/configs/trainer/lightning_multigpu.yaml +37 -0
  181. quadra/configs/trainer/sklearn_classification.yaml +7 -0
  182. quadra/configs/transforms/byol.yaml +47 -0
  183. quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
  184. quadra/configs/transforms/default.yaml +37 -0
  185. quadra/configs/transforms/default_numpy.yaml +24 -0
  186. quadra/configs/transforms/default_resize.yaml +22 -0
  187. quadra/configs/transforms/dino.yaml +63 -0
  188. quadra/configs/transforms/linear_eval.yaml +18 -0
  189. quadra/datamodules/__init__.py +20 -0
  190. quadra/datamodules/anomaly.py +180 -0
  191. quadra/datamodules/base.py +375 -0
  192. quadra/datamodules/classification.py +1003 -0
  193. quadra/datamodules/generic/__init__.py +0 -0
  194. quadra/datamodules/generic/imagenette.py +144 -0
  195. quadra/datamodules/generic/mnist.py +81 -0
  196. quadra/datamodules/generic/mvtec.py +58 -0
  197. quadra/datamodules/generic/oxford_pet.py +163 -0
  198. quadra/datamodules/patch.py +190 -0
  199. quadra/datamodules/segmentation.py +742 -0
  200. quadra/datamodules/ssl.py +140 -0
  201. quadra/datasets/__init__.py +17 -0
  202. quadra/datasets/anomaly.py +287 -0
  203. quadra/datasets/classification.py +241 -0
  204. quadra/datasets/patch.py +138 -0
  205. quadra/datasets/segmentation.py +239 -0
  206. quadra/datasets/ssl.py +110 -0
  207. quadra/losses/__init__.py +0 -0
  208. quadra/losses/classification/__init__.py +6 -0
  209. quadra/losses/classification/asl.py +83 -0
  210. quadra/losses/classification/focal.py +320 -0
  211. quadra/losses/classification/prototypical.py +148 -0
  212. quadra/losses/ssl/__init__.py +17 -0
  213. quadra/losses/ssl/barlowtwins.py +47 -0
  214. quadra/losses/ssl/byol.py +37 -0
  215. quadra/losses/ssl/dino.py +129 -0
  216. quadra/losses/ssl/hyperspherical.py +45 -0
  217. quadra/losses/ssl/idmm.py +50 -0
  218. quadra/losses/ssl/simclr.py +67 -0
  219. quadra/losses/ssl/simsiam.py +30 -0
  220. quadra/losses/ssl/vicreg.py +76 -0
  221. quadra/main.py +46 -0
  222. quadra/metrics/__init__.py +3 -0
  223. quadra/metrics/segmentation.py +251 -0
  224. quadra/models/__init__.py +0 -0
  225. quadra/models/base.py +151 -0
  226. quadra/models/classification/__init__.py +8 -0
  227. quadra/models/classification/backbones.py +149 -0
  228. quadra/models/classification/base.py +92 -0
  229. quadra/models/evaluation.py +322 -0
  230. quadra/modules/__init__.py +0 -0
  231. quadra/modules/backbone.py +30 -0
  232. quadra/modules/base.py +312 -0
  233. quadra/modules/classification/__init__.py +3 -0
  234. quadra/modules/classification/base.py +331 -0
  235. quadra/modules/ssl/__init__.py +17 -0
  236. quadra/modules/ssl/barlowtwins.py +59 -0
  237. quadra/modules/ssl/byol.py +172 -0
  238. quadra/modules/ssl/common.py +285 -0
  239. quadra/modules/ssl/dino.py +186 -0
  240. quadra/modules/ssl/hyperspherical.py +206 -0
  241. quadra/modules/ssl/idmm.py +98 -0
  242. quadra/modules/ssl/simclr.py +73 -0
  243. quadra/modules/ssl/simsiam.py +68 -0
  244. quadra/modules/ssl/vicreg.py +67 -0
  245. quadra/optimizers/__init__.py +4 -0
  246. quadra/optimizers/lars.py +153 -0
  247. quadra/optimizers/sam.py +127 -0
  248. quadra/schedulers/__init__.py +3 -0
  249. quadra/schedulers/base.py +44 -0
  250. quadra/schedulers/warmup.py +127 -0
  251. quadra/tasks/__init__.py +24 -0
  252. quadra/tasks/anomaly.py +582 -0
  253. quadra/tasks/base.py +397 -0
  254. quadra/tasks/classification.py +1264 -0
  255. quadra/tasks/patch.py +492 -0
  256. quadra/tasks/segmentation.py +389 -0
  257. quadra/tasks/ssl.py +560 -0
  258. quadra/trainers/README.md +3 -0
  259. quadra/trainers/__init__.py +0 -0
  260. quadra/trainers/classification.py +179 -0
  261. quadra/utils/__init__.py +0 -0
  262. quadra/utils/anomaly.py +112 -0
  263. quadra/utils/classification.py +618 -0
  264. quadra/utils/deprecation.py +31 -0
  265. quadra/utils/evaluation.py +474 -0
  266. quadra/utils/export.py +579 -0
  267. quadra/utils/imaging.py +32 -0
  268. quadra/utils/logger.py +15 -0
  269. quadra/utils/mlflow.py +98 -0
  270. quadra/utils/model_manager.py +320 -0
  271. quadra/utils/models.py +524 -0
  272. quadra/utils/patch/__init__.py +15 -0
  273. quadra/utils/patch/dataset.py +1433 -0
  274. quadra/utils/patch/metrics.py +449 -0
  275. quadra/utils/patch/model.py +153 -0
  276. quadra/utils/patch/visualization.py +217 -0
  277. quadra/utils/resolver.py +42 -0
  278. quadra/utils/segmentation.py +31 -0
  279. quadra/utils/tests/__init__.py +0 -0
  280. quadra/utils/tests/fixtures/__init__.py +1 -0
  281. quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
  282. quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
  283. quadra/utils/tests/fixtures/dataset/classification.py +406 -0
  284. quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
  285. quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
  286. quadra/utils/tests/fixtures/models/__init__.py +3 -0
  287. quadra/utils/tests/fixtures/models/anomaly.py +89 -0
  288. quadra/utils/tests/fixtures/models/classification.py +45 -0
  289. quadra/utils/tests/fixtures/models/segmentation.py +33 -0
  290. quadra/utils/tests/helpers.py +70 -0
  291. quadra/utils/tests/models.py +27 -0
  292. quadra/utils/utils.py +525 -0
  293. quadra/utils/validator.py +115 -0
  294. quadra/utils/visualization.py +422 -0
  295. quadra/utils/vit_explainability.py +349 -0
  296. quadra-2.1.13.dist-info/LICENSE +201 -0
  297. quadra-2.1.13.dist-info/METADATA +386 -0
  298. quadra-2.1.13.dist-info/RECORD +300 -0
  299. {quadra-0.0.1.dist-info → quadra-2.1.13.dist-info}/WHEEL +1 -1
  300. quadra-2.1.13.dist-info/entry_points.txt +3 -0
  301. quadra-0.0.1.dist-info/METADATA +0 -14
  302. quadra-0.0.1.dist-info/RECORD +0 -4
@@ -0,0 +1,140 @@
1
+ # pylint: disable=unsubscriptable-object
2
+ from __future__ import annotations
3
+
4
+ from typing import Any
5
+
6
+ import numpy as np
7
+ import torch
8
+ from sklearn.model_selection import train_test_split
9
+ from torch.utils.data import DataLoader
10
+
11
+ from quadra.datamodules.classification import ClassificationDataModule
12
+ from quadra.datasets import TwoAugmentationDataset, TwoSetAugmentationDataset
13
+ from quadra.utils import utils
14
+
15
+ log = utils.get_logger(__name__)
16
+
17
+
18
+ class SSLDataModule(ClassificationDataModule):
19
+ """Base class for all data modules for self supervised learning data modules.
20
+
21
+ Args:
22
+ data_path: Path to the data main folder.
23
+ augmentation_dataset: Augmentation dataset
24
+ for training dataset.
25
+ name: The name for the data module. Defaults to "ssl_datamodule".
26
+ split_validation: Whether to split the validation set if . Defaults to True.
27
+ **kwargs: The keyword arguments for the classification data module. Defaults to None.
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ data_path: str,
33
+ augmentation_dataset: TwoAugmentationDataset | TwoSetAugmentationDataset,
34
+ name: str = "ssl_datamodule",
35
+ split_validation: bool = True,
36
+ **kwargs: Any,
37
+ ):
38
+ super().__init__(
39
+ data_path=data_path,
40
+ name=name,
41
+ **kwargs,
42
+ )
43
+ self.augmentation_dataset = augmentation_dataset
44
+ self.classifier_train_dataset: torch.utils.data.Dataset | None = None
45
+ self.split_validation = split_validation
46
+
47
+ def setup(self, stage: str | None = None) -> None:
48
+ """Setup data module based on stages of training."""
49
+ if stage == "fit":
50
+ self.train_dataset = self.dataset(
51
+ samples=self.train_data["samples"].tolist(),
52
+ targets=self.train_data["targets"].tolist(),
53
+ transform=self.train_transform,
54
+ )
55
+
56
+ if np.unique(self.train_data["targets"]).shape[0] > 1 and not self.split_validation:
57
+ self.classifier_train_dataset = self.dataset(
58
+ samples=self.train_data["samples"].tolist(),
59
+ targets=self.train_data["targets"].tolist(),
60
+ transform=self.val_transform,
61
+ )
62
+ self.val_dataset = self.dataset(
63
+ samples=self.val_data["samples"].tolist(),
64
+ targets=self.val_data["targets"].tolist(),
65
+ transform=self.val_transform,
66
+ )
67
+ else:
68
+ train_classifier_samples, val_samples, train_classifier_targets, val_targets = train_test_split(
69
+ self.val_data["samples"],
70
+ self.val_data["targets"],
71
+ test_size=0.3,
72
+ random_state=self.seed,
73
+ stratify=self.val_data["targets"],
74
+ )
75
+
76
+ self.classifier_train_dataset = self.dataset(
77
+ samples=train_classifier_samples,
78
+ targets=train_classifier_targets,
79
+ transform=self.test_transform,
80
+ )
81
+
82
+ self.val_dataset = self.dataset(
83
+ samples=val_samples,
84
+ targets=val_targets,
85
+ transform=self.val_transform,
86
+ )
87
+
88
+ log.warning(
89
+ "The training set contains only one class and cannot be used to train a classifier. To overcome "
90
+ "this issue 70% of the validation set is used to train the classifier. The remaining will be used "
91
+ "as standard validation. To disable this behaviour set the `split_validation` parameter to False."
92
+ )
93
+ self._check_train_dataset_config()
94
+ if stage == "test":
95
+ self.test_dataset = self.dataset(
96
+ samples=self.test_data["samples"].tolist(),
97
+ targets=self.test_data["targets"].tolist(),
98
+ transform=self.test_transform,
99
+ )
100
+
101
+ def _check_train_dataset_config(self):
102
+ """Check if train dataset is configured correctly."""
103
+ if self.train_dataset is None:
104
+ raise ValueError("Train dataset is not initialized")
105
+ if self.augmentation_dataset is None:
106
+ raise ValueError("Augmentation dataset is not initialized")
107
+ if self.train_dataset.transform is not None:
108
+ log.warning("Train dataset transform is not None. It will be applied before SSL augmentations")
109
+
110
+ def train_dataloader(self) -> DataLoader:
111
+ """Returns train dataloader."""
112
+ if not isinstance(self.train_dataset, torch.utils.data.Dataset):
113
+ raise ValueError("Train dataset is not a subclass of `torch.utils.data.Dataset`")
114
+ self.augmentation_dataset.dataset = self.train_dataset
115
+ loader = DataLoader(
116
+ self.augmentation_dataset,
117
+ batch_size=self.batch_size,
118
+ shuffle=True,
119
+ num_workers=self.num_workers,
120
+ drop_last=False,
121
+ pin_memory=True,
122
+ persistent_workers=self.num_workers > 0,
123
+ )
124
+ return loader
125
+
126
+ def classifier_train_dataloader(self) -> DataLoader:
127
+ """Returns classifier train dataloader."""
128
+ if self.classifier_train_dataset is None:
129
+ raise ValueError("Classifier train dataset is not initialized")
130
+
131
+ loader = DataLoader(
132
+ self.classifier_train_dataset,
133
+ batch_size=self.batch_size,
134
+ shuffle=True,
135
+ num_workers=self.num_workers,
136
+ drop_last=False,
137
+ pin_memory=True,
138
+ persistent_workers=self.num_workers > 0,
139
+ )
140
+ return loader
@@ -0,0 +1,17 @@
1
+ from .anomaly import AnomalyDataset
2
+ from .classification import ClassificationDataset, ImageClassificationListDataset, MultilabelClassificationDataset
3
+ from .patch import PatchSklearnClassificationTrainDataset
4
+ from .segmentation import SegmentationDataset, SegmentationDatasetMulticlass
5
+ from .ssl import TwoAugmentationDataset, TwoSetAugmentationDataset
6
+
7
+ __all__ = [
8
+ "ImageClassificationListDataset",
9
+ "ClassificationDataset",
10
+ "SegmentationDataset",
11
+ "SegmentationDatasetMulticlass",
12
+ "PatchSklearnClassificationTrainDataset",
13
+ "MultilabelClassificationDataset",
14
+ "AnomalyDataset",
15
+ "TwoAugmentationDataset",
16
+ "TwoSetAugmentationDataset",
17
+ ]
@@ -0,0 +1,287 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import random
5
+ from pathlib import Path
6
+
7
+ import albumentations as alb
8
+ import cv2
9
+ import numpy as np
10
+ import pandas as pd
11
+ from pandas import DataFrame
12
+ from torch import Tensor
13
+ from torch.utils.data import Dataset
14
+
15
+ from quadra.utils.utils import IMAGE_EXTENSIONS
16
+
17
+
18
+ def create_validation_set_from_test_set(samples: DataFrame, seed: int = 0) -> DataFrame:
19
+ """Craete Validation Set from Test Set.
20
+
21
+ This function creates a validation set from test set by splitting both
22
+ normal and abnormal samples to two.
23
+
24
+ Args:
25
+ samples: Dataframe containing dataset info such as filenames, splits etc.
26
+ seed: Random seed to ensure reproducibility. Defaults to 0.
27
+ """
28
+ if seed > 0:
29
+ random.seed(seed)
30
+
31
+ # Split normal images.
32
+ normal_test_image_indices = samples.index[(samples.split == "test") & (samples.targets == "good")].to_list()
33
+ num_normal_valid_images = len(normal_test_image_indices) // 2
34
+
35
+ indices_to_sample = random.sample(population=normal_test_image_indices, k=num_normal_valid_images)
36
+ samples.loc[indices_to_sample, "split"] = "val"
37
+
38
+ # Split abnormal images.
39
+ abnormal_test_image_indices = samples.index[(samples.split == "test") & (samples.targets != "good")].to_list()
40
+ num_abnormal_valid_images = len(abnormal_test_image_indices) // 2
41
+
42
+ indices_to_sample = random.sample(population=abnormal_test_image_indices, k=num_abnormal_valid_images)
43
+ samples.loc[indices_to_sample, "split"] = "val"
44
+
45
+ return samples
46
+
47
+
48
+ def split_normal_images_in_train_set(samples: DataFrame, split_ratio: float = 0.1, seed: int = 0) -> DataFrame:
49
+ """Split normal images in train set.
50
+
51
+ This function splits the normal images in training set and assigns the
52
+ values to the test set. This is particularly useful especially when the
53
+ test set does not contain any normal images.
54
+
55
+ This is important because when the test set doesn't have any normal images,
56
+ AUC computation fails due to having single class.
57
+
58
+ Args:
59
+ samples: Dataframe containing dataset info such as filenames, splits etc.
60
+ split_ratio: Train-Test normal image split ratio. Defaults to 0.1.
61
+ seed: Random seed to ensure reproducibility. Defaults to 0.
62
+
63
+ Returns:
64
+ Output dataframe where the part of the training set is assigned to test set.
65
+ """
66
+ if seed > 0:
67
+ random.seed(seed)
68
+
69
+ normal_train_image_indices = samples.index[(samples.split == "train") & (samples.targets == "good")].to_list()
70
+ num_normal_train_images = len(normal_train_image_indices)
71
+ num_normal_valid_images = int(num_normal_train_images * split_ratio)
72
+
73
+ indices_to_split_from_train_set = random.sample(population=normal_train_image_indices, k=num_normal_valid_images)
74
+ samples.loc[indices_to_split_from_train_set, "split"] = "test"
75
+
76
+ return samples
77
+
78
+
79
+ def make_anomaly_dataset(
80
+ path: Path,
81
+ split: str | None = None,
82
+ split_ratio: float = 0.1,
83
+ seed: int = 0,
84
+ mask_suffix: str | None = None,
85
+ create_test_set_if_empty: bool = True,
86
+ ) -> DataFrame:
87
+ """Create dataframe by parsing a folder following the MVTec data file structure.
88
+
89
+ The files are expected to follow the structure:
90
+ path/to/dataset/split/label/image_filename.xyz
91
+ path/to/dataset/ground_truth/label/mask_filename.png
92
+
93
+ Masks MUST be png images, no other format is allowed
94
+ Split can be either train/val/test
95
+
96
+ This function creates a dataframe to store the parsed information based on the following format:
97
+ |---|---------------|-------|---------|--------------|-----------------------------------------------|-------------|
98
+ | | path | split | targets | samples | mask_path | label_index |
99
+ |---|---------------|-------|---------|--------------|-----------------------------------------------|-------------|
100
+ | 0 | datasets/name | test | defect | filename.xyz | ground_truth/defect/filename{mask_suffix}.png | 1 |
101
+ |---|---------------|-------|---------|--------------|-----------------------------------------------|-------------|
102
+
103
+ Args:
104
+ path: Path to dataset
105
+ split: Dataset split (i.e., either train or test). Defaults to None.
106
+ split_ratio: Ratio to split normal training images and add to the
107
+ test set in case test set doesn't contain any normal images.
108
+ Defaults to 0.1.
109
+ seed: Random seed to ensure reproducibility when splitting. Defaults to 0.
110
+ mask_suffix: String to append to the base filename to get the mask name, by default for MVTec dataset masks
111
+ are saved as imagename_mask.png in this case the parameter shoul be filled with "_mask"
112
+ create_test_set_if_empty: If True, create a test set if the test set is empty.
113
+
114
+
115
+ Example:
116
+ The following example shows how to get training samples from MVTec bottle category:
117
+
118
+ >>> root = Path('./MVTec')
119
+ >>> category = 'bottle'
120
+ >>> path = root / category
121
+ >>> path
122
+ PosixPath('MVTec/bottle')
123
+
124
+ >>> samples = make_anomaly_dataset(path, split='train', split_ratio=0.1, seed=0)
125
+ >>> samples.head()
126
+ path split label image_path mask_path label_index
127
+ 0 MVTec/bottle train good MVTec/bottle/train/good/105.png MVTec/bottle/ground_truth/good/105_mask.png 0
128
+ 1 MVTec/bottle train good MVTec/bottle/train/good/017.png MVTec/bottle/ground_truth/good/017_mask.png 0
129
+ 2 MVTec/bottle train good MVTec/bottle/train/good/137.png MVTec/bottle/ground_truth/good/137_mask.png 0
130
+ 3 MVTec/bottle train good MVTec/bottle/train/good/152.png MVTec/bottle/ground_truth/good/152_mask.png 0
131
+ 4 MVTec/bottle train good MVTec/bottle/train/good/109.png MVTec/bottle/ground_truth/good/109_mask.png 0
132
+
133
+ Returns:
134
+ An output dataframe containing samples for the requested split (ie., train or test)
135
+ """
136
+ samples_list = [
137
+ (str(path),) + filename.parts[-3:]
138
+ for filename in path.glob("**/*")
139
+ if filename.is_file()
140
+ and os.path.splitext(filename)[-1].lower() in IMAGE_EXTENSIONS
141
+ and ".ipynb_checkpoints" not in str(filename)
142
+ ]
143
+
144
+ if len(samples_list) == 0:
145
+ raise RuntimeError(f"Found 0 images in {path}")
146
+
147
+ samples_list.sort()
148
+
149
+ data = pd.DataFrame(samples_list, columns=["path", "split", "targets", "samples"])
150
+ data = data[data.split != "ground_truth"]
151
+
152
+ # Create mask_path column, masks MUST have png extension
153
+ data["mask_path"] = (
154
+ data.path
155
+ + "/ground_truth/"
156
+ + data.targets
157
+ + "/"
158
+ + data.samples.apply(lambda x: os.path.splitext(os.path.basename(x))[0])
159
+ + (f"{mask_suffix}.png" if mask_suffix is not None else ".png")
160
+ )
161
+
162
+ # Modify image_path column by converting to absolute path
163
+ data["samples"] = data.path + "/" + data.split + "/" + data.targets + "/" + data.samples
164
+
165
+ # Split the normal images in training set if test set doesn't
166
+ # contain any normal images. This is needed because AUC score
167
+ # cannot be computed based on 1-class
168
+ if sum((data.split == "test") & (data.targets == "good")) == 0 and create_test_set_if_empty:
169
+ data = split_normal_images_in_train_set(data, split_ratio, seed)
170
+
171
+ # Good images don't have mask
172
+ data.loc[(data.split == "test") & (data.targets == "good"), "mask_path"] = ""
173
+
174
+ # Create label index for normal (0), anomalous (1) and unknown (-1) images.
175
+ data.loc[data.targets == "good", "label_index"] = 0
176
+ data.loc[~data.targets.isin(["good", "unknown"]), "label_index"] = 1
177
+ data.loc[data.targets == "unknown", "label_index"] = -1
178
+ data.label_index = data.label_index.astype(int)
179
+
180
+ # Get the data frame for the split.
181
+ if split is not None and split in ["train", "val", "test"]:
182
+ data = data[data.split == split]
183
+ data = data.reset_index(drop=True)
184
+
185
+ return data
186
+
187
+
188
+ class AnomalyDataset(Dataset):
189
+ """Anomaly Dataset.
190
+
191
+ Args:
192
+ transform: Albumentations compose.
193
+ task: ``classification`` or ``segmentation``
194
+ samples: Pandas dataframe containing samples following the same structure created by make_anomaly_dataset
195
+ valid_area_mask: Optional path to the mask to use to filter out the valid area of the image. If None, the
196
+ whole image is considered valid.
197
+ crop_area: Optional tuple of 4 integers (x1, y1, x2, y2) to crop the image to the specified area. If None, the
198
+ whole image is considered valid.
199
+ """
200
+
201
+ def __init__(
202
+ self,
203
+ transform: alb.Compose,
204
+ samples: DataFrame,
205
+ task: str = "segmentation",
206
+ valid_area_mask: str | None = None,
207
+ crop_area: tuple[int, int, int, int] | None = None,
208
+ ) -> None:
209
+ self.task = task
210
+ self.transform = transform
211
+
212
+ self.samples = samples
213
+ self.samples = self.samples.reset_index(drop=True)
214
+ self.split = self.samples.split.unique()[0]
215
+
216
+ self.crop_area = crop_area
217
+ self.valid_area_mask: np.ndarray | None = None
218
+
219
+ if valid_area_mask is not None:
220
+ if not os.path.exists(valid_area_mask):
221
+ raise RuntimeError(f"Valid area mask {valid_area_mask} does not exist.")
222
+
223
+ self.valid_area_mask = cv2.imread(valid_area_mask, 0) > 0 # type: ignore[operator]
224
+
225
+ def __len__(self) -> int:
226
+ """Get length of the dataset."""
227
+ return len(self.samples)
228
+
229
+ def __getitem__(self, index: int) -> dict[str, str | Tensor]:
230
+ """Get dataset item for the index ``index``.
231
+
232
+ Args:
233
+ index: Index to get the item.
234
+
235
+ Returns:
236
+ Dict of image tensor during training.
237
+ Otherwise, Dict containing image path, target path, image tensor, label and transformed bounding box.
238
+ """
239
+ item: dict[str, str | Tensor] = {}
240
+
241
+ image_path = self.samples.samples.iloc[index]
242
+ image = cv2.imread(image_path)
243
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
244
+
245
+ original_image_shape = image.shape
246
+ if self.valid_area_mask is not None:
247
+ image = image * self.valid_area_mask[:, :, np.newaxis]
248
+
249
+ if self.crop_area is not None:
250
+ image = image[self.crop_area[1] : self.crop_area[3], self.crop_area[0] : self.crop_area[2]]
251
+
252
+ label_index = self.samples.label_index[index]
253
+
254
+ if self.split == "train":
255
+ pre_processed = self.transform(image=image)
256
+ item = {"image": pre_processed["image"], "label": label_index}
257
+ elif self.split in ["val", "test"]:
258
+ item["image_path"] = image_path
259
+ item["label"] = label_index
260
+
261
+ if self.task == "segmentation":
262
+ mask_path = self.samples.mask_path[index]
263
+
264
+ # If good images have no associated mask create an empty one
265
+ if label_index == 0:
266
+ mask = np.zeros(shape=original_image_shape[:2])
267
+ elif os.path.isfile(mask_path):
268
+ mask = cv2.imread(mask_path, flags=0) / 255.0 # type: ignore[operator]
269
+ else:
270
+ # We need ones in the mask to compute correctly at least image level f1 score
271
+ mask = np.ones(shape=original_image_shape[:2])
272
+
273
+ if self.valid_area_mask is not None:
274
+ mask = mask * self.valid_area_mask
275
+
276
+ if self.crop_area is not None:
277
+ mask = mask[self.crop_area[1] : self.crop_area[3], self.crop_area[0] : self.crop_area[2]]
278
+
279
+ pre_processed = self.transform(image=image, mask=mask)
280
+
281
+ item["mask_path"] = mask_path
282
+ item["mask"] = pre_processed["mask"]
283
+ else:
284
+ pre_processed = self.transform(image=image)
285
+
286
+ item["image"] = pre_processed["image"]
287
+ return item
@@ -0,0 +1,241 @@
1
+ from __future__ import annotations
2
+
3
+ import warnings
4
+ from collections.abc import Callable
5
+
6
+ import cv2
7
+ import numpy as np
8
+ import torch
9
+ from torch.utils.data import Dataset
10
+
11
+ from quadra.utils.imaging import crop_image, keep_aspect_ratio_resize
12
+
13
+
14
+ class ImageClassificationListDataset(Dataset):
15
+ """Standard classification dataset.
16
+
17
+ Args:
18
+ samples: List of paths to images to be read
19
+ targets: List of labels, one for every image
20
+ in samples
21
+ class_to_idx: mapping from classes
22
+ to unique indexes.
23
+ Defaults to None.
24
+ resize: Integer specifying the size of
25
+ a first optional resize keeping the aspect ratio: the smaller side
26
+ of the image will be resized to `resize`, while the longer will be
27
+ resized keeping the aspect ratio.
28
+ Defaults to None.
29
+ roi: Optional ROI, with
30
+ (x_upper_left, y_upper_left, x_bottom_right, y_bottom_right).
31
+ Defaults to None.
32
+ transform: Optional Albumentations
33
+ transform.
34
+ Defaults to None.
35
+ rgb: if False, image will be converted in grayscale
36
+ channel: 1 or 3. If rgb is True, then channel will be set at 3.
37
+ allow_missing_label: If set to false warn the user if the dataset contains missing labels
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ samples: list[str],
43
+ targets: list[str | int],
44
+ class_to_idx: dict | None = None,
45
+ resize: int | None = None,
46
+ roi: tuple[int, int, int, int] | None = None,
47
+ transform: Callable | None = None,
48
+ rgb: bool = True,
49
+ channel: int = 3,
50
+ allow_missing_label: bool | None = False,
51
+ ):
52
+ super().__init__()
53
+ assert len(samples) == len(
54
+ targets
55
+ ), f"Samples ({len(samples)}) and targets ({len(targets)}) must have the same length"
56
+ # Setting the ROI
57
+ self.roi = roi
58
+
59
+ # Keep-Aspect-Ratio resize
60
+ self.resize = resize
61
+
62
+ if not allow_missing_label and None in targets:
63
+ warnings.warn(
64
+ (
65
+ "Dataset contains empty targets but allow_missing_label is set to False, "
66
+ "be careful because None labels will not work inside Dataloaders"
67
+ ),
68
+ UserWarning,
69
+ stacklevel=2,
70
+ )
71
+
72
+ targets = [-1 if target is None else target for target in targets]
73
+ # Data
74
+ self.x = np.array(samples)
75
+ self.y = np.array(targets)
76
+
77
+ if class_to_idx is None:
78
+ unique_targets = np.unique(targets)
79
+ class_to_idx = {c: i for i, c in enumerate(unique_targets)}
80
+
81
+ self.class_to_idx = class_to_idx
82
+ self.idx_to_class = {v: k for k, v in class_to_idx.items()}
83
+ self.samples = [
84
+ (path, self.class_to_idx[self.y[i]] if (self.y[i] != -1 and self.y[i] != "-1") else -1)
85
+ for i, path in enumerate(self.x)
86
+ ]
87
+
88
+ self.rgb = rgb
89
+ self.channel = 3 if rgb else channel
90
+
91
+ self.transform = transform
92
+
93
+ def __getitem__(self, idx) -> tuple[np.ndarray, np.ndarray]:
94
+ path, y = self.samples[idx]
95
+
96
+ # Load image
97
+ x = cv2.imread(str(path))
98
+ if self.rgb:
99
+ x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB)
100
+ else:
101
+ x = cv2.cvtColor(x, cv2.COLOR_BGR2GRAY)
102
+ x = cv2.cvtColor(x, cv2.COLOR_GRAY2RGB)
103
+
104
+ if self.channel == 1:
105
+ x = x[:, :, 0]
106
+
107
+ # Crop with ROI
108
+ if self.roi:
109
+ x = crop_image(x, self.roi)
110
+
111
+ # Resize keeping aspect ratio
112
+ if self.resize:
113
+ x = keep_aspect_ratio_resize(x, self.resize)
114
+
115
+ if self.transform:
116
+ aug = self.transform(image=x)
117
+ x = aug["image"]
118
+
119
+ return x, y
120
+
121
+ def __len__(self):
122
+ return len(self.samples)
123
+
124
+
125
+ class ClassificationDataset(ImageClassificationListDataset):
126
+ """Custom Classification Dataset.
127
+
128
+ Args:
129
+ samples: List of paths to images
130
+ targets: List of targets
131
+ class_to_idx: Defaults to None.
132
+ resize: Resize image to this size. Defaults to None.
133
+ roi: Region of interest. Defaults to None.
134
+ transform: transform function. Defaults to None.
135
+ rgb: Use RGB space
136
+ channel: Number of channels. Defaults to 3.
137
+ random_padding: Random padding. Defaults to False.
138
+ circular_crop: Circular crop. Defaults to False.
139
+ """
140
+
141
+ def __init__(
142
+ self,
143
+ samples: list[str],
144
+ targets: list[str | int],
145
+ class_to_idx: dict | None = None,
146
+ resize: int | None = None,
147
+ roi: tuple[int, int, int, int] | None = None,
148
+ transform: Callable | None = None,
149
+ rgb: bool = True,
150
+ channel: int = 3,
151
+ random_padding: bool = False,
152
+ circular_crop: bool = False,
153
+ ):
154
+ super().__init__(samples, targets, class_to_idx, resize, roi, transform, rgb, channel)
155
+ if transform is None:
156
+ self.transform = None
157
+
158
+ self.random_padding = random_padding
159
+ self.circular_crop = circular_crop
160
+
161
+ def __getitem__(self, idx):
162
+ path, y = self.samples[idx]
163
+ path = str(path)
164
+
165
+ # Load image
166
+ x = cv2.imread(path)
167
+ if self.rgb:
168
+ x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB)
169
+ else:
170
+ x = cv2.cvtColor(x, cv2.COLOR_BGR2GRAY)
171
+ x = cv2.cvtColor(x, cv2.COLOR_GRAY2RGB)
172
+
173
+ if self.transform is not None:
174
+ aug = self.transform(image=x)
175
+ x = aug["image"]
176
+
177
+ if self.channel == 1:
178
+ x = x[:1]
179
+
180
+ return x, y
181
+
182
+
183
+ class MultilabelClassificationDataset(torch.utils.data.Dataset):
184
+ """Custom MultilabelClassification Dataset.
185
+
186
+ Args:
187
+ samples: list of paths to images.
188
+ targets: array of multiple targets per sample. The array must be a one-hot enoding.
189
+ It must have a shape of (n_samples, n_targets).
190
+ class_to_idx: Defaults to None.
191
+ transform: transform function. Defaults to None.
192
+ rgb: Use RGB space
193
+ """
194
+
195
+ def __init__(
196
+ self,
197
+ samples: list[str],
198
+ targets: np.ndarray,
199
+ class_to_idx: dict | None = None,
200
+ transform: Callable | None = None,
201
+ rgb: bool = True,
202
+ ):
203
+ super().__init__()
204
+ assert len(samples) == len(
205
+ targets
206
+ ), f"Samples ({len(samples)}) and targets ({len(targets)}) must have the same length"
207
+
208
+ # Data
209
+ self.x = samples
210
+ self.y = targets
211
+
212
+ # Class to idx and the other way around
213
+ if class_to_idx is None:
214
+ unique_targets = targets.shape[1]
215
+ class_to_idx = {c: i for i, c in enumerate(range(unique_targets))}
216
+ self.class_to_idx = class_to_idx
217
+ self.idx_to_class = {v: k for k, v in class_to_idx.items()}
218
+ self.samples = list(zip(self.x, self.y))
219
+ self.rgb = rgb
220
+ self.transform = transform
221
+
222
+ def __getitem__(self, idx):
223
+ path, y = self.samples[idx]
224
+ path = str(path)
225
+
226
+ # Load image
227
+ x = cv2.imread(path)
228
+ if self.rgb:
229
+ x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB)
230
+ else:
231
+ x = cv2.cvtColor(x, cv2.COLOR_BGR2GRAY)
232
+ x = cv2.cvtColor(x, cv2.COLOR_GRAY2RGB)
233
+
234
+ if self.transform is not None:
235
+ aug = self.transform(image=x)
236
+ x = aug["image"]
237
+
238
+ return x, torch.from_numpy(y).float()
239
+
240
+ def __len__(self):
241
+ return len(self.samples)