quadra 0.0.1__py3-none-any.whl → 2.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (302) hide show
  1. hydra_plugins/quadra_searchpath_plugin.py +30 -0
  2. quadra/__init__.py +6 -0
  3. quadra/callbacks/__init__.py +0 -0
  4. quadra/callbacks/anomalib.py +289 -0
  5. quadra/callbacks/lightning.py +501 -0
  6. quadra/callbacks/mlflow.py +291 -0
  7. quadra/callbacks/scheduler.py +69 -0
  8. quadra/configs/__init__.py +0 -0
  9. quadra/configs/backbone/caformer_m36.yaml +8 -0
  10. quadra/configs/backbone/caformer_s36.yaml +8 -0
  11. quadra/configs/backbone/convnextv2_base.yaml +8 -0
  12. quadra/configs/backbone/convnextv2_femto.yaml +8 -0
  13. quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
  14. quadra/configs/backbone/dino_vitb8.yaml +12 -0
  15. quadra/configs/backbone/dino_vits8.yaml +12 -0
  16. quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
  17. quadra/configs/backbone/dinov2_vits14.yaml +12 -0
  18. quadra/configs/backbone/efficientnet_b0.yaml +8 -0
  19. quadra/configs/backbone/efficientnet_b1.yaml +8 -0
  20. quadra/configs/backbone/efficientnet_b2.yaml +8 -0
  21. quadra/configs/backbone/efficientnet_b3.yaml +8 -0
  22. quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
  23. quadra/configs/backbone/levit_128s.yaml +8 -0
  24. quadra/configs/backbone/mnasnet0_5.yaml +9 -0
  25. quadra/configs/backbone/resnet101.yaml +8 -0
  26. quadra/configs/backbone/resnet18.yaml +8 -0
  27. quadra/configs/backbone/resnet18_ssl.yaml +8 -0
  28. quadra/configs/backbone/resnet50.yaml +8 -0
  29. quadra/configs/backbone/smp.yaml +9 -0
  30. quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
  31. quadra/configs/backbone/unetr.yaml +15 -0
  32. quadra/configs/backbone/vit16_base.yaml +9 -0
  33. quadra/configs/backbone/vit16_small.yaml +9 -0
  34. quadra/configs/backbone/vit16_tiny.yaml +9 -0
  35. quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
  36. quadra/configs/callbacks/all.yaml +45 -0
  37. quadra/configs/callbacks/default.yaml +34 -0
  38. quadra/configs/callbacks/default_anomalib.yaml +64 -0
  39. quadra/configs/config.yaml +33 -0
  40. quadra/configs/core/default.yaml +11 -0
  41. quadra/configs/datamodule/base/anomaly.yaml +16 -0
  42. quadra/configs/datamodule/base/classification.yaml +21 -0
  43. quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
  44. quadra/configs/datamodule/base/segmentation.yaml +18 -0
  45. quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
  46. quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
  47. quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
  48. quadra/configs/datamodule/base/ssl.yaml +21 -0
  49. quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
  50. quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
  51. quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
  52. quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
  53. quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
  54. quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
  55. quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
  56. quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
  57. quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
  58. quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
  59. quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
  60. quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
  61. quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
  62. quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
  63. quadra/configs/experiment/base/classification/classification.yaml +73 -0
  64. quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
  65. quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
  66. quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
  67. quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
  68. quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
  69. quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
  70. quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
  71. quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
  72. quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
  73. quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
  74. quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
  75. quadra/configs/experiment/base/ssl/byol.yaml +43 -0
  76. quadra/configs/experiment/base/ssl/dino.yaml +46 -0
  77. quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
  78. quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
  79. quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
  80. quadra/configs/experiment/custom/cls.yaml +12 -0
  81. quadra/configs/experiment/default.yaml +15 -0
  82. quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
  83. quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
  84. quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
  85. quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
  86. quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
  87. quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
  88. quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
  89. quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
  90. quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
  91. quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
  92. quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
  93. quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
  94. quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
  95. quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
  96. quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
  97. quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
  98. quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
  99. quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
  100. quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
  101. quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
  102. quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
  103. quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
  104. quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
  105. quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
  106. quadra/configs/export/default.yaml +13 -0
  107. quadra/configs/hydra/anomaly_custom.yaml +15 -0
  108. quadra/configs/hydra/default.yaml +14 -0
  109. quadra/configs/inference/default.yaml +26 -0
  110. quadra/configs/logger/comet.yaml +10 -0
  111. quadra/configs/logger/csv.yaml +5 -0
  112. quadra/configs/logger/mlflow.yaml +12 -0
  113. quadra/configs/logger/tensorboard.yaml +8 -0
  114. quadra/configs/loss/asl.yaml +7 -0
  115. quadra/configs/loss/barlow.yaml +2 -0
  116. quadra/configs/loss/bce.yaml +1 -0
  117. quadra/configs/loss/byol.yaml +1 -0
  118. quadra/configs/loss/cross_entropy.yaml +1 -0
  119. quadra/configs/loss/dino.yaml +8 -0
  120. quadra/configs/loss/simclr.yaml +2 -0
  121. quadra/configs/loss/simsiam.yaml +1 -0
  122. quadra/configs/loss/smp_ce.yaml +3 -0
  123. quadra/configs/loss/smp_dice.yaml +2 -0
  124. quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
  125. quadra/configs/loss/smp_mcc.yaml +2 -0
  126. quadra/configs/loss/vicreg.yaml +5 -0
  127. quadra/configs/model/anomalib/cfa.yaml +35 -0
  128. quadra/configs/model/anomalib/cflow.yaml +30 -0
  129. quadra/configs/model/anomalib/csflow.yaml +34 -0
  130. quadra/configs/model/anomalib/dfm.yaml +19 -0
  131. quadra/configs/model/anomalib/draem.yaml +29 -0
  132. quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
  133. quadra/configs/model/anomalib/fastflow.yaml +32 -0
  134. quadra/configs/model/anomalib/padim.yaml +32 -0
  135. quadra/configs/model/anomalib/patchcore.yaml +36 -0
  136. quadra/configs/model/barlow.yaml +16 -0
  137. quadra/configs/model/byol.yaml +25 -0
  138. quadra/configs/model/classification.yaml +10 -0
  139. quadra/configs/model/dino.yaml +26 -0
  140. quadra/configs/model/logistic_regression.yaml +4 -0
  141. quadra/configs/model/multilabel_classification.yaml +9 -0
  142. quadra/configs/model/simclr.yaml +18 -0
  143. quadra/configs/model/simsiam.yaml +24 -0
  144. quadra/configs/model/smp.yaml +4 -0
  145. quadra/configs/model/smp_multiclass.yaml +4 -0
  146. quadra/configs/model/vicreg.yaml +16 -0
  147. quadra/configs/optimizer/adam.yaml +5 -0
  148. quadra/configs/optimizer/adamw.yaml +3 -0
  149. quadra/configs/optimizer/default.yaml +4 -0
  150. quadra/configs/optimizer/lars.yaml +8 -0
  151. quadra/configs/optimizer/sgd.yaml +4 -0
  152. quadra/configs/scheduler/default.yaml +5 -0
  153. quadra/configs/scheduler/rop.yaml +5 -0
  154. quadra/configs/scheduler/step.yaml +3 -0
  155. quadra/configs/scheduler/warmrestart.yaml +2 -0
  156. quadra/configs/scheduler/warmup.yaml +6 -0
  157. quadra/configs/task/anomalib/cfa.yaml +5 -0
  158. quadra/configs/task/anomalib/cflow.yaml +5 -0
  159. quadra/configs/task/anomalib/csflow.yaml +5 -0
  160. quadra/configs/task/anomalib/draem.yaml +5 -0
  161. quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
  162. quadra/configs/task/anomalib/fastflow.yaml +5 -0
  163. quadra/configs/task/anomalib/inference.yaml +3 -0
  164. quadra/configs/task/anomalib/padim.yaml +5 -0
  165. quadra/configs/task/anomalib/patchcore.yaml +5 -0
  166. quadra/configs/task/classification.yaml +6 -0
  167. quadra/configs/task/classification_evaluation.yaml +6 -0
  168. quadra/configs/task/default.yaml +1 -0
  169. quadra/configs/task/segmentation.yaml +9 -0
  170. quadra/configs/task/segmentation_evaluation.yaml +3 -0
  171. quadra/configs/task/sklearn_classification.yaml +13 -0
  172. quadra/configs/task/sklearn_classification_patch.yaml +11 -0
  173. quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
  174. quadra/configs/task/sklearn_classification_test.yaml +8 -0
  175. quadra/configs/task/ssl.yaml +2 -0
  176. quadra/configs/trainer/lightning_cpu.yaml +36 -0
  177. quadra/configs/trainer/lightning_gpu.yaml +35 -0
  178. quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
  179. quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
  180. quadra/configs/trainer/lightning_multigpu.yaml +37 -0
  181. quadra/configs/trainer/sklearn_classification.yaml +7 -0
  182. quadra/configs/transforms/byol.yaml +47 -0
  183. quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
  184. quadra/configs/transforms/default.yaml +37 -0
  185. quadra/configs/transforms/default_numpy.yaml +24 -0
  186. quadra/configs/transforms/default_resize.yaml +22 -0
  187. quadra/configs/transforms/dino.yaml +63 -0
  188. quadra/configs/transforms/linear_eval.yaml +18 -0
  189. quadra/datamodules/__init__.py +20 -0
  190. quadra/datamodules/anomaly.py +180 -0
  191. quadra/datamodules/base.py +375 -0
  192. quadra/datamodules/classification.py +1003 -0
  193. quadra/datamodules/generic/__init__.py +0 -0
  194. quadra/datamodules/generic/imagenette.py +144 -0
  195. quadra/datamodules/generic/mnist.py +81 -0
  196. quadra/datamodules/generic/mvtec.py +58 -0
  197. quadra/datamodules/generic/oxford_pet.py +163 -0
  198. quadra/datamodules/patch.py +190 -0
  199. quadra/datamodules/segmentation.py +742 -0
  200. quadra/datamodules/ssl.py +140 -0
  201. quadra/datasets/__init__.py +17 -0
  202. quadra/datasets/anomaly.py +287 -0
  203. quadra/datasets/classification.py +241 -0
  204. quadra/datasets/patch.py +138 -0
  205. quadra/datasets/segmentation.py +239 -0
  206. quadra/datasets/ssl.py +110 -0
  207. quadra/losses/__init__.py +0 -0
  208. quadra/losses/classification/__init__.py +6 -0
  209. quadra/losses/classification/asl.py +83 -0
  210. quadra/losses/classification/focal.py +320 -0
  211. quadra/losses/classification/prototypical.py +148 -0
  212. quadra/losses/ssl/__init__.py +17 -0
  213. quadra/losses/ssl/barlowtwins.py +47 -0
  214. quadra/losses/ssl/byol.py +37 -0
  215. quadra/losses/ssl/dino.py +129 -0
  216. quadra/losses/ssl/hyperspherical.py +45 -0
  217. quadra/losses/ssl/idmm.py +50 -0
  218. quadra/losses/ssl/simclr.py +67 -0
  219. quadra/losses/ssl/simsiam.py +30 -0
  220. quadra/losses/ssl/vicreg.py +76 -0
  221. quadra/main.py +49 -0
  222. quadra/metrics/__init__.py +3 -0
  223. quadra/metrics/segmentation.py +251 -0
  224. quadra/models/__init__.py +0 -0
  225. quadra/models/base.py +151 -0
  226. quadra/models/classification/__init__.py +8 -0
  227. quadra/models/classification/backbones.py +149 -0
  228. quadra/models/classification/base.py +92 -0
  229. quadra/models/evaluation.py +322 -0
  230. quadra/modules/__init__.py +0 -0
  231. quadra/modules/backbone.py +30 -0
  232. quadra/modules/base.py +312 -0
  233. quadra/modules/classification/__init__.py +3 -0
  234. quadra/modules/classification/base.py +327 -0
  235. quadra/modules/ssl/__init__.py +17 -0
  236. quadra/modules/ssl/barlowtwins.py +59 -0
  237. quadra/modules/ssl/byol.py +172 -0
  238. quadra/modules/ssl/common.py +285 -0
  239. quadra/modules/ssl/dino.py +186 -0
  240. quadra/modules/ssl/hyperspherical.py +206 -0
  241. quadra/modules/ssl/idmm.py +98 -0
  242. quadra/modules/ssl/simclr.py +73 -0
  243. quadra/modules/ssl/simsiam.py +68 -0
  244. quadra/modules/ssl/vicreg.py +67 -0
  245. quadra/optimizers/__init__.py +4 -0
  246. quadra/optimizers/lars.py +153 -0
  247. quadra/optimizers/sam.py +127 -0
  248. quadra/schedulers/__init__.py +3 -0
  249. quadra/schedulers/base.py +44 -0
  250. quadra/schedulers/warmup.py +127 -0
  251. quadra/tasks/__init__.py +24 -0
  252. quadra/tasks/anomaly.py +582 -0
  253. quadra/tasks/base.py +397 -0
  254. quadra/tasks/classification.py +1263 -0
  255. quadra/tasks/patch.py +492 -0
  256. quadra/tasks/segmentation.py +389 -0
  257. quadra/tasks/ssl.py +560 -0
  258. quadra/trainers/README.md +3 -0
  259. quadra/trainers/__init__.py +0 -0
  260. quadra/trainers/classification.py +179 -0
  261. quadra/utils/__init__.py +0 -0
  262. quadra/utils/anomaly.py +112 -0
  263. quadra/utils/classification.py +618 -0
  264. quadra/utils/deprecation.py +31 -0
  265. quadra/utils/evaluation.py +474 -0
  266. quadra/utils/export.py +585 -0
  267. quadra/utils/imaging.py +32 -0
  268. quadra/utils/logger.py +15 -0
  269. quadra/utils/mlflow.py +98 -0
  270. quadra/utils/model_manager.py +320 -0
  271. quadra/utils/models.py +523 -0
  272. quadra/utils/patch/__init__.py +15 -0
  273. quadra/utils/patch/dataset.py +1433 -0
  274. quadra/utils/patch/metrics.py +449 -0
  275. quadra/utils/patch/model.py +153 -0
  276. quadra/utils/patch/visualization.py +217 -0
  277. quadra/utils/resolver.py +42 -0
  278. quadra/utils/segmentation.py +31 -0
  279. quadra/utils/tests/__init__.py +0 -0
  280. quadra/utils/tests/fixtures/__init__.py +1 -0
  281. quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
  282. quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
  283. quadra/utils/tests/fixtures/dataset/classification.py +406 -0
  284. quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
  285. quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
  286. quadra/utils/tests/fixtures/models/__init__.py +3 -0
  287. quadra/utils/tests/fixtures/models/anomaly.py +89 -0
  288. quadra/utils/tests/fixtures/models/classification.py +45 -0
  289. quadra/utils/tests/fixtures/models/segmentation.py +33 -0
  290. quadra/utils/tests/helpers.py +70 -0
  291. quadra/utils/tests/models.py +27 -0
  292. quadra/utils/utils.py +525 -0
  293. quadra/utils/validator.py +115 -0
  294. quadra/utils/visualization.py +422 -0
  295. quadra/utils/vit_explainability.py +349 -0
  296. quadra-2.2.7.dist-info/LICENSE +201 -0
  297. quadra-2.2.7.dist-info/METADATA +381 -0
  298. quadra-2.2.7.dist-info/RECORD +300 -0
  299. {quadra-0.0.1.dist-info → quadra-2.2.7.dist-info}/WHEEL +1 -1
  300. quadra-2.2.7.dist-info/entry_points.txt +3 -0
  301. quadra-0.0.1.dist-info/METADATA +0 -14
  302. quadra-0.0.1.dist-info/RECORD +0 -4
@@ -0,0 +1,618 @@
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ import os
5
+ import random
6
+ import re
7
+ from collections.abc import Generator, Sequence
8
+ from typing import TYPE_CHECKING, Any
9
+
10
+ import matplotlib.pyplot as plt
11
+ import numpy as np
12
+ import pandas as pd
13
+ import torch
14
+ from omegaconf import DictConfig
15
+ from sklearn.metrics import ConfusionMatrixDisplay, accuracy_score, classification_report, confusion_matrix
16
+ from sklearn.model_selection import StratifiedKFold, StratifiedShuffleSplit
17
+ from torch.utils.data import DataLoader
18
+
19
+ from quadra.models.base import ModelSignatureWrapper
20
+ from quadra.utils import utils
21
+ from quadra.utils.models import get_feature
22
+ from quadra.utils.visualization import UnNormalize, plot_classification_results
23
+
24
+ if TYPE_CHECKING:
25
+ from quadra.datamodules.classification import SklearnClassificationDataModule
26
+ from quadra.datamodules.patch import PatchSklearnClassificationDataModule
27
+
28
+ log = utils.get_logger(__name__)
29
+
30
+
31
+ def get_file_condition(
32
+ file_name: str, root: str, exclude_filter: list[str] | None = None, include_filter: list[str] | None = None
33
+ ):
34
+ """Check if a file should be included or excluded based on the filters provided.
35
+
36
+ Args:
37
+ file_name: Name of the file
38
+ root: Root directory of the file
39
+ exclude_filter: List of string filter to be used to exclude images. If None no filter will be applied.
40
+ include_filter: List of string filter to be used to include images. If None no filter will be applied.
41
+ """
42
+ if exclude_filter is not None:
43
+ if any(fil in file_name for fil in exclude_filter):
44
+ return False
45
+
46
+ if any(fil in root for fil in exclude_filter):
47
+ return False
48
+
49
+ if include_filter is not None and (
50
+ not any(fil in file_name for fil in include_filter) and not any(fil in root for fil in include_filter)
51
+ ):
52
+ return False
53
+
54
+ return True
55
+
56
+
57
+ def natural_key(string_):
58
+ """See http://www.codinghorror.com/blog/archives/001018.html."""
59
+ return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())]
60
+
61
+
62
+ def find_images_and_targets(
63
+ folder: str,
64
+ types: list | None = None,
65
+ class_to_idx: dict[str, int] | None = None,
66
+ leaf_name_only: bool = True,
67
+ sort: bool = True,
68
+ exclude_filter: list | None = None,
69
+ include_filter: list | None = None,
70
+ label_map: dict[str, Any] | None = None,
71
+ ) -> tuple[np.ndarray, np.ndarray, dict]:
72
+ """Given a folder, extract the absolute path of all the files with a valid extension.
73
+ Then assign a label based on subfolder name.
74
+
75
+ Args:
76
+ folder: path to main folder
77
+ types: valid file extentions
78
+ class_to_idx: dictionary of conversion btw folder name and index.
79
+ Only file whose label is in dictionary key list will be considered. If None all files will
80
+ be considered and a custom conversion is created.
81
+ leaf_name_only: if True use only the leaf folder name as label, otherwise use the full path
82
+ sort: if True sort the images and labels based on the image name
83
+ exclude_filter: list of string filter to be used to exclude images.
84
+ If None no filter will be applied.
85
+ include_filter: list of string filder to be used to include images.
86
+ Only images that satisfied at list one of the filter will be included.
87
+ label_map: dictionary of conversion btw folder name and label.
88
+ """
89
+ if types is None:
90
+ types = [".png", ".jpg", ".jpeg", ".bmp"]
91
+ labels = []
92
+ filenames = []
93
+
94
+ for root, _, files in os.walk(folder, topdown=False, followlinks=True):
95
+ if root != folder:
96
+ rel_path = os.path.relpath(root, folder)
97
+ else:
98
+ rel_path = ""
99
+
100
+ if leaf_name_only:
101
+ label = os.path.basename(rel_path)
102
+ else:
103
+ aa = rel_path.split(os.path.sep)
104
+ if len(aa) == 2:
105
+ aa = aa[-1:]
106
+ else:
107
+ aa = aa[-2:]
108
+ label = "_".join(aa) # rel_path.replace(os.path.sep, "_")
109
+ # label = rel_path.replace(os.path.sep, "_")
110
+
111
+ for f in files:
112
+ if not get_file_condition(
113
+ file_name=f, root=root, exclude_filter=exclude_filter, include_filter=include_filter
114
+ ):
115
+ continue
116
+
117
+ if f.startswith(".") or "checkpoint" in f:
118
+ continue
119
+ _, ext = os.path.splitext(f)
120
+ if ext.lower() in types:
121
+ filenames.append(os.path.join(root, f))
122
+ labels.append(label)
123
+
124
+ if label_map is not None:
125
+ labels, _ = group_labels(labels, label_map)
126
+
127
+ if class_to_idx is None:
128
+ # building class index
129
+ unique_labels = set(labels)
130
+ sorted_labels = sorted(unique_labels, key=natural_key)
131
+ class_to_idx = {str(c): idx for idx, c in enumerate(sorted_labels)}
132
+
133
+ images_and_targets = [(f, l) for f, l in zip(filenames, labels) if l in class_to_idx]
134
+
135
+ if sort:
136
+ images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0]))
137
+
138
+ return np.array(images_and_targets)[:, 0], np.array(images_and_targets)[:, 1], class_to_idx
139
+
140
+
141
+ def find_test_image(
142
+ folder: str,
143
+ types: list[str] | None = None,
144
+ exclude_filter: list[str] | None = None,
145
+ include_filter: list[str] | None = None,
146
+ include_none_class: bool = True,
147
+ test_split_file: str | None = None,
148
+ label_map=None,
149
+ ) -> tuple[list[str], list[str | None]]:
150
+ """Given a path extract images and labels with filters, labels are based on the parent folder name of the images
151
+ Args:
152
+ folder: root directory containing the images
153
+ types: only choose images with the extensions specified, if None use default extensions
154
+ exclude_filter: list of string filter to be used to exclude images. If None no filter will be applied.
155
+ include_filter: list of string filter to be used to include images. If None no filter will be applied.
156
+ include_none_class: if set to True convert all 'None' labels to None, otherwise ignore the image
157
+ test_split_file: if defined use the split defined inside the file
158
+ Returns:
159
+ Two lists, one containing the images path and the other one containing the labels. Labels can be None.
160
+ """
161
+ if types is None:
162
+ types = [".png", ".jpg", ".jpeg", ".bmp"]
163
+
164
+ labels = []
165
+ filenames = []
166
+
167
+ for root, _, files in os.walk(folder, topdown=False, followlinks=True):
168
+ rel_path = os.path.relpath(root, folder) if root != folder else ""
169
+ label: str | None = os.path.basename(rel_path)
170
+ for f in files:
171
+ if not get_file_condition(
172
+ file_name=f, root=root, exclude_filter=exclude_filter, include_filter=include_filter
173
+ ):
174
+ continue
175
+ if f.startswith(".") or "checkpoint" in f:
176
+ continue
177
+ _, ext = os.path.splitext(f)
178
+ if ext.lower() in types:
179
+ if label == "None":
180
+ if include_none_class:
181
+ label = None
182
+ else:
183
+ continue
184
+ filenames.append(os.path.join(root, f))
185
+ labels.append(label)
186
+
187
+ if test_split_file is not None:
188
+ if not os.path.isabs(test_split_file):
189
+ log.info(
190
+ "test_split_file is not an absolute path. Trying to using folder argument %s as parent folder", folder
191
+ )
192
+ test_split_file = os.path.join(folder, test_split_file)
193
+
194
+ if not os.path.exists(test_split_file):
195
+ raise FileNotFoundError(f"test_split_file {test_split_file} does not exist")
196
+
197
+ with open(test_split_file) as test_file:
198
+ test_split = test_file.read().splitlines()
199
+
200
+ file_samples = []
201
+ for row in test_split:
202
+ csv_values = row.split(",")
203
+ if len(csv_values) == 1:
204
+ # ensuring backward compatibility with old split file format
205
+ # old_format: sample, new_format: sample,class
206
+ sample_path = os.path.join(folder, csv_values[0])
207
+ else:
208
+ sample_path = os.path.join(folder, ",".join(csv_values[:-1]))
209
+
210
+ file_samples.append(sample_path)
211
+
212
+ test_split = [os.path.join(folder, sample.strip()) for sample in file_samples]
213
+ labels = [t for s, t in zip(filenames, labels) if s in file_samples]
214
+ filenames = [s for s in filenames if s in file_samples]
215
+ log.info("Selected %d images using test_split_file for the test", len(filenames))
216
+ if len(filenames) != len(file_samples):
217
+ log.warning(
218
+ "test_split_file contains %d images but only %d images were found in the folder."
219
+ "This may be due to duplicate lines in the test_split_file.",
220
+ len(file_samples),
221
+ len(filenames),
222
+ )
223
+ else:
224
+ log.info("No test_split_file. Selected all %s images for the test", folder)
225
+
226
+ if label_map is not None:
227
+ labels, _ = group_labels(labels, label_map)
228
+
229
+ return filenames, labels
230
+
231
+
232
+ def group_labels(labels: Sequence[str | None], class_mapping: dict[str, str | None | list[str]]) -> tuple[list, dict]:
233
+ """Group labels based on class_mapping.
234
+
235
+ Raises:
236
+ ValueError: if a label is not in class_mapping
237
+ ValueError: if a label is in class_mapping but has no corresponding value
238
+
239
+ Returns:
240
+ List of labels and a dictionary of labels and their corresponding group
241
+
242
+ Example:
243
+ ```python
244
+ grouped_labels, class_to_idx = group_labels(labels, class_mapping={"Good": "A", "Bad": None})
245
+ assert grouped_labels.count("Good") == labels.count("A")
246
+ assert len(class_to_idx.keys()) == 2
247
+
248
+ grouped_labels, class_to_idx = group_labels(labels, class_mapping={"Good": "A", "Defect": "B", "Bad": None})
249
+ assert grouped_labels.count("Bad") == labels.count("C") + labels.count("D")
250
+ assert len(class_to_idx.keys()) == 3
251
+
252
+ grouped_labels, class_to_idx = group_labels(labels, class_mapping={"Good": "A", "Bad": ["B", "C", "D"]})
253
+ assert grouped_labels.count("Bad") == labels.count("B") + labels.count("C") + labels.count("D")
254
+ assert len(class_to_idx.keys()) == 2
255
+ ```
256
+ """
257
+ grouped_labels = []
258
+ specified_targets = [k for k in class_mapping if class_mapping[k] is not None]
259
+ non_specified_targets = [k for k in class_mapping if class_mapping[k] is None]
260
+ if len(non_specified_targets) > 1:
261
+ raise ValueError(f"More than one non specified target: {non_specified_targets}")
262
+ for label in labels:
263
+ found = False
264
+ for target in specified_targets:
265
+ if not found:
266
+ current_mapping = class_mapping[target]
267
+ if current_mapping is None:
268
+ continue
269
+
270
+ if any(label in list(related_label) for related_label in current_mapping if related_label is not None):
271
+ grouped_labels.append(target)
272
+ found = True
273
+ if not found:
274
+ if len(non_specified_targets) > 0:
275
+ grouped_labels.append(non_specified_targets[0])
276
+ else:
277
+ raise ValueError(f"No target found for label: {label}")
278
+ class_to_idx = {k: i for i, k in enumerate(class_mapping.keys())}
279
+ return grouped_labels, class_to_idx
280
+
281
+
282
+ def filter_with_file(list_of_full_paths: list[str], file_path: str, root_path: str) -> tuple[list[str], list[bool]]:
283
+ """Filter a list of items using a file containing the items to keep. Paths inside file
284
+ should be relative to root_path not absolute to avoid user related issues.
285
+
286
+ Args:
287
+ list_of_full_paths: list of items to filter
288
+ file_path: path to the file containing the items to keep
289
+ root_path: root path of the dataset
290
+
291
+ Returns:
292
+ list of items to keep
293
+ the mask list to apply different lists later.
294
+ """
295
+ filtered_full_paths = []
296
+ filter_mask = []
297
+
298
+ with open(file_path) as f:
299
+ for relative_path in f.read().splitlines():
300
+ full_path = os.path.join(root_path, relative_path)
301
+ if full_path in list_of_full_paths:
302
+ filtered_full_paths.append(full_path)
303
+ filter_mask.append(True)
304
+ else:
305
+ filter_mask.append(False)
306
+
307
+ return filtered_full_paths, filter_mask
308
+
309
+
310
+ def get_split(
311
+ image_dir: str,
312
+ exclude_filter: list[str] | None = None,
313
+ include_filter: list[str] | None = None,
314
+ test_size: float = 0.3,
315
+ random_state: int = 42,
316
+ class_to_idx: dict[str, int] | None = None,
317
+ label_map: dict | None = None,
318
+ n_splits: int = 1,
319
+ include_none_class: bool = False,
320
+ limit_training_data: int | None = None,
321
+ train_split_file: str | None = None,
322
+ ) -> tuple[np.ndarray, np.ndarray, Generator[list, None, None], dict]:
323
+ """Given a folder, extract the absolute path of all the files with a valid extension and name
324
+ and split them into train/test.
325
+
326
+ Args:
327
+ image_dir: Path to the folder containing the images
328
+ exclude_filter: List of file name filter to be excluded: If None no filter will be applied
329
+ include_filter: List of file name filter to be included: If None no filter will be applied
330
+ test_size: Percentage of data to be used for test
331
+ random_state: Random state to be used for reproducibility
332
+ class_to_idx: Dictionary of conversion btw folder name and index.
333
+ Only file whose label is in dictionary key list will be considered.
334
+ If None all files will be considered and a custom conversion is created.
335
+ label_map: Dictionary of conversion btw folder name and label.
336
+ n_splits: Number of dataset subdivision (default 1 -> train/test)
337
+ include_none_class: If set to True convert all 'None' labels to None
338
+ limit_training_data: If set to a value, limit the number of training samples to this value
339
+ train_split_file: If set to a path, use the file to split the dataset
340
+ """
341
+ # TODO: Why is include_none_class not used?
342
+ # pylint: disable=unused-argument
343
+ assert os.path.isdir(image_dir), f"Folder {image_dir} does not exist."
344
+ # Get samples and target
345
+ samples, targets, class_to_idx = find_images_and_targets(
346
+ folder=image_dir,
347
+ exclude_filter=exclude_filter,
348
+ include_filter=include_filter,
349
+ class_to_idx=class_to_idx,
350
+ label_map=label_map,
351
+ # include_none_class=include_none_class,
352
+ )
353
+
354
+ cl, counts = np.unique(targets, return_counts=True)
355
+
356
+ for num, _cl in zip(counts, cl):
357
+ if num == 1:
358
+ to_remove = np.where(np.array(targets) == _cl)[0][0]
359
+ samples = np.delete(np.array(samples), to_remove)
360
+ targets = np.delete(np.array(targets), to_remove)
361
+ class_to_idx.pop(_cl)
362
+
363
+ if train_split_file is not None:
364
+ with open(train_split_file) as f:
365
+ train_split = f.read().splitlines()
366
+
367
+ file_samples = []
368
+ for row in train_split:
369
+ csv_values = row.split(",")
370
+
371
+ if len(csv_values) == 1:
372
+ # ensuring backward compatibility with the old split file format
373
+ # old_format: sample, new_format: sample,class
374
+ sample_path = os.path.join(image_dir, csv_values[0])
375
+ else:
376
+ sample_path = os.path.join(image_dir, ",".join(csv_values[:-1]))
377
+
378
+ file_samples.append(sample_path)
379
+
380
+ train_split = [os.path.join(image_dir, sample.strip()) for sample in file_samples]
381
+ targets = np.array([t for s, t in zip(samples, targets) if s in file_samples])
382
+ samples = np.array([s for s in samples if s in file_samples])
383
+
384
+ if limit_training_data is not None:
385
+ idx_to_keep = []
386
+ for cl in np.unique(targets):
387
+ cl_idx = np.where(np.array(targets) == cl)[0].tolist()
388
+ random.seed(random_state)
389
+ random.shuffle(cl_idx)
390
+ idx_to_keep.extend(cl_idx[:limit_training_data])
391
+
392
+ samples = np.asarray([samples[i] for i in idx_to_keep])
393
+ targets = np.asarray([targets[i] for i in idx_to_keep])
394
+
395
+ _, counts = np.unique(targets, return_counts=True)
396
+
397
+ if n_splits == 1:
398
+ split_technique = StratifiedShuffleSplit(n_splits=1, test_size=test_size, random_state=random_state)
399
+ else:
400
+ split_technique = StratifiedKFold(n_splits=n_splits, random_state=random_state, shuffle=True)
401
+
402
+ split = split_technique.split(samples, targets)
403
+
404
+ return np.array(samples), np.array(targets), split, class_to_idx
405
+
406
+
407
+ def save_classification_result(
408
+ results: pd.DataFrame,
409
+ output_folder: str,
410
+ test_dataloader: DataLoader,
411
+ config: DictConfig,
412
+ output: DictConfig,
413
+ accuracy: float | None = None,
414
+ confmat: pd.DataFrame | None = None,
415
+ grayscale_cams: np.ndarray | None = None,
416
+ ):
417
+ """Save csv results, confusion matrix and example images.
418
+
419
+ Args:
420
+ results: Dataframe containing the results
421
+ output_folder: Path to the output folder
422
+ confmat: Confusion matrix in a pandas dataframe, may be None if all test labels are unknown
423
+ accuracy: Accuracy of the model, is None if all test labels are unknown
424
+ test_dataloader: Dataloader used for testing
425
+ config: Configuration file
426
+ output: Output configuration
427
+ grayscale_cams: List of grayscale grad_cam outputs ordered as the results
428
+ """
429
+ # Save csv
430
+ results.to_csv(os.path.join(output_folder, "test_results.csv"), index_label="index")
431
+ if grayscale_cams is None:
432
+ log.info("Plotting only original examples, set gradcam = true in config file to also plot gradcam examples")
433
+
434
+ save_gradcams = False
435
+ else:
436
+ log.info("Plotting original and gradcam examples")
437
+ save_gradcams = True
438
+
439
+ if confmat is not None and accuracy is not None:
440
+ # Save confusion matrix
441
+ disp = ConfusionMatrixDisplay(
442
+ confusion_matrix=np.array(confmat),
443
+ display_labels=[x.replace("pred:", "") for x in confmat.columns.to_list()],
444
+ )
445
+ disp.plot(include_values=True, cmap=plt.cm.Greens, ax=None, colorbar=False, xticks_rotation=90)
446
+ plt.title(f"Confusion Matrix (Accuracy: {(accuracy * 100):.2f}%)")
447
+ plt.savefig(
448
+ os.path.join(output_folder, "test_confusion_matrix.png"),
449
+ bbox_inches="tight",
450
+ pad_inches=0,
451
+ dpi=300,
452
+ )
453
+ plt.close()
454
+
455
+ if output is not None and output.example:
456
+ log.info("Saving discordant/concordant examples in test folder")
457
+ idx_to_class = test_dataloader.dataset.idx_to_class # type: ignore[attr-defined]
458
+
459
+ # Get misclassified samples
460
+ images_folder = os.path.join(output_folder, "example")
461
+ if not os.path.isdir(images_folder):
462
+ os.makedirs(images_folder)
463
+ original_images_folder = os.path.join(images_folder, "original")
464
+ if not os.path.isdir(original_images_folder):
465
+ os.makedirs(original_images_folder)
466
+
467
+ gradcam_folder = os.path.join(images_folder, "gradcam")
468
+ if save_gradcams and not os.path.isdir(gradcam_folder):
469
+ os.makedirs(gradcam_folder)
470
+
471
+ for v in np.unique([results["real_label"], results["pred_label"]]):
472
+ if np.isnan(v) or v == -1:
473
+ continue
474
+
475
+ k = idx_to_class[v]
476
+ plot_classification_results(
477
+ test_dataloader.dataset,
478
+ unorm=UnNormalize(mean=config.transforms.mean, std=config.transforms.std),
479
+ pred_labels=results["pred_label"].to_numpy(),
480
+ test_labels=results["real_label"].to_numpy(),
481
+ grayscale_cams=grayscale_cams,
482
+ class_name=k,
483
+ original_folder=original_images_folder,
484
+ gradcam_folder=gradcam_folder,
485
+ idx_to_class=idx_to_class,
486
+ pred_class_to_plot=v,
487
+ what="con",
488
+ rows=output.get("rows", 3),
489
+ cols=output.get("cols", 2),
490
+ figsize=output.get("figsize", (20, 20)),
491
+ gradcam=save_gradcams,
492
+ )
493
+
494
+ plot_classification_results(
495
+ test_dataloader.dataset,
496
+ unorm=UnNormalize(mean=config.transforms.mean, std=config.transforms.std),
497
+ pred_labels=results["pred_label"].to_numpy(),
498
+ test_labels=results["real_label"].to_numpy(),
499
+ grayscale_cams=grayscale_cams,
500
+ class_name=k,
501
+ original_folder=original_images_folder,
502
+ gradcam_folder=gradcam_folder,
503
+ idx_to_class=idx_to_class,
504
+ pred_class_to_plot=v,
505
+ what="dis",
506
+ rows=output.get("rows", 3),
507
+ cols=output.get("cols", 2),
508
+ figsize=output.get("figsize", (20, 20)),
509
+ gradcam=save_gradcams,
510
+ )
511
+
512
+ else:
513
+ log.info("Not generating discordant/concordant examples. Check task.output.example in config file")
514
+
515
+
516
+ def get_results(
517
+ test_labels: np.ndarray | list[int],
518
+ pred_labels: np.ndarray | list[int],
519
+ idx_to_labels: dict | None = None,
520
+ cl_rep_digits: int = 3,
521
+ ) -> tuple[str | dict, pd.DataFrame, float]:
522
+ """Get prediction results from predicted and test labels.
523
+
524
+ Args:
525
+ test_labels : test labels
526
+ pred_labels : predicted labels
527
+ idx_to_labels : dictionary mapping indices to labels
528
+ cl_rep_digits : number of digits to use in the classification report. Default: 3
529
+
530
+ Returns:
531
+ A tuple that contains classification report as dictionary, `cm` is a pd.Dataframe representing
532
+ the Confusion Matrix, acc is the computed accuracy
533
+ """
534
+ unique_labels = np.unique([test_labels, pred_labels])
535
+ cl_rep = classification_report(
536
+ y_true=test_labels,
537
+ y_pred=pred_labels,
538
+ labels=unique_labels,
539
+ digits=cl_rep_digits,
540
+ zero_division=0,
541
+ )
542
+
543
+ cm = confusion_matrix(y_true=test_labels, y_pred=pred_labels, labels=unique_labels)
544
+
545
+ acc = accuracy_score(y_true=test_labels, y_pred=pred_labels)
546
+
547
+ if idx_to_labels:
548
+ pd_cm = pd.DataFrame(
549
+ cm,
550
+ index=[f"true:{idx_to_labels[x]}" for x in unique_labels],
551
+ columns=[f"pred:{idx_to_labels[x]}" for x in unique_labels],
552
+ )
553
+ else:
554
+ pd_cm = pd.DataFrame(
555
+ cm,
556
+ index=[f"true:{x}" for x in unique_labels],
557
+ columns=[f"pred:{x}" for x in unique_labels],
558
+ )
559
+ return cl_rep, pd_cm, acc
560
+
561
+
562
+ def automatic_batch_size_computation(
563
+ datamodule: SklearnClassificationDataModule | PatchSklearnClassificationDataModule,
564
+ backbone: ModelSignatureWrapper,
565
+ starting_batch_size: int,
566
+ ) -> int:
567
+ """Find the optimal batch size for feature extraction. This algorithm works from the largest batch size possible
568
+ and divide by 2 until it finds the largest batch size that fits in memory.
569
+
570
+ Args:
571
+ datamodule: Datamodule used for feature extraction
572
+ backbone: Backbone used for feature extraction
573
+ starting_batch_size: Starting batch size to use for the search
574
+
575
+ Returns:
576
+ Optimal batch size
577
+ """
578
+ log.info("Finding optimal batch size...")
579
+ optimal = False
580
+ batch_size = starting_batch_size
581
+
582
+ while not optimal:
583
+ datamodule.batch_size = batch_size
584
+ base_dataloader = datamodule.train_dataloader()
585
+
586
+ if isinstance(base_dataloader, Sequence):
587
+ base_dataloader = base_dataloader[0]
588
+
589
+ if len(base_dataloader) == 1:
590
+ # If it fits in memory this is the largest batch size possible
591
+ # If it crashes restart with the previous batch size // 2
592
+ datamodule.batch_size = len(base_dataloader.dataset) # type: ignore[arg-type]
593
+ # New restarting batch size is the largest closest power of 2 to the dataset size, it will be divided by 2
594
+ batch_size = 2 ** math.ceil(math.log2(datamodule.batch_size))
595
+ base_dataloader = datamodule.train_dataloader()
596
+ if isinstance(base_dataloader, Sequence):
597
+ base_dataloader = base_dataloader[0]
598
+ optimal = True
599
+
600
+ try:
601
+ log.info("Trying batch size: %d", datamodule.batch_size)
602
+ _ = get_feature(feature_extractor=backbone, dl=base_dataloader, iteration_over_training=1, limit_batches=1)
603
+ except RuntimeError as e:
604
+ if batch_size > 1:
605
+ batch_size = batch_size // 2
606
+ optimal = False
607
+ continue
608
+
609
+ log.error("Unable to run the model with batch size 1")
610
+ raise e
611
+
612
+ log.info("Found optimal batch size: %d", datamodule.batch_size)
613
+ optimal = True
614
+
615
+ if torch.cuda.is_available():
616
+ torch.cuda.empty_cache()
617
+
618
+ return datamodule.batch_size
@@ -0,0 +1,31 @@
1
+ import functools
2
+ from collections.abc import Callable
3
+
4
+ from quadra.utils.utils import get_logger
5
+
6
+ logger = get_logger(__name__)
7
+
8
+
9
+ def deprecated(message: str) -> Callable:
10
+ """Decorator to mark a function as deprecated.
11
+
12
+ Args:
13
+ message: Message to be displayed when the function is called.
14
+
15
+ Returns:
16
+ Decoratored function.
17
+ """
18
+
19
+ def deprecated_decorator(func_or_class: Callable) -> Callable:
20
+ """Decorator to mark a function as deprecated."""
21
+
22
+ @functools.wraps(func_or_class)
23
+ def wrapper(*args, **kwargs):
24
+ """Wrapper function to display a warning message."""
25
+ warning_msg = f"{func_or_class.__name__} is deprecated. {message}"
26
+ logger.warning(warning_msg)
27
+ return func_or_class(*args, **kwargs)
28
+
29
+ return wrapper
30
+
31
+ return deprecated_decorator