quadra 0.0.1__py3-none-any.whl → 2.1.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hydra_plugins/quadra_searchpath_plugin.py +30 -0
- quadra/__init__.py +6 -0
- quadra/callbacks/__init__.py +0 -0
- quadra/callbacks/anomalib.py +289 -0
- quadra/callbacks/lightning.py +501 -0
- quadra/callbacks/mlflow.py +291 -0
- quadra/callbacks/scheduler.py +69 -0
- quadra/configs/__init__.py +0 -0
- quadra/configs/backbone/caformer_m36.yaml +8 -0
- quadra/configs/backbone/caformer_s36.yaml +8 -0
- quadra/configs/backbone/convnextv2_base.yaml +8 -0
- quadra/configs/backbone/convnextv2_femto.yaml +8 -0
- quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
- quadra/configs/backbone/dino_vitb8.yaml +12 -0
- quadra/configs/backbone/dino_vits8.yaml +12 -0
- quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
- quadra/configs/backbone/dinov2_vits14.yaml +12 -0
- quadra/configs/backbone/efficientnet_b0.yaml +8 -0
- quadra/configs/backbone/efficientnet_b1.yaml +8 -0
- quadra/configs/backbone/efficientnet_b2.yaml +8 -0
- quadra/configs/backbone/efficientnet_b3.yaml +8 -0
- quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
- quadra/configs/backbone/levit_128s.yaml +8 -0
- quadra/configs/backbone/mnasnet0_5.yaml +9 -0
- quadra/configs/backbone/resnet101.yaml +8 -0
- quadra/configs/backbone/resnet18.yaml +8 -0
- quadra/configs/backbone/resnet18_ssl.yaml +8 -0
- quadra/configs/backbone/resnet50.yaml +8 -0
- quadra/configs/backbone/smp.yaml +9 -0
- quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
- quadra/configs/backbone/unetr.yaml +15 -0
- quadra/configs/backbone/vit16_base.yaml +9 -0
- quadra/configs/backbone/vit16_small.yaml +9 -0
- quadra/configs/backbone/vit16_tiny.yaml +9 -0
- quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
- quadra/configs/callbacks/all.yaml +32 -0
- quadra/configs/callbacks/default.yaml +37 -0
- quadra/configs/callbacks/default_anomalib.yaml +67 -0
- quadra/configs/config.yaml +33 -0
- quadra/configs/core/default.yaml +11 -0
- quadra/configs/datamodule/base/anomaly.yaml +16 -0
- quadra/configs/datamodule/base/classification.yaml +21 -0
- quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
- quadra/configs/datamodule/base/segmentation.yaml +18 -0
- quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
- quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
- quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
- quadra/configs/datamodule/base/ssl.yaml +21 -0
- quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
- quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
- quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
- quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
- quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
- quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
- quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
- quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
- quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
- quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
- quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/base/classification/classification.yaml +73 -0
- quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
- quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
- quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
- quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
- quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
- quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
- quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
- quadra/configs/experiment/base/ssl/byol.yaml +43 -0
- quadra/configs/experiment/base/ssl/dino.yaml +46 -0
- quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
- quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
- quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
- quadra/configs/experiment/custom/cls.yaml +12 -0
- quadra/configs/experiment/default.yaml +15 -0
- quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
- quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
- quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
- quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
- quadra/configs/export/default.yaml +13 -0
- quadra/configs/hydra/anomaly_custom.yaml +15 -0
- quadra/configs/hydra/default.yaml +14 -0
- quadra/configs/inference/default.yaml +26 -0
- quadra/configs/logger/comet.yaml +10 -0
- quadra/configs/logger/csv.yaml +5 -0
- quadra/configs/logger/mlflow.yaml +12 -0
- quadra/configs/logger/tensorboard.yaml +8 -0
- quadra/configs/loss/asl.yaml +7 -0
- quadra/configs/loss/barlow.yaml +2 -0
- quadra/configs/loss/bce.yaml +1 -0
- quadra/configs/loss/byol.yaml +1 -0
- quadra/configs/loss/cross_entropy.yaml +1 -0
- quadra/configs/loss/dino.yaml +8 -0
- quadra/configs/loss/simclr.yaml +2 -0
- quadra/configs/loss/simsiam.yaml +1 -0
- quadra/configs/loss/smp_ce.yaml +3 -0
- quadra/configs/loss/smp_dice.yaml +2 -0
- quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
- quadra/configs/loss/smp_mcc.yaml +2 -0
- quadra/configs/loss/vicreg.yaml +5 -0
- quadra/configs/model/anomalib/cfa.yaml +35 -0
- quadra/configs/model/anomalib/cflow.yaml +30 -0
- quadra/configs/model/anomalib/csflow.yaml +34 -0
- quadra/configs/model/anomalib/dfm.yaml +19 -0
- quadra/configs/model/anomalib/draem.yaml +29 -0
- quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
- quadra/configs/model/anomalib/fastflow.yaml +32 -0
- quadra/configs/model/anomalib/padim.yaml +32 -0
- quadra/configs/model/anomalib/patchcore.yaml +36 -0
- quadra/configs/model/barlow.yaml +16 -0
- quadra/configs/model/byol.yaml +25 -0
- quadra/configs/model/classification.yaml +10 -0
- quadra/configs/model/dino.yaml +26 -0
- quadra/configs/model/logistic_regression.yaml +4 -0
- quadra/configs/model/multilabel_classification.yaml +9 -0
- quadra/configs/model/simclr.yaml +18 -0
- quadra/configs/model/simsiam.yaml +24 -0
- quadra/configs/model/smp.yaml +4 -0
- quadra/configs/model/smp_multiclass.yaml +4 -0
- quadra/configs/model/vicreg.yaml +16 -0
- quadra/configs/optimizer/adam.yaml +5 -0
- quadra/configs/optimizer/adamw.yaml +3 -0
- quadra/configs/optimizer/default.yaml +4 -0
- quadra/configs/optimizer/lars.yaml +8 -0
- quadra/configs/optimizer/sgd.yaml +4 -0
- quadra/configs/scheduler/default.yaml +5 -0
- quadra/configs/scheduler/rop.yaml +5 -0
- quadra/configs/scheduler/step.yaml +3 -0
- quadra/configs/scheduler/warmrestart.yaml +2 -0
- quadra/configs/scheduler/warmup.yaml +6 -0
- quadra/configs/task/anomalib/cfa.yaml +5 -0
- quadra/configs/task/anomalib/cflow.yaml +5 -0
- quadra/configs/task/anomalib/csflow.yaml +5 -0
- quadra/configs/task/anomalib/draem.yaml +5 -0
- quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
- quadra/configs/task/anomalib/fastflow.yaml +5 -0
- quadra/configs/task/anomalib/inference.yaml +3 -0
- quadra/configs/task/anomalib/padim.yaml +5 -0
- quadra/configs/task/anomalib/patchcore.yaml +5 -0
- quadra/configs/task/classification.yaml +6 -0
- quadra/configs/task/classification_evaluation.yaml +6 -0
- quadra/configs/task/default.yaml +1 -0
- quadra/configs/task/segmentation.yaml +9 -0
- quadra/configs/task/segmentation_evaluation.yaml +3 -0
- quadra/configs/task/sklearn_classification.yaml +13 -0
- quadra/configs/task/sklearn_classification_patch.yaml +11 -0
- quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
- quadra/configs/task/sklearn_classification_test.yaml +8 -0
- quadra/configs/task/ssl.yaml +2 -0
- quadra/configs/trainer/lightning_cpu.yaml +36 -0
- quadra/configs/trainer/lightning_gpu.yaml +35 -0
- quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
- quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
- quadra/configs/trainer/lightning_multigpu.yaml +37 -0
- quadra/configs/trainer/sklearn_classification.yaml +7 -0
- quadra/configs/transforms/byol.yaml +47 -0
- quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
- quadra/configs/transforms/default.yaml +37 -0
- quadra/configs/transforms/default_numpy.yaml +24 -0
- quadra/configs/transforms/default_resize.yaml +22 -0
- quadra/configs/transforms/dino.yaml +63 -0
- quadra/configs/transforms/linear_eval.yaml +18 -0
- quadra/datamodules/__init__.py +20 -0
- quadra/datamodules/anomaly.py +180 -0
- quadra/datamodules/base.py +375 -0
- quadra/datamodules/classification.py +1003 -0
- quadra/datamodules/generic/__init__.py +0 -0
- quadra/datamodules/generic/imagenette.py +144 -0
- quadra/datamodules/generic/mnist.py +81 -0
- quadra/datamodules/generic/mvtec.py +58 -0
- quadra/datamodules/generic/oxford_pet.py +163 -0
- quadra/datamodules/patch.py +190 -0
- quadra/datamodules/segmentation.py +742 -0
- quadra/datamodules/ssl.py +140 -0
- quadra/datasets/__init__.py +17 -0
- quadra/datasets/anomaly.py +287 -0
- quadra/datasets/classification.py +241 -0
- quadra/datasets/patch.py +138 -0
- quadra/datasets/segmentation.py +239 -0
- quadra/datasets/ssl.py +110 -0
- quadra/losses/__init__.py +0 -0
- quadra/losses/classification/__init__.py +6 -0
- quadra/losses/classification/asl.py +83 -0
- quadra/losses/classification/focal.py +320 -0
- quadra/losses/classification/prototypical.py +148 -0
- quadra/losses/ssl/__init__.py +17 -0
- quadra/losses/ssl/barlowtwins.py +47 -0
- quadra/losses/ssl/byol.py +37 -0
- quadra/losses/ssl/dino.py +129 -0
- quadra/losses/ssl/hyperspherical.py +45 -0
- quadra/losses/ssl/idmm.py +50 -0
- quadra/losses/ssl/simclr.py +67 -0
- quadra/losses/ssl/simsiam.py +30 -0
- quadra/losses/ssl/vicreg.py +76 -0
- quadra/main.py +46 -0
- quadra/metrics/__init__.py +3 -0
- quadra/metrics/segmentation.py +251 -0
- quadra/models/__init__.py +0 -0
- quadra/models/base.py +151 -0
- quadra/models/classification/__init__.py +8 -0
- quadra/models/classification/backbones.py +149 -0
- quadra/models/classification/base.py +92 -0
- quadra/models/evaluation.py +322 -0
- quadra/modules/__init__.py +0 -0
- quadra/modules/backbone.py +30 -0
- quadra/modules/base.py +312 -0
- quadra/modules/classification/__init__.py +3 -0
- quadra/modules/classification/base.py +331 -0
- quadra/modules/ssl/__init__.py +17 -0
- quadra/modules/ssl/barlowtwins.py +59 -0
- quadra/modules/ssl/byol.py +172 -0
- quadra/modules/ssl/common.py +285 -0
- quadra/modules/ssl/dino.py +186 -0
- quadra/modules/ssl/hyperspherical.py +206 -0
- quadra/modules/ssl/idmm.py +98 -0
- quadra/modules/ssl/simclr.py +73 -0
- quadra/modules/ssl/simsiam.py +68 -0
- quadra/modules/ssl/vicreg.py +67 -0
- quadra/optimizers/__init__.py +4 -0
- quadra/optimizers/lars.py +153 -0
- quadra/optimizers/sam.py +127 -0
- quadra/schedulers/__init__.py +3 -0
- quadra/schedulers/base.py +44 -0
- quadra/schedulers/warmup.py +127 -0
- quadra/tasks/__init__.py +24 -0
- quadra/tasks/anomaly.py +582 -0
- quadra/tasks/base.py +397 -0
- quadra/tasks/classification.py +1264 -0
- quadra/tasks/patch.py +492 -0
- quadra/tasks/segmentation.py +389 -0
- quadra/tasks/ssl.py +560 -0
- quadra/trainers/README.md +3 -0
- quadra/trainers/__init__.py +0 -0
- quadra/trainers/classification.py +179 -0
- quadra/utils/__init__.py +0 -0
- quadra/utils/anomaly.py +112 -0
- quadra/utils/classification.py +618 -0
- quadra/utils/deprecation.py +31 -0
- quadra/utils/evaluation.py +474 -0
- quadra/utils/export.py +579 -0
- quadra/utils/imaging.py +32 -0
- quadra/utils/logger.py +15 -0
- quadra/utils/mlflow.py +98 -0
- quadra/utils/model_manager.py +320 -0
- quadra/utils/models.py +524 -0
- quadra/utils/patch/__init__.py +15 -0
- quadra/utils/patch/dataset.py +1433 -0
- quadra/utils/patch/metrics.py +449 -0
- quadra/utils/patch/model.py +153 -0
- quadra/utils/patch/visualization.py +217 -0
- quadra/utils/resolver.py +42 -0
- quadra/utils/segmentation.py +31 -0
- quadra/utils/tests/__init__.py +0 -0
- quadra/utils/tests/fixtures/__init__.py +1 -0
- quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
- quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
- quadra/utils/tests/fixtures/dataset/classification.py +406 -0
- quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
- quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
- quadra/utils/tests/fixtures/models/__init__.py +3 -0
- quadra/utils/tests/fixtures/models/anomaly.py +89 -0
- quadra/utils/tests/fixtures/models/classification.py +45 -0
- quadra/utils/tests/fixtures/models/segmentation.py +33 -0
- quadra/utils/tests/helpers.py +70 -0
- quadra/utils/tests/models.py +27 -0
- quadra/utils/utils.py +525 -0
- quadra/utils/validator.py +115 -0
- quadra/utils/visualization.py +422 -0
- quadra/utils/vit_explainability.py +349 -0
- quadra-2.1.13.dist-info/LICENSE +201 -0
- quadra-2.1.13.dist-info/METADATA +386 -0
- quadra-2.1.13.dist-info/RECORD +300 -0
- {quadra-0.0.1.dist-info → quadra-2.1.13.dist-info}/WHEEL +1 -1
- quadra-2.1.13.dist-info/entry_points.txt +3 -0
- quadra-0.0.1.dist-info/METADATA +0 -14
- quadra-0.0.1.dist-info/RECORD +0 -4
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
import torch
|
|
3
|
+
from anomalib.models.draem.torch_model import DraemModel
|
|
4
|
+
from anomalib.models.efficient_ad.torch_model import EfficientAdModel
|
|
5
|
+
from anomalib.models.padim.torch_model import PadimModel
|
|
6
|
+
from anomalib.models.patchcore.torch_model import PatchcoreModel
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@pytest.fixture
|
|
10
|
+
def padim_resnet18():
|
|
11
|
+
"""Yield a padim model with resnet18 encoder."""
|
|
12
|
+
yield PadimModel(
|
|
13
|
+
input_size=[224, 224], # TODO: This is hardcoded may be not a good idea
|
|
14
|
+
backbone="resnet18",
|
|
15
|
+
layers=["layer1", "layer2", "layer3"],
|
|
16
|
+
pretrained_weights=None,
|
|
17
|
+
tied_covariance=False,
|
|
18
|
+
pre_trained=False,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@torch.inference_mode()
|
|
23
|
+
def _initialize_patchcore_model(patchcore_model: PatchcoreModel, coreset_sampling_ratio: float = 0.1) -> PatchcoreModel:
|
|
24
|
+
"""Initialize a Patchcore model by simulating a training step.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
patchcore_model: Patchcore model to initialize
|
|
28
|
+
coreset_sampling_ratio: Coreset sampling ratio to use for the initialization
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
Patchcore model with initialized memory bank
|
|
32
|
+
"""
|
|
33
|
+
with torch.no_grad():
|
|
34
|
+
training_features = None
|
|
35
|
+
random_input = torch.rand([1, 3, *patchcore_model.input_size])
|
|
36
|
+
|
|
37
|
+
if training_features is None:
|
|
38
|
+
training_features = patchcore_model(random_input)
|
|
39
|
+
else:
|
|
40
|
+
training_features = torch.cat([training_features, patchcore_model(random_input)], dim=0)
|
|
41
|
+
|
|
42
|
+
patchcore_model.eval()
|
|
43
|
+
patchcore_model.subsample_embedding(training_features, sampling_ratio=coreset_sampling_ratio)
|
|
44
|
+
|
|
45
|
+
# Simulate a memory bank with 5 images, at the current stage patchcore onnx export is not handling
|
|
46
|
+
# large memory banks well, so we are using a small one for the benchmark
|
|
47
|
+
memory_bank_number, memory_bank_n_features = patchcore_model.memory_bank.shape
|
|
48
|
+
patchcore_model.memory_bank = torch.rand([5 * memory_bank_number, memory_bank_n_features])
|
|
49
|
+
patchcore_model.train()
|
|
50
|
+
|
|
51
|
+
return patchcore_model
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@pytest.fixture
|
|
55
|
+
def patchcore_resnet18():
|
|
56
|
+
"""Yield a patchcore model with resnet18 encoder."""
|
|
57
|
+
model = PatchcoreModel(
|
|
58
|
+
input_size=[224, 224], # TODO: This is hardcoded may be not a good idea
|
|
59
|
+
backbone="resnet18",
|
|
60
|
+
layers=["layer2", "layer3"],
|
|
61
|
+
pre_trained=False,
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
yield _initialize_patchcore_model(model)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@pytest.fixture
|
|
68
|
+
def draem():
|
|
69
|
+
"""Yield a draem model."""
|
|
70
|
+
yield DraemModel()
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@pytest.fixture
|
|
74
|
+
def efficient_ad_small():
|
|
75
|
+
"""Yield a draem model."""
|
|
76
|
+
|
|
77
|
+
class EfficientAdForwardWrapper(EfficientAdModel):
|
|
78
|
+
"""Wrap the forward method to avoid passing optional parameters."""
|
|
79
|
+
|
|
80
|
+
def forward(self, x):
|
|
81
|
+
return super().forward(x, None)
|
|
82
|
+
|
|
83
|
+
model = EfficientAdForwardWrapper(
|
|
84
|
+
teacher_out_channels=384,
|
|
85
|
+
input_size=[256, 256], # TODO: This is hardcoded may be not a good idea
|
|
86
|
+
pretrained_teacher_type="nelson",
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
yield model
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
|
|
3
|
+
from quadra.models.classification import TimmNetworkBuilder, TorchHubNetworkBuilder
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@pytest.fixture
|
|
7
|
+
def resnet18():
|
|
8
|
+
"""Yield a resnet18 model."""
|
|
9
|
+
yield TimmNetworkBuilder("resnet18", pretrained=False, freeze=True, exportable=True)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@pytest.fixture
|
|
13
|
+
def resnet50():
|
|
14
|
+
"""Yield a resnet50 model."""
|
|
15
|
+
yield TimmNetworkBuilder("resnet50", pretrained=False, freeze=True, exportable=True)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@pytest.fixture
|
|
19
|
+
def vit_tiny_patch16_224():
|
|
20
|
+
"""Yield a vit_tiny_patch16_224 model."""
|
|
21
|
+
yield TimmNetworkBuilder("vit_tiny_patch16_224", pretrained=False, freeze=True, exportable=True)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@pytest.fixture
|
|
25
|
+
def dino_vits8():
|
|
26
|
+
"""Yield a dino_vits8 model."""
|
|
27
|
+
yield TorchHubNetworkBuilder(
|
|
28
|
+
repo_or_dir="facebookresearch/dino:main",
|
|
29
|
+
model_name="dino_vits8",
|
|
30
|
+
pretrained=False,
|
|
31
|
+
freeze=True,
|
|
32
|
+
exportable=True,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@pytest.fixture
|
|
37
|
+
def dino_vitb8():
|
|
38
|
+
"""Yield a dino_vitb8 model."""
|
|
39
|
+
yield TorchHubNetworkBuilder(
|
|
40
|
+
repo_or_dir="facebookresearch/dino:main",
|
|
41
|
+
model_name="dino_vitb8",
|
|
42
|
+
pretrained=False,
|
|
43
|
+
freeze=True,
|
|
44
|
+
exportable=True,
|
|
45
|
+
)
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
|
|
3
|
+
from quadra.modules.backbone import create_smp_backbone
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@pytest.fixture
|
|
7
|
+
def smp_resnet18_unet():
|
|
8
|
+
"""Yield a unet with resnet18 encoder."""
|
|
9
|
+
yield create_smp_backbone(
|
|
10
|
+
arch="unet",
|
|
11
|
+
encoder_name="resnet18",
|
|
12
|
+
encoder_weights=None,
|
|
13
|
+
encoder_depth=5,
|
|
14
|
+
freeze_encoder=True,
|
|
15
|
+
in_channels=3,
|
|
16
|
+
num_classes=1,
|
|
17
|
+
activation=None,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@pytest.fixture
|
|
22
|
+
def smp_resnet18_unetplusplus():
|
|
23
|
+
"""Yield a unetplusplus with resnet18 encoder."""
|
|
24
|
+
yield create_smp_backbone(
|
|
25
|
+
arch="unetplusplus",
|
|
26
|
+
encoder_name="resnet18",
|
|
27
|
+
encoder_weights=None,
|
|
28
|
+
encoder_depth=5,
|
|
29
|
+
freeze_encoder=True,
|
|
30
|
+
in_channels=3,
|
|
31
|
+
num_classes=1,
|
|
32
|
+
activation=None,
|
|
33
|
+
)
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import torch
|
|
8
|
+
from hydra import compose, initialize_config_module
|
|
9
|
+
from hydra.core.hydra_config import HydraConfig
|
|
10
|
+
|
|
11
|
+
from quadra.main import main
|
|
12
|
+
from quadra.utils.export import get_export_extension
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
# taken from hydra unit tests
|
|
16
|
+
def _random_image(size: tuple[int, int] = (10, 10)) -> np.ndarray:
|
|
17
|
+
"""Generate random image."""
|
|
18
|
+
return np.random.randint(0, 255, size=size, dtype=np.uint8)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def execute_quadra_experiment(overrides: list[str], experiment_path: Path) -> None:
|
|
22
|
+
"""Execute quadra experiment."""
|
|
23
|
+
with initialize_config_module(config_module="quadra.configs", version_base="1.3.0"):
|
|
24
|
+
if not experiment_path.exists():
|
|
25
|
+
experiment_path.mkdir(parents=True)
|
|
26
|
+
os.chdir(experiment_path)
|
|
27
|
+
# cfg = compose(config_name="config", overrides=overrides)
|
|
28
|
+
cfg = compose(config_name="config", overrides=overrides, return_hydra_config=True)
|
|
29
|
+
# workaround without actual main function
|
|
30
|
+
# check https://github.com/facebookresearch/hydra/issues/2017 for more details
|
|
31
|
+
HydraConfig.instance().set_config(cfg)
|
|
32
|
+
|
|
33
|
+
main(cfg)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def check_deployment_model(export_type: str):
|
|
37
|
+
"""Check that the runtime model is present and valid.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
export_type: The type of the exported model.
|
|
41
|
+
"""
|
|
42
|
+
extension = get_export_extension(export_type)
|
|
43
|
+
|
|
44
|
+
assert os.path.exists(f"deployment_model/model.{extension}")
|
|
45
|
+
assert os.path.exists("deployment_model/model.json")
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def get_quadra_test_device():
|
|
49
|
+
"""Get the device to use for the tests. If the QUADRA_TEST_DEVICE environment variable is set, it is used."""
|
|
50
|
+
return os.environ.get("QUADRA_TEST_DEVICE", "cpu")
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def setup_trainer_for_lightning() -> list[str]:
|
|
54
|
+
"""Setup trainer for lightning depending on the device. If cuda is used, the device index is also set.
|
|
55
|
+
If cpu is used, the trainer is set to lightning_cpu.
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
A list of overrides for the trainer.
|
|
59
|
+
"""
|
|
60
|
+
overrides = []
|
|
61
|
+
device = get_quadra_test_device()
|
|
62
|
+
torch_device = torch.device(device)
|
|
63
|
+
if torch_device.type == "cuda":
|
|
64
|
+
device_index = torch_device.index
|
|
65
|
+
overrides.append("trainer=lightning_gpu")
|
|
66
|
+
overrides.append(f"trainer.devices=[{device_index}]")
|
|
67
|
+
else:
|
|
68
|
+
overrides.append("trainer=lightning_cpu")
|
|
69
|
+
|
|
70
|
+
return overrides
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import nn
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class SingleInputModel(nn.Module):
|
|
8
|
+
"""Model taking a single input."""
|
|
9
|
+
|
|
10
|
+
def forward(self, x: Any):
|
|
11
|
+
return x
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class DoubleInputModel(nn.Module):
|
|
15
|
+
"""Model taking two inputs."""
|
|
16
|
+
|
|
17
|
+
def forward(self, x: Any, y: Any):
|
|
18
|
+
return x, y
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class UnsupportedInputModel(nn.Module):
|
|
22
|
+
"""Model taking an unsupported input."""
|
|
23
|
+
|
|
24
|
+
def forward(self, x: torch.Tensor, y: str):
|
|
25
|
+
y = f"unsupported input: {y}"
|
|
26
|
+
|
|
27
|
+
return x
|