quadra 0.0.1__py3-none-any.whl → 2.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (302) hide show
  1. hydra_plugins/quadra_searchpath_plugin.py +30 -0
  2. quadra/__init__.py +6 -0
  3. quadra/callbacks/__init__.py +0 -0
  4. quadra/callbacks/anomalib.py +289 -0
  5. quadra/callbacks/lightning.py +501 -0
  6. quadra/callbacks/mlflow.py +291 -0
  7. quadra/callbacks/scheduler.py +69 -0
  8. quadra/configs/__init__.py +0 -0
  9. quadra/configs/backbone/caformer_m36.yaml +8 -0
  10. quadra/configs/backbone/caformer_s36.yaml +8 -0
  11. quadra/configs/backbone/convnextv2_base.yaml +8 -0
  12. quadra/configs/backbone/convnextv2_femto.yaml +8 -0
  13. quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
  14. quadra/configs/backbone/dino_vitb8.yaml +12 -0
  15. quadra/configs/backbone/dino_vits8.yaml +12 -0
  16. quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
  17. quadra/configs/backbone/dinov2_vits14.yaml +12 -0
  18. quadra/configs/backbone/efficientnet_b0.yaml +8 -0
  19. quadra/configs/backbone/efficientnet_b1.yaml +8 -0
  20. quadra/configs/backbone/efficientnet_b2.yaml +8 -0
  21. quadra/configs/backbone/efficientnet_b3.yaml +8 -0
  22. quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
  23. quadra/configs/backbone/levit_128s.yaml +8 -0
  24. quadra/configs/backbone/mnasnet0_5.yaml +9 -0
  25. quadra/configs/backbone/resnet101.yaml +8 -0
  26. quadra/configs/backbone/resnet18.yaml +8 -0
  27. quadra/configs/backbone/resnet18_ssl.yaml +8 -0
  28. quadra/configs/backbone/resnet50.yaml +8 -0
  29. quadra/configs/backbone/smp.yaml +9 -0
  30. quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
  31. quadra/configs/backbone/unetr.yaml +15 -0
  32. quadra/configs/backbone/vit16_base.yaml +9 -0
  33. quadra/configs/backbone/vit16_small.yaml +9 -0
  34. quadra/configs/backbone/vit16_tiny.yaml +9 -0
  35. quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
  36. quadra/configs/callbacks/all.yaml +32 -0
  37. quadra/configs/callbacks/default.yaml +37 -0
  38. quadra/configs/callbacks/default_anomalib.yaml +67 -0
  39. quadra/configs/config.yaml +33 -0
  40. quadra/configs/core/default.yaml +11 -0
  41. quadra/configs/datamodule/base/anomaly.yaml +16 -0
  42. quadra/configs/datamodule/base/classification.yaml +21 -0
  43. quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
  44. quadra/configs/datamodule/base/segmentation.yaml +18 -0
  45. quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
  46. quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
  47. quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
  48. quadra/configs/datamodule/base/ssl.yaml +21 -0
  49. quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
  50. quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
  51. quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
  52. quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
  53. quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
  54. quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
  55. quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
  56. quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
  57. quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
  58. quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
  59. quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
  60. quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
  61. quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
  62. quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
  63. quadra/configs/experiment/base/classification/classification.yaml +73 -0
  64. quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
  65. quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
  66. quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
  67. quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
  68. quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
  69. quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
  70. quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
  71. quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
  72. quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
  73. quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
  74. quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
  75. quadra/configs/experiment/base/ssl/byol.yaml +43 -0
  76. quadra/configs/experiment/base/ssl/dino.yaml +46 -0
  77. quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
  78. quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
  79. quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
  80. quadra/configs/experiment/custom/cls.yaml +12 -0
  81. quadra/configs/experiment/default.yaml +15 -0
  82. quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
  83. quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
  84. quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
  85. quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
  86. quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
  87. quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
  88. quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
  89. quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
  90. quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
  91. quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
  92. quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
  93. quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
  94. quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
  95. quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
  96. quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
  97. quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
  98. quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
  99. quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
  100. quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
  101. quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
  102. quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
  103. quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
  104. quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
  105. quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
  106. quadra/configs/export/default.yaml +13 -0
  107. quadra/configs/hydra/anomaly_custom.yaml +15 -0
  108. quadra/configs/hydra/default.yaml +14 -0
  109. quadra/configs/inference/default.yaml +26 -0
  110. quadra/configs/logger/comet.yaml +10 -0
  111. quadra/configs/logger/csv.yaml +5 -0
  112. quadra/configs/logger/mlflow.yaml +12 -0
  113. quadra/configs/logger/tensorboard.yaml +8 -0
  114. quadra/configs/loss/asl.yaml +7 -0
  115. quadra/configs/loss/barlow.yaml +2 -0
  116. quadra/configs/loss/bce.yaml +1 -0
  117. quadra/configs/loss/byol.yaml +1 -0
  118. quadra/configs/loss/cross_entropy.yaml +1 -0
  119. quadra/configs/loss/dino.yaml +8 -0
  120. quadra/configs/loss/simclr.yaml +2 -0
  121. quadra/configs/loss/simsiam.yaml +1 -0
  122. quadra/configs/loss/smp_ce.yaml +3 -0
  123. quadra/configs/loss/smp_dice.yaml +2 -0
  124. quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
  125. quadra/configs/loss/smp_mcc.yaml +2 -0
  126. quadra/configs/loss/vicreg.yaml +5 -0
  127. quadra/configs/model/anomalib/cfa.yaml +35 -0
  128. quadra/configs/model/anomalib/cflow.yaml +30 -0
  129. quadra/configs/model/anomalib/csflow.yaml +34 -0
  130. quadra/configs/model/anomalib/dfm.yaml +19 -0
  131. quadra/configs/model/anomalib/draem.yaml +29 -0
  132. quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
  133. quadra/configs/model/anomalib/fastflow.yaml +32 -0
  134. quadra/configs/model/anomalib/padim.yaml +32 -0
  135. quadra/configs/model/anomalib/patchcore.yaml +36 -0
  136. quadra/configs/model/barlow.yaml +16 -0
  137. quadra/configs/model/byol.yaml +25 -0
  138. quadra/configs/model/classification.yaml +10 -0
  139. quadra/configs/model/dino.yaml +26 -0
  140. quadra/configs/model/logistic_regression.yaml +4 -0
  141. quadra/configs/model/multilabel_classification.yaml +9 -0
  142. quadra/configs/model/simclr.yaml +18 -0
  143. quadra/configs/model/simsiam.yaml +24 -0
  144. quadra/configs/model/smp.yaml +4 -0
  145. quadra/configs/model/smp_multiclass.yaml +4 -0
  146. quadra/configs/model/vicreg.yaml +16 -0
  147. quadra/configs/optimizer/adam.yaml +5 -0
  148. quadra/configs/optimizer/adamw.yaml +3 -0
  149. quadra/configs/optimizer/default.yaml +4 -0
  150. quadra/configs/optimizer/lars.yaml +8 -0
  151. quadra/configs/optimizer/sgd.yaml +4 -0
  152. quadra/configs/scheduler/default.yaml +5 -0
  153. quadra/configs/scheduler/rop.yaml +5 -0
  154. quadra/configs/scheduler/step.yaml +3 -0
  155. quadra/configs/scheduler/warmrestart.yaml +2 -0
  156. quadra/configs/scheduler/warmup.yaml +6 -0
  157. quadra/configs/task/anomalib/cfa.yaml +5 -0
  158. quadra/configs/task/anomalib/cflow.yaml +5 -0
  159. quadra/configs/task/anomalib/csflow.yaml +5 -0
  160. quadra/configs/task/anomalib/draem.yaml +5 -0
  161. quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
  162. quadra/configs/task/anomalib/fastflow.yaml +5 -0
  163. quadra/configs/task/anomalib/inference.yaml +3 -0
  164. quadra/configs/task/anomalib/padim.yaml +5 -0
  165. quadra/configs/task/anomalib/patchcore.yaml +5 -0
  166. quadra/configs/task/classification.yaml +6 -0
  167. quadra/configs/task/classification_evaluation.yaml +6 -0
  168. quadra/configs/task/default.yaml +1 -0
  169. quadra/configs/task/segmentation.yaml +9 -0
  170. quadra/configs/task/segmentation_evaluation.yaml +3 -0
  171. quadra/configs/task/sklearn_classification.yaml +13 -0
  172. quadra/configs/task/sklearn_classification_patch.yaml +11 -0
  173. quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
  174. quadra/configs/task/sklearn_classification_test.yaml +8 -0
  175. quadra/configs/task/ssl.yaml +2 -0
  176. quadra/configs/trainer/lightning_cpu.yaml +36 -0
  177. quadra/configs/trainer/lightning_gpu.yaml +35 -0
  178. quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
  179. quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
  180. quadra/configs/trainer/lightning_multigpu.yaml +37 -0
  181. quadra/configs/trainer/sklearn_classification.yaml +7 -0
  182. quadra/configs/transforms/byol.yaml +47 -0
  183. quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
  184. quadra/configs/transforms/default.yaml +37 -0
  185. quadra/configs/transforms/default_numpy.yaml +24 -0
  186. quadra/configs/transforms/default_resize.yaml +22 -0
  187. quadra/configs/transforms/dino.yaml +63 -0
  188. quadra/configs/transforms/linear_eval.yaml +18 -0
  189. quadra/datamodules/__init__.py +20 -0
  190. quadra/datamodules/anomaly.py +180 -0
  191. quadra/datamodules/base.py +375 -0
  192. quadra/datamodules/classification.py +1003 -0
  193. quadra/datamodules/generic/__init__.py +0 -0
  194. quadra/datamodules/generic/imagenette.py +144 -0
  195. quadra/datamodules/generic/mnist.py +81 -0
  196. quadra/datamodules/generic/mvtec.py +58 -0
  197. quadra/datamodules/generic/oxford_pet.py +163 -0
  198. quadra/datamodules/patch.py +190 -0
  199. quadra/datamodules/segmentation.py +742 -0
  200. quadra/datamodules/ssl.py +140 -0
  201. quadra/datasets/__init__.py +17 -0
  202. quadra/datasets/anomaly.py +287 -0
  203. quadra/datasets/classification.py +241 -0
  204. quadra/datasets/patch.py +138 -0
  205. quadra/datasets/segmentation.py +239 -0
  206. quadra/datasets/ssl.py +110 -0
  207. quadra/losses/__init__.py +0 -0
  208. quadra/losses/classification/__init__.py +6 -0
  209. quadra/losses/classification/asl.py +83 -0
  210. quadra/losses/classification/focal.py +320 -0
  211. quadra/losses/classification/prototypical.py +148 -0
  212. quadra/losses/ssl/__init__.py +17 -0
  213. quadra/losses/ssl/barlowtwins.py +47 -0
  214. quadra/losses/ssl/byol.py +37 -0
  215. quadra/losses/ssl/dino.py +129 -0
  216. quadra/losses/ssl/hyperspherical.py +45 -0
  217. quadra/losses/ssl/idmm.py +50 -0
  218. quadra/losses/ssl/simclr.py +67 -0
  219. quadra/losses/ssl/simsiam.py +30 -0
  220. quadra/losses/ssl/vicreg.py +76 -0
  221. quadra/main.py +46 -0
  222. quadra/metrics/__init__.py +3 -0
  223. quadra/metrics/segmentation.py +251 -0
  224. quadra/models/__init__.py +0 -0
  225. quadra/models/base.py +151 -0
  226. quadra/models/classification/__init__.py +8 -0
  227. quadra/models/classification/backbones.py +149 -0
  228. quadra/models/classification/base.py +92 -0
  229. quadra/models/evaluation.py +322 -0
  230. quadra/modules/__init__.py +0 -0
  231. quadra/modules/backbone.py +30 -0
  232. quadra/modules/base.py +312 -0
  233. quadra/modules/classification/__init__.py +3 -0
  234. quadra/modules/classification/base.py +331 -0
  235. quadra/modules/ssl/__init__.py +17 -0
  236. quadra/modules/ssl/barlowtwins.py +59 -0
  237. quadra/modules/ssl/byol.py +172 -0
  238. quadra/modules/ssl/common.py +285 -0
  239. quadra/modules/ssl/dino.py +186 -0
  240. quadra/modules/ssl/hyperspherical.py +206 -0
  241. quadra/modules/ssl/idmm.py +98 -0
  242. quadra/modules/ssl/simclr.py +73 -0
  243. quadra/modules/ssl/simsiam.py +68 -0
  244. quadra/modules/ssl/vicreg.py +67 -0
  245. quadra/optimizers/__init__.py +4 -0
  246. quadra/optimizers/lars.py +153 -0
  247. quadra/optimizers/sam.py +127 -0
  248. quadra/schedulers/__init__.py +3 -0
  249. quadra/schedulers/base.py +44 -0
  250. quadra/schedulers/warmup.py +127 -0
  251. quadra/tasks/__init__.py +24 -0
  252. quadra/tasks/anomaly.py +582 -0
  253. quadra/tasks/base.py +397 -0
  254. quadra/tasks/classification.py +1264 -0
  255. quadra/tasks/patch.py +492 -0
  256. quadra/tasks/segmentation.py +389 -0
  257. quadra/tasks/ssl.py +560 -0
  258. quadra/trainers/README.md +3 -0
  259. quadra/trainers/__init__.py +0 -0
  260. quadra/trainers/classification.py +179 -0
  261. quadra/utils/__init__.py +0 -0
  262. quadra/utils/anomaly.py +112 -0
  263. quadra/utils/classification.py +618 -0
  264. quadra/utils/deprecation.py +31 -0
  265. quadra/utils/evaluation.py +474 -0
  266. quadra/utils/export.py +579 -0
  267. quadra/utils/imaging.py +32 -0
  268. quadra/utils/logger.py +15 -0
  269. quadra/utils/mlflow.py +98 -0
  270. quadra/utils/model_manager.py +320 -0
  271. quadra/utils/models.py +524 -0
  272. quadra/utils/patch/__init__.py +15 -0
  273. quadra/utils/patch/dataset.py +1433 -0
  274. quadra/utils/patch/metrics.py +449 -0
  275. quadra/utils/patch/model.py +153 -0
  276. quadra/utils/patch/visualization.py +217 -0
  277. quadra/utils/resolver.py +42 -0
  278. quadra/utils/segmentation.py +31 -0
  279. quadra/utils/tests/__init__.py +0 -0
  280. quadra/utils/tests/fixtures/__init__.py +1 -0
  281. quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
  282. quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
  283. quadra/utils/tests/fixtures/dataset/classification.py +406 -0
  284. quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
  285. quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
  286. quadra/utils/tests/fixtures/models/__init__.py +3 -0
  287. quadra/utils/tests/fixtures/models/anomaly.py +89 -0
  288. quadra/utils/tests/fixtures/models/classification.py +45 -0
  289. quadra/utils/tests/fixtures/models/segmentation.py +33 -0
  290. quadra/utils/tests/helpers.py +70 -0
  291. quadra/utils/tests/models.py +27 -0
  292. quadra/utils/utils.py +525 -0
  293. quadra/utils/validator.py +115 -0
  294. quadra/utils/visualization.py +422 -0
  295. quadra/utils/vit_explainability.py +349 -0
  296. quadra-2.1.13.dist-info/LICENSE +201 -0
  297. quadra-2.1.13.dist-info/METADATA +386 -0
  298. quadra-2.1.13.dist-info/RECORD +300 -0
  299. {quadra-0.0.1.dist-info → quadra-2.1.13.dist-info}/WHEEL +1 -1
  300. quadra-2.1.13.dist-info/entry_points.txt +3 -0
  301. quadra-0.0.1.dist-info/METADATA +0 -14
  302. quadra-0.0.1.dist-info/RECORD +0 -4
quadra/tasks/base.py ADDED
@@ -0,0 +1,397 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import os
5
+ from pathlib import Path
6
+ from typing import Any, Generic, TypeVar
7
+
8
+ import hydra
9
+ import torch
10
+ from hydra.core.hydra_config import HydraConfig
11
+ from lightning_fabric.utilities.device_parser import _parse_gpu_ids
12
+ from omegaconf import DictConfig, OmegaConf, open_dict
13
+ from pytorch_lightning import Callback, LightningModule, Trainer
14
+ from pytorch_lightning.loggers import Logger, MLFlowLogger
15
+ from pytorch_lightning.utilities.exceptions import MisconfigurationException
16
+
17
+ from quadra import get_version
18
+ from quadra.callbacks.mlflow import validate_artifact_storage
19
+ from quadra.datamodules.base import BaseDataModule
20
+ from quadra.models.evaluation import BaseEvaluationModel
21
+ from quadra.utils import utils
22
+ from quadra.utils.export import import_deployment_model
23
+
24
+ log = utils.get_logger(__name__)
25
+ DataModuleT = TypeVar("DataModuleT", bound=BaseDataModule)
26
+
27
+
28
+ class Task(Generic[DataModuleT]):
29
+ """Base Experiment Task.
30
+
31
+ Args:
32
+ config: The experiment configuration.
33
+ """
34
+
35
+ def __init__(self, config: DictConfig):
36
+ self.config = config
37
+ self.export_folder: str = "deployment_model"
38
+ self._datamodule: DataModuleT
39
+ self.metadata: dict[str, Any]
40
+ self.save_config()
41
+
42
+ def save_config(self) -> None:
43
+ """Save the experiment configuration when running an Hydra experiment."""
44
+ if HydraConfig.initialized():
45
+ with open("config_resolved.yaml", "w") as fp:
46
+ OmegaConf.save(config=OmegaConf.to_container(self.config, resolve=True), f=fp.name)
47
+
48
+ def prepare(self) -> None:
49
+ """Prepare the experiment."""
50
+ self.datamodule = self.config.datamodule
51
+
52
+ @property
53
+ def datamodule(self) -> DataModuleT:
54
+ """T_DATAMODULE: The datamodule."""
55
+ return self._datamodule
56
+
57
+ @datamodule.setter
58
+ def datamodule(self, datamodule_config: DictConfig) -> None:
59
+ """DataModuleT: The datamodule. Instantiated from the datamodule config."""
60
+ log.info("Instantiating datamodule <%s>", {datamodule_config["_target_"]})
61
+ datamodule: DataModuleT = hydra.utils.instantiate(datamodule_config)
62
+ self._datamodule = datamodule
63
+
64
+ def train(self) -> Any:
65
+ """Train the model."""
66
+ log.info("Training not implemented for this task!")
67
+
68
+ def test(self) -> Any:
69
+ """Test the model."""
70
+ log.info("Testing not implemented for this task!")
71
+
72
+ def export(self) -> None:
73
+ """Export model for production."""
74
+ log.info("Export model for production not implemented for this task!")
75
+
76
+ def generate_report(self) -> None:
77
+ """Generate a report."""
78
+ log.info("Report generation not implemented for this task!")
79
+
80
+ def finalize(self) -> None:
81
+ """Finalize the experiment."""
82
+ log.info("Results are saved in %s", os.getcwd())
83
+
84
+ def execute(self) -> None:
85
+ """Execute the experiment and all the steps."""
86
+ self.prepare()
87
+ self.train()
88
+ self.test()
89
+ if self.config.export is not None and len(self.config.export.types) > 0:
90
+ self.export()
91
+ self.generate_report()
92
+ self.finalize()
93
+
94
+
95
+ class LightningTask(Generic[DataModuleT], Task[DataModuleT]):
96
+ """Base Experiment Task.
97
+
98
+ Args:
99
+ config: The experiment configuration
100
+ checkpoint_path: The path to the checkpoint to load the model from. Defaults to None.
101
+ run_test: Whether to run the test after training. Defaults to False.
102
+ report: Whether to generate a report. Defaults to False.
103
+ """
104
+
105
+ def __init__(
106
+ self,
107
+ config: DictConfig,
108
+ checkpoint_path: str | None = None,
109
+ run_test: bool = False,
110
+ report: bool = False,
111
+ ):
112
+ super().__init__(config=config)
113
+ self.checkpoint_path = checkpoint_path
114
+ self.run_test = run_test
115
+ self.report = report
116
+ self._module: LightningModule
117
+ self._devices: int | list[int]
118
+ self._callbacks: list[Callback]
119
+ self._logger: list[Logger]
120
+ self._trainer: Trainer
121
+
122
+ def prepare(self) -> None:
123
+ """Prepare the experiment."""
124
+ super().prepare()
125
+
126
+ # First setup loggers since some callbacks might need logger setup correctly.
127
+ if "logger" in self.config:
128
+ self.logger = self.config.logger
129
+
130
+ if "callbacks" in self.config:
131
+ self.callbacks = self.config.callbacks
132
+
133
+ self.devices = self.config.trainer.devices
134
+ self.trainer = self.config.trainer
135
+
136
+ @property
137
+ def module(self) -> LightningModule:
138
+ """LightningModule: The model."""
139
+ return self._module
140
+
141
+ @module.setter
142
+ def module(self, module_config) -> None:
143
+ """LightningModule: The model."""
144
+ raise NotImplementedError("module must be set in subclass")
145
+
146
+ @property
147
+ def trainer(self) -> Trainer:
148
+ """Trainer: The trainer."""
149
+ return self._trainer
150
+
151
+ @trainer.setter
152
+ def trainer(self, trainer_config: DictConfig) -> None:
153
+ """Trainer: The trainer."""
154
+ log.info("Instantiating trainer <%s>", trainer_config["_target_"])
155
+ trainer_config.devices = self.devices
156
+ trainer: Trainer = hydra.utils.instantiate(
157
+ trainer_config,
158
+ callbacks=self.callbacks,
159
+ logger=self.logger,
160
+ _convert_="partial",
161
+ )
162
+ self._trainer = trainer
163
+
164
+ @property
165
+ def callbacks(self) -> list[Callback]:
166
+ """List[Callback]: The callbacks."""
167
+ return self._callbacks
168
+
169
+ @callbacks.setter
170
+ def callbacks(self, callbacks_config) -> None:
171
+ """List[Callback]: The callbacks."""
172
+ if self.config.core.get("unit_test"):
173
+ log.info("Unit Testing, skipping callbacks")
174
+ return
175
+ instatiated_callbacks = []
176
+ for _, cb_conf in callbacks_config.items():
177
+ if "_target_" in cb_conf:
178
+ # Disable is a reserved keyword for callbacks, hopefully no callback will use it
179
+ if "disable" in cb_conf:
180
+ if cb_conf["disable"]:
181
+ log.info("Skipping callback <%s> as it is disabled", cb_conf["_target_"])
182
+ continue
183
+
184
+ with open_dict(cb_conf):
185
+ del cb_conf.disable
186
+
187
+ # Skip the gpu stats logger callback if no gpu is available to avoid errors
188
+ if not torch.cuda.is_available() and cb_conf["_target_"] == "nvitop.callbacks.lightning.GpuStatsLogger":
189
+ continue
190
+
191
+ log.info("Instantiating callback <%s>", cb_conf["_target_"])
192
+ instatiated_callbacks.append(hydra.utils.instantiate(cb_conf))
193
+ self._callbacks = instatiated_callbacks
194
+ if len(instatiated_callbacks) <= 0:
195
+ log.warning("No callback found in configuration.")
196
+
197
+ @property
198
+ def logger(self) -> list[Logger]:
199
+ """List[Logger]: The loggers."""
200
+ return self._logger
201
+
202
+ @logger.setter
203
+ def logger(self, logger_config) -> None:
204
+ """List[Logger]: The loggers."""
205
+ if self.config.core.get("unit_test"):
206
+ log.info("Unit Testing, skipping loggers")
207
+ return
208
+ instantiated_loggers = []
209
+ for _, lg_conf in logger_config.items():
210
+ if "_target_" in lg_conf:
211
+ log.info("Instantiating logger <%s>", lg_conf["_target_"])
212
+ logger = hydra.utils.instantiate(lg_conf)
213
+ if isinstance(logger, MLFlowLogger):
214
+ validate_artifact_storage(logger)
215
+ instantiated_loggers.append(logger)
216
+
217
+ self._logger = instantiated_loggers
218
+
219
+ if len(instantiated_loggers) <= 0:
220
+ log.warning("No logger found in configuration.")
221
+
222
+ @property
223
+ def devices(self) -> int | list[int]:
224
+ """List[int]: The devices ids."""
225
+ return self._devices
226
+
227
+ @devices.setter
228
+ def devices(self, devices) -> None:
229
+ """List[int]: The devices ids."""
230
+ if self.config.trainer.get("accelerator") == "cpu":
231
+ self._devices = self.config.trainer.devices
232
+ return
233
+
234
+ try:
235
+ self._devices = _parse_gpu_ids(devices, include_cuda=True)
236
+ except MisconfigurationException:
237
+ self._devices = 1
238
+ self.config.trainer["accelerator"] = "cpu"
239
+ log.warning("Trying to instantiate GPUs but no GPUs are available, training will be done on CPU")
240
+
241
+ def train(self) -> None:
242
+ """Train the model."""
243
+ log.info("Starting training!")
244
+ utils.log_hyperparameters(
245
+ config=self.config,
246
+ model=self.module,
247
+ trainer=self.trainer,
248
+ )
249
+
250
+ self.trainer.fit(model=self.module, datamodule=self.datamodule)
251
+
252
+ def test(self) -> Any:
253
+ """Test the model."""
254
+ log.info("Starting testing!")
255
+
256
+ best_model = None
257
+ if (
258
+ self.trainer.checkpoint_callback is not None
259
+ and hasattr(self.trainer.checkpoint_callback, "best_model_path")
260
+ and self.trainer.checkpoint_callback.best_model_path is not None
261
+ and len(self.trainer.checkpoint_callback.best_model_path) > 0
262
+ ):
263
+ best_model = self.trainer.checkpoint_callback.best_model_path
264
+
265
+ if best_model is None:
266
+ log.warning(
267
+ "No best checkpoint model found, using last weights for test, this might lead to worse results, "
268
+ "consider using a checkpoint callback."
269
+ )
270
+
271
+ return self.trainer.test(model=self.module, datamodule=self.datamodule, ckpt_path=best_model)
272
+
273
+ def finalize(self) -> None:
274
+ """Finalize the experiment."""
275
+ super().finalize()
276
+ utils.finish(
277
+ config=self.config,
278
+ module=self.module,
279
+ datamodule=self.datamodule,
280
+ trainer=self.trainer,
281
+ callbacks=self.callbacks,
282
+ logger=self.logger,
283
+ export_folder=self.export_folder,
284
+ )
285
+
286
+ if (
287
+ not self.config.trainer.get("fast_dev_run")
288
+ and self.trainer.checkpoint_callback is not None
289
+ and hasattr(self.trainer.checkpoint_callback, "best_model_path")
290
+ ):
291
+ log.info("Best model ckpt: %s", self.trainer.checkpoint_callback.best_model_path)
292
+
293
+ def add_callback(self, callback: Callback):
294
+ """Add a callback to the trainer.
295
+
296
+ Args:
297
+ callback: The callback to add
298
+ """
299
+ if hasattr(self.trainer, "callbacks") and isinstance(self.trainer.callbacks, list):
300
+ self.trainer.callbacks.append(callback)
301
+
302
+ def execute(self) -> None:
303
+ """Execute the experiment and all the steps."""
304
+ self.prepare()
305
+ self.train()
306
+ if self.run_test:
307
+ self.test()
308
+ if self.config.export is not None and len(self.config.export.types) > 0:
309
+ self.export()
310
+ if self.report:
311
+ self.generate_report()
312
+ self.finalize()
313
+
314
+
315
+ class PlaceholderTask(Task):
316
+ """Placeholder task."""
317
+
318
+ def execute(self) -> None:
319
+ """Execute the task and all the steps."""
320
+ log.info("Running Placeholder Task.")
321
+ log.info("Quadra Version: %s", str(get_version()))
322
+ log.info("If you are reading this, it means that library is installed correctly!")
323
+
324
+
325
+ class Evaluation(Generic[DataModuleT], Task[DataModuleT]):
326
+ """Base Evaluation Task with deployment models.
327
+
328
+ Args:
329
+ config: The experiment configuration
330
+ model_path: The model path.
331
+ device: Device to use for evaluation. If None, the device is automatically determined.
332
+
333
+ """
334
+
335
+ def __init__(
336
+ self,
337
+ config: DictConfig,
338
+ model_path: str,
339
+ device: str | None = None,
340
+ ):
341
+ super().__init__(config=config)
342
+
343
+ if device is None:
344
+ self.device = utils.get_device()
345
+ else:
346
+ self.device = device
347
+
348
+ self.config = config
349
+ self.model_data: dict[str, Any]
350
+ self.model_path = model_path
351
+ self._deployment_model: BaseEvaluationModel
352
+ self.deployment_model_type: str
353
+ self.model_info_filename = "model.json"
354
+ self.report_path = ""
355
+ self.metadata = {"report_files": []}
356
+
357
+ @property
358
+ def deployment_model(self) -> BaseEvaluationModel:
359
+ """Deployment model."""
360
+ return self._deployment_model
361
+
362
+ @deployment_model.setter
363
+ def deployment_model(self, model_path: str):
364
+ """Set the deployment model."""
365
+ self._deployment_model = import_deployment_model(
366
+ model_path=model_path, device=self.device, inference_config=self.config.inference
367
+ )
368
+
369
+ def prepare(self) -> None:
370
+ """Prepare the evaluation."""
371
+ with open(os.path.join(Path(self.model_path).parent, self.model_info_filename)) as f:
372
+ self.model_data = json.load(f)
373
+
374
+ if not isinstance(self.model_data, dict):
375
+ raise ValueError("Model info file is not a valid json")
376
+
377
+ for input_size in self.model_data["input_size"]:
378
+ if len(input_size) != 3:
379
+ continue
380
+
381
+ # Adjust the transform for 2D models (CxHxW)
382
+ # We assume that each input size has the same height and width
383
+ if input_size[1] != self.config.transforms.input_height:
384
+ log.warning(
385
+ f"Input height of the model ({input_size[1]}) is different from the one specified "
386
+ + f"in the config ({self.config.transforms.input_height}). Fixing the config."
387
+ )
388
+ self.config.transforms.input_height = input_size[1]
389
+
390
+ if input_size[2] != self.config.transforms.input_width:
391
+ log.warning(
392
+ f"Input width of the model ({input_size[2]}) is different from the one specified "
393
+ + f"in the config ({self.config.transforms.input_width}). Fixing the config."
394
+ )
395
+ self.config.transforms.input_width = input_size[2]
396
+
397
+ self.deployment_model = self.model_path # type: ignore[assignment]