quadra 0.0.1__py3-none-any.whl → 2.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hydra_plugins/quadra_searchpath_plugin.py +30 -0
- quadra/__init__.py +6 -0
- quadra/callbacks/__init__.py +0 -0
- quadra/callbacks/anomalib.py +289 -0
- quadra/callbacks/lightning.py +501 -0
- quadra/callbacks/mlflow.py +291 -0
- quadra/callbacks/scheduler.py +69 -0
- quadra/configs/__init__.py +0 -0
- quadra/configs/backbone/caformer_m36.yaml +8 -0
- quadra/configs/backbone/caformer_s36.yaml +8 -0
- quadra/configs/backbone/convnextv2_base.yaml +8 -0
- quadra/configs/backbone/convnextv2_femto.yaml +8 -0
- quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
- quadra/configs/backbone/dino_vitb8.yaml +12 -0
- quadra/configs/backbone/dino_vits8.yaml +12 -0
- quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
- quadra/configs/backbone/dinov2_vits14.yaml +12 -0
- quadra/configs/backbone/efficientnet_b0.yaml +8 -0
- quadra/configs/backbone/efficientnet_b1.yaml +8 -0
- quadra/configs/backbone/efficientnet_b2.yaml +8 -0
- quadra/configs/backbone/efficientnet_b3.yaml +8 -0
- quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
- quadra/configs/backbone/levit_128s.yaml +8 -0
- quadra/configs/backbone/mnasnet0_5.yaml +9 -0
- quadra/configs/backbone/resnet101.yaml +8 -0
- quadra/configs/backbone/resnet18.yaml +8 -0
- quadra/configs/backbone/resnet18_ssl.yaml +8 -0
- quadra/configs/backbone/resnet50.yaml +8 -0
- quadra/configs/backbone/smp.yaml +9 -0
- quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
- quadra/configs/backbone/unetr.yaml +15 -0
- quadra/configs/backbone/vit16_base.yaml +9 -0
- quadra/configs/backbone/vit16_small.yaml +9 -0
- quadra/configs/backbone/vit16_tiny.yaml +9 -0
- quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
- quadra/configs/callbacks/all.yaml +45 -0
- quadra/configs/callbacks/default.yaml +34 -0
- quadra/configs/callbacks/default_anomalib.yaml +64 -0
- quadra/configs/config.yaml +33 -0
- quadra/configs/core/default.yaml +11 -0
- quadra/configs/datamodule/base/anomaly.yaml +16 -0
- quadra/configs/datamodule/base/classification.yaml +21 -0
- quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
- quadra/configs/datamodule/base/segmentation.yaml +18 -0
- quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
- quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
- quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
- quadra/configs/datamodule/base/ssl.yaml +21 -0
- quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
- quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
- quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
- quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
- quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
- quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
- quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
- quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
- quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
- quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
- quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/base/classification/classification.yaml +73 -0
- quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
- quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
- quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
- quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
- quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
- quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
- quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
- quadra/configs/experiment/base/ssl/byol.yaml +43 -0
- quadra/configs/experiment/base/ssl/dino.yaml +46 -0
- quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
- quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
- quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
- quadra/configs/experiment/custom/cls.yaml +12 -0
- quadra/configs/experiment/default.yaml +15 -0
- quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
- quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
- quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
- quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
- quadra/configs/export/default.yaml +13 -0
- quadra/configs/hydra/anomaly_custom.yaml +15 -0
- quadra/configs/hydra/default.yaml +14 -0
- quadra/configs/inference/default.yaml +26 -0
- quadra/configs/logger/comet.yaml +10 -0
- quadra/configs/logger/csv.yaml +5 -0
- quadra/configs/logger/mlflow.yaml +12 -0
- quadra/configs/logger/tensorboard.yaml +8 -0
- quadra/configs/loss/asl.yaml +7 -0
- quadra/configs/loss/barlow.yaml +2 -0
- quadra/configs/loss/bce.yaml +1 -0
- quadra/configs/loss/byol.yaml +1 -0
- quadra/configs/loss/cross_entropy.yaml +1 -0
- quadra/configs/loss/dino.yaml +8 -0
- quadra/configs/loss/simclr.yaml +2 -0
- quadra/configs/loss/simsiam.yaml +1 -0
- quadra/configs/loss/smp_ce.yaml +3 -0
- quadra/configs/loss/smp_dice.yaml +2 -0
- quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
- quadra/configs/loss/smp_mcc.yaml +2 -0
- quadra/configs/loss/vicreg.yaml +5 -0
- quadra/configs/model/anomalib/cfa.yaml +35 -0
- quadra/configs/model/anomalib/cflow.yaml +30 -0
- quadra/configs/model/anomalib/csflow.yaml +34 -0
- quadra/configs/model/anomalib/dfm.yaml +19 -0
- quadra/configs/model/anomalib/draem.yaml +29 -0
- quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
- quadra/configs/model/anomalib/fastflow.yaml +32 -0
- quadra/configs/model/anomalib/padim.yaml +32 -0
- quadra/configs/model/anomalib/patchcore.yaml +36 -0
- quadra/configs/model/barlow.yaml +16 -0
- quadra/configs/model/byol.yaml +25 -0
- quadra/configs/model/classification.yaml +10 -0
- quadra/configs/model/dino.yaml +26 -0
- quadra/configs/model/logistic_regression.yaml +4 -0
- quadra/configs/model/multilabel_classification.yaml +9 -0
- quadra/configs/model/simclr.yaml +18 -0
- quadra/configs/model/simsiam.yaml +24 -0
- quadra/configs/model/smp.yaml +4 -0
- quadra/configs/model/smp_multiclass.yaml +4 -0
- quadra/configs/model/vicreg.yaml +16 -0
- quadra/configs/optimizer/adam.yaml +5 -0
- quadra/configs/optimizer/adamw.yaml +3 -0
- quadra/configs/optimizer/default.yaml +4 -0
- quadra/configs/optimizer/lars.yaml +8 -0
- quadra/configs/optimizer/sgd.yaml +4 -0
- quadra/configs/scheduler/default.yaml +5 -0
- quadra/configs/scheduler/rop.yaml +5 -0
- quadra/configs/scheduler/step.yaml +3 -0
- quadra/configs/scheduler/warmrestart.yaml +2 -0
- quadra/configs/scheduler/warmup.yaml +6 -0
- quadra/configs/task/anomalib/cfa.yaml +5 -0
- quadra/configs/task/anomalib/cflow.yaml +5 -0
- quadra/configs/task/anomalib/csflow.yaml +5 -0
- quadra/configs/task/anomalib/draem.yaml +5 -0
- quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
- quadra/configs/task/anomalib/fastflow.yaml +5 -0
- quadra/configs/task/anomalib/inference.yaml +3 -0
- quadra/configs/task/anomalib/padim.yaml +5 -0
- quadra/configs/task/anomalib/patchcore.yaml +5 -0
- quadra/configs/task/classification.yaml +6 -0
- quadra/configs/task/classification_evaluation.yaml +6 -0
- quadra/configs/task/default.yaml +1 -0
- quadra/configs/task/segmentation.yaml +9 -0
- quadra/configs/task/segmentation_evaluation.yaml +3 -0
- quadra/configs/task/sklearn_classification.yaml +13 -0
- quadra/configs/task/sklearn_classification_patch.yaml +11 -0
- quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
- quadra/configs/task/sklearn_classification_test.yaml +8 -0
- quadra/configs/task/ssl.yaml +2 -0
- quadra/configs/trainer/lightning_cpu.yaml +36 -0
- quadra/configs/trainer/lightning_gpu.yaml +35 -0
- quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
- quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
- quadra/configs/trainer/lightning_multigpu.yaml +37 -0
- quadra/configs/trainer/sklearn_classification.yaml +7 -0
- quadra/configs/transforms/byol.yaml +47 -0
- quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
- quadra/configs/transforms/default.yaml +37 -0
- quadra/configs/transforms/default_numpy.yaml +24 -0
- quadra/configs/transforms/default_resize.yaml +22 -0
- quadra/configs/transforms/dino.yaml +63 -0
- quadra/configs/transforms/linear_eval.yaml +18 -0
- quadra/datamodules/__init__.py +20 -0
- quadra/datamodules/anomaly.py +180 -0
- quadra/datamodules/base.py +375 -0
- quadra/datamodules/classification.py +1003 -0
- quadra/datamodules/generic/__init__.py +0 -0
- quadra/datamodules/generic/imagenette.py +144 -0
- quadra/datamodules/generic/mnist.py +81 -0
- quadra/datamodules/generic/mvtec.py +58 -0
- quadra/datamodules/generic/oxford_pet.py +163 -0
- quadra/datamodules/patch.py +190 -0
- quadra/datamodules/segmentation.py +742 -0
- quadra/datamodules/ssl.py +140 -0
- quadra/datasets/__init__.py +17 -0
- quadra/datasets/anomaly.py +287 -0
- quadra/datasets/classification.py +241 -0
- quadra/datasets/patch.py +138 -0
- quadra/datasets/segmentation.py +239 -0
- quadra/datasets/ssl.py +110 -0
- quadra/losses/__init__.py +0 -0
- quadra/losses/classification/__init__.py +6 -0
- quadra/losses/classification/asl.py +83 -0
- quadra/losses/classification/focal.py +320 -0
- quadra/losses/classification/prototypical.py +148 -0
- quadra/losses/ssl/__init__.py +17 -0
- quadra/losses/ssl/barlowtwins.py +47 -0
- quadra/losses/ssl/byol.py +37 -0
- quadra/losses/ssl/dino.py +129 -0
- quadra/losses/ssl/hyperspherical.py +45 -0
- quadra/losses/ssl/idmm.py +50 -0
- quadra/losses/ssl/simclr.py +67 -0
- quadra/losses/ssl/simsiam.py +30 -0
- quadra/losses/ssl/vicreg.py +76 -0
- quadra/main.py +49 -0
- quadra/metrics/__init__.py +3 -0
- quadra/metrics/segmentation.py +251 -0
- quadra/models/__init__.py +0 -0
- quadra/models/base.py +151 -0
- quadra/models/classification/__init__.py +8 -0
- quadra/models/classification/backbones.py +149 -0
- quadra/models/classification/base.py +92 -0
- quadra/models/evaluation.py +322 -0
- quadra/modules/__init__.py +0 -0
- quadra/modules/backbone.py +30 -0
- quadra/modules/base.py +312 -0
- quadra/modules/classification/__init__.py +3 -0
- quadra/modules/classification/base.py +327 -0
- quadra/modules/ssl/__init__.py +17 -0
- quadra/modules/ssl/barlowtwins.py +59 -0
- quadra/modules/ssl/byol.py +172 -0
- quadra/modules/ssl/common.py +285 -0
- quadra/modules/ssl/dino.py +186 -0
- quadra/modules/ssl/hyperspherical.py +206 -0
- quadra/modules/ssl/idmm.py +98 -0
- quadra/modules/ssl/simclr.py +73 -0
- quadra/modules/ssl/simsiam.py +68 -0
- quadra/modules/ssl/vicreg.py +67 -0
- quadra/optimizers/__init__.py +4 -0
- quadra/optimizers/lars.py +153 -0
- quadra/optimizers/sam.py +127 -0
- quadra/schedulers/__init__.py +3 -0
- quadra/schedulers/base.py +44 -0
- quadra/schedulers/warmup.py +127 -0
- quadra/tasks/__init__.py +24 -0
- quadra/tasks/anomaly.py +582 -0
- quadra/tasks/base.py +397 -0
- quadra/tasks/classification.py +1263 -0
- quadra/tasks/patch.py +492 -0
- quadra/tasks/segmentation.py +389 -0
- quadra/tasks/ssl.py +560 -0
- quadra/trainers/README.md +3 -0
- quadra/trainers/__init__.py +0 -0
- quadra/trainers/classification.py +179 -0
- quadra/utils/__init__.py +0 -0
- quadra/utils/anomaly.py +112 -0
- quadra/utils/classification.py +618 -0
- quadra/utils/deprecation.py +31 -0
- quadra/utils/evaluation.py +474 -0
- quadra/utils/export.py +585 -0
- quadra/utils/imaging.py +32 -0
- quadra/utils/logger.py +15 -0
- quadra/utils/mlflow.py +98 -0
- quadra/utils/model_manager.py +320 -0
- quadra/utils/models.py +523 -0
- quadra/utils/patch/__init__.py +15 -0
- quadra/utils/patch/dataset.py +1433 -0
- quadra/utils/patch/metrics.py +449 -0
- quadra/utils/patch/model.py +153 -0
- quadra/utils/patch/visualization.py +217 -0
- quadra/utils/resolver.py +42 -0
- quadra/utils/segmentation.py +31 -0
- quadra/utils/tests/__init__.py +0 -0
- quadra/utils/tests/fixtures/__init__.py +1 -0
- quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
- quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
- quadra/utils/tests/fixtures/dataset/classification.py +406 -0
- quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
- quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
- quadra/utils/tests/fixtures/models/__init__.py +3 -0
- quadra/utils/tests/fixtures/models/anomaly.py +89 -0
- quadra/utils/tests/fixtures/models/classification.py +45 -0
- quadra/utils/tests/fixtures/models/segmentation.py +33 -0
- quadra/utils/tests/helpers.py +70 -0
- quadra/utils/tests/models.py +27 -0
- quadra/utils/utils.py +525 -0
- quadra/utils/validator.py +115 -0
- quadra/utils/visualization.py +422 -0
- quadra/utils/vit_explainability.py +349 -0
- quadra-2.2.7.dist-info/LICENSE +201 -0
- quadra-2.2.7.dist-info/METADATA +381 -0
- quadra-2.2.7.dist-info/RECORD +300 -0
- {quadra-0.0.1.dist-info → quadra-2.2.7.dist-info}/WHEEL +1 -1
- quadra-2.2.7.dist-info/entry_points.txt +3 -0
- quadra-0.0.1.dist-info/METADATA +0 -14
- quadra-0.0.1.dist-info/RECORD +0 -4
|
@@ -0,0 +1,322 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import Any, cast
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import torch
|
|
8
|
+
from hydra.utils import instantiate
|
|
9
|
+
from omegaconf import DictConfig, OmegaConf
|
|
10
|
+
from torch import nn
|
|
11
|
+
from torch.jit import RecursiveScriptModule
|
|
12
|
+
|
|
13
|
+
from quadra.utils.logger import get_logger
|
|
14
|
+
|
|
15
|
+
try:
|
|
16
|
+
import onnxruntime as ort # noqa
|
|
17
|
+
|
|
18
|
+
ONNX_AVAILABLE = True
|
|
19
|
+
except ImportError:
|
|
20
|
+
ONNX_AVAILABLE = False
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
log = get_logger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class BaseEvaluationModel(ABC):
|
|
27
|
+
"""Base interface for all evaluation models."""
|
|
28
|
+
|
|
29
|
+
def __init__(self, config: DictConfig) -> None:
|
|
30
|
+
self.model: Any
|
|
31
|
+
self.model_path: str | None
|
|
32
|
+
self.device: str
|
|
33
|
+
self.config = config
|
|
34
|
+
self.is_loaded = False
|
|
35
|
+
self.model_dtype: np.dtype | torch.dtype
|
|
36
|
+
|
|
37
|
+
@abstractmethod
|
|
38
|
+
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
|
39
|
+
pass
|
|
40
|
+
|
|
41
|
+
@abstractmethod
|
|
42
|
+
def load_from_disk(self, model_path: str, device: str = "cpu"):
|
|
43
|
+
"""Load model from disk."""
|
|
44
|
+
|
|
45
|
+
@abstractmethod
|
|
46
|
+
def to(self, device: str):
|
|
47
|
+
"""Move model to device."""
|
|
48
|
+
|
|
49
|
+
@abstractmethod
|
|
50
|
+
def eval(self):
|
|
51
|
+
"""Set model to evaluation mode."""
|
|
52
|
+
|
|
53
|
+
@abstractmethod
|
|
54
|
+
def half(self):
|
|
55
|
+
"""Convert model to half precision."""
|
|
56
|
+
|
|
57
|
+
@abstractmethod
|
|
58
|
+
def cpu(self):
|
|
59
|
+
"""Move model to cpu."""
|
|
60
|
+
|
|
61
|
+
@property
|
|
62
|
+
def training(self) -> bool:
|
|
63
|
+
"""Return whether model is in training mode."""
|
|
64
|
+
return False
|
|
65
|
+
|
|
66
|
+
@property
|
|
67
|
+
def device(self) -> str:
|
|
68
|
+
"""Return the device of the model."""
|
|
69
|
+
return self._device
|
|
70
|
+
|
|
71
|
+
@device.setter
|
|
72
|
+
def device(self, device: str):
|
|
73
|
+
"""Set the device of the model."""
|
|
74
|
+
if device == "cuda" and ":" not in device:
|
|
75
|
+
device = f"{device}:0"
|
|
76
|
+
|
|
77
|
+
self._device = device
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class TorchscriptEvaluationModel(BaseEvaluationModel):
|
|
81
|
+
"""Wrapper for torchscript models."""
|
|
82
|
+
|
|
83
|
+
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
|
84
|
+
return self.model(*args, **kwargs)
|
|
85
|
+
|
|
86
|
+
def load_from_disk(self, model_path: str, device: str = "cpu"):
|
|
87
|
+
"""Load model from disk."""
|
|
88
|
+
self.model_path = model_path
|
|
89
|
+
self.device = device
|
|
90
|
+
|
|
91
|
+
model = cast(RecursiveScriptModule, torch.jit.load(self.model_path))
|
|
92
|
+
model.eval()
|
|
93
|
+
model.to(self.device)
|
|
94
|
+
|
|
95
|
+
parameter_types = {param.dtype for param in model.parameters()}
|
|
96
|
+
if len(parameter_types) == 2:
|
|
97
|
+
# TODO: There could be models with mixed precision?
|
|
98
|
+
raise ValueError(f"Expected only one type of parameters, found {parameter_types}")
|
|
99
|
+
|
|
100
|
+
self.model_dtype = list(parameter_types)[0]
|
|
101
|
+
self.model = model
|
|
102
|
+
self.is_loaded = True
|
|
103
|
+
|
|
104
|
+
def to(self, device: str):
|
|
105
|
+
"""Move model to device."""
|
|
106
|
+
self.model.to(device)
|
|
107
|
+
self.device = device
|
|
108
|
+
|
|
109
|
+
def eval(self):
|
|
110
|
+
"""Set model to evaluation mode."""
|
|
111
|
+
self.model.eval()
|
|
112
|
+
|
|
113
|
+
@property
|
|
114
|
+
def training(self) -> bool:
|
|
115
|
+
"""Return whether model is in training mode."""
|
|
116
|
+
return self.model.training
|
|
117
|
+
|
|
118
|
+
def half(self):
|
|
119
|
+
"""Convert model to half precision."""
|
|
120
|
+
self.model.half()
|
|
121
|
+
|
|
122
|
+
def cpu(self):
|
|
123
|
+
"""Move model to cpu."""
|
|
124
|
+
self.model.cpu()
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class TorchEvaluationModel(TorchscriptEvaluationModel):
|
|
128
|
+
"""Wrapper for torch models.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
model_architecture: Optional torch model architecture
|
|
132
|
+
"""
|
|
133
|
+
|
|
134
|
+
def __init__(self, config: DictConfig, model_architecture: nn.Module) -> None:
|
|
135
|
+
super().__init__(config=config)
|
|
136
|
+
self.model = model_architecture
|
|
137
|
+
self.model.eval()
|
|
138
|
+
device = next(self.model.parameters()).device
|
|
139
|
+
self.device = str(device)
|
|
140
|
+
|
|
141
|
+
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
|
142
|
+
return self.model(*args, **kwargs)
|
|
143
|
+
|
|
144
|
+
def load_from_disk(self, model_path: str, device: str = "cpu"):
|
|
145
|
+
"""Load model from disk."""
|
|
146
|
+
self.model_path = model_path
|
|
147
|
+
self.device = device
|
|
148
|
+
self.model.load_state_dict(torch.load(self.model_path))
|
|
149
|
+
self.model.eval()
|
|
150
|
+
self.model.to(self.device)
|
|
151
|
+
|
|
152
|
+
parameter_types = {param.dtype for param in self.model.parameters()}
|
|
153
|
+
if len(parameter_types) == 2:
|
|
154
|
+
# TODO: There could be models with mixed precision?
|
|
155
|
+
raise ValueError(f"Expected only one type of parameters, found {parameter_types}")
|
|
156
|
+
|
|
157
|
+
self.model_dtype = list(parameter_types)[0]
|
|
158
|
+
self.is_loaded = True
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
onnx_to_torch_dtype_dict = {
|
|
162
|
+
"tensor(bool)": torch.bool,
|
|
163
|
+
"tensor(uint8)": torch.uint8,
|
|
164
|
+
"tensor(int8)": torch.int8,
|
|
165
|
+
"tensor(int16)": torch.int16,
|
|
166
|
+
"tensor(int32)": torch.int32,
|
|
167
|
+
"tensor(int64)": torch.int64,
|
|
168
|
+
"tensor(float16)": torch.float16,
|
|
169
|
+
"tensor(float32)": torch.float32,
|
|
170
|
+
"tensor(float)": torch.float32,
|
|
171
|
+
"tensor(float64)": torch.float64,
|
|
172
|
+
"tensor(complex64)": torch.complex64,
|
|
173
|
+
"tensor(complex128)": torch.complex128,
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
class ONNXEvaluationModel(BaseEvaluationModel):
|
|
178
|
+
"""Wrapper for ONNX models. It's designed to provide a similar interface to standard torch models."""
|
|
179
|
+
|
|
180
|
+
def __init__(self, config: DictConfig) -> None:
|
|
181
|
+
if not ONNX_AVAILABLE:
|
|
182
|
+
raise ImportError(
|
|
183
|
+
"onnxruntime is not installed. Please install ONNX capabilities for quadra with: poetry install -E onnx"
|
|
184
|
+
)
|
|
185
|
+
super().__init__(config=config)
|
|
186
|
+
self.session_options = self.generate_session_options()
|
|
187
|
+
|
|
188
|
+
def generate_session_options(self) -> ort.SessionOptions:
|
|
189
|
+
"""Generate session options from the current config."""
|
|
190
|
+
session_options = ort.SessionOptions()
|
|
191
|
+
|
|
192
|
+
if hasattr(self.config, "session_options") and self.config.session_options is not None:
|
|
193
|
+
session_options_dict = cast(
|
|
194
|
+
dict[str, Any], OmegaConf.to_container(self.config.session_options, resolve=True)
|
|
195
|
+
)
|
|
196
|
+
for key, value in session_options_dict.items():
|
|
197
|
+
final_value = value
|
|
198
|
+
if isinstance(value, dict) and "_target_" in value:
|
|
199
|
+
final_value = instantiate(final_value)
|
|
200
|
+
|
|
201
|
+
setattr(session_options, key, final_value)
|
|
202
|
+
|
|
203
|
+
return session_options
|
|
204
|
+
|
|
205
|
+
def __call__(self, *inputs: np.ndarray | torch.Tensor) -> Any:
|
|
206
|
+
"""Run inference on the model and return the output as torch tensors."""
|
|
207
|
+
# TODO: Maybe we can support also kwargs
|
|
208
|
+
use_pytorch = False
|
|
209
|
+
|
|
210
|
+
onnx_inputs: dict[str, np.ndarray | torch.Tensor] = {}
|
|
211
|
+
|
|
212
|
+
for onnx_input, current_input in zip(self.model.get_inputs(), inputs):
|
|
213
|
+
if isinstance(current_input, torch.Tensor):
|
|
214
|
+
onnx_inputs[onnx_input.name] = current_input
|
|
215
|
+
use_pytorch = True
|
|
216
|
+
elif isinstance(current_input, np.ndarray):
|
|
217
|
+
onnx_inputs[onnx_input.name] = current_input
|
|
218
|
+
else:
|
|
219
|
+
raise ValueError(f"Invalid input type: {type(inputs)}")
|
|
220
|
+
|
|
221
|
+
if use_pytorch and isinstance(current_input, np.ndarray):
|
|
222
|
+
raise ValueError("Cannot mix torch and numpy inputs")
|
|
223
|
+
|
|
224
|
+
if use_pytorch:
|
|
225
|
+
onnx_output = self._forward_from_pytorch(cast(dict[str, torch.Tensor], onnx_inputs))
|
|
226
|
+
else:
|
|
227
|
+
onnx_output = self._forward_from_numpy(cast(dict[str, np.ndarray], onnx_inputs))
|
|
228
|
+
|
|
229
|
+
onnx_output = [torch.from_numpy(x).to(self.device) if isinstance(x, np.ndarray) else x for x in onnx_output]
|
|
230
|
+
|
|
231
|
+
if len(onnx_output) == 1:
|
|
232
|
+
onnx_output = onnx_output[0]
|
|
233
|
+
|
|
234
|
+
return onnx_output
|
|
235
|
+
|
|
236
|
+
def _forward_from_pytorch(self, input_dict: dict[str, torch.Tensor]):
|
|
237
|
+
"""Run inference on the model and return the output as torch tensors."""
|
|
238
|
+
io_binding = self.model.io_binding()
|
|
239
|
+
device_type = self.device.split(":")[0]
|
|
240
|
+
|
|
241
|
+
for k, v in input_dict.items():
|
|
242
|
+
if not v.is_contiguous():
|
|
243
|
+
# If not contiguous onnx give wrong results
|
|
244
|
+
v = v.contiguous() # noqa: PLW2901
|
|
245
|
+
|
|
246
|
+
io_binding.bind_input(
|
|
247
|
+
name=k,
|
|
248
|
+
device_type=device_type,
|
|
249
|
+
# Weirdly enough onnx wants 0 for cpu
|
|
250
|
+
device_id=0 if device_type == "cpu" else int(self.device.split(":")[1]),
|
|
251
|
+
element_type=np.float16 if v.dtype == torch.float16 else np.float32,
|
|
252
|
+
shape=tuple(v.shape),
|
|
253
|
+
buffer_ptr=v.data_ptr(),
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
for x in self.model.get_outputs():
|
|
257
|
+
# TODO: Is it possible to also bind the output? We require info about output dimensions
|
|
258
|
+
io_binding.bind_output(name=x.name)
|
|
259
|
+
|
|
260
|
+
self.model.run_with_iobinding(io_binding)
|
|
261
|
+
|
|
262
|
+
output = io_binding.copy_outputs_to_cpu()
|
|
263
|
+
|
|
264
|
+
return output
|
|
265
|
+
|
|
266
|
+
def _forward_from_numpy(self, input_dict: dict[str, np.ndarray]):
|
|
267
|
+
"""Run inference on the model and return the output as numpy array."""
|
|
268
|
+
ort_outputs = [x.name for x in self.model.get_outputs()]
|
|
269
|
+
|
|
270
|
+
onnx_output = self.model.run(ort_outputs, input_dict)
|
|
271
|
+
|
|
272
|
+
return onnx_output
|
|
273
|
+
|
|
274
|
+
def load_from_disk(self, model_path: str, device: str = "cpu"):
|
|
275
|
+
"""Load model from disk."""
|
|
276
|
+
self.model_path = model_path
|
|
277
|
+
self.device = device
|
|
278
|
+
|
|
279
|
+
ort_providers = self._get_providers(device)
|
|
280
|
+
self.model = ort.InferenceSession(self.model_path, providers=ort_providers, sess_options=self.session_options)
|
|
281
|
+
self.model_dtype = self.cast_onnx_dtype(self.model.get_inputs()[0].type)
|
|
282
|
+
self.is_loaded = True
|
|
283
|
+
|
|
284
|
+
def _get_providers(self, device: str) -> list[tuple[str, dict[str, Any]] | str]:
|
|
285
|
+
"""Return the providers for the ONNX model based on the device."""
|
|
286
|
+
ort_providers: list[tuple[str, dict[str, Any]] | str]
|
|
287
|
+
|
|
288
|
+
if device == "cpu":
|
|
289
|
+
ort_providers = ["CPUExecutionProvider"]
|
|
290
|
+
else:
|
|
291
|
+
ort_providers = [
|
|
292
|
+
(
|
|
293
|
+
"CUDAExecutionProvider",
|
|
294
|
+
{
|
|
295
|
+
"device_id": int(device.split(":")[1]),
|
|
296
|
+
},
|
|
297
|
+
)
|
|
298
|
+
]
|
|
299
|
+
|
|
300
|
+
return ort_providers
|
|
301
|
+
|
|
302
|
+
def to(self, device: str):
|
|
303
|
+
"""Move model to device."""
|
|
304
|
+
self.device = device
|
|
305
|
+
ort_providers = self._get_providers(device)
|
|
306
|
+
self.model.set_providers(ort_providers)
|
|
307
|
+
|
|
308
|
+
def eval(self):
|
|
309
|
+
"""Fake interface to match torch models."""
|
|
310
|
+
return self
|
|
311
|
+
|
|
312
|
+
def half(self):
|
|
313
|
+
"""Convert model to half precision."""
|
|
314
|
+
raise NotImplementedError("At the moment ONNX models do not support half method.")
|
|
315
|
+
|
|
316
|
+
def cpu(self):
|
|
317
|
+
"""Move model to cpu."""
|
|
318
|
+
self.to("cpu")
|
|
319
|
+
|
|
320
|
+
def cast_onnx_dtype(self, onnx_dtype: str) -> torch.dtype | np.dtype:
|
|
321
|
+
"""Cast ONNX dtype to numpy or pytorch dtype."""
|
|
322
|
+
return onnx_to_torch_dtype_dict[onnx_dtype]
|
|
File without changes
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
import segmentation_models_pytorch as smp
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def create_smp_backbone(
|
|
7
|
+
arch: str,
|
|
8
|
+
encoder_name: str,
|
|
9
|
+
freeze_encoder: bool = False,
|
|
10
|
+
in_channels: int = 3,
|
|
11
|
+
num_classes: int = 0,
|
|
12
|
+
**kwargs: Any,
|
|
13
|
+
):
|
|
14
|
+
"""Create Segmentation.models.pytorch model backbone
|
|
15
|
+
Args:
|
|
16
|
+
arch: architecture name
|
|
17
|
+
encoder_name: architecture name
|
|
18
|
+
freeze_encoder: freeze encoder or not
|
|
19
|
+
in_channels: number of input channels
|
|
20
|
+
num_classes: number of classes
|
|
21
|
+
**kwargs: extra arguments for model (for example classification head).
|
|
22
|
+
"""
|
|
23
|
+
model = smp.create_model(
|
|
24
|
+
arch=arch, encoder_name=encoder_name, in_channels=in_channels, classes=num_classes, **kwargs
|
|
25
|
+
)
|
|
26
|
+
if freeze_encoder:
|
|
27
|
+
for child in model.encoder.children():
|
|
28
|
+
for param in child.parameters():
|
|
29
|
+
param.requires_grad = False
|
|
30
|
+
return model
|
quadra/modules/base.py
ADDED
|
@@ -0,0 +1,312 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import pytorch_lightning as pl
|
|
7
|
+
import sklearn
|
|
8
|
+
import torch
|
|
9
|
+
import torchmetrics
|
|
10
|
+
from sklearn.linear_model import LogisticRegression
|
|
11
|
+
from torch import nn
|
|
12
|
+
from torch.optim import Optimizer
|
|
13
|
+
|
|
14
|
+
from quadra.models.base import ModelSignatureWrapper
|
|
15
|
+
|
|
16
|
+
__all__ = ["BaseLightningModule", "SSLModule"]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class BaseLightningModule(pl.LightningModule):
|
|
20
|
+
"""Base lightning module.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
model: Network Module used for extract features
|
|
24
|
+
optimizer: optimizer of the training. If None a default Adam is used.
|
|
25
|
+
lr_scheduler: lr scheduler. If None a default ReduceLROnPlateau is used.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
model: nn.Module,
|
|
31
|
+
optimizer: Optimizer | None = None,
|
|
32
|
+
lr_scheduler: object | None = None,
|
|
33
|
+
lr_scheduler_interval: str | None = "epoch",
|
|
34
|
+
):
|
|
35
|
+
super().__init__()
|
|
36
|
+
self.model = ModelSignatureWrapper(model)
|
|
37
|
+
self.optimizer = optimizer
|
|
38
|
+
self.schedulers = lr_scheduler
|
|
39
|
+
self.lr_scheduler_interval = lr_scheduler_interval
|
|
40
|
+
|
|
41
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
42
|
+
"""Forward method
|
|
43
|
+
Args:
|
|
44
|
+
x: input tensor.
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
model inference
|
|
48
|
+
"""
|
|
49
|
+
return self.model(x)
|
|
50
|
+
|
|
51
|
+
def configure_optimizers(self) -> tuple[list[Any], list[dict[str, Any]]]:
|
|
52
|
+
"""Get default optimizer if not passed a value.
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
optimizer and lr scheduler as Tuple containing a list of optimizers and a list of lr schedulers
|
|
56
|
+
"""
|
|
57
|
+
# get default optimizer
|
|
58
|
+
if getattr(self, "optimizer", None) is None or not self.optimizer:
|
|
59
|
+
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=3e-4)
|
|
60
|
+
|
|
61
|
+
# get default scheduler
|
|
62
|
+
if getattr(self, "schedulers", None) is None or not self.schedulers:
|
|
63
|
+
self.schedulers = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, factor=0.5, patience=30)
|
|
64
|
+
|
|
65
|
+
lr_scheduler_conf = {
|
|
66
|
+
"scheduler": self.schedulers,
|
|
67
|
+
"interval": self.lr_scheduler_interval,
|
|
68
|
+
"monitor": "val_loss",
|
|
69
|
+
"strict": False,
|
|
70
|
+
}
|
|
71
|
+
return [self.optimizer], [lr_scheduler_conf]
|
|
72
|
+
|
|
73
|
+
# pylint: disable=unused-argument
|
|
74
|
+
def optimizer_zero_grad(self, epoch, batch_idx, optimizer, optimizer_idx: int = 0):
|
|
75
|
+
"""Redefine optimizer zero grad."""
|
|
76
|
+
optimizer.zero_grad(set_to_none=True)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class SSLModule(BaseLightningModule):
|
|
80
|
+
"""Base module for self supervised learning.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
model: Network Module used for extract features
|
|
84
|
+
criterion: SSL loss to be applied
|
|
85
|
+
classifier: Standard sklearn classifiers
|
|
86
|
+
optimizer: optimizer of the training. If None a default Adam is used.
|
|
87
|
+
lr_scheduler: lr scheduler. If None a default ReduceLROnPlateau is used.
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
def __init__(
|
|
91
|
+
self,
|
|
92
|
+
model: nn.Module,
|
|
93
|
+
criterion: nn.Module,
|
|
94
|
+
classifier: sklearn.base.ClassifierMixin | None = None,
|
|
95
|
+
optimizer: Optimizer | None = None,
|
|
96
|
+
lr_scheduler: object | None = None,
|
|
97
|
+
lr_scheduler_interval: str | None = "epoch",
|
|
98
|
+
):
|
|
99
|
+
super().__init__(model, optimizer, lr_scheduler, lr_scheduler_interval)
|
|
100
|
+
self.criterion = criterion
|
|
101
|
+
self.classifier_train_loader: torch.utils.data.DataLoader | None
|
|
102
|
+
if classifier is None:
|
|
103
|
+
self.classifier = LogisticRegression(max_iter=10000, n_jobs=8, random_state=42)
|
|
104
|
+
else:
|
|
105
|
+
self.classifier = classifier
|
|
106
|
+
|
|
107
|
+
self.val_acc = torchmetrics.Accuracy()
|
|
108
|
+
|
|
109
|
+
def fit_estimator(self):
|
|
110
|
+
"""Fit a classifier on the embeddings extracted from the current trained model."""
|
|
111
|
+
targets = []
|
|
112
|
+
train_embeddings = []
|
|
113
|
+
self.model.eval()
|
|
114
|
+
with torch.no_grad():
|
|
115
|
+
for im, target in self.classifier_train_loader:
|
|
116
|
+
emb = self.model(im.to(self.device))
|
|
117
|
+
targets.append(target)
|
|
118
|
+
train_embeddings.append(emb)
|
|
119
|
+
targets = torch.cat(targets, dim=0).cpu().numpy()
|
|
120
|
+
train_embeddings = torch.cat(train_embeddings, dim=0).cpu().numpy()
|
|
121
|
+
self.classifier.fit(train_embeddings, targets)
|
|
122
|
+
|
|
123
|
+
def calculate_accuracy(self, batch):
|
|
124
|
+
"""Calculate accuracy on a batch of data."""
|
|
125
|
+
images, labels = batch
|
|
126
|
+
with torch.no_grad():
|
|
127
|
+
embedding = self.model(images).cpu().numpy()
|
|
128
|
+
|
|
129
|
+
predictions = self.classifier.predict(embedding)
|
|
130
|
+
labels = labels.detach()
|
|
131
|
+
acc = self.val_acc(torch.tensor(predictions, device=self.device), labels)
|
|
132
|
+
|
|
133
|
+
return acc
|
|
134
|
+
|
|
135
|
+
# TODO: In multiprocessing mode, this function is called multiple times, how can we avoid this?
|
|
136
|
+
def on_validation_start(self) -> None:
|
|
137
|
+
if not hasattr(self, "classifier_train_loader") and hasattr(self.trainer, "datamodule"):
|
|
138
|
+
self.classifier_train_loader = self.trainer.datamodule.classifier_train_dataloader()
|
|
139
|
+
|
|
140
|
+
if self.classifier_train_loader is not None:
|
|
141
|
+
self.fit_estimator()
|
|
142
|
+
|
|
143
|
+
def validation_step(self, batch: tuple[tuple[torch.Tensor, torch.Tensor], torch.Tensor], batch_idx: int) -> None:
|
|
144
|
+
# pylint: disable=unused-argument
|
|
145
|
+
if self.classifier_train_loader is None:
|
|
146
|
+
# Compute loss
|
|
147
|
+
(im_x, im_y), _ = batch
|
|
148
|
+
z1 = self(im_x)
|
|
149
|
+
z2 = self(im_y)
|
|
150
|
+
loss = self.criterion(z1, z2)
|
|
151
|
+
|
|
152
|
+
self.log(
|
|
153
|
+
"val_loss",
|
|
154
|
+
loss,
|
|
155
|
+
on_epoch=True,
|
|
156
|
+
on_step=True,
|
|
157
|
+
logger=True,
|
|
158
|
+
prog_bar=True,
|
|
159
|
+
)
|
|
160
|
+
return loss
|
|
161
|
+
|
|
162
|
+
acc = self.calculate_accuracy(batch)
|
|
163
|
+
self.log("val_acc", acc, on_epoch=True, on_step=False, logger=True, prog_bar=True)
|
|
164
|
+
return None
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
class SegmentationModel(BaseLightningModule):
|
|
168
|
+
"""Generic segmentation model.
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
model: segmentation model to be used.
|
|
172
|
+
loss_fun: loss function to be used.
|
|
173
|
+
optimizer: Optimizer to be used. Defaults to None.
|
|
174
|
+
lr_scheduler: lr scheduler to be used. Defaults to None.
|
|
175
|
+
"""
|
|
176
|
+
|
|
177
|
+
def __init__(
|
|
178
|
+
self,
|
|
179
|
+
model: torch.nn.Module,
|
|
180
|
+
loss_fun: Callable,
|
|
181
|
+
optimizer: Optimizer | None = None,
|
|
182
|
+
lr_scheduler: object | None = None,
|
|
183
|
+
):
|
|
184
|
+
super().__init__(model, optimizer, lr_scheduler)
|
|
185
|
+
self.loss_fun = loss_fun
|
|
186
|
+
|
|
187
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
188
|
+
"""Forward method
|
|
189
|
+
Args:
|
|
190
|
+
x: input tensor.
|
|
191
|
+
|
|
192
|
+
Returns:
|
|
193
|
+
model inference
|
|
194
|
+
"""
|
|
195
|
+
x = self.model(x)
|
|
196
|
+
return x
|
|
197
|
+
|
|
198
|
+
def step(self, batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
199
|
+
"""Compute loss
|
|
200
|
+
Args:
|
|
201
|
+
batch: batch.
|
|
202
|
+
|
|
203
|
+
Returns:
|
|
204
|
+
Prediction and target masks
|
|
205
|
+
"""
|
|
206
|
+
images, target_masks, _ = batch
|
|
207
|
+
pred_masks = self(images)
|
|
208
|
+
if len(pred_masks.shape) == 3:
|
|
209
|
+
pred_masks = pred_masks.unsqueeze(1)
|
|
210
|
+
if len(target_masks.shape) == 3:
|
|
211
|
+
target_masks = target_masks.unsqueeze(1)
|
|
212
|
+
assert pred_masks.shape == target_masks.shape
|
|
213
|
+
|
|
214
|
+
return pred_masks, target_masks
|
|
215
|
+
|
|
216
|
+
def compute_loss(self, pred_masks: torch.Tensor, target_masks: torch.Tensor) -> torch.Tensor:
|
|
217
|
+
"""Compute loss
|
|
218
|
+
Args:
|
|
219
|
+
pred_masks: predicted masks
|
|
220
|
+
target_masks: target masks.
|
|
221
|
+
|
|
222
|
+
Returns:
|
|
223
|
+
The computed loss
|
|
224
|
+
|
|
225
|
+
"""
|
|
226
|
+
loss = self.loss_fun(pred_masks, target_masks)
|
|
227
|
+
return loss
|
|
228
|
+
|
|
229
|
+
def training_step(self, batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int):
|
|
230
|
+
"""Training step."""
|
|
231
|
+
# pylint: disable=unused-argument
|
|
232
|
+
pred_masks, target_masks = self.step(batch)
|
|
233
|
+
loss = self.compute_loss(pred_masks, target_masks)
|
|
234
|
+
self.log_dict(
|
|
235
|
+
{"loss": loss},
|
|
236
|
+
on_step=True,
|
|
237
|
+
on_epoch=True,
|
|
238
|
+
prog_bar=True,
|
|
239
|
+
)
|
|
240
|
+
return loss
|
|
241
|
+
|
|
242
|
+
def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx):
|
|
243
|
+
"""Validation step."""
|
|
244
|
+
# pylint: disable=unused-argument
|
|
245
|
+
pred_masks, target_masks = self.step(batch)
|
|
246
|
+
loss = self.compute_loss(pred_masks, target_masks)
|
|
247
|
+
self.log_dict(
|
|
248
|
+
{"val_loss": loss},
|
|
249
|
+
on_step=True,
|
|
250
|
+
on_epoch=True,
|
|
251
|
+
prog_bar=True,
|
|
252
|
+
)
|
|
253
|
+
return loss
|
|
254
|
+
|
|
255
|
+
def test_step(self, batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int):
|
|
256
|
+
"""Test step."""
|
|
257
|
+
# pylint: disable=unused-argument
|
|
258
|
+
pred_masks, target_masks = self.step(batch)
|
|
259
|
+
loss = self.compute_loss(pred_masks, target_masks)
|
|
260
|
+
self.log_dict(
|
|
261
|
+
{"test_loss": loss},
|
|
262
|
+
on_step=True,
|
|
263
|
+
on_epoch=True,
|
|
264
|
+
prog_bar=True,
|
|
265
|
+
)
|
|
266
|
+
return loss
|
|
267
|
+
|
|
268
|
+
def predict_step(
|
|
269
|
+
self,
|
|
270
|
+
batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
|
271
|
+
batch_idx: int,
|
|
272
|
+
dataloader_idx: int | None = None,
|
|
273
|
+
) -> Any:
|
|
274
|
+
"""Predict step."""
|
|
275
|
+
# pylint: disable=unused-argument
|
|
276
|
+
images, masks, labels = batch
|
|
277
|
+
pred_masks = self(images)
|
|
278
|
+
return images.cpu(), masks.cpu(), pred_masks.cpu(), labels.cpu()
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
class SegmentationModelMulticlass(SegmentationModel):
|
|
282
|
+
"""Generic multiclass segmentation model.
|
|
283
|
+
|
|
284
|
+
Args:
|
|
285
|
+
model: segmentation model to be used.
|
|
286
|
+
loss_fun: loss function to be used.
|
|
287
|
+
optimizer: Optimizer to be used. Defaults to None.
|
|
288
|
+
lr_scheduler: lr scheduler to be used. Defaults to None.
|
|
289
|
+
"""
|
|
290
|
+
|
|
291
|
+
def __init__(
|
|
292
|
+
self,
|
|
293
|
+
model: torch.nn.Module,
|
|
294
|
+
loss_fun: Callable,
|
|
295
|
+
optimizer: Optimizer | None = None,
|
|
296
|
+
lr_scheduler: object | None = None,
|
|
297
|
+
):
|
|
298
|
+
super().__init__(model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, loss_fun=loss_fun)
|
|
299
|
+
|
|
300
|
+
def step(self, batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
301
|
+
"""Compute step
|
|
302
|
+
Args:
|
|
303
|
+
batch: batch.
|
|
304
|
+
|
|
305
|
+
Returns:
|
|
306
|
+
prediction, target
|
|
307
|
+
|
|
308
|
+
"""
|
|
309
|
+
images, target_masks, _ = batch
|
|
310
|
+
pred_masks = self(images)
|
|
311
|
+
|
|
312
|
+
return pred_masks, target_masks
|