quadra 0.0.1__py3-none-any.whl → 2.1.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hydra_plugins/quadra_searchpath_plugin.py +30 -0
- quadra/__init__.py +6 -0
- quadra/callbacks/__init__.py +0 -0
- quadra/callbacks/anomalib.py +289 -0
- quadra/callbacks/lightning.py +501 -0
- quadra/callbacks/mlflow.py +291 -0
- quadra/callbacks/scheduler.py +69 -0
- quadra/configs/__init__.py +0 -0
- quadra/configs/backbone/caformer_m36.yaml +8 -0
- quadra/configs/backbone/caformer_s36.yaml +8 -0
- quadra/configs/backbone/convnextv2_base.yaml +8 -0
- quadra/configs/backbone/convnextv2_femto.yaml +8 -0
- quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
- quadra/configs/backbone/dino_vitb8.yaml +12 -0
- quadra/configs/backbone/dino_vits8.yaml +12 -0
- quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
- quadra/configs/backbone/dinov2_vits14.yaml +12 -0
- quadra/configs/backbone/efficientnet_b0.yaml +8 -0
- quadra/configs/backbone/efficientnet_b1.yaml +8 -0
- quadra/configs/backbone/efficientnet_b2.yaml +8 -0
- quadra/configs/backbone/efficientnet_b3.yaml +8 -0
- quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
- quadra/configs/backbone/levit_128s.yaml +8 -0
- quadra/configs/backbone/mnasnet0_5.yaml +9 -0
- quadra/configs/backbone/resnet101.yaml +8 -0
- quadra/configs/backbone/resnet18.yaml +8 -0
- quadra/configs/backbone/resnet18_ssl.yaml +8 -0
- quadra/configs/backbone/resnet50.yaml +8 -0
- quadra/configs/backbone/smp.yaml +9 -0
- quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
- quadra/configs/backbone/unetr.yaml +15 -0
- quadra/configs/backbone/vit16_base.yaml +9 -0
- quadra/configs/backbone/vit16_small.yaml +9 -0
- quadra/configs/backbone/vit16_tiny.yaml +9 -0
- quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
- quadra/configs/callbacks/all.yaml +32 -0
- quadra/configs/callbacks/default.yaml +37 -0
- quadra/configs/callbacks/default_anomalib.yaml +67 -0
- quadra/configs/config.yaml +33 -0
- quadra/configs/core/default.yaml +11 -0
- quadra/configs/datamodule/base/anomaly.yaml +16 -0
- quadra/configs/datamodule/base/classification.yaml +21 -0
- quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
- quadra/configs/datamodule/base/segmentation.yaml +18 -0
- quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
- quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
- quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
- quadra/configs/datamodule/base/ssl.yaml +21 -0
- quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
- quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
- quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
- quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
- quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
- quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
- quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
- quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
- quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
- quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
- quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/base/classification/classification.yaml +73 -0
- quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
- quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
- quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
- quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
- quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
- quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
- quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
- quadra/configs/experiment/base/ssl/byol.yaml +43 -0
- quadra/configs/experiment/base/ssl/dino.yaml +46 -0
- quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
- quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
- quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
- quadra/configs/experiment/custom/cls.yaml +12 -0
- quadra/configs/experiment/default.yaml +15 -0
- quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
- quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
- quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
- quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
- quadra/configs/export/default.yaml +13 -0
- quadra/configs/hydra/anomaly_custom.yaml +15 -0
- quadra/configs/hydra/default.yaml +14 -0
- quadra/configs/inference/default.yaml +26 -0
- quadra/configs/logger/comet.yaml +10 -0
- quadra/configs/logger/csv.yaml +5 -0
- quadra/configs/logger/mlflow.yaml +12 -0
- quadra/configs/logger/tensorboard.yaml +8 -0
- quadra/configs/loss/asl.yaml +7 -0
- quadra/configs/loss/barlow.yaml +2 -0
- quadra/configs/loss/bce.yaml +1 -0
- quadra/configs/loss/byol.yaml +1 -0
- quadra/configs/loss/cross_entropy.yaml +1 -0
- quadra/configs/loss/dino.yaml +8 -0
- quadra/configs/loss/simclr.yaml +2 -0
- quadra/configs/loss/simsiam.yaml +1 -0
- quadra/configs/loss/smp_ce.yaml +3 -0
- quadra/configs/loss/smp_dice.yaml +2 -0
- quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
- quadra/configs/loss/smp_mcc.yaml +2 -0
- quadra/configs/loss/vicreg.yaml +5 -0
- quadra/configs/model/anomalib/cfa.yaml +35 -0
- quadra/configs/model/anomalib/cflow.yaml +30 -0
- quadra/configs/model/anomalib/csflow.yaml +34 -0
- quadra/configs/model/anomalib/dfm.yaml +19 -0
- quadra/configs/model/anomalib/draem.yaml +29 -0
- quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
- quadra/configs/model/anomalib/fastflow.yaml +32 -0
- quadra/configs/model/anomalib/padim.yaml +32 -0
- quadra/configs/model/anomalib/patchcore.yaml +36 -0
- quadra/configs/model/barlow.yaml +16 -0
- quadra/configs/model/byol.yaml +25 -0
- quadra/configs/model/classification.yaml +10 -0
- quadra/configs/model/dino.yaml +26 -0
- quadra/configs/model/logistic_regression.yaml +4 -0
- quadra/configs/model/multilabel_classification.yaml +9 -0
- quadra/configs/model/simclr.yaml +18 -0
- quadra/configs/model/simsiam.yaml +24 -0
- quadra/configs/model/smp.yaml +4 -0
- quadra/configs/model/smp_multiclass.yaml +4 -0
- quadra/configs/model/vicreg.yaml +16 -0
- quadra/configs/optimizer/adam.yaml +5 -0
- quadra/configs/optimizer/adamw.yaml +3 -0
- quadra/configs/optimizer/default.yaml +4 -0
- quadra/configs/optimizer/lars.yaml +8 -0
- quadra/configs/optimizer/sgd.yaml +4 -0
- quadra/configs/scheduler/default.yaml +5 -0
- quadra/configs/scheduler/rop.yaml +5 -0
- quadra/configs/scheduler/step.yaml +3 -0
- quadra/configs/scheduler/warmrestart.yaml +2 -0
- quadra/configs/scheduler/warmup.yaml +6 -0
- quadra/configs/task/anomalib/cfa.yaml +5 -0
- quadra/configs/task/anomalib/cflow.yaml +5 -0
- quadra/configs/task/anomalib/csflow.yaml +5 -0
- quadra/configs/task/anomalib/draem.yaml +5 -0
- quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
- quadra/configs/task/anomalib/fastflow.yaml +5 -0
- quadra/configs/task/anomalib/inference.yaml +3 -0
- quadra/configs/task/anomalib/padim.yaml +5 -0
- quadra/configs/task/anomalib/patchcore.yaml +5 -0
- quadra/configs/task/classification.yaml +6 -0
- quadra/configs/task/classification_evaluation.yaml +6 -0
- quadra/configs/task/default.yaml +1 -0
- quadra/configs/task/segmentation.yaml +9 -0
- quadra/configs/task/segmentation_evaluation.yaml +3 -0
- quadra/configs/task/sklearn_classification.yaml +13 -0
- quadra/configs/task/sklearn_classification_patch.yaml +11 -0
- quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
- quadra/configs/task/sklearn_classification_test.yaml +8 -0
- quadra/configs/task/ssl.yaml +2 -0
- quadra/configs/trainer/lightning_cpu.yaml +36 -0
- quadra/configs/trainer/lightning_gpu.yaml +35 -0
- quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
- quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
- quadra/configs/trainer/lightning_multigpu.yaml +37 -0
- quadra/configs/trainer/sklearn_classification.yaml +7 -0
- quadra/configs/transforms/byol.yaml +47 -0
- quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
- quadra/configs/transforms/default.yaml +37 -0
- quadra/configs/transforms/default_numpy.yaml +24 -0
- quadra/configs/transforms/default_resize.yaml +22 -0
- quadra/configs/transforms/dino.yaml +63 -0
- quadra/configs/transforms/linear_eval.yaml +18 -0
- quadra/datamodules/__init__.py +20 -0
- quadra/datamodules/anomaly.py +180 -0
- quadra/datamodules/base.py +375 -0
- quadra/datamodules/classification.py +1003 -0
- quadra/datamodules/generic/__init__.py +0 -0
- quadra/datamodules/generic/imagenette.py +144 -0
- quadra/datamodules/generic/mnist.py +81 -0
- quadra/datamodules/generic/mvtec.py +58 -0
- quadra/datamodules/generic/oxford_pet.py +163 -0
- quadra/datamodules/patch.py +190 -0
- quadra/datamodules/segmentation.py +742 -0
- quadra/datamodules/ssl.py +140 -0
- quadra/datasets/__init__.py +17 -0
- quadra/datasets/anomaly.py +287 -0
- quadra/datasets/classification.py +241 -0
- quadra/datasets/patch.py +138 -0
- quadra/datasets/segmentation.py +239 -0
- quadra/datasets/ssl.py +110 -0
- quadra/losses/__init__.py +0 -0
- quadra/losses/classification/__init__.py +6 -0
- quadra/losses/classification/asl.py +83 -0
- quadra/losses/classification/focal.py +320 -0
- quadra/losses/classification/prototypical.py +148 -0
- quadra/losses/ssl/__init__.py +17 -0
- quadra/losses/ssl/barlowtwins.py +47 -0
- quadra/losses/ssl/byol.py +37 -0
- quadra/losses/ssl/dino.py +129 -0
- quadra/losses/ssl/hyperspherical.py +45 -0
- quadra/losses/ssl/idmm.py +50 -0
- quadra/losses/ssl/simclr.py +67 -0
- quadra/losses/ssl/simsiam.py +30 -0
- quadra/losses/ssl/vicreg.py +76 -0
- quadra/main.py +46 -0
- quadra/metrics/__init__.py +3 -0
- quadra/metrics/segmentation.py +251 -0
- quadra/models/__init__.py +0 -0
- quadra/models/base.py +151 -0
- quadra/models/classification/__init__.py +8 -0
- quadra/models/classification/backbones.py +149 -0
- quadra/models/classification/base.py +92 -0
- quadra/models/evaluation.py +322 -0
- quadra/modules/__init__.py +0 -0
- quadra/modules/backbone.py +30 -0
- quadra/modules/base.py +312 -0
- quadra/modules/classification/__init__.py +3 -0
- quadra/modules/classification/base.py +331 -0
- quadra/modules/ssl/__init__.py +17 -0
- quadra/modules/ssl/barlowtwins.py +59 -0
- quadra/modules/ssl/byol.py +172 -0
- quadra/modules/ssl/common.py +285 -0
- quadra/modules/ssl/dino.py +186 -0
- quadra/modules/ssl/hyperspherical.py +206 -0
- quadra/modules/ssl/idmm.py +98 -0
- quadra/modules/ssl/simclr.py +73 -0
- quadra/modules/ssl/simsiam.py +68 -0
- quadra/modules/ssl/vicreg.py +67 -0
- quadra/optimizers/__init__.py +4 -0
- quadra/optimizers/lars.py +153 -0
- quadra/optimizers/sam.py +127 -0
- quadra/schedulers/__init__.py +3 -0
- quadra/schedulers/base.py +44 -0
- quadra/schedulers/warmup.py +127 -0
- quadra/tasks/__init__.py +24 -0
- quadra/tasks/anomaly.py +582 -0
- quadra/tasks/base.py +397 -0
- quadra/tasks/classification.py +1264 -0
- quadra/tasks/patch.py +492 -0
- quadra/tasks/segmentation.py +389 -0
- quadra/tasks/ssl.py +560 -0
- quadra/trainers/README.md +3 -0
- quadra/trainers/__init__.py +0 -0
- quadra/trainers/classification.py +179 -0
- quadra/utils/__init__.py +0 -0
- quadra/utils/anomaly.py +112 -0
- quadra/utils/classification.py +618 -0
- quadra/utils/deprecation.py +31 -0
- quadra/utils/evaluation.py +474 -0
- quadra/utils/export.py +579 -0
- quadra/utils/imaging.py +32 -0
- quadra/utils/logger.py +15 -0
- quadra/utils/mlflow.py +98 -0
- quadra/utils/model_manager.py +320 -0
- quadra/utils/models.py +524 -0
- quadra/utils/patch/__init__.py +15 -0
- quadra/utils/patch/dataset.py +1433 -0
- quadra/utils/patch/metrics.py +449 -0
- quadra/utils/patch/model.py +153 -0
- quadra/utils/patch/visualization.py +217 -0
- quadra/utils/resolver.py +42 -0
- quadra/utils/segmentation.py +31 -0
- quadra/utils/tests/__init__.py +0 -0
- quadra/utils/tests/fixtures/__init__.py +1 -0
- quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
- quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
- quadra/utils/tests/fixtures/dataset/classification.py +406 -0
- quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
- quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
- quadra/utils/tests/fixtures/models/__init__.py +3 -0
- quadra/utils/tests/fixtures/models/anomaly.py +89 -0
- quadra/utils/tests/fixtures/models/classification.py +45 -0
- quadra/utils/tests/fixtures/models/segmentation.py +33 -0
- quadra/utils/tests/helpers.py +70 -0
- quadra/utils/tests/models.py +27 -0
- quadra/utils/utils.py +525 -0
- quadra/utils/validator.py +115 -0
- quadra/utils/visualization.py +422 -0
- quadra/utils/vit_explainability.py +349 -0
- quadra-2.1.13.dist-info/LICENSE +201 -0
- quadra-2.1.13.dist-info/METADATA +386 -0
- quadra-2.1.13.dist-info/RECORD +300 -0
- {quadra-0.0.1.dist-info → quadra-2.1.13.dist-info}/WHEEL +1 -1
- quadra-2.1.13.dist-info/entry_points.txt +3 -0
- quadra-0.0.1.dist-info/METADATA +0 -14
- quadra-0.0.1.dist-info/RECORD +0 -4
quadra/models/base.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import inspect
|
|
4
|
+
from collections.abc import Sequence
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from torch import nn
|
|
9
|
+
|
|
10
|
+
from quadra.utils.logger import get_logger
|
|
11
|
+
|
|
12
|
+
log = get_logger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ModelSignatureWrapper(nn.Module):
|
|
16
|
+
"""Model wrapper used to retrieve input shape. It can be used as a decorator of nn.Module, the first call to the
|
|
17
|
+
forward method will retrieve the input shape and store it in the input_shapes attribute.
|
|
18
|
+
It will also save the model summary in a file called model_summary.txt in the current working directory.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(self, model: nn.Module):
|
|
22
|
+
super().__init__()
|
|
23
|
+
self.instance = model
|
|
24
|
+
self.input_shapes: Any = None
|
|
25
|
+
self.disable = False
|
|
26
|
+
|
|
27
|
+
if isinstance(self.instance, ModelSignatureWrapper):
|
|
28
|
+
# Handle nested ModelSignatureWrapper
|
|
29
|
+
self.input_shapes = self.instance.input_shapes
|
|
30
|
+
self.instance = self.instance.instance
|
|
31
|
+
|
|
32
|
+
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
|
33
|
+
"""Retrieve the input shape and forward the model, if the input shape is already retrieved it will just forward
|
|
34
|
+
the model.
|
|
35
|
+
"""
|
|
36
|
+
if self.input_shapes is None and not self.disable:
|
|
37
|
+
try:
|
|
38
|
+
self.input_shapes = self._get_input_shapes(*args, **kwargs)
|
|
39
|
+
except Exception:
|
|
40
|
+
log.warning(
|
|
41
|
+
"Failed to retrieve input shapes after forward! To export the model you'll need to "
|
|
42
|
+
"provide the input shapes manually setting the config.export.input_shapes parameter! "
|
|
43
|
+
"Alternatively you could try to use a forward with supported input types (and their compositions) "
|
|
44
|
+
"(list, tuple, dict, tensors)."
|
|
45
|
+
)
|
|
46
|
+
self.disable = True
|
|
47
|
+
|
|
48
|
+
return self.instance.forward(*args, **kwargs)
|
|
49
|
+
|
|
50
|
+
def to(self, *args, **kwargs):
|
|
51
|
+
"""Handle calls to to method returning the underlying model."""
|
|
52
|
+
self.instance = self.instance.to(*args, **kwargs)
|
|
53
|
+
|
|
54
|
+
return self
|
|
55
|
+
|
|
56
|
+
def half(self, *args, **kwargs):
|
|
57
|
+
"""Handle calls to to method returning the underlying model."""
|
|
58
|
+
self.instance = self.instance.half(*args, **kwargs)
|
|
59
|
+
|
|
60
|
+
return self
|
|
61
|
+
|
|
62
|
+
def cpu(self, *args, **kwargs):
|
|
63
|
+
"""Handle calls to to method returning the underlying model."""
|
|
64
|
+
self.instance = self.instance.cpu(*args, **kwargs)
|
|
65
|
+
|
|
66
|
+
return self
|
|
67
|
+
|
|
68
|
+
def _get_input_shapes(self, *args: Any, **kwargs: Any) -> list[Any]:
|
|
69
|
+
"""Retrieve the input shapes from the input. Inputs will be in the same order as the forward method
|
|
70
|
+
signature.
|
|
71
|
+
"""
|
|
72
|
+
input_shapes = []
|
|
73
|
+
|
|
74
|
+
for arg in args:
|
|
75
|
+
input_shapes.append(self._get_input_shape(arg))
|
|
76
|
+
|
|
77
|
+
if isinstance(self.instance.forward, torch.ScriptMethod):
|
|
78
|
+
# Handle torchscript backbones
|
|
79
|
+
for i, argument in enumerate(self.instance.forward.schema.arguments):
|
|
80
|
+
if i < (len(args) + 1): # +1 for self
|
|
81
|
+
continue
|
|
82
|
+
|
|
83
|
+
if argument.name == "self":
|
|
84
|
+
continue
|
|
85
|
+
|
|
86
|
+
if argument.name in kwargs:
|
|
87
|
+
input_shapes.append(self._get_input_shape(kwargs[argument.name]))
|
|
88
|
+
else:
|
|
89
|
+
# Retrieve the default value
|
|
90
|
+
input_shapes.append(self._get_input_shape(argument.default_value))
|
|
91
|
+
else:
|
|
92
|
+
signature = inspect.signature(self.instance.forward)
|
|
93
|
+
|
|
94
|
+
for i, key in enumerate(signature.parameters.keys()):
|
|
95
|
+
if i < len(args):
|
|
96
|
+
continue
|
|
97
|
+
|
|
98
|
+
if key in kwargs:
|
|
99
|
+
input_shapes.append(self._get_input_shape(kwargs[key]))
|
|
100
|
+
else:
|
|
101
|
+
# Retrieve the default value
|
|
102
|
+
input_shapes.append(self._get_input_shape(signature.parameters[key].default))
|
|
103
|
+
|
|
104
|
+
return input_shapes
|
|
105
|
+
|
|
106
|
+
def _get_input_shape(self, inp: Sequence | torch.Tensor) -> list[Any] | tuple[Any, ...] | dict[str, Any]:
|
|
107
|
+
"""Recursive function to retrieve the input shapes."""
|
|
108
|
+
if isinstance(inp, list):
|
|
109
|
+
return [self._get_input_shape(i) for i in inp]
|
|
110
|
+
|
|
111
|
+
if isinstance(inp, tuple):
|
|
112
|
+
return tuple(self._get_input_shape(i) for i in inp)
|
|
113
|
+
|
|
114
|
+
if isinstance(inp, torch.Tensor):
|
|
115
|
+
return tuple(inp.shape[1:])
|
|
116
|
+
|
|
117
|
+
if isinstance(inp, dict):
|
|
118
|
+
return {k: self._get_input_shape(v) for k, v in inp.items()}
|
|
119
|
+
|
|
120
|
+
raise ValueError(f"Input type {type(inp)} not supported")
|
|
121
|
+
|
|
122
|
+
def __getattr__(self, name: str) -> torch.Tensor | nn.Module:
|
|
123
|
+
if name in ["instance", "input_shapes"]:
|
|
124
|
+
return self.__dict__[name]
|
|
125
|
+
|
|
126
|
+
return getattr(self.__dict__["instance"], name)
|
|
127
|
+
|
|
128
|
+
def __setattr__(self, name: str, value: torch.Tensor | nn.Module) -> None:
|
|
129
|
+
if name in ["instance", "input_shapes"]:
|
|
130
|
+
self.__dict__[name] = value
|
|
131
|
+
else:
|
|
132
|
+
setattr(self.instance, name, value)
|
|
133
|
+
|
|
134
|
+
def __getattribute__(self, __name: str) -> Any:
|
|
135
|
+
if __name in [
|
|
136
|
+
"instance",
|
|
137
|
+
"input_shapes",
|
|
138
|
+
"__dict__",
|
|
139
|
+
"forward",
|
|
140
|
+
"_get_input_shapes",
|
|
141
|
+
"_get_input_shape",
|
|
142
|
+
"to",
|
|
143
|
+
"half",
|
|
144
|
+
"cpu",
|
|
145
|
+
"call_super_init",
|
|
146
|
+
"_call_impl",
|
|
147
|
+
"_compiled_call_impl",
|
|
148
|
+
]:
|
|
149
|
+
return super().__getattribute__(__name)
|
|
150
|
+
|
|
151
|
+
return getattr(self.instance, __name)
|
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import timm
|
|
6
|
+
import torch
|
|
7
|
+
from timm.models.helpers import load_checkpoint
|
|
8
|
+
from torch import nn
|
|
9
|
+
from torchvision import models
|
|
10
|
+
|
|
11
|
+
from quadra.models.classification.base import BaseNetworkBuilder
|
|
12
|
+
from quadra.utils.logger import get_logger
|
|
13
|
+
|
|
14
|
+
log = get_logger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class TorchHubNetworkBuilder(BaseNetworkBuilder):
|
|
18
|
+
"""TorchHub feature extractor, with the possibility to map features to an hypersphere.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
repo_or_dir: The name of the repository or the path to the directory containing the model.
|
|
22
|
+
model_name: The name of the model within the repository.
|
|
23
|
+
pretrained: Whether to load the pretrained weights for the model.
|
|
24
|
+
pre_classifier: Pre classifier as a torch.nn.Module. Defaults to nn.Identity() if None.
|
|
25
|
+
classifier: Classifier as a torch.nn.Module. Defaults to nn.Identity() if None.
|
|
26
|
+
freeze: Whether to freeze the feature extractor. Defaults to True.
|
|
27
|
+
hyperspherical: Whether to map features to an hypersphere. Defaults to False.
|
|
28
|
+
flatten_features: Whether to flatten the features before the pre_classifier. Defaults to True.
|
|
29
|
+
checkpoint_path: Path to a checkpoint to load after the model is initialized. Defaults to None.
|
|
30
|
+
**torch_hub_kwargs: Additional arguments to pass to torch.hub.load
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
repo_or_dir: str,
|
|
36
|
+
model_name: str,
|
|
37
|
+
pretrained: bool = True,
|
|
38
|
+
pre_classifier: nn.Module | None = None,
|
|
39
|
+
classifier: nn.Module | None = None,
|
|
40
|
+
freeze: bool = True,
|
|
41
|
+
hyperspherical: bool = False,
|
|
42
|
+
flatten_features: bool = True,
|
|
43
|
+
checkpoint_path: str | None = None,
|
|
44
|
+
**torch_hub_kwargs: Any,
|
|
45
|
+
):
|
|
46
|
+
self.pretrained = pretrained
|
|
47
|
+
features_extractor = torch.hub.load(
|
|
48
|
+
repo_or_dir=repo_or_dir, model=model_name, pretrained=self.pretrained, **torch_hub_kwargs
|
|
49
|
+
)
|
|
50
|
+
if checkpoint_path:
|
|
51
|
+
log.info("Loading checkpoint from %s", checkpoint_path)
|
|
52
|
+
load_checkpoint(model=features_extractor, checkpoint_path=checkpoint_path)
|
|
53
|
+
|
|
54
|
+
super().__init__(
|
|
55
|
+
features_extractor=features_extractor,
|
|
56
|
+
pre_classifier=pre_classifier,
|
|
57
|
+
classifier=classifier,
|
|
58
|
+
freeze=freeze,
|
|
59
|
+
hyperspherical=hyperspherical,
|
|
60
|
+
flatten_features=flatten_features,
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class TorchVisionNetworkBuilder(BaseNetworkBuilder):
|
|
65
|
+
"""Torchvision feature extractor, with the possibility to map features to an hypersphere.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
model_name: Torchvision model function that will be evaluated, for example: torchvision.models.resnet18.
|
|
69
|
+
pretrained: Whether to load the pretrained weights for the model.
|
|
70
|
+
pre_classifier: Pre classifier as a torch.nn.Module. Defaults to nn.Identity() if None.
|
|
71
|
+
classifier: Classifier as a torch.nn.Module. Defaults to nn.Identity() if None.
|
|
72
|
+
freeze: Whether to freeze the feature extractor. Defaults to True.
|
|
73
|
+
hyperspherical: Whether to map features to an hypersphere. Defaults to False.
|
|
74
|
+
flatten_features: Whether to flatten the features before the pre_classifier. Defaults to True.
|
|
75
|
+
checkpoint_path: Path to a checkpoint to load after the model is initialized. Defaults to None.
|
|
76
|
+
**torchvision_kwargs: Additional arguments to pass to the model function.
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
def __init__(
|
|
80
|
+
self,
|
|
81
|
+
model_name: str,
|
|
82
|
+
pretrained: bool = True,
|
|
83
|
+
pre_classifier: nn.Module | None = None,
|
|
84
|
+
classifier: nn.Module | None = None,
|
|
85
|
+
freeze: bool = True,
|
|
86
|
+
hyperspherical: bool = False,
|
|
87
|
+
flatten_features: bool = True,
|
|
88
|
+
checkpoint_path: str | None = None,
|
|
89
|
+
**torchvision_kwargs: Any,
|
|
90
|
+
):
|
|
91
|
+
self.pretrained = pretrained
|
|
92
|
+
model_function = models.__dict__[model_name]
|
|
93
|
+
features_extractor = model_function(pretrained=self.pretrained, progress=True, **torchvision_kwargs)
|
|
94
|
+
if checkpoint_path:
|
|
95
|
+
log.info("Loading checkpoint from %s", checkpoint_path)
|
|
96
|
+
load_checkpoint(model=features_extractor, checkpoint_path=checkpoint_path)
|
|
97
|
+
|
|
98
|
+
# Remove classifier
|
|
99
|
+
features_extractor.classifier = nn.Identity()
|
|
100
|
+
super().__init__(
|
|
101
|
+
features_extractor=features_extractor,
|
|
102
|
+
pre_classifier=pre_classifier,
|
|
103
|
+
classifier=classifier,
|
|
104
|
+
freeze=freeze,
|
|
105
|
+
hyperspherical=hyperspherical,
|
|
106
|
+
flatten_features=flatten_features,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class TimmNetworkBuilder(BaseNetworkBuilder):
|
|
111
|
+
"""Torchvision feature extractor, with the possibility to map features to an hypersphere.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
model_name: Timm model name
|
|
115
|
+
pretrained: Whether to load the pretrained weights for the model.
|
|
116
|
+
pre_classifier: Pre classifier as a torch.nn.Module. Defaults to nn.Identity() if None.
|
|
117
|
+
classifier: Classifier as a torch.nn.Module. Defaults to nn.Identity() if None.
|
|
118
|
+
freeze: Whether to freeze the feature extractor. Defaults to True.
|
|
119
|
+
hyperspherical: Whether to map features to an hypersphere. Defaults to False.
|
|
120
|
+
flatten_features: Whether to flatten the features before the pre_classifier. Defaults to True.
|
|
121
|
+
checkpoint_path: Path to a checkpoint to load after the model is initialized. Defaults to None.
|
|
122
|
+
**timm_kwargs: Additional arguments to pass to timm.create_model
|
|
123
|
+
"""
|
|
124
|
+
|
|
125
|
+
def __init__(
|
|
126
|
+
self,
|
|
127
|
+
model_name: str,
|
|
128
|
+
pretrained: bool = True,
|
|
129
|
+
pre_classifier: nn.Module | None = None,
|
|
130
|
+
classifier: nn.Module | None = None,
|
|
131
|
+
freeze: bool = True,
|
|
132
|
+
hyperspherical: bool = False,
|
|
133
|
+
flatten_features: bool = True,
|
|
134
|
+
checkpoint_path: str | None = None,
|
|
135
|
+
**timm_kwargs: Any,
|
|
136
|
+
):
|
|
137
|
+
self.pretrained = pretrained
|
|
138
|
+
features_extractor = timm.create_model(
|
|
139
|
+
model_name, pretrained=self.pretrained, num_classes=0, checkpoint_path=checkpoint_path, **timm_kwargs
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
super().__init__(
|
|
143
|
+
features_extractor=features_extractor,
|
|
144
|
+
pre_classifier=pre_classifier,
|
|
145
|
+
classifier=classifier,
|
|
146
|
+
freeze=freeze,
|
|
147
|
+
hyperspherical=hyperspherical,
|
|
148
|
+
flatten_features=flatten_features,
|
|
149
|
+
)
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from torch import nn
|
|
4
|
+
|
|
5
|
+
from quadra.utils.models import L2Norm
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class BaseNetworkBuilder(nn.Module):
|
|
9
|
+
"""Baseline Feature Extractor, with the possibility to map features to an hypersphere.
|
|
10
|
+
If hypershperical is True the classifier is ignored.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
features_extractor: Feature extractor as a toch.nn.Module.
|
|
14
|
+
pre_classifier: Pre classifier as a torch.nn.Module. Defaults to nn.Identity() if None.
|
|
15
|
+
classifier: Classifier as a torch.nn.Module. Defaults to nn.Identity() if None.
|
|
16
|
+
freeze: Whether to freeze the feature extractor. Defaults to True.
|
|
17
|
+
hyperspherical: Whether to map features to an hypersphere. Defaults to False.
|
|
18
|
+
flatten_features: Whether to flatten the features before the pre_classifier. May be required if your model
|
|
19
|
+
is outputting a feature map rather than a vector. Defaults to True.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
features_extractor: nn.Module,
|
|
25
|
+
pre_classifier: nn.Module | None = None,
|
|
26
|
+
classifier: nn.Module | None = None,
|
|
27
|
+
freeze: bool = True,
|
|
28
|
+
hyperspherical: bool = False,
|
|
29
|
+
flatten_features: bool = True,
|
|
30
|
+
):
|
|
31
|
+
super().__init__()
|
|
32
|
+
if pre_classifier is None:
|
|
33
|
+
pre_classifier = nn.Identity()
|
|
34
|
+
|
|
35
|
+
if classifier is None:
|
|
36
|
+
classifier = nn.Identity()
|
|
37
|
+
|
|
38
|
+
self.features_extractor = features_extractor
|
|
39
|
+
self.freeze = freeze
|
|
40
|
+
self.hyperspherical = hyperspherical
|
|
41
|
+
self.pre_classifier = pre_classifier
|
|
42
|
+
self.classifier = classifier
|
|
43
|
+
self.flatten: bool = False
|
|
44
|
+
self._hyperspherical: bool = False
|
|
45
|
+
self.l2: L2Norm | None = None
|
|
46
|
+
self.flatten_features = flatten_features
|
|
47
|
+
|
|
48
|
+
self.freeze = freeze
|
|
49
|
+
self.hyperspherical = hyperspherical
|
|
50
|
+
|
|
51
|
+
if self.freeze:
|
|
52
|
+
for p in self.features_extractor.parameters():
|
|
53
|
+
p.requires_grad = False
|
|
54
|
+
|
|
55
|
+
@property
|
|
56
|
+
def freeze(self) -> bool:
|
|
57
|
+
"""Whether to freeze the feature extractor."""
|
|
58
|
+
return self._freeze
|
|
59
|
+
|
|
60
|
+
@freeze.setter
|
|
61
|
+
def freeze(self, value: bool) -> None:
|
|
62
|
+
"""Whether to freeze the feature extractor."""
|
|
63
|
+
for p in self.features_extractor.parameters():
|
|
64
|
+
p.requires_grad = not value
|
|
65
|
+
|
|
66
|
+
self._freeze = value
|
|
67
|
+
|
|
68
|
+
@property
|
|
69
|
+
def hyperspherical(self) -> bool:
|
|
70
|
+
"""Whether to map the extracted features into an hypersphere."""
|
|
71
|
+
return self._hyperspherical
|
|
72
|
+
|
|
73
|
+
@hyperspherical.setter
|
|
74
|
+
def hyperspherical(self, value: bool) -> None:
|
|
75
|
+
"""Whether to map the extracted features into an hypersphere."""
|
|
76
|
+
self._hyperspherical = value
|
|
77
|
+
self.l2 = L2Norm() if value else None
|
|
78
|
+
|
|
79
|
+
def forward(self, x):
|
|
80
|
+
x = self.features_extractor(x)
|
|
81
|
+
|
|
82
|
+
if self.flatten_features:
|
|
83
|
+
x = x.view(x.size(0), -1)
|
|
84
|
+
|
|
85
|
+
x = self.pre_classifier(x)
|
|
86
|
+
|
|
87
|
+
if self.hyperspherical:
|
|
88
|
+
x = self.l2(x)
|
|
89
|
+
|
|
90
|
+
x = self.classifier(x)
|
|
91
|
+
|
|
92
|
+
return x
|