quadra 0.0.1__py3-none-any.whl → 2.1.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hydra_plugins/quadra_searchpath_plugin.py +30 -0
- quadra/__init__.py +6 -0
- quadra/callbacks/__init__.py +0 -0
- quadra/callbacks/anomalib.py +289 -0
- quadra/callbacks/lightning.py +501 -0
- quadra/callbacks/mlflow.py +291 -0
- quadra/callbacks/scheduler.py +69 -0
- quadra/configs/__init__.py +0 -0
- quadra/configs/backbone/caformer_m36.yaml +8 -0
- quadra/configs/backbone/caformer_s36.yaml +8 -0
- quadra/configs/backbone/convnextv2_base.yaml +8 -0
- quadra/configs/backbone/convnextv2_femto.yaml +8 -0
- quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
- quadra/configs/backbone/dino_vitb8.yaml +12 -0
- quadra/configs/backbone/dino_vits8.yaml +12 -0
- quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
- quadra/configs/backbone/dinov2_vits14.yaml +12 -0
- quadra/configs/backbone/efficientnet_b0.yaml +8 -0
- quadra/configs/backbone/efficientnet_b1.yaml +8 -0
- quadra/configs/backbone/efficientnet_b2.yaml +8 -0
- quadra/configs/backbone/efficientnet_b3.yaml +8 -0
- quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
- quadra/configs/backbone/levit_128s.yaml +8 -0
- quadra/configs/backbone/mnasnet0_5.yaml +9 -0
- quadra/configs/backbone/resnet101.yaml +8 -0
- quadra/configs/backbone/resnet18.yaml +8 -0
- quadra/configs/backbone/resnet18_ssl.yaml +8 -0
- quadra/configs/backbone/resnet50.yaml +8 -0
- quadra/configs/backbone/smp.yaml +9 -0
- quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
- quadra/configs/backbone/unetr.yaml +15 -0
- quadra/configs/backbone/vit16_base.yaml +9 -0
- quadra/configs/backbone/vit16_small.yaml +9 -0
- quadra/configs/backbone/vit16_tiny.yaml +9 -0
- quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
- quadra/configs/callbacks/all.yaml +32 -0
- quadra/configs/callbacks/default.yaml +37 -0
- quadra/configs/callbacks/default_anomalib.yaml +67 -0
- quadra/configs/config.yaml +33 -0
- quadra/configs/core/default.yaml +11 -0
- quadra/configs/datamodule/base/anomaly.yaml +16 -0
- quadra/configs/datamodule/base/classification.yaml +21 -0
- quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
- quadra/configs/datamodule/base/segmentation.yaml +18 -0
- quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
- quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
- quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
- quadra/configs/datamodule/base/ssl.yaml +21 -0
- quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
- quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
- quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
- quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
- quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
- quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
- quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
- quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
- quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
- quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
- quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/base/classification/classification.yaml +73 -0
- quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
- quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
- quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
- quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
- quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
- quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
- quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
- quadra/configs/experiment/base/ssl/byol.yaml +43 -0
- quadra/configs/experiment/base/ssl/dino.yaml +46 -0
- quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
- quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
- quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
- quadra/configs/experiment/custom/cls.yaml +12 -0
- quadra/configs/experiment/default.yaml +15 -0
- quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
- quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
- quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
- quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
- quadra/configs/export/default.yaml +13 -0
- quadra/configs/hydra/anomaly_custom.yaml +15 -0
- quadra/configs/hydra/default.yaml +14 -0
- quadra/configs/inference/default.yaml +26 -0
- quadra/configs/logger/comet.yaml +10 -0
- quadra/configs/logger/csv.yaml +5 -0
- quadra/configs/logger/mlflow.yaml +12 -0
- quadra/configs/logger/tensorboard.yaml +8 -0
- quadra/configs/loss/asl.yaml +7 -0
- quadra/configs/loss/barlow.yaml +2 -0
- quadra/configs/loss/bce.yaml +1 -0
- quadra/configs/loss/byol.yaml +1 -0
- quadra/configs/loss/cross_entropy.yaml +1 -0
- quadra/configs/loss/dino.yaml +8 -0
- quadra/configs/loss/simclr.yaml +2 -0
- quadra/configs/loss/simsiam.yaml +1 -0
- quadra/configs/loss/smp_ce.yaml +3 -0
- quadra/configs/loss/smp_dice.yaml +2 -0
- quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
- quadra/configs/loss/smp_mcc.yaml +2 -0
- quadra/configs/loss/vicreg.yaml +5 -0
- quadra/configs/model/anomalib/cfa.yaml +35 -0
- quadra/configs/model/anomalib/cflow.yaml +30 -0
- quadra/configs/model/anomalib/csflow.yaml +34 -0
- quadra/configs/model/anomalib/dfm.yaml +19 -0
- quadra/configs/model/anomalib/draem.yaml +29 -0
- quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
- quadra/configs/model/anomalib/fastflow.yaml +32 -0
- quadra/configs/model/anomalib/padim.yaml +32 -0
- quadra/configs/model/anomalib/patchcore.yaml +36 -0
- quadra/configs/model/barlow.yaml +16 -0
- quadra/configs/model/byol.yaml +25 -0
- quadra/configs/model/classification.yaml +10 -0
- quadra/configs/model/dino.yaml +26 -0
- quadra/configs/model/logistic_regression.yaml +4 -0
- quadra/configs/model/multilabel_classification.yaml +9 -0
- quadra/configs/model/simclr.yaml +18 -0
- quadra/configs/model/simsiam.yaml +24 -0
- quadra/configs/model/smp.yaml +4 -0
- quadra/configs/model/smp_multiclass.yaml +4 -0
- quadra/configs/model/vicreg.yaml +16 -0
- quadra/configs/optimizer/adam.yaml +5 -0
- quadra/configs/optimizer/adamw.yaml +3 -0
- quadra/configs/optimizer/default.yaml +4 -0
- quadra/configs/optimizer/lars.yaml +8 -0
- quadra/configs/optimizer/sgd.yaml +4 -0
- quadra/configs/scheduler/default.yaml +5 -0
- quadra/configs/scheduler/rop.yaml +5 -0
- quadra/configs/scheduler/step.yaml +3 -0
- quadra/configs/scheduler/warmrestart.yaml +2 -0
- quadra/configs/scheduler/warmup.yaml +6 -0
- quadra/configs/task/anomalib/cfa.yaml +5 -0
- quadra/configs/task/anomalib/cflow.yaml +5 -0
- quadra/configs/task/anomalib/csflow.yaml +5 -0
- quadra/configs/task/anomalib/draem.yaml +5 -0
- quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
- quadra/configs/task/anomalib/fastflow.yaml +5 -0
- quadra/configs/task/anomalib/inference.yaml +3 -0
- quadra/configs/task/anomalib/padim.yaml +5 -0
- quadra/configs/task/anomalib/patchcore.yaml +5 -0
- quadra/configs/task/classification.yaml +6 -0
- quadra/configs/task/classification_evaluation.yaml +6 -0
- quadra/configs/task/default.yaml +1 -0
- quadra/configs/task/segmentation.yaml +9 -0
- quadra/configs/task/segmentation_evaluation.yaml +3 -0
- quadra/configs/task/sklearn_classification.yaml +13 -0
- quadra/configs/task/sklearn_classification_patch.yaml +11 -0
- quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
- quadra/configs/task/sklearn_classification_test.yaml +8 -0
- quadra/configs/task/ssl.yaml +2 -0
- quadra/configs/trainer/lightning_cpu.yaml +36 -0
- quadra/configs/trainer/lightning_gpu.yaml +35 -0
- quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
- quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
- quadra/configs/trainer/lightning_multigpu.yaml +37 -0
- quadra/configs/trainer/sklearn_classification.yaml +7 -0
- quadra/configs/transforms/byol.yaml +47 -0
- quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
- quadra/configs/transforms/default.yaml +37 -0
- quadra/configs/transforms/default_numpy.yaml +24 -0
- quadra/configs/transforms/default_resize.yaml +22 -0
- quadra/configs/transforms/dino.yaml +63 -0
- quadra/configs/transforms/linear_eval.yaml +18 -0
- quadra/datamodules/__init__.py +20 -0
- quadra/datamodules/anomaly.py +180 -0
- quadra/datamodules/base.py +375 -0
- quadra/datamodules/classification.py +1003 -0
- quadra/datamodules/generic/__init__.py +0 -0
- quadra/datamodules/generic/imagenette.py +144 -0
- quadra/datamodules/generic/mnist.py +81 -0
- quadra/datamodules/generic/mvtec.py +58 -0
- quadra/datamodules/generic/oxford_pet.py +163 -0
- quadra/datamodules/patch.py +190 -0
- quadra/datamodules/segmentation.py +742 -0
- quadra/datamodules/ssl.py +140 -0
- quadra/datasets/__init__.py +17 -0
- quadra/datasets/anomaly.py +287 -0
- quadra/datasets/classification.py +241 -0
- quadra/datasets/patch.py +138 -0
- quadra/datasets/segmentation.py +239 -0
- quadra/datasets/ssl.py +110 -0
- quadra/losses/__init__.py +0 -0
- quadra/losses/classification/__init__.py +6 -0
- quadra/losses/classification/asl.py +83 -0
- quadra/losses/classification/focal.py +320 -0
- quadra/losses/classification/prototypical.py +148 -0
- quadra/losses/ssl/__init__.py +17 -0
- quadra/losses/ssl/barlowtwins.py +47 -0
- quadra/losses/ssl/byol.py +37 -0
- quadra/losses/ssl/dino.py +129 -0
- quadra/losses/ssl/hyperspherical.py +45 -0
- quadra/losses/ssl/idmm.py +50 -0
- quadra/losses/ssl/simclr.py +67 -0
- quadra/losses/ssl/simsiam.py +30 -0
- quadra/losses/ssl/vicreg.py +76 -0
- quadra/main.py +46 -0
- quadra/metrics/__init__.py +3 -0
- quadra/metrics/segmentation.py +251 -0
- quadra/models/__init__.py +0 -0
- quadra/models/base.py +151 -0
- quadra/models/classification/__init__.py +8 -0
- quadra/models/classification/backbones.py +149 -0
- quadra/models/classification/base.py +92 -0
- quadra/models/evaluation.py +322 -0
- quadra/modules/__init__.py +0 -0
- quadra/modules/backbone.py +30 -0
- quadra/modules/base.py +312 -0
- quadra/modules/classification/__init__.py +3 -0
- quadra/modules/classification/base.py +331 -0
- quadra/modules/ssl/__init__.py +17 -0
- quadra/modules/ssl/barlowtwins.py +59 -0
- quadra/modules/ssl/byol.py +172 -0
- quadra/modules/ssl/common.py +285 -0
- quadra/modules/ssl/dino.py +186 -0
- quadra/modules/ssl/hyperspherical.py +206 -0
- quadra/modules/ssl/idmm.py +98 -0
- quadra/modules/ssl/simclr.py +73 -0
- quadra/modules/ssl/simsiam.py +68 -0
- quadra/modules/ssl/vicreg.py +67 -0
- quadra/optimizers/__init__.py +4 -0
- quadra/optimizers/lars.py +153 -0
- quadra/optimizers/sam.py +127 -0
- quadra/schedulers/__init__.py +3 -0
- quadra/schedulers/base.py +44 -0
- quadra/schedulers/warmup.py +127 -0
- quadra/tasks/__init__.py +24 -0
- quadra/tasks/anomaly.py +582 -0
- quadra/tasks/base.py +397 -0
- quadra/tasks/classification.py +1264 -0
- quadra/tasks/patch.py +492 -0
- quadra/tasks/segmentation.py +389 -0
- quadra/tasks/ssl.py +560 -0
- quadra/trainers/README.md +3 -0
- quadra/trainers/__init__.py +0 -0
- quadra/trainers/classification.py +179 -0
- quadra/utils/__init__.py +0 -0
- quadra/utils/anomaly.py +112 -0
- quadra/utils/classification.py +618 -0
- quadra/utils/deprecation.py +31 -0
- quadra/utils/evaluation.py +474 -0
- quadra/utils/export.py +579 -0
- quadra/utils/imaging.py +32 -0
- quadra/utils/logger.py +15 -0
- quadra/utils/mlflow.py +98 -0
- quadra/utils/model_manager.py +320 -0
- quadra/utils/models.py +524 -0
- quadra/utils/patch/__init__.py +15 -0
- quadra/utils/patch/dataset.py +1433 -0
- quadra/utils/patch/metrics.py +449 -0
- quadra/utils/patch/model.py +153 -0
- quadra/utils/patch/visualization.py +217 -0
- quadra/utils/resolver.py +42 -0
- quadra/utils/segmentation.py +31 -0
- quadra/utils/tests/__init__.py +0 -0
- quadra/utils/tests/fixtures/__init__.py +1 -0
- quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
- quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
- quadra/utils/tests/fixtures/dataset/classification.py +406 -0
- quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
- quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
- quadra/utils/tests/fixtures/models/__init__.py +3 -0
- quadra/utils/tests/fixtures/models/anomaly.py +89 -0
- quadra/utils/tests/fixtures/models/classification.py +45 -0
- quadra/utils/tests/fixtures/models/segmentation.py +33 -0
- quadra/utils/tests/helpers.py +70 -0
- quadra/utils/tests/models.py +27 -0
- quadra/utils/utils.py +525 -0
- quadra/utils/validator.py +115 -0
- quadra/utils/visualization.py +422 -0
- quadra/utils/vit_explainability.py +349 -0
- quadra-2.1.13.dist-info/LICENSE +201 -0
- quadra-2.1.13.dist-info/METADATA +386 -0
- quadra-2.1.13.dist-info/RECORD +300 -0
- {quadra-0.0.1.dist-info → quadra-2.1.13.dist-info}/WHEEL +1 -1
- quadra-2.1.13.dist-info/entry_points.txt +3 -0
- quadra-0.0.1.dist-info/METADATA +0 -14
- quadra-0.0.1.dist-info/RECORD +0 -4
|
@@ -0,0 +1,1003 @@
|
|
|
1
|
+
# pylint: disable=unsupported-assignment-operation,unsubscriptable-object
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import os
|
|
5
|
+
import random
|
|
6
|
+
from collections.abc import Callable
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
import albumentations
|
|
10
|
+
import numpy as np
|
|
11
|
+
import pandas as pd
|
|
12
|
+
import torch
|
|
13
|
+
from sklearn.model_selection import train_test_split
|
|
14
|
+
from skmultilearn.model_selection import iterative_train_test_split
|
|
15
|
+
from timm.data.readers.reader_image_folder import find_images_and_targets
|
|
16
|
+
from torch.utils.data import DataLoader
|
|
17
|
+
|
|
18
|
+
from quadra.datamodules.base import BaseDataModule
|
|
19
|
+
from quadra.datasets import ImageClassificationListDataset
|
|
20
|
+
from quadra.datasets.classification import MultilabelClassificationDataset
|
|
21
|
+
from quadra.utils import utils
|
|
22
|
+
from quadra.utils.classification import find_test_image, get_split, group_labels, natural_key
|
|
23
|
+
|
|
24
|
+
log = utils.get_logger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class ClassificationDataModule(BaseDataModule):
|
|
28
|
+
"""Base class single folder based classification datamodules. If there is no nested folders, use this class.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
data_path: Path to the data main folder.
|
|
32
|
+
name: The name for the data module. Defaults to "classification_datamodule".
|
|
33
|
+
num_workers: Number of workers for dataloaders. Defaults to 16.
|
|
34
|
+
batch_size: Batch size. Defaults to 32.
|
|
35
|
+
seed: Random generator seed. Defaults to 42.
|
|
36
|
+
dataset: Dataset class.
|
|
37
|
+
val_size: The validation split. Defaults to 0.2.
|
|
38
|
+
test_size: The test split. Defaults to 0.2.
|
|
39
|
+
exclude_filter: The filter for excluding folders. Defaults to None.
|
|
40
|
+
include_filter: The filter for including folders. Defaults to None.
|
|
41
|
+
label_map: The mapping for labels. Defaults to None.
|
|
42
|
+
num_data_class: The number of samples per class. Defaults to None.
|
|
43
|
+
train_transform: Transformations for train dataset.
|
|
44
|
+
Defaults to None.
|
|
45
|
+
val_transform: Transformations for validation dataset.
|
|
46
|
+
Defaults to None.
|
|
47
|
+
test_transform: Transformations for test dataset.
|
|
48
|
+
Defaults to None.
|
|
49
|
+
train_split_file: The file with train split. Defaults to None.
|
|
50
|
+
val_split_file: The file with validation split. Defaults to None.
|
|
51
|
+
test_split_file: The file with test split. Defaults to None.
|
|
52
|
+
class_to_idx: The mapping from class name to index. Defaults to None.
|
|
53
|
+
**kwargs: Additional arguments for BaseDataModule.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
def __init__(
|
|
57
|
+
self,
|
|
58
|
+
data_path: str,
|
|
59
|
+
dataset: type[ImageClassificationListDataset] = ImageClassificationListDataset,
|
|
60
|
+
name: str = "classification_datamodule",
|
|
61
|
+
num_workers: int = 8,
|
|
62
|
+
batch_size: int = 32,
|
|
63
|
+
seed: int = 42,
|
|
64
|
+
val_size: float | None = 0.2,
|
|
65
|
+
test_size: float = 0.2,
|
|
66
|
+
num_data_class: int | None = None,
|
|
67
|
+
exclude_filter: list[str] | None = None,
|
|
68
|
+
include_filter: list[str] | None = None,
|
|
69
|
+
label_map: dict[str, Any] | None = None,
|
|
70
|
+
load_aug_images: bool = False,
|
|
71
|
+
aug_name: str | None = None,
|
|
72
|
+
n_aug_to_take: int | None = 4,
|
|
73
|
+
replace_str_from: str | None = None,
|
|
74
|
+
replace_str_to: str | None = None,
|
|
75
|
+
train_transform: albumentations.Compose | None = None,
|
|
76
|
+
val_transform: albumentations.Compose | None = None,
|
|
77
|
+
test_transform: albumentations.Compose | None = None,
|
|
78
|
+
train_split_file: str | None = None,
|
|
79
|
+
test_split_file: str | None = None,
|
|
80
|
+
val_split_file: str | None = None,
|
|
81
|
+
class_to_idx: dict[str, int] | None = None,
|
|
82
|
+
**kwargs: Any,
|
|
83
|
+
):
|
|
84
|
+
super().__init__(
|
|
85
|
+
data_path=data_path,
|
|
86
|
+
name=name,
|
|
87
|
+
seed=seed,
|
|
88
|
+
batch_size=batch_size,
|
|
89
|
+
num_workers=num_workers,
|
|
90
|
+
train_transform=train_transform,
|
|
91
|
+
val_transform=val_transform,
|
|
92
|
+
test_transform=test_transform,
|
|
93
|
+
load_aug_images=load_aug_images,
|
|
94
|
+
aug_name=aug_name,
|
|
95
|
+
n_aug_to_take=n_aug_to_take,
|
|
96
|
+
replace_str_from=replace_str_from,
|
|
97
|
+
replace_str_to=replace_str_to,
|
|
98
|
+
**kwargs,
|
|
99
|
+
)
|
|
100
|
+
self.replace_str = None
|
|
101
|
+
self.exclude_filter = exclude_filter
|
|
102
|
+
self.include_filter = include_filter
|
|
103
|
+
self.val_size = val_size
|
|
104
|
+
self.test_size = test_size
|
|
105
|
+
self.label_map = label_map
|
|
106
|
+
self.num_data_class = num_data_class
|
|
107
|
+
self.dataset = dataset
|
|
108
|
+
self.train_split_file = train_split_file
|
|
109
|
+
self.test_split_file = test_split_file
|
|
110
|
+
self.val_split_file = val_split_file
|
|
111
|
+
self.class_to_idx: dict[str, int] | None
|
|
112
|
+
|
|
113
|
+
if class_to_idx is not None:
|
|
114
|
+
self.class_to_idx = class_to_idx
|
|
115
|
+
self.num_classes = len(self.class_to_idx)
|
|
116
|
+
else:
|
|
117
|
+
self.class_to_idx = self._find_classes_from_data_path(self.data_path)
|
|
118
|
+
if self.class_to_idx is None:
|
|
119
|
+
log.warning("Could not build a class_to_idx from the data_path subdirectories")
|
|
120
|
+
self.num_classes = 0
|
|
121
|
+
else:
|
|
122
|
+
self.num_classes = len(self.class_to_idx)
|
|
123
|
+
|
|
124
|
+
def _read_split(self, split_file: str) -> tuple[list[str], list[str]]:
|
|
125
|
+
"""Reads split file.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
split_file: Path to the split file.
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
List of paths to images.
|
|
132
|
+
"""
|
|
133
|
+
samples, targets = [], []
|
|
134
|
+
with open(split_file) as f:
|
|
135
|
+
split = f.readlines()
|
|
136
|
+
for row in split:
|
|
137
|
+
csv_values = row.split(",")
|
|
138
|
+
sample = str(",".join(csv_values[:-1])).strip()
|
|
139
|
+
target = csv_values[-1].strip()
|
|
140
|
+
sample_path = os.path.join(self.data_path, sample)
|
|
141
|
+
if os.path.exists(sample_path):
|
|
142
|
+
samples.append(sample_path)
|
|
143
|
+
targets.append(target)
|
|
144
|
+
else:
|
|
145
|
+
continue
|
|
146
|
+
# log.warning(f"{sample_path} does not exist")
|
|
147
|
+
return samples, targets
|
|
148
|
+
|
|
149
|
+
def _find_classes_from_data_path(self, data_path: str) -> dict[str, int] | None:
|
|
150
|
+
"""Given a data_path, build a random class_to_idx from the subdirectories.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
data_path: Path to the data main folder.
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
class_to_idx dictionary.
|
|
157
|
+
"""
|
|
158
|
+
subdirectories = []
|
|
159
|
+
|
|
160
|
+
# Check if the directory exists
|
|
161
|
+
if os.path.exists(data_path) and os.path.isdir(data_path):
|
|
162
|
+
# Iterate through the items in the directory
|
|
163
|
+
for item in os.listdir(data_path):
|
|
164
|
+
item_path = os.path.join(data_path, item)
|
|
165
|
+
|
|
166
|
+
# Check if it's a directory and not starting with "."
|
|
167
|
+
if (
|
|
168
|
+
os.path.isdir(item_path)
|
|
169
|
+
and not item.startswith(".")
|
|
170
|
+
# Check if there's at least one image file in the subdirectory
|
|
171
|
+
and any(
|
|
172
|
+
os.path.splitext(file)[1].lower().endswith(tuple(utils.IMAGE_EXTENSIONS))
|
|
173
|
+
for file in os.listdir(item_path)
|
|
174
|
+
)
|
|
175
|
+
):
|
|
176
|
+
subdirectories.append(item)
|
|
177
|
+
|
|
178
|
+
if len(subdirectories) > 0:
|
|
179
|
+
return {cl: idx for idx, cl in enumerate(sorted(subdirectories))}
|
|
180
|
+
return None
|
|
181
|
+
|
|
182
|
+
return None
|
|
183
|
+
|
|
184
|
+
@staticmethod
|
|
185
|
+
def _find_images_and_targets(
|
|
186
|
+
root_folder: str, class_to_idx: dict[str, int] | None = None
|
|
187
|
+
) -> tuple[list[tuple[str, int]], dict[str, int]]:
|
|
188
|
+
"""Collects the samples from item folders."""
|
|
189
|
+
images_and_targets, class_to_idx = find_images_and_targets(
|
|
190
|
+
folder=root_folder, types=utils.IMAGE_EXTENSIONS, class_to_idx=class_to_idx
|
|
191
|
+
)
|
|
192
|
+
return images_and_targets, class_to_idx
|
|
193
|
+
|
|
194
|
+
def _filter_images_and_targets(
|
|
195
|
+
self, images_and_targets: list[tuple[str, int]], class_to_idx: dict[str, int]
|
|
196
|
+
) -> tuple[list[str], list[str]]:
|
|
197
|
+
"""Filters the images and targets."""
|
|
198
|
+
samples: list[str] = []
|
|
199
|
+
targets: list[str] = []
|
|
200
|
+
idx_to_class = {v: k for k, v in class_to_idx.items()}
|
|
201
|
+
images_and_targets = [(str(image_path), target) for image_path, target in images_and_targets]
|
|
202
|
+
for image_path, target in images_and_targets:
|
|
203
|
+
target_class = idx_to_class[target]
|
|
204
|
+
if self.exclude_filter is not None and any(
|
|
205
|
+
exclude_filter in image_path for exclude_filter in self.exclude_filter
|
|
206
|
+
):
|
|
207
|
+
continue
|
|
208
|
+
if self.include_filter is not None:
|
|
209
|
+
if any(include_filter in image_path for include_filter in self.include_filter):
|
|
210
|
+
samples.append(str(image_path))
|
|
211
|
+
targets.append(target_class)
|
|
212
|
+
else:
|
|
213
|
+
samples.append(str(image_path))
|
|
214
|
+
targets.append(target_class)
|
|
215
|
+
return (
|
|
216
|
+
samples,
|
|
217
|
+
targets,
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
def _prepare_data(self) -> None:
|
|
221
|
+
"""Prepares Classification data for the data module."""
|
|
222
|
+
images_and_targets, class_to_idx = self._find_images_and_targets(self.data_path, self.class_to_idx)
|
|
223
|
+
all_samples, all_targets = self._filter_images_and_targets(images_and_targets, class_to_idx)
|
|
224
|
+
if self.label_map is not None:
|
|
225
|
+
all_targets, _ = group_labels(all_targets, self.label_map)
|
|
226
|
+
|
|
227
|
+
samples_train: list[str] = []
|
|
228
|
+
targets_train: list[str] = []
|
|
229
|
+
samples_test: list[str] = []
|
|
230
|
+
targets_test: list[str] = []
|
|
231
|
+
samples_val: list[str] = []
|
|
232
|
+
targets_val: list[str] = []
|
|
233
|
+
|
|
234
|
+
if self.test_size < 1.0:
|
|
235
|
+
samples_train, samples_test, targets_train, targets_test = train_test_split(
|
|
236
|
+
all_samples,
|
|
237
|
+
all_targets,
|
|
238
|
+
test_size=self.test_size,
|
|
239
|
+
random_state=self.seed,
|
|
240
|
+
stratify=all_targets,
|
|
241
|
+
)
|
|
242
|
+
if self.test_split_file:
|
|
243
|
+
samples_test, targets_test = self._read_split(self.test_split_file)
|
|
244
|
+
if not self.train_split_file:
|
|
245
|
+
samples_train, targets_train = [], []
|
|
246
|
+
for sample, target in zip(all_samples, all_targets):
|
|
247
|
+
if sample not in samples_test:
|
|
248
|
+
samples_train.append(sample)
|
|
249
|
+
targets_train.append(target)
|
|
250
|
+
if self.train_split_file:
|
|
251
|
+
samples_train, targets_train = self._read_split(self.train_split_file)
|
|
252
|
+
if not self.test_split_file:
|
|
253
|
+
samples_test, targets_test = [], []
|
|
254
|
+
for sample, target in zip(all_samples, all_targets):
|
|
255
|
+
if sample not in samples_train:
|
|
256
|
+
samples_test.append(sample)
|
|
257
|
+
targets_test.append(target)
|
|
258
|
+
if self.val_split_file:
|
|
259
|
+
samples_val, targets_val = self._read_split(self.val_split_file)
|
|
260
|
+
if not self.test_split_file or not self.train_split_file:
|
|
261
|
+
raise ValueError("Validation split file is specified but no train or test split file is specified.")
|
|
262
|
+
else:
|
|
263
|
+
samples_train, samples_val, targets_train, targets_val = train_test_split(
|
|
264
|
+
samples_train,
|
|
265
|
+
targets_train,
|
|
266
|
+
test_size=self.val_size,
|
|
267
|
+
random_state=self.seed,
|
|
268
|
+
stratify=targets_train,
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
if self.num_data_class is not None:
|
|
272
|
+
samples_train_topick = []
|
|
273
|
+
targets_train_topick = []
|
|
274
|
+
for cl in np.unique(targets_train):
|
|
275
|
+
idx = np.where(np.array(targets_train) == cl)[0]
|
|
276
|
+
random.seed(self.seed)
|
|
277
|
+
random.shuffle(idx) # type: ignore[arg-type]
|
|
278
|
+
to_pick = idx[: self.num_data_class]
|
|
279
|
+
for i in to_pick:
|
|
280
|
+
samples_train_topick.append(samples_train[i])
|
|
281
|
+
targets_train_topick.append(cl)
|
|
282
|
+
|
|
283
|
+
samples_train = samples_train_topick
|
|
284
|
+
targets_train = targets_train_topick
|
|
285
|
+
else:
|
|
286
|
+
log.info("Test size is set to 1.0: all samples will be put in test-set")
|
|
287
|
+
samples_test = all_samples
|
|
288
|
+
targets_test = all_targets
|
|
289
|
+
train_df = pd.DataFrame({"samples": samples_train, "targets": targets_train})
|
|
290
|
+
train_df["split"] = "train"
|
|
291
|
+
val_df = pd.DataFrame({"samples": samples_val, "targets": targets_val})
|
|
292
|
+
val_df["split"] = "val"
|
|
293
|
+
test_df = pd.DataFrame({"samples": samples_test, "targets": targets_test})
|
|
294
|
+
test_df["split"] = "test"
|
|
295
|
+
self.data = pd.concat([train_df, val_df, test_df], axis=0)
|
|
296
|
+
|
|
297
|
+
# if self.load_aug_images:
|
|
298
|
+
# samples_train, targets_train = self.load_augmented_samples(
|
|
299
|
+
# samples_train, targets_train, self.replace_str, shuffle=True
|
|
300
|
+
# )
|
|
301
|
+
# samples_val, targets_val = self.load_augmented_samples(
|
|
302
|
+
# samples_val, targets_val , self.replace_str, shuffle=True
|
|
303
|
+
# )
|
|
304
|
+
unique_targets = [str(t) for t in np.unique(targets_train)]
|
|
305
|
+
if self.class_to_idx is None:
|
|
306
|
+
sorted_targets = sorted(unique_targets, key=natural_key)
|
|
307
|
+
class_to_idx = {c: idx for idx, c in enumerate(sorted_targets)}
|
|
308
|
+
self.class_to_idx = class_to_idx
|
|
309
|
+
log.info("Class_to_idx not provided in config, building it from targets: %s", class_to_idx)
|
|
310
|
+
|
|
311
|
+
if len(unique_targets) == 0:
|
|
312
|
+
log.warning("Unique_targets length is 0, training set is empty")
|
|
313
|
+
else:
|
|
314
|
+
if len(self.class_to_idx.keys()) != len(unique_targets):
|
|
315
|
+
raise ValueError(
|
|
316
|
+
"The number of classes in the class_to_idx dictionary does not match the number of unique targets."
|
|
317
|
+
f" `class_to_idx`: {self.class_to_idx}, `unique_targets`: {unique_targets}"
|
|
318
|
+
)
|
|
319
|
+
if not all(c in unique_targets for c in self.class_to_idx):
|
|
320
|
+
raise ValueError(
|
|
321
|
+
"The classes in the class_to_idx dictionary do not match the available unique targets in the"
|
|
322
|
+
" datasset. `class_to_idx`: {self.class_to_idx}, `unique_targets`: {unique_targets}"
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
def setup(self, stage: str | None = None) -> None:
|
|
326
|
+
"""Setup data module based on stages of training."""
|
|
327
|
+
if stage in ["train", "fit"]:
|
|
328
|
+
self.train_dataset = self.dataset(
|
|
329
|
+
samples=self.data[self.data["split"] == "train"]["samples"].tolist(),
|
|
330
|
+
targets=self.data[self.data["split"] == "train"]["targets"].tolist(),
|
|
331
|
+
transform=self.train_transform,
|
|
332
|
+
class_to_idx=self.class_to_idx,
|
|
333
|
+
)
|
|
334
|
+
self.val_dataset = self.dataset(
|
|
335
|
+
samples=self.data[self.data["split"] == "val"]["samples"].tolist(),
|
|
336
|
+
targets=self.data[self.data["split"] == "val"]["targets"].tolist(),
|
|
337
|
+
transform=self.val_transform,
|
|
338
|
+
class_to_idx=self.class_to_idx,
|
|
339
|
+
)
|
|
340
|
+
if stage in ["test", "predict"]:
|
|
341
|
+
self.test_dataset = self.dataset(
|
|
342
|
+
samples=self.data[self.data["split"] == "test"]["samples"].tolist(),
|
|
343
|
+
targets=self.data[self.data["split"] == "test"]["targets"].tolist(),
|
|
344
|
+
transform=self.test_transform,
|
|
345
|
+
class_to_idx=self.class_to_idx,
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
def train_dataloader(self) -> DataLoader:
|
|
349
|
+
"""Returns the train dataloader.
|
|
350
|
+
|
|
351
|
+
Raises:
|
|
352
|
+
ValueError: If train dataset is not initialized.
|
|
353
|
+
|
|
354
|
+
Returns:
|
|
355
|
+
Train dataloader.
|
|
356
|
+
"""
|
|
357
|
+
if not self.train_dataset_available:
|
|
358
|
+
raise ValueError("Train dataset is not initialized")
|
|
359
|
+
if not isinstance(self.train_dataset, torch.utils.data.Dataset):
|
|
360
|
+
raise ValueError("Train dataset has to be single `torch.utils.data.Dataset` instance.")
|
|
361
|
+
return DataLoader(
|
|
362
|
+
self.train_dataset,
|
|
363
|
+
batch_size=self.batch_size,
|
|
364
|
+
shuffle=True,
|
|
365
|
+
num_workers=self.num_workers,
|
|
366
|
+
drop_last=False,
|
|
367
|
+
pin_memory=True,
|
|
368
|
+
persistent_workers=self.num_workers > 0,
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
def val_dataloader(self) -> DataLoader:
|
|
372
|
+
"""Returns the validation dataloader.
|
|
373
|
+
|
|
374
|
+
Raises:
|
|
375
|
+
ValueError: If validation dataset is not initialized.
|
|
376
|
+
|
|
377
|
+
Returns:
|
|
378
|
+
val dataloader.
|
|
379
|
+
"""
|
|
380
|
+
if not self.val_dataset_available:
|
|
381
|
+
raise ValueError("Validation dataset is not initialized")
|
|
382
|
+
if not isinstance(self.val_dataset, torch.utils.data.Dataset):
|
|
383
|
+
raise ValueError("Validation dataset has to be single `torch.utils.data.Dataset` instance.")
|
|
384
|
+
return DataLoader(
|
|
385
|
+
self.val_dataset,
|
|
386
|
+
batch_size=self.batch_size,
|
|
387
|
+
shuffle=False,
|
|
388
|
+
num_workers=self.num_workers,
|
|
389
|
+
drop_last=False,
|
|
390
|
+
pin_memory=True,
|
|
391
|
+
persistent_workers=self.num_workers > 0,
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
def test_dataloader(self) -> DataLoader:
|
|
395
|
+
"""Returns the test dataloader.
|
|
396
|
+
|
|
397
|
+
Raises:
|
|
398
|
+
ValueError: If test dataset is not initialized.
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
Returns:
|
|
402
|
+
test dataloader.
|
|
403
|
+
"""
|
|
404
|
+
if not self.test_dataset_available:
|
|
405
|
+
raise ValueError("Test dataset is not initialized")
|
|
406
|
+
|
|
407
|
+
loader = DataLoader(
|
|
408
|
+
self.test_dataset,
|
|
409
|
+
batch_size=self.batch_size,
|
|
410
|
+
shuffle=False,
|
|
411
|
+
num_workers=self.num_workers,
|
|
412
|
+
drop_last=False,
|
|
413
|
+
pin_memory=True,
|
|
414
|
+
persistent_workers=self.num_workers > 0,
|
|
415
|
+
)
|
|
416
|
+
return loader
|
|
417
|
+
|
|
418
|
+
def predict_dataloader(self) -> DataLoader:
|
|
419
|
+
"""Returns a dataloader used for predictions."""
|
|
420
|
+
return self.test_dataloader()
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
class SklearnClassificationDataModule(BaseDataModule):
|
|
424
|
+
"""A generic Data Module for classification with frozen torch backbone and sklearn classifier.
|
|
425
|
+
|
|
426
|
+
It can also handle k-fold cross validation.
|
|
427
|
+
|
|
428
|
+
Args:
|
|
429
|
+
name: The name for the data module. Defaults to "sklearn_classification_datamodule".
|
|
430
|
+
data_path: Path to images main folder
|
|
431
|
+
exclude_filter: List of string filter to be used to exclude images. If None no filter will be applied.
|
|
432
|
+
include_filter: List of string filter to be used to include images. Only images that satisfied at list one of
|
|
433
|
+
the filter will be included.
|
|
434
|
+
val_size: The validation split. Defaults to 0.2.
|
|
435
|
+
class_to_idx: Dictionary of conversion btw folder name and index. Only file whose label is in dictionary key
|
|
436
|
+
list will be considered. If None all files will be considered and a custom conversion is created.
|
|
437
|
+
seed: Fixed seed for random operations
|
|
438
|
+
batch_size: Dimension of batches for dataloader
|
|
439
|
+
num_workers: Number of workers for dataloader
|
|
440
|
+
train_transform: Albumentation transformations for training set
|
|
441
|
+
val_transform: Albumentation transformations for validation set
|
|
442
|
+
test_transform: Albumentation transformations for test set
|
|
443
|
+
roi: Optional cropping region
|
|
444
|
+
n_splits: Number of dataset subdivision (default 1 -> train/test). Use a value >= 2 for cross validation.
|
|
445
|
+
phase: Either train or test
|
|
446
|
+
cache: If true disable shuffling in all dataloader to enable feature caching
|
|
447
|
+
limit_training_data: if defined, each class will be donwsampled to this number. It must be >= 2 to allow
|
|
448
|
+
splitting
|
|
449
|
+
label_map: Dictionary of conversion btw folder name and label.
|
|
450
|
+
train_split_file: Optional path to a csv file containing the train split samples.
|
|
451
|
+
test_split_file: Optional path to a csv file containing the test split samples.
|
|
452
|
+
**kwargs: Additional arguments for BaseDataModule
|
|
453
|
+
"""
|
|
454
|
+
|
|
455
|
+
def __init__(
|
|
456
|
+
self,
|
|
457
|
+
data_path: str,
|
|
458
|
+
exclude_filter: list[str] | None = None,
|
|
459
|
+
include_filter: list[str] | None = None,
|
|
460
|
+
val_size: float = 0.2,
|
|
461
|
+
class_to_idx: dict[str, int] | None = None,
|
|
462
|
+
label_map: dict[str, Any] | None = None,
|
|
463
|
+
seed: int = 42,
|
|
464
|
+
batch_size: int = 32,
|
|
465
|
+
num_workers: int = 6,
|
|
466
|
+
train_transform: albumentations.Compose | None = None,
|
|
467
|
+
val_transform: albumentations.Compose | None = None,
|
|
468
|
+
test_transform: albumentations.Compose | None = None,
|
|
469
|
+
roi: tuple[int, int, int, int] | None = None,
|
|
470
|
+
n_splits: int = 1,
|
|
471
|
+
phase: str = "train",
|
|
472
|
+
cache: bool = False,
|
|
473
|
+
limit_training_data: int | None = None,
|
|
474
|
+
train_split_file: str | None = None,
|
|
475
|
+
test_split_file: str | None = None,
|
|
476
|
+
name: str = "sklearn_classification_datamodule",
|
|
477
|
+
dataset: type[ImageClassificationListDataset] = ImageClassificationListDataset,
|
|
478
|
+
**kwargs: Any,
|
|
479
|
+
):
|
|
480
|
+
super().__init__(
|
|
481
|
+
data_path=data_path,
|
|
482
|
+
name=name,
|
|
483
|
+
seed=seed,
|
|
484
|
+
batch_size=batch_size,
|
|
485
|
+
num_workers=num_workers,
|
|
486
|
+
train_transform=train_transform,
|
|
487
|
+
val_transform=val_transform,
|
|
488
|
+
test_transform=test_transform,
|
|
489
|
+
**kwargs,
|
|
490
|
+
)
|
|
491
|
+
|
|
492
|
+
self.class_to_idx = class_to_idx
|
|
493
|
+
self.roi = roi
|
|
494
|
+
self.cache = cache
|
|
495
|
+
self.limit_training_data = limit_training_data
|
|
496
|
+
|
|
497
|
+
self.dataset = dataset
|
|
498
|
+
self.phase = phase
|
|
499
|
+
self.n_splits = n_splits
|
|
500
|
+
self.train_split_file = train_split_file
|
|
501
|
+
self.test_split_file = test_split_file
|
|
502
|
+
self.exclude_filter = exclude_filter
|
|
503
|
+
self.include_filter = include_filter
|
|
504
|
+
self.val_size = val_size
|
|
505
|
+
self.label_map = label_map
|
|
506
|
+
self.full_dataset: ImageClassificationListDataset
|
|
507
|
+
self.train_dataset: list[ImageClassificationListDataset]
|
|
508
|
+
self.val_dataset: list[ImageClassificationListDataset]
|
|
509
|
+
|
|
510
|
+
def _prepare_data(self) -> None:
|
|
511
|
+
"""Prepares the data for the data module."""
|
|
512
|
+
assert os.path.isdir(self.data_path), f"Folder {self.data_path} does not exist."
|
|
513
|
+
|
|
514
|
+
list_df = []
|
|
515
|
+
if self.phase == "train":
|
|
516
|
+
samples, targets, split_generator, self.class_to_idx = get_split(
|
|
517
|
+
image_dir=self.data_path,
|
|
518
|
+
exclude_filter=self.exclude_filter,
|
|
519
|
+
include_filter=self.include_filter,
|
|
520
|
+
test_size=self.val_size,
|
|
521
|
+
random_state=self.seed,
|
|
522
|
+
class_to_idx=self.class_to_idx,
|
|
523
|
+
n_splits=self.n_splits,
|
|
524
|
+
limit_training_data=self.limit_training_data,
|
|
525
|
+
train_split_file=self.train_split_file,
|
|
526
|
+
label_map=self.label_map,
|
|
527
|
+
)
|
|
528
|
+
|
|
529
|
+
for cv_idx, split in enumerate(split_generator):
|
|
530
|
+
train_idx, val_idx = split
|
|
531
|
+
train_val_df = pd.DataFrame({"samples": samples, "targets": targets})
|
|
532
|
+
train_val_df["cv"] = 0
|
|
533
|
+
train_val_df["split"] = "train"
|
|
534
|
+
train_val_df.loc[val_idx, "split"] = "val"
|
|
535
|
+
train_val_df.loc[train_idx, "cv"] = cv_idx
|
|
536
|
+
train_val_df.loc[val_idx, "cv"] = cv_idx
|
|
537
|
+
list_df.append(train_val_df)
|
|
538
|
+
|
|
539
|
+
test_samples, test_targets = find_test_image(
|
|
540
|
+
folder=self.data_path,
|
|
541
|
+
exclude_filter=self.exclude_filter,
|
|
542
|
+
include_filter=self.include_filter,
|
|
543
|
+
test_split_file=self.test_split_file,
|
|
544
|
+
)
|
|
545
|
+
if self.label_map is not None:
|
|
546
|
+
test_targets, _ = group_labels(test_targets, self.label_map)
|
|
547
|
+
test_df = pd.DataFrame({"samples": test_samples, "targets": test_targets})
|
|
548
|
+
test_df["split"] = "test"
|
|
549
|
+
test_df["cv"] = np.nan
|
|
550
|
+
|
|
551
|
+
list_df.append(test_df)
|
|
552
|
+
self.data = pd.concat(list_df, axis=0)
|
|
553
|
+
|
|
554
|
+
def setup(self, stage: str) -> None:
|
|
555
|
+
"""Setup data module based on stages of training."""
|
|
556
|
+
if stage == "fit":
|
|
557
|
+
self.train_dataset = []
|
|
558
|
+
self.val_dataset = []
|
|
559
|
+
|
|
560
|
+
for cv_idx in range(self.n_splits):
|
|
561
|
+
cv_df = self.data[self.data["cv"] == cv_idx]
|
|
562
|
+
train_samples = cv_df[cv_df["split"] == "train"]["samples"].tolist()
|
|
563
|
+
train_targets = cv_df[cv_df["split"] == "train"]["targets"].tolist()
|
|
564
|
+
val_samples = cv_df[cv_df["split"] == "val"]["samples"].tolist()
|
|
565
|
+
val_targets = cv_df[cv_df["split"] == "val"]["targets"].tolist()
|
|
566
|
+
self.train_dataset.append(
|
|
567
|
+
self.dataset(
|
|
568
|
+
class_to_idx=self.class_to_idx,
|
|
569
|
+
samples=train_samples,
|
|
570
|
+
targets=train_targets,
|
|
571
|
+
transform=self.train_transform,
|
|
572
|
+
roi=self.roi,
|
|
573
|
+
)
|
|
574
|
+
)
|
|
575
|
+
self.val_dataset.append(
|
|
576
|
+
self.dataset(
|
|
577
|
+
class_to_idx=self.class_to_idx,
|
|
578
|
+
samples=val_samples,
|
|
579
|
+
targets=val_targets,
|
|
580
|
+
transform=self.val_transform,
|
|
581
|
+
roi=self.roi,
|
|
582
|
+
)
|
|
583
|
+
)
|
|
584
|
+
all_samples = self.data[self.data["cv"] == 0]["samples"].tolist()
|
|
585
|
+
all_targets = self.data[self.data["cv"] == 0]["targets"].tolist()
|
|
586
|
+
self.full_dataset = self.dataset(
|
|
587
|
+
class_to_idx=self.class_to_idx,
|
|
588
|
+
samples=all_samples,
|
|
589
|
+
targets=all_targets,
|
|
590
|
+
transform=self.train_transform,
|
|
591
|
+
roi=self.roi,
|
|
592
|
+
)
|
|
593
|
+
if stage == "test":
|
|
594
|
+
test_samples = self.data[self.data["split"] == "test"]["samples"].tolist()
|
|
595
|
+
test_targets = self.data[self.data["split"] == "test"]["targets"]
|
|
596
|
+
self.test_dataset = self.dataset(
|
|
597
|
+
class_to_idx=self.class_to_idx,
|
|
598
|
+
samples=test_samples,
|
|
599
|
+
targets=test_targets.tolist(),
|
|
600
|
+
transform=self.test_transform,
|
|
601
|
+
roi=self.roi,
|
|
602
|
+
allow_missing_label=True,
|
|
603
|
+
)
|
|
604
|
+
|
|
605
|
+
def predict_dataloader(self) -> DataLoader:
|
|
606
|
+
"""Returns a dataloader used for predictions."""
|
|
607
|
+
return self.test_dataloader()
|
|
608
|
+
|
|
609
|
+
def train_dataloader(self) -> list[DataLoader]:
|
|
610
|
+
"""Returns a list of train dataloader.
|
|
611
|
+
|
|
612
|
+
Raises:
|
|
613
|
+
ValueError: If train dataset is not initialized.
|
|
614
|
+
|
|
615
|
+
Returns:
|
|
616
|
+
list of train dataloader.
|
|
617
|
+
"""
|
|
618
|
+
if not self.train_dataset_available:
|
|
619
|
+
raise ValueError("Train dataset is not initialized")
|
|
620
|
+
|
|
621
|
+
loader = []
|
|
622
|
+
for dataset in self.train_dataset:
|
|
623
|
+
loader.append(
|
|
624
|
+
DataLoader(
|
|
625
|
+
dataset,
|
|
626
|
+
batch_size=self.batch_size,
|
|
627
|
+
shuffle=not self.cache,
|
|
628
|
+
num_workers=self.num_workers,
|
|
629
|
+
drop_last=False,
|
|
630
|
+
pin_memory=True,
|
|
631
|
+
)
|
|
632
|
+
)
|
|
633
|
+
return loader
|
|
634
|
+
|
|
635
|
+
def val_dataloader(self) -> list[DataLoader]:
|
|
636
|
+
"""Returns a list of validation dataloader.
|
|
637
|
+
|
|
638
|
+
Raises:
|
|
639
|
+
ValueError: If validation dataset is not initialized.
|
|
640
|
+
|
|
641
|
+
Returns:
|
|
642
|
+
List of validation dataloader.
|
|
643
|
+
"""
|
|
644
|
+
if not self.val_dataset_available:
|
|
645
|
+
raise ValueError("Validation dataset is not initialized")
|
|
646
|
+
|
|
647
|
+
loader = []
|
|
648
|
+
for dataset in self.val_dataset:
|
|
649
|
+
loader.append(
|
|
650
|
+
DataLoader(
|
|
651
|
+
dataset,
|
|
652
|
+
batch_size=self.batch_size,
|
|
653
|
+
shuffle=False,
|
|
654
|
+
num_workers=self.num_workers,
|
|
655
|
+
drop_last=False,
|
|
656
|
+
pin_memory=True,
|
|
657
|
+
)
|
|
658
|
+
)
|
|
659
|
+
|
|
660
|
+
return loader
|
|
661
|
+
|
|
662
|
+
def test_dataloader(self) -> DataLoader:
|
|
663
|
+
"""Returns the test dataloader.
|
|
664
|
+
|
|
665
|
+
Raises:
|
|
666
|
+
ValueError: If test dataset is not initialized.
|
|
667
|
+
|
|
668
|
+
|
|
669
|
+
Returns:
|
|
670
|
+
test dataloader.
|
|
671
|
+
"""
|
|
672
|
+
if not self.test_dataset_available:
|
|
673
|
+
raise ValueError("Test dataset is not initialized")
|
|
674
|
+
|
|
675
|
+
loader = DataLoader(
|
|
676
|
+
self.test_dataset,
|
|
677
|
+
batch_size=self.batch_size,
|
|
678
|
+
shuffle=False,
|
|
679
|
+
num_workers=self.num_workers,
|
|
680
|
+
drop_last=False,
|
|
681
|
+
pin_memory=True,
|
|
682
|
+
persistent_workers=self.num_workers > 0,
|
|
683
|
+
)
|
|
684
|
+
return loader
|
|
685
|
+
|
|
686
|
+
def full_dataloader(self) -> DataLoader:
|
|
687
|
+
"""Return a dataloader to perform training on the entire dataset.
|
|
688
|
+
|
|
689
|
+
Returns:
|
|
690
|
+
dataloader to perform training on the entire dataset after evaluation. This is useful
|
|
691
|
+
to perform a final training on the entire dataset after the evaluation phase.
|
|
692
|
+
|
|
693
|
+
"""
|
|
694
|
+
if self.full_dataset is None:
|
|
695
|
+
raise ValueError("Full dataset is not initialized")
|
|
696
|
+
|
|
697
|
+
return DataLoader(
|
|
698
|
+
self.full_dataset,
|
|
699
|
+
batch_size=self.batch_size,
|
|
700
|
+
shuffle=not self.cache,
|
|
701
|
+
num_workers=self.num_workers,
|
|
702
|
+
drop_last=False,
|
|
703
|
+
pin_memory=True,
|
|
704
|
+
)
|
|
705
|
+
|
|
706
|
+
|
|
707
|
+
class MultilabelClassificationDataModule(BaseDataModule):
|
|
708
|
+
"""Base class for all multi-label modules.
|
|
709
|
+
|
|
710
|
+
Args:
|
|
711
|
+
data_path: Path to the data main folder.
|
|
712
|
+
images_and_labels_file: a path to a txt file containing the relative (to `data_path`) path
|
|
713
|
+
of images with their relative labels, in a comma-separated way.
|
|
714
|
+
E.g.:
|
|
715
|
+
|
|
716
|
+
* path1,l1,l2,l3
|
|
717
|
+
* path2,l4,l5
|
|
718
|
+
* ...
|
|
719
|
+
|
|
720
|
+
One of `images_and_label` and both `train_split_file` and `test_split_file` must be set.
|
|
721
|
+
Defaults to None.
|
|
722
|
+
name: The name for the data module. Defaults to "multilabel_datamodule".
|
|
723
|
+
dataset: a callable returning a torch.utils.data.Dataset class.
|
|
724
|
+
num_classes: the number of classes in the dataset. This is used to create one-hot encoded
|
|
725
|
+
targets. Defaults to None.
|
|
726
|
+
num_workers: Number of workers for dataloaders. Defaults to 16.
|
|
727
|
+
batch_size: Training batch size. Defaults to 64.
|
|
728
|
+
test_batch_size: Testing batch size. Defaults to 64.
|
|
729
|
+
seed: Random generator seed. Defaults to SegmentationEvalua2.
|
|
730
|
+
val_size: The validation split. Defaults to 0.2.
|
|
731
|
+
test_size: The test split. Defaults to 0.2.
|
|
732
|
+
train_transform: Transformations for train dataset.
|
|
733
|
+
Defaults to None.
|
|
734
|
+
val_transform: Transformations for validation dataset.
|
|
735
|
+
Defaults to None.
|
|
736
|
+
test_transform: Transformations for test dataset.
|
|
737
|
+
Defaults to None.
|
|
738
|
+
train_split_file: The file with train split. Defaults to None.
|
|
739
|
+
val_split_file: The file with validation split. Defaults to None.
|
|
740
|
+
test_split_file: The file with test split. Defaults to None.
|
|
741
|
+
class_to_idx: a clss to idx dictionary. Defaults to None.
|
|
742
|
+
"""
|
|
743
|
+
|
|
744
|
+
def __init__(
|
|
745
|
+
self,
|
|
746
|
+
data_path: str,
|
|
747
|
+
images_and_labels_file: str | None = None,
|
|
748
|
+
train_split_file: str | None = None,
|
|
749
|
+
test_split_file: str | None = None,
|
|
750
|
+
val_split_file: str | None = None,
|
|
751
|
+
name: str = "multilabel_datamodule",
|
|
752
|
+
dataset: Callable = MultilabelClassificationDataset,
|
|
753
|
+
num_classes: int | None = None,
|
|
754
|
+
num_workers: int = 16,
|
|
755
|
+
batch_size: int = 64,
|
|
756
|
+
test_batch_size: int = 64,
|
|
757
|
+
seed: int = 42,
|
|
758
|
+
val_size: float | None = 0.2,
|
|
759
|
+
test_size: float | None = 0.2,
|
|
760
|
+
train_transform: albumentations.Compose | None = None,
|
|
761
|
+
val_transform: albumentations.Compose | None = None,
|
|
762
|
+
test_transform: albumentations.Compose | None = None,
|
|
763
|
+
class_to_idx: dict[str, int] | None = None,
|
|
764
|
+
**kwargs,
|
|
765
|
+
):
|
|
766
|
+
super().__init__(
|
|
767
|
+
data_path=data_path,
|
|
768
|
+
name=name,
|
|
769
|
+
num_workers=num_workers,
|
|
770
|
+
batch_size=batch_size,
|
|
771
|
+
seed=seed,
|
|
772
|
+
train_transform=train_transform,
|
|
773
|
+
val_transform=val_transform,
|
|
774
|
+
test_transform=test_transform,
|
|
775
|
+
**kwargs,
|
|
776
|
+
)
|
|
777
|
+
if not (images_and_labels_file is not None or (train_split_file is not None and test_split_file is not None)):
|
|
778
|
+
raise ValueError(
|
|
779
|
+
"Either `images_and_labels_file` or both `train_split_file` and `test_split_file` must be set"
|
|
780
|
+
)
|
|
781
|
+
self.images_and_labels_file = images_and_labels_file
|
|
782
|
+
self.dataset = dataset
|
|
783
|
+
self.num_classes = num_classes
|
|
784
|
+
self.train_batch_size = batch_size
|
|
785
|
+
self.test_batch_size = test_batch_size
|
|
786
|
+
self.val_size = val_size
|
|
787
|
+
self.test_size = test_size
|
|
788
|
+
self.train_split_file = train_split_file
|
|
789
|
+
self.test_split_file = test_split_file
|
|
790
|
+
self.val_split_file = val_split_file
|
|
791
|
+
self.class_to_idx = class_to_idx
|
|
792
|
+
self.train_dataset: MultilabelClassificationDataset
|
|
793
|
+
self.val_dataset: MultilabelClassificationDataset
|
|
794
|
+
self.test_dataset: MultilabelClassificationDataset
|
|
795
|
+
|
|
796
|
+
def _read_split(self, split_file: str) -> tuple[list[str], list[list[str]]]:
|
|
797
|
+
"""Reads split file.
|
|
798
|
+
|
|
799
|
+
Args:
|
|
800
|
+
split_file: Path to the split file.
|
|
801
|
+
|
|
802
|
+
Returns:
|
|
803
|
+
Tuple containing list of paths to images and list of labels.
|
|
804
|
+
"""
|
|
805
|
+
all_samples, all_targets = [], []
|
|
806
|
+
with open(split_file) as f:
|
|
807
|
+
for line in f.readlines():
|
|
808
|
+
split_line = line.split(",")
|
|
809
|
+
sample = os.path.join(self.data_path, split_line[0])
|
|
810
|
+
targets = [t.strip() for t in split_line[1:]]
|
|
811
|
+
if len(targets) == 0:
|
|
812
|
+
continue
|
|
813
|
+
all_samples.append(sample)
|
|
814
|
+
all_targets.append(targets)
|
|
815
|
+
return all_samples, all_targets
|
|
816
|
+
|
|
817
|
+
def _prepare_data(self) -> None:
|
|
818
|
+
"""Prepares the data for the data module."""
|
|
819
|
+
if self.images_and_labels_file is not None:
|
|
820
|
+
# Read all images and targets
|
|
821
|
+
all_samples, all_targets = self._read_split(self.images_and_labels_file)
|
|
822
|
+
all_samples = np.array(all_samples).reshape(-1, 1)
|
|
823
|
+
|
|
824
|
+
# Targets to idx
|
|
825
|
+
unique_targets = set(utils.flatten_list(all_targets))
|
|
826
|
+
if self.class_to_idx is None:
|
|
827
|
+
self.class_to_idx = {c: i for i, c in enumerate(unique_targets)}
|
|
828
|
+
|
|
829
|
+
all_targets = [[self.class_to_idx[t] for t in targets] for targets in all_targets]
|
|
830
|
+
|
|
831
|
+
# Transform targets to one-hot
|
|
832
|
+
if self.num_classes is None:
|
|
833
|
+
self.num_classes = len(unique_targets)
|
|
834
|
+
all_targets = np.array([[i in targets for i in range(self.num_classes)] for targets in all_targets]).astype(
|
|
835
|
+
int
|
|
836
|
+
)
|
|
837
|
+
|
|
838
|
+
# Create splits
|
|
839
|
+
samples_train, targets_train, samples_test, targets_test = iterative_train_test_split(
|
|
840
|
+
all_samples, all_targets, test_size=self.test_size
|
|
841
|
+
)
|
|
842
|
+
elif self.train_split_file is not None and self.test_split_file is not None:
|
|
843
|
+
# Both train_split_file and test_split_file are set
|
|
844
|
+
samples_train, targets_train = self._read_split(self.train_split_file)
|
|
845
|
+
samples_test, targets_test = self._read_split(self.test_split_file)
|
|
846
|
+
|
|
847
|
+
# Create class_to_idx from all targets
|
|
848
|
+
unique_targets = set(utils.flatten_list(targets_test + targets_train))
|
|
849
|
+
if self.class_to_idx is None:
|
|
850
|
+
self.class_to_idx = {c: i for i, c in enumerate(unique_targets)}
|
|
851
|
+
|
|
852
|
+
# Transform targets to one-hot
|
|
853
|
+
if self.num_classes is None:
|
|
854
|
+
self.num_classes = len(unique_targets)
|
|
855
|
+
targets_test = [[self.class_to_idx[t] for t in targets] for targets in targets_test]
|
|
856
|
+
targets_test = np.array(
|
|
857
|
+
[[i in targets for i in range(self.num_classes)] for targets in targets_test]
|
|
858
|
+
).astype(int)
|
|
859
|
+
targets_train = [[self.class_to_idx[t] for t in targets] for targets in targets_train]
|
|
860
|
+
targets_train = np.array(
|
|
861
|
+
[[i in targets for i in range(self.num_classes)] for targets in targets_train]
|
|
862
|
+
).astype(int)
|
|
863
|
+
else:
|
|
864
|
+
raise ValueError(
|
|
865
|
+
"Either `images_and_labels_file` or both `train_split_file` and `test_split_file` must be set"
|
|
866
|
+
)
|
|
867
|
+
|
|
868
|
+
if self.val_split_file:
|
|
869
|
+
if not self.test_split_file or not self.train_split_file:
|
|
870
|
+
raise ValueError("Validation split file is specified but no train or test split file is specified.")
|
|
871
|
+
samples_val, targets_val = self._read_split(self.val_split_file)
|
|
872
|
+
targets_val = [[self.class_to_idx[t] for t in targets] for targets in targets_val]
|
|
873
|
+
targets_val = np.array([[i in targets for i in range(self.num_classes)] for targets in targets_val]).astype(
|
|
874
|
+
int
|
|
875
|
+
)
|
|
876
|
+
else:
|
|
877
|
+
samples_train = np.array(samples_train).reshape(-1, 1)
|
|
878
|
+
targets_train = np.array(targets_train).reshape(-1, self.num_classes)
|
|
879
|
+
samples_train, targets_train, samples_val, targets_val = iterative_train_test_split(
|
|
880
|
+
samples_train, targets_train, test_size=self.val_size
|
|
881
|
+
)
|
|
882
|
+
|
|
883
|
+
if isinstance(samples_train, np.ndarray):
|
|
884
|
+
samples_train = samples_train.flatten().tolist()
|
|
885
|
+
if isinstance(samples_val, np.ndarray):
|
|
886
|
+
samples_val = samples_val.flatten().tolist()
|
|
887
|
+
if isinstance(samples_test, np.ndarray):
|
|
888
|
+
samples_test = samples_test.flatten().tolist()
|
|
889
|
+
|
|
890
|
+
if isinstance(targets_train, np.ndarray):
|
|
891
|
+
targets_train = list(targets_train)
|
|
892
|
+
if isinstance(targets_val, np.ndarray):
|
|
893
|
+
targets_val = list(targets_val) # type: ignore[assignment]
|
|
894
|
+
if isinstance(targets_test, np.ndarray):
|
|
895
|
+
targets_test = list(targets_test)
|
|
896
|
+
|
|
897
|
+
# Create data
|
|
898
|
+
train_df = pd.DataFrame({"samples": samples_train, "targets": targets_train})
|
|
899
|
+
train_df["split"] = "train"
|
|
900
|
+
val_df = pd.DataFrame({"samples": samples_val, "targets": targets_val})
|
|
901
|
+
val_df["split"] = "val"
|
|
902
|
+
test_df = pd.DataFrame({"samples": samples_test, "targets": targets_test})
|
|
903
|
+
test_df["split"] = "test"
|
|
904
|
+
self.data = pd.concat([train_df, val_df, test_df], axis=0)
|
|
905
|
+
|
|
906
|
+
def setup(self, stage: str | None = None) -> None:
|
|
907
|
+
"""Setup data module based on stages of training."""
|
|
908
|
+
if stage in ["train", "fit"]:
|
|
909
|
+
train_samples = self.data[self.data["split"] == "train"]["samples"].tolist()
|
|
910
|
+
train_targets = self.data[self.data["split"] == "train"]["targets"].tolist()
|
|
911
|
+
val_samples = self.data[self.data["split"] == "val"]["samples"].tolist()
|
|
912
|
+
val_targets = self.data[self.data["split"] == "val"]["targets"].tolist()
|
|
913
|
+
self.train_dataset = self.dataset(
|
|
914
|
+
samples=train_samples,
|
|
915
|
+
targets=train_targets,
|
|
916
|
+
transform=self.train_transform,
|
|
917
|
+
class_to_idx=self.class_to_idx,
|
|
918
|
+
)
|
|
919
|
+
self.val_dataset = self.dataset(
|
|
920
|
+
samples=val_samples,
|
|
921
|
+
targets=val_targets,
|
|
922
|
+
transform=self.val_transform,
|
|
923
|
+
class_to_idx=self.class_to_idx,
|
|
924
|
+
)
|
|
925
|
+
if stage == "test":
|
|
926
|
+
test_samples = self.data[self.data["split"] == "test"]["samples"].tolist()
|
|
927
|
+
test_targets = self.data[self.data["split"] == "test"]["targets"].tolist()
|
|
928
|
+
self.test_dataset = self.dataset(
|
|
929
|
+
samples=test_samples,
|
|
930
|
+
targets=test_targets,
|
|
931
|
+
transform=self.test_transform,
|
|
932
|
+
class_to_idx=self.class_to_idx,
|
|
933
|
+
)
|
|
934
|
+
|
|
935
|
+
def train_dataloader(self) -> DataLoader:
|
|
936
|
+
"""Returns the train dataloader.
|
|
937
|
+
|
|
938
|
+
Raises:
|
|
939
|
+
ValueError: If train dataset is not initialized.
|
|
940
|
+
|
|
941
|
+
Returns:
|
|
942
|
+
Train dataloader.
|
|
943
|
+
"""
|
|
944
|
+
if not self.train_dataset_available:
|
|
945
|
+
raise ValueError("Train dataset is not initialized")
|
|
946
|
+
return DataLoader(
|
|
947
|
+
self.train_dataset,
|
|
948
|
+
batch_size=self.batch_size,
|
|
949
|
+
shuffle=True,
|
|
950
|
+
num_workers=self.num_workers,
|
|
951
|
+
drop_last=False,
|
|
952
|
+
pin_memory=True,
|
|
953
|
+
persistent_workers=self.num_workers > 0,
|
|
954
|
+
)
|
|
955
|
+
|
|
956
|
+
def val_dataloader(self) -> DataLoader:
|
|
957
|
+
"""Returns the validation dataloader.
|
|
958
|
+
|
|
959
|
+
Raises:
|
|
960
|
+
ValueError: If validation dataset is not initialized.
|
|
961
|
+
|
|
962
|
+
Returns:
|
|
963
|
+
val dataloader.
|
|
964
|
+
"""
|
|
965
|
+
if not self.val_dataset_available:
|
|
966
|
+
raise ValueError("Validation dataset is not initialized")
|
|
967
|
+
return DataLoader(
|
|
968
|
+
self.val_dataset,
|
|
969
|
+
batch_size=self.batch_size,
|
|
970
|
+
shuffle=False,
|
|
971
|
+
num_workers=self.num_workers,
|
|
972
|
+
drop_last=False,
|
|
973
|
+
pin_memory=True,
|
|
974
|
+
persistent_workers=self.num_workers > 0,
|
|
975
|
+
)
|
|
976
|
+
|
|
977
|
+
def test_dataloader(self) -> DataLoader:
|
|
978
|
+
"""Returns the test dataloader.
|
|
979
|
+
|
|
980
|
+
Raises:
|
|
981
|
+
ValueError: If test dataset is not initialized.
|
|
982
|
+
|
|
983
|
+
|
|
984
|
+
Returns:
|
|
985
|
+
test dataloader.
|
|
986
|
+
"""
|
|
987
|
+
if not self.test_dataset_available:
|
|
988
|
+
raise ValueError("Test dataset is not initialized")
|
|
989
|
+
|
|
990
|
+
loader = DataLoader(
|
|
991
|
+
self.test_dataset,
|
|
992
|
+
batch_size=self.batch_size,
|
|
993
|
+
shuffle=False,
|
|
994
|
+
num_workers=self.num_workers,
|
|
995
|
+
drop_last=False,
|
|
996
|
+
pin_memory=True,
|
|
997
|
+
persistent_workers=self.num_workers > 0,
|
|
998
|
+
)
|
|
999
|
+
return loader
|
|
1000
|
+
|
|
1001
|
+
def predict_dataloader(self) -> DataLoader:
|
|
1002
|
+
"""Returns a dataloader used for predictions."""
|
|
1003
|
+
return self.test_dataloader()
|