quadra 0.0.1__py3-none-any.whl → 2.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hydra_plugins/quadra_searchpath_plugin.py +30 -0
- quadra/__init__.py +6 -0
- quadra/callbacks/__init__.py +0 -0
- quadra/callbacks/anomalib.py +289 -0
- quadra/callbacks/lightning.py +501 -0
- quadra/callbacks/mlflow.py +291 -0
- quadra/callbacks/scheduler.py +69 -0
- quadra/configs/__init__.py +0 -0
- quadra/configs/backbone/caformer_m36.yaml +8 -0
- quadra/configs/backbone/caformer_s36.yaml +8 -0
- quadra/configs/backbone/convnextv2_base.yaml +8 -0
- quadra/configs/backbone/convnextv2_femto.yaml +8 -0
- quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
- quadra/configs/backbone/dino_vitb8.yaml +12 -0
- quadra/configs/backbone/dino_vits8.yaml +12 -0
- quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
- quadra/configs/backbone/dinov2_vits14.yaml +12 -0
- quadra/configs/backbone/efficientnet_b0.yaml +8 -0
- quadra/configs/backbone/efficientnet_b1.yaml +8 -0
- quadra/configs/backbone/efficientnet_b2.yaml +8 -0
- quadra/configs/backbone/efficientnet_b3.yaml +8 -0
- quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
- quadra/configs/backbone/levit_128s.yaml +8 -0
- quadra/configs/backbone/mnasnet0_5.yaml +9 -0
- quadra/configs/backbone/resnet101.yaml +8 -0
- quadra/configs/backbone/resnet18.yaml +8 -0
- quadra/configs/backbone/resnet18_ssl.yaml +8 -0
- quadra/configs/backbone/resnet50.yaml +8 -0
- quadra/configs/backbone/smp.yaml +9 -0
- quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
- quadra/configs/backbone/unetr.yaml +15 -0
- quadra/configs/backbone/vit16_base.yaml +9 -0
- quadra/configs/backbone/vit16_small.yaml +9 -0
- quadra/configs/backbone/vit16_tiny.yaml +9 -0
- quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
- quadra/configs/callbacks/all.yaml +45 -0
- quadra/configs/callbacks/default.yaml +34 -0
- quadra/configs/callbacks/default_anomalib.yaml +64 -0
- quadra/configs/config.yaml +33 -0
- quadra/configs/core/default.yaml +11 -0
- quadra/configs/datamodule/base/anomaly.yaml +16 -0
- quadra/configs/datamodule/base/classification.yaml +21 -0
- quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
- quadra/configs/datamodule/base/segmentation.yaml +18 -0
- quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
- quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
- quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
- quadra/configs/datamodule/base/ssl.yaml +21 -0
- quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
- quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
- quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
- quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
- quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
- quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
- quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
- quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
- quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
- quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
- quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/base/classification/classification.yaml +73 -0
- quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
- quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
- quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
- quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
- quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
- quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
- quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
- quadra/configs/experiment/base/ssl/byol.yaml +43 -0
- quadra/configs/experiment/base/ssl/dino.yaml +46 -0
- quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
- quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
- quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
- quadra/configs/experiment/custom/cls.yaml +12 -0
- quadra/configs/experiment/default.yaml +15 -0
- quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
- quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
- quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
- quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
- quadra/configs/export/default.yaml +13 -0
- quadra/configs/hydra/anomaly_custom.yaml +15 -0
- quadra/configs/hydra/default.yaml +14 -0
- quadra/configs/inference/default.yaml +26 -0
- quadra/configs/logger/comet.yaml +10 -0
- quadra/configs/logger/csv.yaml +5 -0
- quadra/configs/logger/mlflow.yaml +12 -0
- quadra/configs/logger/tensorboard.yaml +8 -0
- quadra/configs/loss/asl.yaml +7 -0
- quadra/configs/loss/barlow.yaml +2 -0
- quadra/configs/loss/bce.yaml +1 -0
- quadra/configs/loss/byol.yaml +1 -0
- quadra/configs/loss/cross_entropy.yaml +1 -0
- quadra/configs/loss/dino.yaml +8 -0
- quadra/configs/loss/simclr.yaml +2 -0
- quadra/configs/loss/simsiam.yaml +1 -0
- quadra/configs/loss/smp_ce.yaml +3 -0
- quadra/configs/loss/smp_dice.yaml +2 -0
- quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
- quadra/configs/loss/smp_mcc.yaml +2 -0
- quadra/configs/loss/vicreg.yaml +5 -0
- quadra/configs/model/anomalib/cfa.yaml +35 -0
- quadra/configs/model/anomalib/cflow.yaml +30 -0
- quadra/configs/model/anomalib/csflow.yaml +34 -0
- quadra/configs/model/anomalib/dfm.yaml +19 -0
- quadra/configs/model/anomalib/draem.yaml +29 -0
- quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
- quadra/configs/model/anomalib/fastflow.yaml +32 -0
- quadra/configs/model/anomalib/padim.yaml +32 -0
- quadra/configs/model/anomalib/patchcore.yaml +36 -0
- quadra/configs/model/barlow.yaml +16 -0
- quadra/configs/model/byol.yaml +25 -0
- quadra/configs/model/classification.yaml +10 -0
- quadra/configs/model/dino.yaml +26 -0
- quadra/configs/model/logistic_regression.yaml +4 -0
- quadra/configs/model/multilabel_classification.yaml +9 -0
- quadra/configs/model/simclr.yaml +18 -0
- quadra/configs/model/simsiam.yaml +24 -0
- quadra/configs/model/smp.yaml +4 -0
- quadra/configs/model/smp_multiclass.yaml +4 -0
- quadra/configs/model/vicreg.yaml +16 -0
- quadra/configs/optimizer/adam.yaml +5 -0
- quadra/configs/optimizer/adamw.yaml +3 -0
- quadra/configs/optimizer/default.yaml +4 -0
- quadra/configs/optimizer/lars.yaml +8 -0
- quadra/configs/optimizer/sgd.yaml +4 -0
- quadra/configs/scheduler/default.yaml +5 -0
- quadra/configs/scheduler/rop.yaml +5 -0
- quadra/configs/scheduler/step.yaml +3 -0
- quadra/configs/scheduler/warmrestart.yaml +2 -0
- quadra/configs/scheduler/warmup.yaml +6 -0
- quadra/configs/task/anomalib/cfa.yaml +5 -0
- quadra/configs/task/anomalib/cflow.yaml +5 -0
- quadra/configs/task/anomalib/csflow.yaml +5 -0
- quadra/configs/task/anomalib/draem.yaml +5 -0
- quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
- quadra/configs/task/anomalib/fastflow.yaml +5 -0
- quadra/configs/task/anomalib/inference.yaml +3 -0
- quadra/configs/task/anomalib/padim.yaml +5 -0
- quadra/configs/task/anomalib/patchcore.yaml +5 -0
- quadra/configs/task/classification.yaml +6 -0
- quadra/configs/task/classification_evaluation.yaml +6 -0
- quadra/configs/task/default.yaml +1 -0
- quadra/configs/task/segmentation.yaml +9 -0
- quadra/configs/task/segmentation_evaluation.yaml +3 -0
- quadra/configs/task/sklearn_classification.yaml +13 -0
- quadra/configs/task/sklearn_classification_patch.yaml +11 -0
- quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
- quadra/configs/task/sklearn_classification_test.yaml +8 -0
- quadra/configs/task/ssl.yaml +2 -0
- quadra/configs/trainer/lightning_cpu.yaml +36 -0
- quadra/configs/trainer/lightning_gpu.yaml +35 -0
- quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
- quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
- quadra/configs/trainer/lightning_multigpu.yaml +37 -0
- quadra/configs/trainer/sklearn_classification.yaml +7 -0
- quadra/configs/transforms/byol.yaml +47 -0
- quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
- quadra/configs/transforms/default.yaml +37 -0
- quadra/configs/transforms/default_numpy.yaml +24 -0
- quadra/configs/transforms/default_resize.yaml +22 -0
- quadra/configs/transforms/dino.yaml +63 -0
- quadra/configs/transforms/linear_eval.yaml +18 -0
- quadra/datamodules/__init__.py +20 -0
- quadra/datamodules/anomaly.py +180 -0
- quadra/datamodules/base.py +375 -0
- quadra/datamodules/classification.py +1003 -0
- quadra/datamodules/generic/__init__.py +0 -0
- quadra/datamodules/generic/imagenette.py +144 -0
- quadra/datamodules/generic/mnist.py +81 -0
- quadra/datamodules/generic/mvtec.py +58 -0
- quadra/datamodules/generic/oxford_pet.py +163 -0
- quadra/datamodules/patch.py +190 -0
- quadra/datamodules/segmentation.py +742 -0
- quadra/datamodules/ssl.py +140 -0
- quadra/datasets/__init__.py +17 -0
- quadra/datasets/anomaly.py +287 -0
- quadra/datasets/classification.py +241 -0
- quadra/datasets/patch.py +138 -0
- quadra/datasets/segmentation.py +239 -0
- quadra/datasets/ssl.py +110 -0
- quadra/losses/__init__.py +0 -0
- quadra/losses/classification/__init__.py +6 -0
- quadra/losses/classification/asl.py +83 -0
- quadra/losses/classification/focal.py +320 -0
- quadra/losses/classification/prototypical.py +148 -0
- quadra/losses/ssl/__init__.py +17 -0
- quadra/losses/ssl/barlowtwins.py +47 -0
- quadra/losses/ssl/byol.py +37 -0
- quadra/losses/ssl/dino.py +129 -0
- quadra/losses/ssl/hyperspherical.py +45 -0
- quadra/losses/ssl/idmm.py +50 -0
- quadra/losses/ssl/simclr.py +67 -0
- quadra/losses/ssl/simsiam.py +30 -0
- quadra/losses/ssl/vicreg.py +76 -0
- quadra/main.py +49 -0
- quadra/metrics/__init__.py +3 -0
- quadra/metrics/segmentation.py +251 -0
- quadra/models/__init__.py +0 -0
- quadra/models/base.py +151 -0
- quadra/models/classification/__init__.py +8 -0
- quadra/models/classification/backbones.py +149 -0
- quadra/models/classification/base.py +92 -0
- quadra/models/evaluation.py +322 -0
- quadra/modules/__init__.py +0 -0
- quadra/modules/backbone.py +30 -0
- quadra/modules/base.py +312 -0
- quadra/modules/classification/__init__.py +3 -0
- quadra/modules/classification/base.py +327 -0
- quadra/modules/ssl/__init__.py +17 -0
- quadra/modules/ssl/barlowtwins.py +59 -0
- quadra/modules/ssl/byol.py +172 -0
- quadra/modules/ssl/common.py +285 -0
- quadra/modules/ssl/dino.py +186 -0
- quadra/modules/ssl/hyperspherical.py +206 -0
- quadra/modules/ssl/idmm.py +98 -0
- quadra/modules/ssl/simclr.py +73 -0
- quadra/modules/ssl/simsiam.py +68 -0
- quadra/modules/ssl/vicreg.py +67 -0
- quadra/optimizers/__init__.py +4 -0
- quadra/optimizers/lars.py +153 -0
- quadra/optimizers/sam.py +127 -0
- quadra/schedulers/__init__.py +3 -0
- quadra/schedulers/base.py +44 -0
- quadra/schedulers/warmup.py +127 -0
- quadra/tasks/__init__.py +24 -0
- quadra/tasks/anomaly.py +582 -0
- quadra/tasks/base.py +397 -0
- quadra/tasks/classification.py +1263 -0
- quadra/tasks/patch.py +492 -0
- quadra/tasks/segmentation.py +389 -0
- quadra/tasks/ssl.py +560 -0
- quadra/trainers/README.md +3 -0
- quadra/trainers/__init__.py +0 -0
- quadra/trainers/classification.py +179 -0
- quadra/utils/__init__.py +0 -0
- quadra/utils/anomaly.py +112 -0
- quadra/utils/classification.py +618 -0
- quadra/utils/deprecation.py +31 -0
- quadra/utils/evaluation.py +474 -0
- quadra/utils/export.py +585 -0
- quadra/utils/imaging.py +32 -0
- quadra/utils/logger.py +15 -0
- quadra/utils/mlflow.py +98 -0
- quadra/utils/model_manager.py +320 -0
- quadra/utils/models.py +523 -0
- quadra/utils/patch/__init__.py +15 -0
- quadra/utils/patch/dataset.py +1433 -0
- quadra/utils/patch/metrics.py +449 -0
- quadra/utils/patch/model.py +153 -0
- quadra/utils/patch/visualization.py +217 -0
- quadra/utils/resolver.py +42 -0
- quadra/utils/segmentation.py +31 -0
- quadra/utils/tests/__init__.py +0 -0
- quadra/utils/tests/fixtures/__init__.py +1 -0
- quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
- quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
- quadra/utils/tests/fixtures/dataset/classification.py +406 -0
- quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
- quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
- quadra/utils/tests/fixtures/models/__init__.py +3 -0
- quadra/utils/tests/fixtures/models/anomaly.py +89 -0
- quadra/utils/tests/fixtures/models/classification.py +45 -0
- quadra/utils/tests/fixtures/models/segmentation.py +33 -0
- quadra/utils/tests/helpers.py +70 -0
- quadra/utils/tests/models.py +27 -0
- quadra/utils/utils.py +525 -0
- quadra/utils/validator.py +115 -0
- quadra/utils/visualization.py +422 -0
- quadra/utils/vit_explainability.py +349 -0
- quadra-2.2.7.dist-info/LICENSE +201 -0
- quadra-2.2.7.dist-info/METADATA +381 -0
- quadra-2.2.7.dist-info/RECORD +300 -0
- {quadra-0.0.1.dist-info → quadra-2.2.7.dist-info}/WHEEL +1 -1
- quadra-2.2.7.dist-info/entry_points.txt +3 -0
- quadra-0.0.1.dist-info/METADATA +0 -14
- quadra-0.0.1.dist-info/RECORD +0 -4
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
# pylint: disable=unsubscriptable-object
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import torch
|
|
8
|
+
from sklearn.model_selection import train_test_split
|
|
9
|
+
from torch.utils.data import DataLoader
|
|
10
|
+
|
|
11
|
+
from quadra.datamodules.classification import ClassificationDataModule
|
|
12
|
+
from quadra.datasets import TwoAugmentationDataset, TwoSetAugmentationDataset
|
|
13
|
+
from quadra.utils import utils
|
|
14
|
+
|
|
15
|
+
log = utils.get_logger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class SSLDataModule(ClassificationDataModule):
|
|
19
|
+
"""Base class for all data modules for self supervised learning data modules.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
data_path: Path to the data main folder.
|
|
23
|
+
augmentation_dataset: Augmentation dataset
|
|
24
|
+
for training dataset.
|
|
25
|
+
name: The name for the data module. Defaults to "ssl_datamodule".
|
|
26
|
+
split_validation: Whether to split the validation set if . Defaults to True.
|
|
27
|
+
**kwargs: The keyword arguments for the classification data module. Defaults to None.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
data_path: str,
|
|
33
|
+
augmentation_dataset: TwoAugmentationDataset | TwoSetAugmentationDataset,
|
|
34
|
+
name: str = "ssl_datamodule",
|
|
35
|
+
split_validation: bool = True,
|
|
36
|
+
**kwargs: Any,
|
|
37
|
+
):
|
|
38
|
+
super().__init__(
|
|
39
|
+
data_path=data_path,
|
|
40
|
+
name=name,
|
|
41
|
+
**kwargs,
|
|
42
|
+
)
|
|
43
|
+
self.augmentation_dataset = augmentation_dataset
|
|
44
|
+
self.classifier_train_dataset: torch.utils.data.Dataset | None = None
|
|
45
|
+
self.split_validation = split_validation
|
|
46
|
+
|
|
47
|
+
def setup(self, stage: str | None = None) -> None:
|
|
48
|
+
"""Setup data module based on stages of training."""
|
|
49
|
+
if stage == "fit":
|
|
50
|
+
self.train_dataset = self.dataset(
|
|
51
|
+
samples=self.train_data["samples"].tolist(),
|
|
52
|
+
targets=self.train_data["targets"].tolist(),
|
|
53
|
+
transform=self.train_transform,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
if np.unique(self.train_data["targets"]).shape[0] > 1 and not self.split_validation:
|
|
57
|
+
self.classifier_train_dataset = self.dataset(
|
|
58
|
+
samples=self.train_data["samples"].tolist(),
|
|
59
|
+
targets=self.train_data["targets"].tolist(),
|
|
60
|
+
transform=self.val_transform,
|
|
61
|
+
)
|
|
62
|
+
self.val_dataset = self.dataset(
|
|
63
|
+
samples=self.val_data["samples"].tolist(),
|
|
64
|
+
targets=self.val_data["targets"].tolist(),
|
|
65
|
+
transform=self.val_transform,
|
|
66
|
+
)
|
|
67
|
+
else:
|
|
68
|
+
train_classifier_samples, val_samples, train_classifier_targets, val_targets = train_test_split(
|
|
69
|
+
self.val_data["samples"],
|
|
70
|
+
self.val_data["targets"],
|
|
71
|
+
test_size=0.3,
|
|
72
|
+
random_state=self.seed,
|
|
73
|
+
stratify=self.val_data["targets"],
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
self.classifier_train_dataset = self.dataset(
|
|
77
|
+
samples=train_classifier_samples,
|
|
78
|
+
targets=train_classifier_targets,
|
|
79
|
+
transform=self.test_transform,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
self.val_dataset = self.dataset(
|
|
83
|
+
samples=val_samples,
|
|
84
|
+
targets=val_targets,
|
|
85
|
+
transform=self.val_transform,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
log.warning(
|
|
89
|
+
"The training set contains only one class and cannot be used to train a classifier. To overcome "
|
|
90
|
+
"this issue 70% of the validation set is used to train the classifier. The remaining will be used "
|
|
91
|
+
"as standard validation. To disable this behaviour set the `split_validation` parameter to False."
|
|
92
|
+
)
|
|
93
|
+
self._check_train_dataset_config()
|
|
94
|
+
if stage == "test":
|
|
95
|
+
self.test_dataset = self.dataset(
|
|
96
|
+
samples=self.test_data["samples"].tolist(),
|
|
97
|
+
targets=self.test_data["targets"].tolist(),
|
|
98
|
+
transform=self.test_transform,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
def _check_train_dataset_config(self):
|
|
102
|
+
"""Check if train dataset is configured correctly."""
|
|
103
|
+
if self.train_dataset is None:
|
|
104
|
+
raise ValueError("Train dataset is not initialized")
|
|
105
|
+
if self.augmentation_dataset is None:
|
|
106
|
+
raise ValueError("Augmentation dataset is not initialized")
|
|
107
|
+
if self.train_dataset.transform is not None:
|
|
108
|
+
log.warning("Train dataset transform is not None. It will be applied before SSL augmentations")
|
|
109
|
+
|
|
110
|
+
def train_dataloader(self) -> DataLoader:
|
|
111
|
+
"""Returns train dataloader."""
|
|
112
|
+
if not isinstance(self.train_dataset, torch.utils.data.Dataset):
|
|
113
|
+
raise ValueError("Train dataset is not a subclass of `torch.utils.data.Dataset`")
|
|
114
|
+
self.augmentation_dataset.dataset = self.train_dataset
|
|
115
|
+
loader = DataLoader(
|
|
116
|
+
self.augmentation_dataset,
|
|
117
|
+
batch_size=self.batch_size,
|
|
118
|
+
shuffle=True,
|
|
119
|
+
num_workers=self.num_workers,
|
|
120
|
+
drop_last=False,
|
|
121
|
+
pin_memory=True,
|
|
122
|
+
persistent_workers=self.num_workers > 0,
|
|
123
|
+
)
|
|
124
|
+
return loader
|
|
125
|
+
|
|
126
|
+
def classifier_train_dataloader(self) -> DataLoader:
|
|
127
|
+
"""Returns classifier train dataloader."""
|
|
128
|
+
if self.classifier_train_dataset is None:
|
|
129
|
+
raise ValueError("Classifier train dataset is not initialized")
|
|
130
|
+
|
|
131
|
+
loader = DataLoader(
|
|
132
|
+
self.classifier_train_dataset,
|
|
133
|
+
batch_size=self.batch_size,
|
|
134
|
+
shuffle=True,
|
|
135
|
+
num_workers=self.num_workers,
|
|
136
|
+
drop_last=False,
|
|
137
|
+
pin_memory=True,
|
|
138
|
+
persistent_workers=self.num_workers > 0,
|
|
139
|
+
)
|
|
140
|
+
return loader
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from .anomaly import AnomalyDataset
|
|
2
|
+
from .classification import ClassificationDataset, ImageClassificationListDataset, MultilabelClassificationDataset
|
|
3
|
+
from .patch import PatchSklearnClassificationTrainDataset
|
|
4
|
+
from .segmentation import SegmentationDataset, SegmentationDatasetMulticlass
|
|
5
|
+
from .ssl import TwoAugmentationDataset, TwoSetAugmentationDataset
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"ImageClassificationListDataset",
|
|
9
|
+
"ClassificationDataset",
|
|
10
|
+
"SegmentationDataset",
|
|
11
|
+
"SegmentationDatasetMulticlass",
|
|
12
|
+
"PatchSklearnClassificationTrainDataset",
|
|
13
|
+
"MultilabelClassificationDataset",
|
|
14
|
+
"AnomalyDataset",
|
|
15
|
+
"TwoAugmentationDataset",
|
|
16
|
+
"TwoSetAugmentationDataset",
|
|
17
|
+
]
|
|
@@ -0,0 +1,287 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import random
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
import albumentations as alb
|
|
8
|
+
import cv2
|
|
9
|
+
import numpy as np
|
|
10
|
+
import pandas as pd
|
|
11
|
+
from pandas import DataFrame
|
|
12
|
+
from torch import Tensor
|
|
13
|
+
from torch.utils.data import Dataset
|
|
14
|
+
|
|
15
|
+
from quadra.utils.utils import IMAGE_EXTENSIONS
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def create_validation_set_from_test_set(samples: DataFrame, seed: int = 0) -> DataFrame:
|
|
19
|
+
"""Craete Validation Set from Test Set.
|
|
20
|
+
|
|
21
|
+
This function creates a validation set from test set by splitting both
|
|
22
|
+
normal and abnormal samples to two.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
samples: Dataframe containing dataset info such as filenames, splits etc.
|
|
26
|
+
seed: Random seed to ensure reproducibility. Defaults to 0.
|
|
27
|
+
"""
|
|
28
|
+
if seed > 0:
|
|
29
|
+
random.seed(seed)
|
|
30
|
+
|
|
31
|
+
# Split normal images.
|
|
32
|
+
normal_test_image_indices = samples.index[(samples.split == "test") & (samples.targets == "good")].to_list()
|
|
33
|
+
num_normal_valid_images = len(normal_test_image_indices) // 2
|
|
34
|
+
|
|
35
|
+
indices_to_sample = random.sample(population=normal_test_image_indices, k=num_normal_valid_images)
|
|
36
|
+
samples.loc[indices_to_sample, "split"] = "val"
|
|
37
|
+
|
|
38
|
+
# Split abnormal images.
|
|
39
|
+
abnormal_test_image_indices = samples.index[(samples.split == "test") & (samples.targets != "good")].to_list()
|
|
40
|
+
num_abnormal_valid_images = len(abnormal_test_image_indices) // 2
|
|
41
|
+
|
|
42
|
+
indices_to_sample = random.sample(population=abnormal_test_image_indices, k=num_abnormal_valid_images)
|
|
43
|
+
samples.loc[indices_to_sample, "split"] = "val"
|
|
44
|
+
|
|
45
|
+
return samples
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def split_normal_images_in_train_set(samples: DataFrame, split_ratio: float = 0.1, seed: int = 0) -> DataFrame:
|
|
49
|
+
"""Split normal images in train set.
|
|
50
|
+
|
|
51
|
+
This function splits the normal images in training set and assigns the
|
|
52
|
+
values to the test set. This is particularly useful especially when the
|
|
53
|
+
test set does not contain any normal images.
|
|
54
|
+
|
|
55
|
+
This is important because when the test set doesn't have any normal images,
|
|
56
|
+
AUC computation fails due to having single class.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
samples: Dataframe containing dataset info such as filenames, splits etc.
|
|
60
|
+
split_ratio: Train-Test normal image split ratio. Defaults to 0.1.
|
|
61
|
+
seed: Random seed to ensure reproducibility. Defaults to 0.
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
Output dataframe where the part of the training set is assigned to test set.
|
|
65
|
+
"""
|
|
66
|
+
if seed > 0:
|
|
67
|
+
random.seed(seed)
|
|
68
|
+
|
|
69
|
+
normal_train_image_indices = samples.index[(samples.split == "train") & (samples.targets == "good")].to_list()
|
|
70
|
+
num_normal_train_images = len(normal_train_image_indices)
|
|
71
|
+
num_normal_valid_images = int(num_normal_train_images * split_ratio)
|
|
72
|
+
|
|
73
|
+
indices_to_split_from_train_set = random.sample(population=normal_train_image_indices, k=num_normal_valid_images)
|
|
74
|
+
samples.loc[indices_to_split_from_train_set, "split"] = "test"
|
|
75
|
+
|
|
76
|
+
return samples
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def make_anomaly_dataset(
|
|
80
|
+
path: Path,
|
|
81
|
+
split: str | None = None,
|
|
82
|
+
split_ratio: float = 0.1,
|
|
83
|
+
seed: int = 0,
|
|
84
|
+
mask_suffix: str | None = None,
|
|
85
|
+
create_test_set_if_empty: bool = True,
|
|
86
|
+
) -> DataFrame:
|
|
87
|
+
"""Create dataframe by parsing a folder following the MVTec data file structure.
|
|
88
|
+
|
|
89
|
+
The files are expected to follow the structure:
|
|
90
|
+
path/to/dataset/split/label/image_filename.xyz
|
|
91
|
+
path/to/dataset/ground_truth/label/mask_filename.png
|
|
92
|
+
|
|
93
|
+
Masks MUST be png images, no other format is allowed
|
|
94
|
+
Split can be either train/val/test
|
|
95
|
+
|
|
96
|
+
This function creates a dataframe to store the parsed information based on the following format:
|
|
97
|
+
|---|---------------|-------|---------|--------------|-----------------------------------------------|-------------|
|
|
98
|
+
| | path | split | targets | samples | mask_path | label_index |
|
|
99
|
+
|---|---------------|-------|---------|--------------|-----------------------------------------------|-------------|
|
|
100
|
+
| 0 | datasets/name | test | defect | filename.xyz | ground_truth/defect/filename{mask_suffix}.png | 1 |
|
|
101
|
+
|---|---------------|-------|---------|--------------|-----------------------------------------------|-------------|
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
path: Path to dataset
|
|
105
|
+
split: Dataset split (i.e., either train or test). Defaults to None.
|
|
106
|
+
split_ratio: Ratio to split normal training images and add to the
|
|
107
|
+
test set in case test set doesn't contain any normal images.
|
|
108
|
+
Defaults to 0.1.
|
|
109
|
+
seed: Random seed to ensure reproducibility when splitting. Defaults to 0.
|
|
110
|
+
mask_suffix: String to append to the base filename to get the mask name, by default for MVTec dataset masks
|
|
111
|
+
are saved as imagename_mask.png in this case the parameter shoul be filled with "_mask"
|
|
112
|
+
create_test_set_if_empty: If True, create a test set if the test set is empty.
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
Example:
|
|
116
|
+
The following example shows how to get training samples from MVTec bottle category:
|
|
117
|
+
|
|
118
|
+
>>> root = Path('./MVTec')
|
|
119
|
+
>>> category = 'bottle'
|
|
120
|
+
>>> path = root / category
|
|
121
|
+
>>> path
|
|
122
|
+
PosixPath('MVTec/bottle')
|
|
123
|
+
|
|
124
|
+
>>> samples = make_anomaly_dataset(path, split='train', split_ratio=0.1, seed=0)
|
|
125
|
+
>>> samples.head()
|
|
126
|
+
path split label image_path mask_path label_index
|
|
127
|
+
0 MVTec/bottle train good MVTec/bottle/train/good/105.png MVTec/bottle/ground_truth/good/105_mask.png 0
|
|
128
|
+
1 MVTec/bottle train good MVTec/bottle/train/good/017.png MVTec/bottle/ground_truth/good/017_mask.png 0
|
|
129
|
+
2 MVTec/bottle train good MVTec/bottle/train/good/137.png MVTec/bottle/ground_truth/good/137_mask.png 0
|
|
130
|
+
3 MVTec/bottle train good MVTec/bottle/train/good/152.png MVTec/bottle/ground_truth/good/152_mask.png 0
|
|
131
|
+
4 MVTec/bottle train good MVTec/bottle/train/good/109.png MVTec/bottle/ground_truth/good/109_mask.png 0
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
An output dataframe containing samples for the requested split (ie., train or test)
|
|
135
|
+
"""
|
|
136
|
+
samples_list = [
|
|
137
|
+
(str(path),) + filename.parts[-3:]
|
|
138
|
+
for filename in path.glob("**/*")
|
|
139
|
+
if filename.is_file()
|
|
140
|
+
and os.path.splitext(filename)[-1].lower() in IMAGE_EXTENSIONS
|
|
141
|
+
and ".ipynb_checkpoints" not in str(filename)
|
|
142
|
+
]
|
|
143
|
+
|
|
144
|
+
if len(samples_list) == 0:
|
|
145
|
+
raise RuntimeError(f"Found 0 images in {path}")
|
|
146
|
+
|
|
147
|
+
samples_list.sort()
|
|
148
|
+
|
|
149
|
+
data = pd.DataFrame(samples_list, columns=["path", "split", "targets", "samples"])
|
|
150
|
+
data = data[data.split != "ground_truth"]
|
|
151
|
+
|
|
152
|
+
# Create mask_path column, masks MUST have png extension
|
|
153
|
+
data["mask_path"] = (
|
|
154
|
+
data.path
|
|
155
|
+
+ "/ground_truth/"
|
|
156
|
+
+ data.targets
|
|
157
|
+
+ "/"
|
|
158
|
+
+ data.samples.apply(lambda x: os.path.splitext(os.path.basename(x))[0])
|
|
159
|
+
+ (f"{mask_suffix}.png" if mask_suffix is not None else ".png")
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
# Modify image_path column by converting to absolute path
|
|
163
|
+
data["samples"] = data.path + "/" + data.split + "/" + data.targets + "/" + data.samples
|
|
164
|
+
|
|
165
|
+
# Split the normal images in training set if test set doesn't
|
|
166
|
+
# contain any normal images. This is needed because AUC score
|
|
167
|
+
# cannot be computed based on 1-class
|
|
168
|
+
if sum((data.split == "test") & (data.targets == "good")) == 0 and create_test_set_if_empty:
|
|
169
|
+
data = split_normal_images_in_train_set(data, split_ratio, seed)
|
|
170
|
+
|
|
171
|
+
# Good images don't have mask
|
|
172
|
+
data.loc[(data.split == "test") & (data.targets == "good"), "mask_path"] = ""
|
|
173
|
+
|
|
174
|
+
# Create label index for normal (0), anomalous (1) and unknown (-1) images.
|
|
175
|
+
data.loc[data.targets == "good", "label_index"] = 0
|
|
176
|
+
data.loc[~data.targets.isin(["good", "unknown"]), "label_index"] = 1
|
|
177
|
+
data.loc[data.targets == "unknown", "label_index"] = -1
|
|
178
|
+
data.label_index = data.label_index.astype(int)
|
|
179
|
+
|
|
180
|
+
# Get the data frame for the split.
|
|
181
|
+
if split is not None and split in ["train", "val", "test"]:
|
|
182
|
+
data = data[data.split == split]
|
|
183
|
+
data = data.reset_index(drop=True)
|
|
184
|
+
|
|
185
|
+
return data
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
class AnomalyDataset(Dataset):
|
|
189
|
+
"""Anomaly Dataset.
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
transform: Albumentations compose.
|
|
193
|
+
task: ``classification`` or ``segmentation``
|
|
194
|
+
samples: Pandas dataframe containing samples following the same structure created by make_anomaly_dataset
|
|
195
|
+
valid_area_mask: Optional path to the mask to use to filter out the valid area of the image. If None, the
|
|
196
|
+
whole image is considered valid.
|
|
197
|
+
crop_area: Optional tuple of 4 integers (x1, y1, x2, y2) to crop the image to the specified area. If None, the
|
|
198
|
+
whole image is considered valid.
|
|
199
|
+
"""
|
|
200
|
+
|
|
201
|
+
def __init__(
|
|
202
|
+
self,
|
|
203
|
+
transform: alb.Compose,
|
|
204
|
+
samples: DataFrame,
|
|
205
|
+
task: str = "segmentation",
|
|
206
|
+
valid_area_mask: str | None = None,
|
|
207
|
+
crop_area: tuple[int, int, int, int] | None = None,
|
|
208
|
+
) -> None:
|
|
209
|
+
self.task = task
|
|
210
|
+
self.transform = transform
|
|
211
|
+
|
|
212
|
+
self.samples = samples
|
|
213
|
+
self.samples = self.samples.reset_index(drop=True)
|
|
214
|
+
self.split = self.samples.split.unique()[0]
|
|
215
|
+
|
|
216
|
+
self.crop_area = crop_area
|
|
217
|
+
self.valid_area_mask: np.ndarray | None = None
|
|
218
|
+
|
|
219
|
+
if valid_area_mask is not None:
|
|
220
|
+
if not os.path.exists(valid_area_mask):
|
|
221
|
+
raise RuntimeError(f"Valid area mask {valid_area_mask} does not exist.")
|
|
222
|
+
|
|
223
|
+
self.valid_area_mask = cv2.imread(valid_area_mask, 0) > 0 # type: ignore[operator]
|
|
224
|
+
|
|
225
|
+
def __len__(self) -> int:
|
|
226
|
+
"""Get length of the dataset."""
|
|
227
|
+
return len(self.samples)
|
|
228
|
+
|
|
229
|
+
def __getitem__(self, index: int) -> dict[str, str | Tensor]:
|
|
230
|
+
"""Get dataset item for the index ``index``.
|
|
231
|
+
|
|
232
|
+
Args:
|
|
233
|
+
index: Index to get the item.
|
|
234
|
+
|
|
235
|
+
Returns:
|
|
236
|
+
Dict of image tensor during training.
|
|
237
|
+
Otherwise, Dict containing image path, target path, image tensor, label and transformed bounding box.
|
|
238
|
+
"""
|
|
239
|
+
item: dict[str, str | Tensor] = {}
|
|
240
|
+
|
|
241
|
+
image_path = self.samples.samples.iloc[index]
|
|
242
|
+
image = cv2.imread(image_path)
|
|
243
|
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
244
|
+
|
|
245
|
+
original_image_shape = image.shape
|
|
246
|
+
if self.valid_area_mask is not None:
|
|
247
|
+
image = image * self.valid_area_mask[:, :, np.newaxis]
|
|
248
|
+
|
|
249
|
+
if self.crop_area is not None:
|
|
250
|
+
image = image[self.crop_area[1] : self.crop_area[3], self.crop_area[0] : self.crop_area[2]]
|
|
251
|
+
|
|
252
|
+
label_index = self.samples.label_index[index]
|
|
253
|
+
|
|
254
|
+
if self.split == "train":
|
|
255
|
+
pre_processed = self.transform(image=image)
|
|
256
|
+
item = {"image": pre_processed["image"], "label": label_index}
|
|
257
|
+
elif self.split in ["val", "test"]:
|
|
258
|
+
item["image_path"] = image_path
|
|
259
|
+
item["label"] = label_index
|
|
260
|
+
|
|
261
|
+
if self.task == "segmentation":
|
|
262
|
+
mask_path = self.samples.mask_path[index]
|
|
263
|
+
|
|
264
|
+
# If good images have no associated mask create an empty one
|
|
265
|
+
if label_index == 0:
|
|
266
|
+
mask = np.zeros(shape=original_image_shape[:2])
|
|
267
|
+
elif os.path.isfile(mask_path):
|
|
268
|
+
mask = cv2.imread(mask_path, flags=0) / 255.0 # type: ignore[operator]
|
|
269
|
+
else:
|
|
270
|
+
# We need ones in the mask to compute correctly at least image level f1 score
|
|
271
|
+
mask = np.ones(shape=original_image_shape[:2])
|
|
272
|
+
|
|
273
|
+
if self.valid_area_mask is not None:
|
|
274
|
+
mask = mask * self.valid_area_mask
|
|
275
|
+
|
|
276
|
+
if self.crop_area is not None:
|
|
277
|
+
mask = mask[self.crop_area[1] : self.crop_area[3], self.crop_area[0] : self.crop_area[2]]
|
|
278
|
+
|
|
279
|
+
pre_processed = self.transform(image=image, mask=mask)
|
|
280
|
+
|
|
281
|
+
item["mask_path"] = mask_path
|
|
282
|
+
item["mask"] = pre_processed["mask"]
|
|
283
|
+
else:
|
|
284
|
+
pre_processed = self.transform(image=image)
|
|
285
|
+
|
|
286
|
+
item["image"] = pre_processed["image"]
|
|
287
|
+
return item
|
|
@@ -0,0 +1,241 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import warnings
|
|
4
|
+
from collections.abc import Callable
|
|
5
|
+
|
|
6
|
+
import cv2
|
|
7
|
+
import numpy as np
|
|
8
|
+
import torch
|
|
9
|
+
from torch.utils.data import Dataset
|
|
10
|
+
|
|
11
|
+
from quadra.utils.imaging import crop_image, keep_aspect_ratio_resize
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ImageClassificationListDataset(Dataset):
|
|
15
|
+
"""Standard classification dataset.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
samples: List of paths to images to be read
|
|
19
|
+
targets: List of labels, one for every image
|
|
20
|
+
in samples
|
|
21
|
+
class_to_idx: mapping from classes
|
|
22
|
+
to unique indexes.
|
|
23
|
+
Defaults to None.
|
|
24
|
+
resize: Integer specifying the size of
|
|
25
|
+
a first optional resize keeping the aspect ratio: the smaller side
|
|
26
|
+
of the image will be resized to `resize`, while the longer will be
|
|
27
|
+
resized keeping the aspect ratio.
|
|
28
|
+
Defaults to None.
|
|
29
|
+
roi: Optional ROI, with
|
|
30
|
+
(x_upper_left, y_upper_left, x_bottom_right, y_bottom_right).
|
|
31
|
+
Defaults to None.
|
|
32
|
+
transform: Optional Albumentations
|
|
33
|
+
transform.
|
|
34
|
+
Defaults to None.
|
|
35
|
+
rgb: if False, image will be converted in grayscale
|
|
36
|
+
channel: 1 or 3. If rgb is True, then channel will be set at 3.
|
|
37
|
+
allow_missing_label: If set to false warn the user if the dataset contains missing labels
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
samples: list[str],
|
|
43
|
+
targets: list[str | int],
|
|
44
|
+
class_to_idx: dict | None = None,
|
|
45
|
+
resize: int | None = None,
|
|
46
|
+
roi: tuple[int, int, int, int] | None = None,
|
|
47
|
+
transform: Callable | None = None,
|
|
48
|
+
rgb: bool = True,
|
|
49
|
+
channel: int = 3,
|
|
50
|
+
allow_missing_label: bool | None = False,
|
|
51
|
+
):
|
|
52
|
+
super().__init__()
|
|
53
|
+
assert len(samples) == len(
|
|
54
|
+
targets
|
|
55
|
+
), f"Samples ({len(samples)}) and targets ({len(targets)}) must have the same length"
|
|
56
|
+
# Setting the ROI
|
|
57
|
+
self.roi = roi
|
|
58
|
+
|
|
59
|
+
# Keep-Aspect-Ratio resize
|
|
60
|
+
self.resize = resize
|
|
61
|
+
|
|
62
|
+
if not allow_missing_label and None in targets:
|
|
63
|
+
warnings.warn(
|
|
64
|
+
(
|
|
65
|
+
"Dataset contains empty targets but allow_missing_label is set to False, "
|
|
66
|
+
"be careful because None labels will not work inside Dataloaders"
|
|
67
|
+
),
|
|
68
|
+
UserWarning,
|
|
69
|
+
stacklevel=2,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
targets = [-1 if target is None else target for target in targets]
|
|
73
|
+
# Data
|
|
74
|
+
self.x = np.array(samples)
|
|
75
|
+
self.y = np.array(targets)
|
|
76
|
+
|
|
77
|
+
if class_to_idx is None:
|
|
78
|
+
unique_targets = np.unique(targets)
|
|
79
|
+
class_to_idx = {c: i for i, c in enumerate(unique_targets)}
|
|
80
|
+
|
|
81
|
+
self.class_to_idx = class_to_idx
|
|
82
|
+
self.idx_to_class = {v: k for k, v in class_to_idx.items()}
|
|
83
|
+
self.samples = [
|
|
84
|
+
(path, self.class_to_idx[self.y[i]] if (self.y[i] != -1 and self.y[i] != "-1") else -1)
|
|
85
|
+
for i, path in enumerate(self.x)
|
|
86
|
+
]
|
|
87
|
+
|
|
88
|
+
self.rgb = rgb
|
|
89
|
+
self.channel = 3 if rgb else channel
|
|
90
|
+
|
|
91
|
+
self.transform = transform
|
|
92
|
+
|
|
93
|
+
def __getitem__(self, idx) -> tuple[np.ndarray, np.ndarray]:
|
|
94
|
+
path, y = self.samples[idx]
|
|
95
|
+
|
|
96
|
+
# Load image
|
|
97
|
+
x = cv2.imread(str(path))
|
|
98
|
+
if self.rgb:
|
|
99
|
+
x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB)
|
|
100
|
+
else:
|
|
101
|
+
x = cv2.cvtColor(x, cv2.COLOR_BGR2GRAY)
|
|
102
|
+
x = cv2.cvtColor(x, cv2.COLOR_GRAY2RGB)
|
|
103
|
+
|
|
104
|
+
if self.channel == 1:
|
|
105
|
+
x = x[:, :, 0]
|
|
106
|
+
|
|
107
|
+
# Crop with ROI
|
|
108
|
+
if self.roi:
|
|
109
|
+
x = crop_image(x, self.roi)
|
|
110
|
+
|
|
111
|
+
# Resize keeping aspect ratio
|
|
112
|
+
if self.resize:
|
|
113
|
+
x = keep_aspect_ratio_resize(x, self.resize)
|
|
114
|
+
|
|
115
|
+
if self.transform:
|
|
116
|
+
aug = self.transform(image=x)
|
|
117
|
+
x = aug["image"]
|
|
118
|
+
|
|
119
|
+
return x, y
|
|
120
|
+
|
|
121
|
+
def __len__(self):
|
|
122
|
+
return len(self.samples)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class ClassificationDataset(ImageClassificationListDataset):
|
|
126
|
+
"""Custom Classification Dataset.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
samples: List of paths to images
|
|
130
|
+
targets: List of targets
|
|
131
|
+
class_to_idx: Defaults to None.
|
|
132
|
+
resize: Resize image to this size. Defaults to None.
|
|
133
|
+
roi: Region of interest. Defaults to None.
|
|
134
|
+
transform: transform function. Defaults to None.
|
|
135
|
+
rgb: Use RGB space
|
|
136
|
+
channel: Number of channels. Defaults to 3.
|
|
137
|
+
random_padding: Random padding. Defaults to False.
|
|
138
|
+
circular_crop: Circular crop. Defaults to False.
|
|
139
|
+
"""
|
|
140
|
+
|
|
141
|
+
def __init__(
|
|
142
|
+
self,
|
|
143
|
+
samples: list[str],
|
|
144
|
+
targets: list[str | int],
|
|
145
|
+
class_to_idx: dict | None = None,
|
|
146
|
+
resize: int | None = None,
|
|
147
|
+
roi: tuple[int, int, int, int] | None = None,
|
|
148
|
+
transform: Callable | None = None,
|
|
149
|
+
rgb: bool = True,
|
|
150
|
+
channel: int = 3,
|
|
151
|
+
random_padding: bool = False,
|
|
152
|
+
circular_crop: bool = False,
|
|
153
|
+
):
|
|
154
|
+
super().__init__(samples, targets, class_to_idx, resize, roi, transform, rgb, channel)
|
|
155
|
+
if transform is None:
|
|
156
|
+
self.transform = None
|
|
157
|
+
|
|
158
|
+
self.random_padding = random_padding
|
|
159
|
+
self.circular_crop = circular_crop
|
|
160
|
+
|
|
161
|
+
def __getitem__(self, idx):
|
|
162
|
+
path, y = self.samples[idx]
|
|
163
|
+
path = str(path)
|
|
164
|
+
|
|
165
|
+
# Load image
|
|
166
|
+
x = cv2.imread(path)
|
|
167
|
+
if self.rgb:
|
|
168
|
+
x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB)
|
|
169
|
+
else:
|
|
170
|
+
x = cv2.cvtColor(x, cv2.COLOR_BGR2GRAY)
|
|
171
|
+
x = cv2.cvtColor(x, cv2.COLOR_GRAY2RGB)
|
|
172
|
+
|
|
173
|
+
if self.transform is not None:
|
|
174
|
+
aug = self.transform(image=x)
|
|
175
|
+
x = aug["image"]
|
|
176
|
+
|
|
177
|
+
if self.channel == 1:
|
|
178
|
+
x = x[:1]
|
|
179
|
+
|
|
180
|
+
return x, y
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
class MultilabelClassificationDataset(torch.utils.data.Dataset):
|
|
184
|
+
"""Custom MultilabelClassification Dataset.
|
|
185
|
+
|
|
186
|
+
Args:
|
|
187
|
+
samples: list of paths to images.
|
|
188
|
+
targets: array of multiple targets per sample. The array must be a one-hot enoding.
|
|
189
|
+
It must have a shape of (n_samples, n_targets).
|
|
190
|
+
class_to_idx: Defaults to None.
|
|
191
|
+
transform: transform function. Defaults to None.
|
|
192
|
+
rgb: Use RGB space
|
|
193
|
+
"""
|
|
194
|
+
|
|
195
|
+
def __init__(
|
|
196
|
+
self,
|
|
197
|
+
samples: list[str],
|
|
198
|
+
targets: np.ndarray,
|
|
199
|
+
class_to_idx: dict | None = None,
|
|
200
|
+
transform: Callable | None = None,
|
|
201
|
+
rgb: bool = True,
|
|
202
|
+
):
|
|
203
|
+
super().__init__()
|
|
204
|
+
assert len(samples) == len(
|
|
205
|
+
targets
|
|
206
|
+
), f"Samples ({len(samples)}) and targets ({len(targets)}) must have the same length"
|
|
207
|
+
|
|
208
|
+
# Data
|
|
209
|
+
self.x = samples
|
|
210
|
+
self.y = targets
|
|
211
|
+
|
|
212
|
+
# Class to idx and the other way around
|
|
213
|
+
if class_to_idx is None:
|
|
214
|
+
unique_targets = targets.shape[1]
|
|
215
|
+
class_to_idx = {c: i for i, c in enumerate(range(unique_targets))}
|
|
216
|
+
self.class_to_idx = class_to_idx
|
|
217
|
+
self.idx_to_class = {v: k for k, v in class_to_idx.items()}
|
|
218
|
+
self.samples = list(zip(self.x, self.y))
|
|
219
|
+
self.rgb = rgb
|
|
220
|
+
self.transform = transform
|
|
221
|
+
|
|
222
|
+
def __getitem__(self, idx):
|
|
223
|
+
path, y = self.samples[idx]
|
|
224
|
+
path = str(path)
|
|
225
|
+
|
|
226
|
+
# Load image
|
|
227
|
+
x = cv2.imread(path)
|
|
228
|
+
if self.rgb:
|
|
229
|
+
x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB)
|
|
230
|
+
else:
|
|
231
|
+
x = cv2.cvtColor(x, cv2.COLOR_BGR2GRAY)
|
|
232
|
+
x = cv2.cvtColor(x, cv2.COLOR_GRAY2RGB)
|
|
233
|
+
|
|
234
|
+
if self.transform is not None:
|
|
235
|
+
aug = self.transform(image=x)
|
|
236
|
+
x = aug["image"]
|
|
237
|
+
|
|
238
|
+
return x, torch.from_numpy(y).float()
|
|
239
|
+
|
|
240
|
+
def __len__(self):
|
|
241
|
+
return len(self.samples)
|