quadra 0.0.1__py3-none-any.whl → 2.1.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hydra_plugins/quadra_searchpath_plugin.py +30 -0
- quadra/__init__.py +6 -0
- quadra/callbacks/__init__.py +0 -0
- quadra/callbacks/anomalib.py +289 -0
- quadra/callbacks/lightning.py +501 -0
- quadra/callbacks/mlflow.py +291 -0
- quadra/callbacks/scheduler.py +69 -0
- quadra/configs/__init__.py +0 -0
- quadra/configs/backbone/caformer_m36.yaml +8 -0
- quadra/configs/backbone/caformer_s36.yaml +8 -0
- quadra/configs/backbone/convnextv2_base.yaml +8 -0
- quadra/configs/backbone/convnextv2_femto.yaml +8 -0
- quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
- quadra/configs/backbone/dino_vitb8.yaml +12 -0
- quadra/configs/backbone/dino_vits8.yaml +12 -0
- quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
- quadra/configs/backbone/dinov2_vits14.yaml +12 -0
- quadra/configs/backbone/efficientnet_b0.yaml +8 -0
- quadra/configs/backbone/efficientnet_b1.yaml +8 -0
- quadra/configs/backbone/efficientnet_b2.yaml +8 -0
- quadra/configs/backbone/efficientnet_b3.yaml +8 -0
- quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
- quadra/configs/backbone/levit_128s.yaml +8 -0
- quadra/configs/backbone/mnasnet0_5.yaml +9 -0
- quadra/configs/backbone/resnet101.yaml +8 -0
- quadra/configs/backbone/resnet18.yaml +8 -0
- quadra/configs/backbone/resnet18_ssl.yaml +8 -0
- quadra/configs/backbone/resnet50.yaml +8 -0
- quadra/configs/backbone/smp.yaml +9 -0
- quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
- quadra/configs/backbone/unetr.yaml +15 -0
- quadra/configs/backbone/vit16_base.yaml +9 -0
- quadra/configs/backbone/vit16_small.yaml +9 -0
- quadra/configs/backbone/vit16_tiny.yaml +9 -0
- quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
- quadra/configs/callbacks/all.yaml +32 -0
- quadra/configs/callbacks/default.yaml +37 -0
- quadra/configs/callbacks/default_anomalib.yaml +67 -0
- quadra/configs/config.yaml +33 -0
- quadra/configs/core/default.yaml +11 -0
- quadra/configs/datamodule/base/anomaly.yaml +16 -0
- quadra/configs/datamodule/base/classification.yaml +21 -0
- quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
- quadra/configs/datamodule/base/segmentation.yaml +18 -0
- quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
- quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
- quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
- quadra/configs/datamodule/base/ssl.yaml +21 -0
- quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
- quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
- quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
- quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
- quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
- quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
- quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
- quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
- quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
- quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
- quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/base/classification/classification.yaml +73 -0
- quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
- quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
- quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
- quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
- quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
- quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
- quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
- quadra/configs/experiment/base/ssl/byol.yaml +43 -0
- quadra/configs/experiment/base/ssl/dino.yaml +46 -0
- quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
- quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
- quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
- quadra/configs/experiment/custom/cls.yaml +12 -0
- quadra/configs/experiment/default.yaml +15 -0
- quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
- quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
- quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
- quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
- quadra/configs/export/default.yaml +13 -0
- quadra/configs/hydra/anomaly_custom.yaml +15 -0
- quadra/configs/hydra/default.yaml +14 -0
- quadra/configs/inference/default.yaml +26 -0
- quadra/configs/logger/comet.yaml +10 -0
- quadra/configs/logger/csv.yaml +5 -0
- quadra/configs/logger/mlflow.yaml +12 -0
- quadra/configs/logger/tensorboard.yaml +8 -0
- quadra/configs/loss/asl.yaml +7 -0
- quadra/configs/loss/barlow.yaml +2 -0
- quadra/configs/loss/bce.yaml +1 -0
- quadra/configs/loss/byol.yaml +1 -0
- quadra/configs/loss/cross_entropy.yaml +1 -0
- quadra/configs/loss/dino.yaml +8 -0
- quadra/configs/loss/simclr.yaml +2 -0
- quadra/configs/loss/simsiam.yaml +1 -0
- quadra/configs/loss/smp_ce.yaml +3 -0
- quadra/configs/loss/smp_dice.yaml +2 -0
- quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
- quadra/configs/loss/smp_mcc.yaml +2 -0
- quadra/configs/loss/vicreg.yaml +5 -0
- quadra/configs/model/anomalib/cfa.yaml +35 -0
- quadra/configs/model/anomalib/cflow.yaml +30 -0
- quadra/configs/model/anomalib/csflow.yaml +34 -0
- quadra/configs/model/anomalib/dfm.yaml +19 -0
- quadra/configs/model/anomalib/draem.yaml +29 -0
- quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
- quadra/configs/model/anomalib/fastflow.yaml +32 -0
- quadra/configs/model/anomalib/padim.yaml +32 -0
- quadra/configs/model/anomalib/patchcore.yaml +36 -0
- quadra/configs/model/barlow.yaml +16 -0
- quadra/configs/model/byol.yaml +25 -0
- quadra/configs/model/classification.yaml +10 -0
- quadra/configs/model/dino.yaml +26 -0
- quadra/configs/model/logistic_regression.yaml +4 -0
- quadra/configs/model/multilabel_classification.yaml +9 -0
- quadra/configs/model/simclr.yaml +18 -0
- quadra/configs/model/simsiam.yaml +24 -0
- quadra/configs/model/smp.yaml +4 -0
- quadra/configs/model/smp_multiclass.yaml +4 -0
- quadra/configs/model/vicreg.yaml +16 -0
- quadra/configs/optimizer/adam.yaml +5 -0
- quadra/configs/optimizer/adamw.yaml +3 -0
- quadra/configs/optimizer/default.yaml +4 -0
- quadra/configs/optimizer/lars.yaml +8 -0
- quadra/configs/optimizer/sgd.yaml +4 -0
- quadra/configs/scheduler/default.yaml +5 -0
- quadra/configs/scheduler/rop.yaml +5 -0
- quadra/configs/scheduler/step.yaml +3 -0
- quadra/configs/scheduler/warmrestart.yaml +2 -0
- quadra/configs/scheduler/warmup.yaml +6 -0
- quadra/configs/task/anomalib/cfa.yaml +5 -0
- quadra/configs/task/anomalib/cflow.yaml +5 -0
- quadra/configs/task/anomalib/csflow.yaml +5 -0
- quadra/configs/task/anomalib/draem.yaml +5 -0
- quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
- quadra/configs/task/anomalib/fastflow.yaml +5 -0
- quadra/configs/task/anomalib/inference.yaml +3 -0
- quadra/configs/task/anomalib/padim.yaml +5 -0
- quadra/configs/task/anomalib/patchcore.yaml +5 -0
- quadra/configs/task/classification.yaml +6 -0
- quadra/configs/task/classification_evaluation.yaml +6 -0
- quadra/configs/task/default.yaml +1 -0
- quadra/configs/task/segmentation.yaml +9 -0
- quadra/configs/task/segmentation_evaluation.yaml +3 -0
- quadra/configs/task/sklearn_classification.yaml +13 -0
- quadra/configs/task/sklearn_classification_patch.yaml +11 -0
- quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
- quadra/configs/task/sklearn_classification_test.yaml +8 -0
- quadra/configs/task/ssl.yaml +2 -0
- quadra/configs/trainer/lightning_cpu.yaml +36 -0
- quadra/configs/trainer/lightning_gpu.yaml +35 -0
- quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
- quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
- quadra/configs/trainer/lightning_multigpu.yaml +37 -0
- quadra/configs/trainer/sklearn_classification.yaml +7 -0
- quadra/configs/transforms/byol.yaml +47 -0
- quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
- quadra/configs/transforms/default.yaml +37 -0
- quadra/configs/transforms/default_numpy.yaml +24 -0
- quadra/configs/transforms/default_resize.yaml +22 -0
- quadra/configs/transforms/dino.yaml +63 -0
- quadra/configs/transforms/linear_eval.yaml +18 -0
- quadra/datamodules/__init__.py +20 -0
- quadra/datamodules/anomaly.py +180 -0
- quadra/datamodules/base.py +375 -0
- quadra/datamodules/classification.py +1003 -0
- quadra/datamodules/generic/__init__.py +0 -0
- quadra/datamodules/generic/imagenette.py +144 -0
- quadra/datamodules/generic/mnist.py +81 -0
- quadra/datamodules/generic/mvtec.py +58 -0
- quadra/datamodules/generic/oxford_pet.py +163 -0
- quadra/datamodules/patch.py +190 -0
- quadra/datamodules/segmentation.py +742 -0
- quadra/datamodules/ssl.py +140 -0
- quadra/datasets/__init__.py +17 -0
- quadra/datasets/anomaly.py +287 -0
- quadra/datasets/classification.py +241 -0
- quadra/datasets/patch.py +138 -0
- quadra/datasets/segmentation.py +239 -0
- quadra/datasets/ssl.py +110 -0
- quadra/losses/__init__.py +0 -0
- quadra/losses/classification/__init__.py +6 -0
- quadra/losses/classification/asl.py +83 -0
- quadra/losses/classification/focal.py +320 -0
- quadra/losses/classification/prototypical.py +148 -0
- quadra/losses/ssl/__init__.py +17 -0
- quadra/losses/ssl/barlowtwins.py +47 -0
- quadra/losses/ssl/byol.py +37 -0
- quadra/losses/ssl/dino.py +129 -0
- quadra/losses/ssl/hyperspherical.py +45 -0
- quadra/losses/ssl/idmm.py +50 -0
- quadra/losses/ssl/simclr.py +67 -0
- quadra/losses/ssl/simsiam.py +30 -0
- quadra/losses/ssl/vicreg.py +76 -0
- quadra/main.py +46 -0
- quadra/metrics/__init__.py +3 -0
- quadra/metrics/segmentation.py +251 -0
- quadra/models/__init__.py +0 -0
- quadra/models/base.py +151 -0
- quadra/models/classification/__init__.py +8 -0
- quadra/models/classification/backbones.py +149 -0
- quadra/models/classification/base.py +92 -0
- quadra/models/evaluation.py +322 -0
- quadra/modules/__init__.py +0 -0
- quadra/modules/backbone.py +30 -0
- quadra/modules/base.py +312 -0
- quadra/modules/classification/__init__.py +3 -0
- quadra/modules/classification/base.py +331 -0
- quadra/modules/ssl/__init__.py +17 -0
- quadra/modules/ssl/barlowtwins.py +59 -0
- quadra/modules/ssl/byol.py +172 -0
- quadra/modules/ssl/common.py +285 -0
- quadra/modules/ssl/dino.py +186 -0
- quadra/modules/ssl/hyperspherical.py +206 -0
- quadra/modules/ssl/idmm.py +98 -0
- quadra/modules/ssl/simclr.py +73 -0
- quadra/modules/ssl/simsiam.py +68 -0
- quadra/modules/ssl/vicreg.py +67 -0
- quadra/optimizers/__init__.py +4 -0
- quadra/optimizers/lars.py +153 -0
- quadra/optimizers/sam.py +127 -0
- quadra/schedulers/__init__.py +3 -0
- quadra/schedulers/base.py +44 -0
- quadra/schedulers/warmup.py +127 -0
- quadra/tasks/__init__.py +24 -0
- quadra/tasks/anomaly.py +582 -0
- quadra/tasks/base.py +397 -0
- quadra/tasks/classification.py +1264 -0
- quadra/tasks/patch.py +492 -0
- quadra/tasks/segmentation.py +389 -0
- quadra/tasks/ssl.py +560 -0
- quadra/trainers/README.md +3 -0
- quadra/trainers/__init__.py +0 -0
- quadra/trainers/classification.py +179 -0
- quadra/utils/__init__.py +0 -0
- quadra/utils/anomaly.py +112 -0
- quadra/utils/classification.py +618 -0
- quadra/utils/deprecation.py +31 -0
- quadra/utils/evaluation.py +474 -0
- quadra/utils/export.py +579 -0
- quadra/utils/imaging.py +32 -0
- quadra/utils/logger.py +15 -0
- quadra/utils/mlflow.py +98 -0
- quadra/utils/model_manager.py +320 -0
- quadra/utils/models.py +524 -0
- quadra/utils/patch/__init__.py +15 -0
- quadra/utils/patch/dataset.py +1433 -0
- quadra/utils/patch/metrics.py +449 -0
- quadra/utils/patch/model.py +153 -0
- quadra/utils/patch/visualization.py +217 -0
- quadra/utils/resolver.py +42 -0
- quadra/utils/segmentation.py +31 -0
- quadra/utils/tests/__init__.py +0 -0
- quadra/utils/tests/fixtures/__init__.py +1 -0
- quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
- quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
- quadra/utils/tests/fixtures/dataset/classification.py +406 -0
- quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
- quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
- quadra/utils/tests/fixtures/models/__init__.py +3 -0
- quadra/utils/tests/fixtures/models/anomaly.py +89 -0
- quadra/utils/tests/fixtures/models/classification.py +45 -0
- quadra/utils/tests/fixtures/models/segmentation.py +33 -0
- quadra/utils/tests/helpers.py +70 -0
- quadra/utils/tests/models.py +27 -0
- quadra/utils/utils.py +525 -0
- quadra/utils/validator.py +115 -0
- quadra/utils/visualization.py +422 -0
- quadra/utils/vit_explainability.py +349 -0
- quadra-2.1.13.dist-info/LICENSE +201 -0
- quadra-2.1.13.dist-info/METADATA +386 -0
- quadra-2.1.13.dist-info/RECORD +300 -0
- {quadra-0.0.1.dist-info → quadra-2.1.13.dist-info}/WHEEL +1 -1
- quadra-2.1.13.dist-info/entry_points.txt +3 -0
- quadra-0.0.1.dist-info/METADATA +0 -14
- quadra-0.0.1.dist-info/RECORD +0 -4
|
File without changes
|
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import shutil
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import pandas as pd
|
|
8
|
+
from sklearn.model_selection import train_test_split
|
|
9
|
+
from torchvision.datasets.utils import download_and_extract_archive
|
|
10
|
+
|
|
11
|
+
from quadra.datamodules import ClassificationDataModule, SSLDataModule
|
|
12
|
+
from quadra.utils.utils import get_logger
|
|
13
|
+
|
|
14
|
+
IMAGENETTE_LABEL_MAPPER = {
|
|
15
|
+
"n01440764": "tench",
|
|
16
|
+
"n02102040": "english_springer",
|
|
17
|
+
"n02979186": "cassette_player",
|
|
18
|
+
"n03000684": "chain_saw",
|
|
19
|
+
"n03028079": "church",
|
|
20
|
+
"n03394916": "french_horn",
|
|
21
|
+
"n03417042": "garbage_truck",
|
|
22
|
+
"n03425413": "gas_pump",
|
|
23
|
+
"n03445777": "golf_ball",
|
|
24
|
+
"n03888257": "parachute",
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
DEFAULT_CLASS_TO_IDX = {cl: idx for idx, cl in enumerate(sorted(IMAGENETTE_LABEL_MAPPER.values()))}
|
|
28
|
+
|
|
29
|
+
log = get_logger(__name__)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class ImagenetteClassificationDataModule(ClassificationDataModule):
|
|
33
|
+
"""Initializes the classification data module for Imagenette dataset.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
data_path: Path to the dataset.
|
|
37
|
+
name: Name of the dataset.
|
|
38
|
+
imagenette_version: Version of the Imagenette dataset. Can be 320 or 160 or full.
|
|
39
|
+
force_download: If True, the dataset will be downloaded even if the data_path already exists. The data_path
|
|
40
|
+
will be deleted and recreated.
|
|
41
|
+
class_to_idx: Dictionary mapping class names to class indices.
|
|
42
|
+
**kwargs: Keyword arguments for the ClassificationDataModule.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
data_path: str,
|
|
48
|
+
name: str = "imagenette_classification_datamodule",
|
|
49
|
+
imagenette_version: str = "320",
|
|
50
|
+
force_download: bool = False,
|
|
51
|
+
class_to_idx: dict[str, int] | None = None,
|
|
52
|
+
**kwargs: Any,
|
|
53
|
+
):
|
|
54
|
+
if imagenette_version not in ["320", "160", "full"]:
|
|
55
|
+
raise ValueError(f"imagenette_version must be one of 320, 160 or full. Got {imagenette_version} instead.")
|
|
56
|
+
|
|
57
|
+
if imagenette_version == "full":
|
|
58
|
+
imagenette_version = ""
|
|
59
|
+
else:
|
|
60
|
+
imagenette_version = f"-{imagenette_version}"
|
|
61
|
+
|
|
62
|
+
self.download_url = f"https://s3.amazonaws.com/fast-ai-imageclas/imagenette2{imagenette_version}.tgz"
|
|
63
|
+
self.force_download = force_download
|
|
64
|
+
self.imagenette_version = imagenette_version
|
|
65
|
+
|
|
66
|
+
if class_to_idx is None:
|
|
67
|
+
class_to_idx = DEFAULT_CLASS_TO_IDX
|
|
68
|
+
|
|
69
|
+
super().__init__(
|
|
70
|
+
data_path=data_path,
|
|
71
|
+
name=name,
|
|
72
|
+
test_split_file=None,
|
|
73
|
+
train_split_file=None,
|
|
74
|
+
val_size=None,
|
|
75
|
+
class_to_idx=class_to_idx,
|
|
76
|
+
**kwargs,
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
def download_data(self, download_url: str, force_download: bool = False) -> None:
|
|
80
|
+
"""Download the Imagenette dataset.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
download_url: Dataset download url.
|
|
84
|
+
force_download: If True, the dataset will be downloaded even if the data_path already exists. The data_path
|
|
85
|
+
will be removed.
|
|
86
|
+
"""
|
|
87
|
+
if os.path.exists(self.data_path):
|
|
88
|
+
if force_download:
|
|
89
|
+
log.info("The path %s already exists. Removing it and downloading the dataset again.", self.data_path)
|
|
90
|
+
shutil.rmtree(self.data_path)
|
|
91
|
+
else:
|
|
92
|
+
log.info("The path %s already exists. Skipping download.", self.data_path)
|
|
93
|
+
return
|
|
94
|
+
|
|
95
|
+
log.info("Downloading and extracting Imagenette dataset to %s", self.data_path)
|
|
96
|
+
download_and_extract_archive(download_url, self.data_path, remove_finished=True)
|
|
97
|
+
|
|
98
|
+
def _prepare_data(self) -> None:
|
|
99
|
+
"""Prepares the data for the data module."""
|
|
100
|
+
self.download_data(download_url=self.download_url, force_download=self.force_download)
|
|
101
|
+
self.data_path = os.path.join(self.data_path, f"imagenette2{self.imagenette_version}")
|
|
102
|
+
|
|
103
|
+
train_images_and_targets, class_to_idx = self._find_images_and_targets(os.path.join(self.data_path, "train"))
|
|
104
|
+
self.class_to_idx = {IMAGENETTE_LABEL_MAPPER[k]: v for k, v in class_to_idx.items()}
|
|
105
|
+
|
|
106
|
+
samples_train, targets_train = [], []
|
|
107
|
+
idx_to_class = {v: k for k, v in self.class_to_idx.items()}
|
|
108
|
+
for image, target in train_images_and_targets:
|
|
109
|
+
samples_train.append(image)
|
|
110
|
+
targets_train.append(idx_to_class[target])
|
|
111
|
+
|
|
112
|
+
samples_train, samples_val, targets_train, targets_val = train_test_split(
|
|
113
|
+
samples_train,
|
|
114
|
+
targets_train,
|
|
115
|
+
test_size=self.val_size,
|
|
116
|
+
random_state=self.seed,
|
|
117
|
+
stratify=targets_train,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
test_images_and_targets, _ = self._find_images_and_targets(os.path.join(self.data_path, "val"))
|
|
121
|
+
samples_test, targets_test = [], []
|
|
122
|
+
for image, target in test_images_and_targets:
|
|
123
|
+
samples_test.append(image)
|
|
124
|
+
targets_test.append(idx_to_class[target])
|
|
125
|
+
|
|
126
|
+
train_df = pd.DataFrame({"samples": samples_train, "targets": targets_train})
|
|
127
|
+
train_df["split"] = "train"
|
|
128
|
+
val_df = pd.DataFrame({"samples": samples_val, "targets": targets_val})
|
|
129
|
+
val_df["split"] = "val"
|
|
130
|
+
test_df = pd.DataFrame({"samples": samples_test, "targets": targets_test})
|
|
131
|
+
test_df["split"] = "test"
|
|
132
|
+
self.data = pd.concat([train_df, val_df, test_df], axis=0)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class ImagenetteSSLDataModule(ImagenetteClassificationDataModule, SSLDataModule):
|
|
136
|
+
"""Initializes the SSL data module for Imagenette dataset."""
|
|
137
|
+
|
|
138
|
+
def __init__(
|
|
139
|
+
self,
|
|
140
|
+
*args: Any,
|
|
141
|
+
name="imagenette_ssl",
|
|
142
|
+
**kwargs: Any,
|
|
143
|
+
):
|
|
144
|
+
super().__init__(*args, name=name, **kwargs) # type: ignore[misc]
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import shutil
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import cv2
|
|
8
|
+
from torchvision.datasets.mnist import MNIST
|
|
9
|
+
|
|
10
|
+
from quadra.datamodules import AnomalyDataModule
|
|
11
|
+
from quadra.utils.utils import get_logger
|
|
12
|
+
|
|
13
|
+
log = get_logger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class MNISTAnomalyDataModule(AnomalyDataModule):
|
|
17
|
+
"""Standard anomaly datamodule with automatic download of the MNIST dataset."""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self, data_path: str, good_number: int, limit_data: int = 100, category: str | None = None, **kwargs: Any
|
|
21
|
+
):
|
|
22
|
+
"""Initialize the MNIST anomaly datamodule.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
data_path: Path to the dataset
|
|
26
|
+
good_number: Which number to use as a good class, all other numbers are considered anomalies.
|
|
27
|
+
category: The category of the dataset. For mnist this is always None.
|
|
28
|
+
limit_data: Limit the number of images to use for training and testing. Defaults to 100.
|
|
29
|
+
**kwargs: Additional arguments to pass to the AnomalyDataModule.
|
|
30
|
+
"""
|
|
31
|
+
super().__init__(data_path=data_path, category=None, **kwargs)
|
|
32
|
+
self.good_number = good_number
|
|
33
|
+
self.limit_data = limit_data
|
|
34
|
+
|
|
35
|
+
def download_data(self) -> None:
|
|
36
|
+
"""Download the MNIST dataset and move images in the right folders."""
|
|
37
|
+
log.info("Generating MNIST anomaly dataset for good number %s", self.good_number)
|
|
38
|
+
|
|
39
|
+
mnist_train_dataset = MNIST(root=self.data_path, train=True, download=True)
|
|
40
|
+
mnist_test_dataset = MNIST(root=self.data_path, train=False, download=True)
|
|
41
|
+
|
|
42
|
+
self.data_path = os.path.join(self.data_path, "quadra_mnist_anomaly")
|
|
43
|
+
|
|
44
|
+
if os.path.exists(self.data_path):
|
|
45
|
+
shutil.rmtree(self.data_path)
|
|
46
|
+
|
|
47
|
+
# Create the folder structure
|
|
48
|
+
train_good_folder = os.path.join(self.data_path, "train", "good")
|
|
49
|
+
test_good_folder = os.path.join(self.data_path, "test", "good")
|
|
50
|
+
|
|
51
|
+
os.makedirs(train_good_folder, exist_ok=True)
|
|
52
|
+
os.makedirs(test_good_folder, exist_ok=True)
|
|
53
|
+
|
|
54
|
+
# Copy the good train images to the correct folder
|
|
55
|
+
good_train_samples = mnist_train_dataset.data[mnist_train_dataset.targets == self.good_number]
|
|
56
|
+
for i, image in enumerate(good_train_samples.numpy()):
|
|
57
|
+
if i == self.limit_data:
|
|
58
|
+
break
|
|
59
|
+
cv2.imwrite(os.path.join(train_good_folder, f"{i}.png"), image)
|
|
60
|
+
|
|
61
|
+
for number in range(10):
|
|
62
|
+
if number == self.good_number:
|
|
63
|
+
good_train_samples = mnist_test_dataset.data[mnist_test_dataset.targets == number]
|
|
64
|
+
for i, image in enumerate(good_train_samples.numpy()):
|
|
65
|
+
if i == self.limit_data:
|
|
66
|
+
break
|
|
67
|
+
cv2.imwrite(os.path.join(test_good_folder, f"{number}_{i}.png"), image)
|
|
68
|
+
else:
|
|
69
|
+
test_bad_folder = os.path.join(self.data_path, "test", str(number))
|
|
70
|
+
os.makedirs(test_bad_folder, exist_ok=True)
|
|
71
|
+
bad_train_samples = mnist_train_dataset.data[mnist_train_dataset.targets == number]
|
|
72
|
+
for i, image in enumerate(bad_train_samples.numpy()):
|
|
73
|
+
if i == self.limit_data:
|
|
74
|
+
break
|
|
75
|
+
|
|
76
|
+
cv2.imwrite(os.path.join(test_bad_folder, f"{number}_{i}.png"), image)
|
|
77
|
+
|
|
78
|
+
def _prepare_data(self) -> None:
|
|
79
|
+
"""Prepare the MNIST dataset."""
|
|
80
|
+
self.download_data()
|
|
81
|
+
return super()._prepare_data()
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
from torchvision.datasets.utils import download_and_extract_archive
|
|
5
|
+
|
|
6
|
+
from quadra.datamodules import AnomalyDataModule
|
|
7
|
+
from quadra.utils.utils import get_logger
|
|
8
|
+
|
|
9
|
+
log = get_logger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
DATASET_BASE_URL = "https://www.mydrive.ch/shares/38536/3830184030e49fe74747669442f0f282/download/"
|
|
13
|
+
|
|
14
|
+
DATASET_URL = {
|
|
15
|
+
"bottle": DATASET_BASE_URL + "420937370-1629951468/bottle.tar.xz",
|
|
16
|
+
"capsule": DATASET_BASE_URL + "420937454-1629951595/capsule.tar.xz",
|
|
17
|
+
"carpet": DATASET_BASE_URL + "420937484-1629951672/carpet.tar.xz",
|
|
18
|
+
"grid": DATASET_BASE_URL + "420937487-1629951814/grid.tar.xz",
|
|
19
|
+
"hazelnut": DATASET_BASE_URL + "420937545-1629951845/hazelnut.tar.xz",
|
|
20
|
+
"leather": DATASET_BASE_URL + "420937607-1629951964/leather.tar.xz",
|
|
21
|
+
"metal_nut": DATASET_BASE_URL + "420937637-1629952063/metal_nut.tar.xz",
|
|
22
|
+
"pill": DATASET_BASE_URL + "420938129-1629953099/pill.tar.xz",
|
|
23
|
+
"screw": DATASET_BASE_URL + "420938130-1629953152/screw.tar.xz",
|
|
24
|
+
"tile": DATASET_BASE_URL + "420938133-1629953189/tile.tar.xz",
|
|
25
|
+
"toothbrush": DATASET_BASE_URL + "420938134-1629953256/toothbrush.tar.xz",
|
|
26
|
+
"transistor": DATASET_BASE_URL + "420938166-1629953277/transistor.tar.xz",
|
|
27
|
+
"wood": DATASET_BASE_URL + "420938383-1629953354/wood.tar.xz",
|
|
28
|
+
"zipper": DATASET_BASE_URL + "420938385-1629953449/zipper.tar.xz",
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class MVTecDataModule(AnomalyDataModule):
|
|
33
|
+
"""Standard anomaly datamodule with automatic download of the MVTec dataset."""
|
|
34
|
+
|
|
35
|
+
def __init__(self, data_path: str, category: str, **kwargs):
|
|
36
|
+
if category not in DATASET_URL:
|
|
37
|
+
raise ValueError(f"Unknown category {category}. Available categories are {list(DATASET_URL.keys())}")
|
|
38
|
+
|
|
39
|
+
super().__init__(data_path=data_path, category=category, **kwargs)
|
|
40
|
+
|
|
41
|
+
def download_data(self) -> None:
|
|
42
|
+
"""Download the MVTec dataset."""
|
|
43
|
+
if self.category is None:
|
|
44
|
+
raise ValueError("Category must be specified for MVTec dataset.")
|
|
45
|
+
|
|
46
|
+
if os.path.exists(self.data_path):
|
|
47
|
+
log.info("The path %s already exists. Skipping download.", os.path.join(self.data_path, self.category))
|
|
48
|
+
return
|
|
49
|
+
|
|
50
|
+
log.info("Downloading and extracting MVTec dataset for category %s to %s", self.category, self.data_path)
|
|
51
|
+
# self.data_path is the path to the category folder that will be created by the download_and_extract_archive
|
|
52
|
+
data_path_no_category = str(Path(self.data_path).parent)
|
|
53
|
+
download_and_extract_archive(DATASET_URL[self.category], data_path_no_category, remove_finished=True)
|
|
54
|
+
|
|
55
|
+
def _prepare_data(self) -> None:
|
|
56
|
+
"""Prepare the MVTec dataset."""
|
|
57
|
+
self.download_data()
|
|
58
|
+
return super()._prepare_data()
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import albumentations
|
|
7
|
+
import cv2
|
|
8
|
+
import numpy as np
|
|
9
|
+
import pandas as pd
|
|
10
|
+
from torchvision.datasets.utils import download_and_extract_archive
|
|
11
|
+
|
|
12
|
+
from quadra.datamodules import SegmentationMulticlassDataModule
|
|
13
|
+
from quadra.datasets.segmentation import SegmentationDatasetMulticlass
|
|
14
|
+
from quadra.utils import utils
|
|
15
|
+
|
|
16
|
+
log = utils.get_logger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class OxfordPetSegmentationDataModule(SegmentationMulticlassDataModule):
|
|
20
|
+
"""OxfordPetSegmentationDataModule.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
data_path: path to the oxford pet dataset
|
|
24
|
+
idx_to_class: dict with corrispondence btw mask index and classes: {1: class_1, 2: class_2, ..., N: class_N}
|
|
25
|
+
except background class which is 0.
|
|
26
|
+
name: Defaults to "oxford_pet_segmentation_datamodule".
|
|
27
|
+
dataset: Defaults to SegmentationDataset.
|
|
28
|
+
batch_size: batch size for training. Defaults to 32.
|
|
29
|
+
test_size: Defaults to 0.3.
|
|
30
|
+
val_size: Defaults to 0.3.
|
|
31
|
+
seed: Defaults to 42.
|
|
32
|
+
num_workers: number of workers for data loading. Defaults to 6.
|
|
33
|
+
train_transform: Train transform. Defaults to None.
|
|
34
|
+
test_transform: Test transform. Defaults to None.
|
|
35
|
+
val_transform: Validation transform. Defaults to None.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
data_path: str,
|
|
41
|
+
idx_to_class: dict,
|
|
42
|
+
name: str = "oxford_pet_segmentation_datamodule",
|
|
43
|
+
dataset: type[SegmentationDatasetMulticlass] = SegmentationDatasetMulticlass,
|
|
44
|
+
batch_size: int = 32,
|
|
45
|
+
test_size: float = 0.3,
|
|
46
|
+
val_size: float = 0.3,
|
|
47
|
+
seed: int = 42,
|
|
48
|
+
num_workers: int = 6,
|
|
49
|
+
train_transform: albumentations.Compose | None = None,
|
|
50
|
+
test_transform: albumentations.Compose | None = None,
|
|
51
|
+
val_transform: albumentations.Compose | None = None,
|
|
52
|
+
**kwargs: Any,
|
|
53
|
+
):
|
|
54
|
+
super().__init__(
|
|
55
|
+
data_path=data_path,
|
|
56
|
+
idx_to_class=idx_to_class,
|
|
57
|
+
name=name,
|
|
58
|
+
dataset=dataset,
|
|
59
|
+
batch_size=batch_size,
|
|
60
|
+
test_size=test_size,
|
|
61
|
+
val_size=val_size,
|
|
62
|
+
seed=seed,
|
|
63
|
+
num_workers=num_workers,
|
|
64
|
+
train_transform=train_transform,
|
|
65
|
+
test_transform=test_transform,
|
|
66
|
+
val_transform=val_transform,
|
|
67
|
+
**kwargs,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
_RESOURCES = (
|
|
71
|
+
("https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz", "5c4f3ee8e5d25df40f4fd59a7f44e54c"),
|
|
72
|
+
("https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz", "95a8c909bbe2e81eed6a22bccdf3f68f"),
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
def _preprocess_mask(self, mask: np.ndarray) -> np.ndarray:
|
|
76
|
+
"""Preprocess mask function that is adapted from
|
|
77
|
+
https://albumentations.ai/docs/examples/pytorch_semantic_segmentation/.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
mask: mask to be preprocessed
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
binarized mask
|
|
84
|
+
"""
|
|
85
|
+
mask = mask.astype(np.float32)
|
|
86
|
+
mask[mask == 2.0] = 0.0
|
|
87
|
+
mask[(mask == 1.0) | (mask == 3.0)] = 1.0
|
|
88
|
+
mask = (mask > 0).astype(np.uint8)
|
|
89
|
+
return mask
|
|
90
|
+
|
|
91
|
+
def _check_exists(self, image_folder: str, annotation_folder: str) -> bool:
|
|
92
|
+
"""Check if the dataset is already downloaded."""
|
|
93
|
+
return all(os.path.exists(folder) and os.path.isdir(folder) for folder in (image_folder, annotation_folder))
|
|
94
|
+
|
|
95
|
+
def download_data(self):
|
|
96
|
+
"""Download the dataset if it is not already downloaded."""
|
|
97
|
+
image_folder = os.path.join(self.data_path, "images")
|
|
98
|
+
annotation_folder = os.path.join(self.data_path, "annotations")
|
|
99
|
+
if not self._check_exists(image_folder, annotation_folder):
|
|
100
|
+
for url, md5 in self._RESOURCES:
|
|
101
|
+
download_and_extract_archive(url, download_root=self.data_path, md5=md5, remove_finished=True)
|
|
102
|
+
log.info("Fixing corrupted files...")
|
|
103
|
+
images_filenames = sorted(os.listdir(image_folder))
|
|
104
|
+
for filename in images_filenames:
|
|
105
|
+
file_wo_ext = os.path.splitext(os.path.basename(filename))[0]
|
|
106
|
+
try:
|
|
107
|
+
mask = cv2.imread(os.path.join(annotation_folder, "trimaps", file_wo_ext + ".png"))
|
|
108
|
+
mask = self._preprocess_mask(mask)
|
|
109
|
+
if np.sum(mask) == 0:
|
|
110
|
+
os.remove(os.path.join(image_folder, filename))
|
|
111
|
+
os.remove(os.path.join(annotation_folder, "trimaps", file_wo_ext + ".png"))
|
|
112
|
+
log.info("Removed %s", filename)
|
|
113
|
+
else:
|
|
114
|
+
img = cv2.imread(os.path.join(image_folder, filename))
|
|
115
|
+
cv2.imwrite(os.path.join(image_folder, file_wo_ext + ".jpg"), img)
|
|
116
|
+
except Exception:
|
|
117
|
+
ip = os.path.join(image_folder, filename)
|
|
118
|
+
mp = os.path.join(annotation_folder, "trimaps", file_wo_ext + ".png")
|
|
119
|
+
if os.path.exists(ip):
|
|
120
|
+
os.remove(ip)
|
|
121
|
+
if os.path.exists(mp):
|
|
122
|
+
os.remove(mp)
|
|
123
|
+
log.info("Removed %s", filename)
|
|
124
|
+
|
|
125
|
+
def _prepare_data(self) -> None:
|
|
126
|
+
"""Prepare the data to be used by the DataModule."""
|
|
127
|
+
self.download_data()
|
|
128
|
+
|
|
129
|
+
trainval_split_filepath = os.path.join(self.data_path, "annotations", "trainval.txt")
|
|
130
|
+
with open(trainval_split_filepath) as f:
|
|
131
|
+
split_data = f.read().strip("\n").split("\n")
|
|
132
|
+
trainval_filenames = [
|
|
133
|
+
x.split(" ")[0]
|
|
134
|
+
for x in split_data
|
|
135
|
+
if os.path.exists(os.path.join(self.data_path, "images", x.split(" ")[0] + ".jpg"))
|
|
136
|
+
]
|
|
137
|
+
train_filenames = [x for i, x in enumerate(trainval_filenames) if i % 10 != 0]
|
|
138
|
+
val_filenames = [x for i, x in enumerate(trainval_filenames) if i % 10 == 0]
|
|
139
|
+
|
|
140
|
+
test_split_filepath = os.path.join(self.data_path, "annotations", "test.txt")
|
|
141
|
+
with open(test_split_filepath) as f:
|
|
142
|
+
split_data = f.read().strip("\n").split("\n")
|
|
143
|
+
test_filenames = [
|
|
144
|
+
x.split(" ")[0]
|
|
145
|
+
for x in split_data
|
|
146
|
+
if os.path.exists(os.path.join(self.data_path, "images", x.split(" ")[0] + ".jpg"))
|
|
147
|
+
]
|
|
148
|
+
|
|
149
|
+
df_list = []
|
|
150
|
+
for split_name, filenames in [
|
|
151
|
+
("train", train_filenames),
|
|
152
|
+
("val", val_filenames),
|
|
153
|
+
("test", test_filenames),
|
|
154
|
+
]:
|
|
155
|
+
samples = [os.path.join(self.data_path, "images", f + ".jpg") for f in filenames]
|
|
156
|
+
masks = [os.path.join(self.data_path, "annotations", "trimaps", f + ".png") for f in filenames]
|
|
157
|
+
targets = [1] * len(filenames)
|
|
158
|
+
|
|
159
|
+
df = pd.DataFrame({"samples": samples, "masks": masks, "targets": targets})
|
|
160
|
+
df["split"] = split_name
|
|
161
|
+
df_list.append(df)
|
|
162
|
+
|
|
163
|
+
self.data = pd.concat(df_list, axis=0)
|
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
|
|
6
|
+
import albumentations
|
|
7
|
+
import pandas as pd
|
|
8
|
+
from torch.utils.data import DataLoader
|
|
9
|
+
|
|
10
|
+
from quadra.datamodules.base import BaseDataModule
|
|
11
|
+
from quadra.datasets import ImageClassificationListDataset, PatchSklearnClassificationTrainDataset
|
|
12
|
+
from quadra.utils.classification import find_test_image
|
|
13
|
+
from quadra.utils.patch.dataset import PatchDatasetInfo, load_train_file
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class PatchSklearnClassificationDataModule(BaseDataModule):
|
|
17
|
+
"""DataModule for patch classification.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
data_path: Location of the dataset
|
|
21
|
+
name: Name of the datamodule
|
|
22
|
+
train_filename: Name of the file containing the list of training samples
|
|
23
|
+
exclude_filter: Filter to exclude samples from the dataset
|
|
24
|
+
include_filter: Filter to include samples from the dataset
|
|
25
|
+
class_to_idx: Dictionary mapping class names to indices
|
|
26
|
+
seed: Random seed
|
|
27
|
+
batch_size: Batch size
|
|
28
|
+
num_workers: Number of workers
|
|
29
|
+
train_transform: Transform to apply to the training samples
|
|
30
|
+
val_transform: Transform to apply to the validation samples
|
|
31
|
+
test_transform: Transform to apply to the test samples
|
|
32
|
+
balance_classes: If True repeat low represented classes
|
|
33
|
+
class_to_skip_training: List of classes skipped during training.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
data_path: str,
|
|
39
|
+
class_to_idx: dict,
|
|
40
|
+
name: str = "patch_classification_datamodule",
|
|
41
|
+
train_filename: str = "dataset.txt",
|
|
42
|
+
exclude_filter: list[str] | None = None,
|
|
43
|
+
include_filter: list[str] | None = None,
|
|
44
|
+
seed: int = 42,
|
|
45
|
+
batch_size: int = 32,
|
|
46
|
+
num_workers: int = 6,
|
|
47
|
+
train_transform: albumentations.Compose | None = None,
|
|
48
|
+
val_transform: albumentations.Compose | None = None,
|
|
49
|
+
test_transform: albumentations.Compose | None = None,
|
|
50
|
+
balance_classes: bool = False,
|
|
51
|
+
class_to_skip_training: list | None = None,
|
|
52
|
+
**kwargs,
|
|
53
|
+
):
|
|
54
|
+
super().__init__(
|
|
55
|
+
data_path=data_path,
|
|
56
|
+
name=name,
|
|
57
|
+
seed=seed,
|
|
58
|
+
num_workers=num_workers,
|
|
59
|
+
batch_size=batch_size,
|
|
60
|
+
train_transform=train_transform,
|
|
61
|
+
val_transform=val_transform,
|
|
62
|
+
test_transform=test_transform,
|
|
63
|
+
**kwargs,
|
|
64
|
+
)
|
|
65
|
+
self.class_to_idx = class_to_idx
|
|
66
|
+
self.balance_classes = balance_classes
|
|
67
|
+
self.train_filename = train_filename
|
|
68
|
+
self.include_filter = include_filter
|
|
69
|
+
self.exclude_filter = exclude_filter
|
|
70
|
+
self.class_to_skip_training = class_to_skip_training
|
|
71
|
+
|
|
72
|
+
self.train_folder = os.path.join(self.data_path, "train")
|
|
73
|
+
self.val_folder = os.path.join(self.data_path, "val")
|
|
74
|
+
self.test_folder = os.path.join(self.data_path, "test")
|
|
75
|
+
self.info: PatchDatasetInfo
|
|
76
|
+
self.train_dataset: PatchSklearnClassificationTrainDataset
|
|
77
|
+
self.val_dataset: ImageClassificationListDataset
|
|
78
|
+
self.test_dataset: ImageClassificationListDataset
|
|
79
|
+
|
|
80
|
+
def _prepare_data(self):
|
|
81
|
+
"""Prepare data function."""
|
|
82
|
+
if os.path.isfile(os.path.join(self.data_path, "info.json")):
|
|
83
|
+
with open(os.path.join(self.data_path, "info.json")) as f:
|
|
84
|
+
self.info = PatchDatasetInfo(**json.load(f))
|
|
85
|
+
else:
|
|
86
|
+
raise FileNotFoundError("No `info.json` file found in the dataset folder")
|
|
87
|
+
|
|
88
|
+
split_df_list: list[pd.DataFrame] = []
|
|
89
|
+
if os.path.isfile(os.path.join(self.train_folder, self.train_filename)):
|
|
90
|
+
train_samples, train_labels = load_train_file(
|
|
91
|
+
train_file_path=os.path.join(self.train_folder, self.train_filename),
|
|
92
|
+
include_filter=self.include_filter,
|
|
93
|
+
exclude_filter=self.exclude_filter,
|
|
94
|
+
class_to_skip=self.class_to_skip_training,
|
|
95
|
+
)
|
|
96
|
+
train_df = pd.DataFrame({"samples": train_samples, "targets": train_labels})
|
|
97
|
+
train_df["split"] = "train"
|
|
98
|
+
split_df_list.append(train_df)
|
|
99
|
+
if os.path.isdir(self.val_folder):
|
|
100
|
+
val_samples, val_labels = find_test_image(
|
|
101
|
+
folder=self.val_folder,
|
|
102
|
+
exclude_filter=self.exclude_filter,
|
|
103
|
+
include_filter=self.include_filter,
|
|
104
|
+
include_none_class=False,
|
|
105
|
+
)
|
|
106
|
+
val_df = pd.DataFrame({"samples": val_samples, "targets": val_labels})
|
|
107
|
+
val_df["split"] = "val"
|
|
108
|
+
split_df_list.append(val_df)
|
|
109
|
+
if os.path.isdir(self.test_folder):
|
|
110
|
+
test_samples, test_labels = find_test_image(
|
|
111
|
+
folder=self.test_folder,
|
|
112
|
+
exclude_filter=self.exclude_filter,
|
|
113
|
+
include_filter=self.include_filter,
|
|
114
|
+
include_none_class=True,
|
|
115
|
+
)
|
|
116
|
+
test_df = pd.DataFrame({"samples": test_samples, "targets": test_labels})
|
|
117
|
+
test_df["split"] = "test"
|
|
118
|
+
split_df_list.append(test_df)
|
|
119
|
+
if len(split_df_list) == 0:
|
|
120
|
+
raise ValueError("No data found in all split folders")
|
|
121
|
+
self.data = pd.concat(split_df_list, axis=0)
|
|
122
|
+
|
|
123
|
+
def setup(self, stage: str | None = None) -> None:
|
|
124
|
+
"""Setup function."""
|
|
125
|
+
if stage == "fit":
|
|
126
|
+
self.train_dataset = PatchSklearnClassificationTrainDataset(
|
|
127
|
+
data_path=self.data_path,
|
|
128
|
+
class_to_idx=self.class_to_idx,
|
|
129
|
+
samples=self.data[self.data["split"] == "train"]["samples"].tolist(),
|
|
130
|
+
targets=self.data[self.data["split"] == "train"]["targets"].tolist(),
|
|
131
|
+
transform=self.train_transform,
|
|
132
|
+
balance_classes=self.balance_classes,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
self.val_dataset = ImageClassificationListDataset(
|
|
136
|
+
class_to_idx=self.class_to_idx,
|
|
137
|
+
samples=self.data[self.data["split"] == "val"]["samples"].tolist(),
|
|
138
|
+
targets=self.data[self.data["split"] == "val"]["targets"].tolist(),
|
|
139
|
+
transform=self.val_transform,
|
|
140
|
+
allow_missing_label=False,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
elif stage in ["test", "predict"]:
|
|
144
|
+
self.test_dataset = ImageClassificationListDataset(
|
|
145
|
+
class_to_idx=self.class_to_idx,
|
|
146
|
+
samples=self.data[self.data["split"] == "test"]["samples"].tolist(),
|
|
147
|
+
targets=self.data[self.data["split"] == "test"]["targets"].tolist(),
|
|
148
|
+
transform=self.test_transform,
|
|
149
|
+
allow_missing_label=True,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
def train_dataloader(self) -> DataLoader:
|
|
153
|
+
"""Return the train dataloader."""
|
|
154
|
+
if not self.train_dataset_available:
|
|
155
|
+
raise ValueError("No training sample is available")
|
|
156
|
+
return DataLoader(
|
|
157
|
+
self.train_dataset,
|
|
158
|
+
batch_size=self.batch_size,
|
|
159
|
+
shuffle=True,
|
|
160
|
+
num_workers=self.num_workers,
|
|
161
|
+
drop_last=False,
|
|
162
|
+
pin_memory=True,
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
def val_dataloader(self) -> DataLoader:
|
|
166
|
+
"""Return the validation dataloader."""
|
|
167
|
+
if not self.val_dataset_available:
|
|
168
|
+
raise ValueError("No validation dataset is available")
|
|
169
|
+
return DataLoader(
|
|
170
|
+
self.val_dataset,
|
|
171
|
+
batch_size=self.batch_size,
|
|
172
|
+
shuffle=False,
|
|
173
|
+
num_workers=self.num_workers,
|
|
174
|
+
drop_last=False,
|
|
175
|
+
pin_memory=True,
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
def test_dataloader(self) -> DataLoader:
|
|
179
|
+
"""Return the test dataloader."""
|
|
180
|
+
if not self.test_dataset_available:
|
|
181
|
+
raise ValueError("No test dataset is available")
|
|
182
|
+
|
|
183
|
+
return DataLoader(
|
|
184
|
+
self.test_dataset,
|
|
185
|
+
batch_size=self.batch_size,
|
|
186
|
+
shuffle=False,
|
|
187
|
+
num_workers=self.num_workers,
|
|
188
|
+
drop_last=False,
|
|
189
|
+
pin_memory=True,
|
|
190
|
+
)
|