quadra 0.0.1__py3-none-any.whl → 2.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (302) hide show
  1. hydra_plugins/quadra_searchpath_plugin.py +30 -0
  2. quadra/__init__.py +6 -0
  3. quadra/callbacks/__init__.py +0 -0
  4. quadra/callbacks/anomalib.py +289 -0
  5. quadra/callbacks/lightning.py +501 -0
  6. quadra/callbacks/mlflow.py +291 -0
  7. quadra/callbacks/scheduler.py +69 -0
  8. quadra/configs/__init__.py +0 -0
  9. quadra/configs/backbone/caformer_m36.yaml +8 -0
  10. quadra/configs/backbone/caformer_s36.yaml +8 -0
  11. quadra/configs/backbone/convnextv2_base.yaml +8 -0
  12. quadra/configs/backbone/convnextv2_femto.yaml +8 -0
  13. quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
  14. quadra/configs/backbone/dino_vitb8.yaml +12 -0
  15. quadra/configs/backbone/dino_vits8.yaml +12 -0
  16. quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
  17. quadra/configs/backbone/dinov2_vits14.yaml +12 -0
  18. quadra/configs/backbone/efficientnet_b0.yaml +8 -0
  19. quadra/configs/backbone/efficientnet_b1.yaml +8 -0
  20. quadra/configs/backbone/efficientnet_b2.yaml +8 -0
  21. quadra/configs/backbone/efficientnet_b3.yaml +8 -0
  22. quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
  23. quadra/configs/backbone/levit_128s.yaml +8 -0
  24. quadra/configs/backbone/mnasnet0_5.yaml +9 -0
  25. quadra/configs/backbone/resnet101.yaml +8 -0
  26. quadra/configs/backbone/resnet18.yaml +8 -0
  27. quadra/configs/backbone/resnet18_ssl.yaml +8 -0
  28. quadra/configs/backbone/resnet50.yaml +8 -0
  29. quadra/configs/backbone/smp.yaml +9 -0
  30. quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
  31. quadra/configs/backbone/unetr.yaml +15 -0
  32. quadra/configs/backbone/vit16_base.yaml +9 -0
  33. quadra/configs/backbone/vit16_small.yaml +9 -0
  34. quadra/configs/backbone/vit16_tiny.yaml +9 -0
  35. quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
  36. quadra/configs/callbacks/all.yaml +32 -0
  37. quadra/configs/callbacks/default.yaml +37 -0
  38. quadra/configs/callbacks/default_anomalib.yaml +67 -0
  39. quadra/configs/config.yaml +33 -0
  40. quadra/configs/core/default.yaml +11 -0
  41. quadra/configs/datamodule/base/anomaly.yaml +16 -0
  42. quadra/configs/datamodule/base/classification.yaml +21 -0
  43. quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
  44. quadra/configs/datamodule/base/segmentation.yaml +18 -0
  45. quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
  46. quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
  47. quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
  48. quadra/configs/datamodule/base/ssl.yaml +21 -0
  49. quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
  50. quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
  51. quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
  52. quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
  53. quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
  54. quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
  55. quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
  56. quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
  57. quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
  58. quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
  59. quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
  60. quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
  61. quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
  62. quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
  63. quadra/configs/experiment/base/classification/classification.yaml +73 -0
  64. quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
  65. quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
  66. quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
  67. quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
  68. quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
  69. quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
  70. quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
  71. quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
  72. quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
  73. quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
  74. quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
  75. quadra/configs/experiment/base/ssl/byol.yaml +43 -0
  76. quadra/configs/experiment/base/ssl/dino.yaml +46 -0
  77. quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
  78. quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
  79. quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
  80. quadra/configs/experiment/custom/cls.yaml +12 -0
  81. quadra/configs/experiment/default.yaml +15 -0
  82. quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
  83. quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
  84. quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
  85. quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
  86. quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
  87. quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
  88. quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
  89. quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
  90. quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
  91. quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
  92. quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
  93. quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
  94. quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
  95. quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
  96. quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
  97. quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
  98. quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
  99. quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
  100. quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
  101. quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
  102. quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
  103. quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
  104. quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
  105. quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
  106. quadra/configs/export/default.yaml +13 -0
  107. quadra/configs/hydra/anomaly_custom.yaml +15 -0
  108. quadra/configs/hydra/default.yaml +14 -0
  109. quadra/configs/inference/default.yaml +26 -0
  110. quadra/configs/logger/comet.yaml +10 -0
  111. quadra/configs/logger/csv.yaml +5 -0
  112. quadra/configs/logger/mlflow.yaml +12 -0
  113. quadra/configs/logger/tensorboard.yaml +8 -0
  114. quadra/configs/loss/asl.yaml +7 -0
  115. quadra/configs/loss/barlow.yaml +2 -0
  116. quadra/configs/loss/bce.yaml +1 -0
  117. quadra/configs/loss/byol.yaml +1 -0
  118. quadra/configs/loss/cross_entropy.yaml +1 -0
  119. quadra/configs/loss/dino.yaml +8 -0
  120. quadra/configs/loss/simclr.yaml +2 -0
  121. quadra/configs/loss/simsiam.yaml +1 -0
  122. quadra/configs/loss/smp_ce.yaml +3 -0
  123. quadra/configs/loss/smp_dice.yaml +2 -0
  124. quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
  125. quadra/configs/loss/smp_mcc.yaml +2 -0
  126. quadra/configs/loss/vicreg.yaml +5 -0
  127. quadra/configs/model/anomalib/cfa.yaml +35 -0
  128. quadra/configs/model/anomalib/cflow.yaml +30 -0
  129. quadra/configs/model/anomalib/csflow.yaml +34 -0
  130. quadra/configs/model/anomalib/dfm.yaml +19 -0
  131. quadra/configs/model/anomalib/draem.yaml +29 -0
  132. quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
  133. quadra/configs/model/anomalib/fastflow.yaml +32 -0
  134. quadra/configs/model/anomalib/padim.yaml +32 -0
  135. quadra/configs/model/anomalib/patchcore.yaml +36 -0
  136. quadra/configs/model/barlow.yaml +16 -0
  137. quadra/configs/model/byol.yaml +25 -0
  138. quadra/configs/model/classification.yaml +10 -0
  139. quadra/configs/model/dino.yaml +26 -0
  140. quadra/configs/model/logistic_regression.yaml +4 -0
  141. quadra/configs/model/multilabel_classification.yaml +9 -0
  142. quadra/configs/model/simclr.yaml +18 -0
  143. quadra/configs/model/simsiam.yaml +24 -0
  144. quadra/configs/model/smp.yaml +4 -0
  145. quadra/configs/model/smp_multiclass.yaml +4 -0
  146. quadra/configs/model/vicreg.yaml +16 -0
  147. quadra/configs/optimizer/adam.yaml +5 -0
  148. quadra/configs/optimizer/adamw.yaml +3 -0
  149. quadra/configs/optimizer/default.yaml +4 -0
  150. quadra/configs/optimizer/lars.yaml +8 -0
  151. quadra/configs/optimizer/sgd.yaml +4 -0
  152. quadra/configs/scheduler/default.yaml +5 -0
  153. quadra/configs/scheduler/rop.yaml +5 -0
  154. quadra/configs/scheduler/step.yaml +3 -0
  155. quadra/configs/scheduler/warmrestart.yaml +2 -0
  156. quadra/configs/scheduler/warmup.yaml +6 -0
  157. quadra/configs/task/anomalib/cfa.yaml +5 -0
  158. quadra/configs/task/anomalib/cflow.yaml +5 -0
  159. quadra/configs/task/anomalib/csflow.yaml +5 -0
  160. quadra/configs/task/anomalib/draem.yaml +5 -0
  161. quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
  162. quadra/configs/task/anomalib/fastflow.yaml +5 -0
  163. quadra/configs/task/anomalib/inference.yaml +3 -0
  164. quadra/configs/task/anomalib/padim.yaml +5 -0
  165. quadra/configs/task/anomalib/patchcore.yaml +5 -0
  166. quadra/configs/task/classification.yaml +6 -0
  167. quadra/configs/task/classification_evaluation.yaml +6 -0
  168. quadra/configs/task/default.yaml +1 -0
  169. quadra/configs/task/segmentation.yaml +9 -0
  170. quadra/configs/task/segmentation_evaluation.yaml +3 -0
  171. quadra/configs/task/sklearn_classification.yaml +13 -0
  172. quadra/configs/task/sklearn_classification_patch.yaml +11 -0
  173. quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
  174. quadra/configs/task/sklearn_classification_test.yaml +8 -0
  175. quadra/configs/task/ssl.yaml +2 -0
  176. quadra/configs/trainer/lightning_cpu.yaml +36 -0
  177. quadra/configs/trainer/lightning_gpu.yaml +35 -0
  178. quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
  179. quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
  180. quadra/configs/trainer/lightning_multigpu.yaml +37 -0
  181. quadra/configs/trainer/sklearn_classification.yaml +7 -0
  182. quadra/configs/transforms/byol.yaml +47 -0
  183. quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
  184. quadra/configs/transforms/default.yaml +37 -0
  185. quadra/configs/transforms/default_numpy.yaml +24 -0
  186. quadra/configs/transforms/default_resize.yaml +22 -0
  187. quadra/configs/transforms/dino.yaml +63 -0
  188. quadra/configs/transforms/linear_eval.yaml +18 -0
  189. quadra/datamodules/__init__.py +20 -0
  190. quadra/datamodules/anomaly.py +180 -0
  191. quadra/datamodules/base.py +375 -0
  192. quadra/datamodules/classification.py +1003 -0
  193. quadra/datamodules/generic/__init__.py +0 -0
  194. quadra/datamodules/generic/imagenette.py +144 -0
  195. quadra/datamodules/generic/mnist.py +81 -0
  196. quadra/datamodules/generic/mvtec.py +58 -0
  197. quadra/datamodules/generic/oxford_pet.py +163 -0
  198. quadra/datamodules/patch.py +190 -0
  199. quadra/datamodules/segmentation.py +742 -0
  200. quadra/datamodules/ssl.py +140 -0
  201. quadra/datasets/__init__.py +17 -0
  202. quadra/datasets/anomaly.py +287 -0
  203. quadra/datasets/classification.py +241 -0
  204. quadra/datasets/patch.py +138 -0
  205. quadra/datasets/segmentation.py +239 -0
  206. quadra/datasets/ssl.py +110 -0
  207. quadra/losses/__init__.py +0 -0
  208. quadra/losses/classification/__init__.py +6 -0
  209. quadra/losses/classification/asl.py +83 -0
  210. quadra/losses/classification/focal.py +320 -0
  211. quadra/losses/classification/prototypical.py +148 -0
  212. quadra/losses/ssl/__init__.py +17 -0
  213. quadra/losses/ssl/barlowtwins.py +47 -0
  214. quadra/losses/ssl/byol.py +37 -0
  215. quadra/losses/ssl/dino.py +129 -0
  216. quadra/losses/ssl/hyperspherical.py +45 -0
  217. quadra/losses/ssl/idmm.py +50 -0
  218. quadra/losses/ssl/simclr.py +67 -0
  219. quadra/losses/ssl/simsiam.py +30 -0
  220. quadra/losses/ssl/vicreg.py +76 -0
  221. quadra/main.py +46 -0
  222. quadra/metrics/__init__.py +3 -0
  223. quadra/metrics/segmentation.py +251 -0
  224. quadra/models/__init__.py +0 -0
  225. quadra/models/base.py +151 -0
  226. quadra/models/classification/__init__.py +8 -0
  227. quadra/models/classification/backbones.py +149 -0
  228. quadra/models/classification/base.py +92 -0
  229. quadra/models/evaluation.py +322 -0
  230. quadra/modules/__init__.py +0 -0
  231. quadra/modules/backbone.py +30 -0
  232. quadra/modules/base.py +312 -0
  233. quadra/modules/classification/__init__.py +3 -0
  234. quadra/modules/classification/base.py +331 -0
  235. quadra/modules/ssl/__init__.py +17 -0
  236. quadra/modules/ssl/barlowtwins.py +59 -0
  237. quadra/modules/ssl/byol.py +172 -0
  238. quadra/modules/ssl/common.py +285 -0
  239. quadra/modules/ssl/dino.py +186 -0
  240. quadra/modules/ssl/hyperspherical.py +206 -0
  241. quadra/modules/ssl/idmm.py +98 -0
  242. quadra/modules/ssl/simclr.py +73 -0
  243. quadra/modules/ssl/simsiam.py +68 -0
  244. quadra/modules/ssl/vicreg.py +67 -0
  245. quadra/optimizers/__init__.py +4 -0
  246. quadra/optimizers/lars.py +153 -0
  247. quadra/optimizers/sam.py +127 -0
  248. quadra/schedulers/__init__.py +3 -0
  249. quadra/schedulers/base.py +44 -0
  250. quadra/schedulers/warmup.py +127 -0
  251. quadra/tasks/__init__.py +24 -0
  252. quadra/tasks/anomaly.py +582 -0
  253. quadra/tasks/base.py +397 -0
  254. quadra/tasks/classification.py +1264 -0
  255. quadra/tasks/patch.py +492 -0
  256. quadra/tasks/segmentation.py +389 -0
  257. quadra/tasks/ssl.py +560 -0
  258. quadra/trainers/README.md +3 -0
  259. quadra/trainers/__init__.py +0 -0
  260. quadra/trainers/classification.py +179 -0
  261. quadra/utils/__init__.py +0 -0
  262. quadra/utils/anomaly.py +112 -0
  263. quadra/utils/classification.py +618 -0
  264. quadra/utils/deprecation.py +31 -0
  265. quadra/utils/evaluation.py +474 -0
  266. quadra/utils/export.py +579 -0
  267. quadra/utils/imaging.py +32 -0
  268. quadra/utils/logger.py +15 -0
  269. quadra/utils/mlflow.py +98 -0
  270. quadra/utils/model_manager.py +320 -0
  271. quadra/utils/models.py +524 -0
  272. quadra/utils/patch/__init__.py +15 -0
  273. quadra/utils/patch/dataset.py +1433 -0
  274. quadra/utils/patch/metrics.py +449 -0
  275. quadra/utils/patch/model.py +153 -0
  276. quadra/utils/patch/visualization.py +217 -0
  277. quadra/utils/resolver.py +42 -0
  278. quadra/utils/segmentation.py +31 -0
  279. quadra/utils/tests/__init__.py +0 -0
  280. quadra/utils/tests/fixtures/__init__.py +1 -0
  281. quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
  282. quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
  283. quadra/utils/tests/fixtures/dataset/classification.py +406 -0
  284. quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
  285. quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
  286. quadra/utils/tests/fixtures/models/__init__.py +3 -0
  287. quadra/utils/tests/fixtures/models/anomaly.py +89 -0
  288. quadra/utils/tests/fixtures/models/classification.py +45 -0
  289. quadra/utils/tests/fixtures/models/segmentation.py +33 -0
  290. quadra/utils/tests/helpers.py +70 -0
  291. quadra/utils/tests/models.py +27 -0
  292. quadra/utils/utils.py +525 -0
  293. quadra/utils/validator.py +115 -0
  294. quadra/utils/visualization.py +422 -0
  295. quadra/utils/vit_explainability.py +349 -0
  296. quadra-2.1.13.dist-info/LICENSE +201 -0
  297. quadra-2.1.13.dist-info/METADATA +386 -0
  298. quadra-2.1.13.dist-info/RECORD +300 -0
  299. {quadra-0.0.1.dist-info → quadra-2.1.13.dist-info}/WHEEL +1 -1
  300. quadra-2.1.13.dist-info/entry_points.txt +3 -0
  301. quadra-0.0.1.dist-info/METADATA +0 -14
  302. quadra-0.0.1.dist-info/RECORD +0 -4
quadra/utils/mlflow.py ADDED
@@ -0,0 +1,98 @@
1
+ from __future__ import annotations
2
+
3
+ try:
4
+ from mlflow.models import infer_signature # noqa
5
+ from mlflow.models.signature import ModelSignature # noqa
6
+
7
+ MLFLOW_AVAILABLE = True
8
+ except ImportError:
9
+ MLFLOW_AVAILABLE = False
10
+
11
+ from collections.abc import Sequence
12
+ from typing import Any
13
+
14
+ import torch
15
+ from pytorch_lightning import Trainer
16
+ from pytorch_lightning.loggers import MLFlowLogger
17
+
18
+ from quadra.models.evaluation import BaseEvaluationModel
19
+
20
+
21
+ @torch.inference_mode()
22
+ def infer_signature_model(model: BaseEvaluationModel, data: list[Any]) -> ModelSignature | None:
23
+ """Infer input and output signature for a PyTorch/Torchscript model."""
24
+ model = model.eval()
25
+ model_output = model(*data)
26
+
27
+ try:
28
+ output_signature = infer_signature_input(model_output)
29
+
30
+ if len(data) == 1:
31
+ signature_input = infer_signature_input(data[0])
32
+ else:
33
+ signature_input = infer_signature_input(data)
34
+ except ValueError:
35
+ # TODO: Solve circular import as it is not possible to import get_logger right now
36
+ # log.warning("Unable to infer signature for model output type %s", type(model_output))
37
+ return None
38
+
39
+ return infer_signature(signature_input, output_signature)
40
+
41
+
42
+ def infer_signature_input(input_tensor: Any) -> Any:
43
+ """Recursively infer the signature input format to pass to mlflow.models.infer_signature.
44
+
45
+ Raises:
46
+ ValueError: If the input type is not supported or when nested dicts or sequences are encountered.
47
+ """
48
+ if isinstance(input_tensor, Sequence):
49
+ # Mlflow currently does not support sequence outputs, so we use a dict instead
50
+ signature = {}
51
+ for i, x in enumerate(input_tensor):
52
+ if isinstance(x, Sequence):
53
+ # Nested signature is currently not supported by mlflow
54
+ raise ValueError("Nested sequences are not supported")
55
+ # TODO: Enable this once mlflow supports nested signatures
56
+ # signature[f"output_{i}"] = {f"output_{j}": infer_signature_torch(y) for j, y in enumerate(x)}
57
+ if isinstance(x, dict):
58
+ # Nested dicts are not supported
59
+ raise ValueError("Nested dicts are not supported")
60
+
61
+ signature[f"output_{i}"] = infer_signature_input(x)
62
+ elif isinstance(input_tensor, torch.Tensor):
63
+ signature = input_tensor.cpu().numpy()
64
+ elif isinstance(input_tensor, dict):
65
+ signature = {}
66
+ for k, v in input_tensor.items():
67
+ if isinstance(v, dict):
68
+ # Nested dicts are not supported
69
+ raise ValueError("Nested dicts are not supported")
70
+ if isinstance(v, Sequence):
71
+ # Nested signature is currently not supported by mlflow
72
+ raise ValueError("Nested sequences are not supported")
73
+
74
+ signature[k] = infer_signature_input(v)
75
+ else:
76
+ raise ValueError(f"Unable to infer signature for model output type {type(input_tensor)}")
77
+
78
+ return signature
79
+
80
+
81
+ def get_mlflow_logger(trainer: Trainer) -> MLFlowLogger | None:
82
+ """Safely get Mlflow logger from Trainer loggers.
83
+
84
+ Args:
85
+ trainer: Pytorch Lightning trainer.
86
+
87
+ Returns:
88
+ An mlflow logger if available, else None.
89
+ """
90
+ if isinstance(trainer.logger, MLFlowLogger):
91
+ return trainer.logger
92
+
93
+ if isinstance(trainer.logger, list):
94
+ for logger in trainer.logger:
95
+ if isinstance(logger, MLFlowLogger):
96
+ return logger
97
+
98
+ return None
@@ -0,0 +1,320 @@
1
+ from __future__ import annotations
2
+
3
+ import getpass
4
+ import os
5
+ from abc import ABC, abstractmethod
6
+ from datetime import datetime
7
+ from typing import Any, Literal
8
+
9
+ from quadra.utils.utils import get_logger
10
+
11
+ log = get_logger(__name__)
12
+
13
+ try:
14
+ import mlflow # noqa
15
+ from mlflow.entities import Run # noqa
16
+ from mlflow.entities.model_registry import ModelVersion # noqa
17
+ from mlflow.exceptions import RestException # noqa
18
+ from mlflow.tracking import MlflowClient # noqa
19
+
20
+ MLFLOW_AVAILABLE = True
21
+ except ImportError:
22
+ MLFLOW_AVAILABLE = False
23
+
24
+
25
+ VERSION_MD_TEMPLATE = "## **Version {}**\n"
26
+ DESCRIPTION_MD_TEMPLATE = "### Description: \n{}\n"
27
+
28
+
29
+ class AbstractModelManager(ABC):
30
+ """Abstract class for model managers."""
31
+
32
+ @abstractmethod
33
+ def register_model(
34
+ self, model_location: str, model_name: str, description: str, tags: dict[str, Any] | None = None
35
+ ) -> Any:
36
+ """Register a model in the model registry."""
37
+
38
+ @abstractmethod
39
+ def get_latest_version(self, model_name: str) -> Any:
40
+ """Get the latest version of a model for all the possible stages or filtered by stage."""
41
+
42
+ @abstractmethod
43
+ def transition_model(self, model_name: str, version: int, stage: str, description: str | None = None) -> Any:
44
+ """Transition the model with the given version to a new stage."""
45
+
46
+ @abstractmethod
47
+ def delete_model(self, model_name: str, version: int, description: str | None = None) -> None:
48
+ """Delete a model with the given version."""
49
+
50
+ @abstractmethod
51
+ def register_best_model(
52
+ self,
53
+ experiment_name: str,
54
+ metric: str,
55
+ model_name: str,
56
+ description: str,
57
+ tags: dict[str, Any] | None = None,
58
+ mode: Literal["max", "min"] = "max",
59
+ model_path: str = "deployment_model",
60
+ ) -> Any:
61
+ """Register the best model from an experiment."""
62
+
63
+ @abstractmethod
64
+ def download_model(self, model_name: str, version: int, output_path: str) -> None:
65
+ """Download the model with the given version to the given output path."""
66
+
67
+
68
+ class MlflowModelManager(AbstractModelManager):
69
+ """Model manager for Mlflow."""
70
+
71
+ def __init__(self):
72
+ if not MLFLOW_AVAILABLE:
73
+ raise ImportError("Mlflow is not available, please install it with pip install mlflow")
74
+
75
+ if os.getenv("MLFLOW_TRACKING_URI") is None:
76
+ raise ValueError("MLFLOW_TRACKING_URI environment variable is not set")
77
+
78
+ mlflow.set_tracking_uri(os.getenv("MLFLOW_TRACKING_URI"))
79
+ self.client = MlflowClient()
80
+
81
+ def register_model(
82
+ self, model_location: str, model_name: str, description: str | None = None, tags: dict[str, Any] | None = None
83
+ ) -> ModelVersion:
84
+ """Register a model in the model registry.
85
+
86
+ Args:
87
+ model_location: The model uri
88
+ model_name: The name of the model after it is registered
89
+ description: A description of the model, this will be added to the model changelog
90
+ tags: A dictionary of tags to add to the model
91
+
92
+ Returns:
93
+ The model version
94
+ """
95
+ model_version = mlflow.register_model(model_uri=model_location, name=model_name, tags=tags)
96
+ log.info("Registered model %s with version %s", model_name, model_version.version)
97
+ registered_model_description = self.client.get_registered_model(model_name).description
98
+
99
+ if model_version.version == "1":
100
+ header = "# MODEL CHANGELOG\n"
101
+ else:
102
+ header = ""
103
+
104
+ new_model_description = VERSION_MD_TEMPLATE.format(model_version.version)
105
+ new_model_description += self._get_author_and_date()
106
+ new_model_description += self._generate_description(description)
107
+
108
+ self.client.update_registered_model(model_name, header + registered_model_description + new_model_description)
109
+
110
+ self.client.update_model_version(
111
+ model_name, model_version.version, "# MODEL CHANGELOG\n" + new_model_description
112
+ )
113
+
114
+ return model_version
115
+
116
+ def get_latest_version(self, model_name: str) -> ModelVersion:
117
+ """Get the latest version of a model.
118
+
119
+ Args:
120
+ model_name: The name of the model
121
+
122
+ Returns:
123
+ The model version
124
+ """
125
+ latest_version = max(int(x.version) for x in self.client.get_latest_versions(model_name))
126
+ model_version = self.client.get_model_version(model_name, latest_version)
127
+
128
+ return model_version
129
+
130
+ def transition_model(
131
+ self, model_name: str, version: int, stage: str, description: str | None = None
132
+ ) -> ModelVersion | None:
133
+ """Transition a model to a new stage.
134
+
135
+ Args:
136
+ model_name: The name of the model
137
+ version: The version of the model
138
+ stage: The stage of the model
139
+ description: A description of the transition, this will be added to the model changelog
140
+ """
141
+ previous_stage = self._safe_get_stage(model_name, version)
142
+
143
+ if previous_stage is None:
144
+ return None
145
+
146
+ if previous_stage.lower() == stage.lower():
147
+ log.warning("Model %s version %s is already in stage %s", model_name, version, stage)
148
+ return self.client.get_model_version(model_name, version)
149
+
150
+ log.info("Transitioning model %s version %s from %s to %s", model_name, version, previous_stage, stage)
151
+ model_version = self.client.transition_model_version_stage(name=model_name, version=version, stage=stage)
152
+ new_stage = model_version.current_stage
153
+ registered_model_description = self.client.get_registered_model(model_name).description
154
+ single_model_description = self.client.get_model_version(model_name, version).description
155
+
156
+ new_model_description = "## **Transition:**\n"
157
+ new_model_description += f"### Version {model_version.version} from {previous_stage} to {new_stage}\n"
158
+ new_model_description += self._get_author_and_date()
159
+ new_model_description += self._generate_description(description)
160
+
161
+ self.client.update_registered_model(model_name, registered_model_description + new_model_description)
162
+ self.client.update_model_version(
163
+ model_name, model_version.version, single_model_description + new_model_description
164
+ )
165
+
166
+ return model_version
167
+
168
+ def delete_model(self, model_name: str, version: int, description: str | None = None) -> None:
169
+ """Delete a model.
170
+
171
+ Args:
172
+ model_name: The name of the model
173
+ version: The version of the model
174
+ description: Why the model was deleted, this will be added to the model changelog
175
+ """
176
+ model_stage = self._safe_get_stage(model_name, version)
177
+
178
+ if model_stage is None:
179
+ return
180
+
181
+ if (
182
+ input(
183
+ f"Model named `{model_name}`, version {version} is in stage {model_stage}, "
184
+ "type the model name to continue deletion:"
185
+ )
186
+ != model_name
187
+ ):
188
+ log.warning("Model name did not match, aborting deletion")
189
+ return
190
+
191
+ log.info("Deleting model %s version %s", model_name, version)
192
+ self.client.delete_model_version(model_name, version)
193
+
194
+ registered_model_description = self.client.get_registered_model(model_name).description
195
+
196
+ new_model_description = "## **Deletion:**\n"
197
+ new_model_description += f"### Version {version} from stage: {model_stage}\n"
198
+ new_model_description += self._get_author_and_date()
199
+ new_model_description += self._generate_description(description)
200
+
201
+ self.client.update_registered_model(model_name, registered_model_description + new_model_description)
202
+
203
+ def register_best_model(
204
+ self,
205
+ experiment_name: str,
206
+ metric: str,
207
+ model_name: str,
208
+ description: str | None = None,
209
+ tags: dict[str, Any] | None = None,
210
+ mode: Literal["max", "min"] = "max",
211
+ model_path: str = "deployment_model",
212
+ ) -> ModelVersion | None:
213
+ """Register the best model from an experiment.
214
+
215
+ Args:
216
+ experiment_name: The name of the experiment
217
+ metric: The metric to use to determine the best model
218
+ model_name: The name of the model after it is registered
219
+ description: A description of the model, this will be added to the model changelog
220
+ tags: A dictionary of tags to add to the model
221
+ mode: The mode to use to determine the best model, either "max" or "min"
222
+ model_path: The path to the model within the experiment run
223
+
224
+ Returns:
225
+ The registered model version if successful, otherwise None
226
+ """
227
+ if mode not in ["max", "min"]:
228
+ raise ValueError(f"Mode must be either 'max' or 'min', got {mode}")
229
+
230
+ experiment_id = self.client.get_experiment_by_name(experiment_name).experiment_id
231
+ runs = self.client.search_runs(experiment_ids=[experiment_id])
232
+
233
+ if len(runs) == 0:
234
+ log.error("No runs found for experiment %s", experiment_name)
235
+ return None
236
+
237
+ best_run: Run | None = None
238
+
239
+ # We can only make comparisons if the model is on the top folder, otherwise just check if the folder exists
240
+ # TODO: Is there a better way to do this?
241
+ base_model_path = model_path.split("/")[0]
242
+
243
+ for run in runs:
244
+ run_artifacts = [x.path for x in self.client.list_artifacts(run.info.run_id) if x.path == base_model_path]
245
+
246
+ if len(run_artifacts) == 0:
247
+ # If we don't find the given model path, skip this run
248
+ continue
249
+
250
+ if best_run is None:
251
+ # If we find a run with the model it must also have the metric
252
+ if run.data.metrics.get(metric) is not None:
253
+ best_run = run
254
+ continue
255
+
256
+ if mode == "max":
257
+ if run.data.metrics[metric] > best_run.data.metrics[metric]:
258
+ best_run = run
259
+ elif run.data.metrics[metric] < best_run.data.metrics[metric]:
260
+ best_run = run
261
+
262
+ if best_run is None:
263
+ log.error("No runs found for experiment %s with the given metric", experiment_name)
264
+ return None
265
+
266
+ best_model_uri = f"runs:/{best_run.info.run_id}/{model_path}"
267
+
268
+ model_version = self.register_model(
269
+ model_location=best_model_uri, model_name=model_name, tags=tags, description=description
270
+ )
271
+
272
+ return model_version
273
+
274
+ def download_model(self, model_name: str, version: int, output_path: str) -> None:
275
+ """Download the model with the given version to the given output path.
276
+
277
+ Args:
278
+ model_name: The name of the model
279
+ version: The version of the model
280
+ output_path: The path to save the model to
281
+ """
282
+ artifact_uri = self.client.get_model_version_download_uri(model_name, version)
283
+ log.info("Downloading model %s version %s from %s to %s", model_name, version, artifact_uri, output_path)
284
+ if not os.path.exists(output_path):
285
+ log.info("Creating output path %s", output_path)
286
+ os.makedirs(output_path)
287
+ mlflow.artifacts.download_artifacts(artifact_uri=artifact_uri, dst_path=output_path)
288
+
289
+ @staticmethod
290
+ def _generate_description(description: str | None = None) -> str:
291
+ """Generate the description markdown template."""
292
+ if description is None:
293
+ return ""
294
+
295
+ return DESCRIPTION_MD_TEMPLATE.format(description)
296
+
297
+ @staticmethod
298
+ def _get_author_and_date() -> str:
299
+ """Get the author and date markdown template."""
300
+ author_and_date = f"### Author: {getpass.getuser()}\n"
301
+ author_and_date += f"### Date: {datetime.now().astimezone().strftime('%d/%m/%Y %H:%M:%S %Z')}\n"
302
+
303
+ return author_and_date
304
+
305
+ def _safe_get_stage(self, model_name: str, version: int) -> str | None:
306
+ """Get the stage of a model version.
307
+
308
+ Args:
309
+ model_name: The name of the model
310
+ version: The version of the model
311
+
312
+ Returns:
313
+ The stage of the model version if it exists, otherwise None
314
+ """
315
+ try:
316
+ model_stage = self.client.get_model_version(model_name, version).current_stage
317
+ return model_stage
318
+ except RestException:
319
+ log.error("Model named %s with version %s does not exist", model_name, version)
320
+ return None