quadra 0.0.1__py3-none-any.whl → 2.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hydra_plugins/quadra_searchpath_plugin.py +30 -0
- quadra/__init__.py +6 -0
- quadra/callbacks/__init__.py +0 -0
- quadra/callbacks/anomalib.py +289 -0
- quadra/callbacks/lightning.py +501 -0
- quadra/callbacks/mlflow.py +291 -0
- quadra/callbacks/scheduler.py +69 -0
- quadra/configs/__init__.py +0 -0
- quadra/configs/backbone/caformer_m36.yaml +8 -0
- quadra/configs/backbone/caformer_s36.yaml +8 -0
- quadra/configs/backbone/convnextv2_base.yaml +8 -0
- quadra/configs/backbone/convnextv2_femto.yaml +8 -0
- quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
- quadra/configs/backbone/dino_vitb8.yaml +12 -0
- quadra/configs/backbone/dino_vits8.yaml +12 -0
- quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
- quadra/configs/backbone/dinov2_vits14.yaml +12 -0
- quadra/configs/backbone/efficientnet_b0.yaml +8 -0
- quadra/configs/backbone/efficientnet_b1.yaml +8 -0
- quadra/configs/backbone/efficientnet_b2.yaml +8 -0
- quadra/configs/backbone/efficientnet_b3.yaml +8 -0
- quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
- quadra/configs/backbone/levit_128s.yaml +8 -0
- quadra/configs/backbone/mnasnet0_5.yaml +9 -0
- quadra/configs/backbone/resnet101.yaml +8 -0
- quadra/configs/backbone/resnet18.yaml +8 -0
- quadra/configs/backbone/resnet18_ssl.yaml +8 -0
- quadra/configs/backbone/resnet50.yaml +8 -0
- quadra/configs/backbone/smp.yaml +9 -0
- quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
- quadra/configs/backbone/unetr.yaml +15 -0
- quadra/configs/backbone/vit16_base.yaml +9 -0
- quadra/configs/backbone/vit16_small.yaml +9 -0
- quadra/configs/backbone/vit16_tiny.yaml +9 -0
- quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
- quadra/configs/callbacks/all.yaml +45 -0
- quadra/configs/callbacks/default.yaml +34 -0
- quadra/configs/callbacks/default_anomalib.yaml +64 -0
- quadra/configs/config.yaml +33 -0
- quadra/configs/core/default.yaml +11 -0
- quadra/configs/datamodule/base/anomaly.yaml +16 -0
- quadra/configs/datamodule/base/classification.yaml +21 -0
- quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
- quadra/configs/datamodule/base/segmentation.yaml +18 -0
- quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
- quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
- quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
- quadra/configs/datamodule/base/ssl.yaml +21 -0
- quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
- quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
- quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
- quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
- quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
- quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
- quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
- quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
- quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
- quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
- quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/base/classification/classification.yaml +73 -0
- quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
- quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
- quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
- quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
- quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
- quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
- quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
- quadra/configs/experiment/base/ssl/byol.yaml +43 -0
- quadra/configs/experiment/base/ssl/dino.yaml +46 -0
- quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
- quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
- quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
- quadra/configs/experiment/custom/cls.yaml +12 -0
- quadra/configs/experiment/default.yaml +15 -0
- quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
- quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
- quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
- quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
- quadra/configs/export/default.yaml +13 -0
- quadra/configs/hydra/anomaly_custom.yaml +15 -0
- quadra/configs/hydra/default.yaml +14 -0
- quadra/configs/inference/default.yaml +26 -0
- quadra/configs/logger/comet.yaml +10 -0
- quadra/configs/logger/csv.yaml +5 -0
- quadra/configs/logger/mlflow.yaml +12 -0
- quadra/configs/logger/tensorboard.yaml +8 -0
- quadra/configs/loss/asl.yaml +7 -0
- quadra/configs/loss/barlow.yaml +2 -0
- quadra/configs/loss/bce.yaml +1 -0
- quadra/configs/loss/byol.yaml +1 -0
- quadra/configs/loss/cross_entropy.yaml +1 -0
- quadra/configs/loss/dino.yaml +8 -0
- quadra/configs/loss/simclr.yaml +2 -0
- quadra/configs/loss/simsiam.yaml +1 -0
- quadra/configs/loss/smp_ce.yaml +3 -0
- quadra/configs/loss/smp_dice.yaml +2 -0
- quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
- quadra/configs/loss/smp_mcc.yaml +2 -0
- quadra/configs/loss/vicreg.yaml +5 -0
- quadra/configs/model/anomalib/cfa.yaml +35 -0
- quadra/configs/model/anomalib/cflow.yaml +30 -0
- quadra/configs/model/anomalib/csflow.yaml +34 -0
- quadra/configs/model/anomalib/dfm.yaml +19 -0
- quadra/configs/model/anomalib/draem.yaml +29 -0
- quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
- quadra/configs/model/anomalib/fastflow.yaml +32 -0
- quadra/configs/model/anomalib/padim.yaml +32 -0
- quadra/configs/model/anomalib/patchcore.yaml +36 -0
- quadra/configs/model/barlow.yaml +16 -0
- quadra/configs/model/byol.yaml +25 -0
- quadra/configs/model/classification.yaml +10 -0
- quadra/configs/model/dino.yaml +26 -0
- quadra/configs/model/logistic_regression.yaml +4 -0
- quadra/configs/model/multilabel_classification.yaml +9 -0
- quadra/configs/model/simclr.yaml +18 -0
- quadra/configs/model/simsiam.yaml +24 -0
- quadra/configs/model/smp.yaml +4 -0
- quadra/configs/model/smp_multiclass.yaml +4 -0
- quadra/configs/model/vicreg.yaml +16 -0
- quadra/configs/optimizer/adam.yaml +5 -0
- quadra/configs/optimizer/adamw.yaml +3 -0
- quadra/configs/optimizer/default.yaml +4 -0
- quadra/configs/optimizer/lars.yaml +8 -0
- quadra/configs/optimizer/sgd.yaml +4 -0
- quadra/configs/scheduler/default.yaml +5 -0
- quadra/configs/scheduler/rop.yaml +5 -0
- quadra/configs/scheduler/step.yaml +3 -0
- quadra/configs/scheduler/warmrestart.yaml +2 -0
- quadra/configs/scheduler/warmup.yaml +6 -0
- quadra/configs/task/anomalib/cfa.yaml +5 -0
- quadra/configs/task/anomalib/cflow.yaml +5 -0
- quadra/configs/task/anomalib/csflow.yaml +5 -0
- quadra/configs/task/anomalib/draem.yaml +5 -0
- quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
- quadra/configs/task/anomalib/fastflow.yaml +5 -0
- quadra/configs/task/anomalib/inference.yaml +3 -0
- quadra/configs/task/anomalib/padim.yaml +5 -0
- quadra/configs/task/anomalib/patchcore.yaml +5 -0
- quadra/configs/task/classification.yaml +6 -0
- quadra/configs/task/classification_evaluation.yaml +6 -0
- quadra/configs/task/default.yaml +1 -0
- quadra/configs/task/segmentation.yaml +9 -0
- quadra/configs/task/segmentation_evaluation.yaml +3 -0
- quadra/configs/task/sklearn_classification.yaml +13 -0
- quadra/configs/task/sklearn_classification_patch.yaml +11 -0
- quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
- quadra/configs/task/sklearn_classification_test.yaml +8 -0
- quadra/configs/task/ssl.yaml +2 -0
- quadra/configs/trainer/lightning_cpu.yaml +36 -0
- quadra/configs/trainer/lightning_gpu.yaml +35 -0
- quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
- quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
- quadra/configs/trainer/lightning_multigpu.yaml +37 -0
- quadra/configs/trainer/sklearn_classification.yaml +7 -0
- quadra/configs/transforms/byol.yaml +47 -0
- quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
- quadra/configs/transforms/default.yaml +37 -0
- quadra/configs/transforms/default_numpy.yaml +24 -0
- quadra/configs/transforms/default_resize.yaml +22 -0
- quadra/configs/transforms/dino.yaml +63 -0
- quadra/configs/transforms/linear_eval.yaml +18 -0
- quadra/datamodules/__init__.py +20 -0
- quadra/datamodules/anomaly.py +180 -0
- quadra/datamodules/base.py +375 -0
- quadra/datamodules/classification.py +1003 -0
- quadra/datamodules/generic/__init__.py +0 -0
- quadra/datamodules/generic/imagenette.py +144 -0
- quadra/datamodules/generic/mnist.py +81 -0
- quadra/datamodules/generic/mvtec.py +58 -0
- quadra/datamodules/generic/oxford_pet.py +163 -0
- quadra/datamodules/patch.py +190 -0
- quadra/datamodules/segmentation.py +742 -0
- quadra/datamodules/ssl.py +140 -0
- quadra/datasets/__init__.py +17 -0
- quadra/datasets/anomaly.py +287 -0
- quadra/datasets/classification.py +241 -0
- quadra/datasets/patch.py +138 -0
- quadra/datasets/segmentation.py +239 -0
- quadra/datasets/ssl.py +110 -0
- quadra/losses/__init__.py +0 -0
- quadra/losses/classification/__init__.py +6 -0
- quadra/losses/classification/asl.py +83 -0
- quadra/losses/classification/focal.py +320 -0
- quadra/losses/classification/prototypical.py +148 -0
- quadra/losses/ssl/__init__.py +17 -0
- quadra/losses/ssl/barlowtwins.py +47 -0
- quadra/losses/ssl/byol.py +37 -0
- quadra/losses/ssl/dino.py +129 -0
- quadra/losses/ssl/hyperspherical.py +45 -0
- quadra/losses/ssl/idmm.py +50 -0
- quadra/losses/ssl/simclr.py +67 -0
- quadra/losses/ssl/simsiam.py +30 -0
- quadra/losses/ssl/vicreg.py +76 -0
- quadra/main.py +49 -0
- quadra/metrics/__init__.py +3 -0
- quadra/metrics/segmentation.py +251 -0
- quadra/models/__init__.py +0 -0
- quadra/models/base.py +151 -0
- quadra/models/classification/__init__.py +8 -0
- quadra/models/classification/backbones.py +149 -0
- quadra/models/classification/base.py +92 -0
- quadra/models/evaluation.py +322 -0
- quadra/modules/__init__.py +0 -0
- quadra/modules/backbone.py +30 -0
- quadra/modules/base.py +312 -0
- quadra/modules/classification/__init__.py +3 -0
- quadra/modules/classification/base.py +327 -0
- quadra/modules/ssl/__init__.py +17 -0
- quadra/modules/ssl/barlowtwins.py +59 -0
- quadra/modules/ssl/byol.py +172 -0
- quadra/modules/ssl/common.py +285 -0
- quadra/modules/ssl/dino.py +186 -0
- quadra/modules/ssl/hyperspherical.py +206 -0
- quadra/modules/ssl/idmm.py +98 -0
- quadra/modules/ssl/simclr.py +73 -0
- quadra/modules/ssl/simsiam.py +68 -0
- quadra/modules/ssl/vicreg.py +67 -0
- quadra/optimizers/__init__.py +4 -0
- quadra/optimizers/lars.py +153 -0
- quadra/optimizers/sam.py +127 -0
- quadra/schedulers/__init__.py +3 -0
- quadra/schedulers/base.py +44 -0
- quadra/schedulers/warmup.py +127 -0
- quadra/tasks/__init__.py +24 -0
- quadra/tasks/anomaly.py +582 -0
- quadra/tasks/base.py +397 -0
- quadra/tasks/classification.py +1263 -0
- quadra/tasks/patch.py +492 -0
- quadra/tasks/segmentation.py +389 -0
- quadra/tasks/ssl.py +560 -0
- quadra/trainers/README.md +3 -0
- quadra/trainers/__init__.py +0 -0
- quadra/trainers/classification.py +179 -0
- quadra/utils/__init__.py +0 -0
- quadra/utils/anomaly.py +112 -0
- quadra/utils/classification.py +618 -0
- quadra/utils/deprecation.py +31 -0
- quadra/utils/evaluation.py +474 -0
- quadra/utils/export.py +585 -0
- quadra/utils/imaging.py +32 -0
- quadra/utils/logger.py +15 -0
- quadra/utils/mlflow.py +98 -0
- quadra/utils/model_manager.py +320 -0
- quadra/utils/models.py +523 -0
- quadra/utils/patch/__init__.py +15 -0
- quadra/utils/patch/dataset.py +1433 -0
- quadra/utils/patch/metrics.py +449 -0
- quadra/utils/patch/model.py +153 -0
- quadra/utils/patch/visualization.py +217 -0
- quadra/utils/resolver.py +42 -0
- quadra/utils/segmentation.py +31 -0
- quadra/utils/tests/__init__.py +0 -0
- quadra/utils/tests/fixtures/__init__.py +1 -0
- quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
- quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
- quadra/utils/tests/fixtures/dataset/classification.py +406 -0
- quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
- quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
- quadra/utils/tests/fixtures/models/__init__.py +3 -0
- quadra/utils/tests/fixtures/models/anomaly.py +89 -0
- quadra/utils/tests/fixtures/models/classification.py +45 -0
- quadra/utils/tests/fixtures/models/segmentation.py +33 -0
- quadra/utils/tests/helpers.py +70 -0
- quadra/utils/tests/models.py +27 -0
- quadra/utils/utils.py +525 -0
- quadra/utils/validator.py +115 -0
- quadra/utils/visualization.py +422 -0
- quadra/utils/vit_explainability.py +349 -0
- quadra-2.2.7.dist-info/LICENSE +201 -0
- quadra-2.2.7.dist-info/METADATA +381 -0
- quadra-2.2.7.dist-info/RECORD +300 -0
- {quadra-0.0.1.dist-info → quadra-2.2.7.dist-info}/WHEEL +1 -1
- quadra-2.2.7.dist-info/entry_points.txt +3 -0
- quadra-0.0.1.dist-info/METADATA +0 -14
- quadra-0.0.1.dist-info/RECORD +0 -4
|
@@ -0,0 +1,389 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
import typing
|
|
6
|
+
from typing import Any, Generic
|
|
7
|
+
|
|
8
|
+
import cv2
|
|
9
|
+
import hydra
|
|
10
|
+
import torch
|
|
11
|
+
from omegaconf import DictConfig, OmegaConf
|
|
12
|
+
from torch.utils.data import DataLoader
|
|
13
|
+
|
|
14
|
+
from quadra.callbacks.mlflow import get_mlflow_logger
|
|
15
|
+
from quadra.datamodules import SegmentationDataModule, SegmentationMulticlassDataModule
|
|
16
|
+
from quadra.models.base import ModelSignatureWrapper
|
|
17
|
+
from quadra.models.evaluation import BaseEvaluationModel
|
|
18
|
+
from quadra.modules.base import SegmentationModel
|
|
19
|
+
from quadra.tasks.base import Evaluation, LightningTask
|
|
20
|
+
from quadra.utils import utils
|
|
21
|
+
from quadra.utils.evaluation import automatic_datamodule_batch_size, create_mask_report
|
|
22
|
+
from quadra.utils.export import export_model
|
|
23
|
+
|
|
24
|
+
log = utils.get_logger(__name__)
|
|
25
|
+
|
|
26
|
+
SegmentationDataModuleT = typing.TypeVar(
|
|
27
|
+
"SegmentationDataModuleT", SegmentationDataModule, SegmentationMulticlassDataModule
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class Segmentation(Generic[SegmentationDataModuleT], LightningTask[SegmentationDataModuleT]):
|
|
32
|
+
"""Task for segmentation.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
config: Config object
|
|
36
|
+
num_viz_samples: Number of samples to visualize. Defaults to 5.
|
|
37
|
+
checkpoint_path: Path to the checkpoint to load the model from. Defaults to None.
|
|
38
|
+
run_test: If True, run test after training. Defaults to False.
|
|
39
|
+
evaluate: Dict with evaluation parameters. Defaults to None.
|
|
40
|
+
report: If True, create report after training. Defaults to False.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
config: DictConfig,
|
|
46
|
+
num_viz_samples: int = 5,
|
|
47
|
+
checkpoint_path: str | None = None,
|
|
48
|
+
run_test: bool = False,
|
|
49
|
+
evaluate: DictConfig | None = None,
|
|
50
|
+
report: bool = False,
|
|
51
|
+
):
|
|
52
|
+
super().__init__(
|
|
53
|
+
config=config,
|
|
54
|
+
checkpoint_path=checkpoint_path,
|
|
55
|
+
run_test=run_test,
|
|
56
|
+
report=report,
|
|
57
|
+
)
|
|
58
|
+
self.evaluate = evaluate
|
|
59
|
+
self.num_viz_samples = num_viz_samples
|
|
60
|
+
self.export_folder: str = "deployment_model"
|
|
61
|
+
self.exported_model_path: str | None = None
|
|
62
|
+
if self.evaluate and any(self.evaluate.values()):
|
|
63
|
+
if (
|
|
64
|
+
self.config.export is None
|
|
65
|
+
or len(self.config.export.types) == 0
|
|
66
|
+
or "torchscript" not in self.config.export.types
|
|
67
|
+
):
|
|
68
|
+
log.info(
|
|
69
|
+
"Evaluation is enabled, but training does not export a deployment model. Automatically export the "
|
|
70
|
+
"model as torchscript."
|
|
71
|
+
)
|
|
72
|
+
if self.config.export is None:
|
|
73
|
+
self.config.export = DictConfig({"types": ["torchscript"]})
|
|
74
|
+
else:
|
|
75
|
+
self.config.export.types.append("torchscript")
|
|
76
|
+
|
|
77
|
+
if not self.report:
|
|
78
|
+
log.info("Evaluation is enabled, but reporting is disabled. Enabling reporting automatically.")
|
|
79
|
+
self.report = True
|
|
80
|
+
|
|
81
|
+
@property
|
|
82
|
+
def module(self) -> SegmentationModel:
|
|
83
|
+
"""Get the module."""
|
|
84
|
+
return self._module
|
|
85
|
+
|
|
86
|
+
@module.setter
|
|
87
|
+
def module(self, module_config) -> None:
|
|
88
|
+
"""Set the module."""
|
|
89
|
+
log.info("Instantiating model <%s>", module_config.model["_target_"])
|
|
90
|
+
|
|
91
|
+
if isinstance(self.datamodule, SegmentationMulticlassDataModule) and module_config.model.num_classes != (
|
|
92
|
+
len(self.datamodule.idx_to_class) + 1
|
|
93
|
+
):
|
|
94
|
+
log.warning(
|
|
95
|
+
f"Number of classes in the model ({module_config.model.num_classes}) does not match the number of "
|
|
96
|
+
+ f"classes in the datamodule ({len(self.datamodule.idx_to_class)}). Updating the model..."
|
|
97
|
+
)
|
|
98
|
+
module_config.model.num_classes = len(self.datamodule.idx_to_class) + 1
|
|
99
|
+
|
|
100
|
+
model = hydra.utils.instantiate(module_config.model)
|
|
101
|
+
model = ModelSignatureWrapper(model)
|
|
102
|
+
log.info("Instantiating optimizer <%s>", self.config.optimizer["_target_"])
|
|
103
|
+
param_list = []
|
|
104
|
+
for param in model.parameters():
|
|
105
|
+
if param.requires_grad:
|
|
106
|
+
param_list.append(param)
|
|
107
|
+
optimizer = hydra.utils.instantiate(self.config.optimizer, param_list)
|
|
108
|
+
log.info("Instantiating scheduler <%s>", self.config.scheduler["_target_"])
|
|
109
|
+
scheduler = hydra.utils.instantiate(self.config.scheduler, optimizer=optimizer)
|
|
110
|
+
log.info("Instantiating module <%s>", module_config.module["_target_"])
|
|
111
|
+
module = hydra.utils.instantiate(module_config.module, model=model, optimizer=optimizer, lr_scheduler=scheduler)
|
|
112
|
+
if self.checkpoint_path is not None:
|
|
113
|
+
module.__class__.load_from_checkpoint(
|
|
114
|
+
self.checkpoint_path, model=model, optimizer=optimizer, lr_scheduler=scheduler
|
|
115
|
+
)
|
|
116
|
+
self._module = module
|
|
117
|
+
|
|
118
|
+
def prepare(self) -> None:
|
|
119
|
+
"""Prepare the task."""
|
|
120
|
+
super().prepare()
|
|
121
|
+
self.module = self.config.model
|
|
122
|
+
|
|
123
|
+
def export(self) -> None:
|
|
124
|
+
"""Generate a deployment model for the task."""
|
|
125
|
+
log.info("Exporting model ready for deployment")
|
|
126
|
+
|
|
127
|
+
# Get best model!
|
|
128
|
+
if (
|
|
129
|
+
self.trainer.checkpoint_callback is not None
|
|
130
|
+
and hasattr(self.trainer.checkpoint_callback, "best_model_path")
|
|
131
|
+
and self.trainer.checkpoint_callback.best_model_path is not None
|
|
132
|
+
and len(self.trainer.checkpoint_callback.best_model_path) > 0
|
|
133
|
+
):
|
|
134
|
+
best_model_path = self.trainer.checkpoint_callback.best_model_path
|
|
135
|
+
log.info("Loaded best model from %s", best_model_path)
|
|
136
|
+
|
|
137
|
+
module = self.module.__class__.load_from_checkpoint(
|
|
138
|
+
best_model_path,
|
|
139
|
+
model=self.module.model,
|
|
140
|
+
loss_fun=None,
|
|
141
|
+
optimizer=self.module.optimizer,
|
|
142
|
+
lr_scheduler=self.module.schedulers,
|
|
143
|
+
)
|
|
144
|
+
else:
|
|
145
|
+
log.warning("No checkpoint callback found in the trainer, exporting the last model weights")
|
|
146
|
+
module = self.module
|
|
147
|
+
|
|
148
|
+
if "idx_to_class" not in self.config.datamodule:
|
|
149
|
+
log.info("No idx_to_class key")
|
|
150
|
+
idx_to_class = {0: "good", 1: "bad"} # TODO: Why is this the default value?
|
|
151
|
+
else:
|
|
152
|
+
log.info("idx_to_class is present")
|
|
153
|
+
idx_to_class = self.config.datamodule.idx_to_class
|
|
154
|
+
|
|
155
|
+
if self.config.export is None:
|
|
156
|
+
raise ValueError(
|
|
157
|
+
"No export type specified. This should not happen, please check if you have set "
|
|
158
|
+
"the export_type or assign it to a default value."
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
half_precision = "16" in self.trainer.precision
|
|
162
|
+
|
|
163
|
+
input_shapes = self.config.export.input_shapes
|
|
164
|
+
|
|
165
|
+
model_json, export_paths = export_model(
|
|
166
|
+
config=self.config,
|
|
167
|
+
model=module.model,
|
|
168
|
+
export_folder=self.export_folder,
|
|
169
|
+
half_precision=half_precision,
|
|
170
|
+
input_shapes=input_shapes,
|
|
171
|
+
idx_to_class=idx_to_class,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
if len(export_paths) == 0:
|
|
175
|
+
return
|
|
176
|
+
|
|
177
|
+
# Pick one model for evaluation, it should be independent of the export type as the model is wrapped
|
|
178
|
+
self.exported_model_path = next(iter(export_paths.values()))
|
|
179
|
+
|
|
180
|
+
with open(os.path.join(self.export_folder, "model.json"), "w") as f:
|
|
181
|
+
json.dump(model_json, f, cls=utils.HydraEncoder)
|
|
182
|
+
|
|
183
|
+
def generate_report(self) -> None:
|
|
184
|
+
"""Generate a report for the task."""
|
|
185
|
+
if self.evaluate is not None:
|
|
186
|
+
log.info("Generating evaluation report!")
|
|
187
|
+
eval_tasks: list[SegmentationEvaluation] = []
|
|
188
|
+
if self.evaluate.analysis:
|
|
189
|
+
if self.exported_model_path is None:
|
|
190
|
+
raise ValueError(
|
|
191
|
+
"Exported model path is not set yet but the task tries to do an analysis evaluation"
|
|
192
|
+
)
|
|
193
|
+
eval_task = SegmentationAnalysisEvaluation(
|
|
194
|
+
config=self.config,
|
|
195
|
+
model_path=self.exported_model_path,
|
|
196
|
+
)
|
|
197
|
+
eval_tasks.append(eval_task)
|
|
198
|
+
for task in eval_tasks:
|
|
199
|
+
task.execute()
|
|
200
|
+
|
|
201
|
+
if len(self.logger) > 0:
|
|
202
|
+
mflow_logger = get_mlflow_logger(trainer=self.trainer)
|
|
203
|
+
tensorboard_logger = utils.get_tensorboard_logger(trainer=self.trainer)
|
|
204
|
+
|
|
205
|
+
if mflow_logger is not None and self.config.core.get("upload_artifacts"):
|
|
206
|
+
for task in eval_tasks:
|
|
207
|
+
for file in task.metadata["report_files"]:
|
|
208
|
+
mflow_logger.experiment.log_artifact(
|
|
209
|
+
run_id=mflow_logger.run_id, local_path=file, artifact_path=task.report_path
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
if tensorboard_logger is not None and self.config.core.get("upload_artifacts"):
|
|
213
|
+
for task in eval_tasks:
|
|
214
|
+
for file in task.metadata["report_files"]:
|
|
215
|
+
ext = os.path.splitext(file)[1].lower()
|
|
216
|
+
|
|
217
|
+
if ext in [".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".tif", ".gif"]:
|
|
218
|
+
try:
|
|
219
|
+
img = cv2.imread(file)
|
|
220
|
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
221
|
+
except cv2.error:
|
|
222
|
+
log.info("Could not upload artifact image %s", file)
|
|
223
|
+
continue
|
|
224
|
+
|
|
225
|
+
tensorboard_logger.experiment.add_image(
|
|
226
|
+
os.path.basename(file), img, 0, dataformats="HWC"
|
|
227
|
+
)
|
|
228
|
+
else:
|
|
229
|
+
utils.upload_file_tensorboard(file, tensorboard_logger)
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
class SegmentationEvaluation(Evaluation[SegmentationDataModuleT]):
|
|
233
|
+
"""Segmentation Evaluation Task with deployment models.
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
config: The experiment configuration
|
|
237
|
+
model_path: The experiment path.
|
|
238
|
+
device: Device to use for evaluation. If None, the device is automatically determined.
|
|
239
|
+
|
|
240
|
+
Raises:
|
|
241
|
+
ValueError: If the model path is not provided
|
|
242
|
+
"""
|
|
243
|
+
|
|
244
|
+
def __init__(
|
|
245
|
+
self,
|
|
246
|
+
config: DictConfig,
|
|
247
|
+
model_path: str,
|
|
248
|
+
device: str | None = "cpu",
|
|
249
|
+
):
|
|
250
|
+
super().__init__(config=config, model_path=model_path, device=device)
|
|
251
|
+
self.config = config
|
|
252
|
+
|
|
253
|
+
def save_config(self) -> None:
|
|
254
|
+
"""Skip saving the config."""
|
|
255
|
+
|
|
256
|
+
def prepare(self) -> None:
|
|
257
|
+
"""Prepare the evaluation."""
|
|
258
|
+
super().prepare()
|
|
259
|
+
# TODO: Why we propagate mean and std only in Segmentation?
|
|
260
|
+
self.config.transforms.mean = self.model_data["mean"]
|
|
261
|
+
self.config.transforms.std = self.model_data["std"]
|
|
262
|
+
# Setup datamodule
|
|
263
|
+
if hasattr(self.config.datamodule, "idx_to_class"):
|
|
264
|
+
idx_to_class = self.model_data["classes"] # dict {index: class}
|
|
265
|
+
self.config.datamodule.idx_to_class = idx_to_class
|
|
266
|
+
self.datamodule = self.config.datamodule
|
|
267
|
+
# prepare_data() must be explicitly called because there is no lightning training
|
|
268
|
+
self.datamodule.prepare_data()
|
|
269
|
+
|
|
270
|
+
@torch.no_grad()
|
|
271
|
+
def inference(
|
|
272
|
+
self, dataloader: DataLoader, deployment_model: BaseEvaluationModel, device: torch.device
|
|
273
|
+
) -> dict[str, torch.Tensor]:
|
|
274
|
+
"""Run inference on the dataloader and return the output.
|
|
275
|
+
|
|
276
|
+
Args:
|
|
277
|
+
dataloader: The dataloader to run inference on
|
|
278
|
+
deployment_model: The deployment model to use
|
|
279
|
+
device: The device to run inference on
|
|
280
|
+
"""
|
|
281
|
+
image_list, mask_list, mask_pred_list, label_list = [], [], [], []
|
|
282
|
+
for batch in dataloader:
|
|
283
|
+
images, masks, labels = batch
|
|
284
|
+
images = images.to(device)
|
|
285
|
+
masks = masks.to(device)
|
|
286
|
+
labels = labels.to(device)
|
|
287
|
+
image_list.append(images.cpu())
|
|
288
|
+
mask_list.append(masks.cpu())
|
|
289
|
+
mask_pred_list.append(deployment_model(images.to(device)).cpu())
|
|
290
|
+
label_list.append(labels.cpu())
|
|
291
|
+
output = {
|
|
292
|
+
"image": torch.cat(image_list, dim=0),
|
|
293
|
+
"mask": torch.cat(mask_list, dim=0),
|
|
294
|
+
"label": torch.cat(label_list, dim=0),
|
|
295
|
+
"mask_pred": torch.cat(mask_pred_list, dim=0),
|
|
296
|
+
}
|
|
297
|
+
return output
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
class SegmentationAnalysisEvaluation(SegmentationEvaluation):
|
|
301
|
+
"""Segmentation Analysis Evaluation Task
|
|
302
|
+
Args:
|
|
303
|
+
config: The experiment configuration
|
|
304
|
+
model_path: The model path.
|
|
305
|
+
device: Device to use for evaluation. If None, the device is automatically determined.
|
|
306
|
+
"""
|
|
307
|
+
|
|
308
|
+
def __init__(
|
|
309
|
+
self,
|
|
310
|
+
config: DictConfig,
|
|
311
|
+
model_path: str,
|
|
312
|
+
device: str | None = None,
|
|
313
|
+
):
|
|
314
|
+
super().__init__(config=config, model_path=model_path, device=device)
|
|
315
|
+
self.test_output: dict[str, Any] = {}
|
|
316
|
+
|
|
317
|
+
def train(self) -> None:
|
|
318
|
+
"""Skip training."""
|
|
319
|
+
|
|
320
|
+
def prepare(self) -> None:
|
|
321
|
+
"""Prepare the evaluation task."""
|
|
322
|
+
super().prepare()
|
|
323
|
+
self.datamodule.setup(stage="fit")
|
|
324
|
+
self.datamodule.setup(stage="test")
|
|
325
|
+
|
|
326
|
+
@automatic_datamodule_batch_size(batch_size_attribute_name="batch_size")
|
|
327
|
+
def test(self) -> None:
|
|
328
|
+
"""Run testing."""
|
|
329
|
+
log.info("Starting inference for analysis.")
|
|
330
|
+
|
|
331
|
+
stages: list[str] = []
|
|
332
|
+
dataloaders: list[torch.utils.data.DataLoader] = []
|
|
333
|
+
|
|
334
|
+
# if self.datamodule.train_dataset_available:
|
|
335
|
+
# stages.append("train")
|
|
336
|
+
# dataloaders.append(self.datamodule.train_dataloader())
|
|
337
|
+
# if self.datamodule.val_dataset_available:
|
|
338
|
+
# stages.append("val")
|
|
339
|
+
# dataloaders.append(self.datamodule.val_dataloader())
|
|
340
|
+
|
|
341
|
+
if self.datamodule.test_dataset_available:
|
|
342
|
+
stages.append("test")
|
|
343
|
+
dataloaders.append(self.datamodule.test_dataloader())
|
|
344
|
+
for stage, dataloader in zip(stages, dataloaders):
|
|
345
|
+
log.info("Running inference on %s set with batch size: %d", stage, dataloader.batch_size)
|
|
346
|
+
image_list, mask_list, mask_pred_list, label_list = [], [], [], []
|
|
347
|
+
for batch in dataloader:
|
|
348
|
+
images, masks, labels = batch
|
|
349
|
+
images = images.to(device=self.device, dtype=self.deployment_model.model_dtype)
|
|
350
|
+
if len(masks.shape) == 3: # BxHxW -> Bx1xHxW
|
|
351
|
+
masks = masks.unsqueeze(1)
|
|
352
|
+
with torch.no_grad():
|
|
353
|
+
image_list.append(images)
|
|
354
|
+
mask_list.append(masks)
|
|
355
|
+
mask_pred_list.append(self.deployment_model(images.to(self.device)))
|
|
356
|
+
label_list.append(labels)
|
|
357
|
+
|
|
358
|
+
output = {
|
|
359
|
+
"image": torch.cat(image_list, dim=0),
|
|
360
|
+
"mask": torch.cat(mask_list, dim=0),
|
|
361
|
+
"label": torch.cat(label_list, dim=0),
|
|
362
|
+
"mask_pred": torch.cat(mask_pred_list, dim=0),
|
|
363
|
+
}
|
|
364
|
+
self.test_output[stage] = output
|
|
365
|
+
|
|
366
|
+
def generate_report(self) -> None:
|
|
367
|
+
"""Generate a report."""
|
|
368
|
+
log.info("Generating analysis report")
|
|
369
|
+
|
|
370
|
+
for stage, output in self.test_output.items():
|
|
371
|
+
image_mean = OmegaConf.to_container(self.config.transforms.mean)
|
|
372
|
+
if not isinstance(image_mean, list) or any(not isinstance(x, (int, float)) for x in image_mean):
|
|
373
|
+
raise ValueError("Image mean is not a list of float or integer values, please check your config")
|
|
374
|
+
image_std = OmegaConf.to_container(self.config.transforms.std)
|
|
375
|
+
if not isinstance(image_std, list) or any(not isinstance(x, (int, float)) for x in image_std):
|
|
376
|
+
raise ValueError("Image std is not a list of float or integer values, please check your config")
|
|
377
|
+
reports = create_mask_report(
|
|
378
|
+
stage=stage,
|
|
379
|
+
output=output,
|
|
380
|
+
report_path="analysis_report",
|
|
381
|
+
mean=image_mean,
|
|
382
|
+
std=image_std,
|
|
383
|
+
analysis=True,
|
|
384
|
+
nb_samples=10,
|
|
385
|
+
apply_sigmoid=True,
|
|
386
|
+
show_orj_predictions=True,
|
|
387
|
+
)
|
|
388
|
+
self.metadata["report_files"].extend(reports)
|
|
389
|
+
log.info("%s analysis report completed.", stage)
|