quadra 0.0.1__py3-none-any.whl → 2.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (302) hide show
  1. hydra_plugins/quadra_searchpath_plugin.py +30 -0
  2. quadra/__init__.py +6 -0
  3. quadra/callbacks/__init__.py +0 -0
  4. quadra/callbacks/anomalib.py +289 -0
  5. quadra/callbacks/lightning.py +501 -0
  6. quadra/callbacks/mlflow.py +291 -0
  7. quadra/callbacks/scheduler.py +69 -0
  8. quadra/configs/__init__.py +0 -0
  9. quadra/configs/backbone/caformer_m36.yaml +8 -0
  10. quadra/configs/backbone/caformer_s36.yaml +8 -0
  11. quadra/configs/backbone/convnextv2_base.yaml +8 -0
  12. quadra/configs/backbone/convnextv2_femto.yaml +8 -0
  13. quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
  14. quadra/configs/backbone/dino_vitb8.yaml +12 -0
  15. quadra/configs/backbone/dino_vits8.yaml +12 -0
  16. quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
  17. quadra/configs/backbone/dinov2_vits14.yaml +12 -0
  18. quadra/configs/backbone/efficientnet_b0.yaml +8 -0
  19. quadra/configs/backbone/efficientnet_b1.yaml +8 -0
  20. quadra/configs/backbone/efficientnet_b2.yaml +8 -0
  21. quadra/configs/backbone/efficientnet_b3.yaml +8 -0
  22. quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
  23. quadra/configs/backbone/levit_128s.yaml +8 -0
  24. quadra/configs/backbone/mnasnet0_5.yaml +9 -0
  25. quadra/configs/backbone/resnet101.yaml +8 -0
  26. quadra/configs/backbone/resnet18.yaml +8 -0
  27. quadra/configs/backbone/resnet18_ssl.yaml +8 -0
  28. quadra/configs/backbone/resnet50.yaml +8 -0
  29. quadra/configs/backbone/smp.yaml +9 -0
  30. quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
  31. quadra/configs/backbone/unetr.yaml +15 -0
  32. quadra/configs/backbone/vit16_base.yaml +9 -0
  33. quadra/configs/backbone/vit16_small.yaml +9 -0
  34. quadra/configs/backbone/vit16_tiny.yaml +9 -0
  35. quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
  36. quadra/configs/callbacks/all.yaml +32 -0
  37. quadra/configs/callbacks/default.yaml +37 -0
  38. quadra/configs/callbacks/default_anomalib.yaml +67 -0
  39. quadra/configs/config.yaml +33 -0
  40. quadra/configs/core/default.yaml +11 -0
  41. quadra/configs/datamodule/base/anomaly.yaml +16 -0
  42. quadra/configs/datamodule/base/classification.yaml +21 -0
  43. quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
  44. quadra/configs/datamodule/base/segmentation.yaml +18 -0
  45. quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
  46. quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
  47. quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
  48. quadra/configs/datamodule/base/ssl.yaml +21 -0
  49. quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
  50. quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
  51. quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
  52. quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
  53. quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
  54. quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
  55. quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
  56. quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
  57. quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
  58. quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
  59. quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
  60. quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
  61. quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
  62. quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
  63. quadra/configs/experiment/base/classification/classification.yaml +73 -0
  64. quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
  65. quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
  66. quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
  67. quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
  68. quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
  69. quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
  70. quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
  71. quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
  72. quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
  73. quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
  74. quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
  75. quadra/configs/experiment/base/ssl/byol.yaml +43 -0
  76. quadra/configs/experiment/base/ssl/dino.yaml +46 -0
  77. quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
  78. quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
  79. quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
  80. quadra/configs/experiment/custom/cls.yaml +12 -0
  81. quadra/configs/experiment/default.yaml +15 -0
  82. quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
  83. quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
  84. quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
  85. quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
  86. quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
  87. quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
  88. quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
  89. quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
  90. quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
  91. quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
  92. quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
  93. quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
  94. quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
  95. quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
  96. quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
  97. quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
  98. quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
  99. quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
  100. quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
  101. quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
  102. quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
  103. quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
  104. quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
  105. quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
  106. quadra/configs/export/default.yaml +13 -0
  107. quadra/configs/hydra/anomaly_custom.yaml +15 -0
  108. quadra/configs/hydra/default.yaml +14 -0
  109. quadra/configs/inference/default.yaml +26 -0
  110. quadra/configs/logger/comet.yaml +10 -0
  111. quadra/configs/logger/csv.yaml +5 -0
  112. quadra/configs/logger/mlflow.yaml +12 -0
  113. quadra/configs/logger/tensorboard.yaml +8 -0
  114. quadra/configs/loss/asl.yaml +7 -0
  115. quadra/configs/loss/barlow.yaml +2 -0
  116. quadra/configs/loss/bce.yaml +1 -0
  117. quadra/configs/loss/byol.yaml +1 -0
  118. quadra/configs/loss/cross_entropy.yaml +1 -0
  119. quadra/configs/loss/dino.yaml +8 -0
  120. quadra/configs/loss/simclr.yaml +2 -0
  121. quadra/configs/loss/simsiam.yaml +1 -0
  122. quadra/configs/loss/smp_ce.yaml +3 -0
  123. quadra/configs/loss/smp_dice.yaml +2 -0
  124. quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
  125. quadra/configs/loss/smp_mcc.yaml +2 -0
  126. quadra/configs/loss/vicreg.yaml +5 -0
  127. quadra/configs/model/anomalib/cfa.yaml +35 -0
  128. quadra/configs/model/anomalib/cflow.yaml +30 -0
  129. quadra/configs/model/anomalib/csflow.yaml +34 -0
  130. quadra/configs/model/anomalib/dfm.yaml +19 -0
  131. quadra/configs/model/anomalib/draem.yaml +29 -0
  132. quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
  133. quadra/configs/model/anomalib/fastflow.yaml +32 -0
  134. quadra/configs/model/anomalib/padim.yaml +32 -0
  135. quadra/configs/model/anomalib/patchcore.yaml +36 -0
  136. quadra/configs/model/barlow.yaml +16 -0
  137. quadra/configs/model/byol.yaml +25 -0
  138. quadra/configs/model/classification.yaml +10 -0
  139. quadra/configs/model/dino.yaml +26 -0
  140. quadra/configs/model/logistic_regression.yaml +4 -0
  141. quadra/configs/model/multilabel_classification.yaml +9 -0
  142. quadra/configs/model/simclr.yaml +18 -0
  143. quadra/configs/model/simsiam.yaml +24 -0
  144. quadra/configs/model/smp.yaml +4 -0
  145. quadra/configs/model/smp_multiclass.yaml +4 -0
  146. quadra/configs/model/vicreg.yaml +16 -0
  147. quadra/configs/optimizer/adam.yaml +5 -0
  148. quadra/configs/optimizer/adamw.yaml +3 -0
  149. quadra/configs/optimizer/default.yaml +4 -0
  150. quadra/configs/optimizer/lars.yaml +8 -0
  151. quadra/configs/optimizer/sgd.yaml +4 -0
  152. quadra/configs/scheduler/default.yaml +5 -0
  153. quadra/configs/scheduler/rop.yaml +5 -0
  154. quadra/configs/scheduler/step.yaml +3 -0
  155. quadra/configs/scheduler/warmrestart.yaml +2 -0
  156. quadra/configs/scheduler/warmup.yaml +6 -0
  157. quadra/configs/task/anomalib/cfa.yaml +5 -0
  158. quadra/configs/task/anomalib/cflow.yaml +5 -0
  159. quadra/configs/task/anomalib/csflow.yaml +5 -0
  160. quadra/configs/task/anomalib/draem.yaml +5 -0
  161. quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
  162. quadra/configs/task/anomalib/fastflow.yaml +5 -0
  163. quadra/configs/task/anomalib/inference.yaml +3 -0
  164. quadra/configs/task/anomalib/padim.yaml +5 -0
  165. quadra/configs/task/anomalib/patchcore.yaml +5 -0
  166. quadra/configs/task/classification.yaml +6 -0
  167. quadra/configs/task/classification_evaluation.yaml +6 -0
  168. quadra/configs/task/default.yaml +1 -0
  169. quadra/configs/task/segmentation.yaml +9 -0
  170. quadra/configs/task/segmentation_evaluation.yaml +3 -0
  171. quadra/configs/task/sklearn_classification.yaml +13 -0
  172. quadra/configs/task/sklearn_classification_patch.yaml +11 -0
  173. quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
  174. quadra/configs/task/sklearn_classification_test.yaml +8 -0
  175. quadra/configs/task/ssl.yaml +2 -0
  176. quadra/configs/trainer/lightning_cpu.yaml +36 -0
  177. quadra/configs/trainer/lightning_gpu.yaml +35 -0
  178. quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
  179. quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
  180. quadra/configs/trainer/lightning_multigpu.yaml +37 -0
  181. quadra/configs/trainer/sklearn_classification.yaml +7 -0
  182. quadra/configs/transforms/byol.yaml +47 -0
  183. quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
  184. quadra/configs/transforms/default.yaml +37 -0
  185. quadra/configs/transforms/default_numpy.yaml +24 -0
  186. quadra/configs/transforms/default_resize.yaml +22 -0
  187. quadra/configs/transforms/dino.yaml +63 -0
  188. quadra/configs/transforms/linear_eval.yaml +18 -0
  189. quadra/datamodules/__init__.py +20 -0
  190. quadra/datamodules/anomaly.py +180 -0
  191. quadra/datamodules/base.py +375 -0
  192. quadra/datamodules/classification.py +1003 -0
  193. quadra/datamodules/generic/__init__.py +0 -0
  194. quadra/datamodules/generic/imagenette.py +144 -0
  195. quadra/datamodules/generic/mnist.py +81 -0
  196. quadra/datamodules/generic/mvtec.py +58 -0
  197. quadra/datamodules/generic/oxford_pet.py +163 -0
  198. quadra/datamodules/patch.py +190 -0
  199. quadra/datamodules/segmentation.py +742 -0
  200. quadra/datamodules/ssl.py +140 -0
  201. quadra/datasets/__init__.py +17 -0
  202. quadra/datasets/anomaly.py +287 -0
  203. quadra/datasets/classification.py +241 -0
  204. quadra/datasets/patch.py +138 -0
  205. quadra/datasets/segmentation.py +239 -0
  206. quadra/datasets/ssl.py +110 -0
  207. quadra/losses/__init__.py +0 -0
  208. quadra/losses/classification/__init__.py +6 -0
  209. quadra/losses/classification/asl.py +83 -0
  210. quadra/losses/classification/focal.py +320 -0
  211. quadra/losses/classification/prototypical.py +148 -0
  212. quadra/losses/ssl/__init__.py +17 -0
  213. quadra/losses/ssl/barlowtwins.py +47 -0
  214. quadra/losses/ssl/byol.py +37 -0
  215. quadra/losses/ssl/dino.py +129 -0
  216. quadra/losses/ssl/hyperspherical.py +45 -0
  217. quadra/losses/ssl/idmm.py +50 -0
  218. quadra/losses/ssl/simclr.py +67 -0
  219. quadra/losses/ssl/simsiam.py +30 -0
  220. quadra/losses/ssl/vicreg.py +76 -0
  221. quadra/main.py +46 -0
  222. quadra/metrics/__init__.py +3 -0
  223. quadra/metrics/segmentation.py +251 -0
  224. quadra/models/__init__.py +0 -0
  225. quadra/models/base.py +151 -0
  226. quadra/models/classification/__init__.py +8 -0
  227. quadra/models/classification/backbones.py +149 -0
  228. quadra/models/classification/base.py +92 -0
  229. quadra/models/evaluation.py +322 -0
  230. quadra/modules/__init__.py +0 -0
  231. quadra/modules/backbone.py +30 -0
  232. quadra/modules/base.py +312 -0
  233. quadra/modules/classification/__init__.py +3 -0
  234. quadra/modules/classification/base.py +331 -0
  235. quadra/modules/ssl/__init__.py +17 -0
  236. quadra/modules/ssl/barlowtwins.py +59 -0
  237. quadra/modules/ssl/byol.py +172 -0
  238. quadra/modules/ssl/common.py +285 -0
  239. quadra/modules/ssl/dino.py +186 -0
  240. quadra/modules/ssl/hyperspherical.py +206 -0
  241. quadra/modules/ssl/idmm.py +98 -0
  242. quadra/modules/ssl/simclr.py +73 -0
  243. quadra/modules/ssl/simsiam.py +68 -0
  244. quadra/modules/ssl/vicreg.py +67 -0
  245. quadra/optimizers/__init__.py +4 -0
  246. quadra/optimizers/lars.py +153 -0
  247. quadra/optimizers/sam.py +127 -0
  248. quadra/schedulers/__init__.py +3 -0
  249. quadra/schedulers/base.py +44 -0
  250. quadra/schedulers/warmup.py +127 -0
  251. quadra/tasks/__init__.py +24 -0
  252. quadra/tasks/anomaly.py +582 -0
  253. quadra/tasks/base.py +397 -0
  254. quadra/tasks/classification.py +1264 -0
  255. quadra/tasks/patch.py +492 -0
  256. quadra/tasks/segmentation.py +389 -0
  257. quadra/tasks/ssl.py +560 -0
  258. quadra/trainers/README.md +3 -0
  259. quadra/trainers/__init__.py +0 -0
  260. quadra/trainers/classification.py +179 -0
  261. quadra/utils/__init__.py +0 -0
  262. quadra/utils/anomaly.py +112 -0
  263. quadra/utils/classification.py +618 -0
  264. quadra/utils/deprecation.py +31 -0
  265. quadra/utils/evaluation.py +474 -0
  266. quadra/utils/export.py +579 -0
  267. quadra/utils/imaging.py +32 -0
  268. quadra/utils/logger.py +15 -0
  269. quadra/utils/mlflow.py +98 -0
  270. quadra/utils/model_manager.py +320 -0
  271. quadra/utils/models.py +524 -0
  272. quadra/utils/patch/__init__.py +15 -0
  273. quadra/utils/patch/dataset.py +1433 -0
  274. quadra/utils/patch/metrics.py +449 -0
  275. quadra/utils/patch/model.py +153 -0
  276. quadra/utils/patch/visualization.py +217 -0
  277. quadra/utils/resolver.py +42 -0
  278. quadra/utils/segmentation.py +31 -0
  279. quadra/utils/tests/__init__.py +0 -0
  280. quadra/utils/tests/fixtures/__init__.py +1 -0
  281. quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
  282. quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
  283. quadra/utils/tests/fixtures/dataset/classification.py +406 -0
  284. quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
  285. quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
  286. quadra/utils/tests/fixtures/models/__init__.py +3 -0
  287. quadra/utils/tests/fixtures/models/anomaly.py +89 -0
  288. quadra/utils/tests/fixtures/models/classification.py +45 -0
  289. quadra/utils/tests/fixtures/models/segmentation.py +33 -0
  290. quadra/utils/tests/helpers.py +70 -0
  291. quadra/utils/tests/models.py +27 -0
  292. quadra/utils/utils.py +525 -0
  293. quadra/utils/validator.py +115 -0
  294. quadra/utils/visualization.py +422 -0
  295. quadra/utils/vit_explainability.py +349 -0
  296. quadra-2.1.13.dist-info/LICENSE +201 -0
  297. quadra-2.1.13.dist-info/METADATA +386 -0
  298. quadra-2.1.13.dist-info/RECORD +300 -0
  299. {quadra-0.0.1.dist-info → quadra-2.1.13.dist-info}/WHEEL +1 -1
  300. quadra-2.1.13.dist-info/entry_points.txt +3 -0
  301. quadra-0.0.1.dist-info/METADATA +0 -14
  302. quadra-0.0.1.dist-info/RECORD +0 -4
quadra/tasks/patch.py ADDED
@@ -0,0 +1,492 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import os
5
+ from pathlib import Path
6
+ from typing import Any, cast
7
+
8
+ import hydra
9
+ import torch
10
+ from joblib import dump, load
11
+ from omegaconf import DictConfig, OmegaConf
12
+ from sklearn.base import ClassifierMixin
13
+
14
+ from quadra.datamodules import PatchSklearnClassificationDataModule
15
+ from quadra.datasets.patch import PatchSklearnClassificationTrainDataset
16
+ from quadra.models.base import ModelSignatureWrapper
17
+ from quadra.models.evaluation import BaseEvaluationModel
18
+ from quadra.tasks.base import Evaluation, Task
19
+ from quadra.trainers.classification import SklearnClassificationTrainer
20
+ from quadra.utils import utils
21
+ from quadra.utils.classification import automatic_batch_size_computation
22
+ from quadra.utils.evaluation import automatic_datamodule_batch_size
23
+ from quadra.utils.export import export_model, import_deployment_model
24
+ from quadra.utils.patch import RleEncoder, compute_patch_metrics, save_classification_result
25
+ from quadra.utils.patch.dataset import PatchDatasetFileFormat
26
+
27
+ log = utils.get_logger(__name__)
28
+
29
+
30
+ class PatchSklearnClassification(Task[PatchSklearnClassificationDataModule]):
31
+ """Patch classification using torch backbone for feature extraction and sklearn to learn a linear classifier.
32
+
33
+ Args:
34
+ config: The experiment configuration
35
+ device: The device to use
36
+ output: Dictionary defining which kind of outputs to generate. Defaults to None.
37
+ automatic_batch_size: Whether to automatically find the largest batch size that fits in memory.
38
+ half_precision: Whether to use half precision.
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ config: DictConfig,
44
+ output: DictConfig,
45
+ device: str,
46
+ automatic_batch_size: DictConfig,
47
+ half_precision: bool = False,
48
+ ):
49
+ super().__init__(config=config)
50
+ self.device: str = device
51
+ self.output: DictConfig = output
52
+ self.return_polygon: bool = True
53
+ self.reconstruction_results: dict[str, Any]
54
+ self._backbone: ModelSignatureWrapper
55
+ self._trainer: SklearnClassificationTrainer
56
+ self._model: ClassifierMixin
57
+ self.metadata: dict[str, Any] = {
58
+ "test_confusion_matrix": [],
59
+ "test_accuracy": [],
60
+ "test_results": [],
61
+ "test_labels": [],
62
+ }
63
+ self.export_folder: str = "deployment_model"
64
+ self.automatic_batch_size = automatic_batch_size
65
+ self.half_precision = half_precision
66
+
67
+ @property
68
+ def model(self) -> ClassifierMixin:
69
+ """sklearn.base.ClassifierMixin: The model."""
70
+ return self._model
71
+
72
+ @model.setter
73
+ def model(self, model_config: DictConfig):
74
+ """sklearn.base.ClassifierMixin: The model."""
75
+ log.info("Instantiating model <%s>", model_config["_target_"])
76
+ self._model = hydra.utils.instantiate(model_config)
77
+
78
+ @property
79
+ def backbone(self) -> ModelSignatureWrapper:
80
+ """Backbone: The backbone."""
81
+ return self._backbone
82
+
83
+ @backbone.setter
84
+ def backbone(self, backbone_config):
85
+ """Load backbone."""
86
+ if backbone_config.metadata.get("checkpoint"):
87
+ log.info("Loading backbone from <%s>", backbone_config.metadata.checkpoint)
88
+ self._backbone = torch.load(backbone_config.metadata.checkpoint)
89
+ else:
90
+ log.info("Loading backbone from <%s>", backbone_config.model["_target_"])
91
+ self._backbone = hydra.utils.instantiate(backbone_config.model)
92
+
93
+ self._backbone = ModelSignatureWrapper(self._backbone)
94
+ self._backbone.eval()
95
+ if self.half_precision:
96
+ if self.device == "cpu":
97
+ raise ValueError("Half precision is not supported on CPU")
98
+ self._backbone.half()
99
+ self._backbone = self._backbone.to(self.device)
100
+
101
+ def prepare(self) -> None:
102
+ """Prepare the experiment."""
103
+ self.datamodule = self.config.datamodule
104
+ self.backbone = self.config.backbone
105
+ self.model = self.config.model
106
+
107
+ if not self.automatic_batch_size.disable and self.device != "cpu":
108
+ self.datamodule.batch_size = automatic_batch_size_computation(
109
+ datamodule=self.datamodule,
110
+ backbone=self.backbone,
111
+ starting_batch_size=self.automatic_batch_size.starting_batch_size,
112
+ )
113
+
114
+ self.trainer = self.config.trainer
115
+
116
+ @property
117
+ def trainer(self) -> SklearnClassificationTrainer:
118
+ """Trainer: The trainer."""
119
+ return self._trainer
120
+
121
+ @trainer.setter
122
+ def trainer(self, trainer_config: DictConfig) -> None:
123
+ """Trainer: The trainer."""
124
+ log.info("Instantiating trainer <%s>", trainer_config["_target_"])
125
+ trainer = hydra.utils.instantiate(trainer_config, backbone=self.backbone, classifier=self.model)
126
+ self._trainer = trainer
127
+
128
+ def train(self) -> None:
129
+ """Train the model."""
130
+ log.info("Starting training...!")
131
+ # prepare_data() must be explicitly called if the task does not include a lightining training
132
+ self.datamodule.prepare_data()
133
+ self.datamodule.setup(stage="fit")
134
+ class_to_keep = None
135
+ if hasattr(self.datamodule, "class_to_skip_training") and self.datamodule.class_to_skip_training is not None:
136
+ class_to_keep = [x for x in self.datamodule.class_to_idx if x not in self.datamodule.class_to_skip_training]
137
+
138
+ self.model = self.config.model
139
+ self.trainer.change_classifier(self.model)
140
+ train_dataloader = self.datamodule.train_dataloader()
141
+ val_dataloader = self.datamodule.val_dataloader()
142
+ train_dataset = cast(PatchSklearnClassificationTrainDataset, train_dataloader.dataset)
143
+ self.trainer.fit(train_dataloader=train_dataloader)
144
+ _, pd_cm, accuracy, res, _ = self.trainer.test(
145
+ test_dataloader=val_dataloader,
146
+ class_to_keep=class_to_keep,
147
+ idx_to_class=train_dataset.idx_to_class,
148
+ predict_proba=True,
149
+ )
150
+
151
+ # save results
152
+ self.metadata["test_confusion_matrix"] = pd_cm
153
+ self.metadata["test_accuracy"] = accuracy
154
+ self.metadata["test_results"] = res
155
+ self.metadata["test_labels"] = [
156
+ train_dataset.idx_to_class[i] if i != -1 else "N/A" for i in res["real_label"].unique().tolist()
157
+ ]
158
+
159
+ def generate_report(self) -> None:
160
+ """Generate the report for the task."""
161
+ log.info("Generating report!")
162
+ os.makedirs(self.output.folder, exist_ok=True)
163
+
164
+ c_matrix = self.metadata["test_confusion_matrix"]
165
+ idx_to_class = {v: k for k, v in self.datamodule.class_to_idx.items()}
166
+
167
+ datamodule: PatchSklearnClassificationDataModule = self.datamodule
168
+ val_img_info: list[PatchDatasetFileFormat] = datamodule.info.val_files
169
+ for img_info in val_img_info:
170
+ if not os.path.isabs(img_info.image_path):
171
+ img_info.image_path = os.path.join(datamodule.data_path, img_info.image_path)
172
+ if img_info.mask_path is not None and not os.path.isabs(img_info.mask_path):
173
+ img_info.mask_path = os.path.join(datamodule.data_path, img_info.mask_path)
174
+
175
+ false_region_bad, false_region_good, true_region_bad, reconstructions = compute_patch_metrics(
176
+ test_img_info=val_img_info,
177
+ test_results=self.metadata["test_results"],
178
+ patch_num_h=datamodule.info.patch_number[0] if datamodule.info.patch_number is not None else None,
179
+ patch_num_w=datamodule.info.patch_number[1] if datamodule.info.patch_number is not None else None,
180
+ patch_h=datamodule.info.patch_size[0] if datamodule.info.patch_size is not None else None,
181
+ patch_w=datamodule.info.patch_size[1] if datamodule.info.patch_size is not None else None,
182
+ overlap=datamodule.info.overlap,
183
+ idx_to_class=idx_to_class,
184
+ return_polygon=self.return_polygon,
185
+ patch_reconstruction_method=self.output.reconstruction_method,
186
+ annotated_good=datamodule.info.annotated_good,
187
+ )
188
+
189
+ self.reconstruction_results = {
190
+ "false_region_bad": false_region_bad,
191
+ "false_region_good": false_region_good,
192
+ "true_region_bad": true_region_bad,
193
+ "reconstructions": reconstructions,
194
+ "reconstructions_type": "polygon" if self.return_polygon else "rle",
195
+ "patch_reconstruction_method": self.output.reconstruction_method,
196
+ }
197
+
198
+ with open("reconstruction_results.json", "w") as f:
199
+ json.dump(
200
+ self.reconstruction_results,
201
+ f,
202
+ cls=RleEncoder,
203
+ )
204
+
205
+ if hasattr(self.datamodule, "class_to_skip_training") and self.datamodule.class_to_skip_training is not None:
206
+ ignore_classes = [self.datamodule.class_to_idx[x] for x in self.datamodule.class_to_skip_training]
207
+ else:
208
+ ignore_classes = None
209
+ val_dataloader = self.datamodule.val_dataloader()
210
+ save_classification_result(
211
+ results=self.metadata["test_results"],
212
+ output_folder=self.output.folder,
213
+ confusion_matrix=c_matrix,
214
+ accuracy=self.metadata["test_accuracy"],
215
+ test_dataloader=val_dataloader,
216
+ config=self.config,
217
+ output=self.output,
218
+ reconstructions=reconstructions,
219
+ ignore_classes=ignore_classes,
220
+ )
221
+
222
+ def export(self) -> None:
223
+ """Generate deployment model for the task."""
224
+ input_shapes = self.config.export.input_shapes
225
+
226
+ idx_to_class = {v: k for k, v in self.datamodule.class_to_idx.items()}
227
+
228
+ model_json, export_paths = export_model(
229
+ config=self.config,
230
+ model=self.backbone,
231
+ export_folder=self.export_folder,
232
+ half_precision=self.half_precision,
233
+ input_shapes=input_shapes,
234
+ idx_to_class=idx_to_class,
235
+ pytorch_model_type="backbone",
236
+ )
237
+
238
+ if len(export_paths) > 0:
239
+ dataset_info = self.datamodule.info
240
+
241
+ horizontal_patches = dataset_info.patch_number[1] if dataset_info.patch_number is not None else None
242
+ vertical_patches = dataset_info.patch_number[0] if dataset_info.patch_number is not None else None
243
+ patch_height = dataset_info.patch_size[0] if dataset_info.patch_size is not None else None
244
+ patch_width = dataset_info.patch_size[1] if dataset_info.patch_size is not None else None
245
+ overlap = dataset_info.overlap
246
+
247
+ model_json.update(
248
+ {
249
+ "horizontal_patches": horizontal_patches,
250
+ "vertical_patches": vertical_patches,
251
+ "patch_height": patch_height,
252
+ "patch_width": patch_width,
253
+ "overlap": overlap,
254
+ "reconstruction_method": self.output.reconstruction_method,
255
+ "class_to_skip": self.datamodule.class_to_skip_training,
256
+ }
257
+ )
258
+
259
+ with open(os.path.join(self.export_folder, "model.json"), "w") as f:
260
+ json.dump(model_json, f, cls=utils.HydraEncoder)
261
+
262
+ dump(self.model, os.path.join(self.export_folder, "classifier.joblib"))
263
+
264
+ def execute(self) -> None:
265
+ """Execute the experiment and all the steps."""
266
+ self.prepare()
267
+ self.train()
268
+ if self.output.report:
269
+ self.generate_report()
270
+ if self.config.export is not None and len(self.config.export.types) > 0:
271
+ self.export()
272
+ self.finalize()
273
+
274
+
275
+ class PatchSklearnTestClassification(Evaluation[PatchSklearnClassificationDataModule]):
276
+ """Perform a test of an already trained classification model.
277
+
278
+ Args:
279
+ config: The experiment configuration
280
+ output: where to save resultss
281
+ model_path: path to trained model from PatchSklearnClassification task.
282
+ device: the device where to run the model (cuda or cpu). Defaults to 'cpu'.
283
+ """
284
+
285
+ def __init__(
286
+ self,
287
+ config: DictConfig,
288
+ output: DictConfig,
289
+ model_path: str,
290
+ device: str = "cpu",
291
+ ):
292
+ super().__init__(config=config, model_path=model_path, device=device)
293
+ self.output = output
294
+ self._backbone: BaseEvaluationModel
295
+ self._classifier: ClassifierMixin
296
+ self.class_to_idx: dict[str, int]
297
+ self.idx_to_class: dict[int, str]
298
+ self.metadata: dict[str, Any] = {
299
+ "test_confusion_matrix": None,
300
+ "test_accuracy": None,
301
+ "test_results": None,
302
+ "test_labels": None,
303
+ }
304
+ self.class_to_skip: list[str] = []
305
+ self.reconstruction_results: dict[str, Any]
306
+ self.return_polygon: bool = True
307
+
308
+ def prepare(self) -> None:
309
+ """Prepare the experiment."""
310
+ super().prepare()
311
+
312
+ idx_to_class = {}
313
+ class_to_idx = {}
314
+ for k, v in self.model_data["classes"].items():
315
+ idx_to_class[int(k)] = v
316
+ class_to_idx[v] = int(k)
317
+
318
+ self.idx_to_class = idx_to_class
319
+ self.class_to_idx = class_to_idx
320
+ self.config.datamodule.class_to_idx = class_to_idx
321
+
322
+ self.datamodule = self.config.datamodule
323
+ # Configure trainer
324
+ self.trainer = self.config.trainer
325
+
326
+ # prepare_data() must be explicitly called because there is no lightning training
327
+ self.datamodule.prepare_data()
328
+ self.datamodule.setup(stage="test")
329
+
330
+ @automatic_datamodule_batch_size(batch_size_attribute_name="batch_size")
331
+ def test(self) -> None:
332
+ """Run the test."""
333
+ test_dataloader = self.datamodule.test_dataloader()
334
+
335
+ self.class_to_skip = self.model_data["class_to_skip"] if hasattr(self.model_data, "class_to_skip") else None
336
+ class_to_keep = None
337
+
338
+ if self.class_to_skip is not None:
339
+ class_to_keep = [x for x in self.datamodule.class_to_idx if x not in self.class_to_skip]
340
+ _, pd_cm, accuracy, res, _ = self.trainer.test(
341
+ test_dataloader=test_dataloader,
342
+ idx_to_class=self.idx_to_class,
343
+ predict_proba=True,
344
+ class_to_keep=class_to_keep,
345
+ )
346
+
347
+ # save results
348
+ self.metadata["test_confusion_matrix"] = pd_cm
349
+ self.metadata["test_accuracy"] = accuracy
350
+ self.metadata["test_results"] = res
351
+ self.metadata["test_labels"] = [
352
+ self.idx_to_class[i] if i != -1 else "N/A" for i in res["real_label"].unique().tolist()
353
+ ]
354
+
355
+ @property
356
+ def deployment_model(self):
357
+ """Deployment model."""
358
+ return None
359
+
360
+ @deployment_model.setter
361
+ def deployment_model(self, model_path: str):
362
+ """Set backbone and classifier."""
363
+ self.backbone = model_path # type: ignore[assignment]
364
+ # Load classifier
365
+ self.classifier = os.path.join(Path(model_path).parent, "classifier.joblib")
366
+
367
+ @property
368
+ def classifier(self) -> ClassifierMixin:
369
+ """Classifier: The classifier."""
370
+ return self._classifier
371
+
372
+ @classifier.setter
373
+ def classifier(self, classifier_path: str) -> None:
374
+ """Load classifier."""
375
+ self._classifier = load(classifier_path)
376
+
377
+ @property
378
+ def backbone(self) -> BaseEvaluationModel:
379
+ """Backbone: The backbone."""
380
+ return self._backbone
381
+
382
+ @backbone.setter
383
+ def backbone(self, model_path: str) -> None:
384
+ """Load backbone."""
385
+ file_extension = os.path.splitext(model_path)[1]
386
+
387
+ model_architecture = None
388
+ if file_extension == ".pth":
389
+ backbone_config_path = os.path.join(Path(model_path).parent, "model_config.yaml")
390
+ log.info("Loading backbone from config")
391
+ backbone_config = OmegaConf.load(backbone_config_path)
392
+
393
+ if backbone_config.metadata.get("checkpoint"):
394
+ log.info("Loading backbone from <%s>", backbone_config.metadata.checkpoint)
395
+ model_architecture = torch.load(backbone_config.metadata.checkpoint)
396
+ else:
397
+ log.info("Loading backbone from <%s>", backbone_config.model["_target_"])
398
+ model_architecture = hydra.utils.instantiate(backbone_config.model)
399
+
400
+ self._backbone = import_deployment_model(
401
+ model_path=model_path,
402
+ device=self.device,
403
+ inference_config=self.config.inference,
404
+ model_architecture=model_architecture,
405
+ )
406
+
407
+ @property
408
+ def trainer(self) -> SklearnClassificationTrainer:
409
+ """Trainer: The trainer."""
410
+ return self._trainer
411
+
412
+ @trainer.setter
413
+ def trainer(self, trainer_config: DictConfig) -> None:
414
+ """Trainer: The trainer."""
415
+ log.info("Instantiating trainer <%s>", trainer_config["_target_"])
416
+
417
+ if self.backbone.training:
418
+ self.backbone.eval()
419
+
420
+ trainer = hydra.utils.instantiate(trainer_config, backbone=self.backbone, classifier=self.classifier)
421
+ self._trainer = trainer
422
+
423
+ def generate_report(self) -> None:
424
+ """Generate a report for the task."""
425
+ log.info("Generating report!")
426
+ os.makedirs(self.output.folder, exist_ok=True)
427
+
428
+ c_matrix = self.metadata["test_confusion_matrix"]
429
+ idx_to_class = {v: k for k, v in self.datamodule.class_to_idx.items()}
430
+
431
+ datamodule: PatchSklearnClassificationDataModule = self.datamodule
432
+ test_img_info = datamodule.info.test_files
433
+ for img_info in test_img_info:
434
+ if not os.path.isabs(img_info.image_path):
435
+ img_info.image_path = os.path.join(datamodule.data_path, img_info.image_path)
436
+ if img_info.mask_path is not None and not os.path.isabs(img_info.mask_path):
437
+ img_info.mask_path = os.path.join(datamodule.data_path, img_info.mask_path)
438
+
439
+ false_region_bad, false_region_good, true_region_bad, reconstructions = compute_patch_metrics(
440
+ test_img_info=test_img_info,
441
+ test_results=self.metadata["test_results"],
442
+ patch_num_h=datamodule.info.patch_number[0] if datamodule.info.patch_number is not None else None,
443
+ patch_num_w=datamodule.info.patch_number[1] if datamodule.info.patch_number is not None else None,
444
+ patch_h=datamodule.info.patch_size[0] if datamodule.info.patch_size is not None else None,
445
+ patch_w=datamodule.info.patch_size[1] if datamodule.info.patch_size is not None else None,
446
+ overlap=datamodule.info.overlap,
447
+ idx_to_class=idx_to_class,
448
+ return_polygon=self.return_polygon,
449
+ patch_reconstruction_method=self.output.reconstruction_method,
450
+ annotated_good=datamodule.info.annotated_good,
451
+ )
452
+
453
+ self.reconstruction_results = {
454
+ "false_region_bad": false_region_bad,
455
+ "false_region_good": false_region_good,
456
+ "true_region_bad": true_region_bad,
457
+ "reconstructions": reconstructions,
458
+ "reconstructions_type": "polygon" if self.return_polygon else "rle",
459
+ "patch_reconstruction_method": self.output.reconstruction_method,
460
+ }
461
+
462
+ with open("reconstruction_results.json", "w") as f:
463
+ json.dump(
464
+ self.reconstruction_results,
465
+ f,
466
+ cls=RleEncoder,
467
+ )
468
+
469
+ if self.class_to_skip is not None:
470
+ ignore_classes = [datamodule.class_to_idx[x] for x in self.class_to_skip]
471
+ else:
472
+ ignore_classes = None
473
+ test_dataloader = self.datamodule.test_dataloader()
474
+ save_classification_result(
475
+ results=self.metadata["test_results"],
476
+ output_folder=self.output.folder,
477
+ confusion_matrix=c_matrix,
478
+ accuracy=self.metadata["test_accuracy"],
479
+ test_dataloader=test_dataloader,
480
+ config=self.config,
481
+ output=self.output,
482
+ reconstructions=reconstructions,
483
+ ignore_classes=ignore_classes,
484
+ )
485
+
486
+ def execute(self) -> None:
487
+ """Execute the experiment and all the steps."""
488
+ self.prepare()
489
+ self.test()
490
+ if self.output.report:
491
+ self.generate_report()
492
+ self.finalize()