quadra 0.0.1__py3-none-any.whl → 2.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (302) hide show
  1. hydra_plugins/quadra_searchpath_plugin.py +30 -0
  2. quadra/__init__.py +6 -0
  3. quadra/callbacks/__init__.py +0 -0
  4. quadra/callbacks/anomalib.py +289 -0
  5. quadra/callbacks/lightning.py +501 -0
  6. quadra/callbacks/mlflow.py +291 -0
  7. quadra/callbacks/scheduler.py +69 -0
  8. quadra/configs/__init__.py +0 -0
  9. quadra/configs/backbone/caformer_m36.yaml +8 -0
  10. quadra/configs/backbone/caformer_s36.yaml +8 -0
  11. quadra/configs/backbone/convnextv2_base.yaml +8 -0
  12. quadra/configs/backbone/convnextv2_femto.yaml +8 -0
  13. quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
  14. quadra/configs/backbone/dino_vitb8.yaml +12 -0
  15. quadra/configs/backbone/dino_vits8.yaml +12 -0
  16. quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
  17. quadra/configs/backbone/dinov2_vits14.yaml +12 -0
  18. quadra/configs/backbone/efficientnet_b0.yaml +8 -0
  19. quadra/configs/backbone/efficientnet_b1.yaml +8 -0
  20. quadra/configs/backbone/efficientnet_b2.yaml +8 -0
  21. quadra/configs/backbone/efficientnet_b3.yaml +8 -0
  22. quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
  23. quadra/configs/backbone/levit_128s.yaml +8 -0
  24. quadra/configs/backbone/mnasnet0_5.yaml +9 -0
  25. quadra/configs/backbone/resnet101.yaml +8 -0
  26. quadra/configs/backbone/resnet18.yaml +8 -0
  27. quadra/configs/backbone/resnet18_ssl.yaml +8 -0
  28. quadra/configs/backbone/resnet50.yaml +8 -0
  29. quadra/configs/backbone/smp.yaml +9 -0
  30. quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
  31. quadra/configs/backbone/unetr.yaml +15 -0
  32. quadra/configs/backbone/vit16_base.yaml +9 -0
  33. quadra/configs/backbone/vit16_small.yaml +9 -0
  34. quadra/configs/backbone/vit16_tiny.yaml +9 -0
  35. quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
  36. quadra/configs/callbacks/all.yaml +32 -0
  37. quadra/configs/callbacks/default.yaml +37 -0
  38. quadra/configs/callbacks/default_anomalib.yaml +67 -0
  39. quadra/configs/config.yaml +33 -0
  40. quadra/configs/core/default.yaml +11 -0
  41. quadra/configs/datamodule/base/anomaly.yaml +16 -0
  42. quadra/configs/datamodule/base/classification.yaml +21 -0
  43. quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
  44. quadra/configs/datamodule/base/segmentation.yaml +18 -0
  45. quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
  46. quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
  47. quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
  48. quadra/configs/datamodule/base/ssl.yaml +21 -0
  49. quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
  50. quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
  51. quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
  52. quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
  53. quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
  54. quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
  55. quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
  56. quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
  57. quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
  58. quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
  59. quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
  60. quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
  61. quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
  62. quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
  63. quadra/configs/experiment/base/classification/classification.yaml +73 -0
  64. quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
  65. quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
  66. quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
  67. quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
  68. quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
  69. quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
  70. quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
  71. quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
  72. quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
  73. quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
  74. quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
  75. quadra/configs/experiment/base/ssl/byol.yaml +43 -0
  76. quadra/configs/experiment/base/ssl/dino.yaml +46 -0
  77. quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
  78. quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
  79. quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
  80. quadra/configs/experiment/custom/cls.yaml +12 -0
  81. quadra/configs/experiment/default.yaml +15 -0
  82. quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
  83. quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
  84. quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
  85. quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
  86. quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
  87. quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
  88. quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
  89. quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
  90. quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
  91. quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
  92. quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
  93. quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
  94. quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
  95. quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
  96. quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
  97. quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
  98. quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
  99. quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
  100. quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
  101. quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
  102. quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
  103. quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
  104. quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
  105. quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
  106. quadra/configs/export/default.yaml +13 -0
  107. quadra/configs/hydra/anomaly_custom.yaml +15 -0
  108. quadra/configs/hydra/default.yaml +14 -0
  109. quadra/configs/inference/default.yaml +26 -0
  110. quadra/configs/logger/comet.yaml +10 -0
  111. quadra/configs/logger/csv.yaml +5 -0
  112. quadra/configs/logger/mlflow.yaml +12 -0
  113. quadra/configs/logger/tensorboard.yaml +8 -0
  114. quadra/configs/loss/asl.yaml +7 -0
  115. quadra/configs/loss/barlow.yaml +2 -0
  116. quadra/configs/loss/bce.yaml +1 -0
  117. quadra/configs/loss/byol.yaml +1 -0
  118. quadra/configs/loss/cross_entropy.yaml +1 -0
  119. quadra/configs/loss/dino.yaml +8 -0
  120. quadra/configs/loss/simclr.yaml +2 -0
  121. quadra/configs/loss/simsiam.yaml +1 -0
  122. quadra/configs/loss/smp_ce.yaml +3 -0
  123. quadra/configs/loss/smp_dice.yaml +2 -0
  124. quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
  125. quadra/configs/loss/smp_mcc.yaml +2 -0
  126. quadra/configs/loss/vicreg.yaml +5 -0
  127. quadra/configs/model/anomalib/cfa.yaml +35 -0
  128. quadra/configs/model/anomalib/cflow.yaml +30 -0
  129. quadra/configs/model/anomalib/csflow.yaml +34 -0
  130. quadra/configs/model/anomalib/dfm.yaml +19 -0
  131. quadra/configs/model/anomalib/draem.yaml +29 -0
  132. quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
  133. quadra/configs/model/anomalib/fastflow.yaml +32 -0
  134. quadra/configs/model/anomalib/padim.yaml +32 -0
  135. quadra/configs/model/anomalib/patchcore.yaml +36 -0
  136. quadra/configs/model/barlow.yaml +16 -0
  137. quadra/configs/model/byol.yaml +25 -0
  138. quadra/configs/model/classification.yaml +10 -0
  139. quadra/configs/model/dino.yaml +26 -0
  140. quadra/configs/model/logistic_regression.yaml +4 -0
  141. quadra/configs/model/multilabel_classification.yaml +9 -0
  142. quadra/configs/model/simclr.yaml +18 -0
  143. quadra/configs/model/simsiam.yaml +24 -0
  144. quadra/configs/model/smp.yaml +4 -0
  145. quadra/configs/model/smp_multiclass.yaml +4 -0
  146. quadra/configs/model/vicreg.yaml +16 -0
  147. quadra/configs/optimizer/adam.yaml +5 -0
  148. quadra/configs/optimizer/adamw.yaml +3 -0
  149. quadra/configs/optimizer/default.yaml +4 -0
  150. quadra/configs/optimizer/lars.yaml +8 -0
  151. quadra/configs/optimizer/sgd.yaml +4 -0
  152. quadra/configs/scheduler/default.yaml +5 -0
  153. quadra/configs/scheduler/rop.yaml +5 -0
  154. quadra/configs/scheduler/step.yaml +3 -0
  155. quadra/configs/scheduler/warmrestart.yaml +2 -0
  156. quadra/configs/scheduler/warmup.yaml +6 -0
  157. quadra/configs/task/anomalib/cfa.yaml +5 -0
  158. quadra/configs/task/anomalib/cflow.yaml +5 -0
  159. quadra/configs/task/anomalib/csflow.yaml +5 -0
  160. quadra/configs/task/anomalib/draem.yaml +5 -0
  161. quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
  162. quadra/configs/task/anomalib/fastflow.yaml +5 -0
  163. quadra/configs/task/anomalib/inference.yaml +3 -0
  164. quadra/configs/task/anomalib/padim.yaml +5 -0
  165. quadra/configs/task/anomalib/patchcore.yaml +5 -0
  166. quadra/configs/task/classification.yaml +6 -0
  167. quadra/configs/task/classification_evaluation.yaml +6 -0
  168. quadra/configs/task/default.yaml +1 -0
  169. quadra/configs/task/segmentation.yaml +9 -0
  170. quadra/configs/task/segmentation_evaluation.yaml +3 -0
  171. quadra/configs/task/sklearn_classification.yaml +13 -0
  172. quadra/configs/task/sklearn_classification_patch.yaml +11 -0
  173. quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
  174. quadra/configs/task/sklearn_classification_test.yaml +8 -0
  175. quadra/configs/task/ssl.yaml +2 -0
  176. quadra/configs/trainer/lightning_cpu.yaml +36 -0
  177. quadra/configs/trainer/lightning_gpu.yaml +35 -0
  178. quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
  179. quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
  180. quadra/configs/trainer/lightning_multigpu.yaml +37 -0
  181. quadra/configs/trainer/sklearn_classification.yaml +7 -0
  182. quadra/configs/transforms/byol.yaml +47 -0
  183. quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
  184. quadra/configs/transforms/default.yaml +37 -0
  185. quadra/configs/transforms/default_numpy.yaml +24 -0
  186. quadra/configs/transforms/default_resize.yaml +22 -0
  187. quadra/configs/transforms/dino.yaml +63 -0
  188. quadra/configs/transforms/linear_eval.yaml +18 -0
  189. quadra/datamodules/__init__.py +20 -0
  190. quadra/datamodules/anomaly.py +180 -0
  191. quadra/datamodules/base.py +375 -0
  192. quadra/datamodules/classification.py +1003 -0
  193. quadra/datamodules/generic/__init__.py +0 -0
  194. quadra/datamodules/generic/imagenette.py +144 -0
  195. quadra/datamodules/generic/mnist.py +81 -0
  196. quadra/datamodules/generic/mvtec.py +58 -0
  197. quadra/datamodules/generic/oxford_pet.py +163 -0
  198. quadra/datamodules/patch.py +190 -0
  199. quadra/datamodules/segmentation.py +742 -0
  200. quadra/datamodules/ssl.py +140 -0
  201. quadra/datasets/__init__.py +17 -0
  202. quadra/datasets/anomaly.py +287 -0
  203. quadra/datasets/classification.py +241 -0
  204. quadra/datasets/patch.py +138 -0
  205. quadra/datasets/segmentation.py +239 -0
  206. quadra/datasets/ssl.py +110 -0
  207. quadra/losses/__init__.py +0 -0
  208. quadra/losses/classification/__init__.py +6 -0
  209. quadra/losses/classification/asl.py +83 -0
  210. quadra/losses/classification/focal.py +320 -0
  211. quadra/losses/classification/prototypical.py +148 -0
  212. quadra/losses/ssl/__init__.py +17 -0
  213. quadra/losses/ssl/barlowtwins.py +47 -0
  214. quadra/losses/ssl/byol.py +37 -0
  215. quadra/losses/ssl/dino.py +129 -0
  216. quadra/losses/ssl/hyperspherical.py +45 -0
  217. quadra/losses/ssl/idmm.py +50 -0
  218. quadra/losses/ssl/simclr.py +67 -0
  219. quadra/losses/ssl/simsiam.py +30 -0
  220. quadra/losses/ssl/vicreg.py +76 -0
  221. quadra/main.py +46 -0
  222. quadra/metrics/__init__.py +3 -0
  223. quadra/metrics/segmentation.py +251 -0
  224. quadra/models/__init__.py +0 -0
  225. quadra/models/base.py +151 -0
  226. quadra/models/classification/__init__.py +8 -0
  227. quadra/models/classification/backbones.py +149 -0
  228. quadra/models/classification/base.py +92 -0
  229. quadra/models/evaluation.py +322 -0
  230. quadra/modules/__init__.py +0 -0
  231. quadra/modules/backbone.py +30 -0
  232. quadra/modules/base.py +312 -0
  233. quadra/modules/classification/__init__.py +3 -0
  234. quadra/modules/classification/base.py +331 -0
  235. quadra/modules/ssl/__init__.py +17 -0
  236. quadra/modules/ssl/barlowtwins.py +59 -0
  237. quadra/modules/ssl/byol.py +172 -0
  238. quadra/modules/ssl/common.py +285 -0
  239. quadra/modules/ssl/dino.py +186 -0
  240. quadra/modules/ssl/hyperspherical.py +206 -0
  241. quadra/modules/ssl/idmm.py +98 -0
  242. quadra/modules/ssl/simclr.py +73 -0
  243. quadra/modules/ssl/simsiam.py +68 -0
  244. quadra/modules/ssl/vicreg.py +67 -0
  245. quadra/optimizers/__init__.py +4 -0
  246. quadra/optimizers/lars.py +153 -0
  247. quadra/optimizers/sam.py +127 -0
  248. quadra/schedulers/__init__.py +3 -0
  249. quadra/schedulers/base.py +44 -0
  250. quadra/schedulers/warmup.py +127 -0
  251. quadra/tasks/__init__.py +24 -0
  252. quadra/tasks/anomaly.py +582 -0
  253. quadra/tasks/base.py +397 -0
  254. quadra/tasks/classification.py +1264 -0
  255. quadra/tasks/patch.py +492 -0
  256. quadra/tasks/segmentation.py +389 -0
  257. quadra/tasks/ssl.py +560 -0
  258. quadra/trainers/README.md +3 -0
  259. quadra/trainers/__init__.py +0 -0
  260. quadra/trainers/classification.py +179 -0
  261. quadra/utils/__init__.py +0 -0
  262. quadra/utils/anomaly.py +112 -0
  263. quadra/utils/classification.py +618 -0
  264. quadra/utils/deprecation.py +31 -0
  265. quadra/utils/evaluation.py +474 -0
  266. quadra/utils/export.py +579 -0
  267. quadra/utils/imaging.py +32 -0
  268. quadra/utils/logger.py +15 -0
  269. quadra/utils/mlflow.py +98 -0
  270. quadra/utils/model_manager.py +320 -0
  271. quadra/utils/models.py +524 -0
  272. quadra/utils/patch/__init__.py +15 -0
  273. quadra/utils/patch/dataset.py +1433 -0
  274. quadra/utils/patch/metrics.py +449 -0
  275. quadra/utils/patch/model.py +153 -0
  276. quadra/utils/patch/visualization.py +217 -0
  277. quadra/utils/resolver.py +42 -0
  278. quadra/utils/segmentation.py +31 -0
  279. quadra/utils/tests/__init__.py +0 -0
  280. quadra/utils/tests/fixtures/__init__.py +1 -0
  281. quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
  282. quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
  283. quadra/utils/tests/fixtures/dataset/classification.py +406 -0
  284. quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
  285. quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
  286. quadra/utils/tests/fixtures/models/__init__.py +3 -0
  287. quadra/utils/tests/fixtures/models/anomaly.py +89 -0
  288. quadra/utils/tests/fixtures/models/classification.py +45 -0
  289. quadra/utils/tests/fixtures/models/segmentation.py +33 -0
  290. quadra/utils/tests/helpers.py +70 -0
  291. quadra/utils/tests/models.py +27 -0
  292. quadra/utils/utils.py +525 -0
  293. quadra/utils/validator.py +115 -0
  294. quadra/utils/visualization.py +422 -0
  295. quadra/utils/vit_explainability.py +349 -0
  296. quadra-2.1.13.dist-info/LICENSE +201 -0
  297. quadra-2.1.13.dist-info/METADATA +386 -0
  298. quadra-2.1.13.dist-info/RECORD +300 -0
  299. {quadra-0.0.1.dist-info → quadra-2.1.13.dist-info}/WHEEL +1 -1
  300. quadra-2.1.13.dist-info/entry_points.txt +3 -0
  301. quadra-0.0.1.dist-info/METADATA +0 -14
  302. quadra-0.0.1.dist-info/RECORD +0 -4
File without changes
@@ -0,0 +1,144 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import shutil
5
+ from typing import Any
6
+
7
+ import pandas as pd
8
+ from sklearn.model_selection import train_test_split
9
+ from torchvision.datasets.utils import download_and_extract_archive
10
+
11
+ from quadra.datamodules import ClassificationDataModule, SSLDataModule
12
+ from quadra.utils.utils import get_logger
13
+
14
+ IMAGENETTE_LABEL_MAPPER = {
15
+ "n01440764": "tench",
16
+ "n02102040": "english_springer",
17
+ "n02979186": "cassette_player",
18
+ "n03000684": "chain_saw",
19
+ "n03028079": "church",
20
+ "n03394916": "french_horn",
21
+ "n03417042": "garbage_truck",
22
+ "n03425413": "gas_pump",
23
+ "n03445777": "golf_ball",
24
+ "n03888257": "parachute",
25
+ }
26
+
27
+ DEFAULT_CLASS_TO_IDX = {cl: idx for idx, cl in enumerate(sorted(IMAGENETTE_LABEL_MAPPER.values()))}
28
+
29
+ log = get_logger(__name__)
30
+
31
+
32
+ class ImagenetteClassificationDataModule(ClassificationDataModule):
33
+ """Initializes the classification data module for Imagenette dataset.
34
+
35
+ Args:
36
+ data_path: Path to the dataset.
37
+ name: Name of the dataset.
38
+ imagenette_version: Version of the Imagenette dataset. Can be 320 or 160 or full.
39
+ force_download: If True, the dataset will be downloaded even if the data_path already exists. The data_path
40
+ will be deleted and recreated.
41
+ class_to_idx: Dictionary mapping class names to class indices.
42
+ **kwargs: Keyword arguments for the ClassificationDataModule.
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ data_path: str,
48
+ name: str = "imagenette_classification_datamodule",
49
+ imagenette_version: str = "320",
50
+ force_download: bool = False,
51
+ class_to_idx: dict[str, int] | None = None,
52
+ **kwargs: Any,
53
+ ):
54
+ if imagenette_version not in ["320", "160", "full"]:
55
+ raise ValueError(f"imagenette_version must be one of 320, 160 or full. Got {imagenette_version} instead.")
56
+
57
+ if imagenette_version == "full":
58
+ imagenette_version = ""
59
+ else:
60
+ imagenette_version = f"-{imagenette_version}"
61
+
62
+ self.download_url = f"https://s3.amazonaws.com/fast-ai-imageclas/imagenette2{imagenette_version}.tgz"
63
+ self.force_download = force_download
64
+ self.imagenette_version = imagenette_version
65
+
66
+ if class_to_idx is None:
67
+ class_to_idx = DEFAULT_CLASS_TO_IDX
68
+
69
+ super().__init__(
70
+ data_path=data_path,
71
+ name=name,
72
+ test_split_file=None,
73
+ train_split_file=None,
74
+ val_size=None,
75
+ class_to_idx=class_to_idx,
76
+ **kwargs,
77
+ )
78
+
79
+ def download_data(self, download_url: str, force_download: bool = False) -> None:
80
+ """Download the Imagenette dataset.
81
+
82
+ Args:
83
+ download_url: Dataset download url.
84
+ force_download: If True, the dataset will be downloaded even if the data_path already exists. The data_path
85
+ will be removed.
86
+ """
87
+ if os.path.exists(self.data_path):
88
+ if force_download:
89
+ log.info("The path %s already exists. Removing it and downloading the dataset again.", self.data_path)
90
+ shutil.rmtree(self.data_path)
91
+ else:
92
+ log.info("The path %s already exists. Skipping download.", self.data_path)
93
+ return
94
+
95
+ log.info("Downloading and extracting Imagenette dataset to %s", self.data_path)
96
+ download_and_extract_archive(download_url, self.data_path, remove_finished=True)
97
+
98
+ def _prepare_data(self) -> None:
99
+ """Prepares the data for the data module."""
100
+ self.download_data(download_url=self.download_url, force_download=self.force_download)
101
+ self.data_path = os.path.join(self.data_path, f"imagenette2{self.imagenette_version}")
102
+
103
+ train_images_and_targets, class_to_idx = self._find_images_and_targets(os.path.join(self.data_path, "train"))
104
+ self.class_to_idx = {IMAGENETTE_LABEL_MAPPER[k]: v for k, v in class_to_idx.items()}
105
+
106
+ samples_train, targets_train = [], []
107
+ idx_to_class = {v: k for k, v in self.class_to_idx.items()}
108
+ for image, target in train_images_and_targets:
109
+ samples_train.append(image)
110
+ targets_train.append(idx_to_class[target])
111
+
112
+ samples_train, samples_val, targets_train, targets_val = train_test_split(
113
+ samples_train,
114
+ targets_train,
115
+ test_size=self.val_size,
116
+ random_state=self.seed,
117
+ stratify=targets_train,
118
+ )
119
+
120
+ test_images_and_targets, _ = self._find_images_and_targets(os.path.join(self.data_path, "val"))
121
+ samples_test, targets_test = [], []
122
+ for image, target in test_images_and_targets:
123
+ samples_test.append(image)
124
+ targets_test.append(idx_to_class[target])
125
+
126
+ train_df = pd.DataFrame({"samples": samples_train, "targets": targets_train})
127
+ train_df["split"] = "train"
128
+ val_df = pd.DataFrame({"samples": samples_val, "targets": targets_val})
129
+ val_df["split"] = "val"
130
+ test_df = pd.DataFrame({"samples": samples_test, "targets": targets_test})
131
+ test_df["split"] = "test"
132
+ self.data = pd.concat([train_df, val_df, test_df], axis=0)
133
+
134
+
135
+ class ImagenetteSSLDataModule(ImagenetteClassificationDataModule, SSLDataModule):
136
+ """Initializes the SSL data module for Imagenette dataset."""
137
+
138
+ def __init__(
139
+ self,
140
+ *args: Any,
141
+ name="imagenette_ssl",
142
+ **kwargs: Any,
143
+ ):
144
+ super().__init__(*args, name=name, **kwargs) # type: ignore[misc]
@@ -0,0 +1,81 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import shutil
5
+ from typing import Any
6
+
7
+ import cv2
8
+ from torchvision.datasets.mnist import MNIST
9
+
10
+ from quadra.datamodules import AnomalyDataModule
11
+ from quadra.utils.utils import get_logger
12
+
13
+ log = get_logger(__name__)
14
+
15
+
16
+ class MNISTAnomalyDataModule(AnomalyDataModule):
17
+ """Standard anomaly datamodule with automatic download of the MNIST dataset."""
18
+
19
+ def __init__(
20
+ self, data_path: str, good_number: int, limit_data: int = 100, category: str | None = None, **kwargs: Any
21
+ ):
22
+ """Initialize the MNIST anomaly datamodule.
23
+
24
+ Args:
25
+ data_path: Path to the dataset
26
+ good_number: Which number to use as a good class, all other numbers are considered anomalies.
27
+ category: The category of the dataset. For mnist this is always None.
28
+ limit_data: Limit the number of images to use for training and testing. Defaults to 100.
29
+ **kwargs: Additional arguments to pass to the AnomalyDataModule.
30
+ """
31
+ super().__init__(data_path=data_path, category=None, **kwargs)
32
+ self.good_number = good_number
33
+ self.limit_data = limit_data
34
+
35
+ def download_data(self) -> None:
36
+ """Download the MNIST dataset and move images in the right folders."""
37
+ log.info("Generating MNIST anomaly dataset for good number %s", self.good_number)
38
+
39
+ mnist_train_dataset = MNIST(root=self.data_path, train=True, download=True)
40
+ mnist_test_dataset = MNIST(root=self.data_path, train=False, download=True)
41
+
42
+ self.data_path = os.path.join(self.data_path, "quadra_mnist_anomaly")
43
+
44
+ if os.path.exists(self.data_path):
45
+ shutil.rmtree(self.data_path)
46
+
47
+ # Create the folder structure
48
+ train_good_folder = os.path.join(self.data_path, "train", "good")
49
+ test_good_folder = os.path.join(self.data_path, "test", "good")
50
+
51
+ os.makedirs(train_good_folder, exist_ok=True)
52
+ os.makedirs(test_good_folder, exist_ok=True)
53
+
54
+ # Copy the good train images to the correct folder
55
+ good_train_samples = mnist_train_dataset.data[mnist_train_dataset.targets == self.good_number]
56
+ for i, image in enumerate(good_train_samples.numpy()):
57
+ if i == self.limit_data:
58
+ break
59
+ cv2.imwrite(os.path.join(train_good_folder, f"{i}.png"), image)
60
+
61
+ for number in range(10):
62
+ if number == self.good_number:
63
+ good_train_samples = mnist_test_dataset.data[mnist_test_dataset.targets == number]
64
+ for i, image in enumerate(good_train_samples.numpy()):
65
+ if i == self.limit_data:
66
+ break
67
+ cv2.imwrite(os.path.join(test_good_folder, f"{number}_{i}.png"), image)
68
+ else:
69
+ test_bad_folder = os.path.join(self.data_path, "test", str(number))
70
+ os.makedirs(test_bad_folder, exist_ok=True)
71
+ bad_train_samples = mnist_train_dataset.data[mnist_train_dataset.targets == number]
72
+ for i, image in enumerate(bad_train_samples.numpy()):
73
+ if i == self.limit_data:
74
+ break
75
+
76
+ cv2.imwrite(os.path.join(test_bad_folder, f"{number}_{i}.png"), image)
77
+
78
+ def _prepare_data(self) -> None:
79
+ """Prepare the MNIST dataset."""
80
+ self.download_data()
81
+ return super()._prepare_data()
@@ -0,0 +1,58 @@
1
+ import os
2
+ from pathlib import Path
3
+
4
+ from torchvision.datasets.utils import download_and_extract_archive
5
+
6
+ from quadra.datamodules import AnomalyDataModule
7
+ from quadra.utils.utils import get_logger
8
+
9
+ log = get_logger(__name__)
10
+
11
+
12
+ DATASET_BASE_URL = "https://www.mydrive.ch/shares/38536/3830184030e49fe74747669442f0f282/download/"
13
+
14
+ DATASET_URL = {
15
+ "bottle": DATASET_BASE_URL + "420937370-1629951468/bottle.tar.xz",
16
+ "capsule": DATASET_BASE_URL + "420937454-1629951595/capsule.tar.xz",
17
+ "carpet": DATASET_BASE_URL + "420937484-1629951672/carpet.tar.xz",
18
+ "grid": DATASET_BASE_URL + "420937487-1629951814/grid.tar.xz",
19
+ "hazelnut": DATASET_BASE_URL + "420937545-1629951845/hazelnut.tar.xz",
20
+ "leather": DATASET_BASE_URL + "420937607-1629951964/leather.tar.xz",
21
+ "metal_nut": DATASET_BASE_URL + "420937637-1629952063/metal_nut.tar.xz",
22
+ "pill": DATASET_BASE_URL + "420938129-1629953099/pill.tar.xz",
23
+ "screw": DATASET_BASE_URL + "420938130-1629953152/screw.tar.xz",
24
+ "tile": DATASET_BASE_URL + "420938133-1629953189/tile.tar.xz",
25
+ "toothbrush": DATASET_BASE_URL + "420938134-1629953256/toothbrush.tar.xz",
26
+ "transistor": DATASET_BASE_URL + "420938166-1629953277/transistor.tar.xz",
27
+ "wood": DATASET_BASE_URL + "420938383-1629953354/wood.tar.xz",
28
+ "zipper": DATASET_BASE_URL + "420938385-1629953449/zipper.tar.xz",
29
+ }
30
+
31
+
32
+ class MVTecDataModule(AnomalyDataModule):
33
+ """Standard anomaly datamodule with automatic download of the MVTec dataset."""
34
+
35
+ def __init__(self, data_path: str, category: str, **kwargs):
36
+ if category not in DATASET_URL:
37
+ raise ValueError(f"Unknown category {category}. Available categories are {list(DATASET_URL.keys())}")
38
+
39
+ super().__init__(data_path=data_path, category=category, **kwargs)
40
+
41
+ def download_data(self) -> None:
42
+ """Download the MVTec dataset."""
43
+ if self.category is None:
44
+ raise ValueError("Category must be specified for MVTec dataset.")
45
+
46
+ if os.path.exists(self.data_path):
47
+ log.info("The path %s already exists. Skipping download.", os.path.join(self.data_path, self.category))
48
+ return
49
+
50
+ log.info("Downloading and extracting MVTec dataset for category %s to %s", self.category, self.data_path)
51
+ # self.data_path is the path to the category folder that will be created by the download_and_extract_archive
52
+ data_path_no_category = str(Path(self.data_path).parent)
53
+ download_and_extract_archive(DATASET_URL[self.category], data_path_no_category, remove_finished=True)
54
+
55
+ def _prepare_data(self) -> None:
56
+ """Prepare the MVTec dataset."""
57
+ self.download_data()
58
+ return super()._prepare_data()
@@ -0,0 +1,163 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from typing import Any
5
+
6
+ import albumentations
7
+ import cv2
8
+ import numpy as np
9
+ import pandas as pd
10
+ from torchvision.datasets.utils import download_and_extract_archive
11
+
12
+ from quadra.datamodules import SegmentationMulticlassDataModule
13
+ from quadra.datasets.segmentation import SegmentationDatasetMulticlass
14
+ from quadra.utils import utils
15
+
16
+ log = utils.get_logger(__name__)
17
+
18
+
19
+ class OxfordPetSegmentationDataModule(SegmentationMulticlassDataModule):
20
+ """OxfordPetSegmentationDataModule.
21
+
22
+ Args:
23
+ data_path: path to the oxford pet dataset
24
+ idx_to_class: dict with corrispondence btw mask index and classes: {1: class_1, 2: class_2, ..., N: class_N}
25
+ except background class which is 0.
26
+ name: Defaults to "oxford_pet_segmentation_datamodule".
27
+ dataset: Defaults to SegmentationDataset.
28
+ batch_size: batch size for training. Defaults to 32.
29
+ test_size: Defaults to 0.3.
30
+ val_size: Defaults to 0.3.
31
+ seed: Defaults to 42.
32
+ num_workers: number of workers for data loading. Defaults to 6.
33
+ train_transform: Train transform. Defaults to None.
34
+ test_transform: Test transform. Defaults to None.
35
+ val_transform: Validation transform. Defaults to None.
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ data_path: str,
41
+ idx_to_class: dict,
42
+ name: str = "oxford_pet_segmentation_datamodule",
43
+ dataset: type[SegmentationDatasetMulticlass] = SegmentationDatasetMulticlass,
44
+ batch_size: int = 32,
45
+ test_size: float = 0.3,
46
+ val_size: float = 0.3,
47
+ seed: int = 42,
48
+ num_workers: int = 6,
49
+ train_transform: albumentations.Compose | None = None,
50
+ test_transform: albumentations.Compose | None = None,
51
+ val_transform: albumentations.Compose | None = None,
52
+ **kwargs: Any,
53
+ ):
54
+ super().__init__(
55
+ data_path=data_path,
56
+ idx_to_class=idx_to_class,
57
+ name=name,
58
+ dataset=dataset,
59
+ batch_size=batch_size,
60
+ test_size=test_size,
61
+ val_size=val_size,
62
+ seed=seed,
63
+ num_workers=num_workers,
64
+ train_transform=train_transform,
65
+ test_transform=test_transform,
66
+ val_transform=val_transform,
67
+ **kwargs,
68
+ )
69
+
70
+ _RESOURCES = (
71
+ ("https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz", "5c4f3ee8e5d25df40f4fd59a7f44e54c"),
72
+ ("https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz", "95a8c909bbe2e81eed6a22bccdf3f68f"),
73
+ )
74
+
75
+ def _preprocess_mask(self, mask: np.ndarray) -> np.ndarray:
76
+ """Preprocess mask function that is adapted from
77
+ https://albumentations.ai/docs/examples/pytorch_semantic_segmentation/.
78
+
79
+ Args:
80
+ mask: mask to be preprocessed
81
+
82
+ Returns:
83
+ binarized mask
84
+ """
85
+ mask = mask.astype(np.float32)
86
+ mask[mask == 2.0] = 0.0
87
+ mask[(mask == 1.0) | (mask == 3.0)] = 1.0
88
+ mask = (mask > 0).astype(np.uint8)
89
+ return mask
90
+
91
+ def _check_exists(self, image_folder: str, annotation_folder: str) -> bool:
92
+ """Check if the dataset is already downloaded."""
93
+ return all(os.path.exists(folder) and os.path.isdir(folder) for folder in (image_folder, annotation_folder))
94
+
95
+ def download_data(self):
96
+ """Download the dataset if it is not already downloaded."""
97
+ image_folder = os.path.join(self.data_path, "images")
98
+ annotation_folder = os.path.join(self.data_path, "annotations")
99
+ if not self._check_exists(image_folder, annotation_folder):
100
+ for url, md5 in self._RESOURCES:
101
+ download_and_extract_archive(url, download_root=self.data_path, md5=md5, remove_finished=True)
102
+ log.info("Fixing corrupted files...")
103
+ images_filenames = sorted(os.listdir(image_folder))
104
+ for filename in images_filenames:
105
+ file_wo_ext = os.path.splitext(os.path.basename(filename))[0]
106
+ try:
107
+ mask = cv2.imread(os.path.join(annotation_folder, "trimaps", file_wo_ext + ".png"))
108
+ mask = self._preprocess_mask(mask)
109
+ if np.sum(mask) == 0:
110
+ os.remove(os.path.join(image_folder, filename))
111
+ os.remove(os.path.join(annotation_folder, "trimaps", file_wo_ext + ".png"))
112
+ log.info("Removed %s", filename)
113
+ else:
114
+ img = cv2.imread(os.path.join(image_folder, filename))
115
+ cv2.imwrite(os.path.join(image_folder, file_wo_ext + ".jpg"), img)
116
+ except Exception:
117
+ ip = os.path.join(image_folder, filename)
118
+ mp = os.path.join(annotation_folder, "trimaps", file_wo_ext + ".png")
119
+ if os.path.exists(ip):
120
+ os.remove(ip)
121
+ if os.path.exists(mp):
122
+ os.remove(mp)
123
+ log.info("Removed %s", filename)
124
+
125
+ def _prepare_data(self) -> None:
126
+ """Prepare the data to be used by the DataModule."""
127
+ self.download_data()
128
+
129
+ trainval_split_filepath = os.path.join(self.data_path, "annotations", "trainval.txt")
130
+ with open(trainval_split_filepath) as f:
131
+ split_data = f.read().strip("\n").split("\n")
132
+ trainval_filenames = [
133
+ x.split(" ")[0]
134
+ for x in split_data
135
+ if os.path.exists(os.path.join(self.data_path, "images", x.split(" ")[0] + ".jpg"))
136
+ ]
137
+ train_filenames = [x for i, x in enumerate(trainval_filenames) if i % 10 != 0]
138
+ val_filenames = [x for i, x in enumerate(trainval_filenames) if i % 10 == 0]
139
+
140
+ test_split_filepath = os.path.join(self.data_path, "annotations", "test.txt")
141
+ with open(test_split_filepath) as f:
142
+ split_data = f.read().strip("\n").split("\n")
143
+ test_filenames = [
144
+ x.split(" ")[0]
145
+ for x in split_data
146
+ if os.path.exists(os.path.join(self.data_path, "images", x.split(" ")[0] + ".jpg"))
147
+ ]
148
+
149
+ df_list = []
150
+ for split_name, filenames in [
151
+ ("train", train_filenames),
152
+ ("val", val_filenames),
153
+ ("test", test_filenames),
154
+ ]:
155
+ samples = [os.path.join(self.data_path, "images", f + ".jpg") for f in filenames]
156
+ masks = [os.path.join(self.data_path, "annotations", "trimaps", f + ".png") for f in filenames]
157
+ targets = [1] * len(filenames)
158
+
159
+ df = pd.DataFrame({"samples": samples, "masks": masks, "targets": targets})
160
+ df["split"] = split_name
161
+ df_list.append(df)
162
+
163
+ self.data = pd.concat(df_list, axis=0)
@@ -0,0 +1,190 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import os
5
+
6
+ import albumentations
7
+ import pandas as pd
8
+ from torch.utils.data import DataLoader
9
+
10
+ from quadra.datamodules.base import BaseDataModule
11
+ from quadra.datasets import ImageClassificationListDataset, PatchSklearnClassificationTrainDataset
12
+ from quadra.utils.classification import find_test_image
13
+ from quadra.utils.patch.dataset import PatchDatasetInfo, load_train_file
14
+
15
+
16
+ class PatchSklearnClassificationDataModule(BaseDataModule):
17
+ """DataModule for patch classification.
18
+
19
+ Args:
20
+ data_path: Location of the dataset
21
+ name: Name of the datamodule
22
+ train_filename: Name of the file containing the list of training samples
23
+ exclude_filter: Filter to exclude samples from the dataset
24
+ include_filter: Filter to include samples from the dataset
25
+ class_to_idx: Dictionary mapping class names to indices
26
+ seed: Random seed
27
+ batch_size: Batch size
28
+ num_workers: Number of workers
29
+ train_transform: Transform to apply to the training samples
30
+ val_transform: Transform to apply to the validation samples
31
+ test_transform: Transform to apply to the test samples
32
+ balance_classes: If True repeat low represented classes
33
+ class_to_skip_training: List of classes skipped during training.
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ data_path: str,
39
+ class_to_idx: dict,
40
+ name: str = "patch_classification_datamodule",
41
+ train_filename: str = "dataset.txt",
42
+ exclude_filter: list[str] | None = None,
43
+ include_filter: list[str] | None = None,
44
+ seed: int = 42,
45
+ batch_size: int = 32,
46
+ num_workers: int = 6,
47
+ train_transform: albumentations.Compose | None = None,
48
+ val_transform: albumentations.Compose | None = None,
49
+ test_transform: albumentations.Compose | None = None,
50
+ balance_classes: bool = False,
51
+ class_to_skip_training: list | None = None,
52
+ **kwargs,
53
+ ):
54
+ super().__init__(
55
+ data_path=data_path,
56
+ name=name,
57
+ seed=seed,
58
+ num_workers=num_workers,
59
+ batch_size=batch_size,
60
+ train_transform=train_transform,
61
+ val_transform=val_transform,
62
+ test_transform=test_transform,
63
+ **kwargs,
64
+ )
65
+ self.class_to_idx = class_to_idx
66
+ self.balance_classes = balance_classes
67
+ self.train_filename = train_filename
68
+ self.include_filter = include_filter
69
+ self.exclude_filter = exclude_filter
70
+ self.class_to_skip_training = class_to_skip_training
71
+
72
+ self.train_folder = os.path.join(self.data_path, "train")
73
+ self.val_folder = os.path.join(self.data_path, "val")
74
+ self.test_folder = os.path.join(self.data_path, "test")
75
+ self.info: PatchDatasetInfo
76
+ self.train_dataset: PatchSklearnClassificationTrainDataset
77
+ self.val_dataset: ImageClassificationListDataset
78
+ self.test_dataset: ImageClassificationListDataset
79
+
80
+ def _prepare_data(self):
81
+ """Prepare data function."""
82
+ if os.path.isfile(os.path.join(self.data_path, "info.json")):
83
+ with open(os.path.join(self.data_path, "info.json")) as f:
84
+ self.info = PatchDatasetInfo(**json.load(f))
85
+ else:
86
+ raise FileNotFoundError("No `info.json` file found in the dataset folder")
87
+
88
+ split_df_list: list[pd.DataFrame] = []
89
+ if os.path.isfile(os.path.join(self.train_folder, self.train_filename)):
90
+ train_samples, train_labels = load_train_file(
91
+ train_file_path=os.path.join(self.train_folder, self.train_filename),
92
+ include_filter=self.include_filter,
93
+ exclude_filter=self.exclude_filter,
94
+ class_to_skip=self.class_to_skip_training,
95
+ )
96
+ train_df = pd.DataFrame({"samples": train_samples, "targets": train_labels})
97
+ train_df["split"] = "train"
98
+ split_df_list.append(train_df)
99
+ if os.path.isdir(self.val_folder):
100
+ val_samples, val_labels = find_test_image(
101
+ folder=self.val_folder,
102
+ exclude_filter=self.exclude_filter,
103
+ include_filter=self.include_filter,
104
+ include_none_class=False,
105
+ )
106
+ val_df = pd.DataFrame({"samples": val_samples, "targets": val_labels})
107
+ val_df["split"] = "val"
108
+ split_df_list.append(val_df)
109
+ if os.path.isdir(self.test_folder):
110
+ test_samples, test_labels = find_test_image(
111
+ folder=self.test_folder,
112
+ exclude_filter=self.exclude_filter,
113
+ include_filter=self.include_filter,
114
+ include_none_class=True,
115
+ )
116
+ test_df = pd.DataFrame({"samples": test_samples, "targets": test_labels})
117
+ test_df["split"] = "test"
118
+ split_df_list.append(test_df)
119
+ if len(split_df_list) == 0:
120
+ raise ValueError("No data found in all split folders")
121
+ self.data = pd.concat(split_df_list, axis=0)
122
+
123
+ def setup(self, stage: str | None = None) -> None:
124
+ """Setup function."""
125
+ if stage == "fit":
126
+ self.train_dataset = PatchSklearnClassificationTrainDataset(
127
+ data_path=self.data_path,
128
+ class_to_idx=self.class_to_idx,
129
+ samples=self.data[self.data["split"] == "train"]["samples"].tolist(),
130
+ targets=self.data[self.data["split"] == "train"]["targets"].tolist(),
131
+ transform=self.train_transform,
132
+ balance_classes=self.balance_classes,
133
+ )
134
+
135
+ self.val_dataset = ImageClassificationListDataset(
136
+ class_to_idx=self.class_to_idx,
137
+ samples=self.data[self.data["split"] == "val"]["samples"].tolist(),
138
+ targets=self.data[self.data["split"] == "val"]["targets"].tolist(),
139
+ transform=self.val_transform,
140
+ allow_missing_label=False,
141
+ )
142
+
143
+ elif stage in ["test", "predict"]:
144
+ self.test_dataset = ImageClassificationListDataset(
145
+ class_to_idx=self.class_to_idx,
146
+ samples=self.data[self.data["split"] == "test"]["samples"].tolist(),
147
+ targets=self.data[self.data["split"] == "test"]["targets"].tolist(),
148
+ transform=self.test_transform,
149
+ allow_missing_label=True,
150
+ )
151
+
152
+ def train_dataloader(self) -> DataLoader:
153
+ """Return the train dataloader."""
154
+ if not self.train_dataset_available:
155
+ raise ValueError("No training sample is available")
156
+ return DataLoader(
157
+ self.train_dataset,
158
+ batch_size=self.batch_size,
159
+ shuffle=True,
160
+ num_workers=self.num_workers,
161
+ drop_last=False,
162
+ pin_memory=True,
163
+ )
164
+
165
+ def val_dataloader(self) -> DataLoader:
166
+ """Return the validation dataloader."""
167
+ if not self.val_dataset_available:
168
+ raise ValueError("No validation dataset is available")
169
+ return DataLoader(
170
+ self.val_dataset,
171
+ batch_size=self.batch_size,
172
+ shuffle=False,
173
+ num_workers=self.num_workers,
174
+ drop_last=False,
175
+ pin_memory=True,
176
+ )
177
+
178
+ def test_dataloader(self) -> DataLoader:
179
+ """Return the test dataloader."""
180
+ if not self.test_dataset_available:
181
+ raise ValueError("No test dataset is available")
182
+
183
+ return DataLoader(
184
+ self.test_dataset,
185
+ batch_size=self.batch_size,
186
+ shuffle=False,
187
+ num_workers=self.num_workers,
188
+ drop_last=False,
189
+ pin_memory=True,
190
+ )