quadra 0.0.1__py3-none-any.whl → 2.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hydra_plugins/quadra_searchpath_plugin.py +30 -0
- quadra/__init__.py +6 -0
- quadra/callbacks/__init__.py +0 -0
- quadra/callbacks/anomalib.py +289 -0
- quadra/callbacks/lightning.py +501 -0
- quadra/callbacks/mlflow.py +291 -0
- quadra/callbacks/scheduler.py +69 -0
- quadra/configs/__init__.py +0 -0
- quadra/configs/backbone/caformer_m36.yaml +8 -0
- quadra/configs/backbone/caformer_s36.yaml +8 -0
- quadra/configs/backbone/convnextv2_base.yaml +8 -0
- quadra/configs/backbone/convnextv2_femto.yaml +8 -0
- quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
- quadra/configs/backbone/dino_vitb8.yaml +12 -0
- quadra/configs/backbone/dino_vits8.yaml +12 -0
- quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
- quadra/configs/backbone/dinov2_vits14.yaml +12 -0
- quadra/configs/backbone/efficientnet_b0.yaml +8 -0
- quadra/configs/backbone/efficientnet_b1.yaml +8 -0
- quadra/configs/backbone/efficientnet_b2.yaml +8 -0
- quadra/configs/backbone/efficientnet_b3.yaml +8 -0
- quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
- quadra/configs/backbone/levit_128s.yaml +8 -0
- quadra/configs/backbone/mnasnet0_5.yaml +9 -0
- quadra/configs/backbone/resnet101.yaml +8 -0
- quadra/configs/backbone/resnet18.yaml +8 -0
- quadra/configs/backbone/resnet18_ssl.yaml +8 -0
- quadra/configs/backbone/resnet50.yaml +8 -0
- quadra/configs/backbone/smp.yaml +9 -0
- quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
- quadra/configs/backbone/unetr.yaml +15 -0
- quadra/configs/backbone/vit16_base.yaml +9 -0
- quadra/configs/backbone/vit16_small.yaml +9 -0
- quadra/configs/backbone/vit16_tiny.yaml +9 -0
- quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
- quadra/configs/callbacks/all.yaml +45 -0
- quadra/configs/callbacks/default.yaml +34 -0
- quadra/configs/callbacks/default_anomalib.yaml +64 -0
- quadra/configs/config.yaml +33 -0
- quadra/configs/core/default.yaml +11 -0
- quadra/configs/datamodule/base/anomaly.yaml +16 -0
- quadra/configs/datamodule/base/classification.yaml +21 -0
- quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
- quadra/configs/datamodule/base/segmentation.yaml +18 -0
- quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
- quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
- quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
- quadra/configs/datamodule/base/ssl.yaml +21 -0
- quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
- quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
- quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
- quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
- quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
- quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
- quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
- quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
- quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
- quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
- quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/base/classification/classification.yaml +73 -0
- quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
- quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
- quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
- quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
- quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
- quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
- quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
- quadra/configs/experiment/base/ssl/byol.yaml +43 -0
- quadra/configs/experiment/base/ssl/dino.yaml +46 -0
- quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
- quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
- quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
- quadra/configs/experiment/custom/cls.yaml +12 -0
- quadra/configs/experiment/default.yaml +15 -0
- quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
- quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
- quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
- quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
- quadra/configs/export/default.yaml +13 -0
- quadra/configs/hydra/anomaly_custom.yaml +15 -0
- quadra/configs/hydra/default.yaml +14 -0
- quadra/configs/inference/default.yaml +26 -0
- quadra/configs/logger/comet.yaml +10 -0
- quadra/configs/logger/csv.yaml +5 -0
- quadra/configs/logger/mlflow.yaml +12 -0
- quadra/configs/logger/tensorboard.yaml +8 -0
- quadra/configs/loss/asl.yaml +7 -0
- quadra/configs/loss/barlow.yaml +2 -0
- quadra/configs/loss/bce.yaml +1 -0
- quadra/configs/loss/byol.yaml +1 -0
- quadra/configs/loss/cross_entropy.yaml +1 -0
- quadra/configs/loss/dino.yaml +8 -0
- quadra/configs/loss/simclr.yaml +2 -0
- quadra/configs/loss/simsiam.yaml +1 -0
- quadra/configs/loss/smp_ce.yaml +3 -0
- quadra/configs/loss/smp_dice.yaml +2 -0
- quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
- quadra/configs/loss/smp_mcc.yaml +2 -0
- quadra/configs/loss/vicreg.yaml +5 -0
- quadra/configs/model/anomalib/cfa.yaml +35 -0
- quadra/configs/model/anomalib/cflow.yaml +30 -0
- quadra/configs/model/anomalib/csflow.yaml +34 -0
- quadra/configs/model/anomalib/dfm.yaml +19 -0
- quadra/configs/model/anomalib/draem.yaml +29 -0
- quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
- quadra/configs/model/anomalib/fastflow.yaml +32 -0
- quadra/configs/model/anomalib/padim.yaml +32 -0
- quadra/configs/model/anomalib/patchcore.yaml +36 -0
- quadra/configs/model/barlow.yaml +16 -0
- quadra/configs/model/byol.yaml +25 -0
- quadra/configs/model/classification.yaml +10 -0
- quadra/configs/model/dino.yaml +26 -0
- quadra/configs/model/logistic_regression.yaml +4 -0
- quadra/configs/model/multilabel_classification.yaml +9 -0
- quadra/configs/model/simclr.yaml +18 -0
- quadra/configs/model/simsiam.yaml +24 -0
- quadra/configs/model/smp.yaml +4 -0
- quadra/configs/model/smp_multiclass.yaml +4 -0
- quadra/configs/model/vicreg.yaml +16 -0
- quadra/configs/optimizer/adam.yaml +5 -0
- quadra/configs/optimizer/adamw.yaml +3 -0
- quadra/configs/optimizer/default.yaml +4 -0
- quadra/configs/optimizer/lars.yaml +8 -0
- quadra/configs/optimizer/sgd.yaml +4 -0
- quadra/configs/scheduler/default.yaml +5 -0
- quadra/configs/scheduler/rop.yaml +5 -0
- quadra/configs/scheduler/step.yaml +3 -0
- quadra/configs/scheduler/warmrestart.yaml +2 -0
- quadra/configs/scheduler/warmup.yaml +6 -0
- quadra/configs/task/anomalib/cfa.yaml +5 -0
- quadra/configs/task/anomalib/cflow.yaml +5 -0
- quadra/configs/task/anomalib/csflow.yaml +5 -0
- quadra/configs/task/anomalib/draem.yaml +5 -0
- quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
- quadra/configs/task/anomalib/fastflow.yaml +5 -0
- quadra/configs/task/anomalib/inference.yaml +3 -0
- quadra/configs/task/anomalib/padim.yaml +5 -0
- quadra/configs/task/anomalib/patchcore.yaml +5 -0
- quadra/configs/task/classification.yaml +6 -0
- quadra/configs/task/classification_evaluation.yaml +6 -0
- quadra/configs/task/default.yaml +1 -0
- quadra/configs/task/segmentation.yaml +9 -0
- quadra/configs/task/segmentation_evaluation.yaml +3 -0
- quadra/configs/task/sklearn_classification.yaml +13 -0
- quadra/configs/task/sklearn_classification_patch.yaml +11 -0
- quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
- quadra/configs/task/sklearn_classification_test.yaml +8 -0
- quadra/configs/task/ssl.yaml +2 -0
- quadra/configs/trainer/lightning_cpu.yaml +36 -0
- quadra/configs/trainer/lightning_gpu.yaml +35 -0
- quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
- quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
- quadra/configs/trainer/lightning_multigpu.yaml +37 -0
- quadra/configs/trainer/sklearn_classification.yaml +7 -0
- quadra/configs/transforms/byol.yaml +47 -0
- quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
- quadra/configs/transforms/default.yaml +37 -0
- quadra/configs/transforms/default_numpy.yaml +24 -0
- quadra/configs/transforms/default_resize.yaml +22 -0
- quadra/configs/transforms/dino.yaml +63 -0
- quadra/configs/transforms/linear_eval.yaml +18 -0
- quadra/datamodules/__init__.py +20 -0
- quadra/datamodules/anomaly.py +180 -0
- quadra/datamodules/base.py +375 -0
- quadra/datamodules/classification.py +1003 -0
- quadra/datamodules/generic/__init__.py +0 -0
- quadra/datamodules/generic/imagenette.py +144 -0
- quadra/datamodules/generic/mnist.py +81 -0
- quadra/datamodules/generic/mvtec.py +58 -0
- quadra/datamodules/generic/oxford_pet.py +163 -0
- quadra/datamodules/patch.py +190 -0
- quadra/datamodules/segmentation.py +742 -0
- quadra/datamodules/ssl.py +140 -0
- quadra/datasets/__init__.py +17 -0
- quadra/datasets/anomaly.py +287 -0
- quadra/datasets/classification.py +241 -0
- quadra/datasets/patch.py +138 -0
- quadra/datasets/segmentation.py +239 -0
- quadra/datasets/ssl.py +110 -0
- quadra/losses/__init__.py +0 -0
- quadra/losses/classification/__init__.py +6 -0
- quadra/losses/classification/asl.py +83 -0
- quadra/losses/classification/focal.py +320 -0
- quadra/losses/classification/prototypical.py +148 -0
- quadra/losses/ssl/__init__.py +17 -0
- quadra/losses/ssl/barlowtwins.py +47 -0
- quadra/losses/ssl/byol.py +37 -0
- quadra/losses/ssl/dino.py +129 -0
- quadra/losses/ssl/hyperspherical.py +45 -0
- quadra/losses/ssl/idmm.py +50 -0
- quadra/losses/ssl/simclr.py +67 -0
- quadra/losses/ssl/simsiam.py +30 -0
- quadra/losses/ssl/vicreg.py +76 -0
- quadra/main.py +49 -0
- quadra/metrics/__init__.py +3 -0
- quadra/metrics/segmentation.py +251 -0
- quadra/models/__init__.py +0 -0
- quadra/models/base.py +151 -0
- quadra/models/classification/__init__.py +8 -0
- quadra/models/classification/backbones.py +149 -0
- quadra/models/classification/base.py +92 -0
- quadra/models/evaluation.py +322 -0
- quadra/modules/__init__.py +0 -0
- quadra/modules/backbone.py +30 -0
- quadra/modules/base.py +312 -0
- quadra/modules/classification/__init__.py +3 -0
- quadra/modules/classification/base.py +327 -0
- quadra/modules/ssl/__init__.py +17 -0
- quadra/modules/ssl/barlowtwins.py +59 -0
- quadra/modules/ssl/byol.py +172 -0
- quadra/modules/ssl/common.py +285 -0
- quadra/modules/ssl/dino.py +186 -0
- quadra/modules/ssl/hyperspherical.py +206 -0
- quadra/modules/ssl/idmm.py +98 -0
- quadra/modules/ssl/simclr.py +73 -0
- quadra/modules/ssl/simsiam.py +68 -0
- quadra/modules/ssl/vicreg.py +67 -0
- quadra/optimizers/__init__.py +4 -0
- quadra/optimizers/lars.py +153 -0
- quadra/optimizers/sam.py +127 -0
- quadra/schedulers/__init__.py +3 -0
- quadra/schedulers/base.py +44 -0
- quadra/schedulers/warmup.py +127 -0
- quadra/tasks/__init__.py +24 -0
- quadra/tasks/anomaly.py +582 -0
- quadra/tasks/base.py +397 -0
- quadra/tasks/classification.py +1263 -0
- quadra/tasks/patch.py +492 -0
- quadra/tasks/segmentation.py +389 -0
- quadra/tasks/ssl.py +560 -0
- quadra/trainers/README.md +3 -0
- quadra/trainers/__init__.py +0 -0
- quadra/trainers/classification.py +179 -0
- quadra/utils/__init__.py +0 -0
- quadra/utils/anomaly.py +112 -0
- quadra/utils/classification.py +618 -0
- quadra/utils/deprecation.py +31 -0
- quadra/utils/evaluation.py +474 -0
- quadra/utils/export.py +585 -0
- quadra/utils/imaging.py +32 -0
- quadra/utils/logger.py +15 -0
- quadra/utils/mlflow.py +98 -0
- quadra/utils/model_manager.py +320 -0
- quadra/utils/models.py +523 -0
- quadra/utils/patch/__init__.py +15 -0
- quadra/utils/patch/dataset.py +1433 -0
- quadra/utils/patch/metrics.py +449 -0
- quadra/utils/patch/model.py +153 -0
- quadra/utils/patch/visualization.py +217 -0
- quadra/utils/resolver.py +42 -0
- quadra/utils/segmentation.py +31 -0
- quadra/utils/tests/__init__.py +0 -0
- quadra/utils/tests/fixtures/__init__.py +1 -0
- quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
- quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
- quadra/utils/tests/fixtures/dataset/classification.py +406 -0
- quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
- quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
- quadra/utils/tests/fixtures/models/__init__.py +3 -0
- quadra/utils/tests/fixtures/models/anomaly.py +89 -0
- quadra/utils/tests/fixtures/models/classification.py +45 -0
- quadra/utils/tests/fixtures/models/segmentation.py +33 -0
- quadra/utils/tests/helpers.py +70 -0
- quadra/utils/tests/models.py +27 -0
- quadra/utils/utils.py +525 -0
- quadra/utils/validator.py +115 -0
- quadra/utils/visualization.py +422 -0
- quadra/utils/vit_explainability.py +349 -0
- quadra-2.2.7.dist-info/LICENSE +201 -0
- quadra-2.2.7.dist-info/METADATA +381 -0
- quadra-2.2.7.dist-info/RECORD +300 -0
- {quadra-0.0.1.dist-info → quadra-2.2.7.dist-info}/WHEEL +1 -1
- quadra-2.2.7.dist-info/entry_points.txt +3 -0
- quadra-0.0.1.dist-info/METADATA +0 -14
- quadra-0.0.1.dist-info/RECORD +0 -4
quadra/tasks/patch.py
ADDED
|
@@ -0,0 +1,492 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, cast
|
|
7
|
+
|
|
8
|
+
import hydra
|
|
9
|
+
import torch
|
|
10
|
+
from joblib import dump, load
|
|
11
|
+
from omegaconf import DictConfig, OmegaConf
|
|
12
|
+
from sklearn.base import ClassifierMixin
|
|
13
|
+
|
|
14
|
+
from quadra.datamodules import PatchSklearnClassificationDataModule
|
|
15
|
+
from quadra.datasets.patch import PatchSklearnClassificationTrainDataset
|
|
16
|
+
from quadra.models.base import ModelSignatureWrapper
|
|
17
|
+
from quadra.models.evaluation import BaseEvaluationModel
|
|
18
|
+
from quadra.tasks.base import Evaluation, Task
|
|
19
|
+
from quadra.trainers.classification import SklearnClassificationTrainer
|
|
20
|
+
from quadra.utils import utils
|
|
21
|
+
from quadra.utils.classification import automatic_batch_size_computation
|
|
22
|
+
from quadra.utils.evaluation import automatic_datamodule_batch_size
|
|
23
|
+
from quadra.utils.export import export_model, import_deployment_model
|
|
24
|
+
from quadra.utils.patch import RleEncoder, compute_patch_metrics, save_classification_result
|
|
25
|
+
from quadra.utils.patch.dataset import PatchDatasetFileFormat
|
|
26
|
+
|
|
27
|
+
log = utils.get_logger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class PatchSklearnClassification(Task[PatchSklearnClassificationDataModule]):
|
|
31
|
+
"""Patch classification using torch backbone for feature extraction and sklearn to learn a linear classifier.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
config: The experiment configuration
|
|
35
|
+
device: The device to use
|
|
36
|
+
output: Dictionary defining which kind of outputs to generate. Defaults to None.
|
|
37
|
+
automatic_batch_size: Whether to automatically find the largest batch size that fits in memory.
|
|
38
|
+
half_precision: Whether to use half precision.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
config: DictConfig,
|
|
44
|
+
output: DictConfig,
|
|
45
|
+
device: str,
|
|
46
|
+
automatic_batch_size: DictConfig,
|
|
47
|
+
half_precision: bool = False,
|
|
48
|
+
):
|
|
49
|
+
super().__init__(config=config)
|
|
50
|
+
self.device: str = device
|
|
51
|
+
self.output: DictConfig = output
|
|
52
|
+
self.return_polygon: bool = True
|
|
53
|
+
self.reconstruction_results: dict[str, Any]
|
|
54
|
+
self._backbone: ModelSignatureWrapper
|
|
55
|
+
self._trainer: SklearnClassificationTrainer
|
|
56
|
+
self._model: ClassifierMixin
|
|
57
|
+
self.metadata: dict[str, Any] = {
|
|
58
|
+
"test_confusion_matrix": [],
|
|
59
|
+
"test_accuracy": [],
|
|
60
|
+
"test_results": [],
|
|
61
|
+
"test_labels": [],
|
|
62
|
+
}
|
|
63
|
+
self.export_folder: str = "deployment_model"
|
|
64
|
+
self.automatic_batch_size = automatic_batch_size
|
|
65
|
+
self.half_precision = half_precision
|
|
66
|
+
|
|
67
|
+
@property
|
|
68
|
+
def model(self) -> ClassifierMixin:
|
|
69
|
+
"""sklearn.base.ClassifierMixin: The model."""
|
|
70
|
+
return self._model
|
|
71
|
+
|
|
72
|
+
@model.setter
|
|
73
|
+
def model(self, model_config: DictConfig):
|
|
74
|
+
"""sklearn.base.ClassifierMixin: The model."""
|
|
75
|
+
log.info("Instantiating model <%s>", model_config["_target_"])
|
|
76
|
+
self._model = hydra.utils.instantiate(model_config)
|
|
77
|
+
|
|
78
|
+
@property
|
|
79
|
+
def backbone(self) -> ModelSignatureWrapper:
|
|
80
|
+
"""Backbone: The backbone."""
|
|
81
|
+
return self._backbone
|
|
82
|
+
|
|
83
|
+
@backbone.setter
|
|
84
|
+
def backbone(self, backbone_config):
|
|
85
|
+
"""Load backbone."""
|
|
86
|
+
if backbone_config.metadata.get("checkpoint"):
|
|
87
|
+
log.info("Loading backbone from <%s>", backbone_config.metadata.checkpoint)
|
|
88
|
+
self._backbone = torch.load(backbone_config.metadata.checkpoint)
|
|
89
|
+
else:
|
|
90
|
+
log.info("Loading backbone from <%s>", backbone_config.model["_target_"])
|
|
91
|
+
self._backbone = hydra.utils.instantiate(backbone_config.model)
|
|
92
|
+
|
|
93
|
+
self._backbone = ModelSignatureWrapper(self._backbone)
|
|
94
|
+
self._backbone.eval()
|
|
95
|
+
if self.half_precision:
|
|
96
|
+
if self.device == "cpu":
|
|
97
|
+
raise ValueError("Half precision is not supported on CPU")
|
|
98
|
+
self._backbone.half()
|
|
99
|
+
self._backbone = self._backbone.to(self.device)
|
|
100
|
+
|
|
101
|
+
def prepare(self) -> None:
|
|
102
|
+
"""Prepare the experiment."""
|
|
103
|
+
self.datamodule = self.config.datamodule
|
|
104
|
+
self.backbone = self.config.backbone
|
|
105
|
+
self.model = self.config.model
|
|
106
|
+
|
|
107
|
+
if not self.automatic_batch_size.disable and self.device != "cpu":
|
|
108
|
+
self.datamodule.batch_size = automatic_batch_size_computation(
|
|
109
|
+
datamodule=self.datamodule,
|
|
110
|
+
backbone=self.backbone,
|
|
111
|
+
starting_batch_size=self.automatic_batch_size.starting_batch_size,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
self.trainer = self.config.trainer
|
|
115
|
+
|
|
116
|
+
@property
|
|
117
|
+
def trainer(self) -> SklearnClassificationTrainer:
|
|
118
|
+
"""Trainer: The trainer."""
|
|
119
|
+
return self._trainer
|
|
120
|
+
|
|
121
|
+
@trainer.setter
|
|
122
|
+
def trainer(self, trainer_config: DictConfig) -> None:
|
|
123
|
+
"""Trainer: The trainer."""
|
|
124
|
+
log.info("Instantiating trainer <%s>", trainer_config["_target_"])
|
|
125
|
+
trainer = hydra.utils.instantiate(trainer_config, backbone=self.backbone, classifier=self.model)
|
|
126
|
+
self._trainer = trainer
|
|
127
|
+
|
|
128
|
+
def train(self) -> None:
|
|
129
|
+
"""Train the model."""
|
|
130
|
+
log.info("Starting training...!")
|
|
131
|
+
# prepare_data() must be explicitly called if the task does not include a lightining training
|
|
132
|
+
self.datamodule.prepare_data()
|
|
133
|
+
self.datamodule.setup(stage="fit")
|
|
134
|
+
class_to_keep = None
|
|
135
|
+
if hasattr(self.datamodule, "class_to_skip_training") and self.datamodule.class_to_skip_training is not None:
|
|
136
|
+
class_to_keep = [x for x in self.datamodule.class_to_idx if x not in self.datamodule.class_to_skip_training]
|
|
137
|
+
|
|
138
|
+
self.model = self.config.model
|
|
139
|
+
self.trainer.change_classifier(self.model)
|
|
140
|
+
train_dataloader = self.datamodule.train_dataloader()
|
|
141
|
+
val_dataloader = self.datamodule.val_dataloader()
|
|
142
|
+
train_dataset = cast(PatchSklearnClassificationTrainDataset, train_dataloader.dataset)
|
|
143
|
+
self.trainer.fit(train_dataloader=train_dataloader)
|
|
144
|
+
_, pd_cm, accuracy, res, _ = self.trainer.test(
|
|
145
|
+
test_dataloader=val_dataloader,
|
|
146
|
+
class_to_keep=class_to_keep,
|
|
147
|
+
idx_to_class=train_dataset.idx_to_class,
|
|
148
|
+
predict_proba=True,
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
# save results
|
|
152
|
+
self.metadata["test_confusion_matrix"] = pd_cm
|
|
153
|
+
self.metadata["test_accuracy"] = accuracy
|
|
154
|
+
self.metadata["test_results"] = res
|
|
155
|
+
self.metadata["test_labels"] = [
|
|
156
|
+
train_dataset.idx_to_class[i] if i != -1 else "N/A" for i in res["real_label"].unique().tolist()
|
|
157
|
+
]
|
|
158
|
+
|
|
159
|
+
def generate_report(self) -> None:
|
|
160
|
+
"""Generate the report for the task."""
|
|
161
|
+
log.info("Generating report!")
|
|
162
|
+
os.makedirs(self.output.folder, exist_ok=True)
|
|
163
|
+
|
|
164
|
+
c_matrix = self.metadata["test_confusion_matrix"]
|
|
165
|
+
idx_to_class = {v: k for k, v in self.datamodule.class_to_idx.items()}
|
|
166
|
+
|
|
167
|
+
datamodule: PatchSklearnClassificationDataModule = self.datamodule
|
|
168
|
+
val_img_info: list[PatchDatasetFileFormat] = datamodule.info.val_files
|
|
169
|
+
for img_info in val_img_info:
|
|
170
|
+
if not os.path.isabs(img_info.image_path):
|
|
171
|
+
img_info.image_path = os.path.join(datamodule.data_path, img_info.image_path)
|
|
172
|
+
if img_info.mask_path is not None and not os.path.isabs(img_info.mask_path):
|
|
173
|
+
img_info.mask_path = os.path.join(datamodule.data_path, img_info.mask_path)
|
|
174
|
+
|
|
175
|
+
false_region_bad, false_region_good, true_region_bad, reconstructions = compute_patch_metrics(
|
|
176
|
+
test_img_info=val_img_info,
|
|
177
|
+
test_results=self.metadata["test_results"],
|
|
178
|
+
patch_num_h=datamodule.info.patch_number[0] if datamodule.info.patch_number is not None else None,
|
|
179
|
+
patch_num_w=datamodule.info.patch_number[1] if datamodule.info.patch_number is not None else None,
|
|
180
|
+
patch_h=datamodule.info.patch_size[0] if datamodule.info.patch_size is not None else None,
|
|
181
|
+
patch_w=datamodule.info.patch_size[1] if datamodule.info.patch_size is not None else None,
|
|
182
|
+
overlap=datamodule.info.overlap,
|
|
183
|
+
idx_to_class=idx_to_class,
|
|
184
|
+
return_polygon=self.return_polygon,
|
|
185
|
+
patch_reconstruction_method=self.output.reconstruction_method,
|
|
186
|
+
annotated_good=datamodule.info.annotated_good,
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
self.reconstruction_results = {
|
|
190
|
+
"false_region_bad": false_region_bad,
|
|
191
|
+
"false_region_good": false_region_good,
|
|
192
|
+
"true_region_bad": true_region_bad,
|
|
193
|
+
"reconstructions": reconstructions,
|
|
194
|
+
"reconstructions_type": "polygon" if self.return_polygon else "rle",
|
|
195
|
+
"patch_reconstruction_method": self.output.reconstruction_method,
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
with open("reconstruction_results.json", "w") as f:
|
|
199
|
+
json.dump(
|
|
200
|
+
self.reconstruction_results,
|
|
201
|
+
f,
|
|
202
|
+
cls=RleEncoder,
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
if hasattr(self.datamodule, "class_to_skip_training") and self.datamodule.class_to_skip_training is not None:
|
|
206
|
+
ignore_classes = [self.datamodule.class_to_idx[x] for x in self.datamodule.class_to_skip_training]
|
|
207
|
+
else:
|
|
208
|
+
ignore_classes = None
|
|
209
|
+
val_dataloader = self.datamodule.val_dataloader()
|
|
210
|
+
save_classification_result(
|
|
211
|
+
results=self.metadata["test_results"],
|
|
212
|
+
output_folder=self.output.folder,
|
|
213
|
+
confusion_matrix=c_matrix,
|
|
214
|
+
accuracy=self.metadata["test_accuracy"],
|
|
215
|
+
test_dataloader=val_dataloader,
|
|
216
|
+
config=self.config,
|
|
217
|
+
output=self.output,
|
|
218
|
+
reconstructions=reconstructions,
|
|
219
|
+
ignore_classes=ignore_classes,
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
def export(self) -> None:
|
|
223
|
+
"""Generate deployment model for the task."""
|
|
224
|
+
input_shapes = self.config.export.input_shapes
|
|
225
|
+
|
|
226
|
+
idx_to_class = {v: k for k, v in self.datamodule.class_to_idx.items()}
|
|
227
|
+
|
|
228
|
+
model_json, export_paths = export_model(
|
|
229
|
+
config=self.config,
|
|
230
|
+
model=self.backbone,
|
|
231
|
+
export_folder=self.export_folder,
|
|
232
|
+
half_precision=self.half_precision,
|
|
233
|
+
input_shapes=input_shapes,
|
|
234
|
+
idx_to_class=idx_to_class,
|
|
235
|
+
pytorch_model_type="backbone",
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
if len(export_paths) > 0:
|
|
239
|
+
dataset_info = self.datamodule.info
|
|
240
|
+
|
|
241
|
+
horizontal_patches = dataset_info.patch_number[1] if dataset_info.patch_number is not None else None
|
|
242
|
+
vertical_patches = dataset_info.patch_number[0] if dataset_info.patch_number is not None else None
|
|
243
|
+
patch_height = dataset_info.patch_size[0] if dataset_info.patch_size is not None else None
|
|
244
|
+
patch_width = dataset_info.patch_size[1] if dataset_info.patch_size is not None else None
|
|
245
|
+
overlap = dataset_info.overlap
|
|
246
|
+
|
|
247
|
+
model_json.update(
|
|
248
|
+
{
|
|
249
|
+
"horizontal_patches": horizontal_patches,
|
|
250
|
+
"vertical_patches": vertical_patches,
|
|
251
|
+
"patch_height": patch_height,
|
|
252
|
+
"patch_width": patch_width,
|
|
253
|
+
"overlap": overlap,
|
|
254
|
+
"reconstruction_method": self.output.reconstruction_method,
|
|
255
|
+
"class_to_skip": self.datamodule.class_to_skip_training,
|
|
256
|
+
}
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
with open(os.path.join(self.export_folder, "model.json"), "w") as f:
|
|
260
|
+
json.dump(model_json, f, cls=utils.HydraEncoder)
|
|
261
|
+
|
|
262
|
+
dump(self.model, os.path.join(self.export_folder, "classifier.joblib"))
|
|
263
|
+
|
|
264
|
+
def execute(self) -> None:
|
|
265
|
+
"""Execute the experiment and all the steps."""
|
|
266
|
+
self.prepare()
|
|
267
|
+
self.train()
|
|
268
|
+
if self.output.report:
|
|
269
|
+
self.generate_report()
|
|
270
|
+
if self.config.export is not None and len(self.config.export.types) > 0:
|
|
271
|
+
self.export()
|
|
272
|
+
self.finalize()
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
class PatchSklearnTestClassification(Evaluation[PatchSklearnClassificationDataModule]):
|
|
276
|
+
"""Perform a test of an already trained classification model.
|
|
277
|
+
|
|
278
|
+
Args:
|
|
279
|
+
config: The experiment configuration
|
|
280
|
+
output: where to save resultss
|
|
281
|
+
model_path: path to trained model from PatchSklearnClassification task.
|
|
282
|
+
device: the device where to run the model (cuda or cpu). Defaults to 'cpu'.
|
|
283
|
+
"""
|
|
284
|
+
|
|
285
|
+
def __init__(
|
|
286
|
+
self,
|
|
287
|
+
config: DictConfig,
|
|
288
|
+
output: DictConfig,
|
|
289
|
+
model_path: str,
|
|
290
|
+
device: str = "cpu",
|
|
291
|
+
):
|
|
292
|
+
super().__init__(config=config, model_path=model_path, device=device)
|
|
293
|
+
self.output = output
|
|
294
|
+
self._backbone: BaseEvaluationModel
|
|
295
|
+
self._classifier: ClassifierMixin
|
|
296
|
+
self.class_to_idx: dict[str, int]
|
|
297
|
+
self.idx_to_class: dict[int, str]
|
|
298
|
+
self.metadata: dict[str, Any] = {
|
|
299
|
+
"test_confusion_matrix": None,
|
|
300
|
+
"test_accuracy": None,
|
|
301
|
+
"test_results": None,
|
|
302
|
+
"test_labels": None,
|
|
303
|
+
}
|
|
304
|
+
self.class_to_skip: list[str] = []
|
|
305
|
+
self.reconstruction_results: dict[str, Any]
|
|
306
|
+
self.return_polygon: bool = True
|
|
307
|
+
|
|
308
|
+
def prepare(self) -> None:
|
|
309
|
+
"""Prepare the experiment."""
|
|
310
|
+
super().prepare()
|
|
311
|
+
|
|
312
|
+
idx_to_class = {}
|
|
313
|
+
class_to_idx = {}
|
|
314
|
+
for k, v in self.model_data["classes"].items():
|
|
315
|
+
idx_to_class[int(k)] = v
|
|
316
|
+
class_to_idx[v] = int(k)
|
|
317
|
+
|
|
318
|
+
self.idx_to_class = idx_to_class
|
|
319
|
+
self.class_to_idx = class_to_idx
|
|
320
|
+
self.config.datamodule.class_to_idx = class_to_idx
|
|
321
|
+
|
|
322
|
+
self.datamodule = self.config.datamodule
|
|
323
|
+
# Configure trainer
|
|
324
|
+
self.trainer = self.config.trainer
|
|
325
|
+
|
|
326
|
+
# prepare_data() must be explicitly called because there is no lightning training
|
|
327
|
+
self.datamodule.prepare_data()
|
|
328
|
+
self.datamodule.setup(stage="test")
|
|
329
|
+
|
|
330
|
+
@automatic_datamodule_batch_size(batch_size_attribute_name="batch_size")
|
|
331
|
+
def test(self) -> None:
|
|
332
|
+
"""Run the test."""
|
|
333
|
+
test_dataloader = self.datamodule.test_dataloader()
|
|
334
|
+
|
|
335
|
+
self.class_to_skip = self.model_data["class_to_skip"] if hasattr(self.model_data, "class_to_skip") else None
|
|
336
|
+
class_to_keep = None
|
|
337
|
+
|
|
338
|
+
if self.class_to_skip is not None:
|
|
339
|
+
class_to_keep = [x for x in self.datamodule.class_to_idx if x not in self.class_to_skip]
|
|
340
|
+
_, pd_cm, accuracy, res, _ = self.trainer.test(
|
|
341
|
+
test_dataloader=test_dataloader,
|
|
342
|
+
idx_to_class=self.idx_to_class,
|
|
343
|
+
predict_proba=True,
|
|
344
|
+
class_to_keep=class_to_keep,
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
# save results
|
|
348
|
+
self.metadata["test_confusion_matrix"] = pd_cm
|
|
349
|
+
self.metadata["test_accuracy"] = accuracy
|
|
350
|
+
self.metadata["test_results"] = res
|
|
351
|
+
self.metadata["test_labels"] = [
|
|
352
|
+
self.idx_to_class[i] if i != -1 else "N/A" for i in res["real_label"].unique().tolist()
|
|
353
|
+
]
|
|
354
|
+
|
|
355
|
+
@property
|
|
356
|
+
def deployment_model(self):
|
|
357
|
+
"""Deployment model."""
|
|
358
|
+
return None
|
|
359
|
+
|
|
360
|
+
@deployment_model.setter
|
|
361
|
+
def deployment_model(self, model_path: str):
|
|
362
|
+
"""Set backbone and classifier."""
|
|
363
|
+
self.backbone = model_path # type: ignore[assignment]
|
|
364
|
+
# Load classifier
|
|
365
|
+
self.classifier = os.path.join(Path(model_path).parent, "classifier.joblib")
|
|
366
|
+
|
|
367
|
+
@property
|
|
368
|
+
def classifier(self) -> ClassifierMixin:
|
|
369
|
+
"""Classifier: The classifier."""
|
|
370
|
+
return self._classifier
|
|
371
|
+
|
|
372
|
+
@classifier.setter
|
|
373
|
+
def classifier(self, classifier_path: str) -> None:
|
|
374
|
+
"""Load classifier."""
|
|
375
|
+
self._classifier = load(classifier_path)
|
|
376
|
+
|
|
377
|
+
@property
|
|
378
|
+
def backbone(self) -> BaseEvaluationModel:
|
|
379
|
+
"""Backbone: The backbone."""
|
|
380
|
+
return self._backbone
|
|
381
|
+
|
|
382
|
+
@backbone.setter
|
|
383
|
+
def backbone(self, model_path: str) -> None:
|
|
384
|
+
"""Load backbone."""
|
|
385
|
+
file_extension = os.path.splitext(model_path)[1]
|
|
386
|
+
|
|
387
|
+
model_architecture = None
|
|
388
|
+
if file_extension == ".pth":
|
|
389
|
+
backbone_config_path = os.path.join(Path(model_path).parent, "model_config.yaml")
|
|
390
|
+
log.info("Loading backbone from config")
|
|
391
|
+
backbone_config = OmegaConf.load(backbone_config_path)
|
|
392
|
+
|
|
393
|
+
if backbone_config.metadata.get("checkpoint"):
|
|
394
|
+
log.info("Loading backbone from <%s>", backbone_config.metadata.checkpoint)
|
|
395
|
+
model_architecture = torch.load(backbone_config.metadata.checkpoint)
|
|
396
|
+
else:
|
|
397
|
+
log.info("Loading backbone from <%s>", backbone_config.model["_target_"])
|
|
398
|
+
model_architecture = hydra.utils.instantiate(backbone_config.model)
|
|
399
|
+
|
|
400
|
+
self._backbone = import_deployment_model(
|
|
401
|
+
model_path=model_path,
|
|
402
|
+
device=self.device,
|
|
403
|
+
inference_config=self.config.inference,
|
|
404
|
+
model_architecture=model_architecture,
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
@property
|
|
408
|
+
def trainer(self) -> SklearnClassificationTrainer:
|
|
409
|
+
"""Trainer: The trainer."""
|
|
410
|
+
return self._trainer
|
|
411
|
+
|
|
412
|
+
@trainer.setter
|
|
413
|
+
def trainer(self, trainer_config: DictConfig) -> None:
|
|
414
|
+
"""Trainer: The trainer."""
|
|
415
|
+
log.info("Instantiating trainer <%s>", trainer_config["_target_"])
|
|
416
|
+
|
|
417
|
+
if self.backbone.training:
|
|
418
|
+
self.backbone.eval()
|
|
419
|
+
|
|
420
|
+
trainer = hydra.utils.instantiate(trainer_config, backbone=self.backbone, classifier=self.classifier)
|
|
421
|
+
self._trainer = trainer
|
|
422
|
+
|
|
423
|
+
def generate_report(self) -> None:
|
|
424
|
+
"""Generate a report for the task."""
|
|
425
|
+
log.info("Generating report!")
|
|
426
|
+
os.makedirs(self.output.folder, exist_ok=True)
|
|
427
|
+
|
|
428
|
+
c_matrix = self.metadata["test_confusion_matrix"]
|
|
429
|
+
idx_to_class = {v: k for k, v in self.datamodule.class_to_idx.items()}
|
|
430
|
+
|
|
431
|
+
datamodule: PatchSklearnClassificationDataModule = self.datamodule
|
|
432
|
+
test_img_info = datamodule.info.test_files
|
|
433
|
+
for img_info in test_img_info:
|
|
434
|
+
if not os.path.isabs(img_info.image_path):
|
|
435
|
+
img_info.image_path = os.path.join(datamodule.data_path, img_info.image_path)
|
|
436
|
+
if img_info.mask_path is not None and not os.path.isabs(img_info.mask_path):
|
|
437
|
+
img_info.mask_path = os.path.join(datamodule.data_path, img_info.mask_path)
|
|
438
|
+
|
|
439
|
+
false_region_bad, false_region_good, true_region_bad, reconstructions = compute_patch_metrics(
|
|
440
|
+
test_img_info=test_img_info,
|
|
441
|
+
test_results=self.metadata["test_results"],
|
|
442
|
+
patch_num_h=datamodule.info.patch_number[0] if datamodule.info.patch_number is not None else None,
|
|
443
|
+
patch_num_w=datamodule.info.patch_number[1] if datamodule.info.patch_number is not None else None,
|
|
444
|
+
patch_h=datamodule.info.patch_size[0] if datamodule.info.patch_size is not None else None,
|
|
445
|
+
patch_w=datamodule.info.patch_size[1] if datamodule.info.patch_size is not None else None,
|
|
446
|
+
overlap=datamodule.info.overlap,
|
|
447
|
+
idx_to_class=idx_to_class,
|
|
448
|
+
return_polygon=self.return_polygon,
|
|
449
|
+
patch_reconstruction_method=self.output.reconstruction_method,
|
|
450
|
+
annotated_good=datamodule.info.annotated_good,
|
|
451
|
+
)
|
|
452
|
+
|
|
453
|
+
self.reconstruction_results = {
|
|
454
|
+
"false_region_bad": false_region_bad,
|
|
455
|
+
"false_region_good": false_region_good,
|
|
456
|
+
"true_region_bad": true_region_bad,
|
|
457
|
+
"reconstructions": reconstructions,
|
|
458
|
+
"reconstructions_type": "polygon" if self.return_polygon else "rle",
|
|
459
|
+
"patch_reconstruction_method": self.output.reconstruction_method,
|
|
460
|
+
}
|
|
461
|
+
|
|
462
|
+
with open("reconstruction_results.json", "w") as f:
|
|
463
|
+
json.dump(
|
|
464
|
+
self.reconstruction_results,
|
|
465
|
+
f,
|
|
466
|
+
cls=RleEncoder,
|
|
467
|
+
)
|
|
468
|
+
|
|
469
|
+
if self.class_to_skip is not None:
|
|
470
|
+
ignore_classes = [datamodule.class_to_idx[x] for x in self.class_to_skip]
|
|
471
|
+
else:
|
|
472
|
+
ignore_classes = None
|
|
473
|
+
test_dataloader = self.datamodule.test_dataloader()
|
|
474
|
+
save_classification_result(
|
|
475
|
+
results=self.metadata["test_results"],
|
|
476
|
+
output_folder=self.output.folder,
|
|
477
|
+
confusion_matrix=c_matrix,
|
|
478
|
+
accuracy=self.metadata["test_accuracy"],
|
|
479
|
+
test_dataloader=test_dataloader,
|
|
480
|
+
config=self.config,
|
|
481
|
+
output=self.output,
|
|
482
|
+
reconstructions=reconstructions,
|
|
483
|
+
ignore_classes=ignore_classes,
|
|
484
|
+
)
|
|
485
|
+
|
|
486
|
+
def execute(self) -> None:
|
|
487
|
+
"""Execute the experiment and all the steps."""
|
|
488
|
+
self.prepare()
|
|
489
|
+
self.test()
|
|
490
|
+
if self.output.report:
|
|
491
|
+
self.generate_report()
|
|
492
|
+
self.finalize()
|