quadra 0.0.1__py3-none-any.whl → 2.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hydra_plugins/quadra_searchpath_plugin.py +30 -0
- quadra/__init__.py +6 -0
- quadra/callbacks/__init__.py +0 -0
- quadra/callbacks/anomalib.py +289 -0
- quadra/callbacks/lightning.py +501 -0
- quadra/callbacks/mlflow.py +291 -0
- quadra/callbacks/scheduler.py +69 -0
- quadra/configs/__init__.py +0 -0
- quadra/configs/backbone/caformer_m36.yaml +8 -0
- quadra/configs/backbone/caformer_s36.yaml +8 -0
- quadra/configs/backbone/convnextv2_base.yaml +8 -0
- quadra/configs/backbone/convnextv2_femto.yaml +8 -0
- quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
- quadra/configs/backbone/dino_vitb8.yaml +12 -0
- quadra/configs/backbone/dino_vits8.yaml +12 -0
- quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
- quadra/configs/backbone/dinov2_vits14.yaml +12 -0
- quadra/configs/backbone/efficientnet_b0.yaml +8 -0
- quadra/configs/backbone/efficientnet_b1.yaml +8 -0
- quadra/configs/backbone/efficientnet_b2.yaml +8 -0
- quadra/configs/backbone/efficientnet_b3.yaml +8 -0
- quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
- quadra/configs/backbone/levit_128s.yaml +8 -0
- quadra/configs/backbone/mnasnet0_5.yaml +9 -0
- quadra/configs/backbone/resnet101.yaml +8 -0
- quadra/configs/backbone/resnet18.yaml +8 -0
- quadra/configs/backbone/resnet18_ssl.yaml +8 -0
- quadra/configs/backbone/resnet50.yaml +8 -0
- quadra/configs/backbone/smp.yaml +9 -0
- quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
- quadra/configs/backbone/unetr.yaml +15 -0
- quadra/configs/backbone/vit16_base.yaml +9 -0
- quadra/configs/backbone/vit16_small.yaml +9 -0
- quadra/configs/backbone/vit16_tiny.yaml +9 -0
- quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
- quadra/configs/callbacks/all.yaml +45 -0
- quadra/configs/callbacks/default.yaml +34 -0
- quadra/configs/callbacks/default_anomalib.yaml +64 -0
- quadra/configs/config.yaml +33 -0
- quadra/configs/core/default.yaml +11 -0
- quadra/configs/datamodule/base/anomaly.yaml +16 -0
- quadra/configs/datamodule/base/classification.yaml +21 -0
- quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
- quadra/configs/datamodule/base/segmentation.yaml +18 -0
- quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
- quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
- quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
- quadra/configs/datamodule/base/ssl.yaml +21 -0
- quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
- quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
- quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
- quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
- quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
- quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
- quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
- quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
- quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
- quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
- quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/base/classification/classification.yaml +73 -0
- quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
- quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
- quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
- quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
- quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
- quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
- quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
- quadra/configs/experiment/base/ssl/byol.yaml +43 -0
- quadra/configs/experiment/base/ssl/dino.yaml +46 -0
- quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
- quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
- quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
- quadra/configs/experiment/custom/cls.yaml +12 -0
- quadra/configs/experiment/default.yaml +15 -0
- quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
- quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
- quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
- quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
- quadra/configs/export/default.yaml +13 -0
- quadra/configs/hydra/anomaly_custom.yaml +15 -0
- quadra/configs/hydra/default.yaml +14 -0
- quadra/configs/inference/default.yaml +26 -0
- quadra/configs/logger/comet.yaml +10 -0
- quadra/configs/logger/csv.yaml +5 -0
- quadra/configs/logger/mlflow.yaml +12 -0
- quadra/configs/logger/tensorboard.yaml +8 -0
- quadra/configs/loss/asl.yaml +7 -0
- quadra/configs/loss/barlow.yaml +2 -0
- quadra/configs/loss/bce.yaml +1 -0
- quadra/configs/loss/byol.yaml +1 -0
- quadra/configs/loss/cross_entropy.yaml +1 -0
- quadra/configs/loss/dino.yaml +8 -0
- quadra/configs/loss/simclr.yaml +2 -0
- quadra/configs/loss/simsiam.yaml +1 -0
- quadra/configs/loss/smp_ce.yaml +3 -0
- quadra/configs/loss/smp_dice.yaml +2 -0
- quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
- quadra/configs/loss/smp_mcc.yaml +2 -0
- quadra/configs/loss/vicreg.yaml +5 -0
- quadra/configs/model/anomalib/cfa.yaml +35 -0
- quadra/configs/model/anomalib/cflow.yaml +30 -0
- quadra/configs/model/anomalib/csflow.yaml +34 -0
- quadra/configs/model/anomalib/dfm.yaml +19 -0
- quadra/configs/model/anomalib/draem.yaml +29 -0
- quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
- quadra/configs/model/anomalib/fastflow.yaml +32 -0
- quadra/configs/model/anomalib/padim.yaml +32 -0
- quadra/configs/model/anomalib/patchcore.yaml +36 -0
- quadra/configs/model/barlow.yaml +16 -0
- quadra/configs/model/byol.yaml +25 -0
- quadra/configs/model/classification.yaml +10 -0
- quadra/configs/model/dino.yaml +26 -0
- quadra/configs/model/logistic_regression.yaml +4 -0
- quadra/configs/model/multilabel_classification.yaml +9 -0
- quadra/configs/model/simclr.yaml +18 -0
- quadra/configs/model/simsiam.yaml +24 -0
- quadra/configs/model/smp.yaml +4 -0
- quadra/configs/model/smp_multiclass.yaml +4 -0
- quadra/configs/model/vicreg.yaml +16 -0
- quadra/configs/optimizer/adam.yaml +5 -0
- quadra/configs/optimizer/adamw.yaml +3 -0
- quadra/configs/optimizer/default.yaml +4 -0
- quadra/configs/optimizer/lars.yaml +8 -0
- quadra/configs/optimizer/sgd.yaml +4 -0
- quadra/configs/scheduler/default.yaml +5 -0
- quadra/configs/scheduler/rop.yaml +5 -0
- quadra/configs/scheduler/step.yaml +3 -0
- quadra/configs/scheduler/warmrestart.yaml +2 -0
- quadra/configs/scheduler/warmup.yaml +6 -0
- quadra/configs/task/anomalib/cfa.yaml +5 -0
- quadra/configs/task/anomalib/cflow.yaml +5 -0
- quadra/configs/task/anomalib/csflow.yaml +5 -0
- quadra/configs/task/anomalib/draem.yaml +5 -0
- quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
- quadra/configs/task/anomalib/fastflow.yaml +5 -0
- quadra/configs/task/anomalib/inference.yaml +3 -0
- quadra/configs/task/anomalib/padim.yaml +5 -0
- quadra/configs/task/anomalib/patchcore.yaml +5 -0
- quadra/configs/task/classification.yaml +6 -0
- quadra/configs/task/classification_evaluation.yaml +6 -0
- quadra/configs/task/default.yaml +1 -0
- quadra/configs/task/segmentation.yaml +9 -0
- quadra/configs/task/segmentation_evaluation.yaml +3 -0
- quadra/configs/task/sklearn_classification.yaml +13 -0
- quadra/configs/task/sklearn_classification_patch.yaml +11 -0
- quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
- quadra/configs/task/sklearn_classification_test.yaml +8 -0
- quadra/configs/task/ssl.yaml +2 -0
- quadra/configs/trainer/lightning_cpu.yaml +36 -0
- quadra/configs/trainer/lightning_gpu.yaml +35 -0
- quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
- quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
- quadra/configs/trainer/lightning_multigpu.yaml +37 -0
- quadra/configs/trainer/sklearn_classification.yaml +7 -0
- quadra/configs/transforms/byol.yaml +47 -0
- quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
- quadra/configs/transforms/default.yaml +37 -0
- quadra/configs/transforms/default_numpy.yaml +24 -0
- quadra/configs/transforms/default_resize.yaml +22 -0
- quadra/configs/transforms/dino.yaml +63 -0
- quadra/configs/transforms/linear_eval.yaml +18 -0
- quadra/datamodules/__init__.py +20 -0
- quadra/datamodules/anomaly.py +180 -0
- quadra/datamodules/base.py +375 -0
- quadra/datamodules/classification.py +1003 -0
- quadra/datamodules/generic/__init__.py +0 -0
- quadra/datamodules/generic/imagenette.py +144 -0
- quadra/datamodules/generic/mnist.py +81 -0
- quadra/datamodules/generic/mvtec.py +58 -0
- quadra/datamodules/generic/oxford_pet.py +163 -0
- quadra/datamodules/patch.py +190 -0
- quadra/datamodules/segmentation.py +742 -0
- quadra/datamodules/ssl.py +140 -0
- quadra/datasets/__init__.py +17 -0
- quadra/datasets/anomaly.py +287 -0
- quadra/datasets/classification.py +241 -0
- quadra/datasets/patch.py +138 -0
- quadra/datasets/segmentation.py +239 -0
- quadra/datasets/ssl.py +110 -0
- quadra/losses/__init__.py +0 -0
- quadra/losses/classification/__init__.py +6 -0
- quadra/losses/classification/asl.py +83 -0
- quadra/losses/classification/focal.py +320 -0
- quadra/losses/classification/prototypical.py +148 -0
- quadra/losses/ssl/__init__.py +17 -0
- quadra/losses/ssl/barlowtwins.py +47 -0
- quadra/losses/ssl/byol.py +37 -0
- quadra/losses/ssl/dino.py +129 -0
- quadra/losses/ssl/hyperspherical.py +45 -0
- quadra/losses/ssl/idmm.py +50 -0
- quadra/losses/ssl/simclr.py +67 -0
- quadra/losses/ssl/simsiam.py +30 -0
- quadra/losses/ssl/vicreg.py +76 -0
- quadra/main.py +49 -0
- quadra/metrics/__init__.py +3 -0
- quadra/metrics/segmentation.py +251 -0
- quadra/models/__init__.py +0 -0
- quadra/models/base.py +151 -0
- quadra/models/classification/__init__.py +8 -0
- quadra/models/classification/backbones.py +149 -0
- quadra/models/classification/base.py +92 -0
- quadra/models/evaluation.py +322 -0
- quadra/modules/__init__.py +0 -0
- quadra/modules/backbone.py +30 -0
- quadra/modules/base.py +312 -0
- quadra/modules/classification/__init__.py +3 -0
- quadra/modules/classification/base.py +327 -0
- quadra/modules/ssl/__init__.py +17 -0
- quadra/modules/ssl/barlowtwins.py +59 -0
- quadra/modules/ssl/byol.py +172 -0
- quadra/modules/ssl/common.py +285 -0
- quadra/modules/ssl/dino.py +186 -0
- quadra/modules/ssl/hyperspherical.py +206 -0
- quadra/modules/ssl/idmm.py +98 -0
- quadra/modules/ssl/simclr.py +73 -0
- quadra/modules/ssl/simsiam.py +68 -0
- quadra/modules/ssl/vicreg.py +67 -0
- quadra/optimizers/__init__.py +4 -0
- quadra/optimizers/lars.py +153 -0
- quadra/optimizers/sam.py +127 -0
- quadra/schedulers/__init__.py +3 -0
- quadra/schedulers/base.py +44 -0
- quadra/schedulers/warmup.py +127 -0
- quadra/tasks/__init__.py +24 -0
- quadra/tasks/anomaly.py +582 -0
- quadra/tasks/base.py +397 -0
- quadra/tasks/classification.py +1263 -0
- quadra/tasks/patch.py +492 -0
- quadra/tasks/segmentation.py +389 -0
- quadra/tasks/ssl.py +560 -0
- quadra/trainers/README.md +3 -0
- quadra/trainers/__init__.py +0 -0
- quadra/trainers/classification.py +179 -0
- quadra/utils/__init__.py +0 -0
- quadra/utils/anomaly.py +112 -0
- quadra/utils/classification.py +618 -0
- quadra/utils/deprecation.py +31 -0
- quadra/utils/evaluation.py +474 -0
- quadra/utils/export.py +585 -0
- quadra/utils/imaging.py +32 -0
- quadra/utils/logger.py +15 -0
- quadra/utils/mlflow.py +98 -0
- quadra/utils/model_manager.py +320 -0
- quadra/utils/models.py +523 -0
- quadra/utils/patch/__init__.py +15 -0
- quadra/utils/patch/dataset.py +1433 -0
- quadra/utils/patch/metrics.py +449 -0
- quadra/utils/patch/model.py +153 -0
- quadra/utils/patch/visualization.py +217 -0
- quadra/utils/resolver.py +42 -0
- quadra/utils/segmentation.py +31 -0
- quadra/utils/tests/__init__.py +0 -0
- quadra/utils/tests/fixtures/__init__.py +1 -0
- quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
- quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
- quadra/utils/tests/fixtures/dataset/classification.py +406 -0
- quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
- quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
- quadra/utils/tests/fixtures/models/__init__.py +3 -0
- quadra/utils/tests/fixtures/models/anomaly.py +89 -0
- quadra/utils/tests/fixtures/models/classification.py +45 -0
- quadra/utils/tests/fixtures/models/segmentation.py +33 -0
- quadra/utils/tests/helpers.py +70 -0
- quadra/utils/tests/models.py +27 -0
- quadra/utils/utils.py +525 -0
- quadra/utils/validator.py +115 -0
- quadra/utils/visualization.py +422 -0
- quadra/utils/vit_explainability.py +349 -0
- quadra-2.2.7.dist-info/LICENSE +201 -0
- quadra-2.2.7.dist-info/METADATA +381 -0
- quadra-2.2.7.dist-info/RECORD +300 -0
- {quadra-0.0.1.dist-info → quadra-2.2.7.dist-info}/WHEEL +1 -1
- quadra-2.2.7.dist-info/entry_points.txt +3 -0
- quadra-0.0.1.dist-info/METADATA +0 -14
- quadra-0.0.1.dist-info/RECORD +0 -4
quadra/utils/export.py
ADDED
|
@@ -0,0 +1,585 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import contextlib
|
|
4
|
+
import os
|
|
5
|
+
from collections.abc import Sequence
|
|
6
|
+
from typing import Any, Literal, TypeVar, cast
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from anomalib.models.cflow import CflowLightning
|
|
10
|
+
from omegaconf import DictConfig, ListConfig, OmegaConf
|
|
11
|
+
from torch import nn
|
|
12
|
+
|
|
13
|
+
from quadra.models.base import ModelSignatureWrapper
|
|
14
|
+
from quadra.models.evaluation import (
|
|
15
|
+
BaseEvaluationModel,
|
|
16
|
+
ONNXEvaluationModel,
|
|
17
|
+
TorchEvaluationModel,
|
|
18
|
+
TorchscriptEvaluationModel,
|
|
19
|
+
)
|
|
20
|
+
from quadra.utils.logger import get_logger
|
|
21
|
+
|
|
22
|
+
try:
|
|
23
|
+
import onnx # noqa
|
|
24
|
+
from onnxsim import simplify as onnx_simplify # noqa
|
|
25
|
+
from onnxconverter_common import auto_convert_mixed_precision # noqa
|
|
26
|
+
|
|
27
|
+
ONNX_AVAILABLE = True
|
|
28
|
+
except ImportError:
|
|
29
|
+
ONNX_AVAILABLE = False
|
|
30
|
+
|
|
31
|
+
log = get_logger(__name__)
|
|
32
|
+
|
|
33
|
+
BaseDeploymentModelT = TypeVar("BaseDeploymentModelT", bound=BaseEvaluationModel)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def generate_torch_inputs(
|
|
37
|
+
input_shapes: list[Any],
|
|
38
|
+
device: str | torch.device,
|
|
39
|
+
half_precision: bool = False,
|
|
40
|
+
dtype: torch.dtype = torch.float32,
|
|
41
|
+
batch_size: int = 1,
|
|
42
|
+
) -> list[Any] | tuple[Any, ...] | torch.Tensor:
|
|
43
|
+
"""Given a list of input shapes that can contain either lists, tuples or dicts, with tuples being the input shapes
|
|
44
|
+
of the model, generate a list of torch tensors with the given device and dtype.
|
|
45
|
+
"""
|
|
46
|
+
inp = None
|
|
47
|
+
|
|
48
|
+
if isinstance(input_shapes, (ListConfig, DictConfig)):
|
|
49
|
+
input_shapes = OmegaConf.to_container(input_shapes)
|
|
50
|
+
|
|
51
|
+
if isinstance(input_shapes, list):
|
|
52
|
+
if any(isinstance(inp, (Sequence, dict)) for inp in input_shapes):
|
|
53
|
+
return [generate_torch_inputs(inp, device, half_precision, dtype) for inp in input_shapes]
|
|
54
|
+
|
|
55
|
+
# Base case
|
|
56
|
+
inp = torch.randn((batch_size, *input_shapes), dtype=dtype, device=device)
|
|
57
|
+
|
|
58
|
+
if isinstance(input_shapes, dict):
|
|
59
|
+
return {k: generate_torch_inputs(v, device, half_precision, dtype) for k, v in input_shapes.items()}
|
|
60
|
+
|
|
61
|
+
if isinstance(input_shapes, tuple):
|
|
62
|
+
if any(isinstance(inp, (Sequence, dict)) for inp in input_shapes):
|
|
63
|
+
# The tuple contains a list, tuple or dict
|
|
64
|
+
return tuple(generate_torch_inputs(inp, device, half_precision, dtype) for inp in input_shapes)
|
|
65
|
+
|
|
66
|
+
# Base case
|
|
67
|
+
inp = torch.randn((batch_size, *input_shapes), dtype=dtype, device=device)
|
|
68
|
+
|
|
69
|
+
if inp is None:
|
|
70
|
+
raise RuntimeError("Something went wrong during model export, unable to parse input shapes")
|
|
71
|
+
|
|
72
|
+
if half_precision:
|
|
73
|
+
inp = inp.half()
|
|
74
|
+
|
|
75
|
+
return inp
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def extract_torch_model_inputs(
|
|
79
|
+
model: nn.Module | ModelSignatureWrapper,
|
|
80
|
+
input_shapes: list[Any] | None = None,
|
|
81
|
+
half_precision: bool = False,
|
|
82
|
+
batch_size: int = 1,
|
|
83
|
+
) -> tuple[list[Any] | tuple[Any, ...] | torch.Tensor, list[Any]] | None:
|
|
84
|
+
"""Extract the input shapes for the given model and generate a list of torch tensors with the
|
|
85
|
+
given device and dtype.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
model: Module or ModelSignatureWrapper
|
|
89
|
+
input_shapes: Inputs shapes
|
|
90
|
+
half_precision: If True, the model will be exported with half precision
|
|
91
|
+
batch_size: Batch size for the input shapes
|
|
92
|
+
"""
|
|
93
|
+
if isinstance(model, ModelSignatureWrapper) and input_shapes is None:
|
|
94
|
+
input_shapes = model.input_shapes
|
|
95
|
+
|
|
96
|
+
if input_shapes is None:
|
|
97
|
+
log.warning(
|
|
98
|
+
"Input shape is None, can not trace model! Please provide input_shapes in the task export configuration."
|
|
99
|
+
)
|
|
100
|
+
return None
|
|
101
|
+
|
|
102
|
+
if half_precision:
|
|
103
|
+
# TODO: This doesn't support bfloat16!!
|
|
104
|
+
inp = generate_torch_inputs(
|
|
105
|
+
input_shapes=input_shapes, device="cuda:0", half_precision=True, dtype=torch.float16, batch_size=batch_size
|
|
106
|
+
)
|
|
107
|
+
else:
|
|
108
|
+
inp = generate_torch_inputs(
|
|
109
|
+
input_shapes=input_shapes, device="cpu", half_precision=False, dtype=torch.float32, batch_size=batch_size
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
return inp, input_shapes
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
@torch.inference_mode()
|
|
116
|
+
def export_torchscript_model(
|
|
117
|
+
model: nn.Module,
|
|
118
|
+
output_path: str,
|
|
119
|
+
input_shapes: list[Any] | None = None,
|
|
120
|
+
half_precision: bool = False,
|
|
121
|
+
model_name: str = "model.pt",
|
|
122
|
+
) -> tuple[str, Any] | None:
|
|
123
|
+
"""Export a PyTorch model with TorchScript.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
model: PyTorch model to be exported
|
|
127
|
+
input_shapes: Inputs shape for tracing
|
|
128
|
+
output_path: Path to save the model
|
|
129
|
+
half_precision: If True, the model will be exported with half precision
|
|
130
|
+
model_name: Name of the exported model
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
If the model is exported successfully, the path to the model and the input shape are returned.
|
|
134
|
+
|
|
135
|
+
"""
|
|
136
|
+
if isinstance(model, CflowLightning):
|
|
137
|
+
log.warning("Exporting cflow model with torchscript is not supported yet.")
|
|
138
|
+
return None
|
|
139
|
+
|
|
140
|
+
model.eval()
|
|
141
|
+
if half_precision:
|
|
142
|
+
model.to("cuda:0")
|
|
143
|
+
model = model.half()
|
|
144
|
+
else:
|
|
145
|
+
model.cpu()
|
|
146
|
+
|
|
147
|
+
model_inputs = extract_torch_model_inputs(model, input_shapes, half_precision)
|
|
148
|
+
|
|
149
|
+
if model_inputs is None:
|
|
150
|
+
return None
|
|
151
|
+
|
|
152
|
+
if isinstance(model, ModelSignatureWrapper):
|
|
153
|
+
model = model.instance
|
|
154
|
+
|
|
155
|
+
inp, input_shapes = model_inputs
|
|
156
|
+
|
|
157
|
+
try:
|
|
158
|
+
try:
|
|
159
|
+
model_jit = torch.jit.trace(model, inp)
|
|
160
|
+
except RuntimeError as e:
|
|
161
|
+
log.warning("Standard tracing failed with exception %s, attempting tracing with strict=False", e)
|
|
162
|
+
model_jit = torch.jit.trace(model, inp, strict=False)
|
|
163
|
+
|
|
164
|
+
os.makedirs(output_path, exist_ok=True)
|
|
165
|
+
|
|
166
|
+
model_path = os.path.join(output_path, model_name)
|
|
167
|
+
model_jit.save(model_path)
|
|
168
|
+
|
|
169
|
+
log.info("Torchscript model saved to %s", os.path.join(os.getcwd(), model_path))
|
|
170
|
+
|
|
171
|
+
return os.path.join(os.getcwd(), model_path), input_shapes
|
|
172
|
+
except Exception as e:
|
|
173
|
+
log.debug("Failed to export torchscript model with exception: %s", e)
|
|
174
|
+
return None
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
@torch.inference_mode()
|
|
178
|
+
def export_onnx_model(
|
|
179
|
+
model: nn.Module,
|
|
180
|
+
output_path: str,
|
|
181
|
+
onnx_config: DictConfig,
|
|
182
|
+
input_shapes: list[Any] | None = None,
|
|
183
|
+
half_precision: bool = False,
|
|
184
|
+
model_name: str = "model.onnx",
|
|
185
|
+
) -> tuple[str, Any] | None:
|
|
186
|
+
"""Export a PyTorch model with ONNX.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
model: PyTorch model to be exported
|
|
190
|
+
output_path: Path to save the model
|
|
191
|
+
input_shapes: Input shapes for tracing
|
|
192
|
+
onnx_config: ONNX export configuration
|
|
193
|
+
half_precision: If True, the model will be exported with half precision
|
|
194
|
+
model_name: Name of the exported model
|
|
195
|
+
"""
|
|
196
|
+
if not ONNX_AVAILABLE:
|
|
197
|
+
log.warning("ONNX is not installed, can not export model in this format.")
|
|
198
|
+
log.warning("Please install ONNX capabilities for quadra with: poetry install -E onnx")
|
|
199
|
+
return None
|
|
200
|
+
|
|
201
|
+
model.eval()
|
|
202
|
+
if half_precision:
|
|
203
|
+
model.to("cuda:0")
|
|
204
|
+
model = model.half()
|
|
205
|
+
else:
|
|
206
|
+
model.cpu()
|
|
207
|
+
|
|
208
|
+
if hasattr(onnx_config, "fixed_batch_size") and onnx_config.fixed_batch_size is not None:
|
|
209
|
+
batch_size = onnx_config.fixed_batch_size
|
|
210
|
+
else:
|
|
211
|
+
batch_size = 1
|
|
212
|
+
|
|
213
|
+
model_inputs = extract_torch_model_inputs(
|
|
214
|
+
model=model, input_shapes=input_shapes, half_precision=half_precision, batch_size=batch_size
|
|
215
|
+
)
|
|
216
|
+
if model_inputs is None:
|
|
217
|
+
return None
|
|
218
|
+
|
|
219
|
+
if isinstance(model, ModelSignatureWrapper):
|
|
220
|
+
model = model.instance
|
|
221
|
+
|
|
222
|
+
inp, input_shapes = model_inputs
|
|
223
|
+
|
|
224
|
+
os.makedirs(output_path, exist_ok=True)
|
|
225
|
+
|
|
226
|
+
model_path = os.path.join(output_path, model_name)
|
|
227
|
+
|
|
228
|
+
input_names = onnx_config.input_names if hasattr(onnx_config, "input_names") else None
|
|
229
|
+
|
|
230
|
+
if input_names is None:
|
|
231
|
+
input_names = []
|
|
232
|
+
for i, _ in enumerate(inp):
|
|
233
|
+
input_names.append(f"input_{i}")
|
|
234
|
+
|
|
235
|
+
output = [model(*inp)]
|
|
236
|
+
output_names = onnx_config.output_names if hasattr(onnx_config, "output_names") else None
|
|
237
|
+
|
|
238
|
+
if output_names is None:
|
|
239
|
+
output_names = []
|
|
240
|
+
for i, _ in enumerate(output):
|
|
241
|
+
output_names.append(f"output_{i}")
|
|
242
|
+
|
|
243
|
+
dynamic_axes = onnx_config.dynamic_axes if hasattr(onnx_config, "dynamic_axes") else None
|
|
244
|
+
|
|
245
|
+
if hasattr(onnx_config, "fixed_batch_size") and onnx_config.fixed_batch_size is not None:
|
|
246
|
+
dynamic_axes = None
|
|
247
|
+
elif dynamic_axes is None:
|
|
248
|
+
dynamic_axes = {}
|
|
249
|
+
for i, _ in enumerate(input_names):
|
|
250
|
+
dynamic_axes[input_names[i]] = {0: "batch_size"}
|
|
251
|
+
|
|
252
|
+
for i, _ in enumerate(output_names):
|
|
253
|
+
dynamic_axes[output_names[i]] = {0: "batch_size"}
|
|
254
|
+
|
|
255
|
+
modified_onnx_config = cast(dict[str, Any], OmegaConf.to_container(onnx_config, resolve=True))
|
|
256
|
+
|
|
257
|
+
modified_onnx_config["input_names"] = input_names
|
|
258
|
+
modified_onnx_config["output_names"] = output_names
|
|
259
|
+
modified_onnx_config["dynamic_axes"] = dynamic_axes
|
|
260
|
+
|
|
261
|
+
simplify = modified_onnx_config.pop("simplify", False)
|
|
262
|
+
_ = modified_onnx_config.pop("fixed_batch_size", None)
|
|
263
|
+
|
|
264
|
+
if len(inp) == 1:
|
|
265
|
+
inp = inp[0]
|
|
266
|
+
|
|
267
|
+
if isinstance(inp, list):
|
|
268
|
+
inp = tuple(inp) # onnx doesn't like lists representing tuples of inputs
|
|
269
|
+
|
|
270
|
+
if isinstance(inp, dict):
|
|
271
|
+
raise ValueError("ONNX export does not support model with dict inputs")
|
|
272
|
+
|
|
273
|
+
try:
|
|
274
|
+
torch.onnx.export(model=model, args=inp, f=model_path, **modified_onnx_config)
|
|
275
|
+
|
|
276
|
+
onnx_model = onnx.load(model_path)
|
|
277
|
+
# Check if ONNX model is valid
|
|
278
|
+
onnx.checker.check_model(onnx_model)
|
|
279
|
+
except Exception as e:
|
|
280
|
+
log.debug("ONNX export failed with error: %s", e)
|
|
281
|
+
return None
|
|
282
|
+
|
|
283
|
+
log.info("ONNX model saved to %s", os.path.join(os.getcwd(), model_path))
|
|
284
|
+
|
|
285
|
+
if half_precision:
|
|
286
|
+
is_export_ok = _safe_export_half_precision_onnx(
|
|
287
|
+
model=model,
|
|
288
|
+
export_model_path=model_path,
|
|
289
|
+
inp=inp,
|
|
290
|
+
onnx_config=onnx_config,
|
|
291
|
+
input_shapes=input_shapes,
|
|
292
|
+
input_names=input_names,
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
if not is_export_ok:
|
|
296
|
+
return None
|
|
297
|
+
|
|
298
|
+
if simplify:
|
|
299
|
+
log.info("Attempting to simplify ONNX model")
|
|
300
|
+
onnx_model = onnx.load(model_path)
|
|
301
|
+
|
|
302
|
+
try:
|
|
303
|
+
simplified_model, check = onnx_simplify(onnx_model)
|
|
304
|
+
except Exception as e:
|
|
305
|
+
log.debug("ONNX simplification failed with error: %s", e)
|
|
306
|
+
check = False
|
|
307
|
+
|
|
308
|
+
if not check:
|
|
309
|
+
log.warning("Something failed during model simplification, only original ONNX model will be exported")
|
|
310
|
+
else:
|
|
311
|
+
model_filename, model_extension = os.path.splitext(model_name)
|
|
312
|
+
model_name = f"{model_filename}_simplified{model_extension}"
|
|
313
|
+
model_path = os.path.join(output_path, model_name)
|
|
314
|
+
onnx.save(simplified_model, model_path)
|
|
315
|
+
log.info("Simplified ONNX model saved to %s", os.path.join(os.getcwd(), model_path))
|
|
316
|
+
|
|
317
|
+
return os.path.join(os.getcwd(), model_path), input_shapes
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
def _safe_export_half_precision_onnx(
|
|
321
|
+
model: nn.Module,
|
|
322
|
+
export_model_path: str,
|
|
323
|
+
inp: torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor, ...],
|
|
324
|
+
onnx_config: DictConfig,
|
|
325
|
+
input_shapes: list[Any],
|
|
326
|
+
input_names: list[str],
|
|
327
|
+
):
|
|
328
|
+
"""Check that the exported half precision ONNX model does not contain NaN values. If it does, attempt to export
|
|
329
|
+
the model with a more stable export and overwrite the original model.
|
|
330
|
+
|
|
331
|
+
Args:
|
|
332
|
+
model: PyTorch model to be exported
|
|
333
|
+
export_model_path: Path to save the model
|
|
334
|
+
inp: Input tensors for the model
|
|
335
|
+
onnx_config: ONNX export configuration
|
|
336
|
+
input_shapes: Input shapes for the model
|
|
337
|
+
input_names: Input names for the model
|
|
338
|
+
|
|
339
|
+
Returns:
|
|
340
|
+
True if the model is stable or it was possible to export a more stable model, False otherwise.
|
|
341
|
+
"""
|
|
342
|
+
test_fp_16_model: BaseEvaluationModel = import_deployment_model(
|
|
343
|
+
export_model_path, OmegaConf.create({"onnx": {}}), "cuda:0"
|
|
344
|
+
)
|
|
345
|
+
if not isinstance(inp, Sequence):
|
|
346
|
+
inp = [inp]
|
|
347
|
+
|
|
348
|
+
test_output = test_fp_16_model(*inp)
|
|
349
|
+
|
|
350
|
+
if not isinstance(test_output, Sequence):
|
|
351
|
+
test_output = [test_output]
|
|
352
|
+
|
|
353
|
+
# Check if there are nan values in any of the outputs
|
|
354
|
+
is_broken_model = any(torch.isnan(out).any() for out in test_output)
|
|
355
|
+
|
|
356
|
+
if is_broken_model:
|
|
357
|
+
try:
|
|
358
|
+
log.warning(
|
|
359
|
+
"The exported half precision ONNX model contains NaN values, attempting with a more stable export..."
|
|
360
|
+
)
|
|
361
|
+
# Cast back the fp16 model to fp32 to simulate the export with fp32
|
|
362
|
+
model = model.float()
|
|
363
|
+
log.info("Starting to export model in full precision")
|
|
364
|
+
export_output = export_onnx_model(
|
|
365
|
+
model=model,
|
|
366
|
+
output_path=os.path.dirname(export_model_path),
|
|
367
|
+
onnx_config=onnx_config,
|
|
368
|
+
input_shapes=input_shapes,
|
|
369
|
+
half_precision=False,
|
|
370
|
+
model_name=os.path.basename(export_model_path),
|
|
371
|
+
)
|
|
372
|
+
if export_output is not None:
|
|
373
|
+
export_model_path, _ = export_output
|
|
374
|
+
else:
|
|
375
|
+
log.warning("Failed to export model")
|
|
376
|
+
return False
|
|
377
|
+
|
|
378
|
+
model_fp32 = onnx.load(export_model_path)
|
|
379
|
+
test_data = {input_names[i]: inp[i].float().cpu().numpy() for i in range(len(inp))}
|
|
380
|
+
log.warning("Attempting to convert model in mixed precision, this may take a while...")
|
|
381
|
+
with open(os.devnull, "w") as f, contextlib.redirect_stdout(f):
|
|
382
|
+
# This function prints a lot of information that is not useful for the user
|
|
383
|
+
model_fp16 = auto_convert_mixed_precision(
|
|
384
|
+
model_fp32, test_data, rtol=0.01, atol=0.001, keep_io_types=False
|
|
385
|
+
)
|
|
386
|
+
onnx.save(model_fp16, export_model_path)
|
|
387
|
+
|
|
388
|
+
onnx_model = onnx.load(export_model_path)
|
|
389
|
+
# Check if ONNX model is valid
|
|
390
|
+
onnx.checker.check_model(onnx_model)
|
|
391
|
+
return True
|
|
392
|
+
except Exception as e:
|
|
393
|
+
raise RuntimeError(
|
|
394
|
+
"Failed to export model with automatic mixed precision, check your model or disable ONNX export"
|
|
395
|
+
) from e
|
|
396
|
+
else:
|
|
397
|
+
log.info("Exported half precision ONNX model does not contain NaN values, model is stable")
|
|
398
|
+
return True
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
def export_pytorch_model(model: nn.Module, output_path: str, model_name: str = "model.pth") -> str:
|
|
402
|
+
"""Export pytorch model's parameter dictionary using a deserialized state_dict.
|
|
403
|
+
|
|
404
|
+
Args:
|
|
405
|
+
model: PyTorch model to be exported
|
|
406
|
+
output_path: Path to save the model
|
|
407
|
+
model_name: Name of the exported model
|
|
408
|
+
|
|
409
|
+
Returns:
|
|
410
|
+
If the model is exported successfully, the path to the model is returned.
|
|
411
|
+
|
|
412
|
+
"""
|
|
413
|
+
if isinstance(model, ModelSignatureWrapper):
|
|
414
|
+
model = model.instance
|
|
415
|
+
|
|
416
|
+
os.makedirs(output_path, exist_ok=True)
|
|
417
|
+
model.eval()
|
|
418
|
+
model.cpu()
|
|
419
|
+
model_path = os.path.join(output_path, model_name)
|
|
420
|
+
torch.save(model.state_dict(), model_path)
|
|
421
|
+
log.info("Pytorch model saved to %s", os.path.join(output_path, model_name))
|
|
422
|
+
|
|
423
|
+
return os.path.join(os.getcwd(), model_path)
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
def export_model(
|
|
427
|
+
config: DictConfig,
|
|
428
|
+
model: Any,
|
|
429
|
+
export_folder: str,
|
|
430
|
+
half_precision: bool,
|
|
431
|
+
input_shapes: list[Any] | None = None,
|
|
432
|
+
idx_to_class: dict[int, str] | None = None,
|
|
433
|
+
pytorch_model_type: Literal["backbone", "model"] = "model",
|
|
434
|
+
) -> tuple[dict[str, Any], dict[str, str]]:
|
|
435
|
+
"""Generate deployment models for the task.
|
|
436
|
+
|
|
437
|
+
Args:
|
|
438
|
+
config: Experiment config
|
|
439
|
+
model: Model to be exported
|
|
440
|
+
export_folder: Path to save the exported model
|
|
441
|
+
half_precision: Whether to use half precision for the exported model
|
|
442
|
+
input_shapes: Input shapes for the exported model
|
|
443
|
+
idx_to_class: Mapping from class index to class name
|
|
444
|
+
pytorch_model_type: Type of the pytorch model config to be exported, if it's backbone on disk we will save the
|
|
445
|
+
config.backbone config, otherwise we will save the config.model
|
|
446
|
+
|
|
447
|
+
Returns:
|
|
448
|
+
If the model is exported successfully, return a dictionary containing information about the exported model and
|
|
449
|
+
a second dictionary containing the paths to the exported models. Otherwise, return two empty dictionaries.
|
|
450
|
+
"""
|
|
451
|
+
if config.export is None or len(config.export.types) == 0:
|
|
452
|
+
log.info("No export type specified skipping export")
|
|
453
|
+
return {}, {}
|
|
454
|
+
|
|
455
|
+
os.makedirs(export_folder, exist_ok=True)
|
|
456
|
+
|
|
457
|
+
if input_shapes is None:
|
|
458
|
+
# Try to get input shapes from config
|
|
459
|
+
# If this is also None we will try to retrieve it from the ModelSignatureWrapper, if it fails we can't export
|
|
460
|
+
input_shapes = config.export.input_shapes
|
|
461
|
+
|
|
462
|
+
export_paths = {}
|
|
463
|
+
|
|
464
|
+
for export_type in config.export.types:
|
|
465
|
+
if export_type == "torchscript":
|
|
466
|
+
out = export_torchscript_model(
|
|
467
|
+
model=model,
|
|
468
|
+
input_shapes=input_shapes,
|
|
469
|
+
output_path=export_folder,
|
|
470
|
+
half_precision=half_precision,
|
|
471
|
+
)
|
|
472
|
+
|
|
473
|
+
if out is None:
|
|
474
|
+
log.warning("Torchscript export failed, enable debug logging for more details")
|
|
475
|
+
continue
|
|
476
|
+
|
|
477
|
+
export_path, input_shapes = out
|
|
478
|
+
export_paths[export_type] = export_path
|
|
479
|
+
elif export_type == "pytorch":
|
|
480
|
+
export_path = export_pytorch_model(
|
|
481
|
+
model=model,
|
|
482
|
+
output_path=export_folder,
|
|
483
|
+
)
|
|
484
|
+
export_paths[export_type] = export_path
|
|
485
|
+
with open(os.path.join(export_folder, "model_config.yaml"), "w") as f:
|
|
486
|
+
OmegaConf.save(getattr(config, pytorch_model_type), f, resolve=True)
|
|
487
|
+
elif export_type == "onnx":
|
|
488
|
+
if not hasattr(config.export, "onnx"):
|
|
489
|
+
log.warning("No onnx configuration found, skipping onnx export")
|
|
490
|
+
continue
|
|
491
|
+
|
|
492
|
+
out = export_onnx_model(
|
|
493
|
+
model=model,
|
|
494
|
+
output_path=export_folder,
|
|
495
|
+
onnx_config=config.export.onnx,
|
|
496
|
+
input_shapes=input_shapes,
|
|
497
|
+
half_precision=half_precision,
|
|
498
|
+
)
|
|
499
|
+
|
|
500
|
+
if out is None:
|
|
501
|
+
log.warning("ONNX export failed, enable debug logging for more details")
|
|
502
|
+
continue
|
|
503
|
+
|
|
504
|
+
export_path, input_shapes = out
|
|
505
|
+
export_paths[export_type] = export_path
|
|
506
|
+
else:
|
|
507
|
+
log.warning("Export type: %s not implemented", export_type)
|
|
508
|
+
|
|
509
|
+
if len(export_paths) == 0:
|
|
510
|
+
log.warning("No export type was successful, no model will be available for deployment")
|
|
511
|
+
return {}, export_paths
|
|
512
|
+
|
|
513
|
+
model_json = {
|
|
514
|
+
"input_size": input_shapes,
|
|
515
|
+
"classes": idx_to_class,
|
|
516
|
+
"mean": list(config.transforms.mean),
|
|
517
|
+
"std": list(config.transforms.std),
|
|
518
|
+
}
|
|
519
|
+
|
|
520
|
+
return model_json, export_paths
|
|
521
|
+
|
|
522
|
+
|
|
523
|
+
def import_deployment_model(
|
|
524
|
+
model_path: str,
|
|
525
|
+
inference_config: DictConfig,
|
|
526
|
+
device: str,
|
|
527
|
+
model_architecture: nn.Module | None = None,
|
|
528
|
+
) -> BaseEvaluationModel:
|
|
529
|
+
"""Try to import a model for deployment, currently only supports torchscript .pt files and
|
|
530
|
+
state dictionaries .pth files.
|
|
531
|
+
|
|
532
|
+
Args:
|
|
533
|
+
model_path: Path to the model
|
|
534
|
+
inference_config: Inference configuration, should contain keys for the different deployment models
|
|
535
|
+
device: Device to load the model on
|
|
536
|
+
model_architecture: Optional model architecture to use for loading a plain pytorch model
|
|
537
|
+
|
|
538
|
+
Returns:
|
|
539
|
+
A tuple containing the model and the model type
|
|
540
|
+
"""
|
|
541
|
+
log.info("Importing trained model")
|
|
542
|
+
|
|
543
|
+
file_extension = os.path.splitext(os.path.basename(model_path))[1]
|
|
544
|
+
deployment_model: BaseEvaluationModel | None = None
|
|
545
|
+
|
|
546
|
+
if file_extension == ".pt":
|
|
547
|
+
deployment_model = TorchscriptEvaluationModel(config=inference_config.torchscript)
|
|
548
|
+
elif file_extension == ".pth":
|
|
549
|
+
if model_architecture is None:
|
|
550
|
+
raise ValueError("model_architecture must be specified when loading a .pth file")
|
|
551
|
+
|
|
552
|
+
deployment_model = TorchEvaluationModel(config=inference_config.pytorch, model_architecture=model_architecture)
|
|
553
|
+
elif file_extension == ".onnx":
|
|
554
|
+
deployment_model = ONNXEvaluationModel(config=inference_config.onnx)
|
|
555
|
+
|
|
556
|
+
if deployment_model is not None:
|
|
557
|
+
deployment_model.load_from_disk(model_path=model_path, device=device)
|
|
558
|
+
|
|
559
|
+
log.info("Imported %s model", deployment_model.__class__.__name__)
|
|
560
|
+
|
|
561
|
+
return deployment_model
|
|
562
|
+
|
|
563
|
+
raise ValueError(f"Unable to load model with extension {file_extension}, valid extensions are: ['.pt', 'pth']")
|
|
564
|
+
|
|
565
|
+
|
|
566
|
+
# This may be better as a dict?
|
|
567
|
+
def get_export_extension(export_type: str) -> str:
|
|
568
|
+
"""Get the extension of the exported model.
|
|
569
|
+
|
|
570
|
+
Args:
|
|
571
|
+
export_type: The type of the exported model.
|
|
572
|
+
|
|
573
|
+
Returns:
|
|
574
|
+
The extension of the exported model.
|
|
575
|
+
"""
|
|
576
|
+
if export_type == "onnx":
|
|
577
|
+
extension = "onnx"
|
|
578
|
+
elif export_type == "torchscript":
|
|
579
|
+
extension = "pt"
|
|
580
|
+
elif export_type == "pytorch":
|
|
581
|
+
extension = "pth"
|
|
582
|
+
else:
|
|
583
|
+
raise ValueError(f"Unsupported export type {export_type}")
|
|
584
|
+
|
|
585
|
+
return extension
|
quadra/utils/imaging.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import cv2
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def crop_image(image: np.ndarray, roi: tuple[int, int, int, int]) -> np.ndarray:
|
|
8
|
+
"""Crop an image given a roi in proper format.
|
|
9
|
+
|
|
10
|
+
Args:
|
|
11
|
+
image: array of size HxW or HxWxC
|
|
12
|
+
roi: (w_upper_left, h_upper_left, w_bottom_right, h_bottom_right)
|
|
13
|
+
|
|
14
|
+
Returns:
|
|
15
|
+
Cropped image based on roi
|
|
16
|
+
"""
|
|
17
|
+
return image[roi[1] : roi[3], roi[0] : roi[2]]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def keep_aspect_ratio_resize(image: np.ndarray, size: int = 224, interpolation: int = 1) -> np.ndarray:
|
|
21
|
+
"""Resize input image while keeping its aspect ratio."""
|
|
22
|
+
(h, w) = image.shape[:2]
|
|
23
|
+
|
|
24
|
+
if h < w:
|
|
25
|
+
height = size
|
|
26
|
+
width = int(w * size / h)
|
|
27
|
+
else:
|
|
28
|
+
width = size
|
|
29
|
+
height = int(h * size / w)
|
|
30
|
+
|
|
31
|
+
resized = cv2.resize(image, (width, height), interpolation=interpolation)
|
|
32
|
+
return resized
|
quadra/utils/logger.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
from pytorch_lightning.utilities import rank_zero_only
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def get_logger(name=__name__) -> logging.Logger:
|
|
7
|
+
"""Initializes multi-GPU-friendly python logger."""
|
|
8
|
+
logger = logging.getLogger(name)
|
|
9
|
+
|
|
10
|
+
# this ensures all logging levels get marked with the rank zero decorator
|
|
11
|
+
# otherwise logs would get multiplied for each GPU process in multi-GPU setup
|
|
12
|
+
for level in ("debug", "info", "warning", "error", "exception", "fatal", "critical"):
|
|
13
|
+
setattr(logger, level, rank_zero_only(getattr(logger, level)))
|
|
14
|
+
|
|
15
|
+
return logger
|