kaiko-eva 0.0.0.dev6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaiko-eva might be problematic. Click here for more details.
- eva/.DS_Store +0 -0
- eva/__init__.py +33 -0
- eva/__main__.py +18 -0
- eva/__version__.py +25 -0
- eva/core/__init__.py +19 -0
- eva/core/callbacks/__init__.py +5 -0
- eva/core/callbacks/writers/__init__.py +5 -0
- eva/core/callbacks/writers/embeddings.py +169 -0
- eva/core/callbacks/writers/typings.py +23 -0
- eva/core/cli/__init__.py +5 -0
- eva/core/cli/cli.py +19 -0
- eva/core/cli/logo.py +38 -0
- eva/core/cli/setup.py +89 -0
- eva/core/data/__init__.py +14 -0
- eva/core/data/dataloaders/__init__.py +5 -0
- eva/core/data/dataloaders/dataloader.py +80 -0
- eva/core/data/datamodules/__init__.py +6 -0
- eva/core/data/datamodules/call.py +33 -0
- eva/core/data/datamodules/datamodule.py +108 -0
- eva/core/data/datamodules/schemas.py +62 -0
- eva/core/data/datasets/__init__.py +7 -0
- eva/core/data/datasets/base.py +53 -0
- eva/core/data/datasets/classification/__init__.py +5 -0
- eva/core/data/datasets/classification/embeddings.py +154 -0
- eva/core/data/datasets/dataset.py +6 -0
- eva/core/data/samplers/__init__.py +5 -0
- eva/core/data/samplers/sampler.py +6 -0
- eva/core/data/transforms/__init__.py +5 -0
- eva/core/data/transforms/dtype/__init__.py +5 -0
- eva/core/data/transforms/dtype/array.py +28 -0
- eva/core/interface/__init__.py +5 -0
- eva/core/interface/interface.py +79 -0
- eva/core/metrics/__init__.py +17 -0
- eva/core/metrics/average_loss.py +47 -0
- eva/core/metrics/binary_balanced_accuracy.py +22 -0
- eva/core/metrics/defaults/__init__.py +6 -0
- eva/core/metrics/defaults/classification/__init__.py +6 -0
- eva/core/metrics/defaults/classification/binary.py +76 -0
- eva/core/metrics/defaults/classification/multiclass.py +80 -0
- eva/core/metrics/structs/__init__.py +9 -0
- eva/core/metrics/structs/collection.py +6 -0
- eva/core/metrics/structs/metric.py +6 -0
- eva/core/metrics/structs/module.py +115 -0
- eva/core/metrics/structs/schemas.py +47 -0
- eva/core/metrics/structs/typings.py +15 -0
- eva/core/models/__init__.py +13 -0
- eva/core/models/modules/__init__.py +7 -0
- eva/core/models/modules/head.py +113 -0
- eva/core/models/modules/inference.py +37 -0
- eva/core/models/modules/module.py +190 -0
- eva/core/models/modules/typings.py +23 -0
- eva/core/models/modules/utils/__init__.py +6 -0
- eva/core/models/modules/utils/batch_postprocess.py +57 -0
- eva/core/models/modules/utils/grad.py +23 -0
- eva/core/models/networks/__init__.py +6 -0
- eva/core/models/networks/_utils.py +25 -0
- eva/core/models/networks/mlp.py +69 -0
- eva/core/models/networks/transforms/__init__.py +5 -0
- eva/core/models/networks/transforms/extract_cls_features.py +25 -0
- eva/core/models/networks/wrappers/__init__.py +8 -0
- eva/core/models/networks/wrappers/base.py +47 -0
- eva/core/models/networks/wrappers/from_function.py +58 -0
- eva/core/models/networks/wrappers/huggingface.py +37 -0
- eva/core/models/networks/wrappers/onnx.py +47 -0
- eva/core/trainers/__init__.py +6 -0
- eva/core/trainers/_logging.py +81 -0
- eva/core/trainers/_recorder.py +149 -0
- eva/core/trainers/_utils.py +12 -0
- eva/core/trainers/functional.py +113 -0
- eva/core/trainers/trainer.py +97 -0
- eva/core/utils/__init__.py +1 -0
- eva/core/utils/io/__init__.py +5 -0
- eva/core/utils/io/dataframe.py +21 -0
- eva/core/utils/multiprocessing.py +44 -0
- eva/core/utils/workers.py +21 -0
- eva/vision/__init__.py +14 -0
- eva/vision/data/__init__.py +5 -0
- eva/vision/data/datasets/__init__.py +22 -0
- eva/vision/data/datasets/_utils.py +50 -0
- eva/vision/data/datasets/_validators.py +44 -0
- eva/vision/data/datasets/classification/__init__.py +15 -0
- eva/vision/data/datasets/classification/bach.py +174 -0
- eva/vision/data/datasets/classification/base.py +103 -0
- eva/vision/data/datasets/classification/crc.py +176 -0
- eva/vision/data/datasets/classification/mhist.py +106 -0
- eva/vision/data/datasets/classification/patch_camelyon.py +203 -0
- eva/vision/data/datasets/classification/total_segmentator.py +212 -0
- eva/vision/data/datasets/segmentation/__init__.py +6 -0
- eva/vision/data/datasets/segmentation/base.py +112 -0
- eva/vision/data/datasets/segmentation/total_segmentator.py +212 -0
- eva/vision/data/datasets/structs.py +17 -0
- eva/vision/data/datasets/vision.py +43 -0
- eva/vision/data/transforms/__init__.py +5 -0
- eva/vision/data/transforms/common/__init__.py +5 -0
- eva/vision/data/transforms/common/resize_and_crop.py +44 -0
- eva/vision/models/__init__.py +5 -0
- eva/vision/models/networks/__init__.py +6 -0
- eva/vision/models/networks/abmil.py +176 -0
- eva/vision/models/networks/postprocesses/__init__.py +5 -0
- eva/vision/models/networks/postprocesses/cls.py +25 -0
- eva/vision/utils/__init__.py +5 -0
- eva/vision/utils/io/__init__.py +12 -0
- eva/vision/utils/io/_utils.py +29 -0
- eva/vision/utils/io/image.py +54 -0
- eva/vision/utils/io/nifti.py +50 -0
- eva/vision/utils/io/text.py +18 -0
- kaiko_eva-0.0.0.dev6.dist-info/METADATA +393 -0
- kaiko_eva-0.0.0.dev6.dist-info/RECORD +111 -0
- kaiko_eva-0.0.0.dev6.dist-info/WHEEL +4 -0
- kaiko_eva-0.0.0.dev6.dist-info/entry_points.txt +4 -0
- kaiko_eva-0.0.0.dev6.dist-info/licenses/LICENSE +201 -0
eva/.DS_Store
ADDED
|
Binary file
|
eva/__init__.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
"""eva public API."""
|
|
2
|
+
|
|
3
|
+
from eva.core import (
|
|
4
|
+
DataLoader,
|
|
5
|
+
DataloadersSchema,
|
|
6
|
+
DataModule,
|
|
7
|
+
DatasetsSchema,
|
|
8
|
+
HeadModule,
|
|
9
|
+
InferenceModule,
|
|
10
|
+
Interface,
|
|
11
|
+
Trainer,
|
|
12
|
+
callbacks,
|
|
13
|
+
data,
|
|
14
|
+
metrics,
|
|
15
|
+
models,
|
|
16
|
+
)
|
|
17
|
+
from eva.core.data import datasets
|
|
18
|
+
|
|
19
|
+
__all__ = [
|
|
20
|
+
"DataLoader",
|
|
21
|
+
"DataloadersSchema",
|
|
22
|
+
"DataModule",
|
|
23
|
+
"DatasetsSchema",
|
|
24
|
+
"HeadModule",
|
|
25
|
+
"InferenceModule",
|
|
26
|
+
"Interface",
|
|
27
|
+
"Trainer",
|
|
28
|
+
"callbacks",
|
|
29
|
+
"data",
|
|
30
|
+
"metrics",
|
|
31
|
+
"models",
|
|
32
|
+
"datasets",
|
|
33
|
+
]
|
eva/__main__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
"""Main entry-point module."""
|
|
2
|
+
|
|
3
|
+
from eva.core.cli import cli
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def main() -> None:
|
|
7
|
+
"""Main entry-point.
|
|
8
|
+
|
|
9
|
+
The CLI fetches the input arguments from `sys.argv`.
|
|
10
|
+
|
|
11
|
+
For usage information, execute:
|
|
12
|
+
$ eva --help
|
|
13
|
+
"""
|
|
14
|
+
cli()
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
if __name__ == "__main__":
|
|
18
|
+
main()
|
eva/__version__.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
"""Fetches the version of the library."""
|
|
2
|
+
|
|
3
|
+
from importlib import metadata
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def _fetch_version(package_name: str) -> str:
|
|
7
|
+
"""Fetches the version of an installed package.
|
|
8
|
+
|
|
9
|
+
If it fails to do so, it returns a "*", indicating
|
|
10
|
+
that the package has been installed as editable.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
package_name: The name of the package to fetch
|
|
14
|
+
the version of.
|
|
15
|
+
|
|
16
|
+
Returns:
|
|
17
|
+
A string representing the version of the library.
|
|
18
|
+
"""
|
|
19
|
+
try:
|
|
20
|
+
return metadata.version(package_name)
|
|
21
|
+
except metadata.PackageNotFoundError:
|
|
22
|
+
return "*"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
__version__ = _fetch_version("eva")
|
eva/core/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""eva core API."""
|
|
2
|
+
|
|
3
|
+
from eva.core.cli import cli
|
|
4
|
+
from eva.core.data import DataLoader, DataloadersSchema, DataModule, DatasetsSchema
|
|
5
|
+
from eva.core.interface import Interface
|
|
6
|
+
from eva.core.models import HeadModule, InferenceModule
|
|
7
|
+
from eva.core.trainers import Trainer
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"cli",
|
|
11
|
+
"DataLoader",
|
|
12
|
+
"DataloadersSchema",
|
|
13
|
+
"DataModule",
|
|
14
|
+
"DatasetsSchema",
|
|
15
|
+
"Interface",
|
|
16
|
+
"HeadModule",
|
|
17
|
+
"InferenceModule",
|
|
18
|
+
"Trainer",
|
|
19
|
+
]
|
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
"""Embeddings writer."""
|
|
2
|
+
|
|
3
|
+
import csv
|
|
4
|
+
import io
|
|
5
|
+
import os
|
|
6
|
+
from typing import Any, Dict, Sequence
|
|
7
|
+
|
|
8
|
+
import lightning.pytorch as pl
|
|
9
|
+
import torch
|
|
10
|
+
from lightning.pytorch import callbacks
|
|
11
|
+
from loguru import logger
|
|
12
|
+
from torch import multiprocessing, nn
|
|
13
|
+
from typing_extensions import override
|
|
14
|
+
|
|
15
|
+
from eva.core.callbacks.writers.typings import QUEUE_ITEM
|
|
16
|
+
from eva.core.models.modules.typings import INPUT_BATCH
|
|
17
|
+
from eva.core.utils import multiprocessing as eva_multiprocessing
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class EmbeddingsWriter(callbacks.BasePredictionWriter):
|
|
21
|
+
"""Callback for writing generated embeddings to disk."""
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
output_dir: str,
|
|
26
|
+
backbone: nn.Module | None = None,
|
|
27
|
+
dataloader_idx_map: Dict[int, str] | None = None,
|
|
28
|
+
group_key: str | None = None,
|
|
29
|
+
overwrite: bool = True,
|
|
30
|
+
) -> None:
|
|
31
|
+
"""Initializes a new EmbeddingsWriter instance.
|
|
32
|
+
|
|
33
|
+
This callback writes the embedding files in a separate process to avoid blocking the
|
|
34
|
+
main process where the model forward pass is executed.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
output_dir: The directory where the embeddings will be saved.
|
|
38
|
+
backbone: A model to be used as feature extractor. If `None`,
|
|
39
|
+
it will be expected that the input batch returns the features directly.
|
|
40
|
+
dataloader_idx_map: A dictionary mapping dataloader indices to their respective
|
|
41
|
+
names (e.g. train, val, test).
|
|
42
|
+
group_key: The metadata key to group the embeddings by. If specified, the
|
|
43
|
+
embedding files will be saved in subdirectories named after the group_key.
|
|
44
|
+
If specified, the key must be present in the metadata of the input batch.
|
|
45
|
+
overwrite: Whether to overwrite the output directory. Defaults to True.
|
|
46
|
+
"""
|
|
47
|
+
super().__init__(write_interval="batch")
|
|
48
|
+
|
|
49
|
+
self._output_dir = output_dir
|
|
50
|
+
self._backbone = backbone
|
|
51
|
+
self._dataloader_idx_map = dataloader_idx_map or {}
|
|
52
|
+
self._group_key = group_key
|
|
53
|
+
self._overwrite = overwrite
|
|
54
|
+
|
|
55
|
+
self._write_queue: multiprocessing.Queue
|
|
56
|
+
self._write_process: eva_multiprocessing.Process
|
|
57
|
+
|
|
58
|
+
@override
|
|
59
|
+
def on_predict_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
|
|
60
|
+
os.makedirs(self._output_dir, exist_ok=self._overwrite)
|
|
61
|
+
self._initialize_write_process()
|
|
62
|
+
self._write_process.start()
|
|
63
|
+
|
|
64
|
+
if self._backbone is not None:
|
|
65
|
+
self._backbone = self._backbone.to(pl_module.device)
|
|
66
|
+
self._backbone.eval()
|
|
67
|
+
|
|
68
|
+
@override
|
|
69
|
+
def write_on_batch_end(
|
|
70
|
+
self,
|
|
71
|
+
trainer: pl.Trainer,
|
|
72
|
+
pl_module: pl.LightningModule,
|
|
73
|
+
prediction: Any,
|
|
74
|
+
batch_indices: Sequence[int],
|
|
75
|
+
batch: INPUT_BATCH,
|
|
76
|
+
batch_idx: int,
|
|
77
|
+
dataloader_idx: int,
|
|
78
|
+
) -> None:
|
|
79
|
+
dataset = trainer.predict_dataloaders[dataloader_idx].dataset # type: ignore
|
|
80
|
+
_, targets, metadata = INPUT_BATCH(*batch)
|
|
81
|
+
split = self._dataloader_idx_map.get(dataloader_idx)
|
|
82
|
+
|
|
83
|
+
embeddings = self._get_embeddings(prediction)
|
|
84
|
+
for local_idx, global_idx in enumerate(batch_indices[: len(embeddings)]):
|
|
85
|
+
input_name, save_name = self._construct_save_name(
|
|
86
|
+
dataset.filename(global_idx), metadata, local_idx
|
|
87
|
+
)
|
|
88
|
+
embeddings_buffer, target_buffer = io.BytesIO(), io.BytesIO()
|
|
89
|
+
torch.save(embeddings[local_idx].clone(), embeddings_buffer)
|
|
90
|
+
torch.save(targets[local_idx], target_buffer) # type: ignore
|
|
91
|
+
item = QUEUE_ITEM(embeddings_buffer, target_buffer, input_name, save_name, split)
|
|
92
|
+
self._write_queue.put(item)
|
|
93
|
+
|
|
94
|
+
self._write_process.check_exceptions()
|
|
95
|
+
|
|
96
|
+
@override
|
|
97
|
+
def on_predict_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
|
|
98
|
+
self._write_queue.put(None)
|
|
99
|
+
self._write_process.join()
|
|
100
|
+
logger.info(f"Predictions and manifest saved to {self._output_dir}")
|
|
101
|
+
|
|
102
|
+
def _initialize_write_process(self) -> None:
|
|
103
|
+
self._write_queue = multiprocessing.Queue()
|
|
104
|
+
self._write_process = eva_multiprocessing.Process(
|
|
105
|
+
target=_process_write_queue, args=(self._write_queue, self._output_dir, self._overwrite)
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
def _get_embeddings(self, prediction: torch.Tensor) -> torch.Tensor:
|
|
109
|
+
"""Returns the embeddings from predictions."""
|
|
110
|
+
if self._backbone is None:
|
|
111
|
+
return prediction
|
|
112
|
+
|
|
113
|
+
with torch.no_grad():
|
|
114
|
+
return self._backbone(prediction)
|
|
115
|
+
|
|
116
|
+
def _construct_save_name(self, input_name, metadata, local_idx):
|
|
117
|
+
group_name = metadata[self._group_key][local_idx] if self._group_key else None
|
|
118
|
+
save_name = os.path.splitext(input_name)[0] + ".pt"
|
|
119
|
+
if group_name:
|
|
120
|
+
save_name = os.path.join(group_name, save_name)
|
|
121
|
+
return input_name, save_name
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def _process_write_queue(
|
|
125
|
+
write_queue: multiprocessing.Queue, output_dir: str, overwrite: bool = False
|
|
126
|
+
) -> None:
|
|
127
|
+
manifest_file, manifest_writer = _init_manifest(output_dir, overwrite)
|
|
128
|
+
while True:
|
|
129
|
+
item = write_queue.get()
|
|
130
|
+
if item is None:
|
|
131
|
+
break
|
|
132
|
+
|
|
133
|
+
prediction_buffer, target_buffer, input_name, save_name, split = QUEUE_ITEM(*item)
|
|
134
|
+
_save_prediction(prediction_buffer, save_name, output_dir)
|
|
135
|
+
_update_manifest(target_buffer, input_name, save_name, split, manifest_writer)
|
|
136
|
+
|
|
137
|
+
manifest_file.close()
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def _save_prediction(prediction_buffer: io.BytesIO, save_name: str, output_dir: str) -> None:
|
|
141
|
+
save_path = os.path.join(output_dir, save_name)
|
|
142
|
+
prediction = torch.load(io.BytesIO(prediction_buffer.getbuffer()), map_location="cpu")
|
|
143
|
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
|
144
|
+
torch.save(prediction, save_path)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def _init_manifest(output_dir: str, overwrite: bool = False) -> tuple[io.TextIOWrapper, Any]:
|
|
148
|
+
manifest_path = os.path.join(output_dir, "manifest.csv")
|
|
149
|
+
if os.path.exists(manifest_path) and not overwrite:
|
|
150
|
+
raise FileExistsError(
|
|
151
|
+
f"Manifest file already exists at {manifest_path}. This likely means that the "
|
|
152
|
+
"embeddings have been computed before. Consider using `eva fit` instead "
|
|
153
|
+
"of `eva predict_fit` or `eva predict`."
|
|
154
|
+
)
|
|
155
|
+
manifest_file = open(manifest_path, "w", newline="")
|
|
156
|
+
manifest_writer = csv.writer(manifest_file)
|
|
157
|
+
manifest_writer.writerow(["origin", "embeddings", "target", "split"])
|
|
158
|
+
return manifest_file, manifest_writer
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def _update_manifest(
|
|
162
|
+
target_buffer: io.BytesIO,
|
|
163
|
+
input_name: str,
|
|
164
|
+
save_name: str,
|
|
165
|
+
split: str | None,
|
|
166
|
+
manifest_writer,
|
|
167
|
+
) -> None:
|
|
168
|
+
target = torch.load(io.BytesIO(target_buffer.getbuffer()), map_location="cpu")
|
|
169
|
+
manifest_writer.writerow([input_name, save_name, target.item(), split])
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
"""Typing definitions for the writer callback functions."""
|
|
2
|
+
|
|
3
|
+
import io
|
|
4
|
+
from typing import NamedTuple
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class QUEUE_ITEM(NamedTuple):
|
|
8
|
+
"""The default input batch data scheme."""
|
|
9
|
+
|
|
10
|
+
prediction_buffer: io.BytesIO
|
|
11
|
+
"""IO buffer containing the prediction tensor"""
|
|
12
|
+
|
|
13
|
+
target_buffer: io.BytesIO
|
|
14
|
+
"""IO buffer containing the target tensor"""
|
|
15
|
+
|
|
16
|
+
input_name: str
|
|
17
|
+
"""Name of the original input file that was used to generate the embedding."""
|
|
18
|
+
|
|
19
|
+
save_name: str
|
|
20
|
+
"""Name to store the generated embedding"""
|
|
21
|
+
|
|
22
|
+
split: str | None
|
|
23
|
+
"""The dataset split the item belongs to (e.g. train, val, test)."""
|
eva/core/cli/__init__.py
ADDED
eva/core/cli/cli.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""eva's main cli manager."""
|
|
2
|
+
|
|
3
|
+
import jsonargparse
|
|
4
|
+
|
|
5
|
+
from eva.__version__ import __version__
|
|
6
|
+
from eva.core import interface
|
|
7
|
+
from eva.core.cli import logo, setup
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def cli() -> object:
|
|
11
|
+
"""Main CLI factory."""
|
|
12
|
+
logo.print_cli_logo()
|
|
13
|
+
setup.setup()
|
|
14
|
+
return jsonargparse.CLI(
|
|
15
|
+
interface.Interface,
|
|
16
|
+
parser_mode="omegaconf",
|
|
17
|
+
fail_untyped=False,
|
|
18
|
+
version=f"version {__version__}",
|
|
19
|
+
)
|
eva/core/cli/logo.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
"""CLI logos."""
|
|
2
|
+
|
|
3
|
+
_EVA_LOGO: str = r"""
|
|
4
|
+
_____ ____ _
|
|
5
|
+
/ _ \ \ / / _` |
|
|
6
|
+
| __/\ V / (_| |
|
|
7
|
+
\___| \_/ \__,_|
|
|
8
|
+
|
|
9
|
+
kaiko.ai
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
ANSI_LOGO_COLOR = "\33[0;34m"
|
|
14
|
+
"""ANSI main logo color."""
|
|
15
|
+
|
|
16
|
+
ANSI_COLOR_RESET = "\33[0m"
|
|
17
|
+
"""ANSI color reset code."""
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _print_logo(
|
|
21
|
+
logo: str,
|
|
22
|
+
prefix: str = "",
|
|
23
|
+
suffix: str = "",
|
|
24
|
+
) -> None:
|
|
25
|
+
r"""Prints an ASCII terminal art logo in terminal.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
logo: The desired art logo to print.
|
|
29
|
+
prefix: Characters to add before the logo.
|
|
30
|
+
suffix: Characters to add after the logo.
|
|
31
|
+
"""
|
|
32
|
+
colored_logo = f"{ANSI_LOGO_COLOR}{logo}{ANSI_COLOR_RESET}"
|
|
33
|
+
print(prefix + colored_logo + suffix)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def print_cli_logo() -> None:
|
|
37
|
+
"""Prints the CLI logo."""
|
|
38
|
+
_print_logo(_EVA_LOGO, suffix="\n")
|
eva/core/cli/setup.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
"""Operations which are executed when the CLI is triggered."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import sys
|
|
5
|
+
import warnings
|
|
6
|
+
|
|
7
|
+
import jsonargparse
|
|
8
|
+
import yaml
|
|
9
|
+
from lightning_fabric.utilities import seed as pl_seed
|
|
10
|
+
from loguru import logger
|
|
11
|
+
|
|
12
|
+
from eva.core.utils import workers
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _configure_random_seed(seed: int | None = None) -> None:
|
|
16
|
+
"""Sets the global random seed.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
seed: The seed number to use. If `None`, it will read the seed from
|
|
20
|
+
`EVA_GLOBAL_SEED` env variable. If `None` and the `EVA_GLOBAL_SEED`
|
|
21
|
+
env variable is not set, then the seed defaults to `42`. If `None`
|
|
22
|
+
and the `EVA_GLOBAL_SEED` is set to `False`, it will not set the seed.
|
|
23
|
+
"""
|
|
24
|
+
effective_seed = seed or os.environ.get("EVA_GLOBAL_SEED", default=42)
|
|
25
|
+
if isinstance(effective_seed, str):
|
|
26
|
+
effective_seed = yaml.safe_load(effective_seed)
|
|
27
|
+
if not isinstance(effective_seed, (bool, int)):
|
|
28
|
+
raise ValueError(
|
|
29
|
+
f"Invalid 'EVA_GLOBAL_SEED' value '{effective_seed}'. "
|
|
30
|
+
"It should be an integer or a boolean value."
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
if isinstance(effective_seed, bool) and effective_seed is False:
|
|
34
|
+
return
|
|
35
|
+
|
|
36
|
+
pl_seed.seed_everything(seed=int(effective_seed), workers=True)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _configure_jsonargparse() -> None:
|
|
40
|
+
"""Configures the `jsonargparse` library."""
|
|
41
|
+
jsonargparse.set_config_read_mode(
|
|
42
|
+
urls_enabled=True,
|
|
43
|
+
fsspec_enabled=True,
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _initialize_logger() -> None:
|
|
48
|
+
"""Initializes, manipulates and customizes the logger.
|
|
49
|
+
|
|
50
|
+
This customizable logger can be used by just importing `loguru`
|
|
51
|
+
from everywhere as follows:
|
|
52
|
+
>>> from loguru import logger
|
|
53
|
+
>>> logger.info(...)
|
|
54
|
+
"""
|
|
55
|
+
logger.remove()
|
|
56
|
+
logger.add(
|
|
57
|
+
sys.stderr,
|
|
58
|
+
format="<blue>{time:HH:mm:ss}</blue>"
|
|
59
|
+
" :: <bold><level>{level}</level></bold>"
|
|
60
|
+
" :: {message}",
|
|
61
|
+
colorize=True,
|
|
62
|
+
level="INFO",
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _suppress_warnings() -> None:
|
|
67
|
+
"""Suppress all warnings from all subprocesses."""
|
|
68
|
+
if not sys.warnoptions:
|
|
69
|
+
warnings.simplefilter("ignore")
|
|
70
|
+
os.environ["PYTHONWARNINGS"] = "ignore"
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _enable_mps_fallback() -> None:
|
|
74
|
+
"""It enables the MPS fallback in torch.
|
|
75
|
+
|
|
76
|
+
Note that this action has to take place before importing torch.
|
|
77
|
+
"""
|
|
78
|
+
if os.environ.get("PYTORCH_ENABLE_MPS_FALLBACK") is None:
|
|
79
|
+
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@workers.main_worker_only
|
|
83
|
+
def setup() -> None:
|
|
84
|
+
"""Sets up the environment before the module is imported."""
|
|
85
|
+
_configure_random_seed()
|
|
86
|
+
_configure_jsonargparse()
|
|
87
|
+
_initialize_logger()
|
|
88
|
+
_suppress_warnings()
|
|
89
|
+
_enable_mps_fallback()
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
"""Data API."""
|
|
2
|
+
|
|
3
|
+
from eva.core.data.dataloaders import DataLoader
|
|
4
|
+
from eva.core.data.datamodules import DataloadersSchema, DataModule, DatasetsSchema
|
|
5
|
+
from eva.core.data.datasets import Dataset, TorchDataset
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"DataLoader",
|
|
9
|
+
"DataloadersSchema",
|
|
10
|
+
"DataModule",
|
|
11
|
+
"DatasetsSchema",
|
|
12
|
+
"Dataset",
|
|
13
|
+
"TorchDataset",
|
|
14
|
+
]
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
"""Core Dataloader module."""
|
|
2
|
+
|
|
3
|
+
import dataclasses
|
|
4
|
+
import multiprocessing
|
|
5
|
+
from typing import Callable
|
|
6
|
+
|
|
7
|
+
from torch.utils.data import dataloader
|
|
8
|
+
|
|
9
|
+
from eva.core.data import datasets, samplers
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclasses.dataclass
|
|
13
|
+
class DataLoader:
|
|
14
|
+
"""The `DataLoader` combines a dataset and a sampler.
|
|
15
|
+
|
|
16
|
+
It provides an iterable over the given dataset.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
batch_size: int | None = 1
|
|
20
|
+
"""How many samples per batch to load.
|
|
21
|
+
|
|
22
|
+
Set to `None` for iterable dataset where dataset produces batches.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
shuffle: bool = False
|
|
26
|
+
"""Whether to shuffle the data at every epoch."""
|
|
27
|
+
|
|
28
|
+
sampler: samplers.Sampler | None = None
|
|
29
|
+
"""Defines the strategy to draw samples from the dataset.
|
|
30
|
+
|
|
31
|
+
Can be any Iterable with `__len__` implemented. If specified, shuffle must
|
|
32
|
+
not be specified.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
batch_sampler: samplers.Sampler | None = None
|
|
36
|
+
"""Like `sampler`, but returns a batch of indices at a time.
|
|
37
|
+
|
|
38
|
+
Mutually exclusive with `batch_size`, `shuffle`, `sampler` and `drop_last`.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
num_workers: int = multiprocessing.cpu_count()
|
|
42
|
+
"""How many workers to use for loading the data.
|
|
43
|
+
|
|
44
|
+
By default, it will use the number of CPUs available.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
collate_fn: Callable | None = None
|
|
48
|
+
"""The batching process."""
|
|
49
|
+
|
|
50
|
+
pin_memory: bool = True
|
|
51
|
+
"""Will copy Tensors into CUDA pinned memory before returning them."""
|
|
52
|
+
|
|
53
|
+
drop_last: bool = False
|
|
54
|
+
"""Drops the last incomplete batch."""
|
|
55
|
+
|
|
56
|
+
persistent_workers: bool = True
|
|
57
|
+
"""Will keep the worker processes after a dataset has been consumed once."""
|
|
58
|
+
|
|
59
|
+
prefetch_factor: int | None = 2
|
|
60
|
+
"""Number of batches loaded in advance by each worker."""
|
|
61
|
+
|
|
62
|
+
def __call__(self, dataset: datasets.TorchDataset) -> dataloader.DataLoader:
|
|
63
|
+
"""Returns the dataloader on the provided dataset.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
dataset: dataset from which to load the data.
|
|
67
|
+
"""
|
|
68
|
+
return dataloader.DataLoader(
|
|
69
|
+
dataset=dataset,
|
|
70
|
+
batch_size=self.batch_size,
|
|
71
|
+
shuffle=self.shuffle,
|
|
72
|
+
sampler=self.sampler,
|
|
73
|
+
batch_sampler=self.batch_sampler,
|
|
74
|
+
num_workers=self.num_workers,
|
|
75
|
+
collate_fn=self.collate_fn,
|
|
76
|
+
pin_memory=self.pin_memory,
|
|
77
|
+
drop_last=self.drop_last,
|
|
78
|
+
persistent_workers=self.persistent_workers,
|
|
79
|
+
prefetch_factor=self.prefetch_factor,
|
|
80
|
+
)
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
"""Helper dataset calling methods."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Iterable
|
|
4
|
+
|
|
5
|
+
from eva.core.data import datasets as datasets_lib
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def call_method_if_exists(objects: Iterable[Any], /, method: str) -> None:
|
|
9
|
+
"""Calls a desired `method` from the datasets if exists.
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
objects: An iterable of objects.
|
|
13
|
+
method: The dataset method name to call if exists.
|
|
14
|
+
"""
|
|
15
|
+
for _object in _recursive_iter(objects):
|
|
16
|
+
if hasattr(_object, method):
|
|
17
|
+
fn = getattr(_object, method)
|
|
18
|
+
fn()
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _recursive_iter(objects: Iterable[Any], /) -> Iterable[datasets_lib.TorchDataset]:
|
|
22
|
+
"""Iterates thought an iterable of objects and their respective iterable values.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
objects: The objects to iterate from.
|
|
26
|
+
|
|
27
|
+
Yields:
|
|
28
|
+
The individual object class.
|
|
29
|
+
"""
|
|
30
|
+
for _object in objects:
|
|
31
|
+
if not isinstance(_object, list):
|
|
32
|
+
_object = [_object]
|
|
33
|
+
yield from _object
|