quadra 0.0.1__py3-none-any.whl → 2.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hydra_plugins/quadra_searchpath_plugin.py +30 -0
- quadra/__init__.py +6 -0
- quadra/callbacks/__init__.py +0 -0
- quadra/callbacks/anomalib.py +289 -0
- quadra/callbacks/lightning.py +501 -0
- quadra/callbacks/mlflow.py +291 -0
- quadra/callbacks/scheduler.py +69 -0
- quadra/configs/__init__.py +0 -0
- quadra/configs/backbone/caformer_m36.yaml +8 -0
- quadra/configs/backbone/caformer_s36.yaml +8 -0
- quadra/configs/backbone/convnextv2_base.yaml +8 -0
- quadra/configs/backbone/convnextv2_femto.yaml +8 -0
- quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
- quadra/configs/backbone/dino_vitb8.yaml +12 -0
- quadra/configs/backbone/dino_vits8.yaml +12 -0
- quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
- quadra/configs/backbone/dinov2_vits14.yaml +12 -0
- quadra/configs/backbone/efficientnet_b0.yaml +8 -0
- quadra/configs/backbone/efficientnet_b1.yaml +8 -0
- quadra/configs/backbone/efficientnet_b2.yaml +8 -0
- quadra/configs/backbone/efficientnet_b3.yaml +8 -0
- quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
- quadra/configs/backbone/levit_128s.yaml +8 -0
- quadra/configs/backbone/mnasnet0_5.yaml +9 -0
- quadra/configs/backbone/resnet101.yaml +8 -0
- quadra/configs/backbone/resnet18.yaml +8 -0
- quadra/configs/backbone/resnet18_ssl.yaml +8 -0
- quadra/configs/backbone/resnet50.yaml +8 -0
- quadra/configs/backbone/smp.yaml +9 -0
- quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
- quadra/configs/backbone/unetr.yaml +15 -0
- quadra/configs/backbone/vit16_base.yaml +9 -0
- quadra/configs/backbone/vit16_small.yaml +9 -0
- quadra/configs/backbone/vit16_tiny.yaml +9 -0
- quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
- quadra/configs/callbacks/all.yaml +45 -0
- quadra/configs/callbacks/default.yaml +34 -0
- quadra/configs/callbacks/default_anomalib.yaml +64 -0
- quadra/configs/config.yaml +33 -0
- quadra/configs/core/default.yaml +11 -0
- quadra/configs/datamodule/base/anomaly.yaml +16 -0
- quadra/configs/datamodule/base/classification.yaml +21 -0
- quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
- quadra/configs/datamodule/base/segmentation.yaml +18 -0
- quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
- quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
- quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
- quadra/configs/datamodule/base/ssl.yaml +21 -0
- quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
- quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
- quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
- quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
- quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
- quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
- quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
- quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
- quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
- quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
- quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/base/classification/classification.yaml +73 -0
- quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
- quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
- quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
- quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
- quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
- quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
- quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
- quadra/configs/experiment/base/ssl/byol.yaml +43 -0
- quadra/configs/experiment/base/ssl/dino.yaml +46 -0
- quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
- quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
- quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
- quadra/configs/experiment/custom/cls.yaml +12 -0
- quadra/configs/experiment/default.yaml +15 -0
- quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
- quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
- quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
- quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
- quadra/configs/export/default.yaml +13 -0
- quadra/configs/hydra/anomaly_custom.yaml +15 -0
- quadra/configs/hydra/default.yaml +14 -0
- quadra/configs/inference/default.yaml +26 -0
- quadra/configs/logger/comet.yaml +10 -0
- quadra/configs/logger/csv.yaml +5 -0
- quadra/configs/logger/mlflow.yaml +12 -0
- quadra/configs/logger/tensorboard.yaml +8 -0
- quadra/configs/loss/asl.yaml +7 -0
- quadra/configs/loss/barlow.yaml +2 -0
- quadra/configs/loss/bce.yaml +1 -0
- quadra/configs/loss/byol.yaml +1 -0
- quadra/configs/loss/cross_entropy.yaml +1 -0
- quadra/configs/loss/dino.yaml +8 -0
- quadra/configs/loss/simclr.yaml +2 -0
- quadra/configs/loss/simsiam.yaml +1 -0
- quadra/configs/loss/smp_ce.yaml +3 -0
- quadra/configs/loss/smp_dice.yaml +2 -0
- quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
- quadra/configs/loss/smp_mcc.yaml +2 -0
- quadra/configs/loss/vicreg.yaml +5 -0
- quadra/configs/model/anomalib/cfa.yaml +35 -0
- quadra/configs/model/anomalib/cflow.yaml +30 -0
- quadra/configs/model/anomalib/csflow.yaml +34 -0
- quadra/configs/model/anomalib/dfm.yaml +19 -0
- quadra/configs/model/anomalib/draem.yaml +29 -0
- quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
- quadra/configs/model/anomalib/fastflow.yaml +32 -0
- quadra/configs/model/anomalib/padim.yaml +32 -0
- quadra/configs/model/anomalib/patchcore.yaml +36 -0
- quadra/configs/model/barlow.yaml +16 -0
- quadra/configs/model/byol.yaml +25 -0
- quadra/configs/model/classification.yaml +10 -0
- quadra/configs/model/dino.yaml +26 -0
- quadra/configs/model/logistic_regression.yaml +4 -0
- quadra/configs/model/multilabel_classification.yaml +9 -0
- quadra/configs/model/simclr.yaml +18 -0
- quadra/configs/model/simsiam.yaml +24 -0
- quadra/configs/model/smp.yaml +4 -0
- quadra/configs/model/smp_multiclass.yaml +4 -0
- quadra/configs/model/vicreg.yaml +16 -0
- quadra/configs/optimizer/adam.yaml +5 -0
- quadra/configs/optimizer/adamw.yaml +3 -0
- quadra/configs/optimizer/default.yaml +4 -0
- quadra/configs/optimizer/lars.yaml +8 -0
- quadra/configs/optimizer/sgd.yaml +4 -0
- quadra/configs/scheduler/default.yaml +5 -0
- quadra/configs/scheduler/rop.yaml +5 -0
- quadra/configs/scheduler/step.yaml +3 -0
- quadra/configs/scheduler/warmrestart.yaml +2 -0
- quadra/configs/scheduler/warmup.yaml +6 -0
- quadra/configs/task/anomalib/cfa.yaml +5 -0
- quadra/configs/task/anomalib/cflow.yaml +5 -0
- quadra/configs/task/anomalib/csflow.yaml +5 -0
- quadra/configs/task/anomalib/draem.yaml +5 -0
- quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
- quadra/configs/task/anomalib/fastflow.yaml +5 -0
- quadra/configs/task/anomalib/inference.yaml +3 -0
- quadra/configs/task/anomalib/padim.yaml +5 -0
- quadra/configs/task/anomalib/patchcore.yaml +5 -0
- quadra/configs/task/classification.yaml +6 -0
- quadra/configs/task/classification_evaluation.yaml +6 -0
- quadra/configs/task/default.yaml +1 -0
- quadra/configs/task/segmentation.yaml +9 -0
- quadra/configs/task/segmentation_evaluation.yaml +3 -0
- quadra/configs/task/sklearn_classification.yaml +13 -0
- quadra/configs/task/sklearn_classification_patch.yaml +11 -0
- quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
- quadra/configs/task/sklearn_classification_test.yaml +8 -0
- quadra/configs/task/ssl.yaml +2 -0
- quadra/configs/trainer/lightning_cpu.yaml +36 -0
- quadra/configs/trainer/lightning_gpu.yaml +35 -0
- quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
- quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
- quadra/configs/trainer/lightning_multigpu.yaml +37 -0
- quadra/configs/trainer/sklearn_classification.yaml +7 -0
- quadra/configs/transforms/byol.yaml +47 -0
- quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
- quadra/configs/transforms/default.yaml +37 -0
- quadra/configs/transforms/default_numpy.yaml +24 -0
- quadra/configs/transforms/default_resize.yaml +22 -0
- quadra/configs/transforms/dino.yaml +63 -0
- quadra/configs/transforms/linear_eval.yaml +18 -0
- quadra/datamodules/__init__.py +20 -0
- quadra/datamodules/anomaly.py +180 -0
- quadra/datamodules/base.py +375 -0
- quadra/datamodules/classification.py +1003 -0
- quadra/datamodules/generic/__init__.py +0 -0
- quadra/datamodules/generic/imagenette.py +144 -0
- quadra/datamodules/generic/mnist.py +81 -0
- quadra/datamodules/generic/mvtec.py +58 -0
- quadra/datamodules/generic/oxford_pet.py +163 -0
- quadra/datamodules/patch.py +190 -0
- quadra/datamodules/segmentation.py +742 -0
- quadra/datamodules/ssl.py +140 -0
- quadra/datasets/__init__.py +17 -0
- quadra/datasets/anomaly.py +287 -0
- quadra/datasets/classification.py +241 -0
- quadra/datasets/patch.py +138 -0
- quadra/datasets/segmentation.py +239 -0
- quadra/datasets/ssl.py +110 -0
- quadra/losses/__init__.py +0 -0
- quadra/losses/classification/__init__.py +6 -0
- quadra/losses/classification/asl.py +83 -0
- quadra/losses/classification/focal.py +320 -0
- quadra/losses/classification/prototypical.py +148 -0
- quadra/losses/ssl/__init__.py +17 -0
- quadra/losses/ssl/barlowtwins.py +47 -0
- quadra/losses/ssl/byol.py +37 -0
- quadra/losses/ssl/dino.py +129 -0
- quadra/losses/ssl/hyperspherical.py +45 -0
- quadra/losses/ssl/idmm.py +50 -0
- quadra/losses/ssl/simclr.py +67 -0
- quadra/losses/ssl/simsiam.py +30 -0
- quadra/losses/ssl/vicreg.py +76 -0
- quadra/main.py +49 -0
- quadra/metrics/__init__.py +3 -0
- quadra/metrics/segmentation.py +251 -0
- quadra/models/__init__.py +0 -0
- quadra/models/base.py +151 -0
- quadra/models/classification/__init__.py +8 -0
- quadra/models/classification/backbones.py +149 -0
- quadra/models/classification/base.py +92 -0
- quadra/models/evaluation.py +322 -0
- quadra/modules/__init__.py +0 -0
- quadra/modules/backbone.py +30 -0
- quadra/modules/base.py +312 -0
- quadra/modules/classification/__init__.py +3 -0
- quadra/modules/classification/base.py +327 -0
- quadra/modules/ssl/__init__.py +17 -0
- quadra/modules/ssl/barlowtwins.py +59 -0
- quadra/modules/ssl/byol.py +172 -0
- quadra/modules/ssl/common.py +285 -0
- quadra/modules/ssl/dino.py +186 -0
- quadra/modules/ssl/hyperspherical.py +206 -0
- quadra/modules/ssl/idmm.py +98 -0
- quadra/modules/ssl/simclr.py +73 -0
- quadra/modules/ssl/simsiam.py +68 -0
- quadra/modules/ssl/vicreg.py +67 -0
- quadra/optimizers/__init__.py +4 -0
- quadra/optimizers/lars.py +153 -0
- quadra/optimizers/sam.py +127 -0
- quadra/schedulers/__init__.py +3 -0
- quadra/schedulers/base.py +44 -0
- quadra/schedulers/warmup.py +127 -0
- quadra/tasks/__init__.py +24 -0
- quadra/tasks/anomaly.py +582 -0
- quadra/tasks/base.py +397 -0
- quadra/tasks/classification.py +1263 -0
- quadra/tasks/patch.py +492 -0
- quadra/tasks/segmentation.py +389 -0
- quadra/tasks/ssl.py +560 -0
- quadra/trainers/README.md +3 -0
- quadra/trainers/__init__.py +0 -0
- quadra/trainers/classification.py +179 -0
- quadra/utils/__init__.py +0 -0
- quadra/utils/anomaly.py +112 -0
- quadra/utils/classification.py +618 -0
- quadra/utils/deprecation.py +31 -0
- quadra/utils/evaluation.py +474 -0
- quadra/utils/export.py +585 -0
- quadra/utils/imaging.py +32 -0
- quadra/utils/logger.py +15 -0
- quadra/utils/mlflow.py +98 -0
- quadra/utils/model_manager.py +320 -0
- quadra/utils/models.py +523 -0
- quadra/utils/patch/__init__.py +15 -0
- quadra/utils/patch/dataset.py +1433 -0
- quadra/utils/patch/metrics.py +449 -0
- quadra/utils/patch/model.py +153 -0
- quadra/utils/patch/visualization.py +217 -0
- quadra/utils/resolver.py +42 -0
- quadra/utils/segmentation.py +31 -0
- quadra/utils/tests/__init__.py +0 -0
- quadra/utils/tests/fixtures/__init__.py +1 -0
- quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
- quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
- quadra/utils/tests/fixtures/dataset/classification.py +406 -0
- quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
- quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
- quadra/utils/tests/fixtures/models/__init__.py +3 -0
- quadra/utils/tests/fixtures/models/anomaly.py +89 -0
- quadra/utils/tests/fixtures/models/classification.py +45 -0
- quadra/utils/tests/fixtures/models/segmentation.py +33 -0
- quadra/utils/tests/helpers.py +70 -0
- quadra/utils/tests/models.py +27 -0
- quadra/utils/utils.py +525 -0
- quadra/utils/validator.py +115 -0
- quadra/utils/visualization.py +422 -0
- quadra/utils/vit_explainability.py +349 -0
- quadra-2.2.7.dist-info/LICENSE +201 -0
- quadra-2.2.7.dist-info/METADATA +381 -0
- quadra-2.2.7.dist-info/RECORD +300 -0
- {quadra-0.0.1.dist-info → quadra-2.2.7.dist-info}/WHEEL +1 -1
- quadra-2.2.7.dist-info/entry_points.txt +3 -0
- quadra-0.0.1.dist-info/METADATA +0 -14
- quadra-0.0.1.dist-info/RECORD +0 -4
|
@@ -0,0 +1,291 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import glob
|
|
4
|
+
import os
|
|
5
|
+
from typing import Any, Literal
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from pytorch_lightning import Callback, LightningModule, Trainer
|
|
9
|
+
from pytorch_lightning.callbacks import LearningRateMonitor
|
|
10
|
+
from pytorch_lightning.loggers import MLFlowLogger
|
|
11
|
+
from pytorch_lightning.utilities import rank_zero_only
|
|
12
|
+
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
|
13
|
+
|
|
14
|
+
from quadra.utils.mlflow import get_mlflow_logger
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def check_minio_credentials() -> None:
|
|
18
|
+
"""Check minio credentials for aws based storage such as minio.
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
None
|
|
22
|
+
"""
|
|
23
|
+
check = os.environ.get("AWS_ACCESS_KEY_ID") is not None and os.environ.get("AWS_SECRET_ACCESS_KEY") is not None
|
|
24
|
+
if not check:
|
|
25
|
+
raise ValueError(
|
|
26
|
+
"You are trying to upload mlflow artifacts, but minio credentials are not set. Please set them in your"
|
|
27
|
+
" environment variables."
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def check_file_server_dependencies() -> None:
|
|
32
|
+
"""Check file dependencies as boto3.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
None
|
|
36
|
+
"""
|
|
37
|
+
try:
|
|
38
|
+
# pylint: disable=unused-import,import-outside-toplevel
|
|
39
|
+
import boto3 # noqa
|
|
40
|
+
import minio # noqa
|
|
41
|
+
except ImportError as e:
|
|
42
|
+
raise ImportError(
|
|
43
|
+
"You are trying to upload mlflow artifacts, but boto3 and minio are not installed. Please install them by"
|
|
44
|
+
" calling pip install minio boto3."
|
|
45
|
+
) from e
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def validate_artifact_storage(logger: MLFlowLogger):
|
|
49
|
+
"""Validate artifact storage.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
logger: Mlflow logger from pytorch lightning.
|
|
53
|
+
|
|
54
|
+
"""
|
|
55
|
+
from quadra.utils.utils import get_logger # pylint: disable=[import-outside-toplevel]
|
|
56
|
+
|
|
57
|
+
log = get_logger(__name__)
|
|
58
|
+
|
|
59
|
+
client = logger.experiment
|
|
60
|
+
# TODO: we have to access the internal api to get the artifact uri, however there could be a better way
|
|
61
|
+
artifact_uri = client._tracking_client._get_artifact_repo( # pylint: disable=protected-access
|
|
62
|
+
logger.run_id
|
|
63
|
+
).artifact_uri
|
|
64
|
+
if artifact_uri.startswith("s3://"):
|
|
65
|
+
check_minio_credentials()
|
|
66
|
+
check_file_server_dependencies()
|
|
67
|
+
log.info("Mlflow artifact storage is AWS/S3 basedand credentials and dependencies are satisfied.")
|
|
68
|
+
else:
|
|
69
|
+
log.info("Mlflow artifact storage uri is %s. Validation checks are not implemented.", artifact_uri)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class UploadCodeAsArtifact(Callback):
|
|
73
|
+
"""Callback used to upload Code as artifact.
|
|
74
|
+
|
|
75
|
+
Uploads all *.py files to mlflow as an artifact, at the beginning of the run but
|
|
76
|
+
after initializing the trainer. It creates project-source folder under mlflow
|
|
77
|
+
artifacts and other necessary subfolders.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
source_dir: Folder where all the source files are stored.
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
def __init__(self, source_dir: str):
|
|
84
|
+
self.source_dir = source_dir
|
|
85
|
+
|
|
86
|
+
@rank_zero_only
|
|
87
|
+
def on_test_end(self, trainer: Trainer, pl_module: LightningModule):
|
|
88
|
+
"""Triggered at the end of test. Uploads all *.py files to mlflow as an artifact.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
trainer: Pytorch Lightning trainer.
|
|
92
|
+
pl_module: Pytorch Lightning module.
|
|
93
|
+
"""
|
|
94
|
+
logger = get_mlflow_logger(trainer=trainer)
|
|
95
|
+
|
|
96
|
+
if logger is None:
|
|
97
|
+
return
|
|
98
|
+
|
|
99
|
+
experiment = logger.experiment
|
|
100
|
+
|
|
101
|
+
for path in glob.glob(os.path.join(self.source_dir, "**/*.py"), recursive=True):
|
|
102
|
+
stripped_path = path.replace(self.source_dir, "")
|
|
103
|
+
if len(stripped_path.split("/")) > 1:
|
|
104
|
+
file_path_tree = "/" + "/".join(stripped_path.split("/")[:-1])
|
|
105
|
+
else:
|
|
106
|
+
file_path_tree = ""
|
|
107
|
+
experiment.log_artifact(
|
|
108
|
+
run_id=logger.run_id,
|
|
109
|
+
local_path=path,
|
|
110
|
+
artifact_path=f"project-source{file_path_tree}",
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class LogGradients(Callback):
|
|
115
|
+
"""Callback used to logs of the model at the end of the of each training step.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
norm: Norm to use for the gradient. Default is L2 norm.
|
|
119
|
+
tag: Tag to add to the gradients. If None, no tag will be added.
|
|
120
|
+
sep: Separator to use in the log.
|
|
121
|
+
round_to: Number of decimals to round the gradients to.
|
|
122
|
+
log_all_grads: If True, log all gradients, not just the total norm.
|
|
123
|
+
"""
|
|
124
|
+
|
|
125
|
+
def __init__(
|
|
126
|
+
self,
|
|
127
|
+
norm: int = 2,
|
|
128
|
+
tag: str | None = None,
|
|
129
|
+
sep: str = "/",
|
|
130
|
+
round_to: int = 3,
|
|
131
|
+
log_all_grads: bool = False,
|
|
132
|
+
):
|
|
133
|
+
self.norm = norm
|
|
134
|
+
self.tag = tag
|
|
135
|
+
self.sep = sep
|
|
136
|
+
self.round_to = round_to
|
|
137
|
+
self.log_all_grads = log_all_grads
|
|
138
|
+
|
|
139
|
+
def _grad_norm(self, named_params) -> dict:
|
|
140
|
+
"""Compute the gradient norm and return it in a dictionary."""
|
|
141
|
+
grad_tag = "" if self.tag is None else "_" + self.tag
|
|
142
|
+
results = {}
|
|
143
|
+
for name, p in named_params:
|
|
144
|
+
if p.requires_grad and p.grad is not None:
|
|
145
|
+
norm = float(p.grad.data.norm(self.norm))
|
|
146
|
+
key = f"grad_norm_{self.norm}{grad_tag}{self.sep}{name}"
|
|
147
|
+
results[key] = round(norm, 3)
|
|
148
|
+
total_norm = float(torch.tensor(list(results.values())).norm(self.norm))
|
|
149
|
+
if not self.log_all_grads:
|
|
150
|
+
# clear dictionary
|
|
151
|
+
results = {}
|
|
152
|
+
results[f"grad_norm_{self.norm}_total{grad_tag}"] = round(total_norm, self.round_to)
|
|
153
|
+
return results
|
|
154
|
+
|
|
155
|
+
@rank_zero_only
|
|
156
|
+
def on_train_batch_end(
|
|
157
|
+
self,
|
|
158
|
+
trainer: Trainer,
|
|
159
|
+
pl_module: LightningModule,
|
|
160
|
+
outputs: STEP_OUTPUT,
|
|
161
|
+
batch: Any,
|
|
162
|
+
batch_idx: int,
|
|
163
|
+
unused: int | None = 0,
|
|
164
|
+
) -> None:
|
|
165
|
+
"""Method called at the end of the train batch
|
|
166
|
+
Args:
|
|
167
|
+
trainer: pl.trainer
|
|
168
|
+
pl_module: lightning module
|
|
169
|
+
outputs: outputs
|
|
170
|
+
batch: batch
|
|
171
|
+
batch_idx: index
|
|
172
|
+
unused: dl index.
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
None
|
|
177
|
+
"""
|
|
178
|
+
# pylint: disable=unused-argument
|
|
179
|
+
logger = get_mlflow_logger(trainer=trainer)
|
|
180
|
+
|
|
181
|
+
if logger is None:
|
|
182
|
+
return
|
|
183
|
+
|
|
184
|
+
named_params = pl_module.named_parameters()
|
|
185
|
+
grads = self._grad_norm(named_params)
|
|
186
|
+
logger.log_metrics(grads)
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
class UploadCheckpointsAsArtifact(Callback):
|
|
190
|
+
"""Callback used to upload checkpoints as artifacts.
|
|
191
|
+
|
|
192
|
+
Args:
|
|
193
|
+
ckpt_dir: Folder where all the checkpoints are stored in artifact folder.
|
|
194
|
+
ckpt_ext: Extension of checkpoint files (default: ckpt).
|
|
195
|
+
upload_best_only: Only upload best checkpoint (default: False)
|
|
196
|
+
delete_after_upload: Delete the checkpoint from local storage after uploading (default: True)
|
|
197
|
+
upload: If True, upload the checkpoints. If False, only save them on local machine.
|
|
198
|
+
"""
|
|
199
|
+
|
|
200
|
+
def __init__(
|
|
201
|
+
self,
|
|
202
|
+
ckpt_dir: str = "checkpoints/",
|
|
203
|
+
ckpt_ext: str = "ckpt",
|
|
204
|
+
upload_best_only: bool = False,
|
|
205
|
+
delete_after_upload: bool = True,
|
|
206
|
+
upload: bool = True,
|
|
207
|
+
):
|
|
208
|
+
self.ckpt_dir = ckpt_dir
|
|
209
|
+
self.upload_best_only = upload_best_only
|
|
210
|
+
self.ckpt_ext = ckpt_ext
|
|
211
|
+
self.delete_after_upload = delete_after_upload
|
|
212
|
+
self.upload = upload
|
|
213
|
+
|
|
214
|
+
@rank_zero_only
|
|
215
|
+
def on_test_end(self, trainer: Trainer, pl_module: LightningModule):
|
|
216
|
+
"""Triggered at the end of test. Uploads all model checkpoints to mlflow as an artifact.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
trainer: Pytorch Lightning trainer.
|
|
220
|
+
pl_module: Pytorch Lightning module.
|
|
221
|
+
"""
|
|
222
|
+
logger = get_mlflow_logger(trainer=trainer)
|
|
223
|
+
|
|
224
|
+
if logger is None:
|
|
225
|
+
return
|
|
226
|
+
|
|
227
|
+
experiment = logger.experiment
|
|
228
|
+
|
|
229
|
+
if (
|
|
230
|
+
trainer.checkpoint_callback
|
|
231
|
+
and self.upload_best_only
|
|
232
|
+
and hasattr(trainer.checkpoint_callback, "best_model_path")
|
|
233
|
+
):
|
|
234
|
+
if self.upload:
|
|
235
|
+
experiment.log_artifact(
|
|
236
|
+
run_id=logger.run_id,
|
|
237
|
+
local_path=trainer.checkpoint_callback.best_model_path,
|
|
238
|
+
artifact_path="checkpoints",
|
|
239
|
+
)
|
|
240
|
+
else:
|
|
241
|
+
for path in glob.glob(os.path.join(self.ckpt_dir, f"**/*.{self.ckpt_ext}"), recursive=True):
|
|
242
|
+
if self.upload:
|
|
243
|
+
experiment.log_artifact(
|
|
244
|
+
run_id=logger.run_id,
|
|
245
|
+
local_path=path,
|
|
246
|
+
artifact_path="checkpoints",
|
|
247
|
+
)
|
|
248
|
+
if self.delete_after_upload:
|
|
249
|
+
for path in glob.glob(os.path.join(self.ckpt_dir, f"**/*.{self.ckpt_ext}"), recursive=True):
|
|
250
|
+
os.remove(path)
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
class LogLearningRate(LearningRateMonitor):
|
|
254
|
+
"""Learning rate logger at the end of the training step/epoch.
|
|
255
|
+
|
|
256
|
+
Args:
|
|
257
|
+
logging_interval: Logging interval.
|
|
258
|
+
log_momentum: If True, log momentum as well.
|
|
259
|
+
"""
|
|
260
|
+
|
|
261
|
+
def __init__(self, logging_interval: Literal["step", "epoch"] | None = None, log_momentum: bool = False):
|
|
262
|
+
super().__init__(logging_interval=logging_interval, log_momentum=log_momentum)
|
|
263
|
+
|
|
264
|
+
def on_train_batch_start(self, trainer, *args, **kwargs):
|
|
265
|
+
"""Log learning rate at the beginning of the training step if logging interval is set to step."""
|
|
266
|
+
if not trainer.logger_connector.should_update_logs:
|
|
267
|
+
return
|
|
268
|
+
if self.logging_interval != "epoch":
|
|
269
|
+
logger = get_mlflow_logger(trainer=trainer)
|
|
270
|
+
|
|
271
|
+
if logger is None:
|
|
272
|
+
return
|
|
273
|
+
|
|
274
|
+
interval = "step" if self.logging_interval is None else "any"
|
|
275
|
+
latest_stat = self._extract_stats(trainer, interval)
|
|
276
|
+
|
|
277
|
+
if latest_stat:
|
|
278
|
+
logger.log_metrics(latest_stat, step=trainer.global_step)
|
|
279
|
+
|
|
280
|
+
def on_train_epoch_start(self, trainer, *args, **kwargs):
|
|
281
|
+
"""Log learning rate at the beginning of the epoch if logging interval is set to epoch."""
|
|
282
|
+
if self.logging_interval != "step":
|
|
283
|
+
interval = "epoch" if self.logging_interval is None else "any"
|
|
284
|
+
latest_stat = self._extract_stats(trainer, interval)
|
|
285
|
+
logger = get_mlflow_logger(trainer=trainer)
|
|
286
|
+
|
|
287
|
+
if logger is None:
|
|
288
|
+
return
|
|
289
|
+
|
|
290
|
+
if latest_stat:
|
|
291
|
+
logger.log_metrics(latest_stat, step=trainer.global_step)
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
import hydra
|
|
2
|
+
import pytorch_lightning as pl
|
|
3
|
+
from omegaconf import DictConfig
|
|
4
|
+
from pytorch_lightning import Callback
|
|
5
|
+
from pytorch_lightning.utilities import rank_zero_only
|
|
6
|
+
|
|
7
|
+
from quadra.schedulers.warmup import CosineAnnealingWithLinearWarmUp
|
|
8
|
+
from quadra.utils.utils import get_logger
|
|
9
|
+
|
|
10
|
+
log = get_logger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class WarmupInit(Callback):
|
|
14
|
+
"""Custom callback used to setup a warmup scheduler.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
scheduler_config: scheduler configuration.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
scheduler_config: DictConfig,
|
|
23
|
+
) -> None:
|
|
24
|
+
self.scheduler_config = scheduler_config
|
|
25
|
+
|
|
26
|
+
@rank_zero_only
|
|
27
|
+
def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
|
|
28
|
+
"""Called when fit begins."""
|
|
29
|
+
if not hasattr(trainer, "datamodule"):
|
|
30
|
+
raise ValueError("Trainer must have a datamodule attribute.")
|
|
31
|
+
|
|
32
|
+
if not any(isinstance(s.scheduler, CosineAnnealingWithLinearWarmUp) for s in trainer.lr_scheduler_configs):
|
|
33
|
+
return
|
|
34
|
+
|
|
35
|
+
log.info("Using warmup scheduler, forcing optimizer learning rate to zero.")
|
|
36
|
+
for i, _ in enumerate(trainer.optimizers):
|
|
37
|
+
for param_group in trainer.optimizers[i].param_groups:
|
|
38
|
+
param_group["lr"] = 0.0
|
|
39
|
+
trainer.optimizers[i].defaults["lr"] = 0.0
|
|
40
|
+
|
|
41
|
+
batch_size = trainer.datamodule.batch_size
|
|
42
|
+
train_dataloader = trainer.datamodule.train_dataloader()
|
|
43
|
+
len_train_dataloader = len(train_dataloader)
|
|
44
|
+
if isinstance(trainer.device_ids, list) and pl_module.device.type == "cuda":
|
|
45
|
+
num_gpus = len(trainer.device_ids)
|
|
46
|
+
len_train_dataloader = len_train_dataloader // num_gpus
|
|
47
|
+
if not train_dataloader.drop_last:
|
|
48
|
+
len_train_dataloader += int((len_train_dataloader % num_gpus) != 0)
|
|
49
|
+
|
|
50
|
+
if len_train_dataloader == 1:
|
|
51
|
+
log.warning(
|
|
52
|
+
"From this dataset size, we can only generate single batch. The batch size will be set as lenght of"
|
|
53
|
+
" the dataset "
|
|
54
|
+
)
|
|
55
|
+
batch_size = len(train_dataloader.dataset)
|
|
56
|
+
|
|
57
|
+
if isinstance(trainer.device_ids, list) and pl_module.device.type == "cuda":
|
|
58
|
+
batch_size = batch_size * len(trainer.device_ids)
|
|
59
|
+
|
|
60
|
+
scheduler = hydra.utils.instantiate(
|
|
61
|
+
self.scheduler_config,
|
|
62
|
+
optimizer=pl_module.optimizer,
|
|
63
|
+
batch_size=batch_size,
|
|
64
|
+
len_loader=len_train_dataloader,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
for i, s in enumerate(trainer.lr_scheduler_configs):
|
|
68
|
+
if isinstance(s.scheduler, CosineAnnealingWithLinearWarmUp):
|
|
69
|
+
trainer.lr_scheduler_configs[i].scheduler = scheduler
|
|
File without changes
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
model:
|
|
2
|
+
_target_: quadra.models.classification.TorchHubNetworkBuilder
|
|
3
|
+
repo_or_dir: facebookresearch/dino:main
|
|
4
|
+
model_name: dino_vitb8
|
|
5
|
+
pretrained: true
|
|
6
|
+
freeze: false
|
|
7
|
+
hyperspherical: false
|
|
8
|
+
metadata:
|
|
9
|
+
input_size: 224
|
|
10
|
+
output_dim: 768
|
|
11
|
+
patch_size: 8
|
|
12
|
+
nb_heads: 12
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
model:
|
|
2
|
+
_target_: quadra.models.classification.TorchHubNetworkBuilder
|
|
3
|
+
repo_or_dir: facebookresearch/dino:main
|
|
4
|
+
model_name: dino_vits8
|
|
5
|
+
pretrained: true
|
|
6
|
+
freeze: false
|
|
7
|
+
hyperspherical: false
|
|
8
|
+
metadata:
|
|
9
|
+
input_size: 224
|
|
10
|
+
output_dim: 384
|
|
11
|
+
patch_size: 8
|
|
12
|
+
nb_heads: 6
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
model:
|
|
2
|
+
_target_: quadra.models.classification.TorchHubNetworkBuilder
|
|
3
|
+
repo_or_dir: facebookresearch/dinov2
|
|
4
|
+
model_name: dinov2_vitb14
|
|
5
|
+
pretrained: true
|
|
6
|
+
freeze: false
|
|
7
|
+
hyperspherical: false
|
|
8
|
+
metadata:
|
|
9
|
+
input_size: 224
|
|
10
|
+
output_dim: 768
|
|
11
|
+
patch_size: 14
|
|
12
|
+
nb_heads: 12
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
model:
|
|
2
|
+
_target_: quadra.models.classification.TorchHubNetworkBuilder
|
|
3
|
+
repo_or_dir: facebookresearch/dinov2
|
|
4
|
+
model_name: dinov2_vits14
|
|
5
|
+
pretrained: true
|
|
6
|
+
freeze: false
|
|
7
|
+
hyperspherical: false
|
|
8
|
+
metadata:
|
|
9
|
+
input_size: 224
|
|
10
|
+
output_dim: 384
|
|
11
|
+
patch_size: 14
|
|
12
|
+
nb_heads: 6
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
model:
|
|
2
|
+
_target_: monai.networks.nets.unetr.UNETR
|
|
3
|
+
in_channels: 3
|
|
4
|
+
out_channels: 1
|
|
5
|
+
img_size: [448, 448]
|
|
6
|
+
feature_size: 16
|
|
7
|
+
hidden_size: 384 # 192
|
|
8
|
+
mlp_dim: 1536 # 768
|
|
9
|
+
num_heads: 8 # 3
|
|
10
|
+
pos_embed: conv
|
|
11
|
+
norm_name: instance
|
|
12
|
+
conv_block: true
|
|
13
|
+
res_block: true
|
|
14
|
+
dropout_rate: 0
|
|
15
|
+
spatial_dims: 2
|