quadra 0.0.1__py3-none-any.whl → 2.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hydra_plugins/quadra_searchpath_plugin.py +30 -0
- quadra/__init__.py +6 -0
- quadra/callbacks/__init__.py +0 -0
- quadra/callbacks/anomalib.py +289 -0
- quadra/callbacks/lightning.py +501 -0
- quadra/callbacks/mlflow.py +291 -0
- quadra/callbacks/scheduler.py +69 -0
- quadra/configs/__init__.py +0 -0
- quadra/configs/backbone/caformer_m36.yaml +8 -0
- quadra/configs/backbone/caformer_s36.yaml +8 -0
- quadra/configs/backbone/convnextv2_base.yaml +8 -0
- quadra/configs/backbone/convnextv2_femto.yaml +8 -0
- quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
- quadra/configs/backbone/dino_vitb8.yaml +12 -0
- quadra/configs/backbone/dino_vits8.yaml +12 -0
- quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
- quadra/configs/backbone/dinov2_vits14.yaml +12 -0
- quadra/configs/backbone/efficientnet_b0.yaml +8 -0
- quadra/configs/backbone/efficientnet_b1.yaml +8 -0
- quadra/configs/backbone/efficientnet_b2.yaml +8 -0
- quadra/configs/backbone/efficientnet_b3.yaml +8 -0
- quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
- quadra/configs/backbone/levit_128s.yaml +8 -0
- quadra/configs/backbone/mnasnet0_5.yaml +9 -0
- quadra/configs/backbone/resnet101.yaml +8 -0
- quadra/configs/backbone/resnet18.yaml +8 -0
- quadra/configs/backbone/resnet18_ssl.yaml +8 -0
- quadra/configs/backbone/resnet50.yaml +8 -0
- quadra/configs/backbone/smp.yaml +9 -0
- quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
- quadra/configs/backbone/unetr.yaml +15 -0
- quadra/configs/backbone/vit16_base.yaml +9 -0
- quadra/configs/backbone/vit16_small.yaml +9 -0
- quadra/configs/backbone/vit16_tiny.yaml +9 -0
- quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
- quadra/configs/callbacks/all.yaml +45 -0
- quadra/configs/callbacks/default.yaml +34 -0
- quadra/configs/callbacks/default_anomalib.yaml +64 -0
- quadra/configs/config.yaml +33 -0
- quadra/configs/core/default.yaml +11 -0
- quadra/configs/datamodule/base/anomaly.yaml +16 -0
- quadra/configs/datamodule/base/classification.yaml +21 -0
- quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
- quadra/configs/datamodule/base/segmentation.yaml +18 -0
- quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
- quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
- quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
- quadra/configs/datamodule/base/ssl.yaml +21 -0
- quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
- quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
- quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
- quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
- quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
- quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
- quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
- quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
- quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
- quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
- quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/base/classification/classification.yaml +73 -0
- quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
- quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
- quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
- quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
- quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
- quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
- quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
- quadra/configs/experiment/base/ssl/byol.yaml +43 -0
- quadra/configs/experiment/base/ssl/dino.yaml +46 -0
- quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
- quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
- quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
- quadra/configs/experiment/custom/cls.yaml +12 -0
- quadra/configs/experiment/default.yaml +15 -0
- quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
- quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
- quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
- quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
- quadra/configs/export/default.yaml +13 -0
- quadra/configs/hydra/anomaly_custom.yaml +15 -0
- quadra/configs/hydra/default.yaml +14 -0
- quadra/configs/inference/default.yaml +26 -0
- quadra/configs/logger/comet.yaml +10 -0
- quadra/configs/logger/csv.yaml +5 -0
- quadra/configs/logger/mlflow.yaml +12 -0
- quadra/configs/logger/tensorboard.yaml +8 -0
- quadra/configs/loss/asl.yaml +7 -0
- quadra/configs/loss/barlow.yaml +2 -0
- quadra/configs/loss/bce.yaml +1 -0
- quadra/configs/loss/byol.yaml +1 -0
- quadra/configs/loss/cross_entropy.yaml +1 -0
- quadra/configs/loss/dino.yaml +8 -0
- quadra/configs/loss/simclr.yaml +2 -0
- quadra/configs/loss/simsiam.yaml +1 -0
- quadra/configs/loss/smp_ce.yaml +3 -0
- quadra/configs/loss/smp_dice.yaml +2 -0
- quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
- quadra/configs/loss/smp_mcc.yaml +2 -0
- quadra/configs/loss/vicreg.yaml +5 -0
- quadra/configs/model/anomalib/cfa.yaml +35 -0
- quadra/configs/model/anomalib/cflow.yaml +30 -0
- quadra/configs/model/anomalib/csflow.yaml +34 -0
- quadra/configs/model/anomalib/dfm.yaml +19 -0
- quadra/configs/model/anomalib/draem.yaml +29 -0
- quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
- quadra/configs/model/anomalib/fastflow.yaml +32 -0
- quadra/configs/model/anomalib/padim.yaml +32 -0
- quadra/configs/model/anomalib/patchcore.yaml +36 -0
- quadra/configs/model/barlow.yaml +16 -0
- quadra/configs/model/byol.yaml +25 -0
- quadra/configs/model/classification.yaml +10 -0
- quadra/configs/model/dino.yaml +26 -0
- quadra/configs/model/logistic_regression.yaml +4 -0
- quadra/configs/model/multilabel_classification.yaml +9 -0
- quadra/configs/model/simclr.yaml +18 -0
- quadra/configs/model/simsiam.yaml +24 -0
- quadra/configs/model/smp.yaml +4 -0
- quadra/configs/model/smp_multiclass.yaml +4 -0
- quadra/configs/model/vicreg.yaml +16 -0
- quadra/configs/optimizer/adam.yaml +5 -0
- quadra/configs/optimizer/adamw.yaml +3 -0
- quadra/configs/optimizer/default.yaml +4 -0
- quadra/configs/optimizer/lars.yaml +8 -0
- quadra/configs/optimizer/sgd.yaml +4 -0
- quadra/configs/scheduler/default.yaml +5 -0
- quadra/configs/scheduler/rop.yaml +5 -0
- quadra/configs/scheduler/step.yaml +3 -0
- quadra/configs/scheduler/warmrestart.yaml +2 -0
- quadra/configs/scheduler/warmup.yaml +6 -0
- quadra/configs/task/anomalib/cfa.yaml +5 -0
- quadra/configs/task/anomalib/cflow.yaml +5 -0
- quadra/configs/task/anomalib/csflow.yaml +5 -0
- quadra/configs/task/anomalib/draem.yaml +5 -0
- quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
- quadra/configs/task/anomalib/fastflow.yaml +5 -0
- quadra/configs/task/anomalib/inference.yaml +3 -0
- quadra/configs/task/anomalib/padim.yaml +5 -0
- quadra/configs/task/anomalib/patchcore.yaml +5 -0
- quadra/configs/task/classification.yaml +6 -0
- quadra/configs/task/classification_evaluation.yaml +6 -0
- quadra/configs/task/default.yaml +1 -0
- quadra/configs/task/segmentation.yaml +9 -0
- quadra/configs/task/segmentation_evaluation.yaml +3 -0
- quadra/configs/task/sklearn_classification.yaml +13 -0
- quadra/configs/task/sklearn_classification_patch.yaml +11 -0
- quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
- quadra/configs/task/sklearn_classification_test.yaml +8 -0
- quadra/configs/task/ssl.yaml +2 -0
- quadra/configs/trainer/lightning_cpu.yaml +36 -0
- quadra/configs/trainer/lightning_gpu.yaml +35 -0
- quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
- quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
- quadra/configs/trainer/lightning_multigpu.yaml +37 -0
- quadra/configs/trainer/sklearn_classification.yaml +7 -0
- quadra/configs/transforms/byol.yaml +47 -0
- quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
- quadra/configs/transforms/default.yaml +37 -0
- quadra/configs/transforms/default_numpy.yaml +24 -0
- quadra/configs/transforms/default_resize.yaml +22 -0
- quadra/configs/transforms/dino.yaml +63 -0
- quadra/configs/transforms/linear_eval.yaml +18 -0
- quadra/datamodules/__init__.py +20 -0
- quadra/datamodules/anomaly.py +180 -0
- quadra/datamodules/base.py +375 -0
- quadra/datamodules/classification.py +1003 -0
- quadra/datamodules/generic/__init__.py +0 -0
- quadra/datamodules/generic/imagenette.py +144 -0
- quadra/datamodules/generic/mnist.py +81 -0
- quadra/datamodules/generic/mvtec.py +58 -0
- quadra/datamodules/generic/oxford_pet.py +163 -0
- quadra/datamodules/patch.py +190 -0
- quadra/datamodules/segmentation.py +742 -0
- quadra/datamodules/ssl.py +140 -0
- quadra/datasets/__init__.py +17 -0
- quadra/datasets/anomaly.py +287 -0
- quadra/datasets/classification.py +241 -0
- quadra/datasets/patch.py +138 -0
- quadra/datasets/segmentation.py +239 -0
- quadra/datasets/ssl.py +110 -0
- quadra/losses/__init__.py +0 -0
- quadra/losses/classification/__init__.py +6 -0
- quadra/losses/classification/asl.py +83 -0
- quadra/losses/classification/focal.py +320 -0
- quadra/losses/classification/prototypical.py +148 -0
- quadra/losses/ssl/__init__.py +17 -0
- quadra/losses/ssl/barlowtwins.py +47 -0
- quadra/losses/ssl/byol.py +37 -0
- quadra/losses/ssl/dino.py +129 -0
- quadra/losses/ssl/hyperspherical.py +45 -0
- quadra/losses/ssl/idmm.py +50 -0
- quadra/losses/ssl/simclr.py +67 -0
- quadra/losses/ssl/simsiam.py +30 -0
- quadra/losses/ssl/vicreg.py +76 -0
- quadra/main.py +49 -0
- quadra/metrics/__init__.py +3 -0
- quadra/metrics/segmentation.py +251 -0
- quadra/models/__init__.py +0 -0
- quadra/models/base.py +151 -0
- quadra/models/classification/__init__.py +8 -0
- quadra/models/classification/backbones.py +149 -0
- quadra/models/classification/base.py +92 -0
- quadra/models/evaluation.py +322 -0
- quadra/modules/__init__.py +0 -0
- quadra/modules/backbone.py +30 -0
- quadra/modules/base.py +312 -0
- quadra/modules/classification/__init__.py +3 -0
- quadra/modules/classification/base.py +327 -0
- quadra/modules/ssl/__init__.py +17 -0
- quadra/modules/ssl/barlowtwins.py +59 -0
- quadra/modules/ssl/byol.py +172 -0
- quadra/modules/ssl/common.py +285 -0
- quadra/modules/ssl/dino.py +186 -0
- quadra/modules/ssl/hyperspherical.py +206 -0
- quadra/modules/ssl/idmm.py +98 -0
- quadra/modules/ssl/simclr.py +73 -0
- quadra/modules/ssl/simsiam.py +68 -0
- quadra/modules/ssl/vicreg.py +67 -0
- quadra/optimizers/__init__.py +4 -0
- quadra/optimizers/lars.py +153 -0
- quadra/optimizers/sam.py +127 -0
- quadra/schedulers/__init__.py +3 -0
- quadra/schedulers/base.py +44 -0
- quadra/schedulers/warmup.py +127 -0
- quadra/tasks/__init__.py +24 -0
- quadra/tasks/anomaly.py +582 -0
- quadra/tasks/base.py +397 -0
- quadra/tasks/classification.py +1263 -0
- quadra/tasks/patch.py +492 -0
- quadra/tasks/segmentation.py +389 -0
- quadra/tasks/ssl.py +560 -0
- quadra/trainers/README.md +3 -0
- quadra/trainers/__init__.py +0 -0
- quadra/trainers/classification.py +179 -0
- quadra/utils/__init__.py +0 -0
- quadra/utils/anomaly.py +112 -0
- quadra/utils/classification.py +618 -0
- quadra/utils/deprecation.py +31 -0
- quadra/utils/evaluation.py +474 -0
- quadra/utils/export.py +585 -0
- quadra/utils/imaging.py +32 -0
- quadra/utils/logger.py +15 -0
- quadra/utils/mlflow.py +98 -0
- quadra/utils/model_manager.py +320 -0
- quadra/utils/models.py +523 -0
- quadra/utils/patch/__init__.py +15 -0
- quadra/utils/patch/dataset.py +1433 -0
- quadra/utils/patch/metrics.py +449 -0
- quadra/utils/patch/model.py +153 -0
- quadra/utils/patch/visualization.py +217 -0
- quadra/utils/resolver.py +42 -0
- quadra/utils/segmentation.py +31 -0
- quadra/utils/tests/__init__.py +0 -0
- quadra/utils/tests/fixtures/__init__.py +1 -0
- quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
- quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
- quadra/utils/tests/fixtures/dataset/classification.py +406 -0
- quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
- quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
- quadra/utils/tests/fixtures/models/__init__.py +3 -0
- quadra/utils/tests/fixtures/models/anomaly.py +89 -0
- quadra/utils/tests/fixtures/models/classification.py +45 -0
- quadra/utils/tests/fixtures/models/segmentation.py +33 -0
- quadra/utils/tests/helpers.py +70 -0
- quadra/utils/tests/models.py +27 -0
- quadra/utils/utils.py +525 -0
- quadra/utils/validator.py +115 -0
- quadra/utils/visualization.py +422 -0
- quadra/utils/vit_explainability.py +349 -0
- quadra-2.2.7.dist-info/LICENSE +201 -0
- quadra-2.2.7.dist-info/METADATA +381 -0
- quadra-2.2.7.dist-info/RECORD +300 -0
- {quadra-0.0.1.dist-info → quadra-2.2.7.dist-info}/WHEEL +1 -1
- quadra-2.2.7.dist-info/entry_points.txt +3 -0
- quadra-0.0.1.dist-info/METADATA +0 -14
- quadra-0.0.1.dist-info/RECORD +0 -4
|
@@ -0,0 +1,742 @@
|
|
|
1
|
+
# pylint: disable=unsubscriptable-object,unsupported-assignment-operation,unsupported-membership-test
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import glob
|
|
5
|
+
import os
|
|
6
|
+
import random
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
import albumentations
|
|
10
|
+
import cv2
|
|
11
|
+
import numpy as np
|
|
12
|
+
import pandas as pd
|
|
13
|
+
from sklearn.model_selection import train_test_split
|
|
14
|
+
from skmultilearn.model_selection import iterative_train_test_split
|
|
15
|
+
from torch.utils.data import DataLoader
|
|
16
|
+
|
|
17
|
+
from quadra.datamodules.base import BaseDataModule
|
|
18
|
+
from quadra.datasets.segmentation import SegmentationDataset, SegmentationDatasetMulticlass
|
|
19
|
+
from quadra.utils import utils
|
|
20
|
+
|
|
21
|
+
log = utils.get_logger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class SegmentationDataModule(BaseDataModule):
|
|
25
|
+
"""Base class for segmentation datasets.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
data_path: Path to the data main folder.
|
|
29
|
+
name: The name for the data module. Defaults to "segmentation_datamodule".
|
|
30
|
+
val_size: The validation split. Defaults to 0.2.
|
|
31
|
+
test_size: The test split. Defaults to 0.2.
|
|
32
|
+
seed: Random generator seed. Defaults to 42.
|
|
33
|
+
dataset: Dataset class.
|
|
34
|
+
batch_size: Batch size. Defaults to 32.
|
|
35
|
+
num_workers: Number of workers for dataloaders. Defaults to 16.
|
|
36
|
+
train_transform: Transformations for train dataset.
|
|
37
|
+
Defaults to None.
|
|
38
|
+
val_transform: Transformations for validation dataset.
|
|
39
|
+
Defaults to None.
|
|
40
|
+
test_transform: Transformations for test dataset.
|
|
41
|
+
Defaults to None.
|
|
42
|
+
num_data_class: The number of samples per class. Defaults to None.
|
|
43
|
+
exclude_good: If True, exclude good samples from the dataset. Defaults to False.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
data_path: str,
|
|
49
|
+
name: str = "segmentation_datamodule",
|
|
50
|
+
test_size: float = 0.3,
|
|
51
|
+
val_size: float = 0.3,
|
|
52
|
+
seed: int = 42,
|
|
53
|
+
dataset: type[SegmentationDataset] = SegmentationDataset,
|
|
54
|
+
batch_size: int = 32,
|
|
55
|
+
num_workers: int = 6,
|
|
56
|
+
train_transform: albumentations.Compose | None = None,
|
|
57
|
+
test_transform: albumentations.Compose | None = None,
|
|
58
|
+
val_transform: albumentations.Compose | None = None,
|
|
59
|
+
train_split_file: str | None = None,
|
|
60
|
+
test_split_file: str | None = None,
|
|
61
|
+
val_split_file: str | None = None,
|
|
62
|
+
num_data_class: int | None = None,
|
|
63
|
+
exclude_good: bool = False,
|
|
64
|
+
**kwargs: Any,
|
|
65
|
+
):
|
|
66
|
+
super().__init__(
|
|
67
|
+
data_path=data_path,
|
|
68
|
+
name=name,
|
|
69
|
+
seed=seed,
|
|
70
|
+
batch_size=batch_size,
|
|
71
|
+
num_workers=num_workers,
|
|
72
|
+
train_transform=train_transform,
|
|
73
|
+
val_transform=val_transform,
|
|
74
|
+
test_transform=test_transform,
|
|
75
|
+
**kwargs,
|
|
76
|
+
)
|
|
77
|
+
self.test_size = test_size
|
|
78
|
+
self.val_size = val_size
|
|
79
|
+
self.num_data_class = num_data_class
|
|
80
|
+
self.exclude_good = exclude_good
|
|
81
|
+
self.train_split_file = train_split_file
|
|
82
|
+
self.test_split_file = test_split_file
|
|
83
|
+
self.val_split_file = val_split_file
|
|
84
|
+
self.dataset = dataset
|
|
85
|
+
self.train_dataset: SegmentationDataset
|
|
86
|
+
self.val_dataset: SegmentationDataset
|
|
87
|
+
self.test_dataset: SegmentationDataset
|
|
88
|
+
|
|
89
|
+
def _preprocess_mask(self, mask) -> np.ndarray:
|
|
90
|
+
"""Binarize mask using 0 as threshold."""
|
|
91
|
+
mask = (mask > 0).astype(np.uint8)
|
|
92
|
+
return mask
|
|
93
|
+
|
|
94
|
+
@staticmethod
|
|
95
|
+
def _resolve_label(path: str) -> int:
|
|
96
|
+
"""Resolve label from mask.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
path: Path to the mask.
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
0 if the mask is empty, 1 otherwise
|
|
103
|
+
"""
|
|
104
|
+
if cv2.imread(path).sum() == 0:
|
|
105
|
+
return 0
|
|
106
|
+
|
|
107
|
+
return 1
|
|
108
|
+
|
|
109
|
+
def _read_folder(self, data_path: str) -> tuple[list[str], list[int], list[str]]:
|
|
110
|
+
"""Read a folder containing images and masks subfolders.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
data_path: Path to the data folder.
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
List of paths to the images, associated binary targets and list to paths to the masks.
|
|
117
|
+
"""
|
|
118
|
+
samples = []
|
|
119
|
+
targets = []
|
|
120
|
+
masks = []
|
|
121
|
+
|
|
122
|
+
for im in glob.glob(os.path.join(data_path, "images", "*")):
|
|
123
|
+
if im[0] == ".":
|
|
124
|
+
continue
|
|
125
|
+
|
|
126
|
+
mask_path = glob.glob(os.path.splitext(im.replace("images", "masks"))[0] + ".*")
|
|
127
|
+
|
|
128
|
+
if len(mask_path) == 0:
|
|
129
|
+
log.debug("Mask not found: %s", os.path.basename(im))
|
|
130
|
+
continue
|
|
131
|
+
|
|
132
|
+
if len(mask_path) > 1:
|
|
133
|
+
raise ValueError(f"Multiple masks found for image: {os.path.basename(im)}, this is not supported")
|
|
134
|
+
|
|
135
|
+
target = self._resolve_label(mask_path[0])
|
|
136
|
+
samples.append(im)
|
|
137
|
+
targets.append(target)
|
|
138
|
+
masks.append(mask_path[0])
|
|
139
|
+
|
|
140
|
+
return samples, targets, masks
|
|
141
|
+
|
|
142
|
+
def _read_split(self, split_file: str) -> tuple[list[str], list[int], list[str]]:
|
|
143
|
+
"""Reads split file.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
split_file: Path to the split file.
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
List of paths to images, List of labels.
|
|
150
|
+
"""
|
|
151
|
+
samples, targets, masks = [], [], []
|
|
152
|
+
with open(split_file) as f:
|
|
153
|
+
split = f.read().splitlines()
|
|
154
|
+
for sample in split:
|
|
155
|
+
sample_path = os.path.join(self.data_path, sample)
|
|
156
|
+
mask_path = glob.glob(os.path.splitext(sample_path.replace("images", "masks"))[0] + ".*")
|
|
157
|
+
|
|
158
|
+
if len(mask_path) == 0:
|
|
159
|
+
log.debug("Mask not found: %s", os.path.basename(sample_path))
|
|
160
|
+
continue
|
|
161
|
+
|
|
162
|
+
if len(mask_path) > 1:
|
|
163
|
+
raise ValueError(
|
|
164
|
+
f"Multiple masks found for image: {os.path.basename(sample_path)}, this is not supported"
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
target = self._resolve_label(mask_path[0])
|
|
168
|
+
samples.append(sample_path)
|
|
169
|
+
targets.append(target)
|
|
170
|
+
masks.append(mask_path[0])
|
|
171
|
+
|
|
172
|
+
return samples, targets, masks
|
|
173
|
+
|
|
174
|
+
def _prepare_data(self) -> None:
|
|
175
|
+
"""Prepare data for training and testing."""
|
|
176
|
+
if not (self.test_split_file and self.train_split_file and self.val_split_file):
|
|
177
|
+
all_samples, all_targets, all_masks = self._read_folder(self.data_path)
|
|
178
|
+
samples_train, samples_test, targets_train, targets_test, masks_train, masks_test = train_test_split(
|
|
179
|
+
all_samples,
|
|
180
|
+
all_targets,
|
|
181
|
+
all_masks,
|
|
182
|
+
test_size=self.test_size,
|
|
183
|
+
random_state=self.seed,
|
|
184
|
+
stratify=all_targets,
|
|
185
|
+
)
|
|
186
|
+
if self.test_split_file:
|
|
187
|
+
samples_test, targets_test, masks_test = self._read_split(self.test_split_file)
|
|
188
|
+
if not self.train_split_file:
|
|
189
|
+
samples_train, targets_train, masks_train = [], [], []
|
|
190
|
+
for sample, target, mask in zip(all_samples, all_targets, all_masks):
|
|
191
|
+
if sample not in samples_test:
|
|
192
|
+
samples_train.append(sample)
|
|
193
|
+
targets_train.append(target)
|
|
194
|
+
masks_train.append(mask)
|
|
195
|
+
|
|
196
|
+
if self.train_split_file:
|
|
197
|
+
samples_train, targets_train, masks_train = self._read_split(self.train_split_file)
|
|
198
|
+
if not self.test_split_file:
|
|
199
|
+
samples_test, targets_test, masks_test = [], [], []
|
|
200
|
+
for sample, target, mask in zip(all_samples, all_targets, all_masks):
|
|
201
|
+
if sample not in samples_train:
|
|
202
|
+
samples_test.append(sample)
|
|
203
|
+
targets_test.append(target)
|
|
204
|
+
masks_test.append(mask)
|
|
205
|
+
|
|
206
|
+
if self.val_split_file:
|
|
207
|
+
if not self.test_split_file or not self.train_split_file:
|
|
208
|
+
raise ValueError("Validation split file is specified but no train or test split file is specified.")
|
|
209
|
+
samples_val, targets_val, masks_val = self._read_split(self.val_split_file)
|
|
210
|
+
else:
|
|
211
|
+
samples_train, samples_val, targets_train, targets_val, masks_train, masks_val = train_test_split(
|
|
212
|
+
samples_train,
|
|
213
|
+
targets_train,
|
|
214
|
+
masks_train,
|
|
215
|
+
test_size=self.val_size,
|
|
216
|
+
random_state=self.seed,
|
|
217
|
+
stratify=targets_train,
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
if self.exclude_good:
|
|
221
|
+
samples_train = list(np.array(samples_train)[np.array(targets_train) != 0])
|
|
222
|
+
masks_train = list(np.array(masks_train)[np.array(targets_train) != 0])
|
|
223
|
+
targets_train = list(np.array(targets_train)[np.array(targets_train) != 0])
|
|
224
|
+
|
|
225
|
+
if self.num_data_class is not None:
|
|
226
|
+
samples_train_topick = []
|
|
227
|
+
targets_train_topick = []
|
|
228
|
+
masks_train_topick = []
|
|
229
|
+
|
|
230
|
+
for cl in np.unique(targets_train):
|
|
231
|
+
idx = np.where(np.array(targets_train) == cl)[0].tolist()
|
|
232
|
+
random.seed(self.seed)
|
|
233
|
+
random.shuffle(idx)
|
|
234
|
+
to_pick = idx[: self.num_data_class]
|
|
235
|
+
for i in to_pick:
|
|
236
|
+
samples_train_topick.append(samples_train[i])
|
|
237
|
+
targets_train_topick.append(cl)
|
|
238
|
+
masks_train_topick.append(masks_train[i])
|
|
239
|
+
|
|
240
|
+
samples_train = samples_train_topick
|
|
241
|
+
targets_train = targets_train_topick
|
|
242
|
+
masks_train = masks_train_topick
|
|
243
|
+
|
|
244
|
+
df_list = []
|
|
245
|
+
for split_name, samples, targets, masks in [
|
|
246
|
+
("train", samples_train, targets_train, masks_train),
|
|
247
|
+
("val", samples_val, targets_val, masks_val),
|
|
248
|
+
("test", samples_test, targets_test, masks_test),
|
|
249
|
+
]:
|
|
250
|
+
df = pd.DataFrame({"samples": samples, "targets": targets, "masks": masks})
|
|
251
|
+
df["split"] = split_name
|
|
252
|
+
df_list.append(df)
|
|
253
|
+
|
|
254
|
+
self.data = pd.concat(df_list, axis=0)
|
|
255
|
+
|
|
256
|
+
def setup(self, stage=None):
|
|
257
|
+
"""Setup data module based on stages of training."""
|
|
258
|
+
if stage in ["fit", "train"]:
|
|
259
|
+
self.train_dataset = self.dataset(
|
|
260
|
+
image_paths=self.data[self.data["split"] == "train"]["samples"].tolist(),
|
|
261
|
+
mask_paths=self.data[self.data["split"] == "train"]["masks"].tolist(),
|
|
262
|
+
mask_preprocess=self._preprocess_mask,
|
|
263
|
+
labels=self.data[self.data["split"] == "train"]["targets"].tolist(),
|
|
264
|
+
object_masks=None,
|
|
265
|
+
transform=self.train_transform,
|
|
266
|
+
batch_size=None,
|
|
267
|
+
defect_transform=None,
|
|
268
|
+
resize=None,
|
|
269
|
+
)
|
|
270
|
+
self.val_dataset = self.dataset(
|
|
271
|
+
image_paths=self.data[self.data["split"] == "val"]["samples"].tolist(),
|
|
272
|
+
mask_paths=self.data[self.data["split"] == "val"]["masks"].tolist(),
|
|
273
|
+
defect_transform=None,
|
|
274
|
+
labels=self.data[self.data["split"] == "val"]["targets"].tolist(),
|
|
275
|
+
object_masks=None,
|
|
276
|
+
batch_size=None,
|
|
277
|
+
mask_preprocess=self._preprocess_mask,
|
|
278
|
+
transform=self.test_transform,
|
|
279
|
+
resize=None,
|
|
280
|
+
)
|
|
281
|
+
elif stage == "test":
|
|
282
|
+
self.test_dataset = self.dataset(
|
|
283
|
+
image_paths=self.data[self.data["split"] == "test"]["samples"].tolist(),
|
|
284
|
+
mask_paths=self.data[self.data["split"] == "test"]["masks"].tolist(),
|
|
285
|
+
labels=self.data[self.data["split"] == "test"]["targets"].tolist(),
|
|
286
|
+
object_masks=None,
|
|
287
|
+
batch_size=None,
|
|
288
|
+
mask_preprocess=self._preprocess_mask,
|
|
289
|
+
transform=self.test_transform,
|
|
290
|
+
resize=None,
|
|
291
|
+
)
|
|
292
|
+
elif stage == "predict":
|
|
293
|
+
pass
|
|
294
|
+
else:
|
|
295
|
+
raise ValueError(f"Unknown stage {stage}")
|
|
296
|
+
|
|
297
|
+
def train_dataloader(self) -> DataLoader:
|
|
298
|
+
"""Returns the train dataloader.
|
|
299
|
+
|
|
300
|
+
Raises:
|
|
301
|
+
ValueError: If train dataset is not initialized.
|
|
302
|
+
|
|
303
|
+
Returns:
|
|
304
|
+
Train dataloader.
|
|
305
|
+
"""
|
|
306
|
+
if not self.train_dataset_available:
|
|
307
|
+
raise ValueError("Train dataset is not initialized")
|
|
308
|
+
|
|
309
|
+
return DataLoader(
|
|
310
|
+
self.train_dataset,
|
|
311
|
+
batch_size=self.batch_size,
|
|
312
|
+
shuffle=True,
|
|
313
|
+
num_workers=self.num_workers,
|
|
314
|
+
drop_last=False,
|
|
315
|
+
pin_memory=True,
|
|
316
|
+
persistent_workers=self.num_workers > 0,
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
def val_dataloader(self) -> DataLoader:
|
|
320
|
+
"""Returns the validation dataloader.
|
|
321
|
+
|
|
322
|
+
Raises:
|
|
323
|
+
ValueError: If validation dataset is not initialized.
|
|
324
|
+
|
|
325
|
+
Returns:
|
|
326
|
+
val dataloader.
|
|
327
|
+
"""
|
|
328
|
+
if not self.val_dataset_available:
|
|
329
|
+
raise ValueError("Validation dataset is not initialized")
|
|
330
|
+
|
|
331
|
+
return DataLoader(
|
|
332
|
+
self.val_dataset,
|
|
333
|
+
batch_size=self.batch_size,
|
|
334
|
+
shuffle=False,
|
|
335
|
+
num_workers=self.num_workers,
|
|
336
|
+
drop_last=False,
|
|
337
|
+
pin_memory=True,
|
|
338
|
+
persistent_workers=self.num_workers > 0,
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
def test_dataloader(self) -> DataLoader:
|
|
342
|
+
"""Returns the test dataloader.
|
|
343
|
+
|
|
344
|
+
Raises:
|
|
345
|
+
ValueError: If test dataset is not initialized.
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
Returns:
|
|
349
|
+
test dataloader.
|
|
350
|
+
"""
|
|
351
|
+
if not self.test_dataset_available:
|
|
352
|
+
raise ValueError("Test dataset is not initialized")
|
|
353
|
+
|
|
354
|
+
loader = DataLoader(
|
|
355
|
+
self.test_dataset,
|
|
356
|
+
batch_size=self.batch_size,
|
|
357
|
+
shuffle=False,
|
|
358
|
+
num_workers=self.num_workers,
|
|
359
|
+
drop_last=False,
|
|
360
|
+
pin_memory=True,
|
|
361
|
+
persistent_workers=self.num_workers > 0,
|
|
362
|
+
)
|
|
363
|
+
return loader
|
|
364
|
+
|
|
365
|
+
def predict_dataloader(self) -> DataLoader:
|
|
366
|
+
"""Returns a dataloader used for predictions."""
|
|
367
|
+
return self.test_dataloader()
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
class SegmentationMulticlassDataModule(BaseDataModule):
|
|
371
|
+
"""Base class for segmentation datasets with multiple classes.
|
|
372
|
+
|
|
373
|
+
Args:
|
|
374
|
+
data_path : Path to the data main folder.
|
|
375
|
+
idx_to_class: dict with corrispondence btw mask index and classes: {1: class_1, 2: class_2, ..., N: class_N}
|
|
376
|
+
except background class which is 0.
|
|
377
|
+
name : The name for the data module. Defaults to "multiclass_segmentation_datamodule".
|
|
378
|
+
dataset: Dataset class.
|
|
379
|
+
batch_size : Batch size. Defaults to 32.
|
|
380
|
+
val_size : The validation split. Defaults to 0.3.
|
|
381
|
+
test_size : The test split. Defaults to 0.3.
|
|
382
|
+
seed : Random generator seed. Defaults to 42.
|
|
383
|
+
num_workers: Number of workers for dataloaders. Defaults to 6.
|
|
384
|
+
train_transform: Transformations for train dataset.
|
|
385
|
+
Defaults to None.
|
|
386
|
+
val_transform : Transformations for validation dataset.
|
|
387
|
+
Defaults to None.
|
|
388
|
+
test_transform : Transformations for test dataset.
|
|
389
|
+
Defaults to None.
|
|
390
|
+
train_split_file: path to txt file with training samples list
|
|
391
|
+
val_split_file: path to txt file with validation samples list
|
|
392
|
+
test_split_file: path to txt file with test samples list
|
|
393
|
+
exclude_good : If True, exclude good samples from the dataset. Defaults to False.
|
|
394
|
+
num_data_train: number of samples to use in the train split (shuffle the samples and pick the
|
|
395
|
+
first num_data_train)
|
|
396
|
+
one_hot_encoding: if True, the labels are one-hot encoded to N channels, where N is the number of classes.
|
|
397
|
+
If False, masks are single channel that contains values as class indexes. Defaults to True.
|
|
398
|
+
"""
|
|
399
|
+
|
|
400
|
+
def __init__(
|
|
401
|
+
self,
|
|
402
|
+
data_path: str,
|
|
403
|
+
idx_to_class: dict,
|
|
404
|
+
name: str = "multiclass_segmentation_datamodule",
|
|
405
|
+
dataset: type[SegmentationDatasetMulticlass] = SegmentationDatasetMulticlass,
|
|
406
|
+
batch_size: int = 32,
|
|
407
|
+
test_size: float = 0.3,
|
|
408
|
+
val_size: float = 0.3,
|
|
409
|
+
seed: int = 42,
|
|
410
|
+
num_workers: int = 6,
|
|
411
|
+
train_transform: albumentations.Compose | None = None,
|
|
412
|
+
test_transform: albumentations.Compose | None = None,
|
|
413
|
+
val_transform: albumentations.Compose | None = None,
|
|
414
|
+
train_split_file: str | None = None,
|
|
415
|
+
test_split_file: str | None = None,
|
|
416
|
+
val_split_file: str | None = None,
|
|
417
|
+
exclude_good: bool = False,
|
|
418
|
+
num_data_train: int | None = None,
|
|
419
|
+
one_hot_encoding: bool = False,
|
|
420
|
+
**kwargs: Any,
|
|
421
|
+
):
|
|
422
|
+
super().__init__(
|
|
423
|
+
data_path=data_path,
|
|
424
|
+
name=name,
|
|
425
|
+
seed=seed,
|
|
426
|
+
batch_size=batch_size,
|
|
427
|
+
num_workers=num_workers,
|
|
428
|
+
train_transform=train_transform,
|
|
429
|
+
val_transform=val_transform,
|
|
430
|
+
test_transform=test_transform,
|
|
431
|
+
**kwargs,
|
|
432
|
+
)
|
|
433
|
+
self.test_size = test_size
|
|
434
|
+
self.val_size = val_size
|
|
435
|
+
self.exclude_good = exclude_good
|
|
436
|
+
self.train_split_file = train_split_file
|
|
437
|
+
self.test_split_file = test_split_file
|
|
438
|
+
self.val_split_file = val_split_file
|
|
439
|
+
self.dataset = dataset
|
|
440
|
+
self.idx_to_class = idx_to_class
|
|
441
|
+
self.num_data_train = num_data_train
|
|
442
|
+
self.one_hot_encoding = one_hot_encoding
|
|
443
|
+
self.train_dataset: SegmentationDataset
|
|
444
|
+
self.val_dataset: SegmentationDataset
|
|
445
|
+
self.test_dataset: SegmentationDataset
|
|
446
|
+
|
|
447
|
+
def _preprocess_mask(self, mask: np.ndarray) -> np.ndarray:
|
|
448
|
+
"""Function to preprocess the mask.
|
|
449
|
+
|
|
450
|
+
Args:
|
|
451
|
+
mask: a numpy array of dimension HxW with values in [0] + self.idx_to_class.
|
|
452
|
+
|
|
453
|
+
Output:
|
|
454
|
+
a binary numpy array with dims len(self.idx_to_class+1)xHxW
|
|
455
|
+
"""
|
|
456
|
+
# For each class we must have a channel
|
|
457
|
+
multilayer_mask = np.zeros((len(self.idx_to_class) + 1, *mask.shape[:2]))
|
|
458
|
+
for idx in self.idx_to_class:
|
|
459
|
+
multilayer_mask[int(idx)] = (mask == int(idx)).astype(np.uint8)
|
|
460
|
+
|
|
461
|
+
return multilayer_mask
|
|
462
|
+
|
|
463
|
+
def _resolve_label(self, path: str) -> np.ndarray:
|
|
464
|
+
"""Return a binary array of 1 + len(self.idx_to_class) with 1 if that class is present in the mask."""
|
|
465
|
+
one_hot = np.zeros([len(self.idx_to_class) + 1], np.uint8) # add class 0
|
|
466
|
+
mask = cv2.imread(path, 0)
|
|
467
|
+
if mask.sum() == 0:
|
|
468
|
+
one_hot[0] = 1
|
|
469
|
+
else:
|
|
470
|
+
indices = np.unique(mask)
|
|
471
|
+
one_hot[indices] = 1
|
|
472
|
+
one_hot[0] = 0
|
|
473
|
+
|
|
474
|
+
return one_hot
|
|
475
|
+
|
|
476
|
+
def _read_folder(self, data_path: str) -> tuple[list[str], list[np.ndarray], list[str]]:
|
|
477
|
+
"""Read a folder containing images and masks subfolders.
|
|
478
|
+
|
|
479
|
+
Args:
|
|
480
|
+
data_path: Path to the data folder.
|
|
481
|
+
|
|
482
|
+
Returns:
|
|
483
|
+
List of paths to the images, list of associated one-hot encoded targets and list of mask paths.
|
|
484
|
+
"""
|
|
485
|
+
samples = []
|
|
486
|
+
targets = []
|
|
487
|
+
masks = []
|
|
488
|
+
|
|
489
|
+
for im in glob.glob(os.path.join(data_path, "images", "*")):
|
|
490
|
+
if im[0] == ".":
|
|
491
|
+
continue
|
|
492
|
+
|
|
493
|
+
mask_path = glob.glob(os.path.splitext(im.replace("images", "masks"))[0] + ".*")
|
|
494
|
+
|
|
495
|
+
if len(mask_path) == 0:
|
|
496
|
+
log.debug("Mask not found: %s", os.path.basename(im))
|
|
497
|
+
continue
|
|
498
|
+
|
|
499
|
+
if len(mask_path) > 1:
|
|
500
|
+
raise ValueError(f"Multiple masks found for image: {os.path.basename(im)}, this is not supported")
|
|
501
|
+
|
|
502
|
+
target = self._resolve_label(mask_path[0])
|
|
503
|
+
samples.append(im)
|
|
504
|
+
targets.append(target)
|
|
505
|
+
masks.append(mask_path[0])
|
|
506
|
+
|
|
507
|
+
return samples, targets, masks
|
|
508
|
+
|
|
509
|
+
def _read_split(self, split_file: str) -> tuple[list[str], list[np.ndarray], list[str]]:
|
|
510
|
+
"""Reads split file.
|
|
511
|
+
|
|
512
|
+
Args:
|
|
513
|
+
split_file: Path to the split file.
|
|
514
|
+
|
|
515
|
+
Returns:
|
|
516
|
+
List of paths to images, labels and mask paths.
|
|
517
|
+
"""
|
|
518
|
+
samples, targets, masks = [], [], []
|
|
519
|
+
with open(split_file) as f:
|
|
520
|
+
split = f.read().splitlines()
|
|
521
|
+
for sample in split:
|
|
522
|
+
sample_path = os.path.join(self.data_path, sample)
|
|
523
|
+
mask_path = glob.glob(os.path.splitext(sample_path.replace("images", "masks"))[0] + ".*")
|
|
524
|
+
|
|
525
|
+
if len(mask_path) == 0:
|
|
526
|
+
log.debug("Mask not found: %s", os.path.basename(sample_path))
|
|
527
|
+
continue
|
|
528
|
+
|
|
529
|
+
if len(mask_path) > 1:
|
|
530
|
+
raise ValueError(
|
|
531
|
+
f"Multiple masks found for image: {os.path.basename(sample_path)}, this is not supported"
|
|
532
|
+
)
|
|
533
|
+
|
|
534
|
+
target = self._resolve_label(mask_path[0])
|
|
535
|
+
samples.append(sample_path)
|
|
536
|
+
targets.append(target)
|
|
537
|
+
masks.append(mask_path[0])
|
|
538
|
+
|
|
539
|
+
return samples, targets, masks
|
|
540
|
+
|
|
541
|
+
def _prepare_data(self) -> None:
|
|
542
|
+
"""Prepare data for training and testing."""
|
|
543
|
+
if not (self.train_split_file and self.test_split_file and self.val_split_file):
|
|
544
|
+
all_samples, all_targets, all_masks = self._read_folder(self.data_path)
|
|
545
|
+
|
|
546
|
+
(
|
|
547
|
+
samples_and_masks_train,
|
|
548
|
+
targets_train,
|
|
549
|
+
samples_and_masks_test,
|
|
550
|
+
targets_test,
|
|
551
|
+
) = iterative_train_test_split(
|
|
552
|
+
np.expand_dims(np.array(list(zip(all_samples, all_masks))), 1),
|
|
553
|
+
np.array(all_targets),
|
|
554
|
+
test_size=self.test_size,
|
|
555
|
+
)
|
|
556
|
+
|
|
557
|
+
samples_train, samples_test = samples_and_masks_train[:, 0, 0], samples_and_masks_test[:, 0, 0]
|
|
558
|
+
masks_train, masks_test = samples_and_masks_train[:, 0, 1], samples_and_masks_test[:, 0, 1]
|
|
559
|
+
|
|
560
|
+
if self.test_split_file:
|
|
561
|
+
samples_test, targets_test, masks_test = self._read_split(self.test_split_file)
|
|
562
|
+
if not self.train_split_file:
|
|
563
|
+
samples_train, targets_train, masks_train = [], [], []
|
|
564
|
+
for sample, target, mask in zip(all_samples, all_targets, all_masks):
|
|
565
|
+
if sample not in samples_test:
|
|
566
|
+
samples_train.append(sample)
|
|
567
|
+
targets_train.append(target)
|
|
568
|
+
masks_train.append(mask)
|
|
569
|
+
|
|
570
|
+
if self.train_split_file:
|
|
571
|
+
samples_train, targets_train, masks_train = self._read_split(self.train_split_file)
|
|
572
|
+
if not self.test_split_file:
|
|
573
|
+
samples_test, targets_test, masks_test = [], [], []
|
|
574
|
+
for sample, target, mask in zip(all_samples, all_targets, all_masks):
|
|
575
|
+
if sample not in samples_train:
|
|
576
|
+
samples_test.append(sample)
|
|
577
|
+
targets_test.append(target)
|
|
578
|
+
masks_test.append(mask)
|
|
579
|
+
|
|
580
|
+
if self.val_split_file:
|
|
581
|
+
samples_val, targets_val, masks_val = self._read_split(self.val_split_file)
|
|
582
|
+
if not self.test_split_file or not self.train_split_file:
|
|
583
|
+
raise ValueError("Validation split file is specified but no train or test split file is specified.")
|
|
584
|
+
else:
|
|
585
|
+
samples_and_masks_train, targets_train, samples_and_masks_val, targets_val = iterative_train_test_split(
|
|
586
|
+
np.expand_dims(np.array(list(zip(samples_train, masks_train))), 1),
|
|
587
|
+
np.array(targets_train),
|
|
588
|
+
test_size=self.val_size,
|
|
589
|
+
)
|
|
590
|
+
samples_train = samples_and_masks_train[:, 0, 0]
|
|
591
|
+
samples_val = samples_and_masks_val[:, 0, 0]
|
|
592
|
+
masks_train = samples_and_masks_train[:, 0, 1]
|
|
593
|
+
masks_val = samples_and_masks_val[:, 0, 1]
|
|
594
|
+
|
|
595
|
+
# Pre-ordering train and val samples for determinism
|
|
596
|
+
# They will be shuffled (with a seed) during training
|
|
597
|
+
sorting_indices_train = np.argsort(list(samples_train))
|
|
598
|
+
samples_train = [samples_train[i] for i in sorting_indices_train]
|
|
599
|
+
targets_train = [targets_train[i] for i in sorting_indices_train]
|
|
600
|
+
masks_train = [masks_train[i] for i in sorting_indices_train]
|
|
601
|
+
|
|
602
|
+
sorting_indices_val = np.argsort(samples_val)
|
|
603
|
+
samples_val = [samples_val[i] for i in sorting_indices_val]
|
|
604
|
+
targets_val = [targets_val[i] for i in sorting_indices_val]
|
|
605
|
+
masks_val = [masks_val[i] for i in sorting_indices_val]
|
|
606
|
+
|
|
607
|
+
if self.exclude_good:
|
|
608
|
+
samples_train = list(np.array(samples_train)[np.array(targets_train)[:, 0] == 0])
|
|
609
|
+
masks_train = list(np.array(masks_train)[np.array(targets_train)[:, 0] == 0])
|
|
610
|
+
targets_train = list(np.array(targets_train)[np.array(targets_train)[:, 0] == 0])
|
|
611
|
+
|
|
612
|
+
if self.num_data_train is not None:
|
|
613
|
+
# Generate a random permutation
|
|
614
|
+
random_permutation = list(range(len(samples_train)))
|
|
615
|
+
random.seed(self.seed)
|
|
616
|
+
random.shuffle(random_permutation)
|
|
617
|
+
|
|
618
|
+
# Shuffle samples_train, targets_train, and masks_train using the same permutation
|
|
619
|
+
samples_train = [samples_train[i] for i in random_permutation]
|
|
620
|
+
targets_train = [targets_train[i] for i in random_permutation]
|
|
621
|
+
masks_train = [masks_train[i] for i in random_permutation]
|
|
622
|
+
|
|
623
|
+
samples_train = np.array(samples_train)[: self.num_data_train]
|
|
624
|
+
targets_train = np.array(targets_train)[: self.num_data_train]
|
|
625
|
+
masks_train = np.array(masks_train)[: self.num_data_train]
|
|
626
|
+
|
|
627
|
+
df_list = []
|
|
628
|
+
for split_name, samples, targets, masks in [
|
|
629
|
+
("train", samples_train, targets_train, masks_train),
|
|
630
|
+
("val", samples_val, targets_val, masks_val),
|
|
631
|
+
("test", samples_test, targets_test, masks_test),
|
|
632
|
+
]:
|
|
633
|
+
df = pd.DataFrame({"samples": samples, "targets": list(targets), "masks": masks})
|
|
634
|
+
df["split"] = split_name
|
|
635
|
+
df_list.append(df)
|
|
636
|
+
|
|
637
|
+
self.data = pd.concat(df_list, axis=0)
|
|
638
|
+
|
|
639
|
+
def setup(self, stage=None):
|
|
640
|
+
"""Setup data module based on stages of training."""
|
|
641
|
+
if stage in ["fit", "train"]:
|
|
642
|
+
train_data = self.data[self.data["split"] == "train"]
|
|
643
|
+
val_data = self.data[self.data["split"] == "val"]
|
|
644
|
+
|
|
645
|
+
self.train_dataset = self.dataset(
|
|
646
|
+
image_paths=train_data["samples"].tolist(),
|
|
647
|
+
mask_paths=train_data["masks"].tolist(),
|
|
648
|
+
idx_to_class=self.idx_to_class,
|
|
649
|
+
transform=self.train_transform,
|
|
650
|
+
one_hot=self.one_hot_encoding,
|
|
651
|
+
)
|
|
652
|
+
self.val_dataset = self.dataset(
|
|
653
|
+
image_paths=val_data["samples"].tolist(),
|
|
654
|
+
mask_paths=val_data["masks"].tolist(),
|
|
655
|
+
transform=self.val_transform,
|
|
656
|
+
idx_to_class=self.idx_to_class,
|
|
657
|
+
one_hot=self.one_hot_encoding,
|
|
658
|
+
)
|
|
659
|
+
elif stage == "test":
|
|
660
|
+
self.test_dataset = self.dataset(
|
|
661
|
+
image_paths=self.data[self.data["split"] == "test"]["samples"].tolist(),
|
|
662
|
+
mask_paths=self.data[self.data["split"] == "test"]["masks"].tolist(),
|
|
663
|
+
transform=self.test_transform,
|
|
664
|
+
idx_to_class=self.idx_to_class,
|
|
665
|
+
one_hot=self.one_hot_encoding,
|
|
666
|
+
)
|
|
667
|
+
elif stage == "predict":
|
|
668
|
+
pass
|
|
669
|
+
else:
|
|
670
|
+
raise ValueError(f"Unknown stage {stage}")
|
|
671
|
+
|
|
672
|
+
def train_dataloader(self) -> DataLoader:
|
|
673
|
+
"""Returns the train dataloader.
|
|
674
|
+
|
|
675
|
+
Raises:
|
|
676
|
+
ValueError: If train dataset is not initialized.
|
|
677
|
+
|
|
678
|
+
Returns:
|
|
679
|
+
Train dataloader.
|
|
680
|
+
"""
|
|
681
|
+
if not self.train_dataset_available:
|
|
682
|
+
raise ValueError("Train dataset is not initialized")
|
|
683
|
+
|
|
684
|
+
return DataLoader(
|
|
685
|
+
self.train_dataset,
|
|
686
|
+
batch_size=self.batch_size,
|
|
687
|
+
shuffle=True,
|
|
688
|
+
num_workers=self.num_workers,
|
|
689
|
+
drop_last=False,
|
|
690
|
+
pin_memory=True,
|
|
691
|
+
persistent_workers=self.num_workers > 0,
|
|
692
|
+
)
|
|
693
|
+
|
|
694
|
+
def val_dataloader(self) -> DataLoader:
|
|
695
|
+
"""Returns the validation dataloader.
|
|
696
|
+
|
|
697
|
+
Raises:
|
|
698
|
+
ValueError: If validation dataset is not initialized.
|
|
699
|
+
|
|
700
|
+
Returns:
|
|
701
|
+
val dataloader.
|
|
702
|
+
"""
|
|
703
|
+
if not self.val_dataset_available:
|
|
704
|
+
raise ValueError("Validation dataset is not initialized")
|
|
705
|
+
|
|
706
|
+
return DataLoader(
|
|
707
|
+
self.val_dataset,
|
|
708
|
+
batch_size=self.batch_size,
|
|
709
|
+
shuffle=False,
|
|
710
|
+
num_workers=self.num_workers,
|
|
711
|
+
drop_last=False,
|
|
712
|
+
pin_memory=True,
|
|
713
|
+
persistent_workers=self.num_workers > 0,
|
|
714
|
+
)
|
|
715
|
+
|
|
716
|
+
def test_dataloader(self) -> DataLoader:
|
|
717
|
+
"""Returns the test dataloader.
|
|
718
|
+
|
|
719
|
+
Raises:
|
|
720
|
+
ValueError: If test dataset is not initialized.
|
|
721
|
+
|
|
722
|
+
|
|
723
|
+
Returns:
|
|
724
|
+
test dataloader.
|
|
725
|
+
"""
|
|
726
|
+
if not self.test_dataset_available:
|
|
727
|
+
raise ValueError("Test dataset is not initialized")
|
|
728
|
+
|
|
729
|
+
loader = DataLoader(
|
|
730
|
+
self.test_dataset,
|
|
731
|
+
batch_size=self.batch_size,
|
|
732
|
+
shuffle=False,
|
|
733
|
+
num_workers=self.num_workers,
|
|
734
|
+
drop_last=False,
|
|
735
|
+
pin_memory=True,
|
|
736
|
+
persistent_workers=self.num_workers > 0,
|
|
737
|
+
)
|
|
738
|
+
return loader
|
|
739
|
+
|
|
740
|
+
def predict_dataloader(self) -> DataLoader:
|
|
741
|
+
"""Returns a dataloader used for predictions."""
|
|
742
|
+
return self.test_dataloader()
|