quadra 0.0.1__py3-none-any.whl → 2.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hydra_plugins/quadra_searchpath_plugin.py +30 -0
- quadra/__init__.py +6 -0
- quadra/callbacks/__init__.py +0 -0
- quadra/callbacks/anomalib.py +289 -0
- quadra/callbacks/lightning.py +501 -0
- quadra/callbacks/mlflow.py +291 -0
- quadra/callbacks/scheduler.py +69 -0
- quadra/configs/__init__.py +0 -0
- quadra/configs/backbone/caformer_m36.yaml +8 -0
- quadra/configs/backbone/caformer_s36.yaml +8 -0
- quadra/configs/backbone/convnextv2_base.yaml +8 -0
- quadra/configs/backbone/convnextv2_femto.yaml +8 -0
- quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
- quadra/configs/backbone/dino_vitb8.yaml +12 -0
- quadra/configs/backbone/dino_vits8.yaml +12 -0
- quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
- quadra/configs/backbone/dinov2_vits14.yaml +12 -0
- quadra/configs/backbone/efficientnet_b0.yaml +8 -0
- quadra/configs/backbone/efficientnet_b1.yaml +8 -0
- quadra/configs/backbone/efficientnet_b2.yaml +8 -0
- quadra/configs/backbone/efficientnet_b3.yaml +8 -0
- quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
- quadra/configs/backbone/levit_128s.yaml +8 -0
- quadra/configs/backbone/mnasnet0_5.yaml +9 -0
- quadra/configs/backbone/resnet101.yaml +8 -0
- quadra/configs/backbone/resnet18.yaml +8 -0
- quadra/configs/backbone/resnet18_ssl.yaml +8 -0
- quadra/configs/backbone/resnet50.yaml +8 -0
- quadra/configs/backbone/smp.yaml +9 -0
- quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
- quadra/configs/backbone/unetr.yaml +15 -0
- quadra/configs/backbone/vit16_base.yaml +9 -0
- quadra/configs/backbone/vit16_small.yaml +9 -0
- quadra/configs/backbone/vit16_tiny.yaml +9 -0
- quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
- quadra/configs/callbacks/all.yaml +45 -0
- quadra/configs/callbacks/default.yaml +34 -0
- quadra/configs/callbacks/default_anomalib.yaml +64 -0
- quadra/configs/config.yaml +33 -0
- quadra/configs/core/default.yaml +11 -0
- quadra/configs/datamodule/base/anomaly.yaml +16 -0
- quadra/configs/datamodule/base/classification.yaml +21 -0
- quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
- quadra/configs/datamodule/base/segmentation.yaml +18 -0
- quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
- quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
- quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
- quadra/configs/datamodule/base/ssl.yaml +21 -0
- quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
- quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
- quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
- quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
- quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
- quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
- quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
- quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
- quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
- quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
- quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/base/classification/classification.yaml +73 -0
- quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
- quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
- quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
- quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
- quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
- quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
- quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
- quadra/configs/experiment/base/ssl/byol.yaml +43 -0
- quadra/configs/experiment/base/ssl/dino.yaml +46 -0
- quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
- quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
- quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
- quadra/configs/experiment/custom/cls.yaml +12 -0
- quadra/configs/experiment/default.yaml +15 -0
- quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
- quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
- quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
- quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
- quadra/configs/export/default.yaml +13 -0
- quadra/configs/hydra/anomaly_custom.yaml +15 -0
- quadra/configs/hydra/default.yaml +14 -0
- quadra/configs/inference/default.yaml +26 -0
- quadra/configs/logger/comet.yaml +10 -0
- quadra/configs/logger/csv.yaml +5 -0
- quadra/configs/logger/mlflow.yaml +12 -0
- quadra/configs/logger/tensorboard.yaml +8 -0
- quadra/configs/loss/asl.yaml +7 -0
- quadra/configs/loss/barlow.yaml +2 -0
- quadra/configs/loss/bce.yaml +1 -0
- quadra/configs/loss/byol.yaml +1 -0
- quadra/configs/loss/cross_entropy.yaml +1 -0
- quadra/configs/loss/dino.yaml +8 -0
- quadra/configs/loss/simclr.yaml +2 -0
- quadra/configs/loss/simsiam.yaml +1 -0
- quadra/configs/loss/smp_ce.yaml +3 -0
- quadra/configs/loss/smp_dice.yaml +2 -0
- quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
- quadra/configs/loss/smp_mcc.yaml +2 -0
- quadra/configs/loss/vicreg.yaml +5 -0
- quadra/configs/model/anomalib/cfa.yaml +35 -0
- quadra/configs/model/anomalib/cflow.yaml +30 -0
- quadra/configs/model/anomalib/csflow.yaml +34 -0
- quadra/configs/model/anomalib/dfm.yaml +19 -0
- quadra/configs/model/anomalib/draem.yaml +29 -0
- quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
- quadra/configs/model/anomalib/fastflow.yaml +32 -0
- quadra/configs/model/anomalib/padim.yaml +32 -0
- quadra/configs/model/anomalib/patchcore.yaml +36 -0
- quadra/configs/model/barlow.yaml +16 -0
- quadra/configs/model/byol.yaml +25 -0
- quadra/configs/model/classification.yaml +10 -0
- quadra/configs/model/dino.yaml +26 -0
- quadra/configs/model/logistic_regression.yaml +4 -0
- quadra/configs/model/multilabel_classification.yaml +9 -0
- quadra/configs/model/simclr.yaml +18 -0
- quadra/configs/model/simsiam.yaml +24 -0
- quadra/configs/model/smp.yaml +4 -0
- quadra/configs/model/smp_multiclass.yaml +4 -0
- quadra/configs/model/vicreg.yaml +16 -0
- quadra/configs/optimizer/adam.yaml +5 -0
- quadra/configs/optimizer/adamw.yaml +3 -0
- quadra/configs/optimizer/default.yaml +4 -0
- quadra/configs/optimizer/lars.yaml +8 -0
- quadra/configs/optimizer/sgd.yaml +4 -0
- quadra/configs/scheduler/default.yaml +5 -0
- quadra/configs/scheduler/rop.yaml +5 -0
- quadra/configs/scheduler/step.yaml +3 -0
- quadra/configs/scheduler/warmrestart.yaml +2 -0
- quadra/configs/scheduler/warmup.yaml +6 -0
- quadra/configs/task/anomalib/cfa.yaml +5 -0
- quadra/configs/task/anomalib/cflow.yaml +5 -0
- quadra/configs/task/anomalib/csflow.yaml +5 -0
- quadra/configs/task/anomalib/draem.yaml +5 -0
- quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
- quadra/configs/task/anomalib/fastflow.yaml +5 -0
- quadra/configs/task/anomalib/inference.yaml +3 -0
- quadra/configs/task/anomalib/padim.yaml +5 -0
- quadra/configs/task/anomalib/patchcore.yaml +5 -0
- quadra/configs/task/classification.yaml +6 -0
- quadra/configs/task/classification_evaluation.yaml +6 -0
- quadra/configs/task/default.yaml +1 -0
- quadra/configs/task/segmentation.yaml +9 -0
- quadra/configs/task/segmentation_evaluation.yaml +3 -0
- quadra/configs/task/sklearn_classification.yaml +13 -0
- quadra/configs/task/sklearn_classification_patch.yaml +11 -0
- quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
- quadra/configs/task/sklearn_classification_test.yaml +8 -0
- quadra/configs/task/ssl.yaml +2 -0
- quadra/configs/trainer/lightning_cpu.yaml +36 -0
- quadra/configs/trainer/lightning_gpu.yaml +35 -0
- quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
- quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
- quadra/configs/trainer/lightning_multigpu.yaml +37 -0
- quadra/configs/trainer/sklearn_classification.yaml +7 -0
- quadra/configs/transforms/byol.yaml +47 -0
- quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
- quadra/configs/transforms/default.yaml +37 -0
- quadra/configs/transforms/default_numpy.yaml +24 -0
- quadra/configs/transforms/default_resize.yaml +22 -0
- quadra/configs/transforms/dino.yaml +63 -0
- quadra/configs/transforms/linear_eval.yaml +18 -0
- quadra/datamodules/__init__.py +20 -0
- quadra/datamodules/anomaly.py +180 -0
- quadra/datamodules/base.py +375 -0
- quadra/datamodules/classification.py +1003 -0
- quadra/datamodules/generic/__init__.py +0 -0
- quadra/datamodules/generic/imagenette.py +144 -0
- quadra/datamodules/generic/mnist.py +81 -0
- quadra/datamodules/generic/mvtec.py +58 -0
- quadra/datamodules/generic/oxford_pet.py +163 -0
- quadra/datamodules/patch.py +190 -0
- quadra/datamodules/segmentation.py +742 -0
- quadra/datamodules/ssl.py +140 -0
- quadra/datasets/__init__.py +17 -0
- quadra/datasets/anomaly.py +287 -0
- quadra/datasets/classification.py +241 -0
- quadra/datasets/patch.py +138 -0
- quadra/datasets/segmentation.py +239 -0
- quadra/datasets/ssl.py +110 -0
- quadra/losses/__init__.py +0 -0
- quadra/losses/classification/__init__.py +6 -0
- quadra/losses/classification/asl.py +83 -0
- quadra/losses/classification/focal.py +320 -0
- quadra/losses/classification/prototypical.py +148 -0
- quadra/losses/ssl/__init__.py +17 -0
- quadra/losses/ssl/barlowtwins.py +47 -0
- quadra/losses/ssl/byol.py +37 -0
- quadra/losses/ssl/dino.py +129 -0
- quadra/losses/ssl/hyperspherical.py +45 -0
- quadra/losses/ssl/idmm.py +50 -0
- quadra/losses/ssl/simclr.py +67 -0
- quadra/losses/ssl/simsiam.py +30 -0
- quadra/losses/ssl/vicreg.py +76 -0
- quadra/main.py +49 -0
- quadra/metrics/__init__.py +3 -0
- quadra/metrics/segmentation.py +251 -0
- quadra/models/__init__.py +0 -0
- quadra/models/base.py +151 -0
- quadra/models/classification/__init__.py +8 -0
- quadra/models/classification/backbones.py +149 -0
- quadra/models/classification/base.py +92 -0
- quadra/models/evaluation.py +322 -0
- quadra/modules/__init__.py +0 -0
- quadra/modules/backbone.py +30 -0
- quadra/modules/base.py +312 -0
- quadra/modules/classification/__init__.py +3 -0
- quadra/modules/classification/base.py +327 -0
- quadra/modules/ssl/__init__.py +17 -0
- quadra/modules/ssl/barlowtwins.py +59 -0
- quadra/modules/ssl/byol.py +172 -0
- quadra/modules/ssl/common.py +285 -0
- quadra/modules/ssl/dino.py +186 -0
- quadra/modules/ssl/hyperspherical.py +206 -0
- quadra/modules/ssl/idmm.py +98 -0
- quadra/modules/ssl/simclr.py +73 -0
- quadra/modules/ssl/simsiam.py +68 -0
- quadra/modules/ssl/vicreg.py +67 -0
- quadra/optimizers/__init__.py +4 -0
- quadra/optimizers/lars.py +153 -0
- quadra/optimizers/sam.py +127 -0
- quadra/schedulers/__init__.py +3 -0
- quadra/schedulers/base.py +44 -0
- quadra/schedulers/warmup.py +127 -0
- quadra/tasks/__init__.py +24 -0
- quadra/tasks/anomaly.py +582 -0
- quadra/tasks/base.py +397 -0
- quadra/tasks/classification.py +1263 -0
- quadra/tasks/patch.py +492 -0
- quadra/tasks/segmentation.py +389 -0
- quadra/tasks/ssl.py +560 -0
- quadra/trainers/README.md +3 -0
- quadra/trainers/__init__.py +0 -0
- quadra/trainers/classification.py +179 -0
- quadra/utils/__init__.py +0 -0
- quadra/utils/anomaly.py +112 -0
- quadra/utils/classification.py +618 -0
- quadra/utils/deprecation.py +31 -0
- quadra/utils/evaluation.py +474 -0
- quadra/utils/export.py +585 -0
- quadra/utils/imaging.py +32 -0
- quadra/utils/logger.py +15 -0
- quadra/utils/mlflow.py +98 -0
- quadra/utils/model_manager.py +320 -0
- quadra/utils/models.py +523 -0
- quadra/utils/patch/__init__.py +15 -0
- quadra/utils/patch/dataset.py +1433 -0
- quadra/utils/patch/metrics.py +449 -0
- quadra/utils/patch/model.py +153 -0
- quadra/utils/patch/visualization.py +217 -0
- quadra/utils/resolver.py +42 -0
- quadra/utils/segmentation.py +31 -0
- quadra/utils/tests/__init__.py +0 -0
- quadra/utils/tests/fixtures/__init__.py +1 -0
- quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
- quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
- quadra/utils/tests/fixtures/dataset/classification.py +406 -0
- quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
- quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
- quadra/utils/tests/fixtures/models/__init__.py +3 -0
- quadra/utils/tests/fixtures/models/anomaly.py +89 -0
- quadra/utils/tests/fixtures/models/classification.py +45 -0
- quadra/utils/tests/fixtures/models/segmentation.py +33 -0
- quadra/utils/tests/helpers.py +70 -0
- quadra/utils/tests/models.py +27 -0
- quadra/utils/utils.py +525 -0
- quadra/utils/validator.py +115 -0
- quadra/utils/visualization.py +422 -0
- quadra/utils/vit_explainability.py +349 -0
- quadra-2.2.7.dist-info/LICENSE +201 -0
- quadra-2.2.7.dist-info/METADATA +381 -0
- quadra-2.2.7.dist-info/RECORD +300 -0
- {quadra-0.0.1.dist-info → quadra-2.2.7.dist-info}/WHEEL +1 -1
- quadra-2.2.7.dist-info/entry_points.txt +3 -0
- quadra-0.0.1.dist-info/METADATA +0 -14
- quadra-0.0.1.dist-info/RECORD +0 -4
quadra/utils/utils.py
ADDED
|
@@ -0,0 +1,525 @@
|
|
|
1
|
+
"""Common utility functions.
|
|
2
|
+
Some of them are mostly based on https://github.com/ashleve/lightning-hydra-template.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import glob
|
|
8
|
+
import json
|
|
9
|
+
import logging
|
|
10
|
+
import os
|
|
11
|
+
import subprocess
|
|
12
|
+
import sys
|
|
13
|
+
import warnings
|
|
14
|
+
from collections.abc import Iterable, Iterator, Sequence
|
|
15
|
+
from typing import Any, cast
|
|
16
|
+
|
|
17
|
+
import cv2
|
|
18
|
+
import dotenv
|
|
19
|
+
import mlflow
|
|
20
|
+
import numpy as np
|
|
21
|
+
import pytorch_lightning as pl
|
|
22
|
+
import rich.syntax
|
|
23
|
+
import rich.tree
|
|
24
|
+
import torch
|
|
25
|
+
from hydra.core.hydra_config import HydraConfig
|
|
26
|
+
from hydra.utils import get_original_cwd
|
|
27
|
+
from lightning_fabric.utilities.device_parser import _parse_gpu_ids
|
|
28
|
+
from omegaconf import DictConfig, OmegaConf
|
|
29
|
+
from pytorch_lightning.loggers import TensorBoardLogger
|
|
30
|
+
from pytorch_lightning.utilities import rank_zero_only
|
|
31
|
+
|
|
32
|
+
import quadra
|
|
33
|
+
import quadra.utils.export as quadra_export
|
|
34
|
+
from quadra.callbacks.mlflow import get_mlflow_logger
|
|
35
|
+
from quadra.utils.mlflow import infer_signature_model
|
|
36
|
+
|
|
37
|
+
try:
|
|
38
|
+
import onnx # noqa
|
|
39
|
+
|
|
40
|
+
ONNX_AVAILABLE = True
|
|
41
|
+
except ImportError:
|
|
42
|
+
ONNX_AVAILABLE = False
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
IMAGE_EXTENSIONS: list[str] = [".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".tif", ".pbm", ".pgm", ".ppm", ".pxm", ".pnm"]
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def get_logger(name=__name__) -> logging.Logger:
|
|
49
|
+
"""Initializes multi-GPU-friendly python logger."""
|
|
50
|
+
logger = logging.getLogger(name)
|
|
51
|
+
|
|
52
|
+
# this ensures all logging levels get marked with the rank zero decorator
|
|
53
|
+
# otherwise logs would get multiplied for each GPU process in multi-GPU setup
|
|
54
|
+
for level in ("debug", "info", "warning", "error", "exception", "fatal", "critical"):
|
|
55
|
+
setattr(logger, level, rank_zero_only(getattr(logger, level)))
|
|
56
|
+
|
|
57
|
+
return logger
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def extras(config: DictConfig) -> None:
|
|
61
|
+
"""A couple of optional utilities, controlled by main config file:
|
|
62
|
+
- disabling warnings
|
|
63
|
+
- forcing debug friendly configuration
|
|
64
|
+
- verifying experiment name is set when running in experiment mode.
|
|
65
|
+
Modifies DictConfig in place.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
config: Configuration composed by Hydra.
|
|
69
|
+
"""
|
|
70
|
+
logging.basicConfig()
|
|
71
|
+
logging.getLogger().setLevel(config.core.log_level.upper())
|
|
72
|
+
|
|
73
|
+
log = get_logger(__name__)
|
|
74
|
+
config.core.command += " ".join(sys.argv)
|
|
75
|
+
config.core.experiment_path = os.getcwd()
|
|
76
|
+
|
|
77
|
+
# disable python warnings if <config.ignore_warnings=True>
|
|
78
|
+
if config.get("ignore_warnings"):
|
|
79
|
+
log.info("Disabling python warnings! <config.ignore_warnings=True>")
|
|
80
|
+
warnings.filterwarnings("ignore")
|
|
81
|
+
|
|
82
|
+
# force debugger friendly configuration if <config.trainer.fast_dev_run=True>
|
|
83
|
+
# debuggers don't like GPUs and multiprocessing
|
|
84
|
+
if config.get("trainer") and config.trainer.get("fast_dev_run"):
|
|
85
|
+
log.info("Forcing debugger friendly configuration! <config.trainer.fast_dev_run=True>")
|
|
86
|
+
if config.trainer.get("gpus"):
|
|
87
|
+
config.trainer.devices = 1
|
|
88
|
+
config.trainer.accelerator = "cpu"
|
|
89
|
+
config.trainer.gpus = None
|
|
90
|
+
if config.datamodule.get("pin_memory"):
|
|
91
|
+
config.datamodule.pin_memory = False
|
|
92
|
+
if config.datamodule.get("num_workers"):
|
|
93
|
+
config.datamodule.num_workers = 0
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
@rank_zero_only
|
|
97
|
+
def print_config(
|
|
98
|
+
config: DictConfig,
|
|
99
|
+
fields: Sequence[str] = (
|
|
100
|
+
"task",
|
|
101
|
+
"trainer",
|
|
102
|
+
"model",
|
|
103
|
+
"datamodule",
|
|
104
|
+
"callbacks",
|
|
105
|
+
"logger",
|
|
106
|
+
"core",
|
|
107
|
+
"backbone",
|
|
108
|
+
"transforms",
|
|
109
|
+
"optimizer",
|
|
110
|
+
"scheduler",
|
|
111
|
+
),
|
|
112
|
+
resolve: bool = True,
|
|
113
|
+
) -> None:
|
|
114
|
+
"""Prints content of DictConfig using Rich library and its tree structure.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
config: Configuration composed by Hydra.
|
|
118
|
+
fields: Determines which main fields from config will
|
|
119
|
+
be printed and in what order.
|
|
120
|
+
resolve: Whether to resolve reference fields of DictConfig.
|
|
121
|
+
"""
|
|
122
|
+
style = "dim"
|
|
123
|
+
tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
|
|
124
|
+
|
|
125
|
+
for field in fields:
|
|
126
|
+
branch = tree.add(field, style=style, guide_style=style)
|
|
127
|
+
|
|
128
|
+
config_section = config.get(field)
|
|
129
|
+
branch_content = str(config_section)
|
|
130
|
+
if isinstance(config_section, DictConfig):
|
|
131
|
+
branch_content = OmegaConf.to_yaml(config_section, resolve=resolve)
|
|
132
|
+
|
|
133
|
+
branch.add(rich.syntax.Syntax(branch_content, "yaml"))
|
|
134
|
+
|
|
135
|
+
rich.print(tree)
|
|
136
|
+
|
|
137
|
+
with open("config_tree.txt", "w") as fp:
|
|
138
|
+
rich.print(tree, file=fp)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
@rank_zero_only
|
|
142
|
+
def log_hyperparameters(
|
|
143
|
+
config: DictConfig,
|
|
144
|
+
model: pl.LightningModule,
|
|
145
|
+
trainer: pl.Trainer,
|
|
146
|
+
) -> None:
|
|
147
|
+
"""This method controls which parameters from Hydra config are saved by Lightning loggers.
|
|
148
|
+
|
|
149
|
+
Additionaly saves:
|
|
150
|
+
- number of trainable model parameters
|
|
151
|
+
"""
|
|
152
|
+
log = get_logger(__name__)
|
|
153
|
+
|
|
154
|
+
if not HydraConfig.initialized() or trainer.logger is None:
|
|
155
|
+
return
|
|
156
|
+
|
|
157
|
+
log.info("Logging hyperparameters!")
|
|
158
|
+
hydra_cfg = HydraConfig.get()
|
|
159
|
+
hydra_choices = OmegaConf.to_container(hydra_cfg.runtime.choices)
|
|
160
|
+
if isinstance(hydra_choices, dict):
|
|
161
|
+
# For multirun override the choices that are not automatically updated
|
|
162
|
+
for item in hydra_cfg.overrides.task:
|
|
163
|
+
if "." in item:
|
|
164
|
+
continue
|
|
165
|
+
|
|
166
|
+
override, value = item.split("=")
|
|
167
|
+
hydra_choices[override] = value
|
|
168
|
+
|
|
169
|
+
hparams = {}
|
|
170
|
+
hydra_choices_final = {}
|
|
171
|
+
for k, v in hydra_choices.items():
|
|
172
|
+
if isinstance(k, str):
|
|
173
|
+
k_replaced = k.replace("@", "_at_")
|
|
174
|
+
hydra_choices_final[k_replaced] = v
|
|
175
|
+
if v is not None and isinstance(v, str) and "@" in v:
|
|
176
|
+
hydra_choices_final[k_replaced] = v.replace("@", "_at_")
|
|
177
|
+
|
|
178
|
+
hparams.update(hydra_choices_final)
|
|
179
|
+
else:
|
|
180
|
+
logging.warning("Hydra choices is not a dictionary, skip adding them to the logger")
|
|
181
|
+
# save number of model parameters
|
|
182
|
+
hparams["model/params_total"] = sum(p.numel() for p in model.parameters())
|
|
183
|
+
hparams["model/params_trainable"] = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
184
|
+
hparams["model/params_not_trainable"] = sum(p.numel() for p in model.parameters() if not p.requires_grad)
|
|
185
|
+
hparams["experiment_path"] = config.core.experiment_path
|
|
186
|
+
hparams["command"] = config.core.command
|
|
187
|
+
hparams["library/version"] = str(quadra.__version__)
|
|
188
|
+
|
|
189
|
+
with open(os.devnull, "w") as fnull:
|
|
190
|
+
if subprocess.call(["git", "-C", get_original_cwd(), "status"], stderr=subprocess.STDOUT, stdout=fnull) == 0:
|
|
191
|
+
try:
|
|
192
|
+
hparams["git/commit"] = subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("ascii").strip()
|
|
193
|
+
hparams["git/branch"] = (
|
|
194
|
+
subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"]).decode("ascii").strip()
|
|
195
|
+
)
|
|
196
|
+
hparams["git/remote"] = (
|
|
197
|
+
subprocess.check_output(["git", "remote", "get-url", "origin"]).decode("ascii").strip()
|
|
198
|
+
)
|
|
199
|
+
except subprocess.CalledProcessError:
|
|
200
|
+
log.warning(
|
|
201
|
+
"Could not get git commit, branch or remote information, the repository might not have any commits "
|
|
202
|
+
" yet or it might have been initialized wrongly."
|
|
203
|
+
)
|
|
204
|
+
else:
|
|
205
|
+
log.warning("Could not find git repository, skipping git commit and branch info")
|
|
206
|
+
|
|
207
|
+
# send hparams to all loggers
|
|
208
|
+
trainer.logger.log_hyperparams(hparams)
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def upload_file_tensorboard(file_path: str, tensorboard_logger: TensorBoardLogger) -> None:
|
|
212
|
+
"""Upload a file to tensorboard handling different extensions.
|
|
213
|
+
|
|
214
|
+
Args:
|
|
215
|
+
file_path: Path to the file to upload.
|
|
216
|
+
tensorboard_logger: Tensorboard logger instance.
|
|
217
|
+
"""
|
|
218
|
+
tag = os.path.basename(file_path)
|
|
219
|
+
ext = os.path.splitext(file_path)[1].lower()
|
|
220
|
+
|
|
221
|
+
if ext == ".json":
|
|
222
|
+
with open(file_path) as f:
|
|
223
|
+
json_content = json.load(f)
|
|
224
|
+
|
|
225
|
+
json_content = f"```json\n{json.dumps(json_content, indent=4)}\n```"
|
|
226
|
+
tensorboard_logger.experiment.add_text(tag=tag, text_string=json_content, global_step=0)
|
|
227
|
+
elif ext in [".yaml", ".yml"]:
|
|
228
|
+
with open(file_path) as f:
|
|
229
|
+
yaml_content = f.read()
|
|
230
|
+
yaml_content = f"```yaml\n{yaml_content}\n```"
|
|
231
|
+
tensorboard_logger.experiment.add_text(tag=tag, text_string=yaml_content, global_step=0)
|
|
232
|
+
else:
|
|
233
|
+
with open(file_path, encoding="utf-8") as f:
|
|
234
|
+
tensorboard_logger.experiment.add_text(tag=tag, text_string=f.read().replace("\n", " \n"), global_step=0)
|
|
235
|
+
|
|
236
|
+
tensorboard_logger.experiment.flush()
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def finish(
|
|
240
|
+
config: DictConfig,
|
|
241
|
+
module: pl.LightningModule,
|
|
242
|
+
datamodule: pl.LightningDataModule,
|
|
243
|
+
trainer: pl.Trainer,
|
|
244
|
+
callbacks: list[pl.Callback],
|
|
245
|
+
logger: list[pl.loggers.Logger],
|
|
246
|
+
export_folder: str,
|
|
247
|
+
) -> None:
|
|
248
|
+
"""Upload config files to MLFlow server.
|
|
249
|
+
|
|
250
|
+
Args:
|
|
251
|
+
config: Configuration composed by Hydra.
|
|
252
|
+
module: LightningModule.
|
|
253
|
+
datamodule: LightningDataModule.
|
|
254
|
+
trainer: LightningTrainer.
|
|
255
|
+
callbacks: List of LightningCallbacks.
|
|
256
|
+
logger: List of LightningLoggers.
|
|
257
|
+
export_folder: Folder where the deployment models are exported.
|
|
258
|
+
"""
|
|
259
|
+
# pylint: disable=unused-argument
|
|
260
|
+
if len(logger) > 0 and config.core.get("upload_artifacts"):
|
|
261
|
+
mlflow_logger = get_mlflow_logger(trainer=trainer)
|
|
262
|
+
tensorboard_logger = get_tensorboard_logger(trainer=trainer)
|
|
263
|
+
file_names = ["config.yaml", "config_resolved.yaml", "config_tree.txt", "data/dataset.csv"]
|
|
264
|
+
if "16" in str(trainer.precision):
|
|
265
|
+
index = _parse_gpu_ids(config.trainer.devices, include_cuda=True)[0]
|
|
266
|
+
device = "cuda:" + str(index)
|
|
267
|
+
half_precision = True
|
|
268
|
+
else:
|
|
269
|
+
device = "cpu"
|
|
270
|
+
half_precision = False
|
|
271
|
+
|
|
272
|
+
if mlflow_logger is not None:
|
|
273
|
+
config_paths = []
|
|
274
|
+
|
|
275
|
+
for f in file_names:
|
|
276
|
+
if os.path.isfile(os.path.join(os.getcwd(), f)):
|
|
277
|
+
config_paths.append(os.path.join(os.getcwd(), f))
|
|
278
|
+
|
|
279
|
+
for path in config_paths:
|
|
280
|
+
mlflow_logger.experiment.log_artifact(
|
|
281
|
+
run_id=mlflow_logger.run_id, local_path=path, artifact_path="metadata"
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
deployed_models = glob.glob(os.path.join(export_folder, "*"))
|
|
285
|
+
model_json: dict[str, Any] | None = None
|
|
286
|
+
|
|
287
|
+
if os.path.exists(os.path.join(export_folder, "model.json")):
|
|
288
|
+
with open(os.path.join(export_folder, "model.json")) as json_file:
|
|
289
|
+
model_json = json.load(json_file)
|
|
290
|
+
|
|
291
|
+
if model_json is not None:
|
|
292
|
+
input_size = model_json["input_size"]
|
|
293
|
+
# Not a huge fan of this check
|
|
294
|
+
if not isinstance(input_size[0], list):
|
|
295
|
+
# Input size is not a list of lists
|
|
296
|
+
input_size = [input_size]
|
|
297
|
+
inputs = cast(
|
|
298
|
+
list[Any],
|
|
299
|
+
quadra_export.generate_torch_inputs(input_size, device=device, half_precision=half_precision),
|
|
300
|
+
)
|
|
301
|
+
types_to_upload = config.core.get("upload_models")
|
|
302
|
+
for model_path in deployed_models:
|
|
303
|
+
model_type = model_type_from_path(model_path)
|
|
304
|
+
if model_type is None:
|
|
305
|
+
logging.warning("%s model type not supported", model_path)
|
|
306
|
+
continue
|
|
307
|
+
if model_type is not None and model_type in types_to_upload:
|
|
308
|
+
if model_type == "pytorch":
|
|
309
|
+
logging.warning("Pytorch format still not supported for mlflow upload")
|
|
310
|
+
continue
|
|
311
|
+
|
|
312
|
+
model = quadra_export.import_deployment_model(
|
|
313
|
+
model_path,
|
|
314
|
+
device=device,
|
|
315
|
+
inference_config=config.inference,
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
if model_type in ["torchscript", "pytorch"]:
|
|
319
|
+
signature = infer_signature_model(model.model, inputs)
|
|
320
|
+
with mlflow.start_run(run_id=mlflow_logger.run_id) as _:
|
|
321
|
+
mlflow.pytorch.log_model(
|
|
322
|
+
model.model,
|
|
323
|
+
artifact_path=model_path,
|
|
324
|
+
signature=signature,
|
|
325
|
+
)
|
|
326
|
+
elif model_type in ["onnx", "simplified_onnx"] and ONNX_AVAILABLE:
|
|
327
|
+
signature = infer_signature_model(model, inputs)
|
|
328
|
+
with mlflow.start_run(run_id=mlflow_logger.run_id) as _:
|
|
329
|
+
if model.model_path is None:
|
|
330
|
+
logging.warning(
|
|
331
|
+
"Cannot log onnx model on mlflow, \
|
|
332
|
+
BaseEvaluationModel 'model_path' attribute is None"
|
|
333
|
+
)
|
|
334
|
+
else:
|
|
335
|
+
model_proto = onnx.load(model.model_path)
|
|
336
|
+
mlflow.onnx.log_model(
|
|
337
|
+
model_proto,
|
|
338
|
+
artifact_path=model_path,
|
|
339
|
+
signature=signature,
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
if tensorboard_logger is not None:
|
|
343
|
+
config_paths = []
|
|
344
|
+
for f in file_names:
|
|
345
|
+
if os.path.isfile(os.path.join(os.getcwd(), f)):
|
|
346
|
+
config_paths.append(os.path.join(os.getcwd(), f))
|
|
347
|
+
|
|
348
|
+
for path in config_paths:
|
|
349
|
+
upload_file_tensorboard(file_path=path, tensorboard_logger=tensorboard_logger)
|
|
350
|
+
|
|
351
|
+
tensorboard_logger.experiment.flush()
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
def load_envs(env_file: str | None = None) -> None:
|
|
355
|
+
"""Load all the environment variables defined in the `env_file`.
|
|
356
|
+
This is equivalent to `. env_file` in bash.
|
|
357
|
+
|
|
358
|
+
It is possible to define all the system specific variables in the `env_file`.
|
|
359
|
+
|
|
360
|
+
Args:
|
|
361
|
+
env_file: the file that defines the environment variables to use. If None
|
|
362
|
+
it searches for a `.env` file in the project.
|
|
363
|
+
"""
|
|
364
|
+
dotenv.load_dotenv(dotenv_path=env_file, override=True)
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
def model_type_from_path(model_path: str) -> str | None:
|
|
368
|
+
"""Determine the type of the machine learning model based on its file extension.
|
|
369
|
+
|
|
370
|
+
Parameters:
|
|
371
|
+
- model_path (str): The file path of the machine learning model.
|
|
372
|
+
|
|
373
|
+
Returns:
|
|
374
|
+
- str: The type of the model, which can be one of the following:
|
|
375
|
+
- "torchscript" if the model has a '.pt' extension (TorchScript).
|
|
376
|
+
- "pytorch" if the model has a '.pth' extension (PyTorch).
|
|
377
|
+
- "simplified_onnx" if the model file ends with 'simplified.onnx' (Simplified ONNX).
|
|
378
|
+
- "onnx" if the model has a '.onnx' extension (ONNX).
|
|
379
|
+
- "json" id the model has a '.json' extension (JSON).
|
|
380
|
+
- None if model extension is not supported.
|
|
381
|
+
|
|
382
|
+
Example:
|
|
383
|
+
```python
|
|
384
|
+
model_path = "path/to/your/model.onnx"
|
|
385
|
+
model_type = model_type_from_path(model_path)
|
|
386
|
+
print(f"The model type is: {model_type}")
|
|
387
|
+
```
|
|
388
|
+
"""
|
|
389
|
+
if model_path.endswith(".pt"):
|
|
390
|
+
return "torchscript"
|
|
391
|
+
if model_path.endswith(".pth"):
|
|
392
|
+
return "pytorch"
|
|
393
|
+
if model_path.endswith("simplified.onnx"):
|
|
394
|
+
return "simplified_onnx"
|
|
395
|
+
if model_path.endswith(".onnx"):
|
|
396
|
+
return "onnx"
|
|
397
|
+
if model_path.endswith(".json"):
|
|
398
|
+
return "json"
|
|
399
|
+
return None
|
|
400
|
+
|
|
401
|
+
|
|
402
|
+
def setup_opencv() -> None:
|
|
403
|
+
"""Setup OpenCV to use only one thread and not use OpenCL."""
|
|
404
|
+
cv2.setNumThreads(1)
|
|
405
|
+
cv2.ocl.setUseOpenCL(False)
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
def get_device(cuda: bool = True) -> str:
|
|
409
|
+
"""Returns the device to use for training.
|
|
410
|
+
|
|
411
|
+
Args:
|
|
412
|
+
cuda: whether to use cuda or not
|
|
413
|
+
|
|
414
|
+
Returns:
|
|
415
|
+
The device to use
|
|
416
|
+
"""
|
|
417
|
+
if torch.cuda.is_available() and cuda:
|
|
418
|
+
return "cuda:0"
|
|
419
|
+
|
|
420
|
+
return "cpu"
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
def nested_set(dic: dict, keys: list[str], value: str) -> None:
|
|
424
|
+
"""Assign the value of a dictionary using nested keys."""
|
|
425
|
+
for key in keys[:-1]:
|
|
426
|
+
dic = dic.setdefault(key, {})
|
|
427
|
+
|
|
428
|
+
dic[keys[-1]] = value
|
|
429
|
+
|
|
430
|
+
|
|
431
|
+
def flatten_list(input_list: Iterable[Any]) -> Iterator[Any]:
|
|
432
|
+
"""Return an iterator over the flattened list.
|
|
433
|
+
|
|
434
|
+
Args:
|
|
435
|
+
input_list: the list to be flattened
|
|
436
|
+
|
|
437
|
+
Yields:
|
|
438
|
+
The iterator over the flattend list
|
|
439
|
+
"""
|
|
440
|
+
for v in input_list:
|
|
441
|
+
if isinstance(v, Iterable) and not isinstance(v, (str, bytes)):
|
|
442
|
+
yield from flatten_list(v)
|
|
443
|
+
else:
|
|
444
|
+
yield v
|
|
445
|
+
|
|
446
|
+
|
|
447
|
+
class HydraEncoder(json.JSONEncoder):
|
|
448
|
+
"""Custom JSON encoder to handle OmegaConf objects."""
|
|
449
|
+
|
|
450
|
+
def default(self, o):
|
|
451
|
+
"""Convert OmegaConf objects to base python objects."""
|
|
452
|
+
if o is not None and OmegaConf.is_config(o):
|
|
453
|
+
return OmegaConf.to_container(o)
|
|
454
|
+
return json.JSONEncoder.default(self, o)
|
|
455
|
+
|
|
456
|
+
|
|
457
|
+
class NumpyEncoder(json.JSONEncoder):
|
|
458
|
+
"""Custom JSON encoder to handle numpy objects."""
|
|
459
|
+
|
|
460
|
+
def default(self, o):
|
|
461
|
+
"""Custom JSON encoder to handle numpy objects."""
|
|
462
|
+
if o is not None:
|
|
463
|
+
if isinstance(o, np.ndarray):
|
|
464
|
+
if o.size == 1:
|
|
465
|
+
return o.item()
|
|
466
|
+
return o.tolist()
|
|
467
|
+
if isinstance(o, np.number):
|
|
468
|
+
return o.item()
|
|
469
|
+
return json.JSONEncoder.default(self, o)
|
|
470
|
+
|
|
471
|
+
|
|
472
|
+
class AllGatherSyncFunction(torch.autograd.Function):
|
|
473
|
+
"""Function to gather gradients from multiple GPUs."""
|
|
474
|
+
|
|
475
|
+
@staticmethod
|
|
476
|
+
def forward(ctx, tensor):
|
|
477
|
+
ctx.batch_size = tensor.shape[0]
|
|
478
|
+
|
|
479
|
+
gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())]
|
|
480
|
+
|
|
481
|
+
torch.distributed.all_gather(gathered_tensor, tensor)
|
|
482
|
+
gathered_tensor = torch.cat(gathered_tensor, 0)
|
|
483
|
+
|
|
484
|
+
return gathered_tensor
|
|
485
|
+
|
|
486
|
+
@staticmethod
|
|
487
|
+
def backward(ctx, grad_output):
|
|
488
|
+
grad_input = grad_output.clone()
|
|
489
|
+
torch.distributed.all_reduce(grad_input, op=torch.distributed.ReduceOp.SUM, async_op=False)
|
|
490
|
+
|
|
491
|
+
idx_from = torch.distributed.get_rank() * ctx.batch_size
|
|
492
|
+
idx_to = (torch.distributed.get_rank() + 1) * ctx.batch_size
|
|
493
|
+
return grad_input[idx_from:idx_to]
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
@torch.no_grad()
|
|
497
|
+
def concat_all_gather(tensor):
|
|
498
|
+
"""Performs all_gather operation on the provided tensors.
|
|
499
|
+
*** Warning ***: torch.distributed.all_gather has no gradient.
|
|
500
|
+
"""
|
|
501
|
+
tensors_gather = [torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())]
|
|
502
|
+
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
|
503
|
+
|
|
504
|
+
output = torch.cat(tensors_gather, dim=0)
|
|
505
|
+
return output
|
|
506
|
+
|
|
507
|
+
|
|
508
|
+
def get_tensorboard_logger(trainer: pl.Trainer) -> TensorBoardLogger | None:
|
|
509
|
+
"""Safely get tensorboard logger from Lightning Trainer loggers.
|
|
510
|
+
|
|
511
|
+
Args:
|
|
512
|
+
trainer: Pytorch Lightning Trainer.
|
|
513
|
+
|
|
514
|
+
Returns:
|
|
515
|
+
An mlflow logger if available, else None.
|
|
516
|
+
"""
|
|
517
|
+
if isinstance(trainer.logger, TensorBoardLogger):
|
|
518
|
+
return trainer.logger
|
|
519
|
+
|
|
520
|
+
if isinstance(trainer.logger, list):
|
|
521
|
+
for logger in trainer.logger:
|
|
522
|
+
if isinstance(logger, TensorBoardLogger):
|
|
523
|
+
return logger
|
|
524
|
+
|
|
525
|
+
return None
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import difflib
|
|
4
|
+
import importlib
|
|
5
|
+
import inspect
|
|
6
|
+
from collections.abc import Iterable
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
from omegaconf import DictConfig, ListConfig, OmegaConf
|
|
10
|
+
|
|
11
|
+
from quadra.utils.utils import get_logger
|
|
12
|
+
|
|
13
|
+
OMEGACONF_FIELDS: tuple[str, ...] = ("_target_", "_convert_", "_recursive_", "_args_")
|
|
14
|
+
EXCLUDE_KEYS: tuple[str, ...] = ("hydra",)
|
|
15
|
+
|
|
16
|
+
logger = get_logger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def get_callable_arguments(full_module_path: str) -> tuple[list[str], bool]:
|
|
20
|
+
"""Gets all arguments from module path.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
full_module_path: Full module path to the target class or function.
|
|
24
|
+
|
|
25
|
+
Raises:
|
|
26
|
+
ValueError: If the target is not a class or a function.
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
All arguments from the target class or function.
|
|
30
|
+
"""
|
|
31
|
+
module_path, callable_name = full_module_path.rsplit(".", 1)
|
|
32
|
+
module = importlib.import_module(module_path)
|
|
33
|
+
callable_ = getattr(module, callable_name)
|
|
34
|
+
# check if it is a class
|
|
35
|
+
accepts_kwargs = False
|
|
36
|
+
if inspect.isclass(callable_):
|
|
37
|
+
arg_names = []
|
|
38
|
+
for cls in callable_.__mro__:
|
|
39
|
+
if cls is object:
|
|
40
|
+
break
|
|
41
|
+
# We don' access the instance but mypy complains
|
|
42
|
+
init_argspec = inspect.getfullargspec(cls.__init__) # type: ignore
|
|
43
|
+
cls_arg_names = init_argspec.args[1:]
|
|
44
|
+
cls_kwonlyargs = init_argspec.kwonlyargs
|
|
45
|
+
arg_names.extend(cls_arg_names)
|
|
46
|
+
arg_names.extend(cls_kwonlyargs)
|
|
47
|
+
# if the target class or function accepts kwargs, we cannot check arguments
|
|
48
|
+
accepts_kwargs = init_argspec.varkw is not None or accepts_kwargs
|
|
49
|
+
arg_names = list(set(arg_names))
|
|
50
|
+
elif inspect.isfunction(callable_):
|
|
51
|
+
init_argspec = inspect.getfullargspec(callable_)
|
|
52
|
+
arg_names = []
|
|
53
|
+
arg_names.extend(init_argspec.args)
|
|
54
|
+
arg_names.extend(init_argspec.kwonlyargs)
|
|
55
|
+
accepts_kwargs = init_argspec.varkw is not None or accepts_kwargs
|
|
56
|
+
else:
|
|
57
|
+
raise ValueError("The target must be a class or a function.")
|
|
58
|
+
|
|
59
|
+
return arg_names, accepts_kwargs
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def check_all_arguments(callable_variable: str, configuration_arguments: list[str], argument_names: list[str]) -> None:
|
|
63
|
+
"""Checks if all arguments passed from configuration are valid for the target class or function.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
callable_variable : Full module path to the target class or function.
|
|
67
|
+
configuration_arguments : All arguments passed from configuration.
|
|
68
|
+
argument_names: All arguments from the target class or function.
|
|
69
|
+
|
|
70
|
+
Raises:
|
|
71
|
+
ValueError: If the argument is not valid for the target class or function.
|
|
72
|
+
"""
|
|
73
|
+
for argument in configuration_arguments:
|
|
74
|
+
if argument not in argument_names:
|
|
75
|
+
error_string = (
|
|
76
|
+
f"`{argument}` is not a valid argument passed " f"from configuration to `{callable_variable}`."
|
|
77
|
+
)
|
|
78
|
+
closest_match = difflib.get_close_matches(argument, argument_names, n=1, cutoff=0.5)
|
|
79
|
+
if len(closest_match) > 0:
|
|
80
|
+
error_string += f" Did you mean `{closest_match[0]}`?"
|
|
81
|
+
raise ValueError(error_string)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def validate_config(_cfg: DictConfig | ListConfig, package_name: str = "quadra") -> None:
|
|
85
|
+
"""Recursively traverse OmegaConf object and check if arguments are valid for the target class or function.
|
|
86
|
+
If not, raise a ValueError with a suggestion for the closest match of the argument name.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
_cfg: OmegaConf object
|
|
90
|
+
package_name: package name to check for instantiation.
|
|
91
|
+
"""
|
|
92
|
+
# The below lines of code for looping over a DictConfig/ListConfig are
|
|
93
|
+
# borrowed from OmegaConf PR #719.
|
|
94
|
+
itr: Iterable[Any]
|
|
95
|
+
if isinstance(_cfg, ListConfig):
|
|
96
|
+
itr = range(len(_cfg))
|
|
97
|
+
else:
|
|
98
|
+
itr = _cfg
|
|
99
|
+
for key in itr:
|
|
100
|
+
if OmegaConf.is_missing(_cfg, key):
|
|
101
|
+
continue
|
|
102
|
+
if isinstance(key, str) and any(x in key for x in EXCLUDE_KEYS):
|
|
103
|
+
continue
|
|
104
|
+
if OmegaConf.is_config(_cfg[key]):
|
|
105
|
+
validate_config(_cfg[key])
|
|
106
|
+
elif isinstance(_cfg[key], str):
|
|
107
|
+
if key == "_target_":
|
|
108
|
+
callable_variable = str(_cfg[key])
|
|
109
|
+
if callable_variable.startswith(package_name):
|
|
110
|
+
configuration_arguments = [str(x) for x in _cfg if x not in OMEGACONF_FIELDS]
|
|
111
|
+
argument_names, accepts_kwargs = get_callable_arguments(callable_variable)
|
|
112
|
+
if not accepts_kwargs:
|
|
113
|
+
check_all_arguments(callable_variable, configuration_arguments, argument_names)
|
|
114
|
+
else:
|
|
115
|
+
logger.debug("Skipping %s from config. It is not supported.", key)
|