quadra 0.0.1__py3-none-any.whl → 2.1.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hydra_plugins/quadra_searchpath_plugin.py +30 -0
- quadra/__init__.py +6 -0
- quadra/callbacks/__init__.py +0 -0
- quadra/callbacks/anomalib.py +289 -0
- quadra/callbacks/lightning.py +501 -0
- quadra/callbacks/mlflow.py +291 -0
- quadra/callbacks/scheduler.py +69 -0
- quadra/configs/__init__.py +0 -0
- quadra/configs/backbone/caformer_m36.yaml +8 -0
- quadra/configs/backbone/caformer_s36.yaml +8 -0
- quadra/configs/backbone/convnextv2_base.yaml +8 -0
- quadra/configs/backbone/convnextv2_femto.yaml +8 -0
- quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
- quadra/configs/backbone/dino_vitb8.yaml +12 -0
- quadra/configs/backbone/dino_vits8.yaml +12 -0
- quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
- quadra/configs/backbone/dinov2_vits14.yaml +12 -0
- quadra/configs/backbone/efficientnet_b0.yaml +8 -0
- quadra/configs/backbone/efficientnet_b1.yaml +8 -0
- quadra/configs/backbone/efficientnet_b2.yaml +8 -0
- quadra/configs/backbone/efficientnet_b3.yaml +8 -0
- quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
- quadra/configs/backbone/levit_128s.yaml +8 -0
- quadra/configs/backbone/mnasnet0_5.yaml +9 -0
- quadra/configs/backbone/resnet101.yaml +8 -0
- quadra/configs/backbone/resnet18.yaml +8 -0
- quadra/configs/backbone/resnet18_ssl.yaml +8 -0
- quadra/configs/backbone/resnet50.yaml +8 -0
- quadra/configs/backbone/smp.yaml +9 -0
- quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
- quadra/configs/backbone/unetr.yaml +15 -0
- quadra/configs/backbone/vit16_base.yaml +9 -0
- quadra/configs/backbone/vit16_small.yaml +9 -0
- quadra/configs/backbone/vit16_tiny.yaml +9 -0
- quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
- quadra/configs/callbacks/all.yaml +32 -0
- quadra/configs/callbacks/default.yaml +37 -0
- quadra/configs/callbacks/default_anomalib.yaml +67 -0
- quadra/configs/config.yaml +33 -0
- quadra/configs/core/default.yaml +11 -0
- quadra/configs/datamodule/base/anomaly.yaml +16 -0
- quadra/configs/datamodule/base/classification.yaml +21 -0
- quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
- quadra/configs/datamodule/base/segmentation.yaml +18 -0
- quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
- quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
- quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
- quadra/configs/datamodule/base/ssl.yaml +21 -0
- quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
- quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
- quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
- quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
- quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
- quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
- quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
- quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
- quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
- quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
- quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/base/classification/classification.yaml +73 -0
- quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
- quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
- quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
- quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
- quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
- quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
- quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
- quadra/configs/experiment/base/ssl/byol.yaml +43 -0
- quadra/configs/experiment/base/ssl/dino.yaml +46 -0
- quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
- quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
- quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
- quadra/configs/experiment/custom/cls.yaml +12 -0
- quadra/configs/experiment/default.yaml +15 -0
- quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
- quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
- quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
- quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
- quadra/configs/export/default.yaml +13 -0
- quadra/configs/hydra/anomaly_custom.yaml +15 -0
- quadra/configs/hydra/default.yaml +14 -0
- quadra/configs/inference/default.yaml +26 -0
- quadra/configs/logger/comet.yaml +10 -0
- quadra/configs/logger/csv.yaml +5 -0
- quadra/configs/logger/mlflow.yaml +12 -0
- quadra/configs/logger/tensorboard.yaml +8 -0
- quadra/configs/loss/asl.yaml +7 -0
- quadra/configs/loss/barlow.yaml +2 -0
- quadra/configs/loss/bce.yaml +1 -0
- quadra/configs/loss/byol.yaml +1 -0
- quadra/configs/loss/cross_entropy.yaml +1 -0
- quadra/configs/loss/dino.yaml +8 -0
- quadra/configs/loss/simclr.yaml +2 -0
- quadra/configs/loss/simsiam.yaml +1 -0
- quadra/configs/loss/smp_ce.yaml +3 -0
- quadra/configs/loss/smp_dice.yaml +2 -0
- quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
- quadra/configs/loss/smp_mcc.yaml +2 -0
- quadra/configs/loss/vicreg.yaml +5 -0
- quadra/configs/model/anomalib/cfa.yaml +35 -0
- quadra/configs/model/anomalib/cflow.yaml +30 -0
- quadra/configs/model/anomalib/csflow.yaml +34 -0
- quadra/configs/model/anomalib/dfm.yaml +19 -0
- quadra/configs/model/anomalib/draem.yaml +29 -0
- quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
- quadra/configs/model/anomalib/fastflow.yaml +32 -0
- quadra/configs/model/anomalib/padim.yaml +32 -0
- quadra/configs/model/anomalib/patchcore.yaml +36 -0
- quadra/configs/model/barlow.yaml +16 -0
- quadra/configs/model/byol.yaml +25 -0
- quadra/configs/model/classification.yaml +10 -0
- quadra/configs/model/dino.yaml +26 -0
- quadra/configs/model/logistic_regression.yaml +4 -0
- quadra/configs/model/multilabel_classification.yaml +9 -0
- quadra/configs/model/simclr.yaml +18 -0
- quadra/configs/model/simsiam.yaml +24 -0
- quadra/configs/model/smp.yaml +4 -0
- quadra/configs/model/smp_multiclass.yaml +4 -0
- quadra/configs/model/vicreg.yaml +16 -0
- quadra/configs/optimizer/adam.yaml +5 -0
- quadra/configs/optimizer/adamw.yaml +3 -0
- quadra/configs/optimizer/default.yaml +4 -0
- quadra/configs/optimizer/lars.yaml +8 -0
- quadra/configs/optimizer/sgd.yaml +4 -0
- quadra/configs/scheduler/default.yaml +5 -0
- quadra/configs/scheduler/rop.yaml +5 -0
- quadra/configs/scheduler/step.yaml +3 -0
- quadra/configs/scheduler/warmrestart.yaml +2 -0
- quadra/configs/scheduler/warmup.yaml +6 -0
- quadra/configs/task/anomalib/cfa.yaml +5 -0
- quadra/configs/task/anomalib/cflow.yaml +5 -0
- quadra/configs/task/anomalib/csflow.yaml +5 -0
- quadra/configs/task/anomalib/draem.yaml +5 -0
- quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
- quadra/configs/task/anomalib/fastflow.yaml +5 -0
- quadra/configs/task/anomalib/inference.yaml +3 -0
- quadra/configs/task/anomalib/padim.yaml +5 -0
- quadra/configs/task/anomalib/patchcore.yaml +5 -0
- quadra/configs/task/classification.yaml +6 -0
- quadra/configs/task/classification_evaluation.yaml +6 -0
- quadra/configs/task/default.yaml +1 -0
- quadra/configs/task/segmentation.yaml +9 -0
- quadra/configs/task/segmentation_evaluation.yaml +3 -0
- quadra/configs/task/sklearn_classification.yaml +13 -0
- quadra/configs/task/sklearn_classification_patch.yaml +11 -0
- quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
- quadra/configs/task/sklearn_classification_test.yaml +8 -0
- quadra/configs/task/ssl.yaml +2 -0
- quadra/configs/trainer/lightning_cpu.yaml +36 -0
- quadra/configs/trainer/lightning_gpu.yaml +35 -0
- quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
- quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
- quadra/configs/trainer/lightning_multigpu.yaml +37 -0
- quadra/configs/trainer/sklearn_classification.yaml +7 -0
- quadra/configs/transforms/byol.yaml +47 -0
- quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
- quadra/configs/transforms/default.yaml +37 -0
- quadra/configs/transforms/default_numpy.yaml +24 -0
- quadra/configs/transforms/default_resize.yaml +22 -0
- quadra/configs/transforms/dino.yaml +63 -0
- quadra/configs/transforms/linear_eval.yaml +18 -0
- quadra/datamodules/__init__.py +20 -0
- quadra/datamodules/anomaly.py +180 -0
- quadra/datamodules/base.py +375 -0
- quadra/datamodules/classification.py +1003 -0
- quadra/datamodules/generic/__init__.py +0 -0
- quadra/datamodules/generic/imagenette.py +144 -0
- quadra/datamodules/generic/mnist.py +81 -0
- quadra/datamodules/generic/mvtec.py +58 -0
- quadra/datamodules/generic/oxford_pet.py +163 -0
- quadra/datamodules/patch.py +190 -0
- quadra/datamodules/segmentation.py +742 -0
- quadra/datamodules/ssl.py +140 -0
- quadra/datasets/__init__.py +17 -0
- quadra/datasets/anomaly.py +287 -0
- quadra/datasets/classification.py +241 -0
- quadra/datasets/patch.py +138 -0
- quadra/datasets/segmentation.py +239 -0
- quadra/datasets/ssl.py +110 -0
- quadra/losses/__init__.py +0 -0
- quadra/losses/classification/__init__.py +6 -0
- quadra/losses/classification/asl.py +83 -0
- quadra/losses/classification/focal.py +320 -0
- quadra/losses/classification/prototypical.py +148 -0
- quadra/losses/ssl/__init__.py +17 -0
- quadra/losses/ssl/barlowtwins.py +47 -0
- quadra/losses/ssl/byol.py +37 -0
- quadra/losses/ssl/dino.py +129 -0
- quadra/losses/ssl/hyperspherical.py +45 -0
- quadra/losses/ssl/idmm.py +50 -0
- quadra/losses/ssl/simclr.py +67 -0
- quadra/losses/ssl/simsiam.py +30 -0
- quadra/losses/ssl/vicreg.py +76 -0
- quadra/main.py +46 -0
- quadra/metrics/__init__.py +3 -0
- quadra/metrics/segmentation.py +251 -0
- quadra/models/__init__.py +0 -0
- quadra/models/base.py +151 -0
- quadra/models/classification/__init__.py +8 -0
- quadra/models/classification/backbones.py +149 -0
- quadra/models/classification/base.py +92 -0
- quadra/models/evaluation.py +322 -0
- quadra/modules/__init__.py +0 -0
- quadra/modules/backbone.py +30 -0
- quadra/modules/base.py +312 -0
- quadra/modules/classification/__init__.py +3 -0
- quadra/modules/classification/base.py +331 -0
- quadra/modules/ssl/__init__.py +17 -0
- quadra/modules/ssl/barlowtwins.py +59 -0
- quadra/modules/ssl/byol.py +172 -0
- quadra/modules/ssl/common.py +285 -0
- quadra/modules/ssl/dino.py +186 -0
- quadra/modules/ssl/hyperspherical.py +206 -0
- quadra/modules/ssl/idmm.py +98 -0
- quadra/modules/ssl/simclr.py +73 -0
- quadra/modules/ssl/simsiam.py +68 -0
- quadra/modules/ssl/vicreg.py +67 -0
- quadra/optimizers/__init__.py +4 -0
- quadra/optimizers/lars.py +153 -0
- quadra/optimizers/sam.py +127 -0
- quadra/schedulers/__init__.py +3 -0
- quadra/schedulers/base.py +44 -0
- quadra/schedulers/warmup.py +127 -0
- quadra/tasks/__init__.py +24 -0
- quadra/tasks/anomaly.py +582 -0
- quadra/tasks/base.py +397 -0
- quadra/tasks/classification.py +1264 -0
- quadra/tasks/patch.py +492 -0
- quadra/tasks/segmentation.py +389 -0
- quadra/tasks/ssl.py +560 -0
- quadra/trainers/README.md +3 -0
- quadra/trainers/__init__.py +0 -0
- quadra/trainers/classification.py +179 -0
- quadra/utils/__init__.py +0 -0
- quadra/utils/anomaly.py +112 -0
- quadra/utils/classification.py +618 -0
- quadra/utils/deprecation.py +31 -0
- quadra/utils/evaluation.py +474 -0
- quadra/utils/export.py +579 -0
- quadra/utils/imaging.py +32 -0
- quadra/utils/logger.py +15 -0
- quadra/utils/mlflow.py +98 -0
- quadra/utils/model_manager.py +320 -0
- quadra/utils/models.py +524 -0
- quadra/utils/patch/__init__.py +15 -0
- quadra/utils/patch/dataset.py +1433 -0
- quadra/utils/patch/metrics.py +449 -0
- quadra/utils/patch/model.py +153 -0
- quadra/utils/patch/visualization.py +217 -0
- quadra/utils/resolver.py +42 -0
- quadra/utils/segmentation.py +31 -0
- quadra/utils/tests/__init__.py +0 -0
- quadra/utils/tests/fixtures/__init__.py +1 -0
- quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
- quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
- quadra/utils/tests/fixtures/dataset/classification.py +406 -0
- quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
- quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
- quadra/utils/tests/fixtures/models/__init__.py +3 -0
- quadra/utils/tests/fixtures/models/anomaly.py +89 -0
- quadra/utils/tests/fixtures/models/classification.py +45 -0
- quadra/utils/tests/fixtures/models/segmentation.py +33 -0
- quadra/utils/tests/helpers.py +70 -0
- quadra/utils/tests/models.py +27 -0
- quadra/utils/utils.py +525 -0
- quadra/utils/validator.py +115 -0
- quadra/utils/visualization.py +422 -0
- quadra/utils/vit_explainability.py +349 -0
- quadra-2.1.13.dist-info/LICENSE +201 -0
- quadra-2.1.13.dist-info/METADATA +386 -0
- quadra-2.1.13.dist-info/RECORD +300 -0
- {quadra-0.0.1.dist-info → quadra-2.1.13.dist-info}/WHEEL +1 -1
- quadra-2.1.13.dist-info/entry_points.txt +3 -0
- quadra-0.0.1.dist-info/METADATA +0 -14
- quadra-0.0.1.dist-info/RECORD +0 -4
quadra/utils/mlflow.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
try:
|
|
4
|
+
from mlflow.models import infer_signature # noqa
|
|
5
|
+
from mlflow.models.signature import ModelSignature # noqa
|
|
6
|
+
|
|
7
|
+
MLFLOW_AVAILABLE = True
|
|
8
|
+
except ImportError:
|
|
9
|
+
MLFLOW_AVAILABLE = False
|
|
10
|
+
|
|
11
|
+
from collections.abc import Sequence
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
import torch
|
|
15
|
+
from pytorch_lightning import Trainer
|
|
16
|
+
from pytorch_lightning.loggers import MLFlowLogger
|
|
17
|
+
|
|
18
|
+
from quadra.models.evaluation import BaseEvaluationModel
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@torch.inference_mode()
|
|
22
|
+
def infer_signature_model(model: BaseEvaluationModel, data: list[Any]) -> ModelSignature | None:
|
|
23
|
+
"""Infer input and output signature for a PyTorch/Torchscript model."""
|
|
24
|
+
model = model.eval()
|
|
25
|
+
model_output = model(*data)
|
|
26
|
+
|
|
27
|
+
try:
|
|
28
|
+
output_signature = infer_signature_input(model_output)
|
|
29
|
+
|
|
30
|
+
if len(data) == 1:
|
|
31
|
+
signature_input = infer_signature_input(data[0])
|
|
32
|
+
else:
|
|
33
|
+
signature_input = infer_signature_input(data)
|
|
34
|
+
except ValueError:
|
|
35
|
+
# TODO: Solve circular import as it is not possible to import get_logger right now
|
|
36
|
+
# log.warning("Unable to infer signature for model output type %s", type(model_output))
|
|
37
|
+
return None
|
|
38
|
+
|
|
39
|
+
return infer_signature(signature_input, output_signature)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def infer_signature_input(input_tensor: Any) -> Any:
|
|
43
|
+
"""Recursively infer the signature input format to pass to mlflow.models.infer_signature.
|
|
44
|
+
|
|
45
|
+
Raises:
|
|
46
|
+
ValueError: If the input type is not supported or when nested dicts or sequences are encountered.
|
|
47
|
+
"""
|
|
48
|
+
if isinstance(input_tensor, Sequence):
|
|
49
|
+
# Mlflow currently does not support sequence outputs, so we use a dict instead
|
|
50
|
+
signature = {}
|
|
51
|
+
for i, x in enumerate(input_tensor):
|
|
52
|
+
if isinstance(x, Sequence):
|
|
53
|
+
# Nested signature is currently not supported by mlflow
|
|
54
|
+
raise ValueError("Nested sequences are not supported")
|
|
55
|
+
# TODO: Enable this once mlflow supports nested signatures
|
|
56
|
+
# signature[f"output_{i}"] = {f"output_{j}": infer_signature_torch(y) for j, y in enumerate(x)}
|
|
57
|
+
if isinstance(x, dict):
|
|
58
|
+
# Nested dicts are not supported
|
|
59
|
+
raise ValueError("Nested dicts are not supported")
|
|
60
|
+
|
|
61
|
+
signature[f"output_{i}"] = infer_signature_input(x)
|
|
62
|
+
elif isinstance(input_tensor, torch.Tensor):
|
|
63
|
+
signature = input_tensor.cpu().numpy()
|
|
64
|
+
elif isinstance(input_tensor, dict):
|
|
65
|
+
signature = {}
|
|
66
|
+
for k, v in input_tensor.items():
|
|
67
|
+
if isinstance(v, dict):
|
|
68
|
+
# Nested dicts are not supported
|
|
69
|
+
raise ValueError("Nested dicts are not supported")
|
|
70
|
+
if isinstance(v, Sequence):
|
|
71
|
+
# Nested signature is currently not supported by mlflow
|
|
72
|
+
raise ValueError("Nested sequences are not supported")
|
|
73
|
+
|
|
74
|
+
signature[k] = infer_signature_input(v)
|
|
75
|
+
else:
|
|
76
|
+
raise ValueError(f"Unable to infer signature for model output type {type(input_tensor)}")
|
|
77
|
+
|
|
78
|
+
return signature
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def get_mlflow_logger(trainer: Trainer) -> MLFlowLogger | None:
|
|
82
|
+
"""Safely get Mlflow logger from Trainer loggers.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
trainer: Pytorch Lightning trainer.
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
An mlflow logger if available, else None.
|
|
89
|
+
"""
|
|
90
|
+
if isinstance(trainer.logger, MLFlowLogger):
|
|
91
|
+
return trainer.logger
|
|
92
|
+
|
|
93
|
+
if isinstance(trainer.logger, list):
|
|
94
|
+
for logger in trainer.logger:
|
|
95
|
+
if isinstance(logger, MLFlowLogger):
|
|
96
|
+
return logger
|
|
97
|
+
|
|
98
|
+
return None
|
|
@@ -0,0 +1,320 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import getpass
|
|
4
|
+
import os
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from datetime import datetime
|
|
7
|
+
from typing import Any, Literal
|
|
8
|
+
|
|
9
|
+
from quadra.utils.utils import get_logger
|
|
10
|
+
|
|
11
|
+
log = get_logger(__name__)
|
|
12
|
+
|
|
13
|
+
try:
|
|
14
|
+
import mlflow # noqa
|
|
15
|
+
from mlflow.entities import Run # noqa
|
|
16
|
+
from mlflow.entities.model_registry import ModelVersion # noqa
|
|
17
|
+
from mlflow.exceptions import RestException # noqa
|
|
18
|
+
from mlflow.tracking import MlflowClient # noqa
|
|
19
|
+
|
|
20
|
+
MLFLOW_AVAILABLE = True
|
|
21
|
+
except ImportError:
|
|
22
|
+
MLFLOW_AVAILABLE = False
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
VERSION_MD_TEMPLATE = "## **Version {}**\n"
|
|
26
|
+
DESCRIPTION_MD_TEMPLATE = "### Description: \n{}\n"
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class AbstractModelManager(ABC):
|
|
30
|
+
"""Abstract class for model managers."""
|
|
31
|
+
|
|
32
|
+
@abstractmethod
|
|
33
|
+
def register_model(
|
|
34
|
+
self, model_location: str, model_name: str, description: str, tags: dict[str, Any] | None = None
|
|
35
|
+
) -> Any:
|
|
36
|
+
"""Register a model in the model registry."""
|
|
37
|
+
|
|
38
|
+
@abstractmethod
|
|
39
|
+
def get_latest_version(self, model_name: str) -> Any:
|
|
40
|
+
"""Get the latest version of a model for all the possible stages or filtered by stage."""
|
|
41
|
+
|
|
42
|
+
@abstractmethod
|
|
43
|
+
def transition_model(self, model_name: str, version: int, stage: str, description: str | None = None) -> Any:
|
|
44
|
+
"""Transition the model with the given version to a new stage."""
|
|
45
|
+
|
|
46
|
+
@abstractmethod
|
|
47
|
+
def delete_model(self, model_name: str, version: int, description: str | None = None) -> None:
|
|
48
|
+
"""Delete a model with the given version."""
|
|
49
|
+
|
|
50
|
+
@abstractmethod
|
|
51
|
+
def register_best_model(
|
|
52
|
+
self,
|
|
53
|
+
experiment_name: str,
|
|
54
|
+
metric: str,
|
|
55
|
+
model_name: str,
|
|
56
|
+
description: str,
|
|
57
|
+
tags: dict[str, Any] | None = None,
|
|
58
|
+
mode: Literal["max", "min"] = "max",
|
|
59
|
+
model_path: str = "deployment_model",
|
|
60
|
+
) -> Any:
|
|
61
|
+
"""Register the best model from an experiment."""
|
|
62
|
+
|
|
63
|
+
@abstractmethod
|
|
64
|
+
def download_model(self, model_name: str, version: int, output_path: str) -> None:
|
|
65
|
+
"""Download the model with the given version to the given output path."""
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class MlflowModelManager(AbstractModelManager):
|
|
69
|
+
"""Model manager for Mlflow."""
|
|
70
|
+
|
|
71
|
+
def __init__(self):
|
|
72
|
+
if not MLFLOW_AVAILABLE:
|
|
73
|
+
raise ImportError("Mlflow is not available, please install it with pip install mlflow")
|
|
74
|
+
|
|
75
|
+
if os.getenv("MLFLOW_TRACKING_URI") is None:
|
|
76
|
+
raise ValueError("MLFLOW_TRACKING_URI environment variable is not set")
|
|
77
|
+
|
|
78
|
+
mlflow.set_tracking_uri(os.getenv("MLFLOW_TRACKING_URI"))
|
|
79
|
+
self.client = MlflowClient()
|
|
80
|
+
|
|
81
|
+
def register_model(
|
|
82
|
+
self, model_location: str, model_name: str, description: str | None = None, tags: dict[str, Any] | None = None
|
|
83
|
+
) -> ModelVersion:
|
|
84
|
+
"""Register a model in the model registry.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
model_location: The model uri
|
|
88
|
+
model_name: The name of the model after it is registered
|
|
89
|
+
description: A description of the model, this will be added to the model changelog
|
|
90
|
+
tags: A dictionary of tags to add to the model
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
The model version
|
|
94
|
+
"""
|
|
95
|
+
model_version = mlflow.register_model(model_uri=model_location, name=model_name, tags=tags)
|
|
96
|
+
log.info("Registered model %s with version %s", model_name, model_version.version)
|
|
97
|
+
registered_model_description = self.client.get_registered_model(model_name).description
|
|
98
|
+
|
|
99
|
+
if model_version.version == "1":
|
|
100
|
+
header = "# MODEL CHANGELOG\n"
|
|
101
|
+
else:
|
|
102
|
+
header = ""
|
|
103
|
+
|
|
104
|
+
new_model_description = VERSION_MD_TEMPLATE.format(model_version.version)
|
|
105
|
+
new_model_description += self._get_author_and_date()
|
|
106
|
+
new_model_description += self._generate_description(description)
|
|
107
|
+
|
|
108
|
+
self.client.update_registered_model(model_name, header + registered_model_description + new_model_description)
|
|
109
|
+
|
|
110
|
+
self.client.update_model_version(
|
|
111
|
+
model_name, model_version.version, "# MODEL CHANGELOG\n" + new_model_description
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
return model_version
|
|
115
|
+
|
|
116
|
+
def get_latest_version(self, model_name: str) -> ModelVersion:
|
|
117
|
+
"""Get the latest version of a model.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
model_name: The name of the model
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
The model version
|
|
124
|
+
"""
|
|
125
|
+
latest_version = max(int(x.version) for x in self.client.get_latest_versions(model_name))
|
|
126
|
+
model_version = self.client.get_model_version(model_name, latest_version)
|
|
127
|
+
|
|
128
|
+
return model_version
|
|
129
|
+
|
|
130
|
+
def transition_model(
|
|
131
|
+
self, model_name: str, version: int, stage: str, description: str | None = None
|
|
132
|
+
) -> ModelVersion | None:
|
|
133
|
+
"""Transition a model to a new stage.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
model_name: The name of the model
|
|
137
|
+
version: The version of the model
|
|
138
|
+
stage: The stage of the model
|
|
139
|
+
description: A description of the transition, this will be added to the model changelog
|
|
140
|
+
"""
|
|
141
|
+
previous_stage = self._safe_get_stage(model_name, version)
|
|
142
|
+
|
|
143
|
+
if previous_stage is None:
|
|
144
|
+
return None
|
|
145
|
+
|
|
146
|
+
if previous_stage.lower() == stage.lower():
|
|
147
|
+
log.warning("Model %s version %s is already in stage %s", model_name, version, stage)
|
|
148
|
+
return self.client.get_model_version(model_name, version)
|
|
149
|
+
|
|
150
|
+
log.info("Transitioning model %s version %s from %s to %s", model_name, version, previous_stage, stage)
|
|
151
|
+
model_version = self.client.transition_model_version_stage(name=model_name, version=version, stage=stage)
|
|
152
|
+
new_stage = model_version.current_stage
|
|
153
|
+
registered_model_description = self.client.get_registered_model(model_name).description
|
|
154
|
+
single_model_description = self.client.get_model_version(model_name, version).description
|
|
155
|
+
|
|
156
|
+
new_model_description = "## **Transition:**\n"
|
|
157
|
+
new_model_description += f"### Version {model_version.version} from {previous_stage} to {new_stage}\n"
|
|
158
|
+
new_model_description += self._get_author_and_date()
|
|
159
|
+
new_model_description += self._generate_description(description)
|
|
160
|
+
|
|
161
|
+
self.client.update_registered_model(model_name, registered_model_description + new_model_description)
|
|
162
|
+
self.client.update_model_version(
|
|
163
|
+
model_name, model_version.version, single_model_description + new_model_description
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
return model_version
|
|
167
|
+
|
|
168
|
+
def delete_model(self, model_name: str, version: int, description: str | None = None) -> None:
|
|
169
|
+
"""Delete a model.
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
model_name: The name of the model
|
|
173
|
+
version: The version of the model
|
|
174
|
+
description: Why the model was deleted, this will be added to the model changelog
|
|
175
|
+
"""
|
|
176
|
+
model_stage = self._safe_get_stage(model_name, version)
|
|
177
|
+
|
|
178
|
+
if model_stage is None:
|
|
179
|
+
return
|
|
180
|
+
|
|
181
|
+
if (
|
|
182
|
+
input(
|
|
183
|
+
f"Model named `{model_name}`, version {version} is in stage {model_stage}, "
|
|
184
|
+
"type the model name to continue deletion:"
|
|
185
|
+
)
|
|
186
|
+
!= model_name
|
|
187
|
+
):
|
|
188
|
+
log.warning("Model name did not match, aborting deletion")
|
|
189
|
+
return
|
|
190
|
+
|
|
191
|
+
log.info("Deleting model %s version %s", model_name, version)
|
|
192
|
+
self.client.delete_model_version(model_name, version)
|
|
193
|
+
|
|
194
|
+
registered_model_description = self.client.get_registered_model(model_name).description
|
|
195
|
+
|
|
196
|
+
new_model_description = "## **Deletion:**\n"
|
|
197
|
+
new_model_description += f"### Version {version} from stage: {model_stage}\n"
|
|
198
|
+
new_model_description += self._get_author_and_date()
|
|
199
|
+
new_model_description += self._generate_description(description)
|
|
200
|
+
|
|
201
|
+
self.client.update_registered_model(model_name, registered_model_description + new_model_description)
|
|
202
|
+
|
|
203
|
+
def register_best_model(
|
|
204
|
+
self,
|
|
205
|
+
experiment_name: str,
|
|
206
|
+
metric: str,
|
|
207
|
+
model_name: str,
|
|
208
|
+
description: str | None = None,
|
|
209
|
+
tags: dict[str, Any] | None = None,
|
|
210
|
+
mode: Literal["max", "min"] = "max",
|
|
211
|
+
model_path: str = "deployment_model",
|
|
212
|
+
) -> ModelVersion | None:
|
|
213
|
+
"""Register the best model from an experiment.
|
|
214
|
+
|
|
215
|
+
Args:
|
|
216
|
+
experiment_name: The name of the experiment
|
|
217
|
+
metric: The metric to use to determine the best model
|
|
218
|
+
model_name: The name of the model after it is registered
|
|
219
|
+
description: A description of the model, this will be added to the model changelog
|
|
220
|
+
tags: A dictionary of tags to add to the model
|
|
221
|
+
mode: The mode to use to determine the best model, either "max" or "min"
|
|
222
|
+
model_path: The path to the model within the experiment run
|
|
223
|
+
|
|
224
|
+
Returns:
|
|
225
|
+
The registered model version if successful, otherwise None
|
|
226
|
+
"""
|
|
227
|
+
if mode not in ["max", "min"]:
|
|
228
|
+
raise ValueError(f"Mode must be either 'max' or 'min', got {mode}")
|
|
229
|
+
|
|
230
|
+
experiment_id = self.client.get_experiment_by_name(experiment_name).experiment_id
|
|
231
|
+
runs = self.client.search_runs(experiment_ids=[experiment_id])
|
|
232
|
+
|
|
233
|
+
if len(runs) == 0:
|
|
234
|
+
log.error("No runs found for experiment %s", experiment_name)
|
|
235
|
+
return None
|
|
236
|
+
|
|
237
|
+
best_run: Run | None = None
|
|
238
|
+
|
|
239
|
+
# We can only make comparisons if the model is on the top folder, otherwise just check if the folder exists
|
|
240
|
+
# TODO: Is there a better way to do this?
|
|
241
|
+
base_model_path = model_path.split("/")[0]
|
|
242
|
+
|
|
243
|
+
for run in runs:
|
|
244
|
+
run_artifacts = [x.path for x in self.client.list_artifacts(run.info.run_id) if x.path == base_model_path]
|
|
245
|
+
|
|
246
|
+
if len(run_artifacts) == 0:
|
|
247
|
+
# If we don't find the given model path, skip this run
|
|
248
|
+
continue
|
|
249
|
+
|
|
250
|
+
if best_run is None:
|
|
251
|
+
# If we find a run with the model it must also have the metric
|
|
252
|
+
if run.data.metrics.get(metric) is not None:
|
|
253
|
+
best_run = run
|
|
254
|
+
continue
|
|
255
|
+
|
|
256
|
+
if mode == "max":
|
|
257
|
+
if run.data.metrics[metric] > best_run.data.metrics[metric]:
|
|
258
|
+
best_run = run
|
|
259
|
+
elif run.data.metrics[metric] < best_run.data.metrics[metric]:
|
|
260
|
+
best_run = run
|
|
261
|
+
|
|
262
|
+
if best_run is None:
|
|
263
|
+
log.error("No runs found for experiment %s with the given metric", experiment_name)
|
|
264
|
+
return None
|
|
265
|
+
|
|
266
|
+
best_model_uri = f"runs:/{best_run.info.run_id}/{model_path}"
|
|
267
|
+
|
|
268
|
+
model_version = self.register_model(
|
|
269
|
+
model_location=best_model_uri, model_name=model_name, tags=tags, description=description
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
return model_version
|
|
273
|
+
|
|
274
|
+
def download_model(self, model_name: str, version: int, output_path: str) -> None:
|
|
275
|
+
"""Download the model with the given version to the given output path.
|
|
276
|
+
|
|
277
|
+
Args:
|
|
278
|
+
model_name: The name of the model
|
|
279
|
+
version: The version of the model
|
|
280
|
+
output_path: The path to save the model to
|
|
281
|
+
"""
|
|
282
|
+
artifact_uri = self.client.get_model_version_download_uri(model_name, version)
|
|
283
|
+
log.info("Downloading model %s version %s from %s to %s", model_name, version, artifact_uri, output_path)
|
|
284
|
+
if not os.path.exists(output_path):
|
|
285
|
+
log.info("Creating output path %s", output_path)
|
|
286
|
+
os.makedirs(output_path)
|
|
287
|
+
mlflow.artifacts.download_artifacts(artifact_uri=artifact_uri, dst_path=output_path)
|
|
288
|
+
|
|
289
|
+
@staticmethod
|
|
290
|
+
def _generate_description(description: str | None = None) -> str:
|
|
291
|
+
"""Generate the description markdown template."""
|
|
292
|
+
if description is None:
|
|
293
|
+
return ""
|
|
294
|
+
|
|
295
|
+
return DESCRIPTION_MD_TEMPLATE.format(description)
|
|
296
|
+
|
|
297
|
+
@staticmethod
|
|
298
|
+
def _get_author_and_date() -> str:
|
|
299
|
+
"""Get the author and date markdown template."""
|
|
300
|
+
author_and_date = f"### Author: {getpass.getuser()}\n"
|
|
301
|
+
author_and_date += f"### Date: {datetime.now().astimezone().strftime('%d/%m/%Y %H:%M:%S %Z')}\n"
|
|
302
|
+
|
|
303
|
+
return author_and_date
|
|
304
|
+
|
|
305
|
+
def _safe_get_stage(self, model_name: str, version: int) -> str | None:
|
|
306
|
+
"""Get the stage of a model version.
|
|
307
|
+
|
|
308
|
+
Args:
|
|
309
|
+
model_name: The name of the model
|
|
310
|
+
version: The version of the model
|
|
311
|
+
|
|
312
|
+
Returns:
|
|
313
|
+
The stage of the model version if it exists, otherwise None
|
|
314
|
+
"""
|
|
315
|
+
try:
|
|
316
|
+
model_stage = self.client.get_model_version(model_name, version).current_stage
|
|
317
|
+
return model_stage
|
|
318
|
+
except RestException:
|
|
319
|
+
log.error("Model named %s with version %s does not exist", model_name, version)
|
|
320
|
+
return None
|