quadra 0.0.1__py3-none-any.whl → 2.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (302) hide show
  1. hydra_plugins/quadra_searchpath_plugin.py +30 -0
  2. quadra/__init__.py +6 -0
  3. quadra/callbacks/__init__.py +0 -0
  4. quadra/callbacks/anomalib.py +289 -0
  5. quadra/callbacks/lightning.py +501 -0
  6. quadra/callbacks/mlflow.py +291 -0
  7. quadra/callbacks/scheduler.py +69 -0
  8. quadra/configs/__init__.py +0 -0
  9. quadra/configs/backbone/caformer_m36.yaml +8 -0
  10. quadra/configs/backbone/caformer_s36.yaml +8 -0
  11. quadra/configs/backbone/convnextv2_base.yaml +8 -0
  12. quadra/configs/backbone/convnextv2_femto.yaml +8 -0
  13. quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
  14. quadra/configs/backbone/dino_vitb8.yaml +12 -0
  15. quadra/configs/backbone/dino_vits8.yaml +12 -0
  16. quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
  17. quadra/configs/backbone/dinov2_vits14.yaml +12 -0
  18. quadra/configs/backbone/efficientnet_b0.yaml +8 -0
  19. quadra/configs/backbone/efficientnet_b1.yaml +8 -0
  20. quadra/configs/backbone/efficientnet_b2.yaml +8 -0
  21. quadra/configs/backbone/efficientnet_b3.yaml +8 -0
  22. quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
  23. quadra/configs/backbone/levit_128s.yaml +8 -0
  24. quadra/configs/backbone/mnasnet0_5.yaml +9 -0
  25. quadra/configs/backbone/resnet101.yaml +8 -0
  26. quadra/configs/backbone/resnet18.yaml +8 -0
  27. quadra/configs/backbone/resnet18_ssl.yaml +8 -0
  28. quadra/configs/backbone/resnet50.yaml +8 -0
  29. quadra/configs/backbone/smp.yaml +9 -0
  30. quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
  31. quadra/configs/backbone/unetr.yaml +15 -0
  32. quadra/configs/backbone/vit16_base.yaml +9 -0
  33. quadra/configs/backbone/vit16_small.yaml +9 -0
  34. quadra/configs/backbone/vit16_tiny.yaml +9 -0
  35. quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
  36. quadra/configs/callbacks/all.yaml +32 -0
  37. quadra/configs/callbacks/default.yaml +37 -0
  38. quadra/configs/callbacks/default_anomalib.yaml +67 -0
  39. quadra/configs/config.yaml +33 -0
  40. quadra/configs/core/default.yaml +11 -0
  41. quadra/configs/datamodule/base/anomaly.yaml +16 -0
  42. quadra/configs/datamodule/base/classification.yaml +21 -0
  43. quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
  44. quadra/configs/datamodule/base/segmentation.yaml +18 -0
  45. quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
  46. quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
  47. quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
  48. quadra/configs/datamodule/base/ssl.yaml +21 -0
  49. quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
  50. quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
  51. quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
  52. quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
  53. quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
  54. quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
  55. quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
  56. quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
  57. quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
  58. quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
  59. quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
  60. quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
  61. quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
  62. quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
  63. quadra/configs/experiment/base/classification/classification.yaml +73 -0
  64. quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
  65. quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
  66. quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
  67. quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
  68. quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
  69. quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
  70. quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
  71. quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
  72. quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
  73. quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
  74. quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
  75. quadra/configs/experiment/base/ssl/byol.yaml +43 -0
  76. quadra/configs/experiment/base/ssl/dino.yaml +46 -0
  77. quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
  78. quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
  79. quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
  80. quadra/configs/experiment/custom/cls.yaml +12 -0
  81. quadra/configs/experiment/default.yaml +15 -0
  82. quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
  83. quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
  84. quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
  85. quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
  86. quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
  87. quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
  88. quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
  89. quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
  90. quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
  91. quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
  92. quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
  93. quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
  94. quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
  95. quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
  96. quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
  97. quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
  98. quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
  99. quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
  100. quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
  101. quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
  102. quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
  103. quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
  104. quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
  105. quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
  106. quadra/configs/export/default.yaml +13 -0
  107. quadra/configs/hydra/anomaly_custom.yaml +15 -0
  108. quadra/configs/hydra/default.yaml +14 -0
  109. quadra/configs/inference/default.yaml +26 -0
  110. quadra/configs/logger/comet.yaml +10 -0
  111. quadra/configs/logger/csv.yaml +5 -0
  112. quadra/configs/logger/mlflow.yaml +12 -0
  113. quadra/configs/logger/tensorboard.yaml +8 -0
  114. quadra/configs/loss/asl.yaml +7 -0
  115. quadra/configs/loss/barlow.yaml +2 -0
  116. quadra/configs/loss/bce.yaml +1 -0
  117. quadra/configs/loss/byol.yaml +1 -0
  118. quadra/configs/loss/cross_entropy.yaml +1 -0
  119. quadra/configs/loss/dino.yaml +8 -0
  120. quadra/configs/loss/simclr.yaml +2 -0
  121. quadra/configs/loss/simsiam.yaml +1 -0
  122. quadra/configs/loss/smp_ce.yaml +3 -0
  123. quadra/configs/loss/smp_dice.yaml +2 -0
  124. quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
  125. quadra/configs/loss/smp_mcc.yaml +2 -0
  126. quadra/configs/loss/vicreg.yaml +5 -0
  127. quadra/configs/model/anomalib/cfa.yaml +35 -0
  128. quadra/configs/model/anomalib/cflow.yaml +30 -0
  129. quadra/configs/model/anomalib/csflow.yaml +34 -0
  130. quadra/configs/model/anomalib/dfm.yaml +19 -0
  131. quadra/configs/model/anomalib/draem.yaml +29 -0
  132. quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
  133. quadra/configs/model/anomalib/fastflow.yaml +32 -0
  134. quadra/configs/model/anomalib/padim.yaml +32 -0
  135. quadra/configs/model/anomalib/patchcore.yaml +36 -0
  136. quadra/configs/model/barlow.yaml +16 -0
  137. quadra/configs/model/byol.yaml +25 -0
  138. quadra/configs/model/classification.yaml +10 -0
  139. quadra/configs/model/dino.yaml +26 -0
  140. quadra/configs/model/logistic_regression.yaml +4 -0
  141. quadra/configs/model/multilabel_classification.yaml +9 -0
  142. quadra/configs/model/simclr.yaml +18 -0
  143. quadra/configs/model/simsiam.yaml +24 -0
  144. quadra/configs/model/smp.yaml +4 -0
  145. quadra/configs/model/smp_multiclass.yaml +4 -0
  146. quadra/configs/model/vicreg.yaml +16 -0
  147. quadra/configs/optimizer/adam.yaml +5 -0
  148. quadra/configs/optimizer/adamw.yaml +3 -0
  149. quadra/configs/optimizer/default.yaml +4 -0
  150. quadra/configs/optimizer/lars.yaml +8 -0
  151. quadra/configs/optimizer/sgd.yaml +4 -0
  152. quadra/configs/scheduler/default.yaml +5 -0
  153. quadra/configs/scheduler/rop.yaml +5 -0
  154. quadra/configs/scheduler/step.yaml +3 -0
  155. quadra/configs/scheduler/warmrestart.yaml +2 -0
  156. quadra/configs/scheduler/warmup.yaml +6 -0
  157. quadra/configs/task/anomalib/cfa.yaml +5 -0
  158. quadra/configs/task/anomalib/cflow.yaml +5 -0
  159. quadra/configs/task/anomalib/csflow.yaml +5 -0
  160. quadra/configs/task/anomalib/draem.yaml +5 -0
  161. quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
  162. quadra/configs/task/anomalib/fastflow.yaml +5 -0
  163. quadra/configs/task/anomalib/inference.yaml +3 -0
  164. quadra/configs/task/anomalib/padim.yaml +5 -0
  165. quadra/configs/task/anomalib/patchcore.yaml +5 -0
  166. quadra/configs/task/classification.yaml +6 -0
  167. quadra/configs/task/classification_evaluation.yaml +6 -0
  168. quadra/configs/task/default.yaml +1 -0
  169. quadra/configs/task/segmentation.yaml +9 -0
  170. quadra/configs/task/segmentation_evaluation.yaml +3 -0
  171. quadra/configs/task/sklearn_classification.yaml +13 -0
  172. quadra/configs/task/sklearn_classification_patch.yaml +11 -0
  173. quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
  174. quadra/configs/task/sklearn_classification_test.yaml +8 -0
  175. quadra/configs/task/ssl.yaml +2 -0
  176. quadra/configs/trainer/lightning_cpu.yaml +36 -0
  177. quadra/configs/trainer/lightning_gpu.yaml +35 -0
  178. quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
  179. quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
  180. quadra/configs/trainer/lightning_multigpu.yaml +37 -0
  181. quadra/configs/trainer/sklearn_classification.yaml +7 -0
  182. quadra/configs/transforms/byol.yaml +47 -0
  183. quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
  184. quadra/configs/transforms/default.yaml +37 -0
  185. quadra/configs/transforms/default_numpy.yaml +24 -0
  186. quadra/configs/transforms/default_resize.yaml +22 -0
  187. quadra/configs/transforms/dino.yaml +63 -0
  188. quadra/configs/transforms/linear_eval.yaml +18 -0
  189. quadra/datamodules/__init__.py +20 -0
  190. quadra/datamodules/anomaly.py +180 -0
  191. quadra/datamodules/base.py +375 -0
  192. quadra/datamodules/classification.py +1003 -0
  193. quadra/datamodules/generic/__init__.py +0 -0
  194. quadra/datamodules/generic/imagenette.py +144 -0
  195. quadra/datamodules/generic/mnist.py +81 -0
  196. quadra/datamodules/generic/mvtec.py +58 -0
  197. quadra/datamodules/generic/oxford_pet.py +163 -0
  198. quadra/datamodules/patch.py +190 -0
  199. quadra/datamodules/segmentation.py +742 -0
  200. quadra/datamodules/ssl.py +140 -0
  201. quadra/datasets/__init__.py +17 -0
  202. quadra/datasets/anomaly.py +287 -0
  203. quadra/datasets/classification.py +241 -0
  204. quadra/datasets/patch.py +138 -0
  205. quadra/datasets/segmentation.py +239 -0
  206. quadra/datasets/ssl.py +110 -0
  207. quadra/losses/__init__.py +0 -0
  208. quadra/losses/classification/__init__.py +6 -0
  209. quadra/losses/classification/asl.py +83 -0
  210. quadra/losses/classification/focal.py +320 -0
  211. quadra/losses/classification/prototypical.py +148 -0
  212. quadra/losses/ssl/__init__.py +17 -0
  213. quadra/losses/ssl/barlowtwins.py +47 -0
  214. quadra/losses/ssl/byol.py +37 -0
  215. quadra/losses/ssl/dino.py +129 -0
  216. quadra/losses/ssl/hyperspherical.py +45 -0
  217. quadra/losses/ssl/idmm.py +50 -0
  218. quadra/losses/ssl/simclr.py +67 -0
  219. quadra/losses/ssl/simsiam.py +30 -0
  220. quadra/losses/ssl/vicreg.py +76 -0
  221. quadra/main.py +46 -0
  222. quadra/metrics/__init__.py +3 -0
  223. quadra/metrics/segmentation.py +251 -0
  224. quadra/models/__init__.py +0 -0
  225. quadra/models/base.py +151 -0
  226. quadra/models/classification/__init__.py +8 -0
  227. quadra/models/classification/backbones.py +149 -0
  228. quadra/models/classification/base.py +92 -0
  229. quadra/models/evaluation.py +322 -0
  230. quadra/modules/__init__.py +0 -0
  231. quadra/modules/backbone.py +30 -0
  232. quadra/modules/base.py +312 -0
  233. quadra/modules/classification/__init__.py +3 -0
  234. quadra/modules/classification/base.py +331 -0
  235. quadra/modules/ssl/__init__.py +17 -0
  236. quadra/modules/ssl/barlowtwins.py +59 -0
  237. quadra/modules/ssl/byol.py +172 -0
  238. quadra/modules/ssl/common.py +285 -0
  239. quadra/modules/ssl/dino.py +186 -0
  240. quadra/modules/ssl/hyperspherical.py +206 -0
  241. quadra/modules/ssl/idmm.py +98 -0
  242. quadra/modules/ssl/simclr.py +73 -0
  243. quadra/modules/ssl/simsiam.py +68 -0
  244. quadra/modules/ssl/vicreg.py +67 -0
  245. quadra/optimizers/__init__.py +4 -0
  246. quadra/optimizers/lars.py +153 -0
  247. quadra/optimizers/sam.py +127 -0
  248. quadra/schedulers/__init__.py +3 -0
  249. quadra/schedulers/base.py +44 -0
  250. quadra/schedulers/warmup.py +127 -0
  251. quadra/tasks/__init__.py +24 -0
  252. quadra/tasks/anomaly.py +582 -0
  253. quadra/tasks/base.py +397 -0
  254. quadra/tasks/classification.py +1264 -0
  255. quadra/tasks/patch.py +492 -0
  256. quadra/tasks/segmentation.py +389 -0
  257. quadra/tasks/ssl.py +560 -0
  258. quadra/trainers/README.md +3 -0
  259. quadra/trainers/__init__.py +0 -0
  260. quadra/trainers/classification.py +179 -0
  261. quadra/utils/__init__.py +0 -0
  262. quadra/utils/anomaly.py +112 -0
  263. quadra/utils/classification.py +618 -0
  264. quadra/utils/deprecation.py +31 -0
  265. quadra/utils/evaluation.py +474 -0
  266. quadra/utils/export.py +579 -0
  267. quadra/utils/imaging.py +32 -0
  268. quadra/utils/logger.py +15 -0
  269. quadra/utils/mlflow.py +98 -0
  270. quadra/utils/model_manager.py +320 -0
  271. quadra/utils/models.py +524 -0
  272. quadra/utils/patch/__init__.py +15 -0
  273. quadra/utils/patch/dataset.py +1433 -0
  274. quadra/utils/patch/metrics.py +449 -0
  275. quadra/utils/patch/model.py +153 -0
  276. quadra/utils/patch/visualization.py +217 -0
  277. quadra/utils/resolver.py +42 -0
  278. quadra/utils/segmentation.py +31 -0
  279. quadra/utils/tests/__init__.py +0 -0
  280. quadra/utils/tests/fixtures/__init__.py +1 -0
  281. quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
  282. quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
  283. quadra/utils/tests/fixtures/dataset/classification.py +406 -0
  284. quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
  285. quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
  286. quadra/utils/tests/fixtures/models/__init__.py +3 -0
  287. quadra/utils/tests/fixtures/models/anomaly.py +89 -0
  288. quadra/utils/tests/fixtures/models/classification.py +45 -0
  289. quadra/utils/tests/fixtures/models/segmentation.py +33 -0
  290. quadra/utils/tests/helpers.py +70 -0
  291. quadra/utils/tests/models.py +27 -0
  292. quadra/utils/utils.py +525 -0
  293. quadra/utils/validator.py +115 -0
  294. quadra/utils/visualization.py +422 -0
  295. quadra/utils/vit_explainability.py +349 -0
  296. quadra-2.1.13.dist-info/LICENSE +201 -0
  297. quadra-2.1.13.dist-info/METADATA +386 -0
  298. quadra-2.1.13.dist-info/RECORD +300 -0
  299. {quadra-0.0.1.dist-info → quadra-2.1.13.dist-info}/WHEEL +1 -1
  300. quadra-2.1.13.dist-info/entry_points.txt +3 -0
  301. quadra-0.0.1.dist-info/METADATA +0 -14
  302. quadra-0.0.1.dist-info/RECORD +0 -4
@@ -0,0 +1,406 @@
1
+ from __future__ import annotations
2
+
3
+ import glob
4
+ import os
5
+ import random
6
+ import shutil
7
+ from dataclasses import dataclass
8
+ from pathlib import Path
9
+ from typing import Any
10
+
11
+ import cv2
12
+ import numpy as np
13
+ import pytest
14
+
15
+ from quadra.utils.patch import generate_patch_dataset, get_image_mask_association
16
+ from quadra.utils.tests.helpers import _random_image
17
+
18
+
19
+ @dataclass
20
+ class ClassificationDatasetArguments:
21
+ """Classification dataset arguments.
22
+
23
+ Args:
24
+ samples: number of samples per class
25
+ classes: class names, if set it must be the same length as samples
26
+ val_size: validation set size
27
+ test_size: test set size
28
+ """
29
+
30
+ samples: list[int]
31
+ classes: list[str] | None = None
32
+ val_size: float | None = None
33
+ test_size: float | None = None
34
+
35
+
36
+ @dataclass
37
+ class ClassificationMultilabelDatasetArguments:
38
+ """Classification dataset arguments.
39
+
40
+ Args:
41
+ samples: number of samples per class
42
+ classes: class names, if set it must be the same length as samples
43
+ val_size: validation set size
44
+ test_size: test set size
45
+ percentage_other_classes: probability of adding other classes to the labels of each sample
46
+ """
47
+
48
+ samples: list[int]
49
+ classes: list[str] | None = None
50
+ val_size: float | None = None
51
+ test_size: float | None = None
52
+ percentage_other_classes: float | None = 0.0
53
+
54
+
55
+ @dataclass
56
+ class ClassificationPatchDatasetArguments:
57
+ """Classification patch dataset arguments.
58
+
59
+ Args:
60
+ samples: number of samples per class
61
+ overlap: overlap between patches
62
+ patch_size: patch size
63
+ patch_number: number of patches
64
+ classes: class names, if set it must be the same length as samples
65
+ val_size: validation set size
66
+ test_size: test set size
67
+ annotated_good: list of class names that are considered as good annotations (E.g. ["good"])
68
+ """
69
+
70
+ samples: list[int]
71
+ overlap: float
72
+ patch_size: tuple[int, int] | None = None
73
+ patch_number: tuple[int, int] | None = None
74
+ classes: list[str] | None = None
75
+ val_size: float | None = 0.0
76
+ test_size: float | None = 0.0
77
+ annotated_good: list[str] | None = None
78
+
79
+
80
+ def _build_classification_dataset(
81
+ tmp_path: Path, dataset_arguments: ClassificationDatasetArguments
82
+ ) -> tuple[str, ClassificationDatasetArguments]:
83
+ """Generate classification dataset. If val_size or test_size are set, it will generate a train.txt, val.txt and
84
+ test.txt file in the dataset directory. By default generated images are 10x10 pixels.
85
+
86
+ Args:
87
+ tmp_path: path to temporary directory
88
+ dataset_arguments: dataset arguments
89
+
90
+ Returns:
91
+ Tuple containing path to created dataset and dataset arguments
92
+ """
93
+ classification_dataset_path = tmp_path / "classification_dataset"
94
+ classification_dataset_path.mkdir()
95
+
96
+ classes = dataset_arguments.classes if dataset_arguments.classes else range(len(dataset_arguments.samples))
97
+
98
+ for class_name, samples in zip(classes, dataset_arguments.samples):
99
+ class_path = classification_dataset_path / str(class_name)
100
+ class_path.mkdir()
101
+ for i in range(samples):
102
+ image = _random_image()
103
+ image_path = class_path / f"{class_name}_{i}.png"
104
+ cv2.imwrite(str(image_path), image)
105
+
106
+ if dataset_arguments.val_size is not None or dataset_arguments.test_size is not None:
107
+ all_images = glob.glob(os.path.join(str(classification_dataset_path), "**", "*.png"))
108
+ all_images = [f"{os.path.basename(os.path.dirname(image))}/{os.path.basename(image)}" for image in all_images]
109
+ val_size = dataset_arguments.val_size if dataset_arguments.val_size is not None else 0
110
+ test_size = dataset_arguments.test_size if dataset_arguments.test_size is not None else 0
111
+ train_size = 1 - val_size - test_size
112
+
113
+ # pylint: disable=unbalanced-tuple-unpacking
114
+ train_images, val_images, test_images = np.split(
115
+ np.random.permutation(all_images),
116
+ [int(train_size * len(all_images)), int((train_size + val_size) * len(all_images))],
117
+ )
118
+
119
+ with open(classification_dataset_path / "train.txt", "w") as f:
120
+ f.write("\n".join(train_images))
121
+
122
+ with open(classification_dataset_path / "val.txt", "w") as f:
123
+ f.write("\n".join(val_images))
124
+
125
+ with open(classification_dataset_path / "test.txt", "w") as f:
126
+ f.write("\n".join(test_images))
127
+
128
+ return str(classification_dataset_path), dataset_arguments
129
+
130
+
131
+ @pytest.fixture
132
+ def classification_dataset(
133
+ tmp_path: Path, dataset_arguments: ClassificationDatasetArguments
134
+ ) -> tuple[str, ClassificationDatasetArguments]:
135
+ """Generate classification dataset. If val_size or test_size are set, it will generate a train.txt, val.txt and
136
+ test.txt file in the dataset directory. By default generated images are 10x10 pixels.
137
+
138
+ Args:
139
+ tmp_path: path to temporary directory
140
+ dataset_arguments: dataset arguments
141
+
142
+ Yields:
143
+ Tuple containing path to created dataset and dataset arguments
144
+ """
145
+ yield _build_classification_dataset(tmp_path, dataset_arguments)
146
+ if tmp_path.exists():
147
+ shutil.rmtree(tmp_path)
148
+
149
+
150
+ @pytest.fixture(
151
+ params=[
152
+ ClassificationDatasetArguments(
153
+ **{"samples": [10, 10], "classes": ["class_1", "class_2"], "val_size": 0.1, "test_size": 0.1}
154
+ )
155
+ ]
156
+ )
157
+ def base_classification_dataset(tmp_path: Path, request: Any) -> tuple[str, ClassificationDatasetArguments]:
158
+ """Generate base classification dataset with the following parameters:
159
+ - 10 samples per class
160
+ - 2 classes (class_1 and class_2)
161
+ By default generated images are grayscale and 10x10 pixels.
162
+
163
+ Args:
164
+ tmp_path: path to temporary directory
165
+ request: pytest request
166
+
167
+ Yields:
168
+ Tuple containing path to created dataset and dataset arguments
169
+ """
170
+ yield _build_classification_dataset(tmp_path, request.param)
171
+ if tmp_path.exists():
172
+ shutil.rmtree(tmp_path)
173
+
174
+
175
+ def _build_multilabel_classification_dataset(
176
+ tmp_path: Path, dataset_arguments: ClassificationMultilabelDatasetArguments
177
+ ) -> tuple[str, ClassificationMultilabelDatasetArguments]:
178
+ """Generate a multilabel classification dataset.
179
+ Generates a samples.txt file in the dataset directory containing the path to the image and the corresponding
180
+ classes. If val_size or test_size are set, it will generate a train.txt, val.txt and test.txt file in the
181
+ dataset directory. By default generated images are 10x10 pixels.
182
+
183
+ Args:
184
+ tmp_path: path to temporary directory
185
+ dataset_arguments: dataset arguments
186
+
187
+ Returns:
188
+ Tuple containing path to created dataset and dataset arguments
189
+ """
190
+ classification_dataset_path = tmp_path / "multilabel_classification_dataset"
191
+ images_path = classification_dataset_path / "images"
192
+ classification_dataset_path.mkdir()
193
+ images_path.mkdir()
194
+
195
+ classes = dataset_arguments.classes if dataset_arguments.classes else range(len(dataset_arguments.samples))
196
+ percentage_other_classes = dataset_arguments.percentage_other_classes
197
+
198
+ generated_samples = []
199
+ counter = 0
200
+ for class_name, samples in zip(classes, dataset_arguments.samples):
201
+ for _ in range(samples):
202
+ image = _random_image()
203
+ image_path = images_path / f"{counter}.png"
204
+ counter += 1
205
+ cv2.imwrite(str(image_path), image)
206
+ targets = [class_name]
207
+ targets = targets + [
208
+ cl_name for cl_name in classes if cl_name != class_name and random.random() < percentage_other_classes
209
+ ]
210
+ generated_samples.append(f"images/{image_path.name},{','.join(targets)}")
211
+
212
+ with open(classification_dataset_path / "samples.txt", "w") as f:
213
+ f.write("\n".join(generated_samples))
214
+
215
+ if dataset_arguments.val_size is not None or dataset_arguments.test_size is not None:
216
+ val_size = dataset_arguments.val_size if dataset_arguments.val_size is not None else 0
217
+ test_size = dataset_arguments.test_size if dataset_arguments.test_size is not None else 0
218
+ train_size = 1 - val_size - test_size
219
+
220
+ # pylint: disable=unbalanced-tuple-unpacking
221
+ train_images, val_images, test_images = np.split(
222
+ np.random.permutation(generated_samples),
223
+ [int(train_size * len(generated_samples)), int((train_size + val_size) * len(generated_samples))],
224
+ )
225
+
226
+ with open(classification_dataset_path / "train.txt", "w") as f:
227
+ f.write("\n".join(train_images))
228
+
229
+ with open(classification_dataset_path / "val.txt", "w") as f:
230
+ f.write("\n".join(val_images))
231
+
232
+ with open(classification_dataset_path / "test.txt", "w") as f:
233
+ f.write("\n".join(test_images))
234
+
235
+ return str(classification_dataset_path), dataset_arguments
236
+
237
+
238
+ @pytest.fixture
239
+ def multilabel_classification_dataset(
240
+ tmp_path: Path, dataset_arguments: ClassificationMultilabelDatasetArguments
241
+ ) -> tuple[str, ClassificationMultilabelDatasetArguments]:
242
+ """Fixture to dinamically generate a multilabel classification dataset.
243
+ Generates a samples.txt file in the dataset directory containing the path to the image and the corresponding
244
+ classes. If val_size or test_size are set, it will generate a train.txt, val.txt and test.txt file in the
245
+ dataset directory. By default generated images are 10x10 pixels.
246
+
247
+ Args:
248
+ tmp_path: path to temporary directory
249
+ dataset_arguments: dataset arguments
250
+
251
+ Returns:
252
+ Tuple containing path to created dataset and dataset arguments
253
+ """
254
+ yield _build_multilabel_classification_dataset(tmp_path, dataset_arguments)
255
+ if tmp_path.exists():
256
+ shutil.rmtree(tmp_path)
257
+
258
+
259
+ @pytest.fixture(
260
+ params=[
261
+ ClassificationMultilabelDatasetArguments(
262
+ **{
263
+ "samples": [10, 10, 10],
264
+ "classes": ["class_1", "class_2", "class_3"],
265
+ "val_size": 0.1,
266
+ "test_size": 0.1,
267
+ "percentage_other_classes": 0.3,
268
+ }
269
+ )
270
+ ]
271
+ )
272
+ def base_multilabel_classification_dataset(
273
+ tmp_path: Path, request: Any
274
+ ) -> tuple[str, ClassificationMultilabelDatasetArguments]:
275
+ """Fixture to generate base multilabel classification dataset with the following parameters:
276
+ - 10 samples per class
277
+ - 3 classes (class_1, class_2 and class_3)
278
+ - 10% of samples in validation set
279
+ - 10% of samples in test set
280
+ - 30% of possibility to add each other class to the sample
281
+ By default generated images are grayscale and 10x10 pixels.
282
+
283
+ Args:
284
+ tmp_path: path to temporary directory
285
+ request: pytest request
286
+
287
+ Yields:
288
+ Tuple containing path to created dataset and dataset arguments
289
+ """
290
+ yield _build_multilabel_classification_dataset(tmp_path, request.param)
291
+ if tmp_path.exists():
292
+ shutil.rmtree(tmp_path)
293
+
294
+
295
+ def _build_classification_patch_dataset(
296
+ tmp_path: Path, dataset_arguments: ClassificationDatasetArguments
297
+ ) -> tuple[str, ClassificationDatasetArguments, dict[str, int]]:
298
+ """Generate a classification patch dataset. By default generated images are 224x224 pixels
299
+ and associated masks contains a 50x50 pixels square with the corresponding image class, so at the current stage
300
+ is not possible to have images with multiple annotations. The patch dataset will be generated using the standard
301
+ parameters of generate_patch_dataset function.
302
+
303
+ Args:
304
+ tmp_path: path to temporary directory
305
+ dataset_arguments: dataset arguments
306
+
307
+ Returns:
308
+ Tuple containing path to created dataset, dataset arguments and class to index mapping
309
+ """
310
+ initial_dataset_path = tmp_path / "initial_dataset"
311
+ initial_dataset_path.mkdir()
312
+
313
+ images_path = initial_dataset_path / "images"
314
+ masks_path = initial_dataset_path / "masks"
315
+ images_path.mkdir()
316
+ masks_path.mkdir()
317
+
318
+ classes = dataset_arguments.classes if dataset_arguments.classes else range(len(dataset_arguments.samples))
319
+
320
+ class_to_idx = {class_name: i for i, class_name in enumerate(classes)}
321
+
322
+ for class_name, samples in zip(classes, dataset_arguments.samples):
323
+ for i in range(samples):
324
+ image = _random_image(size=(224, 224))
325
+ mask = np.zeros((224, 224), dtype=np.uint8)
326
+ mask[100:150, 100:150] = class_to_idx[class_name]
327
+ image_path = images_path / f"{class_name}_{i}.png"
328
+ mask_path = masks_path / f"{class_name}_{i}.png"
329
+ cv2.imwrite(str(image_path), image)
330
+ cv2.imwrite(str(mask_path), mask)
331
+
332
+ patch_dataset_path = tmp_path / "patch_dataset"
333
+ patch_dataset_path.mkdir()
334
+
335
+ data_dictionary = get_image_mask_association(data_folder=str(images_path), mask_folder=str(masks_path))
336
+
337
+ _ = generate_patch_dataset(
338
+ data_dictionary=data_dictionary,
339
+ class_to_idx=class_to_idx,
340
+ val_size=dataset_arguments.val_size,
341
+ test_size=dataset_arguments.test_size,
342
+ patch_number=dataset_arguments.patch_number,
343
+ patch_size=dataset_arguments.patch_size,
344
+ overlap=dataset_arguments.overlap,
345
+ output_folder=str(patch_dataset_path),
346
+ annotated_good=dataset_arguments.annotated_good,
347
+ )
348
+
349
+ return str(patch_dataset_path), dataset_arguments, class_to_idx
350
+
351
+
352
+ @pytest.fixture
353
+ def classification_patch_dataset(
354
+ tmp_path: Path, dataset_arguments: ClassificationDatasetArguments
355
+ ) -> tuple[str, ClassificationDatasetArguments, dict[str, int]]:
356
+ """Fixture to dinamically generate a classification patch dataset.
357
+
358
+ By default generated images are 224x224 pixels
359
+ and associated masks contains a 50x50 pixels square with the corresponding image class, so at the current stage
360
+ is not possible to have images with multiple annotations. The patch dataset will be generated using the standard
361
+ parameters of generate_patch_dataset function.
362
+
363
+ Args:
364
+ tmp_path: path to temporary directory
365
+ dataset_arguments: dataset arguments
366
+
367
+ Yields:
368
+ Tuple containing path to created dataset, dataset arguments and class to index mapping
369
+ """
370
+ yield _build_classification_patch_dataset(tmp_path, dataset_arguments)
371
+ if tmp_path.exists():
372
+ shutil.rmtree(tmp_path)
373
+
374
+
375
+ @pytest.fixture(
376
+ params=[
377
+ ClassificationPatchDatasetArguments(
378
+ **{
379
+ "samples": [5, 5, 5],
380
+ "classes": ["bg", "a", "b"],
381
+ "patch_number": [2, 2],
382
+ "overlap": 0,
383
+ "val_size": 0.1,
384
+ "test_size": 0.1,
385
+ }
386
+ )
387
+ ]
388
+ )
389
+ def base_patch_classification_dataset(
390
+ tmp_path: Path, request: Any
391
+ ) -> tuple[str, ClassificationDatasetArguments, dict[str, int]]:
392
+ """Generate a classification patch dataset with the following parameters:
393
+ - 3 classes named bg, a and b
394
+ - 5, 5 and 5 samples for each class
395
+ - 2 horizontal patches and 2 vertical patches
396
+ - 0% overlap
397
+ - 10% validation set
398
+ - 10% test set.
399
+
400
+ Args:
401
+ tmp_path: path to temporary directory
402
+ request: pytest SubRequest object
403
+ """
404
+ yield _build_classification_patch_dataset(tmp_path, request.param)
405
+ if tmp_path.exists():
406
+ shutil.rmtree(tmp_path)
@@ -0,0 +1,53 @@
1
+ import shutil
2
+ from pathlib import Path
3
+
4
+ import cv2
5
+ import pytest
6
+
7
+ from quadra.utils.tests.helpers import _random_image
8
+
9
+
10
+ def _build_imagenette_dataset(tmp_path: Path, classes: int, class_samples: int) -> str:
11
+ """Generate imagenette dataset in the format required by efficient_ad model.
12
+
13
+ Args:
14
+ tmp_path: Path to temporary directory
15
+ classes: Number of mock imagenette classes
16
+ class_samples: Number of samples for each mock imagenette class
17
+
18
+ Returns:
19
+ Path to imagenette dataset
20
+ """
21
+ parent_path = tmp_path / "imagenette2"
22
+ parent_path.mkdir()
23
+ train_path = parent_path / "train"
24
+ train_path.mkdir()
25
+ val_path = parent_path / "val"
26
+ val_path.mkdir()
27
+
28
+ for split in [train_path, val_path]:
29
+ for i in range(classes):
30
+ cl_path = split / f"class_{i}"
31
+ cl_path.mkdir()
32
+ for j in range(class_samples):
33
+ image = _random_image()
34
+ image_path = cl_path / f"fake_{j}.png"
35
+ cv2.imwrite(str(image_path), image)
36
+
37
+ return parent_path
38
+
39
+
40
+ @pytest.fixture
41
+ def imagenette_dataset(tmp_path: Path) -> str:
42
+ """Generate a mock imagenette dataset to test efficient_ad model.
43
+
44
+ Args:
45
+ tmp_path: Path to temporary directory
46
+ request: Pytest SubRequest object
47
+ Yields:
48
+ Path to imagenette dataset folder
49
+ """
50
+ yield _build_imagenette_dataset(tmp_path, classes=3, class_samples=3)
51
+
52
+ if tmp_path.exists():
53
+ shutil.rmtree(tmp_path)
@@ -0,0 +1,161 @@
1
+ from __future__ import annotations
2
+
3
+ import shutil
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+ import cv2
9
+ import numpy as np
10
+ import pytest
11
+
12
+ from quadra.utils.tests.helpers import _random_image
13
+
14
+
15
+ @dataclass
16
+ class SegmentationDatasetArguments:
17
+ """Segmentation dataset arguments.
18
+
19
+ Args:
20
+ train_samples: List of samples per class in train set, element at index 0 are good samples
21
+ val_samples: List of samples per class in validation set, same as above.
22
+ test_samples: List of samples per class in test set, same as above.
23
+ classes: Optional list of class names, must be equal to len(train_samples) - 1
24
+ """
25
+
26
+ train_samples: list[int]
27
+ val_samples: list[int] | None = None
28
+ test_samples: list[int] | None = None
29
+ classes: list[str] | None = None
30
+
31
+
32
+ def _build_segmentation_dataset(
33
+ tmp_path: Path, dataset_arguments: SegmentationDatasetArguments
34
+ ) -> tuple[str, SegmentationDatasetArguments, dict[str, int]]:
35
+ """Generate segmentation dataset.
36
+
37
+ Args:
38
+ tmp_path: path to temporary directory
39
+ dataset_arguments: dataset arguments
40
+
41
+ Returns:
42
+ Tuple containing path to dataset, dataset arguments and class to index mapping
43
+ """
44
+ train_samples = dataset_arguments.train_samples
45
+ val_samples = dataset_arguments.val_samples
46
+ test_samples = dataset_arguments.test_samples
47
+ classes = (
48
+ dataset_arguments.classes if dataset_arguments.classes else list(range(1, len(dataset_arguments.train_samples)))
49
+ )
50
+
51
+ segmentation_dataset_path = tmp_path / "segmentation_dataset"
52
+ segmentation_dataset_path.mkdir()
53
+ images_path = segmentation_dataset_path / "images"
54
+ masks_path = segmentation_dataset_path / "masks"
55
+ images_path.mkdir(parents=True)
56
+ masks_path.mkdir(parents=True)
57
+ class_to_idx = {class_name: i + 1 for i, class_name in enumerate(classes)}
58
+ classes = [0] + classes
59
+
60
+ counter = 0
61
+ for split_name, split_samples in zip(["train", "val", "test"], [train_samples, val_samples, test_samples]):
62
+ if split_samples is None:
63
+ continue
64
+
65
+ with open(segmentation_dataset_path / f"{split_name}.txt", "w") as split_file:
66
+ for class_name, samples in zip(classes, split_samples):
67
+ for _ in range(samples):
68
+ image = _random_image(size=(224, 224))
69
+ mask = np.zeros((224, 224), dtype=np.uint8)
70
+ if class_name != 0:
71
+ mask[100:150, 100:150] = class_to_idx[class_name]
72
+ image_path = images_path / f"{class_name}_{counter}.png"
73
+ mask_path = masks_path / f"{class_name}_{counter}.png"
74
+ cv2.imwrite(str(image_path), image)
75
+ cv2.imwrite(str(mask_path), mask)
76
+ split_file.write(f"images/{image_path.name}\n")
77
+ counter += 1
78
+
79
+ return str(segmentation_dataset_path), dataset_arguments, class_to_idx
80
+
81
+
82
+ @pytest.fixture
83
+ def segmentation_dataset(
84
+ tmp_path: Path, dataset_arguments: SegmentationDatasetArguments
85
+ ) -> tuple[str, SegmentationDatasetArguments, dict[str, int]]:
86
+ """Fixture to dinamically generate a segmentation dataset. By default generated images are 224x224 pixels
87
+ and associated masks contains a 50x50 pixels square with the corresponding image class, so at the current stage
88
+ is not possible to have images with multiple annotations. Split files are saved as train.txt,
89
+ val.txt and test.txt.
90
+
91
+ Args:
92
+ tmp_path: path to temporary directory
93
+ dataset_arguments: dataset arguments
94
+
95
+ Yields:
96
+ Tuple containing path to dataset, dataset arguments and class to index mapping
97
+ """
98
+ yield _build_segmentation_dataset(tmp_path, dataset_arguments)
99
+ if tmp_path.exists():
100
+ shutil.rmtree(tmp_path)
101
+
102
+
103
+ @pytest.fixture(
104
+ params=[
105
+ SegmentationDatasetArguments(
106
+ **{"train_samples": [3, 2], "val_samples": [2, 2], "test_samples": [1, 1], "classes": ["bad"]}
107
+ )
108
+ ]
109
+ )
110
+ def base_binary_segmentation_dataset(
111
+ tmp_path: Path, request: Any
112
+ ) -> tuple[str, SegmentationDatasetArguments, dict[str, int]]:
113
+ """Generate a base binary segmentation dataset with the following structure:
114
+ - 3 good and 2 bad samples in train set
115
+ - 2 good and 2 bad samples in validation set
116
+ - 11 good and 1 bad sample in test set
117
+ - 2 classes: good and bad.
118
+
119
+ Args:
120
+ tmp_path: path to temporary directory
121
+ request: pytest request
122
+
123
+ Yields:
124
+ Tuple containing path to dataset, dataset arguments and class to index mapping
125
+ """
126
+ yield _build_segmentation_dataset(tmp_path, request.param)
127
+ if tmp_path.exists():
128
+ shutil.rmtree(tmp_path)
129
+
130
+
131
+ @pytest.fixture(
132
+ params=[
133
+ SegmentationDatasetArguments(
134
+ **{
135
+ "train_samples": [2, 2, 2],
136
+ "val_samples": [2, 2, 2],
137
+ "test_samples": [1, 1, 1],
138
+ "classes": ["defect_1", "defect_2"],
139
+ }
140
+ )
141
+ ]
142
+ )
143
+ def base_multiclass_segmentation_dataset(
144
+ tmp_path: Path, request: Any
145
+ ) -> tuple[str, SegmentationDatasetArguments, dict[str, int]]:
146
+ """Generate a base binary segmentation dataset with the following structure:
147
+ - 2 good, 2 defect_1 and 2 defect_2 samples in train set
148
+ - 2 good, 2 defect_1 and 2 defect_2 samples in validation set
149
+ - 1 good, 1 defect_1 and 1 defect_2 sample in test set
150
+ - 3 classes: good, defect_1 and defect_2.
151
+
152
+ Args:
153
+ tmp_path: path to temporary directory
154
+ request: pytest request
155
+
156
+ Yields:
157
+ Tuple containing path to dataset, dataset arguments and class to index mapping
158
+ """
159
+ yield _build_segmentation_dataset(tmp_path, request.param)
160
+ if tmp_path.exists():
161
+ shutil.rmtree(tmp_path)
@@ -0,0 +1,3 @@
1
+ from .anomaly import * # noqa: F403
2
+ from .classification import * # noqa: F403
3
+ from .segmentation import * # noqa: F403