kaiko-eva 0.0.0.dev6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaiko-eva might be problematic. Click here for more details.

Files changed (111) hide show
  1. eva/.DS_Store +0 -0
  2. eva/__init__.py +33 -0
  3. eva/__main__.py +18 -0
  4. eva/__version__.py +25 -0
  5. eva/core/__init__.py +19 -0
  6. eva/core/callbacks/__init__.py +5 -0
  7. eva/core/callbacks/writers/__init__.py +5 -0
  8. eva/core/callbacks/writers/embeddings.py +169 -0
  9. eva/core/callbacks/writers/typings.py +23 -0
  10. eva/core/cli/__init__.py +5 -0
  11. eva/core/cli/cli.py +19 -0
  12. eva/core/cli/logo.py +38 -0
  13. eva/core/cli/setup.py +89 -0
  14. eva/core/data/__init__.py +14 -0
  15. eva/core/data/dataloaders/__init__.py +5 -0
  16. eva/core/data/dataloaders/dataloader.py +80 -0
  17. eva/core/data/datamodules/__init__.py +6 -0
  18. eva/core/data/datamodules/call.py +33 -0
  19. eva/core/data/datamodules/datamodule.py +108 -0
  20. eva/core/data/datamodules/schemas.py +62 -0
  21. eva/core/data/datasets/__init__.py +7 -0
  22. eva/core/data/datasets/base.py +53 -0
  23. eva/core/data/datasets/classification/__init__.py +5 -0
  24. eva/core/data/datasets/classification/embeddings.py +154 -0
  25. eva/core/data/datasets/dataset.py +6 -0
  26. eva/core/data/samplers/__init__.py +5 -0
  27. eva/core/data/samplers/sampler.py +6 -0
  28. eva/core/data/transforms/__init__.py +5 -0
  29. eva/core/data/transforms/dtype/__init__.py +5 -0
  30. eva/core/data/transforms/dtype/array.py +28 -0
  31. eva/core/interface/__init__.py +5 -0
  32. eva/core/interface/interface.py +79 -0
  33. eva/core/metrics/__init__.py +17 -0
  34. eva/core/metrics/average_loss.py +47 -0
  35. eva/core/metrics/binary_balanced_accuracy.py +22 -0
  36. eva/core/metrics/defaults/__init__.py +6 -0
  37. eva/core/metrics/defaults/classification/__init__.py +6 -0
  38. eva/core/metrics/defaults/classification/binary.py +76 -0
  39. eva/core/metrics/defaults/classification/multiclass.py +80 -0
  40. eva/core/metrics/structs/__init__.py +9 -0
  41. eva/core/metrics/structs/collection.py +6 -0
  42. eva/core/metrics/structs/metric.py +6 -0
  43. eva/core/metrics/structs/module.py +115 -0
  44. eva/core/metrics/structs/schemas.py +47 -0
  45. eva/core/metrics/structs/typings.py +15 -0
  46. eva/core/models/__init__.py +13 -0
  47. eva/core/models/modules/__init__.py +7 -0
  48. eva/core/models/modules/head.py +113 -0
  49. eva/core/models/modules/inference.py +37 -0
  50. eva/core/models/modules/module.py +190 -0
  51. eva/core/models/modules/typings.py +23 -0
  52. eva/core/models/modules/utils/__init__.py +6 -0
  53. eva/core/models/modules/utils/batch_postprocess.py +57 -0
  54. eva/core/models/modules/utils/grad.py +23 -0
  55. eva/core/models/networks/__init__.py +6 -0
  56. eva/core/models/networks/_utils.py +25 -0
  57. eva/core/models/networks/mlp.py +69 -0
  58. eva/core/models/networks/transforms/__init__.py +5 -0
  59. eva/core/models/networks/transforms/extract_cls_features.py +25 -0
  60. eva/core/models/networks/wrappers/__init__.py +8 -0
  61. eva/core/models/networks/wrappers/base.py +47 -0
  62. eva/core/models/networks/wrappers/from_function.py +58 -0
  63. eva/core/models/networks/wrappers/huggingface.py +37 -0
  64. eva/core/models/networks/wrappers/onnx.py +47 -0
  65. eva/core/trainers/__init__.py +6 -0
  66. eva/core/trainers/_logging.py +81 -0
  67. eva/core/trainers/_recorder.py +149 -0
  68. eva/core/trainers/_utils.py +12 -0
  69. eva/core/trainers/functional.py +113 -0
  70. eva/core/trainers/trainer.py +97 -0
  71. eva/core/utils/__init__.py +1 -0
  72. eva/core/utils/io/__init__.py +5 -0
  73. eva/core/utils/io/dataframe.py +21 -0
  74. eva/core/utils/multiprocessing.py +44 -0
  75. eva/core/utils/workers.py +21 -0
  76. eva/vision/__init__.py +14 -0
  77. eva/vision/data/__init__.py +5 -0
  78. eva/vision/data/datasets/__init__.py +22 -0
  79. eva/vision/data/datasets/_utils.py +50 -0
  80. eva/vision/data/datasets/_validators.py +44 -0
  81. eva/vision/data/datasets/classification/__init__.py +15 -0
  82. eva/vision/data/datasets/classification/bach.py +174 -0
  83. eva/vision/data/datasets/classification/base.py +103 -0
  84. eva/vision/data/datasets/classification/crc.py +176 -0
  85. eva/vision/data/datasets/classification/mhist.py +106 -0
  86. eva/vision/data/datasets/classification/patch_camelyon.py +203 -0
  87. eva/vision/data/datasets/classification/total_segmentator.py +212 -0
  88. eva/vision/data/datasets/segmentation/__init__.py +6 -0
  89. eva/vision/data/datasets/segmentation/base.py +112 -0
  90. eva/vision/data/datasets/segmentation/total_segmentator.py +212 -0
  91. eva/vision/data/datasets/structs.py +17 -0
  92. eva/vision/data/datasets/vision.py +43 -0
  93. eva/vision/data/transforms/__init__.py +5 -0
  94. eva/vision/data/transforms/common/__init__.py +5 -0
  95. eva/vision/data/transforms/common/resize_and_crop.py +44 -0
  96. eva/vision/models/__init__.py +5 -0
  97. eva/vision/models/networks/__init__.py +6 -0
  98. eva/vision/models/networks/abmil.py +176 -0
  99. eva/vision/models/networks/postprocesses/__init__.py +5 -0
  100. eva/vision/models/networks/postprocesses/cls.py +25 -0
  101. eva/vision/utils/__init__.py +5 -0
  102. eva/vision/utils/io/__init__.py +12 -0
  103. eva/vision/utils/io/_utils.py +29 -0
  104. eva/vision/utils/io/image.py +54 -0
  105. eva/vision/utils/io/nifti.py +50 -0
  106. eva/vision/utils/io/text.py +18 -0
  107. kaiko_eva-0.0.0.dev6.dist-info/METADATA +393 -0
  108. kaiko_eva-0.0.0.dev6.dist-info/RECORD +111 -0
  109. kaiko_eva-0.0.0.dev6.dist-info/WHEEL +4 -0
  110. kaiko_eva-0.0.0.dev6.dist-info/entry_points.txt +4 -0
  111. kaiko_eva-0.0.0.dev6.dist-info/licenses/LICENSE +201 -0
eva/.DS_Store ADDED
Binary file
eva/__init__.py ADDED
@@ -0,0 +1,33 @@
1
+ """eva public API."""
2
+
3
+ from eva.core import (
4
+ DataLoader,
5
+ DataloadersSchema,
6
+ DataModule,
7
+ DatasetsSchema,
8
+ HeadModule,
9
+ InferenceModule,
10
+ Interface,
11
+ Trainer,
12
+ callbacks,
13
+ data,
14
+ metrics,
15
+ models,
16
+ )
17
+ from eva.core.data import datasets
18
+
19
+ __all__ = [
20
+ "DataLoader",
21
+ "DataloadersSchema",
22
+ "DataModule",
23
+ "DatasetsSchema",
24
+ "HeadModule",
25
+ "InferenceModule",
26
+ "Interface",
27
+ "Trainer",
28
+ "callbacks",
29
+ "data",
30
+ "metrics",
31
+ "models",
32
+ "datasets",
33
+ ]
eva/__main__.py ADDED
@@ -0,0 +1,18 @@
1
+ """Main entry-point module."""
2
+
3
+ from eva.core.cli import cli
4
+
5
+
6
+ def main() -> None:
7
+ """Main entry-point.
8
+
9
+ The CLI fetches the input arguments from `sys.argv`.
10
+
11
+ For usage information, execute:
12
+ $ eva --help
13
+ """
14
+ cli()
15
+
16
+
17
+ if __name__ == "__main__":
18
+ main()
eva/__version__.py ADDED
@@ -0,0 +1,25 @@
1
+ """Fetches the version of the library."""
2
+
3
+ from importlib import metadata
4
+
5
+
6
+ def _fetch_version(package_name: str) -> str:
7
+ """Fetches the version of an installed package.
8
+
9
+ If it fails to do so, it returns a "*", indicating
10
+ that the package has been installed as editable.
11
+
12
+ Args:
13
+ package_name: The name of the package to fetch
14
+ the version of.
15
+
16
+ Returns:
17
+ A string representing the version of the library.
18
+ """
19
+ try:
20
+ return metadata.version(package_name)
21
+ except metadata.PackageNotFoundError:
22
+ return "*"
23
+
24
+
25
+ __version__ = _fetch_version("eva")
eva/core/__init__.py ADDED
@@ -0,0 +1,19 @@
1
+ """eva core API."""
2
+
3
+ from eva.core.cli import cli
4
+ from eva.core.data import DataLoader, DataloadersSchema, DataModule, DatasetsSchema
5
+ from eva.core.interface import Interface
6
+ from eva.core.models import HeadModule, InferenceModule
7
+ from eva.core.trainers import Trainer
8
+
9
+ __all__ = [
10
+ "cli",
11
+ "DataLoader",
12
+ "DataloadersSchema",
13
+ "DataModule",
14
+ "DatasetsSchema",
15
+ "Interface",
16
+ "HeadModule",
17
+ "InferenceModule",
18
+ "Trainer",
19
+ ]
@@ -0,0 +1,5 @@
1
+ """Callbacks API."""
2
+
3
+ from eva.core.callbacks.writers import EmbeddingsWriter
4
+
5
+ __all__ = ["EmbeddingsWriter"]
@@ -0,0 +1,5 @@
1
+ """Callbacks API."""
2
+
3
+ from eva.core.callbacks.writers.embeddings import EmbeddingsWriter
4
+
5
+ __all__ = ["EmbeddingsWriter"]
@@ -0,0 +1,169 @@
1
+ """Embeddings writer."""
2
+
3
+ import csv
4
+ import io
5
+ import os
6
+ from typing import Any, Dict, Sequence
7
+
8
+ import lightning.pytorch as pl
9
+ import torch
10
+ from lightning.pytorch import callbacks
11
+ from loguru import logger
12
+ from torch import multiprocessing, nn
13
+ from typing_extensions import override
14
+
15
+ from eva.core.callbacks.writers.typings import QUEUE_ITEM
16
+ from eva.core.models.modules.typings import INPUT_BATCH
17
+ from eva.core.utils import multiprocessing as eva_multiprocessing
18
+
19
+
20
+ class EmbeddingsWriter(callbacks.BasePredictionWriter):
21
+ """Callback for writing generated embeddings to disk."""
22
+
23
+ def __init__(
24
+ self,
25
+ output_dir: str,
26
+ backbone: nn.Module | None = None,
27
+ dataloader_idx_map: Dict[int, str] | None = None,
28
+ group_key: str | None = None,
29
+ overwrite: bool = True,
30
+ ) -> None:
31
+ """Initializes a new EmbeddingsWriter instance.
32
+
33
+ This callback writes the embedding files in a separate process to avoid blocking the
34
+ main process where the model forward pass is executed.
35
+
36
+ Args:
37
+ output_dir: The directory where the embeddings will be saved.
38
+ backbone: A model to be used as feature extractor. If `None`,
39
+ it will be expected that the input batch returns the features directly.
40
+ dataloader_idx_map: A dictionary mapping dataloader indices to their respective
41
+ names (e.g. train, val, test).
42
+ group_key: The metadata key to group the embeddings by. If specified, the
43
+ embedding files will be saved in subdirectories named after the group_key.
44
+ If specified, the key must be present in the metadata of the input batch.
45
+ overwrite: Whether to overwrite the output directory. Defaults to True.
46
+ """
47
+ super().__init__(write_interval="batch")
48
+
49
+ self._output_dir = output_dir
50
+ self._backbone = backbone
51
+ self._dataloader_idx_map = dataloader_idx_map or {}
52
+ self._group_key = group_key
53
+ self._overwrite = overwrite
54
+
55
+ self._write_queue: multiprocessing.Queue
56
+ self._write_process: eva_multiprocessing.Process
57
+
58
+ @override
59
+ def on_predict_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
60
+ os.makedirs(self._output_dir, exist_ok=self._overwrite)
61
+ self._initialize_write_process()
62
+ self._write_process.start()
63
+
64
+ if self._backbone is not None:
65
+ self._backbone = self._backbone.to(pl_module.device)
66
+ self._backbone.eval()
67
+
68
+ @override
69
+ def write_on_batch_end(
70
+ self,
71
+ trainer: pl.Trainer,
72
+ pl_module: pl.LightningModule,
73
+ prediction: Any,
74
+ batch_indices: Sequence[int],
75
+ batch: INPUT_BATCH,
76
+ batch_idx: int,
77
+ dataloader_idx: int,
78
+ ) -> None:
79
+ dataset = trainer.predict_dataloaders[dataloader_idx].dataset # type: ignore
80
+ _, targets, metadata = INPUT_BATCH(*batch)
81
+ split = self._dataloader_idx_map.get(dataloader_idx)
82
+
83
+ embeddings = self._get_embeddings(prediction)
84
+ for local_idx, global_idx in enumerate(batch_indices[: len(embeddings)]):
85
+ input_name, save_name = self._construct_save_name(
86
+ dataset.filename(global_idx), metadata, local_idx
87
+ )
88
+ embeddings_buffer, target_buffer = io.BytesIO(), io.BytesIO()
89
+ torch.save(embeddings[local_idx].clone(), embeddings_buffer)
90
+ torch.save(targets[local_idx], target_buffer) # type: ignore
91
+ item = QUEUE_ITEM(embeddings_buffer, target_buffer, input_name, save_name, split)
92
+ self._write_queue.put(item)
93
+
94
+ self._write_process.check_exceptions()
95
+
96
+ @override
97
+ def on_predict_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
98
+ self._write_queue.put(None)
99
+ self._write_process.join()
100
+ logger.info(f"Predictions and manifest saved to {self._output_dir}")
101
+
102
+ def _initialize_write_process(self) -> None:
103
+ self._write_queue = multiprocessing.Queue()
104
+ self._write_process = eva_multiprocessing.Process(
105
+ target=_process_write_queue, args=(self._write_queue, self._output_dir, self._overwrite)
106
+ )
107
+
108
+ def _get_embeddings(self, prediction: torch.Tensor) -> torch.Tensor:
109
+ """Returns the embeddings from predictions."""
110
+ if self._backbone is None:
111
+ return prediction
112
+
113
+ with torch.no_grad():
114
+ return self._backbone(prediction)
115
+
116
+ def _construct_save_name(self, input_name, metadata, local_idx):
117
+ group_name = metadata[self._group_key][local_idx] if self._group_key else None
118
+ save_name = os.path.splitext(input_name)[0] + ".pt"
119
+ if group_name:
120
+ save_name = os.path.join(group_name, save_name)
121
+ return input_name, save_name
122
+
123
+
124
+ def _process_write_queue(
125
+ write_queue: multiprocessing.Queue, output_dir: str, overwrite: bool = False
126
+ ) -> None:
127
+ manifest_file, manifest_writer = _init_manifest(output_dir, overwrite)
128
+ while True:
129
+ item = write_queue.get()
130
+ if item is None:
131
+ break
132
+
133
+ prediction_buffer, target_buffer, input_name, save_name, split = QUEUE_ITEM(*item)
134
+ _save_prediction(prediction_buffer, save_name, output_dir)
135
+ _update_manifest(target_buffer, input_name, save_name, split, manifest_writer)
136
+
137
+ manifest_file.close()
138
+
139
+
140
+ def _save_prediction(prediction_buffer: io.BytesIO, save_name: str, output_dir: str) -> None:
141
+ save_path = os.path.join(output_dir, save_name)
142
+ prediction = torch.load(io.BytesIO(prediction_buffer.getbuffer()), map_location="cpu")
143
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
144
+ torch.save(prediction, save_path)
145
+
146
+
147
+ def _init_manifest(output_dir: str, overwrite: bool = False) -> tuple[io.TextIOWrapper, Any]:
148
+ manifest_path = os.path.join(output_dir, "manifest.csv")
149
+ if os.path.exists(manifest_path) and not overwrite:
150
+ raise FileExistsError(
151
+ f"Manifest file already exists at {manifest_path}. This likely means that the "
152
+ "embeddings have been computed before. Consider using `eva fit` instead "
153
+ "of `eva predict_fit` or `eva predict`."
154
+ )
155
+ manifest_file = open(manifest_path, "w", newline="")
156
+ manifest_writer = csv.writer(manifest_file)
157
+ manifest_writer.writerow(["origin", "embeddings", "target", "split"])
158
+ return manifest_file, manifest_writer
159
+
160
+
161
+ def _update_manifest(
162
+ target_buffer: io.BytesIO,
163
+ input_name: str,
164
+ save_name: str,
165
+ split: str | None,
166
+ manifest_writer,
167
+ ) -> None:
168
+ target = torch.load(io.BytesIO(target_buffer.getbuffer()), map_location="cpu")
169
+ manifest_writer.writerow([input_name, save_name, target.item(), split])
@@ -0,0 +1,23 @@
1
+ """Typing definitions for the writer callback functions."""
2
+
3
+ import io
4
+ from typing import NamedTuple
5
+
6
+
7
+ class QUEUE_ITEM(NamedTuple):
8
+ """The default input batch data scheme."""
9
+
10
+ prediction_buffer: io.BytesIO
11
+ """IO buffer containing the prediction tensor"""
12
+
13
+ target_buffer: io.BytesIO
14
+ """IO buffer containing the target tensor"""
15
+
16
+ input_name: str
17
+ """Name of the original input file that was used to generate the embedding."""
18
+
19
+ save_name: str
20
+ """Name to store the generated embedding"""
21
+
22
+ split: str | None
23
+ """The dataset split the item belongs to (e.g. train, val, test)."""
@@ -0,0 +1,5 @@
1
+ """CLI API."""
2
+
3
+ from eva.core.cli.cli import cli
4
+
5
+ __all__ = ["cli"]
eva/core/cli/cli.py ADDED
@@ -0,0 +1,19 @@
1
+ """eva's main cli manager."""
2
+
3
+ import jsonargparse
4
+
5
+ from eva.__version__ import __version__
6
+ from eva.core import interface
7
+ from eva.core.cli import logo, setup
8
+
9
+
10
+ def cli() -> object:
11
+ """Main CLI factory."""
12
+ logo.print_cli_logo()
13
+ setup.setup()
14
+ return jsonargparse.CLI(
15
+ interface.Interface,
16
+ parser_mode="omegaconf",
17
+ fail_untyped=False,
18
+ version=f"version {__version__}",
19
+ )
eva/core/cli/logo.py ADDED
@@ -0,0 +1,38 @@
1
+ """CLI logos."""
2
+
3
+ _EVA_LOGO: str = r"""
4
+ _____ ____ _
5
+ / _ \ \ / / _` |
6
+ | __/\ V / (_| |
7
+ \___| \_/ \__,_|
8
+
9
+ kaiko.ai
10
+ """
11
+
12
+
13
+ ANSI_LOGO_COLOR = "\33[0;34m"
14
+ """ANSI main logo color."""
15
+
16
+ ANSI_COLOR_RESET = "\33[0m"
17
+ """ANSI color reset code."""
18
+
19
+
20
+ def _print_logo(
21
+ logo: str,
22
+ prefix: str = "",
23
+ suffix: str = "",
24
+ ) -> None:
25
+ r"""Prints an ASCII terminal art logo in terminal.
26
+
27
+ Args:
28
+ logo: The desired art logo to print.
29
+ prefix: Characters to add before the logo.
30
+ suffix: Characters to add after the logo.
31
+ """
32
+ colored_logo = f"{ANSI_LOGO_COLOR}{logo}{ANSI_COLOR_RESET}"
33
+ print(prefix + colored_logo + suffix)
34
+
35
+
36
+ def print_cli_logo() -> None:
37
+ """Prints the CLI logo."""
38
+ _print_logo(_EVA_LOGO, suffix="\n")
eva/core/cli/setup.py ADDED
@@ -0,0 +1,89 @@
1
+ """Operations which are executed when the CLI is triggered."""
2
+
3
+ import os
4
+ import sys
5
+ import warnings
6
+
7
+ import jsonargparse
8
+ import yaml
9
+ from lightning_fabric.utilities import seed as pl_seed
10
+ from loguru import logger
11
+
12
+ from eva.core.utils import workers
13
+
14
+
15
+ def _configure_random_seed(seed: int | None = None) -> None:
16
+ """Sets the global random seed.
17
+
18
+ Args:
19
+ seed: The seed number to use. If `None`, it will read the seed from
20
+ `EVA_GLOBAL_SEED` env variable. If `None` and the `EVA_GLOBAL_SEED`
21
+ env variable is not set, then the seed defaults to `42`. If `None`
22
+ and the `EVA_GLOBAL_SEED` is set to `False`, it will not set the seed.
23
+ """
24
+ effective_seed = seed or os.environ.get("EVA_GLOBAL_SEED", default=42)
25
+ if isinstance(effective_seed, str):
26
+ effective_seed = yaml.safe_load(effective_seed)
27
+ if not isinstance(effective_seed, (bool, int)):
28
+ raise ValueError(
29
+ f"Invalid 'EVA_GLOBAL_SEED' value '{effective_seed}'. "
30
+ "It should be an integer or a boolean value."
31
+ )
32
+
33
+ if isinstance(effective_seed, bool) and effective_seed is False:
34
+ return
35
+
36
+ pl_seed.seed_everything(seed=int(effective_seed), workers=True)
37
+
38
+
39
+ def _configure_jsonargparse() -> None:
40
+ """Configures the `jsonargparse` library."""
41
+ jsonargparse.set_config_read_mode(
42
+ urls_enabled=True,
43
+ fsspec_enabled=True,
44
+ )
45
+
46
+
47
+ def _initialize_logger() -> None:
48
+ """Initializes, manipulates and customizes the logger.
49
+
50
+ This customizable logger can be used by just importing `loguru`
51
+ from everywhere as follows:
52
+ >>> from loguru import logger
53
+ >>> logger.info(...)
54
+ """
55
+ logger.remove()
56
+ logger.add(
57
+ sys.stderr,
58
+ format="<blue>{time:HH:mm:ss}</blue>"
59
+ " :: <bold><level>{level}</level></bold>"
60
+ " :: {message}",
61
+ colorize=True,
62
+ level="INFO",
63
+ )
64
+
65
+
66
+ def _suppress_warnings() -> None:
67
+ """Suppress all warnings from all subprocesses."""
68
+ if not sys.warnoptions:
69
+ warnings.simplefilter("ignore")
70
+ os.environ["PYTHONWARNINGS"] = "ignore"
71
+
72
+
73
+ def _enable_mps_fallback() -> None:
74
+ """It enables the MPS fallback in torch.
75
+
76
+ Note that this action has to take place before importing torch.
77
+ """
78
+ if os.environ.get("PYTORCH_ENABLE_MPS_FALLBACK") is None:
79
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
80
+
81
+
82
+ @workers.main_worker_only
83
+ def setup() -> None:
84
+ """Sets up the environment before the module is imported."""
85
+ _configure_random_seed()
86
+ _configure_jsonargparse()
87
+ _initialize_logger()
88
+ _suppress_warnings()
89
+ _enable_mps_fallback()
@@ -0,0 +1,14 @@
1
+ """Data API."""
2
+
3
+ from eva.core.data.dataloaders import DataLoader
4
+ from eva.core.data.datamodules import DataloadersSchema, DataModule, DatasetsSchema
5
+ from eva.core.data.datasets import Dataset, TorchDataset
6
+
7
+ __all__ = [
8
+ "DataLoader",
9
+ "DataloadersSchema",
10
+ "DataModule",
11
+ "DatasetsSchema",
12
+ "Dataset",
13
+ "TorchDataset",
14
+ ]
@@ -0,0 +1,5 @@
1
+ """Dataloaders API."""
2
+
3
+ from eva.core.data.dataloaders.dataloader import DataLoader
4
+
5
+ __all__ = ["DataLoader"]
@@ -0,0 +1,80 @@
1
+ """Core Dataloader module."""
2
+
3
+ import dataclasses
4
+ import multiprocessing
5
+ from typing import Callable
6
+
7
+ from torch.utils.data import dataloader
8
+
9
+ from eva.core.data import datasets, samplers
10
+
11
+
12
+ @dataclasses.dataclass
13
+ class DataLoader:
14
+ """The `DataLoader` combines a dataset and a sampler.
15
+
16
+ It provides an iterable over the given dataset.
17
+ """
18
+
19
+ batch_size: int | None = 1
20
+ """How many samples per batch to load.
21
+
22
+ Set to `None` for iterable dataset where dataset produces batches.
23
+ """
24
+
25
+ shuffle: bool = False
26
+ """Whether to shuffle the data at every epoch."""
27
+
28
+ sampler: samplers.Sampler | None = None
29
+ """Defines the strategy to draw samples from the dataset.
30
+
31
+ Can be any Iterable with `__len__` implemented. If specified, shuffle must
32
+ not be specified.
33
+ """
34
+
35
+ batch_sampler: samplers.Sampler | None = None
36
+ """Like `sampler`, but returns a batch of indices at a time.
37
+
38
+ Mutually exclusive with `batch_size`, `shuffle`, `sampler` and `drop_last`.
39
+ """
40
+
41
+ num_workers: int = multiprocessing.cpu_count()
42
+ """How many workers to use for loading the data.
43
+
44
+ By default, it will use the number of CPUs available.
45
+ """
46
+
47
+ collate_fn: Callable | None = None
48
+ """The batching process."""
49
+
50
+ pin_memory: bool = True
51
+ """Will copy Tensors into CUDA pinned memory before returning them."""
52
+
53
+ drop_last: bool = False
54
+ """Drops the last incomplete batch."""
55
+
56
+ persistent_workers: bool = True
57
+ """Will keep the worker processes after a dataset has been consumed once."""
58
+
59
+ prefetch_factor: int | None = 2
60
+ """Number of batches loaded in advance by each worker."""
61
+
62
+ def __call__(self, dataset: datasets.TorchDataset) -> dataloader.DataLoader:
63
+ """Returns the dataloader on the provided dataset.
64
+
65
+ Args:
66
+ dataset: dataset from which to load the data.
67
+ """
68
+ return dataloader.DataLoader(
69
+ dataset=dataset,
70
+ batch_size=self.batch_size,
71
+ shuffle=self.shuffle,
72
+ sampler=self.sampler,
73
+ batch_sampler=self.batch_sampler,
74
+ num_workers=self.num_workers,
75
+ collate_fn=self.collate_fn,
76
+ pin_memory=self.pin_memory,
77
+ drop_last=self.drop_last,
78
+ persistent_workers=self.persistent_workers,
79
+ prefetch_factor=self.prefetch_factor,
80
+ )
@@ -0,0 +1,6 @@
1
+ """Datamodules API."""
2
+
3
+ from eva.core.data.datamodules.datamodule import DataModule
4
+ from eva.core.data.datamodules.schemas import DataloadersSchema, DatasetsSchema
5
+
6
+ __all__ = ["DataModule", "DataloadersSchema", "DatasetsSchema"]
@@ -0,0 +1,33 @@
1
+ """Helper dataset calling methods."""
2
+
3
+ from typing import Any, Iterable
4
+
5
+ from eva.core.data import datasets as datasets_lib
6
+
7
+
8
+ def call_method_if_exists(objects: Iterable[Any], /, method: str) -> None:
9
+ """Calls a desired `method` from the datasets if exists.
10
+
11
+ Args:
12
+ objects: An iterable of objects.
13
+ method: The dataset method name to call if exists.
14
+ """
15
+ for _object in _recursive_iter(objects):
16
+ if hasattr(_object, method):
17
+ fn = getattr(_object, method)
18
+ fn()
19
+
20
+
21
+ def _recursive_iter(objects: Iterable[Any], /) -> Iterable[datasets_lib.TorchDataset]:
22
+ """Iterates thought an iterable of objects and their respective iterable values.
23
+
24
+ Args:
25
+ objects: The objects to iterate from.
26
+
27
+ Yields:
28
+ The individual object class.
29
+ """
30
+ for _object in objects:
31
+ if not isinstance(_object, list):
32
+ _object = [_object]
33
+ yield from _object