quadra 0.0.1__py3-none-any.whl → 2.1.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hydra_plugins/quadra_searchpath_plugin.py +30 -0
- quadra/__init__.py +6 -0
- quadra/callbacks/__init__.py +0 -0
- quadra/callbacks/anomalib.py +289 -0
- quadra/callbacks/lightning.py +501 -0
- quadra/callbacks/mlflow.py +291 -0
- quadra/callbacks/scheduler.py +69 -0
- quadra/configs/__init__.py +0 -0
- quadra/configs/backbone/caformer_m36.yaml +8 -0
- quadra/configs/backbone/caformer_s36.yaml +8 -0
- quadra/configs/backbone/convnextv2_base.yaml +8 -0
- quadra/configs/backbone/convnextv2_femto.yaml +8 -0
- quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
- quadra/configs/backbone/dino_vitb8.yaml +12 -0
- quadra/configs/backbone/dino_vits8.yaml +12 -0
- quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
- quadra/configs/backbone/dinov2_vits14.yaml +12 -0
- quadra/configs/backbone/efficientnet_b0.yaml +8 -0
- quadra/configs/backbone/efficientnet_b1.yaml +8 -0
- quadra/configs/backbone/efficientnet_b2.yaml +8 -0
- quadra/configs/backbone/efficientnet_b3.yaml +8 -0
- quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
- quadra/configs/backbone/levit_128s.yaml +8 -0
- quadra/configs/backbone/mnasnet0_5.yaml +9 -0
- quadra/configs/backbone/resnet101.yaml +8 -0
- quadra/configs/backbone/resnet18.yaml +8 -0
- quadra/configs/backbone/resnet18_ssl.yaml +8 -0
- quadra/configs/backbone/resnet50.yaml +8 -0
- quadra/configs/backbone/smp.yaml +9 -0
- quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
- quadra/configs/backbone/unetr.yaml +15 -0
- quadra/configs/backbone/vit16_base.yaml +9 -0
- quadra/configs/backbone/vit16_small.yaml +9 -0
- quadra/configs/backbone/vit16_tiny.yaml +9 -0
- quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
- quadra/configs/callbacks/all.yaml +32 -0
- quadra/configs/callbacks/default.yaml +37 -0
- quadra/configs/callbacks/default_anomalib.yaml +67 -0
- quadra/configs/config.yaml +33 -0
- quadra/configs/core/default.yaml +11 -0
- quadra/configs/datamodule/base/anomaly.yaml +16 -0
- quadra/configs/datamodule/base/classification.yaml +21 -0
- quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
- quadra/configs/datamodule/base/segmentation.yaml +18 -0
- quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
- quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
- quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
- quadra/configs/datamodule/base/ssl.yaml +21 -0
- quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
- quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
- quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
- quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
- quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
- quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
- quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
- quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
- quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
- quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
- quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/base/classification/classification.yaml +73 -0
- quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
- quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
- quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
- quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
- quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
- quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
- quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
- quadra/configs/experiment/base/ssl/byol.yaml +43 -0
- quadra/configs/experiment/base/ssl/dino.yaml +46 -0
- quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
- quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
- quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
- quadra/configs/experiment/custom/cls.yaml +12 -0
- quadra/configs/experiment/default.yaml +15 -0
- quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
- quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
- quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
- quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
- quadra/configs/export/default.yaml +13 -0
- quadra/configs/hydra/anomaly_custom.yaml +15 -0
- quadra/configs/hydra/default.yaml +14 -0
- quadra/configs/inference/default.yaml +26 -0
- quadra/configs/logger/comet.yaml +10 -0
- quadra/configs/logger/csv.yaml +5 -0
- quadra/configs/logger/mlflow.yaml +12 -0
- quadra/configs/logger/tensorboard.yaml +8 -0
- quadra/configs/loss/asl.yaml +7 -0
- quadra/configs/loss/barlow.yaml +2 -0
- quadra/configs/loss/bce.yaml +1 -0
- quadra/configs/loss/byol.yaml +1 -0
- quadra/configs/loss/cross_entropy.yaml +1 -0
- quadra/configs/loss/dino.yaml +8 -0
- quadra/configs/loss/simclr.yaml +2 -0
- quadra/configs/loss/simsiam.yaml +1 -0
- quadra/configs/loss/smp_ce.yaml +3 -0
- quadra/configs/loss/smp_dice.yaml +2 -0
- quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
- quadra/configs/loss/smp_mcc.yaml +2 -0
- quadra/configs/loss/vicreg.yaml +5 -0
- quadra/configs/model/anomalib/cfa.yaml +35 -0
- quadra/configs/model/anomalib/cflow.yaml +30 -0
- quadra/configs/model/anomalib/csflow.yaml +34 -0
- quadra/configs/model/anomalib/dfm.yaml +19 -0
- quadra/configs/model/anomalib/draem.yaml +29 -0
- quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
- quadra/configs/model/anomalib/fastflow.yaml +32 -0
- quadra/configs/model/anomalib/padim.yaml +32 -0
- quadra/configs/model/anomalib/patchcore.yaml +36 -0
- quadra/configs/model/barlow.yaml +16 -0
- quadra/configs/model/byol.yaml +25 -0
- quadra/configs/model/classification.yaml +10 -0
- quadra/configs/model/dino.yaml +26 -0
- quadra/configs/model/logistic_regression.yaml +4 -0
- quadra/configs/model/multilabel_classification.yaml +9 -0
- quadra/configs/model/simclr.yaml +18 -0
- quadra/configs/model/simsiam.yaml +24 -0
- quadra/configs/model/smp.yaml +4 -0
- quadra/configs/model/smp_multiclass.yaml +4 -0
- quadra/configs/model/vicreg.yaml +16 -0
- quadra/configs/optimizer/adam.yaml +5 -0
- quadra/configs/optimizer/adamw.yaml +3 -0
- quadra/configs/optimizer/default.yaml +4 -0
- quadra/configs/optimizer/lars.yaml +8 -0
- quadra/configs/optimizer/sgd.yaml +4 -0
- quadra/configs/scheduler/default.yaml +5 -0
- quadra/configs/scheduler/rop.yaml +5 -0
- quadra/configs/scheduler/step.yaml +3 -0
- quadra/configs/scheduler/warmrestart.yaml +2 -0
- quadra/configs/scheduler/warmup.yaml +6 -0
- quadra/configs/task/anomalib/cfa.yaml +5 -0
- quadra/configs/task/anomalib/cflow.yaml +5 -0
- quadra/configs/task/anomalib/csflow.yaml +5 -0
- quadra/configs/task/anomalib/draem.yaml +5 -0
- quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
- quadra/configs/task/anomalib/fastflow.yaml +5 -0
- quadra/configs/task/anomalib/inference.yaml +3 -0
- quadra/configs/task/anomalib/padim.yaml +5 -0
- quadra/configs/task/anomalib/patchcore.yaml +5 -0
- quadra/configs/task/classification.yaml +6 -0
- quadra/configs/task/classification_evaluation.yaml +6 -0
- quadra/configs/task/default.yaml +1 -0
- quadra/configs/task/segmentation.yaml +9 -0
- quadra/configs/task/segmentation_evaluation.yaml +3 -0
- quadra/configs/task/sklearn_classification.yaml +13 -0
- quadra/configs/task/sklearn_classification_patch.yaml +11 -0
- quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
- quadra/configs/task/sklearn_classification_test.yaml +8 -0
- quadra/configs/task/ssl.yaml +2 -0
- quadra/configs/trainer/lightning_cpu.yaml +36 -0
- quadra/configs/trainer/lightning_gpu.yaml +35 -0
- quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
- quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
- quadra/configs/trainer/lightning_multigpu.yaml +37 -0
- quadra/configs/trainer/sklearn_classification.yaml +7 -0
- quadra/configs/transforms/byol.yaml +47 -0
- quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
- quadra/configs/transforms/default.yaml +37 -0
- quadra/configs/transforms/default_numpy.yaml +24 -0
- quadra/configs/transforms/default_resize.yaml +22 -0
- quadra/configs/transforms/dino.yaml +63 -0
- quadra/configs/transforms/linear_eval.yaml +18 -0
- quadra/datamodules/__init__.py +20 -0
- quadra/datamodules/anomaly.py +180 -0
- quadra/datamodules/base.py +375 -0
- quadra/datamodules/classification.py +1003 -0
- quadra/datamodules/generic/__init__.py +0 -0
- quadra/datamodules/generic/imagenette.py +144 -0
- quadra/datamodules/generic/mnist.py +81 -0
- quadra/datamodules/generic/mvtec.py +58 -0
- quadra/datamodules/generic/oxford_pet.py +163 -0
- quadra/datamodules/patch.py +190 -0
- quadra/datamodules/segmentation.py +742 -0
- quadra/datamodules/ssl.py +140 -0
- quadra/datasets/__init__.py +17 -0
- quadra/datasets/anomaly.py +287 -0
- quadra/datasets/classification.py +241 -0
- quadra/datasets/patch.py +138 -0
- quadra/datasets/segmentation.py +239 -0
- quadra/datasets/ssl.py +110 -0
- quadra/losses/__init__.py +0 -0
- quadra/losses/classification/__init__.py +6 -0
- quadra/losses/classification/asl.py +83 -0
- quadra/losses/classification/focal.py +320 -0
- quadra/losses/classification/prototypical.py +148 -0
- quadra/losses/ssl/__init__.py +17 -0
- quadra/losses/ssl/barlowtwins.py +47 -0
- quadra/losses/ssl/byol.py +37 -0
- quadra/losses/ssl/dino.py +129 -0
- quadra/losses/ssl/hyperspherical.py +45 -0
- quadra/losses/ssl/idmm.py +50 -0
- quadra/losses/ssl/simclr.py +67 -0
- quadra/losses/ssl/simsiam.py +30 -0
- quadra/losses/ssl/vicreg.py +76 -0
- quadra/main.py +46 -0
- quadra/metrics/__init__.py +3 -0
- quadra/metrics/segmentation.py +251 -0
- quadra/models/__init__.py +0 -0
- quadra/models/base.py +151 -0
- quadra/models/classification/__init__.py +8 -0
- quadra/models/classification/backbones.py +149 -0
- quadra/models/classification/base.py +92 -0
- quadra/models/evaluation.py +322 -0
- quadra/modules/__init__.py +0 -0
- quadra/modules/backbone.py +30 -0
- quadra/modules/base.py +312 -0
- quadra/modules/classification/__init__.py +3 -0
- quadra/modules/classification/base.py +331 -0
- quadra/modules/ssl/__init__.py +17 -0
- quadra/modules/ssl/barlowtwins.py +59 -0
- quadra/modules/ssl/byol.py +172 -0
- quadra/modules/ssl/common.py +285 -0
- quadra/modules/ssl/dino.py +186 -0
- quadra/modules/ssl/hyperspherical.py +206 -0
- quadra/modules/ssl/idmm.py +98 -0
- quadra/modules/ssl/simclr.py +73 -0
- quadra/modules/ssl/simsiam.py +68 -0
- quadra/modules/ssl/vicreg.py +67 -0
- quadra/optimizers/__init__.py +4 -0
- quadra/optimizers/lars.py +153 -0
- quadra/optimizers/sam.py +127 -0
- quadra/schedulers/__init__.py +3 -0
- quadra/schedulers/base.py +44 -0
- quadra/schedulers/warmup.py +127 -0
- quadra/tasks/__init__.py +24 -0
- quadra/tasks/anomaly.py +582 -0
- quadra/tasks/base.py +397 -0
- quadra/tasks/classification.py +1264 -0
- quadra/tasks/patch.py +492 -0
- quadra/tasks/segmentation.py +389 -0
- quadra/tasks/ssl.py +560 -0
- quadra/trainers/README.md +3 -0
- quadra/trainers/__init__.py +0 -0
- quadra/trainers/classification.py +179 -0
- quadra/utils/__init__.py +0 -0
- quadra/utils/anomaly.py +112 -0
- quadra/utils/classification.py +618 -0
- quadra/utils/deprecation.py +31 -0
- quadra/utils/evaluation.py +474 -0
- quadra/utils/export.py +579 -0
- quadra/utils/imaging.py +32 -0
- quadra/utils/logger.py +15 -0
- quadra/utils/mlflow.py +98 -0
- quadra/utils/model_manager.py +320 -0
- quadra/utils/models.py +524 -0
- quadra/utils/patch/__init__.py +15 -0
- quadra/utils/patch/dataset.py +1433 -0
- quadra/utils/patch/metrics.py +449 -0
- quadra/utils/patch/model.py +153 -0
- quadra/utils/patch/visualization.py +217 -0
- quadra/utils/resolver.py +42 -0
- quadra/utils/segmentation.py +31 -0
- quadra/utils/tests/__init__.py +0 -0
- quadra/utils/tests/fixtures/__init__.py +1 -0
- quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
- quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
- quadra/utils/tests/fixtures/dataset/classification.py +406 -0
- quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
- quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
- quadra/utils/tests/fixtures/models/__init__.py +3 -0
- quadra/utils/tests/fixtures/models/anomaly.py +89 -0
- quadra/utils/tests/fixtures/models/classification.py +45 -0
- quadra/utils/tests/fixtures/models/segmentation.py +33 -0
- quadra/utils/tests/helpers.py +70 -0
- quadra/utils/tests/models.py +27 -0
- quadra/utils/utils.py +525 -0
- quadra/utils/validator.py +115 -0
- quadra/utils/visualization.py +422 -0
- quadra/utils/vit_explainability.py +349 -0
- quadra-2.1.13.dist-info/LICENSE +201 -0
- quadra-2.1.13.dist-info/METADATA +386 -0
- quadra-2.1.13.dist-info/RECORD +300 -0
- {quadra-0.0.1.dist-info → quadra-2.1.13.dist-info}/WHEEL +1 -1
- quadra-2.1.13.dist-info/entry_points.txt +3 -0
- quadra-0.0.1.dist-info/METADATA +0 -14
- quadra-0.0.1.dist-info/RECORD +0 -4
|
@@ -0,0 +1,375 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import multiprocessing as mp
|
|
4
|
+
import multiprocessing.pool as mpp
|
|
5
|
+
import os
|
|
6
|
+
import pickle as pkl
|
|
7
|
+
import typing
|
|
8
|
+
from collections.abc import Callable, Iterable, Sequence
|
|
9
|
+
from functools import wraps
|
|
10
|
+
from typing import Any, Literal, Union, cast
|
|
11
|
+
|
|
12
|
+
import albumentations
|
|
13
|
+
import numpy as np
|
|
14
|
+
import pandas as pd
|
|
15
|
+
import torch
|
|
16
|
+
import xxhash
|
|
17
|
+
from pytorch_lightning import LightningDataModule
|
|
18
|
+
from tqdm import tqdm
|
|
19
|
+
|
|
20
|
+
from quadra.utils import utils
|
|
21
|
+
|
|
22
|
+
log = utils.get_logger(__name__)
|
|
23
|
+
TrainDataset = Union[torch.utils.data.Dataset, Sequence[torch.utils.data.Dataset]]
|
|
24
|
+
ValDataset = Union[torch.utils.data.Dataset, Sequence[torch.utils.data.Dataset]]
|
|
25
|
+
TestDataset = torch.utils.data.Dataset
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def load_data_from_disk_dec(func):
|
|
29
|
+
"""Load data from disk if it exists."""
|
|
30
|
+
|
|
31
|
+
@wraps(func)
|
|
32
|
+
def wrapper(*args, **kwargs):
|
|
33
|
+
"""Wrapper function to load data from disk if it exists."""
|
|
34
|
+
self = cast(BaseDataModule, args[0])
|
|
35
|
+
self.restore_checkpoint()
|
|
36
|
+
return func(*args, **kwargs)
|
|
37
|
+
|
|
38
|
+
return wrapper
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class DecorateParentMethod(type):
|
|
42
|
+
"""Metaclass to decorate methods of subclasses."""
|
|
43
|
+
|
|
44
|
+
def __new__(cls, name, bases, dct):
|
|
45
|
+
"""Create new decorator for parent class methods."""
|
|
46
|
+
method_decorator_mapper = {
|
|
47
|
+
"setup": load_data_from_disk_dec,
|
|
48
|
+
}
|
|
49
|
+
for method_name, decorator in method_decorator_mapper.items():
|
|
50
|
+
if method_name in dct:
|
|
51
|
+
dct[method_name] = decorator(dct[method_name])
|
|
52
|
+
|
|
53
|
+
return super().__new__(cls, name, bases, dct)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def compute_file_content_hash(path: str, hash_size: Literal[32, 64, 128] = 64) -> str:
|
|
57
|
+
"""Get hash of a file based on its content.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
path: Path to the file.
|
|
61
|
+
hash_size: Size of the hash. Must be one of [32, 64, 128].
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
The hash of the file.
|
|
65
|
+
"""
|
|
66
|
+
with open(path, "rb") as f:
|
|
67
|
+
data = f.read()
|
|
68
|
+
|
|
69
|
+
if hash_size == 32:
|
|
70
|
+
file_hash = xxhash.xxh32(data, seed=42).hexdigest()
|
|
71
|
+
elif hash_size == 64:
|
|
72
|
+
file_hash = xxhash.xxh64(data, seed=42).hexdigest()
|
|
73
|
+
elif hash_size == 128:
|
|
74
|
+
file_hash = xxhash.xxh128(data, seed=42).hexdigest()
|
|
75
|
+
else:
|
|
76
|
+
raise ValueError(f"Invalid hash size {hash_size}. Must be one of [32, 64, 128].")
|
|
77
|
+
|
|
78
|
+
return file_hash
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def compute_file_size_hash(path: str, hash_size: Literal[32, 64, 128] = 64) -> str:
|
|
82
|
+
"""Get hash of a file based on its size.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
path: Path to the file.
|
|
86
|
+
hash_size: Size of the hash. Must be one of [32, 64, 128].
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
The hash of the file.
|
|
90
|
+
"""
|
|
91
|
+
data = str(os.path.getsize(path))
|
|
92
|
+
|
|
93
|
+
if hash_size == 32:
|
|
94
|
+
file_hash = xxhash.xxh32(data, seed=42).hexdigest()
|
|
95
|
+
elif hash_size == 64:
|
|
96
|
+
file_hash = xxhash.xxh64(data, seed=42).hexdigest()
|
|
97
|
+
elif hash_size == 128:
|
|
98
|
+
file_hash = xxhash.xxh128(data, seed=42).hexdigest()
|
|
99
|
+
else:
|
|
100
|
+
raise ValueError(f"Invalid hash size {hash_size}. Must be one of [32, 64, 128].")
|
|
101
|
+
|
|
102
|
+
return file_hash
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
@typing.no_type_check
|
|
106
|
+
def istarmap(self, func: Callable, iterable: Iterable, chunksize: int = 1):
|
|
107
|
+
# pylint: disable=all
|
|
108
|
+
"""Starmap-version of imap."""
|
|
109
|
+
self._check_running()
|
|
110
|
+
if chunksize < 1:
|
|
111
|
+
raise ValueError(f"Chunksize must be 1+, not {chunksize:n}")
|
|
112
|
+
|
|
113
|
+
task_batches = mpp.Pool._get_tasks(func, iterable, chunksize)
|
|
114
|
+
result = mpp.IMapIterator(self)
|
|
115
|
+
self._taskqueue.put((self._guarded_task_generation(result._job, mpp.starmapstar, task_batches), result._set_length))
|
|
116
|
+
return (item for chunk in result for item in chunk)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
# Patch Pool class to include istarmap
|
|
120
|
+
mpp.Pool.istarmap = istarmap # type: ignore[attr-defined]
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
class BaseDataModule(LightningDataModule, metaclass=DecorateParentMethod):
|
|
124
|
+
"""Base class for all data modules.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
data_path: Path to the data main folder.
|
|
128
|
+
name: The name for the data module. Defaults to "base_datamodule".
|
|
129
|
+
num_workers: Number of workers for dataloaders. Defaults to 16.
|
|
130
|
+
batch_size: Batch size. Defaults to 32.
|
|
131
|
+
seed: Random generator seed. Defaults to 42.
|
|
132
|
+
train_transform: Transformations for train dataset.
|
|
133
|
+
Defaults to None.
|
|
134
|
+
val_transform: Transformations for validation dataset.
|
|
135
|
+
Defaults to None.
|
|
136
|
+
test_transform: Transformations for test dataset.
|
|
137
|
+
Defaults to None.
|
|
138
|
+
enable_hashing: Whether to enable hashing of images. Defaults to True.
|
|
139
|
+
hash_size: Size of the hash. Must be one of [32, 64, 128]. Defaults to 64.
|
|
140
|
+
hash_type: Type of hash to use, if content hash is used, the hash is computed on the file content, otherwise
|
|
141
|
+
the hash is computed on the file size which is faster but less safe. Defaults to "content".
|
|
142
|
+
"""
|
|
143
|
+
|
|
144
|
+
def __init__(
|
|
145
|
+
self,
|
|
146
|
+
data_path: str,
|
|
147
|
+
name: str = "base_datamodule",
|
|
148
|
+
num_workers: int = 16,
|
|
149
|
+
batch_size: int = 32,
|
|
150
|
+
seed: int = 42,
|
|
151
|
+
load_aug_images: bool = False,
|
|
152
|
+
aug_name: str | None = None,
|
|
153
|
+
n_aug_to_take: int | None = None,
|
|
154
|
+
replace_str_from: str | None = None,
|
|
155
|
+
replace_str_to: str | None = None,
|
|
156
|
+
train_transform: albumentations.Compose | None = None,
|
|
157
|
+
val_transform: albumentations.Compose | None = None,
|
|
158
|
+
test_transform: albumentations.Compose | None = None,
|
|
159
|
+
enable_hashing: bool = True,
|
|
160
|
+
hash_size: Literal[32, 64, 128] = 64,
|
|
161
|
+
hash_type: Literal["content", "size"] = "content",
|
|
162
|
+
):
|
|
163
|
+
super().__init__()
|
|
164
|
+
self.num_workers = num_workers
|
|
165
|
+
self.batch_size = batch_size
|
|
166
|
+
self.seed = seed
|
|
167
|
+
self.data_path = data_path
|
|
168
|
+
self.name = name
|
|
169
|
+
self.train_transform = train_transform
|
|
170
|
+
self.val_transform = val_transform
|
|
171
|
+
self.test_transform = test_transform
|
|
172
|
+
self.enable_hashing = enable_hashing
|
|
173
|
+
self.hash_size = hash_size
|
|
174
|
+
self.hash_type = hash_type
|
|
175
|
+
|
|
176
|
+
if self.hash_size not in [32, 64, 128]:
|
|
177
|
+
raise ValueError(f"Invalid hash size {self.hash_size}. Must be one of [32, 64, 128].")
|
|
178
|
+
|
|
179
|
+
self.load_aug_images = load_aug_images
|
|
180
|
+
self.aug_name = aug_name
|
|
181
|
+
self.n_aug_to_take = n_aug_to_take
|
|
182
|
+
self.replace_str_from = replace_str_from
|
|
183
|
+
self.replace_str_to = replace_str_to
|
|
184
|
+
self.extra_args: dict[str, Any] = {}
|
|
185
|
+
self.train_dataset: TrainDataset
|
|
186
|
+
self.val_dataset: ValDataset
|
|
187
|
+
self.test_dataset: TestDataset
|
|
188
|
+
self.data: pd.DataFrame
|
|
189
|
+
self.data_folder = "data"
|
|
190
|
+
os.makedirs(self.data_folder, exist_ok=True)
|
|
191
|
+
self.datamodule_checkpoint_file = os.path.join(self.data_folder, "datamodule.pkl")
|
|
192
|
+
self.dataset_file = os.path.join(self.data_folder, "dataset.csv")
|
|
193
|
+
|
|
194
|
+
@property
|
|
195
|
+
def train_data(self) -> pd.DataFrame:
|
|
196
|
+
"""Get train data."""
|
|
197
|
+
if not hasattr(self, "data"):
|
|
198
|
+
raise ValueError("`data` attribute is not set. Cannot load train data.")
|
|
199
|
+
return self.data[self.data["split"] == "train"]
|
|
200
|
+
|
|
201
|
+
@property
|
|
202
|
+
def val_data(self) -> pd.DataFrame:
|
|
203
|
+
"""Get validation data."""
|
|
204
|
+
if not hasattr(self, "data"):
|
|
205
|
+
raise ValueError("`data` attribute is not set. Cannot load val data.")
|
|
206
|
+
return self.data[self.data["split"] == "val"]
|
|
207
|
+
|
|
208
|
+
@property
|
|
209
|
+
def test_data(self) -> pd.DataFrame:
|
|
210
|
+
"""Get test data."""
|
|
211
|
+
if not hasattr(self, "data"):
|
|
212
|
+
raise ValueError("`data` attribute is not set. Cannot load test data.")
|
|
213
|
+
return self.data[self.data["split"] == "test"]
|
|
214
|
+
|
|
215
|
+
def _dataset_available(self, dataset_name: str) -> bool:
|
|
216
|
+
"""Checks if the dataset is available.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
dataset_name : Name of the dataset attribute.
|
|
220
|
+
|
|
221
|
+
Returns:
|
|
222
|
+
True if the dataset is available, False otherwise.
|
|
223
|
+
"""
|
|
224
|
+
available = hasattr(self, dataset_name) and getattr(self, dataset_name) is not None
|
|
225
|
+
if available:
|
|
226
|
+
dataset_attr = getattr(self, dataset_name)
|
|
227
|
+
if isinstance(dataset_attr, list):
|
|
228
|
+
available = all(len(d) > 0 for d in dataset_attr)
|
|
229
|
+
else:
|
|
230
|
+
available = len(dataset_attr) > 0
|
|
231
|
+
return available
|
|
232
|
+
|
|
233
|
+
@property
|
|
234
|
+
def train_dataset_available(self) -> bool:
|
|
235
|
+
"""Checks if the train dataset is available."""
|
|
236
|
+
return self._dataset_available("train_dataset")
|
|
237
|
+
|
|
238
|
+
@property
|
|
239
|
+
def val_dataset_available(self) -> bool:
|
|
240
|
+
"""Checks if the validation dataset is available."""
|
|
241
|
+
return self._dataset_available("val_dataset")
|
|
242
|
+
|
|
243
|
+
@property
|
|
244
|
+
def test_dataset_available(self) -> bool:
|
|
245
|
+
"""Checks if the test dataset is available."""
|
|
246
|
+
return self._dataset_available("test_dataset")
|
|
247
|
+
|
|
248
|
+
def _prepare_data(self) -> None:
|
|
249
|
+
"""Prepares the data, this should have exactly the same logic as the prepare_data method
|
|
250
|
+
of a LightningModule.
|
|
251
|
+
"""
|
|
252
|
+
raise NotImplementedError(
|
|
253
|
+
"This method must be implemented, it should contain all the logic that normally is "
|
|
254
|
+
"contained in the prepare_data method of a LightningModule."
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
def hash_data(self) -> None:
|
|
258
|
+
"""Computes the hash of the files inside the datasets."""
|
|
259
|
+
if not self.enable_hashing:
|
|
260
|
+
return
|
|
261
|
+
|
|
262
|
+
# TODO: We need to find a way to annotate the columns of data.
|
|
263
|
+
paths_and_hash_length = zip(self.data["samples"], [self.hash_size] * len(self.data))
|
|
264
|
+
|
|
265
|
+
with mp.Pool(min(8, mp.cpu_count() - 1)) as pool:
|
|
266
|
+
self.data["hash"] = list(
|
|
267
|
+
tqdm(
|
|
268
|
+
pool.istarmap( # type: ignore[attr-defined]
|
|
269
|
+
compute_file_content_hash if self.hash_type == "content" else compute_file_size_hash,
|
|
270
|
+
paths_and_hash_length,
|
|
271
|
+
),
|
|
272
|
+
total=len(self.data),
|
|
273
|
+
desc="Computing hashes",
|
|
274
|
+
)
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
self.data["hash_type"] = self.hash_type
|
|
278
|
+
|
|
279
|
+
def prepare_data(self) -> None:
|
|
280
|
+
"""Prepares the data, should be overridden by subclasses."""
|
|
281
|
+
if hasattr(self, "data"):
|
|
282
|
+
return
|
|
283
|
+
|
|
284
|
+
self._prepare_data()
|
|
285
|
+
self.hash_data()
|
|
286
|
+
self.save_checkpoint()
|
|
287
|
+
|
|
288
|
+
def __getstate__(self) -> dict[str, Any]:
|
|
289
|
+
"""This method is called when pickling the object.
|
|
290
|
+
It's useful to remove attributes that shouldn't be pickled.
|
|
291
|
+
"""
|
|
292
|
+
state = self.__dict__.copy()
|
|
293
|
+
if "trainer" in state:
|
|
294
|
+
# Lightning injects the trainer in the datamodule, we don't want to pickle it.
|
|
295
|
+
del state["trainer"]
|
|
296
|
+
|
|
297
|
+
return state
|
|
298
|
+
|
|
299
|
+
def save_checkpoint(self) -> None:
|
|
300
|
+
"""Saves the datamodule to disk, utility function that is called from prepare_data. We are required to save
|
|
301
|
+
datamodule to disk because we can't assign attributes to the datamodule in prepare_data when working with
|
|
302
|
+
multiple gpus.
|
|
303
|
+
"""
|
|
304
|
+
if not os.path.exists(self.datamodule_checkpoint_file) and not os.path.exists(self.dataset_file):
|
|
305
|
+
with open(self.datamodule_checkpoint_file, "wb") as f:
|
|
306
|
+
pkl.dump(self, f)
|
|
307
|
+
|
|
308
|
+
self.data.to_csv(self.dataset_file, index=False)
|
|
309
|
+
log.info("Datamodule checkpoint saved to disk.")
|
|
310
|
+
|
|
311
|
+
if "targets" in self.data:
|
|
312
|
+
if isinstance(self.data["targets"].iloc[0], np.ndarray):
|
|
313
|
+
# If we find a numpy array target it's very likely one hot encoded,
|
|
314
|
+
# in that case we just print the number of train/val/test samples
|
|
315
|
+
grouping = ["split"]
|
|
316
|
+
else:
|
|
317
|
+
grouping = ["split", "targets"]
|
|
318
|
+
log.info("Dataset Info:")
|
|
319
|
+
split_order = {"train": 0, "val": 1, "test": 2}
|
|
320
|
+
log.info(
|
|
321
|
+
"\n%s",
|
|
322
|
+
self.data.groupby(grouping)
|
|
323
|
+
.size()
|
|
324
|
+
.to_frame()
|
|
325
|
+
.reset_index()
|
|
326
|
+
.sort_values(by=["split"], key=lambda x: x.map(split_order))
|
|
327
|
+
.rename(columns={0: "count"})
|
|
328
|
+
.to_string(index=False),
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
def restore_checkpoint(self) -> None:
|
|
332
|
+
"""Loads the data from disk, utility function that should be called from setup."""
|
|
333
|
+
if hasattr(self, "data"):
|
|
334
|
+
return
|
|
335
|
+
|
|
336
|
+
if not os.path.isfile(self.datamodule_checkpoint_file):
|
|
337
|
+
raise ValueError(f"Dataset file {self.datamodule_checkpoint_file} does not exist.")
|
|
338
|
+
|
|
339
|
+
with open(self.datamodule_checkpoint_file, "rb") as f:
|
|
340
|
+
checkpoint_datamodule = pkl.load(f)
|
|
341
|
+
for key, value in checkpoint_datamodule.__dict__.items():
|
|
342
|
+
setattr(self, key, value)
|
|
343
|
+
|
|
344
|
+
# TODO: Check if this function can be removed
|
|
345
|
+
def load_augmented_samples(
|
|
346
|
+
self,
|
|
347
|
+
samples: list[str],
|
|
348
|
+
targets: list[Any],
|
|
349
|
+
replace_str_from: str | None = None,
|
|
350
|
+
replace_str_to: str | None = None,
|
|
351
|
+
shuffle: bool = False,
|
|
352
|
+
) -> tuple[list[str], list[str]]:
|
|
353
|
+
"""Loads augmented samples."""
|
|
354
|
+
if self.n_aug_to_take is None:
|
|
355
|
+
raise ValueError("`n_aug_to_take` is not set. Cannot load augmented samples.")
|
|
356
|
+
aug_samples = []
|
|
357
|
+
aug_labels = []
|
|
358
|
+
for sample, label in zip(samples, targets):
|
|
359
|
+
aug_samples.append(sample)
|
|
360
|
+
aug_labels.append(label)
|
|
361
|
+
final_sample = sample
|
|
362
|
+
if replace_str_from is not None and replace_str_to is not None:
|
|
363
|
+
final_sample = final_sample.replace(replace_str_from, replace_str_to)
|
|
364
|
+
base, ext = os.path.splitext(final_sample)
|
|
365
|
+
for k in range(self.n_aug_to_take):
|
|
366
|
+
aug_samples.append(base + "_" + str(k + 1) + ext)
|
|
367
|
+
aug_labels.append(label)
|
|
368
|
+
samples = aug_samples
|
|
369
|
+
targets = aug_labels
|
|
370
|
+
if shuffle:
|
|
371
|
+
idexs = np.arange(len(aug_samples))
|
|
372
|
+
np.random.shuffle(idexs)
|
|
373
|
+
samples = np.array(samples)[idexs].tolist()
|
|
374
|
+
targets = np.array(targets)[idexs].tolist()
|
|
375
|
+
return samples, targets
|