quadra 0.0.1__py3-none-any.whl → 2.1.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hydra_plugins/quadra_searchpath_plugin.py +30 -0
- quadra/__init__.py +6 -0
- quadra/callbacks/__init__.py +0 -0
- quadra/callbacks/anomalib.py +289 -0
- quadra/callbacks/lightning.py +501 -0
- quadra/callbacks/mlflow.py +291 -0
- quadra/callbacks/scheduler.py +69 -0
- quadra/configs/__init__.py +0 -0
- quadra/configs/backbone/caformer_m36.yaml +8 -0
- quadra/configs/backbone/caformer_s36.yaml +8 -0
- quadra/configs/backbone/convnextv2_base.yaml +8 -0
- quadra/configs/backbone/convnextv2_femto.yaml +8 -0
- quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
- quadra/configs/backbone/dino_vitb8.yaml +12 -0
- quadra/configs/backbone/dino_vits8.yaml +12 -0
- quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
- quadra/configs/backbone/dinov2_vits14.yaml +12 -0
- quadra/configs/backbone/efficientnet_b0.yaml +8 -0
- quadra/configs/backbone/efficientnet_b1.yaml +8 -0
- quadra/configs/backbone/efficientnet_b2.yaml +8 -0
- quadra/configs/backbone/efficientnet_b3.yaml +8 -0
- quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
- quadra/configs/backbone/levit_128s.yaml +8 -0
- quadra/configs/backbone/mnasnet0_5.yaml +9 -0
- quadra/configs/backbone/resnet101.yaml +8 -0
- quadra/configs/backbone/resnet18.yaml +8 -0
- quadra/configs/backbone/resnet18_ssl.yaml +8 -0
- quadra/configs/backbone/resnet50.yaml +8 -0
- quadra/configs/backbone/smp.yaml +9 -0
- quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
- quadra/configs/backbone/unetr.yaml +15 -0
- quadra/configs/backbone/vit16_base.yaml +9 -0
- quadra/configs/backbone/vit16_small.yaml +9 -0
- quadra/configs/backbone/vit16_tiny.yaml +9 -0
- quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
- quadra/configs/callbacks/all.yaml +32 -0
- quadra/configs/callbacks/default.yaml +37 -0
- quadra/configs/callbacks/default_anomalib.yaml +67 -0
- quadra/configs/config.yaml +33 -0
- quadra/configs/core/default.yaml +11 -0
- quadra/configs/datamodule/base/anomaly.yaml +16 -0
- quadra/configs/datamodule/base/classification.yaml +21 -0
- quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
- quadra/configs/datamodule/base/segmentation.yaml +18 -0
- quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
- quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
- quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
- quadra/configs/datamodule/base/ssl.yaml +21 -0
- quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
- quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
- quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
- quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
- quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
- quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
- quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
- quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
- quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
- quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
- quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/base/classification/classification.yaml +73 -0
- quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
- quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
- quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
- quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
- quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
- quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
- quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
- quadra/configs/experiment/base/ssl/byol.yaml +43 -0
- quadra/configs/experiment/base/ssl/dino.yaml +46 -0
- quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
- quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
- quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
- quadra/configs/experiment/custom/cls.yaml +12 -0
- quadra/configs/experiment/default.yaml +15 -0
- quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
- quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
- quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
- quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
- quadra/configs/export/default.yaml +13 -0
- quadra/configs/hydra/anomaly_custom.yaml +15 -0
- quadra/configs/hydra/default.yaml +14 -0
- quadra/configs/inference/default.yaml +26 -0
- quadra/configs/logger/comet.yaml +10 -0
- quadra/configs/logger/csv.yaml +5 -0
- quadra/configs/logger/mlflow.yaml +12 -0
- quadra/configs/logger/tensorboard.yaml +8 -0
- quadra/configs/loss/asl.yaml +7 -0
- quadra/configs/loss/barlow.yaml +2 -0
- quadra/configs/loss/bce.yaml +1 -0
- quadra/configs/loss/byol.yaml +1 -0
- quadra/configs/loss/cross_entropy.yaml +1 -0
- quadra/configs/loss/dino.yaml +8 -0
- quadra/configs/loss/simclr.yaml +2 -0
- quadra/configs/loss/simsiam.yaml +1 -0
- quadra/configs/loss/smp_ce.yaml +3 -0
- quadra/configs/loss/smp_dice.yaml +2 -0
- quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
- quadra/configs/loss/smp_mcc.yaml +2 -0
- quadra/configs/loss/vicreg.yaml +5 -0
- quadra/configs/model/anomalib/cfa.yaml +35 -0
- quadra/configs/model/anomalib/cflow.yaml +30 -0
- quadra/configs/model/anomalib/csflow.yaml +34 -0
- quadra/configs/model/anomalib/dfm.yaml +19 -0
- quadra/configs/model/anomalib/draem.yaml +29 -0
- quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
- quadra/configs/model/anomalib/fastflow.yaml +32 -0
- quadra/configs/model/anomalib/padim.yaml +32 -0
- quadra/configs/model/anomalib/patchcore.yaml +36 -0
- quadra/configs/model/barlow.yaml +16 -0
- quadra/configs/model/byol.yaml +25 -0
- quadra/configs/model/classification.yaml +10 -0
- quadra/configs/model/dino.yaml +26 -0
- quadra/configs/model/logistic_regression.yaml +4 -0
- quadra/configs/model/multilabel_classification.yaml +9 -0
- quadra/configs/model/simclr.yaml +18 -0
- quadra/configs/model/simsiam.yaml +24 -0
- quadra/configs/model/smp.yaml +4 -0
- quadra/configs/model/smp_multiclass.yaml +4 -0
- quadra/configs/model/vicreg.yaml +16 -0
- quadra/configs/optimizer/adam.yaml +5 -0
- quadra/configs/optimizer/adamw.yaml +3 -0
- quadra/configs/optimizer/default.yaml +4 -0
- quadra/configs/optimizer/lars.yaml +8 -0
- quadra/configs/optimizer/sgd.yaml +4 -0
- quadra/configs/scheduler/default.yaml +5 -0
- quadra/configs/scheduler/rop.yaml +5 -0
- quadra/configs/scheduler/step.yaml +3 -0
- quadra/configs/scheduler/warmrestart.yaml +2 -0
- quadra/configs/scheduler/warmup.yaml +6 -0
- quadra/configs/task/anomalib/cfa.yaml +5 -0
- quadra/configs/task/anomalib/cflow.yaml +5 -0
- quadra/configs/task/anomalib/csflow.yaml +5 -0
- quadra/configs/task/anomalib/draem.yaml +5 -0
- quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
- quadra/configs/task/anomalib/fastflow.yaml +5 -0
- quadra/configs/task/anomalib/inference.yaml +3 -0
- quadra/configs/task/anomalib/padim.yaml +5 -0
- quadra/configs/task/anomalib/patchcore.yaml +5 -0
- quadra/configs/task/classification.yaml +6 -0
- quadra/configs/task/classification_evaluation.yaml +6 -0
- quadra/configs/task/default.yaml +1 -0
- quadra/configs/task/segmentation.yaml +9 -0
- quadra/configs/task/segmentation_evaluation.yaml +3 -0
- quadra/configs/task/sklearn_classification.yaml +13 -0
- quadra/configs/task/sklearn_classification_patch.yaml +11 -0
- quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
- quadra/configs/task/sklearn_classification_test.yaml +8 -0
- quadra/configs/task/ssl.yaml +2 -0
- quadra/configs/trainer/lightning_cpu.yaml +36 -0
- quadra/configs/trainer/lightning_gpu.yaml +35 -0
- quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
- quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
- quadra/configs/trainer/lightning_multigpu.yaml +37 -0
- quadra/configs/trainer/sklearn_classification.yaml +7 -0
- quadra/configs/transforms/byol.yaml +47 -0
- quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
- quadra/configs/transforms/default.yaml +37 -0
- quadra/configs/transforms/default_numpy.yaml +24 -0
- quadra/configs/transforms/default_resize.yaml +22 -0
- quadra/configs/transforms/dino.yaml +63 -0
- quadra/configs/transforms/linear_eval.yaml +18 -0
- quadra/datamodules/__init__.py +20 -0
- quadra/datamodules/anomaly.py +180 -0
- quadra/datamodules/base.py +375 -0
- quadra/datamodules/classification.py +1003 -0
- quadra/datamodules/generic/__init__.py +0 -0
- quadra/datamodules/generic/imagenette.py +144 -0
- quadra/datamodules/generic/mnist.py +81 -0
- quadra/datamodules/generic/mvtec.py +58 -0
- quadra/datamodules/generic/oxford_pet.py +163 -0
- quadra/datamodules/patch.py +190 -0
- quadra/datamodules/segmentation.py +742 -0
- quadra/datamodules/ssl.py +140 -0
- quadra/datasets/__init__.py +17 -0
- quadra/datasets/anomaly.py +287 -0
- quadra/datasets/classification.py +241 -0
- quadra/datasets/patch.py +138 -0
- quadra/datasets/segmentation.py +239 -0
- quadra/datasets/ssl.py +110 -0
- quadra/losses/__init__.py +0 -0
- quadra/losses/classification/__init__.py +6 -0
- quadra/losses/classification/asl.py +83 -0
- quadra/losses/classification/focal.py +320 -0
- quadra/losses/classification/prototypical.py +148 -0
- quadra/losses/ssl/__init__.py +17 -0
- quadra/losses/ssl/barlowtwins.py +47 -0
- quadra/losses/ssl/byol.py +37 -0
- quadra/losses/ssl/dino.py +129 -0
- quadra/losses/ssl/hyperspherical.py +45 -0
- quadra/losses/ssl/idmm.py +50 -0
- quadra/losses/ssl/simclr.py +67 -0
- quadra/losses/ssl/simsiam.py +30 -0
- quadra/losses/ssl/vicreg.py +76 -0
- quadra/main.py +46 -0
- quadra/metrics/__init__.py +3 -0
- quadra/metrics/segmentation.py +251 -0
- quadra/models/__init__.py +0 -0
- quadra/models/base.py +151 -0
- quadra/models/classification/__init__.py +8 -0
- quadra/models/classification/backbones.py +149 -0
- quadra/models/classification/base.py +92 -0
- quadra/models/evaluation.py +322 -0
- quadra/modules/__init__.py +0 -0
- quadra/modules/backbone.py +30 -0
- quadra/modules/base.py +312 -0
- quadra/modules/classification/__init__.py +3 -0
- quadra/modules/classification/base.py +331 -0
- quadra/modules/ssl/__init__.py +17 -0
- quadra/modules/ssl/barlowtwins.py +59 -0
- quadra/modules/ssl/byol.py +172 -0
- quadra/modules/ssl/common.py +285 -0
- quadra/modules/ssl/dino.py +186 -0
- quadra/modules/ssl/hyperspherical.py +206 -0
- quadra/modules/ssl/idmm.py +98 -0
- quadra/modules/ssl/simclr.py +73 -0
- quadra/modules/ssl/simsiam.py +68 -0
- quadra/modules/ssl/vicreg.py +67 -0
- quadra/optimizers/__init__.py +4 -0
- quadra/optimizers/lars.py +153 -0
- quadra/optimizers/sam.py +127 -0
- quadra/schedulers/__init__.py +3 -0
- quadra/schedulers/base.py +44 -0
- quadra/schedulers/warmup.py +127 -0
- quadra/tasks/__init__.py +24 -0
- quadra/tasks/anomaly.py +582 -0
- quadra/tasks/base.py +397 -0
- quadra/tasks/classification.py +1264 -0
- quadra/tasks/patch.py +492 -0
- quadra/tasks/segmentation.py +389 -0
- quadra/tasks/ssl.py +560 -0
- quadra/trainers/README.md +3 -0
- quadra/trainers/__init__.py +0 -0
- quadra/trainers/classification.py +179 -0
- quadra/utils/__init__.py +0 -0
- quadra/utils/anomaly.py +112 -0
- quadra/utils/classification.py +618 -0
- quadra/utils/deprecation.py +31 -0
- quadra/utils/evaluation.py +474 -0
- quadra/utils/export.py +579 -0
- quadra/utils/imaging.py +32 -0
- quadra/utils/logger.py +15 -0
- quadra/utils/mlflow.py +98 -0
- quadra/utils/model_manager.py +320 -0
- quadra/utils/models.py +524 -0
- quadra/utils/patch/__init__.py +15 -0
- quadra/utils/patch/dataset.py +1433 -0
- quadra/utils/patch/metrics.py +449 -0
- quadra/utils/patch/model.py +153 -0
- quadra/utils/patch/visualization.py +217 -0
- quadra/utils/resolver.py +42 -0
- quadra/utils/segmentation.py +31 -0
- quadra/utils/tests/__init__.py +0 -0
- quadra/utils/tests/fixtures/__init__.py +1 -0
- quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
- quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
- quadra/utils/tests/fixtures/dataset/classification.py +406 -0
- quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
- quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
- quadra/utils/tests/fixtures/models/__init__.py +3 -0
- quadra/utils/tests/fixtures/models/anomaly.py +89 -0
- quadra/utils/tests/fixtures/models/classification.py +45 -0
- quadra/utils/tests/fixtures/models/segmentation.py +33 -0
- quadra/utils/tests/helpers.py +70 -0
- quadra/utils/tests/models.py +27 -0
- quadra/utils/utils.py +525 -0
- quadra/utils/validator.py +115 -0
- quadra/utils/visualization.py +422 -0
- quadra/utils/vit_explainability.py +349 -0
- quadra-2.1.13.dist-info/LICENSE +201 -0
- quadra-2.1.13.dist-info/METADATA +386 -0
- quadra-2.1.13.dist-info/RECORD +300 -0
- {quadra-0.0.1.dist-info → quadra-2.1.13.dist-info}/WHEEL +1 -1
- quadra-2.1.13.dist-info/entry_points.txt +3 -0
- quadra-0.0.1.dist-info/METADATA +0 -14
- quadra-0.0.1.dist-info/RECORD +0 -4
quadra/tasks/base.py
ADDED
|
@@ -0,0 +1,397 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Generic, TypeVar
|
|
7
|
+
|
|
8
|
+
import hydra
|
|
9
|
+
import torch
|
|
10
|
+
from hydra.core.hydra_config import HydraConfig
|
|
11
|
+
from lightning_fabric.utilities.device_parser import _parse_gpu_ids
|
|
12
|
+
from omegaconf import DictConfig, OmegaConf, open_dict
|
|
13
|
+
from pytorch_lightning import Callback, LightningModule, Trainer
|
|
14
|
+
from pytorch_lightning.loggers import Logger, MLFlowLogger
|
|
15
|
+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
|
16
|
+
|
|
17
|
+
from quadra import get_version
|
|
18
|
+
from quadra.callbacks.mlflow import validate_artifact_storage
|
|
19
|
+
from quadra.datamodules.base import BaseDataModule
|
|
20
|
+
from quadra.models.evaluation import BaseEvaluationModel
|
|
21
|
+
from quadra.utils import utils
|
|
22
|
+
from quadra.utils.export import import_deployment_model
|
|
23
|
+
|
|
24
|
+
log = utils.get_logger(__name__)
|
|
25
|
+
DataModuleT = TypeVar("DataModuleT", bound=BaseDataModule)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class Task(Generic[DataModuleT]):
|
|
29
|
+
"""Base Experiment Task.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
config: The experiment configuration.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(self, config: DictConfig):
|
|
36
|
+
self.config = config
|
|
37
|
+
self.export_folder: str = "deployment_model"
|
|
38
|
+
self._datamodule: DataModuleT
|
|
39
|
+
self.metadata: dict[str, Any]
|
|
40
|
+
self.save_config()
|
|
41
|
+
|
|
42
|
+
def save_config(self) -> None:
|
|
43
|
+
"""Save the experiment configuration when running an Hydra experiment."""
|
|
44
|
+
if HydraConfig.initialized():
|
|
45
|
+
with open("config_resolved.yaml", "w") as fp:
|
|
46
|
+
OmegaConf.save(config=OmegaConf.to_container(self.config, resolve=True), f=fp.name)
|
|
47
|
+
|
|
48
|
+
def prepare(self) -> None:
|
|
49
|
+
"""Prepare the experiment."""
|
|
50
|
+
self.datamodule = self.config.datamodule
|
|
51
|
+
|
|
52
|
+
@property
|
|
53
|
+
def datamodule(self) -> DataModuleT:
|
|
54
|
+
"""T_DATAMODULE: The datamodule."""
|
|
55
|
+
return self._datamodule
|
|
56
|
+
|
|
57
|
+
@datamodule.setter
|
|
58
|
+
def datamodule(self, datamodule_config: DictConfig) -> None:
|
|
59
|
+
"""DataModuleT: The datamodule. Instantiated from the datamodule config."""
|
|
60
|
+
log.info("Instantiating datamodule <%s>", {datamodule_config["_target_"]})
|
|
61
|
+
datamodule: DataModuleT = hydra.utils.instantiate(datamodule_config)
|
|
62
|
+
self._datamodule = datamodule
|
|
63
|
+
|
|
64
|
+
def train(self) -> Any:
|
|
65
|
+
"""Train the model."""
|
|
66
|
+
log.info("Training not implemented for this task!")
|
|
67
|
+
|
|
68
|
+
def test(self) -> Any:
|
|
69
|
+
"""Test the model."""
|
|
70
|
+
log.info("Testing not implemented for this task!")
|
|
71
|
+
|
|
72
|
+
def export(self) -> None:
|
|
73
|
+
"""Export model for production."""
|
|
74
|
+
log.info("Export model for production not implemented for this task!")
|
|
75
|
+
|
|
76
|
+
def generate_report(self) -> None:
|
|
77
|
+
"""Generate a report."""
|
|
78
|
+
log.info("Report generation not implemented for this task!")
|
|
79
|
+
|
|
80
|
+
def finalize(self) -> None:
|
|
81
|
+
"""Finalize the experiment."""
|
|
82
|
+
log.info("Results are saved in %s", os.getcwd())
|
|
83
|
+
|
|
84
|
+
def execute(self) -> None:
|
|
85
|
+
"""Execute the experiment and all the steps."""
|
|
86
|
+
self.prepare()
|
|
87
|
+
self.train()
|
|
88
|
+
self.test()
|
|
89
|
+
if self.config.export is not None and len(self.config.export.types) > 0:
|
|
90
|
+
self.export()
|
|
91
|
+
self.generate_report()
|
|
92
|
+
self.finalize()
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class LightningTask(Generic[DataModuleT], Task[DataModuleT]):
|
|
96
|
+
"""Base Experiment Task.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
config: The experiment configuration
|
|
100
|
+
checkpoint_path: The path to the checkpoint to load the model from. Defaults to None.
|
|
101
|
+
run_test: Whether to run the test after training. Defaults to False.
|
|
102
|
+
report: Whether to generate a report. Defaults to False.
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
def __init__(
|
|
106
|
+
self,
|
|
107
|
+
config: DictConfig,
|
|
108
|
+
checkpoint_path: str | None = None,
|
|
109
|
+
run_test: bool = False,
|
|
110
|
+
report: bool = False,
|
|
111
|
+
):
|
|
112
|
+
super().__init__(config=config)
|
|
113
|
+
self.checkpoint_path = checkpoint_path
|
|
114
|
+
self.run_test = run_test
|
|
115
|
+
self.report = report
|
|
116
|
+
self._module: LightningModule
|
|
117
|
+
self._devices: int | list[int]
|
|
118
|
+
self._callbacks: list[Callback]
|
|
119
|
+
self._logger: list[Logger]
|
|
120
|
+
self._trainer: Trainer
|
|
121
|
+
|
|
122
|
+
def prepare(self) -> None:
|
|
123
|
+
"""Prepare the experiment."""
|
|
124
|
+
super().prepare()
|
|
125
|
+
|
|
126
|
+
# First setup loggers since some callbacks might need logger setup correctly.
|
|
127
|
+
if "logger" in self.config:
|
|
128
|
+
self.logger = self.config.logger
|
|
129
|
+
|
|
130
|
+
if "callbacks" in self.config:
|
|
131
|
+
self.callbacks = self.config.callbacks
|
|
132
|
+
|
|
133
|
+
self.devices = self.config.trainer.devices
|
|
134
|
+
self.trainer = self.config.trainer
|
|
135
|
+
|
|
136
|
+
@property
|
|
137
|
+
def module(self) -> LightningModule:
|
|
138
|
+
"""LightningModule: The model."""
|
|
139
|
+
return self._module
|
|
140
|
+
|
|
141
|
+
@module.setter
|
|
142
|
+
def module(self, module_config) -> None:
|
|
143
|
+
"""LightningModule: The model."""
|
|
144
|
+
raise NotImplementedError("module must be set in subclass")
|
|
145
|
+
|
|
146
|
+
@property
|
|
147
|
+
def trainer(self) -> Trainer:
|
|
148
|
+
"""Trainer: The trainer."""
|
|
149
|
+
return self._trainer
|
|
150
|
+
|
|
151
|
+
@trainer.setter
|
|
152
|
+
def trainer(self, trainer_config: DictConfig) -> None:
|
|
153
|
+
"""Trainer: The trainer."""
|
|
154
|
+
log.info("Instantiating trainer <%s>", trainer_config["_target_"])
|
|
155
|
+
trainer_config.devices = self.devices
|
|
156
|
+
trainer: Trainer = hydra.utils.instantiate(
|
|
157
|
+
trainer_config,
|
|
158
|
+
callbacks=self.callbacks,
|
|
159
|
+
logger=self.logger,
|
|
160
|
+
_convert_="partial",
|
|
161
|
+
)
|
|
162
|
+
self._trainer = trainer
|
|
163
|
+
|
|
164
|
+
@property
|
|
165
|
+
def callbacks(self) -> list[Callback]:
|
|
166
|
+
"""List[Callback]: The callbacks."""
|
|
167
|
+
return self._callbacks
|
|
168
|
+
|
|
169
|
+
@callbacks.setter
|
|
170
|
+
def callbacks(self, callbacks_config) -> None:
|
|
171
|
+
"""List[Callback]: The callbacks."""
|
|
172
|
+
if self.config.core.get("unit_test"):
|
|
173
|
+
log.info("Unit Testing, skipping callbacks")
|
|
174
|
+
return
|
|
175
|
+
instatiated_callbacks = []
|
|
176
|
+
for _, cb_conf in callbacks_config.items():
|
|
177
|
+
if "_target_" in cb_conf:
|
|
178
|
+
# Disable is a reserved keyword for callbacks, hopefully no callback will use it
|
|
179
|
+
if "disable" in cb_conf:
|
|
180
|
+
if cb_conf["disable"]:
|
|
181
|
+
log.info("Skipping callback <%s> as it is disabled", cb_conf["_target_"])
|
|
182
|
+
continue
|
|
183
|
+
|
|
184
|
+
with open_dict(cb_conf):
|
|
185
|
+
del cb_conf.disable
|
|
186
|
+
|
|
187
|
+
# Skip the gpu stats logger callback if no gpu is available to avoid errors
|
|
188
|
+
if not torch.cuda.is_available() and cb_conf["_target_"] == "nvitop.callbacks.lightning.GpuStatsLogger":
|
|
189
|
+
continue
|
|
190
|
+
|
|
191
|
+
log.info("Instantiating callback <%s>", cb_conf["_target_"])
|
|
192
|
+
instatiated_callbacks.append(hydra.utils.instantiate(cb_conf))
|
|
193
|
+
self._callbacks = instatiated_callbacks
|
|
194
|
+
if len(instatiated_callbacks) <= 0:
|
|
195
|
+
log.warning("No callback found in configuration.")
|
|
196
|
+
|
|
197
|
+
@property
|
|
198
|
+
def logger(self) -> list[Logger]:
|
|
199
|
+
"""List[Logger]: The loggers."""
|
|
200
|
+
return self._logger
|
|
201
|
+
|
|
202
|
+
@logger.setter
|
|
203
|
+
def logger(self, logger_config) -> None:
|
|
204
|
+
"""List[Logger]: The loggers."""
|
|
205
|
+
if self.config.core.get("unit_test"):
|
|
206
|
+
log.info("Unit Testing, skipping loggers")
|
|
207
|
+
return
|
|
208
|
+
instantiated_loggers = []
|
|
209
|
+
for _, lg_conf in logger_config.items():
|
|
210
|
+
if "_target_" in lg_conf:
|
|
211
|
+
log.info("Instantiating logger <%s>", lg_conf["_target_"])
|
|
212
|
+
logger = hydra.utils.instantiate(lg_conf)
|
|
213
|
+
if isinstance(logger, MLFlowLogger):
|
|
214
|
+
validate_artifact_storage(logger)
|
|
215
|
+
instantiated_loggers.append(logger)
|
|
216
|
+
|
|
217
|
+
self._logger = instantiated_loggers
|
|
218
|
+
|
|
219
|
+
if len(instantiated_loggers) <= 0:
|
|
220
|
+
log.warning("No logger found in configuration.")
|
|
221
|
+
|
|
222
|
+
@property
|
|
223
|
+
def devices(self) -> int | list[int]:
|
|
224
|
+
"""List[int]: The devices ids."""
|
|
225
|
+
return self._devices
|
|
226
|
+
|
|
227
|
+
@devices.setter
|
|
228
|
+
def devices(self, devices) -> None:
|
|
229
|
+
"""List[int]: The devices ids."""
|
|
230
|
+
if self.config.trainer.get("accelerator") == "cpu":
|
|
231
|
+
self._devices = self.config.trainer.devices
|
|
232
|
+
return
|
|
233
|
+
|
|
234
|
+
try:
|
|
235
|
+
self._devices = _parse_gpu_ids(devices, include_cuda=True)
|
|
236
|
+
except MisconfigurationException:
|
|
237
|
+
self._devices = 1
|
|
238
|
+
self.config.trainer["accelerator"] = "cpu"
|
|
239
|
+
log.warning("Trying to instantiate GPUs but no GPUs are available, training will be done on CPU")
|
|
240
|
+
|
|
241
|
+
def train(self) -> None:
|
|
242
|
+
"""Train the model."""
|
|
243
|
+
log.info("Starting training!")
|
|
244
|
+
utils.log_hyperparameters(
|
|
245
|
+
config=self.config,
|
|
246
|
+
model=self.module,
|
|
247
|
+
trainer=self.trainer,
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
self.trainer.fit(model=self.module, datamodule=self.datamodule)
|
|
251
|
+
|
|
252
|
+
def test(self) -> Any:
|
|
253
|
+
"""Test the model."""
|
|
254
|
+
log.info("Starting testing!")
|
|
255
|
+
|
|
256
|
+
best_model = None
|
|
257
|
+
if (
|
|
258
|
+
self.trainer.checkpoint_callback is not None
|
|
259
|
+
and hasattr(self.trainer.checkpoint_callback, "best_model_path")
|
|
260
|
+
and self.trainer.checkpoint_callback.best_model_path is not None
|
|
261
|
+
and len(self.trainer.checkpoint_callback.best_model_path) > 0
|
|
262
|
+
):
|
|
263
|
+
best_model = self.trainer.checkpoint_callback.best_model_path
|
|
264
|
+
|
|
265
|
+
if best_model is None:
|
|
266
|
+
log.warning(
|
|
267
|
+
"No best checkpoint model found, using last weights for test, this might lead to worse results, "
|
|
268
|
+
"consider using a checkpoint callback."
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
return self.trainer.test(model=self.module, datamodule=self.datamodule, ckpt_path=best_model)
|
|
272
|
+
|
|
273
|
+
def finalize(self) -> None:
|
|
274
|
+
"""Finalize the experiment."""
|
|
275
|
+
super().finalize()
|
|
276
|
+
utils.finish(
|
|
277
|
+
config=self.config,
|
|
278
|
+
module=self.module,
|
|
279
|
+
datamodule=self.datamodule,
|
|
280
|
+
trainer=self.trainer,
|
|
281
|
+
callbacks=self.callbacks,
|
|
282
|
+
logger=self.logger,
|
|
283
|
+
export_folder=self.export_folder,
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
if (
|
|
287
|
+
not self.config.trainer.get("fast_dev_run")
|
|
288
|
+
and self.trainer.checkpoint_callback is not None
|
|
289
|
+
and hasattr(self.trainer.checkpoint_callback, "best_model_path")
|
|
290
|
+
):
|
|
291
|
+
log.info("Best model ckpt: %s", self.trainer.checkpoint_callback.best_model_path)
|
|
292
|
+
|
|
293
|
+
def add_callback(self, callback: Callback):
|
|
294
|
+
"""Add a callback to the trainer.
|
|
295
|
+
|
|
296
|
+
Args:
|
|
297
|
+
callback: The callback to add
|
|
298
|
+
"""
|
|
299
|
+
if hasattr(self.trainer, "callbacks") and isinstance(self.trainer.callbacks, list):
|
|
300
|
+
self.trainer.callbacks.append(callback)
|
|
301
|
+
|
|
302
|
+
def execute(self) -> None:
|
|
303
|
+
"""Execute the experiment and all the steps."""
|
|
304
|
+
self.prepare()
|
|
305
|
+
self.train()
|
|
306
|
+
if self.run_test:
|
|
307
|
+
self.test()
|
|
308
|
+
if self.config.export is not None and len(self.config.export.types) > 0:
|
|
309
|
+
self.export()
|
|
310
|
+
if self.report:
|
|
311
|
+
self.generate_report()
|
|
312
|
+
self.finalize()
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
class PlaceholderTask(Task):
|
|
316
|
+
"""Placeholder task."""
|
|
317
|
+
|
|
318
|
+
def execute(self) -> None:
|
|
319
|
+
"""Execute the task and all the steps."""
|
|
320
|
+
log.info("Running Placeholder Task.")
|
|
321
|
+
log.info("Quadra Version: %s", str(get_version()))
|
|
322
|
+
log.info("If you are reading this, it means that library is installed correctly!")
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
class Evaluation(Generic[DataModuleT], Task[DataModuleT]):
|
|
326
|
+
"""Base Evaluation Task with deployment models.
|
|
327
|
+
|
|
328
|
+
Args:
|
|
329
|
+
config: The experiment configuration
|
|
330
|
+
model_path: The model path.
|
|
331
|
+
device: Device to use for evaluation. If None, the device is automatically determined.
|
|
332
|
+
|
|
333
|
+
"""
|
|
334
|
+
|
|
335
|
+
def __init__(
|
|
336
|
+
self,
|
|
337
|
+
config: DictConfig,
|
|
338
|
+
model_path: str,
|
|
339
|
+
device: str | None = None,
|
|
340
|
+
):
|
|
341
|
+
super().__init__(config=config)
|
|
342
|
+
|
|
343
|
+
if device is None:
|
|
344
|
+
self.device = utils.get_device()
|
|
345
|
+
else:
|
|
346
|
+
self.device = device
|
|
347
|
+
|
|
348
|
+
self.config = config
|
|
349
|
+
self.model_data: dict[str, Any]
|
|
350
|
+
self.model_path = model_path
|
|
351
|
+
self._deployment_model: BaseEvaluationModel
|
|
352
|
+
self.deployment_model_type: str
|
|
353
|
+
self.model_info_filename = "model.json"
|
|
354
|
+
self.report_path = ""
|
|
355
|
+
self.metadata = {"report_files": []}
|
|
356
|
+
|
|
357
|
+
@property
|
|
358
|
+
def deployment_model(self) -> BaseEvaluationModel:
|
|
359
|
+
"""Deployment model."""
|
|
360
|
+
return self._deployment_model
|
|
361
|
+
|
|
362
|
+
@deployment_model.setter
|
|
363
|
+
def deployment_model(self, model_path: str):
|
|
364
|
+
"""Set the deployment model."""
|
|
365
|
+
self._deployment_model = import_deployment_model(
|
|
366
|
+
model_path=model_path, device=self.device, inference_config=self.config.inference
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
def prepare(self) -> None:
|
|
370
|
+
"""Prepare the evaluation."""
|
|
371
|
+
with open(os.path.join(Path(self.model_path).parent, self.model_info_filename)) as f:
|
|
372
|
+
self.model_data = json.load(f)
|
|
373
|
+
|
|
374
|
+
if not isinstance(self.model_data, dict):
|
|
375
|
+
raise ValueError("Model info file is not a valid json")
|
|
376
|
+
|
|
377
|
+
for input_size in self.model_data["input_size"]:
|
|
378
|
+
if len(input_size) != 3:
|
|
379
|
+
continue
|
|
380
|
+
|
|
381
|
+
# Adjust the transform for 2D models (CxHxW)
|
|
382
|
+
# We assume that each input size has the same height and width
|
|
383
|
+
if input_size[1] != self.config.transforms.input_height:
|
|
384
|
+
log.warning(
|
|
385
|
+
f"Input height of the model ({input_size[1]}) is different from the one specified "
|
|
386
|
+
+ f"in the config ({self.config.transforms.input_height}). Fixing the config."
|
|
387
|
+
)
|
|
388
|
+
self.config.transforms.input_height = input_size[1]
|
|
389
|
+
|
|
390
|
+
if input_size[2] != self.config.transforms.input_width:
|
|
391
|
+
log.warning(
|
|
392
|
+
f"Input width of the model ({input_size[2]}) is different from the one specified "
|
|
393
|
+
+ f"in the config ({self.config.transforms.input_width}). Fixing the config."
|
|
394
|
+
)
|
|
395
|
+
self.config.transforms.input_width = input_size[2]
|
|
396
|
+
|
|
397
|
+
self.deployment_model = self.model_path # type: ignore[assignment]
|