quadra 0.0.1__py3-none-any.whl → 2.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hydra_plugins/quadra_searchpath_plugin.py +30 -0
- quadra/__init__.py +6 -0
- quadra/callbacks/__init__.py +0 -0
- quadra/callbacks/anomalib.py +289 -0
- quadra/callbacks/lightning.py +501 -0
- quadra/callbacks/mlflow.py +291 -0
- quadra/callbacks/scheduler.py +69 -0
- quadra/configs/__init__.py +0 -0
- quadra/configs/backbone/caformer_m36.yaml +8 -0
- quadra/configs/backbone/caformer_s36.yaml +8 -0
- quadra/configs/backbone/convnextv2_base.yaml +8 -0
- quadra/configs/backbone/convnextv2_femto.yaml +8 -0
- quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
- quadra/configs/backbone/dino_vitb8.yaml +12 -0
- quadra/configs/backbone/dino_vits8.yaml +12 -0
- quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
- quadra/configs/backbone/dinov2_vits14.yaml +12 -0
- quadra/configs/backbone/efficientnet_b0.yaml +8 -0
- quadra/configs/backbone/efficientnet_b1.yaml +8 -0
- quadra/configs/backbone/efficientnet_b2.yaml +8 -0
- quadra/configs/backbone/efficientnet_b3.yaml +8 -0
- quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
- quadra/configs/backbone/levit_128s.yaml +8 -0
- quadra/configs/backbone/mnasnet0_5.yaml +9 -0
- quadra/configs/backbone/resnet101.yaml +8 -0
- quadra/configs/backbone/resnet18.yaml +8 -0
- quadra/configs/backbone/resnet18_ssl.yaml +8 -0
- quadra/configs/backbone/resnet50.yaml +8 -0
- quadra/configs/backbone/smp.yaml +9 -0
- quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
- quadra/configs/backbone/unetr.yaml +15 -0
- quadra/configs/backbone/vit16_base.yaml +9 -0
- quadra/configs/backbone/vit16_small.yaml +9 -0
- quadra/configs/backbone/vit16_tiny.yaml +9 -0
- quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
- quadra/configs/callbacks/all.yaml +45 -0
- quadra/configs/callbacks/default.yaml +34 -0
- quadra/configs/callbacks/default_anomalib.yaml +64 -0
- quadra/configs/config.yaml +33 -0
- quadra/configs/core/default.yaml +11 -0
- quadra/configs/datamodule/base/anomaly.yaml +16 -0
- quadra/configs/datamodule/base/classification.yaml +21 -0
- quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
- quadra/configs/datamodule/base/segmentation.yaml +18 -0
- quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
- quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
- quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
- quadra/configs/datamodule/base/ssl.yaml +21 -0
- quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
- quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
- quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
- quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
- quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
- quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
- quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
- quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
- quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
- quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
- quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/base/classification/classification.yaml +73 -0
- quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
- quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
- quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
- quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
- quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
- quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
- quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
- quadra/configs/experiment/base/ssl/byol.yaml +43 -0
- quadra/configs/experiment/base/ssl/dino.yaml +46 -0
- quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
- quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
- quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
- quadra/configs/experiment/custom/cls.yaml +12 -0
- quadra/configs/experiment/default.yaml +15 -0
- quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
- quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
- quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
- quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
- quadra/configs/export/default.yaml +13 -0
- quadra/configs/hydra/anomaly_custom.yaml +15 -0
- quadra/configs/hydra/default.yaml +14 -0
- quadra/configs/inference/default.yaml +26 -0
- quadra/configs/logger/comet.yaml +10 -0
- quadra/configs/logger/csv.yaml +5 -0
- quadra/configs/logger/mlflow.yaml +12 -0
- quadra/configs/logger/tensorboard.yaml +8 -0
- quadra/configs/loss/asl.yaml +7 -0
- quadra/configs/loss/barlow.yaml +2 -0
- quadra/configs/loss/bce.yaml +1 -0
- quadra/configs/loss/byol.yaml +1 -0
- quadra/configs/loss/cross_entropy.yaml +1 -0
- quadra/configs/loss/dino.yaml +8 -0
- quadra/configs/loss/simclr.yaml +2 -0
- quadra/configs/loss/simsiam.yaml +1 -0
- quadra/configs/loss/smp_ce.yaml +3 -0
- quadra/configs/loss/smp_dice.yaml +2 -0
- quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
- quadra/configs/loss/smp_mcc.yaml +2 -0
- quadra/configs/loss/vicreg.yaml +5 -0
- quadra/configs/model/anomalib/cfa.yaml +35 -0
- quadra/configs/model/anomalib/cflow.yaml +30 -0
- quadra/configs/model/anomalib/csflow.yaml +34 -0
- quadra/configs/model/anomalib/dfm.yaml +19 -0
- quadra/configs/model/anomalib/draem.yaml +29 -0
- quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
- quadra/configs/model/anomalib/fastflow.yaml +32 -0
- quadra/configs/model/anomalib/padim.yaml +32 -0
- quadra/configs/model/anomalib/patchcore.yaml +36 -0
- quadra/configs/model/barlow.yaml +16 -0
- quadra/configs/model/byol.yaml +25 -0
- quadra/configs/model/classification.yaml +10 -0
- quadra/configs/model/dino.yaml +26 -0
- quadra/configs/model/logistic_regression.yaml +4 -0
- quadra/configs/model/multilabel_classification.yaml +9 -0
- quadra/configs/model/simclr.yaml +18 -0
- quadra/configs/model/simsiam.yaml +24 -0
- quadra/configs/model/smp.yaml +4 -0
- quadra/configs/model/smp_multiclass.yaml +4 -0
- quadra/configs/model/vicreg.yaml +16 -0
- quadra/configs/optimizer/adam.yaml +5 -0
- quadra/configs/optimizer/adamw.yaml +3 -0
- quadra/configs/optimizer/default.yaml +4 -0
- quadra/configs/optimizer/lars.yaml +8 -0
- quadra/configs/optimizer/sgd.yaml +4 -0
- quadra/configs/scheduler/default.yaml +5 -0
- quadra/configs/scheduler/rop.yaml +5 -0
- quadra/configs/scheduler/step.yaml +3 -0
- quadra/configs/scheduler/warmrestart.yaml +2 -0
- quadra/configs/scheduler/warmup.yaml +6 -0
- quadra/configs/task/anomalib/cfa.yaml +5 -0
- quadra/configs/task/anomalib/cflow.yaml +5 -0
- quadra/configs/task/anomalib/csflow.yaml +5 -0
- quadra/configs/task/anomalib/draem.yaml +5 -0
- quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
- quadra/configs/task/anomalib/fastflow.yaml +5 -0
- quadra/configs/task/anomalib/inference.yaml +3 -0
- quadra/configs/task/anomalib/padim.yaml +5 -0
- quadra/configs/task/anomalib/patchcore.yaml +5 -0
- quadra/configs/task/classification.yaml +6 -0
- quadra/configs/task/classification_evaluation.yaml +6 -0
- quadra/configs/task/default.yaml +1 -0
- quadra/configs/task/segmentation.yaml +9 -0
- quadra/configs/task/segmentation_evaluation.yaml +3 -0
- quadra/configs/task/sklearn_classification.yaml +13 -0
- quadra/configs/task/sklearn_classification_patch.yaml +11 -0
- quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
- quadra/configs/task/sklearn_classification_test.yaml +8 -0
- quadra/configs/task/ssl.yaml +2 -0
- quadra/configs/trainer/lightning_cpu.yaml +36 -0
- quadra/configs/trainer/lightning_gpu.yaml +35 -0
- quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
- quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
- quadra/configs/trainer/lightning_multigpu.yaml +37 -0
- quadra/configs/trainer/sklearn_classification.yaml +7 -0
- quadra/configs/transforms/byol.yaml +47 -0
- quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
- quadra/configs/transforms/default.yaml +37 -0
- quadra/configs/transforms/default_numpy.yaml +24 -0
- quadra/configs/transforms/default_resize.yaml +22 -0
- quadra/configs/transforms/dino.yaml +63 -0
- quadra/configs/transforms/linear_eval.yaml +18 -0
- quadra/datamodules/__init__.py +20 -0
- quadra/datamodules/anomaly.py +180 -0
- quadra/datamodules/base.py +375 -0
- quadra/datamodules/classification.py +1003 -0
- quadra/datamodules/generic/__init__.py +0 -0
- quadra/datamodules/generic/imagenette.py +144 -0
- quadra/datamodules/generic/mnist.py +81 -0
- quadra/datamodules/generic/mvtec.py +58 -0
- quadra/datamodules/generic/oxford_pet.py +163 -0
- quadra/datamodules/patch.py +190 -0
- quadra/datamodules/segmentation.py +742 -0
- quadra/datamodules/ssl.py +140 -0
- quadra/datasets/__init__.py +17 -0
- quadra/datasets/anomaly.py +287 -0
- quadra/datasets/classification.py +241 -0
- quadra/datasets/patch.py +138 -0
- quadra/datasets/segmentation.py +239 -0
- quadra/datasets/ssl.py +110 -0
- quadra/losses/__init__.py +0 -0
- quadra/losses/classification/__init__.py +6 -0
- quadra/losses/classification/asl.py +83 -0
- quadra/losses/classification/focal.py +320 -0
- quadra/losses/classification/prototypical.py +148 -0
- quadra/losses/ssl/__init__.py +17 -0
- quadra/losses/ssl/barlowtwins.py +47 -0
- quadra/losses/ssl/byol.py +37 -0
- quadra/losses/ssl/dino.py +129 -0
- quadra/losses/ssl/hyperspherical.py +45 -0
- quadra/losses/ssl/idmm.py +50 -0
- quadra/losses/ssl/simclr.py +67 -0
- quadra/losses/ssl/simsiam.py +30 -0
- quadra/losses/ssl/vicreg.py +76 -0
- quadra/main.py +49 -0
- quadra/metrics/__init__.py +3 -0
- quadra/metrics/segmentation.py +251 -0
- quadra/models/__init__.py +0 -0
- quadra/models/base.py +151 -0
- quadra/models/classification/__init__.py +8 -0
- quadra/models/classification/backbones.py +149 -0
- quadra/models/classification/base.py +92 -0
- quadra/models/evaluation.py +322 -0
- quadra/modules/__init__.py +0 -0
- quadra/modules/backbone.py +30 -0
- quadra/modules/base.py +312 -0
- quadra/modules/classification/__init__.py +3 -0
- quadra/modules/classification/base.py +327 -0
- quadra/modules/ssl/__init__.py +17 -0
- quadra/modules/ssl/barlowtwins.py +59 -0
- quadra/modules/ssl/byol.py +172 -0
- quadra/modules/ssl/common.py +285 -0
- quadra/modules/ssl/dino.py +186 -0
- quadra/modules/ssl/hyperspherical.py +206 -0
- quadra/modules/ssl/idmm.py +98 -0
- quadra/modules/ssl/simclr.py +73 -0
- quadra/modules/ssl/simsiam.py +68 -0
- quadra/modules/ssl/vicreg.py +67 -0
- quadra/optimizers/__init__.py +4 -0
- quadra/optimizers/lars.py +153 -0
- quadra/optimizers/sam.py +127 -0
- quadra/schedulers/__init__.py +3 -0
- quadra/schedulers/base.py +44 -0
- quadra/schedulers/warmup.py +127 -0
- quadra/tasks/__init__.py +24 -0
- quadra/tasks/anomaly.py +582 -0
- quadra/tasks/base.py +397 -0
- quadra/tasks/classification.py +1263 -0
- quadra/tasks/patch.py +492 -0
- quadra/tasks/segmentation.py +389 -0
- quadra/tasks/ssl.py +560 -0
- quadra/trainers/README.md +3 -0
- quadra/trainers/__init__.py +0 -0
- quadra/trainers/classification.py +179 -0
- quadra/utils/__init__.py +0 -0
- quadra/utils/anomaly.py +112 -0
- quadra/utils/classification.py +618 -0
- quadra/utils/deprecation.py +31 -0
- quadra/utils/evaluation.py +474 -0
- quadra/utils/export.py +585 -0
- quadra/utils/imaging.py +32 -0
- quadra/utils/logger.py +15 -0
- quadra/utils/mlflow.py +98 -0
- quadra/utils/model_manager.py +320 -0
- quadra/utils/models.py +523 -0
- quadra/utils/patch/__init__.py +15 -0
- quadra/utils/patch/dataset.py +1433 -0
- quadra/utils/patch/metrics.py +449 -0
- quadra/utils/patch/model.py +153 -0
- quadra/utils/patch/visualization.py +217 -0
- quadra/utils/resolver.py +42 -0
- quadra/utils/segmentation.py +31 -0
- quadra/utils/tests/__init__.py +0 -0
- quadra/utils/tests/fixtures/__init__.py +1 -0
- quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
- quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
- quadra/utils/tests/fixtures/dataset/classification.py +406 -0
- quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
- quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
- quadra/utils/tests/fixtures/models/__init__.py +3 -0
- quadra/utils/tests/fixtures/models/anomaly.py +89 -0
- quadra/utils/tests/fixtures/models/classification.py +45 -0
- quadra/utils/tests/fixtures/models/segmentation.py +33 -0
- quadra/utils/tests/helpers.py +70 -0
- quadra/utils/tests/models.py +27 -0
- quadra/utils/utils.py +525 -0
- quadra/utils/validator.py +115 -0
- quadra/utils/visualization.py +422 -0
- quadra/utils/vit_explainability.py +349 -0
- quadra-2.2.7.dist-info/LICENSE +201 -0
- quadra-2.2.7.dist-info/METADATA +381 -0
- quadra-2.2.7.dist-info/RECORD +300 -0
- {quadra-0.0.1.dist-info → quadra-2.2.7.dist-info}/WHEEL +1 -1
- quadra-2.2.7.dist-info/entry_points.txt +3 -0
- quadra-0.0.1.dist-info/METADATA +0 -14
- quadra-0.0.1.dist-info/RECORD +0 -4
|
@@ -0,0 +1,406 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import glob
|
|
4
|
+
import os
|
|
5
|
+
import random
|
|
6
|
+
import shutil
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
import cv2
|
|
12
|
+
import numpy as np
|
|
13
|
+
import pytest
|
|
14
|
+
|
|
15
|
+
from quadra.utils.patch import generate_patch_dataset, get_image_mask_association
|
|
16
|
+
from quadra.utils.tests.helpers import _random_image
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class ClassificationDatasetArguments:
|
|
21
|
+
"""Classification dataset arguments.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
samples: number of samples per class
|
|
25
|
+
classes: class names, if set it must be the same length as samples
|
|
26
|
+
val_size: validation set size
|
|
27
|
+
test_size: test set size
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
samples: list[int]
|
|
31
|
+
classes: list[str] | None = None
|
|
32
|
+
val_size: float | None = None
|
|
33
|
+
test_size: float | None = None
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass
|
|
37
|
+
class ClassificationMultilabelDatasetArguments:
|
|
38
|
+
"""Classification dataset arguments.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
samples: number of samples per class
|
|
42
|
+
classes: class names, if set it must be the same length as samples
|
|
43
|
+
val_size: validation set size
|
|
44
|
+
test_size: test set size
|
|
45
|
+
percentage_other_classes: probability of adding other classes to the labels of each sample
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
samples: list[int]
|
|
49
|
+
classes: list[str] | None = None
|
|
50
|
+
val_size: float | None = None
|
|
51
|
+
test_size: float | None = None
|
|
52
|
+
percentage_other_classes: float | None = 0.0
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@dataclass
|
|
56
|
+
class ClassificationPatchDatasetArguments:
|
|
57
|
+
"""Classification patch dataset arguments.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
samples: number of samples per class
|
|
61
|
+
overlap: overlap between patches
|
|
62
|
+
patch_size: patch size
|
|
63
|
+
patch_number: number of patches
|
|
64
|
+
classes: class names, if set it must be the same length as samples
|
|
65
|
+
val_size: validation set size
|
|
66
|
+
test_size: test set size
|
|
67
|
+
annotated_good: list of class names that are considered as good annotations (E.g. ["good"])
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
samples: list[int]
|
|
71
|
+
overlap: float
|
|
72
|
+
patch_size: tuple[int, int] | None = None
|
|
73
|
+
patch_number: tuple[int, int] | None = None
|
|
74
|
+
classes: list[str] | None = None
|
|
75
|
+
val_size: float | None = 0.0
|
|
76
|
+
test_size: float | None = 0.0
|
|
77
|
+
annotated_good: list[str] | None = None
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _build_classification_dataset(
|
|
81
|
+
tmp_path: Path, dataset_arguments: ClassificationDatasetArguments
|
|
82
|
+
) -> tuple[str, ClassificationDatasetArguments]:
|
|
83
|
+
"""Generate classification dataset. If val_size or test_size are set, it will generate a train.txt, val.txt and
|
|
84
|
+
test.txt file in the dataset directory. By default generated images are 10x10 pixels.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
tmp_path: path to temporary directory
|
|
88
|
+
dataset_arguments: dataset arguments
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
Tuple containing path to created dataset and dataset arguments
|
|
92
|
+
"""
|
|
93
|
+
classification_dataset_path = tmp_path / "classification_dataset"
|
|
94
|
+
classification_dataset_path.mkdir()
|
|
95
|
+
|
|
96
|
+
classes = dataset_arguments.classes if dataset_arguments.classes else range(len(dataset_arguments.samples))
|
|
97
|
+
|
|
98
|
+
for class_name, samples in zip(classes, dataset_arguments.samples):
|
|
99
|
+
class_path = classification_dataset_path / str(class_name)
|
|
100
|
+
class_path.mkdir()
|
|
101
|
+
for i in range(samples):
|
|
102
|
+
image = _random_image()
|
|
103
|
+
image_path = class_path / f"{class_name}_{i}.png"
|
|
104
|
+
cv2.imwrite(str(image_path), image)
|
|
105
|
+
|
|
106
|
+
if dataset_arguments.val_size is not None or dataset_arguments.test_size is not None:
|
|
107
|
+
all_images = glob.glob(os.path.join(str(classification_dataset_path), "**", "*.png"))
|
|
108
|
+
all_images = [f"{os.path.basename(os.path.dirname(image))}/{os.path.basename(image)}" for image in all_images]
|
|
109
|
+
val_size = dataset_arguments.val_size if dataset_arguments.val_size is not None else 0
|
|
110
|
+
test_size = dataset_arguments.test_size if dataset_arguments.test_size is not None else 0
|
|
111
|
+
train_size = 1 - val_size - test_size
|
|
112
|
+
|
|
113
|
+
# pylint: disable=unbalanced-tuple-unpacking
|
|
114
|
+
train_images, val_images, test_images = np.split(
|
|
115
|
+
np.random.permutation(all_images),
|
|
116
|
+
[int(train_size * len(all_images)), int((train_size + val_size) * len(all_images))],
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
with open(classification_dataset_path / "train.txt", "w") as f:
|
|
120
|
+
f.write("\n".join(train_images))
|
|
121
|
+
|
|
122
|
+
with open(classification_dataset_path / "val.txt", "w") as f:
|
|
123
|
+
f.write("\n".join(val_images))
|
|
124
|
+
|
|
125
|
+
with open(classification_dataset_path / "test.txt", "w") as f:
|
|
126
|
+
f.write("\n".join(test_images))
|
|
127
|
+
|
|
128
|
+
return str(classification_dataset_path), dataset_arguments
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
@pytest.fixture
|
|
132
|
+
def classification_dataset(
|
|
133
|
+
tmp_path: Path, dataset_arguments: ClassificationDatasetArguments
|
|
134
|
+
) -> tuple[str, ClassificationDatasetArguments]:
|
|
135
|
+
"""Generate classification dataset. If val_size or test_size are set, it will generate a train.txt, val.txt and
|
|
136
|
+
test.txt file in the dataset directory. By default generated images are 10x10 pixels.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
tmp_path: path to temporary directory
|
|
140
|
+
dataset_arguments: dataset arguments
|
|
141
|
+
|
|
142
|
+
Yields:
|
|
143
|
+
Tuple containing path to created dataset and dataset arguments
|
|
144
|
+
"""
|
|
145
|
+
yield _build_classification_dataset(tmp_path, dataset_arguments)
|
|
146
|
+
if tmp_path.exists():
|
|
147
|
+
shutil.rmtree(tmp_path)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
@pytest.fixture(
|
|
151
|
+
params=[
|
|
152
|
+
ClassificationDatasetArguments(
|
|
153
|
+
**{"samples": [10, 10], "classes": ["class_1", "class_2"], "val_size": 0.1, "test_size": 0.1}
|
|
154
|
+
)
|
|
155
|
+
]
|
|
156
|
+
)
|
|
157
|
+
def base_classification_dataset(tmp_path: Path, request: Any) -> tuple[str, ClassificationDatasetArguments]:
|
|
158
|
+
"""Generate base classification dataset with the following parameters:
|
|
159
|
+
- 10 samples per class
|
|
160
|
+
- 2 classes (class_1 and class_2)
|
|
161
|
+
By default generated images are grayscale and 10x10 pixels.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
tmp_path: path to temporary directory
|
|
165
|
+
request: pytest request
|
|
166
|
+
|
|
167
|
+
Yields:
|
|
168
|
+
Tuple containing path to created dataset and dataset arguments
|
|
169
|
+
"""
|
|
170
|
+
yield _build_classification_dataset(tmp_path, request.param)
|
|
171
|
+
if tmp_path.exists():
|
|
172
|
+
shutil.rmtree(tmp_path)
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def _build_multilabel_classification_dataset(
|
|
176
|
+
tmp_path: Path, dataset_arguments: ClassificationMultilabelDatasetArguments
|
|
177
|
+
) -> tuple[str, ClassificationMultilabelDatasetArguments]:
|
|
178
|
+
"""Generate a multilabel classification dataset.
|
|
179
|
+
Generates a samples.txt file in the dataset directory containing the path to the image and the corresponding
|
|
180
|
+
classes. If val_size or test_size are set, it will generate a train.txt, val.txt and test.txt file in the
|
|
181
|
+
dataset directory. By default generated images are 10x10 pixels.
|
|
182
|
+
|
|
183
|
+
Args:
|
|
184
|
+
tmp_path: path to temporary directory
|
|
185
|
+
dataset_arguments: dataset arguments
|
|
186
|
+
|
|
187
|
+
Returns:
|
|
188
|
+
Tuple containing path to created dataset and dataset arguments
|
|
189
|
+
"""
|
|
190
|
+
classification_dataset_path = tmp_path / "multilabel_classification_dataset"
|
|
191
|
+
images_path = classification_dataset_path / "images"
|
|
192
|
+
classification_dataset_path.mkdir()
|
|
193
|
+
images_path.mkdir()
|
|
194
|
+
|
|
195
|
+
classes = dataset_arguments.classes if dataset_arguments.classes else range(len(dataset_arguments.samples))
|
|
196
|
+
percentage_other_classes = dataset_arguments.percentage_other_classes
|
|
197
|
+
|
|
198
|
+
generated_samples = []
|
|
199
|
+
counter = 0
|
|
200
|
+
for class_name, samples in zip(classes, dataset_arguments.samples):
|
|
201
|
+
for _ in range(samples):
|
|
202
|
+
image = _random_image()
|
|
203
|
+
image_path = images_path / f"{counter}.png"
|
|
204
|
+
counter += 1
|
|
205
|
+
cv2.imwrite(str(image_path), image)
|
|
206
|
+
targets = [class_name]
|
|
207
|
+
targets = targets + [
|
|
208
|
+
cl_name for cl_name in classes if cl_name != class_name and random.random() < percentage_other_classes
|
|
209
|
+
]
|
|
210
|
+
generated_samples.append(f"images/{image_path.name},{','.join(targets)}")
|
|
211
|
+
|
|
212
|
+
with open(classification_dataset_path / "samples.txt", "w") as f:
|
|
213
|
+
f.write("\n".join(generated_samples))
|
|
214
|
+
|
|
215
|
+
if dataset_arguments.val_size is not None or dataset_arguments.test_size is not None:
|
|
216
|
+
val_size = dataset_arguments.val_size if dataset_arguments.val_size is not None else 0
|
|
217
|
+
test_size = dataset_arguments.test_size if dataset_arguments.test_size is not None else 0
|
|
218
|
+
train_size = 1 - val_size - test_size
|
|
219
|
+
|
|
220
|
+
# pylint: disable=unbalanced-tuple-unpacking
|
|
221
|
+
train_images, val_images, test_images = np.split(
|
|
222
|
+
np.random.permutation(generated_samples),
|
|
223
|
+
[int(train_size * len(generated_samples)), int((train_size + val_size) * len(generated_samples))],
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
with open(classification_dataset_path / "train.txt", "w") as f:
|
|
227
|
+
f.write("\n".join(train_images))
|
|
228
|
+
|
|
229
|
+
with open(classification_dataset_path / "val.txt", "w") as f:
|
|
230
|
+
f.write("\n".join(val_images))
|
|
231
|
+
|
|
232
|
+
with open(classification_dataset_path / "test.txt", "w") as f:
|
|
233
|
+
f.write("\n".join(test_images))
|
|
234
|
+
|
|
235
|
+
return str(classification_dataset_path), dataset_arguments
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
@pytest.fixture
|
|
239
|
+
def multilabel_classification_dataset(
|
|
240
|
+
tmp_path: Path, dataset_arguments: ClassificationMultilabelDatasetArguments
|
|
241
|
+
) -> tuple[str, ClassificationMultilabelDatasetArguments]:
|
|
242
|
+
"""Fixture to dinamically generate a multilabel classification dataset.
|
|
243
|
+
Generates a samples.txt file in the dataset directory containing the path to the image and the corresponding
|
|
244
|
+
classes. If val_size or test_size are set, it will generate a train.txt, val.txt and test.txt file in the
|
|
245
|
+
dataset directory. By default generated images are 10x10 pixels.
|
|
246
|
+
|
|
247
|
+
Args:
|
|
248
|
+
tmp_path: path to temporary directory
|
|
249
|
+
dataset_arguments: dataset arguments
|
|
250
|
+
|
|
251
|
+
Returns:
|
|
252
|
+
Tuple containing path to created dataset and dataset arguments
|
|
253
|
+
"""
|
|
254
|
+
yield _build_multilabel_classification_dataset(tmp_path, dataset_arguments)
|
|
255
|
+
if tmp_path.exists():
|
|
256
|
+
shutil.rmtree(tmp_path)
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
@pytest.fixture(
|
|
260
|
+
params=[
|
|
261
|
+
ClassificationMultilabelDatasetArguments(
|
|
262
|
+
**{
|
|
263
|
+
"samples": [10, 10, 10],
|
|
264
|
+
"classes": ["class_1", "class_2", "class_3"],
|
|
265
|
+
"val_size": 0.1,
|
|
266
|
+
"test_size": 0.1,
|
|
267
|
+
"percentage_other_classes": 0.3,
|
|
268
|
+
}
|
|
269
|
+
)
|
|
270
|
+
]
|
|
271
|
+
)
|
|
272
|
+
def base_multilabel_classification_dataset(
|
|
273
|
+
tmp_path: Path, request: Any
|
|
274
|
+
) -> tuple[str, ClassificationMultilabelDatasetArguments]:
|
|
275
|
+
"""Fixture to generate base multilabel classification dataset with the following parameters:
|
|
276
|
+
- 10 samples per class
|
|
277
|
+
- 3 classes (class_1, class_2 and class_3)
|
|
278
|
+
- 10% of samples in validation set
|
|
279
|
+
- 10% of samples in test set
|
|
280
|
+
- 30% of possibility to add each other class to the sample
|
|
281
|
+
By default generated images are grayscale and 10x10 pixels.
|
|
282
|
+
|
|
283
|
+
Args:
|
|
284
|
+
tmp_path: path to temporary directory
|
|
285
|
+
request: pytest request
|
|
286
|
+
|
|
287
|
+
Yields:
|
|
288
|
+
Tuple containing path to created dataset and dataset arguments
|
|
289
|
+
"""
|
|
290
|
+
yield _build_multilabel_classification_dataset(tmp_path, request.param)
|
|
291
|
+
if tmp_path.exists():
|
|
292
|
+
shutil.rmtree(tmp_path)
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def _build_classification_patch_dataset(
|
|
296
|
+
tmp_path: Path, dataset_arguments: ClassificationDatasetArguments
|
|
297
|
+
) -> tuple[str, ClassificationDatasetArguments, dict[str, int]]:
|
|
298
|
+
"""Generate a classification patch dataset. By default generated images are 224x224 pixels
|
|
299
|
+
and associated masks contains a 50x50 pixels square with the corresponding image class, so at the current stage
|
|
300
|
+
is not possible to have images with multiple annotations. The patch dataset will be generated using the standard
|
|
301
|
+
parameters of generate_patch_dataset function.
|
|
302
|
+
|
|
303
|
+
Args:
|
|
304
|
+
tmp_path: path to temporary directory
|
|
305
|
+
dataset_arguments: dataset arguments
|
|
306
|
+
|
|
307
|
+
Returns:
|
|
308
|
+
Tuple containing path to created dataset, dataset arguments and class to index mapping
|
|
309
|
+
"""
|
|
310
|
+
initial_dataset_path = tmp_path / "initial_dataset"
|
|
311
|
+
initial_dataset_path.mkdir()
|
|
312
|
+
|
|
313
|
+
images_path = initial_dataset_path / "images"
|
|
314
|
+
masks_path = initial_dataset_path / "masks"
|
|
315
|
+
images_path.mkdir()
|
|
316
|
+
masks_path.mkdir()
|
|
317
|
+
|
|
318
|
+
classes = dataset_arguments.classes if dataset_arguments.classes else range(len(dataset_arguments.samples))
|
|
319
|
+
|
|
320
|
+
class_to_idx = {class_name: i for i, class_name in enumerate(classes)}
|
|
321
|
+
|
|
322
|
+
for class_name, samples in zip(classes, dataset_arguments.samples):
|
|
323
|
+
for i in range(samples):
|
|
324
|
+
image = _random_image(size=(224, 224))
|
|
325
|
+
mask = np.zeros((224, 224), dtype=np.uint8)
|
|
326
|
+
mask[100:150, 100:150] = class_to_idx[class_name]
|
|
327
|
+
image_path = images_path / f"{class_name}_{i}.png"
|
|
328
|
+
mask_path = masks_path / f"{class_name}_{i}.png"
|
|
329
|
+
cv2.imwrite(str(image_path), image)
|
|
330
|
+
cv2.imwrite(str(mask_path), mask)
|
|
331
|
+
|
|
332
|
+
patch_dataset_path = tmp_path / "patch_dataset"
|
|
333
|
+
patch_dataset_path.mkdir()
|
|
334
|
+
|
|
335
|
+
data_dictionary = get_image_mask_association(data_folder=str(images_path), mask_folder=str(masks_path))
|
|
336
|
+
|
|
337
|
+
_ = generate_patch_dataset(
|
|
338
|
+
data_dictionary=data_dictionary,
|
|
339
|
+
class_to_idx=class_to_idx,
|
|
340
|
+
val_size=dataset_arguments.val_size,
|
|
341
|
+
test_size=dataset_arguments.test_size,
|
|
342
|
+
patch_number=dataset_arguments.patch_number,
|
|
343
|
+
patch_size=dataset_arguments.patch_size,
|
|
344
|
+
overlap=dataset_arguments.overlap,
|
|
345
|
+
output_folder=str(patch_dataset_path),
|
|
346
|
+
annotated_good=dataset_arguments.annotated_good,
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
return str(patch_dataset_path), dataset_arguments, class_to_idx
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
@pytest.fixture
|
|
353
|
+
def classification_patch_dataset(
|
|
354
|
+
tmp_path: Path, dataset_arguments: ClassificationDatasetArguments
|
|
355
|
+
) -> tuple[str, ClassificationDatasetArguments, dict[str, int]]:
|
|
356
|
+
"""Fixture to dinamically generate a classification patch dataset.
|
|
357
|
+
|
|
358
|
+
By default generated images are 224x224 pixels
|
|
359
|
+
and associated masks contains a 50x50 pixels square with the corresponding image class, so at the current stage
|
|
360
|
+
is not possible to have images with multiple annotations. The patch dataset will be generated using the standard
|
|
361
|
+
parameters of generate_patch_dataset function.
|
|
362
|
+
|
|
363
|
+
Args:
|
|
364
|
+
tmp_path: path to temporary directory
|
|
365
|
+
dataset_arguments: dataset arguments
|
|
366
|
+
|
|
367
|
+
Yields:
|
|
368
|
+
Tuple containing path to created dataset, dataset arguments and class to index mapping
|
|
369
|
+
"""
|
|
370
|
+
yield _build_classification_patch_dataset(tmp_path, dataset_arguments)
|
|
371
|
+
if tmp_path.exists():
|
|
372
|
+
shutil.rmtree(tmp_path)
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
@pytest.fixture(
|
|
376
|
+
params=[
|
|
377
|
+
ClassificationPatchDatasetArguments(
|
|
378
|
+
**{
|
|
379
|
+
"samples": [5, 5, 5],
|
|
380
|
+
"classes": ["bg", "a", "b"],
|
|
381
|
+
"patch_number": [2, 2],
|
|
382
|
+
"overlap": 0,
|
|
383
|
+
"val_size": 0.1,
|
|
384
|
+
"test_size": 0.1,
|
|
385
|
+
}
|
|
386
|
+
)
|
|
387
|
+
]
|
|
388
|
+
)
|
|
389
|
+
def base_patch_classification_dataset(
|
|
390
|
+
tmp_path: Path, request: Any
|
|
391
|
+
) -> tuple[str, ClassificationDatasetArguments, dict[str, int]]:
|
|
392
|
+
"""Generate a classification patch dataset with the following parameters:
|
|
393
|
+
- 3 classes named bg, a and b
|
|
394
|
+
- 5, 5 and 5 samples for each class
|
|
395
|
+
- 2 horizontal patches and 2 vertical patches
|
|
396
|
+
- 0% overlap
|
|
397
|
+
- 10% validation set
|
|
398
|
+
- 10% test set.
|
|
399
|
+
|
|
400
|
+
Args:
|
|
401
|
+
tmp_path: path to temporary directory
|
|
402
|
+
request: pytest SubRequest object
|
|
403
|
+
"""
|
|
404
|
+
yield _build_classification_patch_dataset(tmp_path, request.param)
|
|
405
|
+
if tmp_path.exists():
|
|
406
|
+
shutil.rmtree(tmp_path)
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
import shutil
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
import cv2
|
|
5
|
+
import pytest
|
|
6
|
+
|
|
7
|
+
from quadra.utils.tests.helpers import _random_image
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _build_imagenette_dataset(tmp_path: Path, classes: int, class_samples: int) -> str:
|
|
11
|
+
"""Generate imagenette dataset in the format required by efficient_ad model.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
tmp_path: Path to temporary directory
|
|
15
|
+
classes: Number of mock imagenette classes
|
|
16
|
+
class_samples: Number of samples for each mock imagenette class
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
Path to imagenette dataset
|
|
20
|
+
"""
|
|
21
|
+
parent_path = tmp_path / "imagenette2"
|
|
22
|
+
parent_path.mkdir()
|
|
23
|
+
train_path = parent_path / "train"
|
|
24
|
+
train_path.mkdir()
|
|
25
|
+
val_path = parent_path / "val"
|
|
26
|
+
val_path.mkdir()
|
|
27
|
+
|
|
28
|
+
for split in [train_path, val_path]:
|
|
29
|
+
for i in range(classes):
|
|
30
|
+
cl_path = split / f"class_{i}"
|
|
31
|
+
cl_path.mkdir()
|
|
32
|
+
for j in range(class_samples):
|
|
33
|
+
image = _random_image()
|
|
34
|
+
image_path = cl_path / f"fake_{j}.png"
|
|
35
|
+
cv2.imwrite(str(image_path), image)
|
|
36
|
+
|
|
37
|
+
return parent_path
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@pytest.fixture
|
|
41
|
+
def imagenette_dataset(tmp_path: Path) -> str:
|
|
42
|
+
"""Generate a mock imagenette dataset to test efficient_ad model.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
tmp_path: Path to temporary directory
|
|
46
|
+
request: Pytest SubRequest object
|
|
47
|
+
Yields:
|
|
48
|
+
Path to imagenette dataset folder
|
|
49
|
+
"""
|
|
50
|
+
yield _build_imagenette_dataset(tmp_path, classes=3, class_samples=3)
|
|
51
|
+
|
|
52
|
+
if tmp_path.exists():
|
|
53
|
+
shutil.rmtree(tmp_path)
|
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import shutil
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import cv2
|
|
9
|
+
import numpy as np
|
|
10
|
+
import pytest
|
|
11
|
+
|
|
12
|
+
from quadra.utils.tests.helpers import _random_image
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class SegmentationDatasetArguments:
|
|
17
|
+
"""Segmentation dataset arguments.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
train_samples: List of samples per class in train set, element at index 0 are good samples
|
|
21
|
+
val_samples: List of samples per class in validation set, same as above.
|
|
22
|
+
test_samples: List of samples per class in test set, same as above.
|
|
23
|
+
classes: Optional list of class names, must be equal to len(train_samples) - 1
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
train_samples: list[int]
|
|
27
|
+
val_samples: list[int] | None = None
|
|
28
|
+
test_samples: list[int] | None = None
|
|
29
|
+
classes: list[str] | None = None
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _build_segmentation_dataset(
|
|
33
|
+
tmp_path: Path, dataset_arguments: SegmentationDatasetArguments
|
|
34
|
+
) -> tuple[str, SegmentationDatasetArguments, dict[str, int]]:
|
|
35
|
+
"""Generate segmentation dataset.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
tmp_path: path to temporary directory
|
|
39
|
+
dataset_arguments: dataset arguments
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
Tuple containing path to dataset, dataset arguments and class to index mapping
|
|
43
|
+
"""
|
|
44
|
+
train_samples = dataset_arguments.train_samples
|
|
45
|
+
val_samples = dataset_arguments.val_samples
|
|
46
|
+
test_samples = dataset_arguments.test_samples
|
|
47
|
+
classes = (
|
|
48
|
+
dataset_arguments.classes if dataset_arguments.classes else list(range(1, len(dataset_arguments.train_samples)))
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
segmentation_dataset_path = tmp_path / "segmentation_dataset"
|
|
52
|
+
segmentation_dataset_path.mkdir()
|
|
53
|
+
images_path = segmentation_dataset_path / "images"
|
|
54
|
+
masks_path = segmentation_dataset_path / "masks"
|
|
55
|
+
images_path.mkdir(parents=True)
|
|
56
|
+
masks_path.mkdir(parents=True)
|
|
57
|
+
class_to_idx = {class_name: i + 1 for i, class_name in enumerate(classes)}
|
|
58
|
+
classes = [0] + classes
|
|
59
|
+
|
|
60
|
+
counter = 0
|
|
61
|
+
for split_name, split_samples in zip(["train", "val", "test"], [train_samples, val_samples, test_samples]):
|
|
62
|
+
if split_samples is None:
|
|
63
|
+
continue
|
|
64
|
+
|
|
65
|
+
with open(segmentation_dataset_path / f"{split_name}.txt", "w") as split_file:
|
|
66
|
+
for class_name, samples in zip(classes, split_samples):
|
|
67
|
+
for _ in range(samples):
|
|
68
|
+
image = _random_image(size=(224, 224))
|
|
69
|
+
mask = np.zeros((224, 224), dtype=np.uint8)
|
|
70
|
+
if class_name != 0:
|
|
71
|
+
mask[100:150, 100:150] = class_to_idx[class_name]
|
|
72
|
+
image_path = images_path / f"{class_name}_{counter}.png"
|
|
73
|
+
mask_path = masks_path / f"{class_name}_{counter}.png"
|
|
74
|
+
cv2.imwrite(str(image_path), image)
|
|
75
|
+
cv2.imwrite(str(mask_path), mask)
|
|
76
|
+
split_file.write(f"images/{image_path.name}\n")
|
|
77
|
+
counter += 1
|
|
78
|
+
|
|
79
|
+
return str(segmentation_dataset_path), dataset_arguments, class_to_idx
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@pytest.fixture
|
|
83
|
+
def segmentation_dataset(
|
|
84
|
+
tmp_path: Path, dataset_arguments: SegmentationDatasetArguments
|
|
85
|
+
) -> tuple[str, SegmentationDatasetArguments, dict[str, int]]:
|
|
86
|
+
"""Fixture to dinamically generate a segmentation dataset. By default generated images are 224x224 pixels
|
|
87
|
+
and associated masks contains a 50x50 pixels square with the corresponding image class, so at the current stage
|
|
88
|
+
is not possible to have images with multiple annotations. Split files are saved as train.txt,
|
|
89
|
+
val.txt and test.txt.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
tmp_path: path to temporary directory
|
|
93
|
+
dataset_arguments: dataset arguments
|
|
94
|
+
|
|
95
|
+
Yields:
|
|
96
|
+
Tuple containing path to dataset, dataset arguments and class to index mapping
|
|
97
|
+
"""
|
|
98
|
+
yield _build_segmentation_dataset(tmp_path, dataset_arguments)
|
|
99
|
+
if tmp_path.exists():
|
|
100
|
+
shutil.rmtree(tmp_path)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
@pytest.fixture(
|
|
104
|
+
params=[
|
|
105
|
+
SegmentationDatasetArguments(
|
|
106
|
+
**{"train_samples": [3, 2], "val_samples": [2, 2], "test_samples": [1, 1], "classes": ["bad"]}
|
|
107
|
+
)
|
|
108
|
+
]
|
|
109
|
+
)
|
|
110
|
+
def base_binary_segmentation_dataset(
|
|
111
|
+
tmp_path: Path, request: Any
|
|
112
|
+
) -> tuple[str, SegmentationDatasetArguments, dict[str, int]]:
|
|
113
|
+
"""Generate a base binary segmentation dataset with the following structure:
|
|
114
|
+
- 3 good and 2 bad samples in train set
|
|
115
|
+
- 2 good and 2 bad samples in validation set
|
|
116
|
+
- 11 good and 1 bad sample in test set
|
|
117
|
+
- 2 classes: good and bad.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
tmp_path: path to temporary directory
|
|
121
|
+
request: pytest request
|
|
122
|
+
|
|
123
|
+
Yields:
|
|
124
|
+
Tuple containing path to dataset, dataset arguments and class to index mapping
|
|
125
|
+
"""
|
|
126
|
+
yield _build_segmentation_dataset(tmp_path, request.param)
|
|
127
|
+
if tmp_path.exists():
|
|
128
|
+
shutil.rmtree(tmp_path)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
@pytest.fixture(
|
|
132
|
+
params=[
|
|
133
|
+
SegmentationDatasetArguments(
|
|
134
|
+
**{
|
|
135
|
+
"train_samples": [2, 2, 2],
|
|
136
|
+
"val_samples": [2, 2, 2],
|
|
137
|
+
"test_samples": [1, 1, 1],
|
|
138
|
+
"classes": ["defect_1", "defect_2"],
|
|
139
|
+
}
|
|
140
|
+
)
|
|
141
|
+
]
|
|
142
|
+
)
|
|
143
|
+
def base_multiclass_segmentation_dataset(
|
|
144
|
+
tmp_path: Path, request: Any
|
|
145
|
+
) -> tuple[str, SegmentationDatasetArguments, dict[str, int]]:
|
|
146
|
+
"""Generate a base binary segmentation dataset with the following structure:
|
|
147
|
+
- 2 good, 2 defect_1 and 2 defect_2 samples in train set
|
|
148
|
+
- 2 good, 2 defect_1 and 2 defect_2 samples in validation set
|
|
149
|
+
- 1 good, 1 defect_1 and 1 defect_2 sample in test set
|
|
150
|
+
- 3 classes: good, defect_1 and defect_2.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
tmp_path: path to temporary directory
|
|
154
|
+
request: pytest request
|
|
155
|
+
|
|
156
|
+
Yields:
|
|
157
|
+
Tuple containing path to dataset, dataset arguments and class to index mapping
|
|
158
|
+
"""
|
|
159
|
+
yield _build_segmentation_dataset(tmp_path, request.param)
|
|
160
|
+
if tmp_path.exists():
|
|
161
|
+
shutil.rmtree(tmp_path)
|