quadra 0.0.1__py3-none-any.whl → 2.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (302) hide show
  1. hydra_plugins/quadra_searchpath_plugin.py +30 -0
  2. quadra/__init__.py +6 -0
  3. quadra/callbacks/__init__.py +0 -0
  4. quadra/callbacks/anomalib.py +289 -0
  5. quadra/callbacks/lightning.py +501 -0
  6. quadra/callbacks/mlflow.py +291 -0
  7. quadra/callbacks/scheduler.py +69 -0
  8. quadra/configs/__init__.py +0 -0
  9. quadra/configs/backbone/caformer_m36.yaml +8 -0
  10. quadra/configs/backbone/caformer_s36.yaml +8 -0
  11. quadra/configs/backbone/convnextv2_base.yaml +8 -0
  12. quadra/configs/backbone/convnextv2_femto.yaml +8 -0
  13. quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
  14. quadra/configs/backbone/dino_vitb8.yaml +12 -0
  15. quadra/configs/backbone/dino_vits8.yaml +12 -0
  16. quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
  17. quadra/configs/backbone/dinov2_vits14.yaml +12 -0
  18. quadra/configs/backbone/efficientnet_b0.yaml +8 -0
  19. quadra/configs/backbone/efficientnet_b1.yaml +8 -0
  20. quadra/configs/backbone/efficientnet_b2.yaml +8 -0
  21. quadra/configs/backbone/efficientnet_b3.yaml +8 -0
  22. quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
  23. quadra/configs/backbone/levit_128s.yaml +8 -0
  24. quadra/configs/backbone/mnasnet0_5.yaml +9 -0
  25. quadra/configs/backbone/resnet101.yaml +8 -0
  26. quadra/configs/backbone/resnet18.yaml +8 -0
  27. quadra/configs/backbone/resnet18_ssl.yaml +8 -0
  28. quadra/configs/backbone/resnet50.yaml +8 -0
  29. quadra/configs/backbone/smp.yaml +9 -0
  30. quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
  31. quadra/configs/backbone/unetr.yaml +15 -0
  32. quadra/configs/backbone/vit16_base.yaml +9 -0
  33. quadra/configs/backbone/vit16_small.yaml +9 -0
  34. quadra/configs/backbone/vit16_tiny.yaml +9 -0
  35. quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
  36. quadra/configs/callbacks/all.yaml +32 -0
  37. quadra/configs/callbacks/default.yaml +37 -0
  38. quadra/configs/callbacks/default_anomalib.yaml +67 -0
  39. quadra/configs/config.yaml +33 -0
  40. quadra/configs/core/default.yaml +11 -0
  41. quadra/configs/datamodule/base/anomaly.yaml +16 -0
  42. quadra/configs/datamodule/base/classification.yaml +21 -0
  43. quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
  44. quadra/configs/datamodule/base/segmentation.yaml +18 -0
  45. quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
  46. quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
  47. quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
  48. quadra/configs/datamodule/base/ssl.yaml +21 -0
  49. quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
  50. quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
  51. quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
  52. quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
  53. quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
  54. quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
  55. quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
  56. quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
  57. quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
  58. quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
  59. quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
  60. quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
  61. quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
  62. quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
  63. quadra/configs/experiment/base/classification/classification.yaml +73 -0
  64. quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
  65. quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
  66. quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
  67. quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
  68. quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
  69. quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
  70. quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
  71. quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
  72. quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
  73. quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
  74. quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
  75. quadra/configs/experiment/base/ssl/byol.yaml +43 -0
  76. quadra/configs/experiment/base/ssl/dino.yaml +46 -0
  77. quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
  78. quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
  79. quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
  80. quadra/configs/experiment/custom/cls.yaml +12 -0
  81. quadra/configs/experiment/default.yaml +15 -0
  82. quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
  83. quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
  84. quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
  85. quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
  86. quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
  87. quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
  88. quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
  89. quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
  90. quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
  91. quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
  92. quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
  93. quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
  94. quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
  95. quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
  96. quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
  97. quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
  98. quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
  99. quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
  100. quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
  101. quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
  102. quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
  103. quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
  104. quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
  105. quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
  106. quadra/configs/export/default.yaml +13 -0
  107. quadra/configs/hydra/anomaly_custom.yaml +15 -0
  108. quadra/configs/hydra/default.yaml +14 -0
  109. quadra/configs/inference/default.yaml +26 -0
  110. quadra/configs/logger/comet.yaml +10 -0
  111. quadra/configs/logger/csv.yaml +5 -0
  112. quadra/configs/logger/mlflow.yaml +12 -0
  113. quadra/configs/logger/tensorboard.yaml +8 -0
  114. quadra/configs/loss/asl.yaml +7 -0
  115. quadra/configs/loss/barlow.yaml +2 -0
  116. quadra/configs/loss/bce.yaml +1 -0
  117. quadra/configs/loss/byol.yaml +1 -0
  118. quadra/configs/loss/cross_entropy.yaml +1 -0
  119. quadra/configs/loss/dino.yaml +8 -0
  120. quadra/configs/loss/simclr.yaml +2 -0
  121. quadra/configs/loss/simsiam.yaml +1 -0
  122. quadra/configs/loss/smp_ce.yaml +3 -0
  123. quadra/configs/loss/smp_dice.yaml +2 -0
  124. quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
  125. quadra/configs/loss/smp_mcc.yaml +2 -0
  126. quadra/configs/loss/vicreg.yaml +5 -0
  127. quadra/configs/model/anomalib/cfa.yaml +35 -0
  128. quadra/configs/model/anomalib/cflow.yaml +30 -0
  129. quadra/configs/model/anomalib/csflow.yaml +34 -0
  130. quadra/configs/model/anomalib/dfm.yaml +19 -0
  131. quadra/configs/model/anomalib/draem.yaml +29 -0
  132. quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
  133. quadra/configs/model/anomalib/fastflow.yaml +32 -0
  134. quadra/configs/model/anomalib/padim.yaml +32 -0
  135. quadra/configs/model/anomalib/patchcore.yaml +36 -0
  136. quadra/configs/model/barlow.yaml +16 -0
  137. quadra/configs/model/byol.yaml +25 -0
  138. quadra/configs/model/classification.yaml +10 -0
  139. quadra/configs/model/dino.yaml +26 -0
  140. quadra/configs/model/logistic_regression.yaml +4 -0
  141. quadra/configs/model/multilabel_classification.yaml +9 -0
  142. quadra/configs/model/simclr.yaml +18 -0
  143. quadra/configs/model/simsiam.yaml +24 -0
  144. quadra/configs/model/smp.yaml +4 -0
  145. quadra/configs/model/smp_multiclass.yaml +4 -0
  146. quadra/configs/model/vicreg.yaml +16 -0
  147. quadra/configs/optimizer/adam.yaml +5 -0
  148. quadra/configs/optimizer/adamw.yaml +3 -0
  149. quadra/configs/optimizer/default.yaml +4 -0
  150. quadra/configs/optimizer/lars.yaml +8 -0
  151. quadra/configs/optimizer/sgd.yaml +4 -0
  152. quadra/configs/scheduler/default.yaml +5 -0
  153. quadra/configs/scheduler/rop.yaml +5 -0
  154. quadra/configs/scheduler/step.yaml +3 -0
  155. quadra/configs/scheduler/warmrestart.yaml +2 -0
  156. quadra/configs/scheduler/warmup.yaml +6 -0
  157. quadra/configs/task/anomalib/cfa.yaml +5 -0
  158. quadra/configs/task/anomalib/cflow.yaml +5 -0
  159. quadra/configs/task/anomalib/csflow.yaml +5 -0
  160. quadra/configs/task/anomalib/draem.yaml +5 -0
  161. quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
  162. quadra/configs/task/anomalib/fastflow.yaml +5 -0
  163. quadra/configs/task/anomalib/inference.yaml +3 -0
  164. quadra/configs/task/anomalib/padim.yaml +5 -0
  165. quadra/configs/task/anomalib/patchcore.yaml +5 -0
  166. quadra/configs/task/classification.yaml +6 -0
  167. quadra/configs/task/classification_evaluation.yaml +6 -0
  168. quadra/configs/task/default.yaml +1 -0
  169. quadra/configs/task/segmentation.yaml +9 -0
  170. quadra/configs/task/segmentation_evaluation.yaml +3 -0
  171. quadra/configs/task/sklearn_classification.yaml +13 -0
  172. quadra/configs/task/sklearn_classification_patch.yaml +11 -0
  173. quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
  174. quadra/configs/task/sklearn_classification_test.yaml +8 -0
  175. quadra/configs/task/ssl.yaml +2 -0
  176. quadra/configs/trainer/lightning_cpu.yaml +36 -0
  177. quadra/configs/trainer/lightning_gpu.yaml +35 -0
  178. quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
  179. quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
  180. quadra/configs/trainer/lightning_multigpu.yaml +37 -0
  181. quadra/configs/trainer/sklearn_classification.yaml +7 -0
  182. quadra/configs/transforms/byol.yaml +47 -0
  183. quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
  184. quadra/configs/transforms/default.yaml +37 -0
  185. quadra/configs/transforms/default_numpy.yaml +24 -0
  186. quadra/configs/transforms/default_resize.yaml +22 -0
  187. quadra/configs/transforms/dino.yaml +63 -0
  188. quadra/configs/transforms/linear_eval.yaml +18 -0
  189. quadra/datamodules/__init__.py +20 -0
  190. quadra/datamodules/anomaly.py +180 -0
  191. quadra/datamodules/base.py +375 -0
  192. quadra/datamodules/classification.py +1003 -0
  193. quadra/datamodules/generic/__init__.py +0 -0
  194. quadra/datamodules/generic/imagenette.py +144 -0
  195. quadra/datamodules/generic/mnist.py +81 -0
  196. quadra/datamodules/generic/mvtec.py +58 -0
  197. quadra/datamodules/generic/oxford_pet.py +163 -0
  198. quadra/datamodules/patch.py +190 -0
  199. quadra/datamodules/segmentation.py +742 -0
  200. quadra/datamodules/ssl.py +140 -0
  201. quadra/datasets/__init__.py +17 -0
  202. quadra/datasets/anomaly.py +287 -0
  203. quadra/datasets/classification.py +241 -0
  204. quadra/datasets/patch.py +138 -0
  205. quadra/datasets/segmentation.py +239 -0
  206. quadra/datasets/ssl.py +110 -0
  207. quadra/losses/__init__.py +0 -0
  208. quadra/losses/classification/__init__.py +6 -0
  209. quadra/losses/classification/asl.py +83 -0
  210. quadra/losses/classification/focal.py +320 -0
  211. quadra/losses/classification/prototypical.py +148 -0
  212. quadra/losses/ssl/__init__.py +17 -0
  213. quadra/losses/ssl/barlowtwins.py +47 -0
  214. quadra/losses/ssl/byol.py +37 -0
  215. quadra/losses/ssl/dino.py +129 -0
  216. quadra/losses/ssl/hyperspherical.py +45 -0
  217. quadra/losses/ssl/idmm.py +50 -0
  218. quadra/losses/ssl/simclr.py +67 -0
  219. quadra/losses/ssl/simsiam.py +30 -0
  220. quadra/losses/ssl/vicreg.py +76 -0
  221. quadra/main.py +46 -0
  222. quadra/metrics/__init__.py +3 -0
  223. quadra/metrics/segmentation.py +251 -0
  224. quadra/models/__init__.py +0 -0
  225. quadra/models/base.py +151 -0
  226. quadra/models/classification/__init__.py +8 -0
  227. quadra/models/classification/backbones.py +149 -0
  228. quadra/models/classification/base.py +92 -0
  229. quadra/models/evaluation.py +322 -0
  230. quadra/modules/__init__.py +0 -0
  231. quadra/modules/backbone.py +30 -0
  232. quadra/modules/base.py +312 -0
  233. quadra/modules/classification/__init__.py +3 -0
  234. quadra/modules/classification/base.py +331 -0
  235. quadra/modules/ssl/__init__.py +17 -0
  236. quadra/modules/ssl/barlowtwins.py +59 -0
  237. quadra/modules/ssl/byol.py +172 -0
  238. quadra/modules/ssl/common.py +285 -0
  239. quadra/modules/ssl/dino.py +186 -0
  240. quadra/modules/ssl/hyperspherical.py +206 -0
  241. quadra/modules/ssl/idmm.py +98 -0
  242. quadra/modules/ssl/simclr.py +73 -0
  243. quadra/modules/ssl/simsiam.py +68 -0
  244. quadra/modules/ssl/vicreg.py +67 -0
  245. quadra/optimizers/__init__.py +4 -0
  246. quadra/optimizers/lars.py +153 -0
  247. quadra/optimizers/sam.py +127 -0
  248. quadra/schedulers/__init__.py +3 -0
  249. quadra/schedulers/base.py +44 -0
  250. quadra/schedulers/warmup.py +127 -0
  251. quadra/tasks/__init__.py +24 -0
  252. quadra/tasks/anomaly.py +582 -0
  253. quadra/tasks/base.py +397 -0
  254. quadra/tasks/classification.py +1264 -0
  255. quadra/tasks/patch.py +492 -0
  256. quadra/tasks/segmentation.py +389 -0
  257. quadra/tasks/ssl.py +560 -0
  258. quadra/trainers/README.md +3 -0
  259. quadra/trainers/__init__.py +0 -0
  260. quadra/trainers/classification.py +179 -0
  261. quadra/utils/__init__.py +0 -0
  262. quadra/utils/anomaly.py +112 -0
  263. quadra/utils/classification.py +618 -0
  264. quadra/utils/deprecation.py +31 -0
  265. quadra/utils/evaluation.py +474 -0
  266. quadra/utils/export.py +579 -0
  267. quadra/utils/imaging.py +32 -0
  268. quadra/utils/logger.py +15 -0
  269. quadra/utils/mlflow.py +98 -0
  270. quadra/utils/model_manager.py +320 -0
  271. quadra/utils/models.py +524 -0
  272. quadra/utils/patch/__init__.py +15 -0
  273. quadra/utils/patch/dataset.py +1433 -0
  274. quadra/utils/patch/metrics.py +449 -0
  275. quadra/utils/patch/model.py +153 -0
  276. quadra/utils/patch/visualization.py +217 -0
  277. quadra/utils/resolver.py +42 -0
  278. quadra/utils/segmentation.py +31 -0
  279. quadra/utils/tests/__init__.py +0 -0
  280. quadra/utils/tests/fixtures/__init__.py +1 -0
  281. quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
  282. quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
  283. quadra/utils/tests/fixtures/dataset/classification.py +406 -0
  284. quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
  285. quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
  286. quadra/utils/tests/fixtures/models/__init__.py +3 -0
  287. quadra/utils/tests/fixtures/models/anomaly.py +89 -0
  288. quadra/utils/tests/fixtures/models/classification.py +45 -0
  289. quadra/utils/tests/fixtures/models/segmentation.py +33 -0
  290. quadra/utils/tests/helpers.py +70 -0
  291. quadra/utils/tests/models.py +27 -0
  292. quadra/utils/utils.py +525 -0
  293. quadra/utils/validator.py +115 -0
  294. quadra/utils/visualization.py +422 -0
  295. quadra/utils/vit_explainability.py +349 -0
  296. quadra-2.1.13.dist-info/LICENSE +201 -0
  297. quadra-2.1.13.dist-info/METADATA +386 -0
  298. quadra-2.1.13.dist-info/RECORD +300 -0
  299. {quadra-0.0.1.dist-info → quadra-2.1.13.dist-info}/WHEEL +1 -1
  300. quadra-2.1.13.dist-info/entry_points.txt +3 -0
  301. quadra-0.0.1.dist-info/METADATA +0 -14
  302. quadra-0.0.1.dist-info/RECORD +0 -4
@@ -0,0 +1,389 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import os
5
+ import typing
6
+ from typing import Any, Generic
7
+
8
+ import cv2
9
+ import hydra
10
+ import torch
11
+ from omegaconf import DictConfig, OmegaConf
12
+ from torch.utils.data import DataLoader
13
+
14
+ from quadra.callbacks.mlflow import get_mlflow_logger
15
+ from quadra.datamodules import SegmentationDataModule, SegmentationMulticlassDataModule
16
+ from quadra.models.base import ModelSignatureWrapper
17
+ from quadra.models.evaluation import BaseEvaluationModel
18
+ from quadra.modules.base import SegmentationModel
19
+ from quadra.tasks.base import Evaluation, LightningTask
20
+ from quadra.utils import utils
21
+ from quadra.utils.evaluation import automatic_datamodule_batch_size, create_mask_report
22
+ from quadra.utils.export import export_model
23
+
24
+ log = utils.get_logger(__name__)
25
+
26
+ SegmentationDataModuleT = typing.TypeVar(
27
+ "SegmentationDataModuleT", SegmentationDataModule, SegmentationMulticlassDataModule
28
+ )
29
+
30
+
31
+ class Segmentation(Generic[SegmentationDataModuleT], LightningTask[SegmentationDataModuleT]):
32
+ """Task for segmentation.
33
+
34
+ Args:
35
+ config: Config object
36
+ num_viz_samples: Number of samples to visualize. Defaults to 5.
37
+ checkpoint_path: Path to the checkpoint to load the model from. Defaults to None.
38
+ run_test: If True, run test after training. Defaults to False.
39
+ evaluate: Dict with evaluation parameters. Defaults to None.
40
+ report: If True, create report after training. Defaults to False.
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ config: DictConfig,
46
+ num_viz_samples: int = 5,
47
+ checkpoint_path: str | None = None,
48
+ run_test: bool = False,
49
+ evaluate: DictConfig | None = None,
50
+ report: bool = False,
51
+ ):
52
+ super().__init__(
53
+ config=config,
54
+ checkpoint_path=checkpoint_path,
55
+ run_test=run_test,
56
+ report=report,
57
+ )
58
+ self.evaluate = evaluate
59
+ self.num_viz_samples = num_viz_samples
60
+ self.export_folder: str = "deployment_model"
61
+ self.exported_model_path: str | None = None
62
+ if self.evaluate and any(self.evaluate.values()):
63
+ if (
64
+ self.config.export is None
65
+ or len(self.config.export.types) == 0
66
+ or "torchscript" not in self.config.export.types
67
+ ):
68
+ log.info(
69
+ "Evaluation is enabled, but training does not export a deployment model. Automatically export the "
70
+ "model as torchscript."
71
+ )
72
+ if self.config.export is None:
73
+ self.config.export = DictConfig({"types": ["torchscript"]})
74
+ else:
75
+ self.config.export.types.append("torchscript")
76
+
77
+ if not self.report:
78
+ log.info("Evaluation is enabled, but reporting is disabled. Enabling reporting automatically.")
79
+ self.report = True
80
+
81
+ @property
82
+ def module(self) -> SegmentationModel:
83
+ """Get the module."""
84
+ return self._module
85
+
86
+ @module.setter
87
+ def module(self, module_config) -> None:
88
+ """Set the module."""
89
+ log.info("Instantiating model <%s>", module_config.model["_target_"])
90
+
91
+ if isinstance(self.datamodule, SegmentationMulticlassDataModule) and module_config.model.num_classes != (
92
+ len(self.datamodule.idx_to_class) + 1
93
+ ):
94
+ log.warning(
95
+ f"Number of classes in the model ({module_config.model.num_classes}) does not match the number of "
96
+ + f"classes in the datamodule ({len(self.datamodule.idx_to_class)}). Updating the model..."
97
+ )
98
+ module_config.model.num_classes = len(self.datamodule.idx_to_class) + 1
99
+
100
+ model = hydra.utils.instantiate(module_config.model)
101
+ model = ModelSignatureWrapper(model)
102
+ log.info("Instantiating optimizer <%s>", self.config.optimizer["_target_"])
103
+ param_list = []
104
+ for param in model.parameters():
105
+ if param.requires_grad:
106
+ param_list.append(param)
107
+ optimizer = hydra.utils.instantiate(self.config.optimizer, param_list)
108
+ log.info("Instantiating scheduler <%s>", self.config.scheduler["_target_"])
109
+ scheduler = hydra.utils.instantiate(self.config.scheduler, optimizer=optimizer)
110
+ log.info("Instantiating module <%s>", module_config.module["_target_"])
111
+ module = hydra.utils.instantiate(module_config.module, model=model, optimizer=optimizer, lr_scheduler=scheduler)
112
+ if self.checkpoint_path is not None:
113
+ module.__class__.load_from_checkpoint(
114
+ self.checkpoint_path, model=model, optimizer=optimizer, lr_scheduler=scheduler
115
+ )
116
+ self._module = module
117
+
118
+ def prepare(self) -> None:
119
+ """Prepare the task."""
120
+ super().prepare()
121
+ self.module = self.config.model
122
+
123
+ def export(self) -> None:
124
+ """Generate a deployment model for the task."""
125
+ log.info("Exporting model ready for deployment")
126
+
127
+ # Get best model!
128
+ if (
129
+ self.trainer.checkpoint_callback is not None
130
+ and hasattr(self.trainer.checkpoint_callback, "best_model_path")
131
+ and self.trainer.checkpoint_callback.best_model_path is not None
132
+ and len(self.trainer.checkpoint_callback.best_model_path) > 0
133
+ ):
134
+ best_model_path = self.trainer.checkpoint_callback.best_model_path
135
+ log.info("Loaded best model from %s", best_model_path)
136
+
137
+ module = self.module.__class__.load_from_checkpoint(
138
+ best_model_path,
139
+ model=self.module.model,
140
+ loss_fun=None,
141
+ optimizer=self.module.optimizer,
142
+ lr_scheduler=self.module.schedulers,
143
+ )
144
+ else:
145
+ log.warning("No checkpoint callback found in the trainer, exporting the last model weights")
146
+ module = self.module
147
+
148
+ if "idx_to_class" not in self.config.datamodule:
149
+ log.info("No idx_to_class key")
150
+ idx_to_class = {0: "good", 1: "bad"} # TODO: Why is this the default value?
151
+ else:
152
+ log.info("idx_to_class is present")
153
+ idx_to_class = self.config.datamodule.idx_to_class
154
+
155
+ if self.config.export is None:
156
+ raise ValueError(
157
+ "No export type specified. This should not happen, please check if you have set "
158
+ "the export_type or assign it to a default value."
159
+ )
160
+
161
+ half_precision = "16" in self.trainer.precision
162
+
163
+ input_shapes = self.config.export.input_shapes
164
+
165
+ model_json, export_paths = export_model(
166
+ config=self.config,
167
+ model=module.model,
168
+ export_folder=self.export_folder,
169
+ half_precision=half_precision,
170
+ input_shapes=input_shapes,
171
+ idx_to_class=idx_to_class,
172
+ )
173
+
174
+ if len(export_paths) == 0:
175
+ return
176
+
177
+ # Pick one model for evaluation, it should be independent of the export type as the model is wrapped
178
+ self.exported_model_path = next(iter(export_paths.values()))
179
+
180
+ with open(os.path.join(self.export_folder, "model.json"), "w") as f:
181
+ json.dump(model_json, f, cls=utils.HydraEncoder)
182
+
183
+ def generate_report(self) -> None:
184
+ """Generate a report for the task."""
185
+ if self.evaluate is not None:
186
+ log.info("Generating evaluation report!")
187
+ eval_tasks: list[SegmentationEvaluation] = []
188
+ if self.evaluate.analysis:
189
+ if self.exported_model_path is None:
190
+ raise ValueError(
191
+ "Exported model path is not set yet but the task tries to do an analysis evaluation"
192
+ )
193
+ eval_task = SegmentationAnalysisEvaluation(
194
+ config=self.config,
195
+ model_path=self.exported_model_path,
196
+ )
197
+ eval_tasks.append(eval_task)
198
+ for task in eval_tasks:
199
+ task.execute()
200
+
201
+ if len(self.logger) > 0:
202
+ mflow_logger = get_mlflow_logger(trainer=self.trainer)
203
+ tensorboard_logger = utils.get_tensorboard_logger(trainer=self.trainer)
204
+
205
+ if mflow_logger is not None and self.config.core.get("upload_artifacts"):
206
+ for task in eval_tasks:
207
+ for file in task.metadata["report_files"]:
208
+ mflow_logger.experiment.log_artifact(
209
+ run_id=mflow_logger.run_id, local_path=file, artifact_path=task.report_path
210
+ )
211
+
212
+ if tensorboard_logger is not None and self.config.core.get("upload_artifacts"):
213
+ for task in eval_tasks:
214
+ for file in task.metadata["report_files"]:
215
+ ext = os.path.splitext(file)[1].lower()
216
+
217
+ if ext in [".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".tif", ".gif"]:
218
+ try:
219
+ img = cv2.imread(file)
220
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
221
+ except cv2.error:
222
+ log.info("Could not upload artifact image %s", file)
223
+ continue
224
+
225
+ tensorboard_logger.experiment.add_image(
226
+ os.path.basename(file), img, 0, dataformats="HWC"
227
+ )
228
+ else:
229
+ utils.upload_file_tensorboard(file, tensorboard_logger)
230
+
231
+
232
+ class SegmentationEvaluation(Evaluation[SegmentationDataModuleT]):
233
+ """Segmentation Evaluation Task with deployment models.
234
+
235
+ Args:
236
+ config: The experiment configuration
237
+ model_path: The experiment path.
238
+ device: Device to use for evaluation. If None, the device is automatically determined.
239
+
240
+ Raises:
241
+ ValueError: If the model path is not provided
242
+ """
243
+
244
+ def __init__(
245
+ self,
246
+ config: DictConfig,
247
+ model_path: str,
248
+ device: str | None = "cpu",
249
+ ):
250
+ super().__init__(config=config, model_path=model_path, device=device)
251
+ self.config = config
252
+
253
+ def save_config(self) -> None:
254
+ """Skip saving the config."""
255
+
256
+ def prepare(self) -> None:
257
+ """Prepare the evaluation."""
258
+ super().prepare()
259
+ # TODO: Why we propagate mean and std only in Segmentation?
260
+ self.config.transforms.mean = self.model_data["mean"]
261
+ self.config.transforms.std = self.model_data["std"]
262
+ # Setup datamodule
263
+ if hasattr(self.config.datamodule, "idx_to_class"):
264
+ idx_to_class = self.model_data["classes"] # dict {index: class}
265
+ self.config.datamodule.idx_to_class = idx_to_class
266
+ self.datamodule = self.config.datamodule
267
+ # prepare_data() must be explicitly called because there is no lightning training
268
+ self.datamodule.prepare_data()
269
+
270
+ @torch.no_grad()
271
+ def inference(
272
+ self, dataloader: DataLoader, deployment_model: BaseEvaluationModel, device: torch.device
273
+ ) -> dict[str, torch.Tensor]:
274
+ """Run inference on the dataloader and return the output.
275
+
276
+ Args:
277
+ dataloader: The dataloader to run inference on
278
+ deployment_model: The deployment model to use
279
+ device: The device to run inference on
280
+ """
281
+ image_list, mask_list, mask_pred_list, label_list = [], [], [], []
282
+ for batch in dataloader:
283
+ images, masks, labels = batch
284
+ images = images.to(device)
285
+ masks = masks.to(device)
286
+ labels = labels.to(device)
287
+ image_list.append(images.cpu())
288
+ mask_list.append(masks.cpu())
289
+ mask_pred_list.append(deployment_model(images.to(device)).cpu())
290
+ label_list.append(labels.cpu())
291
+ output = {
292
+ "image": torch.cat(image_list, dim=0),
293
+ "mask": torch.cat(mask_list, dim=0),
294
+ "label": torch.cat(label_list, dim=0),
295
+ "mask_pred": torch.cat(mask_pred_list, dim=0),
296
+ }
297
+ return output
298
+
299
+
300
+ class SegmentationAnalysisEvaluation(SegmentationEvaluation):
301
+ """Segmentation Analysis Evaluation Task
302
+ Args:
303
+ config: The experiment configuration
304
+ model_path: The model path.
305
+ device: Device to use for evaluation. If None, the device is automatically determined.
306
+ """
307
+
308
+ def __init__(
309
+ self,
310
+ config: DictConfig,
311
+ model_path: str,
312
+ device: str | None = None,
313
+ ):
314
+ super().__init__(config=config, model_path=model_path, device=device)
315
+ self.test_output: dict[str, Any] = {}
316
+
317
+ def train(self) -> None:
318
+ """Skip training."""
319
+
320
+ def prepare(self) -> None:
321
+ """Prepare the evaluation task."""
322
+ super().prepare()
323
+ self.datamodule.setup(stage="fit")
324
+ self.datamodule.setup(stage="test")
325
+
326
+ @automatic_datamodule_batch_size(batch_size_attribute_name="batch_size")
327
+ def test(self) -> None:
328
+ """Run testing."""
329
+ log.info("Starting inference for analysis.")
330
+
331
+ stages: list[str] = []
332
+ dataloaders: list[torch.utils.data.DataLoader] = []
333
+
334
+ # if self.datamodule.train_dataset_available:
335
+ # stages.append("train")
336
+ # dataloaders.append(self.datamodule.train_dataloader())
337
+ # if self.datamodule.val_dataset_available:
338
+ # stages.append("val")
339
+ # dataloaders.append(self.datamodule.val_dataloader())
340
+
341
+ if self.datamodule.test_dataset_available:
342
+ stages.append("test")
343
+ dataloaders.append(self.datamodule.test_dataloader())
344
+ for stage, dataloader in zip(stages, dataloaders):
345
+ log.info("Running inference on %s set with batch size: %d", stage, dataloader.batch_size)
346
+ image_list, mask_list, mask_pred_list, label_list = [], [], [], []
347
+ for batch in dataloader:
348
+ images, masks, labels = batch
349
+ images = images.to(device=self.device, dtype=self.deployment_model.model_dtype)
350
+ if len(masks.shape) == 3: # BxHxW -> Bx1xHxW
351
+ masks = masks.unsqueeze(1)
352
+ with torch.no_grad():
353
+ image_list.append(images)
354
+ mask_list.append(masks)
355
+ mask_pred_list.append(self.deployment_model(images.to(self.device)))
356
+ label_list.append(labels)
357
+
358
+ output = {
359
+ "image": torch.cat(image_list, dim=0),
360
+ "mask": torch.cat(mask_list, dim=0),
361
+ "label": torch.cat(label_list, dim=0),
362
+ "mask_pred": torch.cat(mask_pred_list, dim=0),
363
+ }
364
+ self.test_output[stage] = output
365
+
366
+ def generate_report(self) -> None:
367
+ """Generate a report."""
368
+ log.info("Generating analysis report")
369
+
370
+ for stage, output in self.test_output.items():
371
+ image_mean = OmegaConf.to_container(self.config.transforms.mean)
372
+ if not isinstance(image_mean, list) or any(not isinstance(x, (int, float)) for x in image_mean):
373
+ raise ValueError("Image mean is not a list of float or integer values, please check your config")
374
+ image_std = OmegaConf.to_container(self.config.transforms.std)
375
+ if not isinstance(image_std, list) or any(not isinstance(x, (int, float)) for x in image_std):
376
+ raise ValueError("Image std is not a list of float or integer values, please check your config")
377
+ reports = create_mask_report(
378
+ stage=stage,
379
+ output=output,
380
+ report_path="analysis_report",
381
+ mean=image_mean,
382
+ std=image_std,
383
+ analysis=True,
384
+ nb_samples=10,
385
+ apply_sigmoid=True,
386
+ show_orj_predictions=True,
387
+ )
388
+ self.metadata["report_files"].extend(reports)
389
+ log.info("%s analysis report completed.", stage)