kaiko-eva 0.3.3__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaiko-eva might be problematic. Click here for more details.
- eva/core/callbacks/config.py +15 -6
- eva/core/callbacks/writers/embeddings/base.py +44 -10
- eva/core/cli/setup.py +1 -1
- eva/core/data/dataloaders/__init__.py +1 -2
- eva/core/data/samplers/classification/balanced.py +24 -12
- eva/core/data/samplers/random.py +17 -10
- eva/core/interface/interface.py +21 -0
- eva/core/loggers/utils/wandb.py +4 -1
- eva/core/models/modules/module.py +2 -2
- eva/core/models/wrappers/base.py +2 -2
- eva/core/models/wrappers/from_function.py +3 -3
- eva/core/models/wrappers/from_torchhub.py +9 -7
- eva/core/models/wrappers/huggingface.py +4 -5
- eva/core/models/wrappers/onnx.py +5 -5
- eva/core/trainers/trainer.py +13 -1
- eva/core/utils/__init__.py +2 -1
- eva/core/utils/distributed.py +12 -0
- eva/core/utils/paths.py +14 -0
- eva/core/utils/requirements.py +52 -6
- eva/language/__init__.py +2 -1
- eva/language/callbacks/__init__.py +5 -0
- eva/language/callbacks/writers/__init__.py +5 -0
- eva/language/callbacks/writers/prediction.py +201 -0
- eva/language/data/dataloaders/__init__.py +5 -0
- eva/language/data/dataloaders/collate_fn/__init__.py +5 -0
- eva/language/data/dataloaders/collate_fn/text.py +57 -0
- eva/language/data/datasets/__init__.py +3 -1
- eva/language/data/datasets/{language.py → base.py} +1 -1
- eva/language/data/datasets/classification/base.py +3 -43
- eva/language/data/datasets/classification/pubmedqa.py +36 -4
- eva/language/data/datasets/prediction.py +151 -0
- eva/language/data/datasets/schemas.py +18 -0
- eva/language/data/datasets/text.py +92 -0
- eva/language/data/datasets/typings.py +39 -0
- eva/language/data/messages.py +60 -0
- eva/language/models/__init__.py +15 -11
- eva/language/models/modules/__init__.py +2 -2
- eva/language/models/modules/language.py +94 -0
- eva/language/models/networks/__init__.py +12 -0
- eva/language/models/networks/alibaba.py +26 -0
- eva/language/models/networks/api/__init__.py +11 -0
- eva/language/models/networks/api/anthropic.py +34 -0
- eva/language/models/networks/registry.py +5 -0
- eva/language/models/typings.py +56 -0
- eva/language/models/wrappers/__init__.py +13 -5
- eva/language/models/wrappers/base.py +47 -0
- eva/language/models/wrappers/from_registry.py +54 -0
- eva/language/models/wrappers/huggingface.py +57 -11
- eva/language/models/wrappers/litellm.py +91 -46
- eva/language/models/wrappers/vllm.py +37 -13
- eva/language/utils/__init__.py +2 -1
- eva/language/utils/str_to_int_tensor.py +20 -12
- eva/language/utils/text/__init__.py +5 -0
- eva/language/utils/text/messages.py +113 -0
- eva/multimodal/__init__.py +6 -0
- eva/multimodal/callbacks/__init__.py +5 -0
- eva/multimodal/callbacks/writers/__init__.py +5 -0
- eva/multimodal/callbacks/writers/prediction.py +39 -0
- eva/multimodal/data/__init__.py +5 -0
- eva/multimodal/data/dataloaders/__init__.py +5 -0
- eva/multimodal/data/dataloaders/collate_fn/__init__.py +5 -0
- eva/multimodal/data/dataloaders/collate_fn/text_image.py +28 -0
- eva/multimodal/data/datasets/__init__.py +6 -0
- eva/multimodal/data/datasets/base.py +13 -0
- eva/multimodal/data/datasets/multiple_choice/__init__.py +5 -0
- eva/multimodal/data/datasets/multiple_choice/patch_camelyon.py +80 -0
- eva/multimodal/data/datasets/schemas.py +14 -0
- eva/multimodal/data/datasets/text_image.py +77 -0
- eva/multimodal/data/datasets/typings.py +27 -0
- eva/multimodal/models/__init__.py +8 -0
- eva/multimodal/models/modules/__init__.py +5 -0
- eva/multimodal/models/modules/vision_language.py +56 -0
- eva/multimodal/models/networks/__init__.py +14 -0
- eva/multimodal/models/networks/alibaba.py +40 -0
- eva/multimodal/models/networks/api/__init__.py +11 -0
- eva/multimodal/models/networks/api/anthropic.py +34 -0
- eva/multimodal/models/networks/others.py +48 -0
- eva/multimodal/models/networks/registry.py +5 -0
- eva/multimodal/models/typings.py +27 -0
- eva/multimodal/models/wrappers/__init__.py +13 -0
- eva/multimodal/models/wrappers/base.py +48 -0
- eva/multimodal/models/wrappers/from_registry.py +54 -0
- eva/multimodal/models/wrappers/huggingface.py +193 -0
- eva/multimodal/models/wrappers/litellm.py +58 -0
- eva/multimodal/utils/__init__.py +1 -0
- eva/multimodal/utils/batch/__init__.py +5 -0
- eva/multimodal/utils/batch/unpack.py +11 -0
- eva/multimodal/utils/image/__init__.py +5 -0
- eva/multimodal/utils/image/encode.py +28 -0
- eva/multimodal/utils/text/__init__.py +1 -0
- eva/multimodal/utils/text/messages.py +79 -0
- eva/vision/data/datasets/classification/breakhis.py +5 -8
- eva/vision/data/datasets/classification/panda.py +12 -5
- eva/vision/data/datasets/classification/patch_camelyon.py +8 -6
- eva/vision/data/datasets/segmentation/btcv.py +1 -1
- eva/vision/data/datasets/segmentation/consep.py +1 -1
- eva/vision/data/datasets/segmentation/lits17.py +1 -1
- eva/vision/data/datasets/segmentation/monusac.py +15 -6
- eva/vision/data/datasets/segmentation/msd_task7_pancreas.py +1 -1
- eva/vision/data/transforms/__init__.py +2 -1
- eva/vision/data/transforms/base/__init__.py +2 -1
- eva/vision/data/transforms/base/monai.py +2 -2
- eva/vision/data/transforms/base/torchvision.py +33 -0
- eva/vision/data/transforms/common/squeeze.py +6 -3
- eva/vision/data/transforms/croppad/crop_foreground.py +8 -7
- eva/vision/data/transforms/croppad/rand_crop_by_label_classes.py +6 -5
- eva/vision/data/transforms/croppad/rand_crop_by_pos_neg_label.py +6 -5
- eva/vision/data/transforms/croppad/rand_spatial_crop.py +8 -7
- eva/vision/data/transforms/croppad/spatial_pad.py +6 -6
- eva/vision/data/transforms/intensity/rand_scale_intensity.py +3 -3
- eva/vision/data/transforms/intensity/rand_shift_intensity.py +3 -3
- eva/vision/data/transforms/intensity/scale_intensity_ranged.py +5 -5
- eva/vision/data/transforms/spatial/__init__.py +2 -1
- eva/vision/data/transforms/spatial/flip.py +8 -7
- eva/vision/data/transforms/spatial/functional/__init__.py +5 -0
- eva/vision/data/transforms/spatial/functional/resize.py +26 -0
- eva/vision/data/transforms/spatial/resize.py +63 -0
- eva/vision/data/transforms/spatial/rotate.py +8 -7
- eva/vision/data/transforms/spatial/spacing.py +7 -6
- eva/vision/data/transforms/utility/ensure_channel_first.py +6 -6
- eva/vision/models/networks/backbones/universal/vit.py +24 -0
- eva/vision/models/wrappers/from_registry.py +6 -5
- eva/vision/models/wrappers/from_timm.py +6 -4
- {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/METADATA +17 -3
- {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/RECORD +128 -66
- eva/core/data/dataloaders/collate_fn/__init__.py +0 -5
- eva/core/data/dataloaders/collate_fn/collate.py +0 -24
- eva/language/models/modules/text.py +0 -85
- {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/WHEEL +0 -0
- {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/entry_points.txt +0 -0
- {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/licenses/LICENSE +0 -0
eva/core/callbacks/config.py
CHANGED
|
@@ -9,11 +9,13 @@ from typing import Any, Dict, List
|
|
|
9
9
|
import lightning.pytorch as pl
|
|
10
10
|
import yaml
|
|
11
11
|
from lightning_fabric.utilities import cloud_io
|
|
12
|
+
from loguru import logger
|
|
12
13
|
from loguru import logger as cli_logger
|
|
13
14
|
from omegaconf import OmegaConf
|
|
14
15
|
from typing_extensions import TypeGuard, override
|
|
15
16
|
|
|
16
17
|
from eva.core import loggers
|
|
18
|
+
from eva.core.utils import distributed as dist_utils
|
|
17
19
|
|
|
18
20
|
|
|
19
21
|
class ConfigurationLogger(pl.Callback):
|
|
@@ -39,8 +41,14 @@ class ConfigurationLogger(pl.Callback):
|
|
|
39
41
|
pl_module: pl.LightningModule,
|
|
40
42
|
stage: str | None = None,
|
|
41
43
|
) -> None:
|
|
42
|
-
|
|
43
|
-
|
|
44
|
+
if dist_utils.is_distributed():
|
|
45
|
+
logger.info("ConfigurationLogger skipped as not supported in distributed mode.")
|
|
46
|
+
# TODO: Enabling leads to deadlocks in DDP mode, but I could not yet figure out why.
|
|
47
|
+
return
|
|
48
|
+
|
|
49
|
+
if not trainer.is_global_zero or not _logdir_exists(
|
|
50
|
+
log_dir := trainer.log_dir, self._verbose
|
|
51
|
+
):
|
|
44
52
|
return
|
|
45
53
|
|
|
46
54
|
configuration = _load_submitted_config()
|
|
@@ -51,6 +59,10 @@ class ConfigurationLogger(pl.Callback):
|
|
|
51
59
|
|
|
52
60
|
save_as = os.path.join(log_dir, self._save_as)
|
|
53
61
|
fs = cloud_io.get_filesystem(log_dir)
|
|
62
|
+
|
|
63
|
+
if not fs.exists(log_dir):
|
|
64
|
+
fs.makedirs(log_dir)
|
|
65
|
+
|
|
54
66
|
with fs.open(save_as, "w") as output_file:
|
|
55
67
|
yaml.dump(configuration, output_file, sort_keys=False)
|
|
56
68
|
|
|
@@ -126,7 +138,7 @@ def _type_resolver(mapping: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
126
138
|
for key, value in mapping.items():
|
|
127
139
|
if isinstance(value, dict):
|
|
128
140
|
formatted_value = _type_resolver(value)
|
|
129
|
-
elif isinstance(value, list) and isinstance(value[0], dict):
|
|
141
|
+
elif isinstance(value, list) and value and isinstance(value[0], dict):
|
|
130
142
|
formatted_value = [_type_resolver(subvalue) for subvalue in value]
|
|
131
143
|
else:
|
|
132
144
|
try:
|
|
@@ -134,10 +146,7 @@ def _type_resolver(mapping: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
134
146
|
formatted_value = (
|
|
135
147
|
value if isinstance(parsed_value, BuiltinFunctionType) else parsed_value
|
|
136
148
|
)
|
|
137
|
-
|
|
138
149
|
except Exception:
|
|
139
150
|
formatted_value = value
|
|
140
|
-
|
|
141
151
|
mapping[key] = formatted_value
|
|
142
|
-
|
|
143
152
|
return mapping
|
|
@@ -7,6 +7,7 @@ from typing import Any, Dict, List, Sequence
|
|
|
7
7
|
|
|
8
8
|
import lightning.pytorch as pl
|
|
9
9
|
import torch
|
|
10
|
+
import torch.distributed as dist
|
|
10
11
|
from lightning.pytorch import callbacks
|
|
11
12
|
from loguru import logger
|
|
12
13
|
from torch import multiprocessing, nn
|
|
@@ -15,6 +16,7 @@ from typing_extensions import override
|
|
|
15
16
|
from eva.core import utils
|
|
16
17
|
from eva.core.callbacks.writers.embeddings.typings import QUEUE_ITEM
|
|
17
18
|
from eva.core.models.modules.typings import INPUT_BATCH
|
|
19
|
+
from eva.core.utils import distributed as dist_utils
|
|
18
20
|
from eva.core.utils import multiprocessing as eva_multiprocessing
|
|
19
21
|
|
|
20
22
|
|
|
@@ -58,8 +60,9 @@ class EmbeddingsWriter(callbacks.BasePredictionWriter, abc.ABC):
|
|
|
58
60
|
self._save_every_n = save_every_n
|
|
59
61
|
self._metadata_keys = metadata_keys or []
|
|
60
62
|
|
|
61
|
-
self._write_queue: multiprocessing.Queue
|
|
62
|
-
self._write_process: eva_multiprocessing.Process
|
|
63
|
+
self._write_queue: multiprocessing.Queue | None = None
|
|
64
|
+
self._write_process: eva_multiprocessing.Process | None = None
|
|
65
|
+
self._is_rank_zero: bool = False
|
|
63
66
|
|
|
64
67
|
@staticmethod
|
|
65
68
|
@abc.abstractmethod
|
|
@@ -78,9 +81,13 @@ class EmbeddingsWriter(callbacks.BasePredictionWriter, abc.ABC):
|
|
|
78
81
|
|
|
79
82
|
@override
|
|
80
83
|
def on_predict_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
|
|
81
|
-
self.
|
|
82
|
-
self.
|
|
83
|
-
|
|
84
|
+
self._is_rank_zero = trainer.is_global_zero
|
|
85
|
+
if self._is_rank_zero:
|
|
86
|
+
self._check_if_exists()
|
|
87
|
+
self._initialize_write_process()
|
|
88
|
+
if self._write_process is None or self._write_queue is None:
|
|
89
|
+
raise RuntimeError("Failed to initialize embedding writer process.")
|
|
90
|
+
self._write_process.start()
|
|
84
91
|
|
|
85
92
|
if self._backbone is not None:
|
|
86
93
|
self._backbone = self._backbone.to(pl_module.device)
|
|
@@ -106,6 +113,7 @@ class EmbeddingsWriter(callbacks.BasePredictionWriter, abc.ABC):
|
|
|
106
113
|
with torch.no_grad():
|
|
107
114
|
embeddings = self._get_embeddings(prediction)
|
|
108
115
|
|
|
116
|
+
queue_items: List[QUEUE_ITEM] = []
|
|
109
117
|
for local_idx, global_idx in enumerate(batch_indices[: len(embeddings)]):
|
|
110
118
|
data_name = dataset.filename(global_idx)
|
|
111
119
|
save_name = os.path.splitext(data_name)[0] + ".pt"
|
|
@@ -121,15 +129,41 @@ class EmbeddingsWriter(callbacks.BasePredictionWriter, abc.ABC):
|
|
|
121
129
|
split=split,
|
|
122
130
|
metadata=item_metadata,
|
|
123
131
|
)
|
|
124
|
-
|
|
132
|
+
queue_items.append(item)
|
|
125
133
|
|
|
126
|
-
self.
|
|
134
|
+
gathered_items = self._gather_queue_items(queue_items)
|
|
135
|
+
if self._is_rank_zero:
|
|
136
|
+
for item in gathered_items:
|
|
137
|
+
self._write_queue.put(item) # type: ignore
|
|
138
|
+
self._write_process.check_exceptions() # type: ignore
|
|
127
139
|
|
|
128
140
|
@override
|
|
129
141
|
def on_predict_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
142
|
+
if dist_utils.is_distributed():
|
|
143
|
+
dist.barrier()
|
|
144
|
+
|
|
145
|
+
if self._is_rank_zero and self._write_queue is not None:
|
|
146
|
+
self._write_queue.put(None)
|
|
147
|
+
if self._write_process is not None:
|
|
148
|
+
self._write_process.join()
|
|
149
|
+
logger.info(f"Predictions and manifest saved to {self._output_dir}")
|
|
150
|
+
|
|
151
|
+
def _gather_queue_items(self, items: List[QUEUE_ITEM]) -> List[QUEUE_ITEM]:
|
|
152
|
+
"""Gather queue items across distributed ranks, returning only on rank zero."""
|
|
153
|
+
if not dist_utils.is_distributed():
|
|
154
|
+
return items
|
|
155
|
+
|
|
156
|
+
world_size = dist.get_world_size()
|
|
157
|
+
object_list: List[List[QUEUE_ITEM]] = [[] for _ in range(world_size)]
|
|
158
|
+
dist.all_gather_object(object_list, items)
|
|
159
|
+
|
|
160
|
+
if self._is_rank_zero:
|
|
161
|
+
gathered: List[QUEUE_ITEM] = []
|
|
162
|
+
for rank_items in object_list:
|
|
163
|
+
gathered.extend(rank_items)
|
|
164
|
+
return gathered
|
|
165
|
+
|
|
166
|
+
return []
|
|
133
167
|
|
|
134
168
|
def _initialize_write_process(self) -> None:
|
|
135
169
|
self._write_queue = multiprocessing.Queue()
|
eva/core/cli/setup.py
CHANGED
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
"""Random class sampler for data loading."""
|
|
2
2
|
|
|
3
3
|
from collections import defaultdict
|
|
4
|
-
from typing import Dict, Iterator, List
|
|
4
|
+
from typing import Dict, Iterator, List, Union
|
|
5
5
|
|
|
6
6
|
import numpy as np
|
|
7
|
+
import torch
|
|
7
8
|
from loguru import logger
|
|
8
9
|
from typing_extensions import override
|
|
9
10
|
|
|
@@ -32,7 +33,7 @@ class BalancedSampler(SamplerWithDataSource[int]):
|
|
|
32
33
|
"""
|
|
33
34
|
self._num_samples = num_samples
|
|
34
35
|
self._replacement = replacement
|
|
35
|
-
self._class_indices: Dict[int, List[int]] = defaultdict(list)
|
|
36
|
+
self._class_indices: Dict[Union[int, str], List[int]] = defaultdict(list)
|
|
36
37
|
self._random_generator = np.random.default_rng(seed)
|
|
37
38
|
self._indices: List[int] = []
|
|
38
39
|
|
|
@@ -62,20 +63,31 @@ class BalancedSampler(SamplerWithDataSource[int]):
|
|
|
62
63
|
super().set_dataset(data_source)
|
|
63
64
|
self._make_indices()
|
|
64
65
|
|
|
66
|
+
def _get_class_idx(self, idx):
|
|
67
|
+
"""Load and validate the class index for a given sample index."""
|
|
68
|
+
if hasattr(self.data_source, "load_target"):
|
|
69
|
+
target = self.data_source.load_target(idx) # type: ignore
|
|
70
|
+
else:
|
|
71
|
+
_, target, _ = DataSample(*self.data_source[idx])
|
|
72
|
+
|
|
73
|
+
if target is None:
|
|
74
|
+
raise ValueError("The dataset must return non-empty targets.")
|
|
75
|
+
|
|
76
|
+
if isinstance(target, str):
|
|
77
|
+
return target
|
|
78
|
+
|
|
79
|
+
if isinstance(target, torch.Tensor):
|
|
80
|
+
if target.numel() != 1:
|
|
81
|
+
raise ValueError("The dataset must return a single & scalar target.")
|
|
82
|
+
return int(target.item())
|
|
83
|
+
|
|
84
|
+
raise ValueError("Unsupported target type. Expected str or tensor-like object.")
|
|
85
|
+
|
|
65
86
|
def _make_indices(self):
|
|
66
87
|
"""Samples the indices for each class in the dataset."""
|
|
67
88
|
self._class_indices.clear()
|
|
68
89
|
for idx in tqdm(range(len(self.data_source)), desc="Fetching class indices for sampler"):
|
|
69
|
-
|
|
70
|
-
target = self.data_source.load_target(idx) # type: ignore
|
|
71
|
-
else:
|
|
72
|
-
_, target, _ = DataSample(*self.data_source[idx])
|
|
73
|
-
if target is None:
|
|
74
|
-
raise ValueError("The dataset must return non-empty targets.")
|
|
75
|
-
if target.numel() != 1:
|
|
76
|
-
raise ValueError("The dataset must return a single & scalar target.")
|
|
77
|
-
|
|
78
|
-
class_idx = int(target.item())
|
|
90
|
+
class_idx = self._get_class_idx(idx)
|
|
79
91
|
self._class_indices[class_idx].append(idx)
|
|
80
92
|
|
|
81
93
|
if not self._replacement:
|
eva/core/data/samplers/random.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from typing import Optional
|
|
4
4
|
|
|
5
|
+
import torch
|
|
5
6
|
from torch.utils import data
|
|
6
7
|
from typing_extensions import override
|
|
7
8
|
|
|
@@ -10,30 +11,36 @@ from eva.core.data.samplers.sampler import SamplerWithDataSource
|
|
|
10
11
|
|
|
11
12
|
|
|
12
13
|
class RandomSampler(data.RandomSampler, SamplerWithDataSource[int]):
|
|
13
|
-
"""Samples elements randomly."""
|
|
14
|
+
"""Samples elements randomly from a MapDataset."""
|
|
14
15
|
|
|
15
16
|
data_source: datasets.MapDataset # type: ignore
|
|
16
17
|
|
|
17
18
|
def __init__(
|
|
18
|
-
self,
|
|
19
|
+
self,
|
|
20
|
+
replacement: bool = False,
|
|
21
|
+
num_samples: Optional[int] = None,
|
|
22
|
+
seed: Optional[int] = None,
|
|
19
23
|
) -> None:
|
|
20
|
-
"""
|
|
24
|
+
"""Initialize the random sampler.
|
|
21
25
|
|
|
22
26
|
Args:
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
generator: Generator used in sampling.
|
|
27
|
+
replacement: Samples are drawn on-demand with replacement if ``True``, default=``False``
|
|
28
|
+
num_samples: Number of samples to draw, default=``len(dataset)``.
|
|
29
|
+
seed: Optional seed for the random number generator.
|
|
27
30
|
"""
|
|
28
31
|
self.replacement = replacement
|
|
29
32
|
self._num_samples = num_samples
|
|
30
|
-
self.
|
|
33
|
+
self._generator = None
|
|
34
|
+
|
|
35
|
+
if seed is not None:
|
|
36
|
+
self._generator = torch.Generator()
|
|
37
|
+
self._generator.manual_seed(seed)
|
|
31
38
|
|
|
32
39
|
@override
|
|
33
40
|
def set_dataset(self, data_source: datasets.MapDataset) -> None:
|
|
34
41
|
super().__init__(
|
|
35
42
|
data_source,
|
|
36
43
|
replacement=self.replacement,
|
|
37
|
-
num_samples=self.
|
|
38
|
-
generator=self.
|
|
44
|
+
num_samples=self._num_samples,
|
|
45
|
+
generator=self._generator,
|
|
39
46
|
)
|
eva/core/interface/interface.py
CHANGED
|
@@ -132,3 +132,24 @@ class Interface:
|
|
|
132
132
|
n_runs=trainer.n_runs,
|
|
133
133
|
verbose=trainer.n_runs > 1,
|
|
134
134
|
)
|
|
135
|
+
|
|
136
|
+
def validate_test(
|
|
137
|
+
self,
|
|
138
|
+
trainer: eva_trainer.Trainer,
|
|
139
|
+
model: modules.ModelModule,
|
|
140
|
+
data: datamodules.DataModule,
|
|
141
|
+
) -> None:
|
|
142
|
+
"""Runs validation & test stages."""
|
|
143
|
+
if getattr(data.datasets, "val", None) is None:
|
|
144
|
+
raise ValueError("The provided data module does not contain a validation dataset.")
|
|
145
|
+
if getattr(data.datasets, "test", None) is None:
|
|
146
|
+
raise ValueError("The provided data module does not contain a test dataset.")
|
|
147
|
+
|
|
148
|
+
eva_trainer.run_evaluation_session(
|
|
149
|
+
base_trainer=trainer,
|
|
150
|
+
base_model=model,
|
|
151
|
+
datamodule=data,
|
|
152
|
+
stages=["validate", "test"],
|
|
153
|
+
n_runs=trainer.n_runs,
|
|
154
|
+
verbose=trainer.n_runs > 1,
|
|
155
|
+
)
|
eva/core/loggers/utils/wandb.py
CHANGED
|
@@ -5,6 +5,8 @@ from typing import Any, Dict
|
|
|
5
5
|
|
|
6
6
|
from loguru import logger
|
|
7
7
|
|
|
8
|
+
from eva.core.utils import requirements
|
|
9
|
+
|
|
8
10
|
|
|
9
11
|
def rename_active_run(name: str) -> None:
|
|
10
12
|
"""Renames the current run."""
|
|
@@ -12,7 +14,8 @@ def rename_active_run(name: str) -> None:
|
|
|
12
14
|
|
|
13
15
|
if wandb.run:
|
|
14
16
|
wandb.run.name = name
|
|
15
|
-
wandb.
|
|
17
|
+
if requirements.below("wandb", "0.21.0"):
|
|
18
|
+
wandb.run.save()
|
|
16
19
|
else:
|
|
17
20
|
logger.warning("No active wandb run found that could be renamed.")
|
|
18
21
|
|
|
@@ -33,8 +33,8 @@ class ModelModule(pl.LightningModule):
|
|
|
33
33
|
super().__init__()
|
|
34
34
|
|
|
35
35
|
self._metrics = metrics or self.default_metrics
|
|
36
|
-
self._postprocess = postprocess or self.default_postprocess
|
|
37
36
|
|
|
37
|
+
self.postprocess = postprocess or self.default_postprocess
|
|
38
38
|
self.metrics = metrics_lib.MetricModule.from_schema(self._metrics)
|
|
39
39
|
|
|
40
40
|
@property
|
|
@@ -133,7 +133,7 @@ class ModelModule(pl.LightningModule):
|
|
|
133
133
|
Returns:
|
|
134
134
|
The updated outputs.
|
|
135
135
|
"""
|
|
136
|
-
self.
|
|
136
|
+
self.postprocess(outputs)
|
|
137
137
|
return memory.recursive_detach(outputs, to_cpu=self.metrics_device.type == "cpu")
|
|
138
138
|
|
|
139
139
|
def _forward_and_log_metrics(
|
eva/core/models/wrappers/base.py
CHANGED
|
@@ -25,7 +25,7 @@ class BaseModel(nn.Module, Generic[InputType, OutputType]):
|
|
|
25
25
|
|
|
26
26
|
self._output_transforms = transforms
|
|
27
27
|
|
|
28
|
-
self.
|
|
28
|
+
self.model: Callable[..., OutputType] | nn.Module
|
|
29
29
|
|
|
30
30
|
@override
|
|
31
31
|
def forward(self, tensor: InputType) -> OutputType:
|
|
@@ -43,7 +43,7 @@ class BaseModel(nn.Module, Generic[InputType, OutputType]):
|
|
|
43
43
|
Args:
|
|
44
44
|
tensor: The input tensor to the model.
|
|
45
45
|
"""
|
|
46
|
-
return self.
|
|
46
|
+
return self.model(tensor)
|
|
47
47
|
|
|
48
48
|
def _apply_transforms(self, tensor: OutputType) -> OutputType:
|
|
49
49
|
if self._output_transforms is not None:
|
|
@@ -41,12 +41,12 @@ class ModelFromFunction(base.BaseModel[torch.Tensor, torch.Tensor]):
|
|
|
41
41
|
self._arguments = arguments
|
|
42
42
|
self._checkpoint_path = checkpoint_path
|
|
43
43
|
|
|
44
|
-
self.load_model()
|
|
44
|
+
self.model = self.load_model()
|
|
45
45
|
|
|
46
46
|
@override
|
|
47
|
-
def load_model(self) ->
|
|
47
|
+
def load_model(self) -> nn.Module:
|
|
48
48
|
class_path = jsonargparse.class_from_function(self._path, func_return=nn.Module)
|
|
49
49
|
model = class_path(**self._arguments or {})
|
|
50
50
|
if self._checkpoint_path is not None:
|
|
51
51
|
_utils.load_model_weights(model, self._checkpoint_path)
|
|
52
|
-
|
|
52
|
+
return model
|
|
@@ -52,12 +52,12 @@ class TorchHubModel(base.BaseModel[torch.Tensor, torch.Tensor]):
|
|
|
52
52
|
self._trust_repo = trust_repo
|
|
53
53
|
self._model_kwargs = model_kwargs or {}
|
|
54
54
|
|
|
55
|
-
self.load_model()
|
|
55
|
+
self.model = self.load_model()
|
|
56
56
|
|
|
57
57
|
@override
|
|
58
|
-
def load_model(self) ->
|
|
58
|
+
def load_model(self) -> nn.Module:
|
|
59
59
|
"""Builds and loads the torch.hub model."""
|
|
60
|
-
|
|
60
|
+
model: nn.Module = torch.hub.load(
|
|
61
61
|
repo_or_dir=self._repo_or_dir,
|
|
62
62
|
model=self._model_name,
|
|
63
63
|
trust_repo=self._trust_repo,
|
|
@@ -66,21 +66,23 @@ class TorchHubModel(base.BaseModel[torch.Tensor, torch.Tensor]):
|
|
|
66
66
|
) # type: ignore
|
|
67
67
|
|
|
68
68
|
if self._checkpoint_path:
|
|
69
|
-
_utils.load_model_weights(
|
|
69
|
+
_utils.load_model_weights(model, self._checkpoint_path)
|
|
70
70
|
|
|
71
71
|
TorchHubModel.__name__ = self._model_name
|
|
72
72
|
|
|
73
|
+
return model
|
|
74
|
+
|
|
73
75
|
@override
|
|
74
76
|
def model_forward(self, tensor: torch.Tensor) -> torch.Tensor | List[torch.Tensor]:
|
|
75
77
|
if self._out_indices is not None:
|
|
76
|
-
if not hasattr(self.
|
|
78
|
+
if not hasattr(self.model, "get_intermediate_layers"):
|
|
77
79
|
raise ValueError(
|
|
78
80
|
"Only models with `get_intermediate_layers` are supported "
|
|
79
81
|
"when using `out_indices`."
|
|
80
82
|
)
|
|
81
83
|
|
|
82
84
|
return list(
|
|
83
|
-
self.
|
|
85
|
+
self.model.get_intermediate_layers( # type: ignore
|
|
84
86
|
tensor,
|
|
85
87
|
self._out_indices,
|
|
86
88
|
reshape=True,
|
|
@@ -89,4 +91,4 @@ class TorchHubModel(base.BaseModel[torch.Tensor, torch.Tensor]):
|
|
|
89
91
|
)
|
|
90
92
|
)
|
|
91
93
|
|
|
92
|
-
return self.
|
|
94
|
+
return self.model(tensor)
|
|
@@ -4,6 +4,7 @@ from typing import Any, Callable, Dict
|
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
import transformers
|
|
7
|
+
from torch import nn
|
|
7
8
|
from typing_extensions import override
|
|
8
9
|
|
|
9
10
|
from eva.core.models.wrappers import base
|
|
@@ -33,12 +34,10 @@ class HuggingFaceModel(base.BaseModel[torch.Tensor, torch.Tensor]):
|
|
|
33
34
|
self._model_name_or_path = model_name_or_path
|
|
34
35
|
self._model_kwargs = model_kwargs or {}
|
|
35
36
|
|
|
36
|
-
self.load_model()
|
|
37
|
+
self.model = self.load_model()
|
|
37
38
|
|
|
38
39
|
@override
|
|
39
|
-
def load_model(self) ->
|
|
40
|
+
def load_model(self) -> nn.Module:
|
|
40
41
|
# Use safetensors to avoid torch.load security vulnerability
|
|
41
42
|
model_kwargs = {"use_safetensors": True, **self._model_kwargs}
|
|
42
|
-
|
|
43
|
-
self._model_name_or_path, **model_kwargs
|
|
44
|
-
)
|
|
43
|
+
return transformers.AutoModel.from_pretrained(self._model_name_or_path, **model_kwargs)
|
eva/core/models/wrappers/onnx.py
CHANGED
|
@@ -30,21 +30,21 @@ class ONNXModel(base.BaseModel[torch.Tensor, torch.Tensor]):
|
|
|
30
30
|
self._path = path
|
|
31
31
|
self._device = device
|
|
32
32
|
|
|
33
|
-
self.load_model()
|
|
33
|
+
self.model = self.load_model()
|
|
34
34
|
|
|
35
35
|
@override
|
|
36
36
|
def load_model(self) -> Any:
|
|
37
37
|
if self._device == "cuda" and not torch.cuda.is_available():
|
|
38
38
|
raise ValueError("Device is set to 'cuda', but CUDA is not available.")
|
|
39
39
|
provider = "CUDAExecutionProvider" if self._device == "cuda" else "CPUExecutionProvider"
|
|
40
|
-
|
|
40
|
+
return ort.InferenceSession(self._path, providers=[provider]) # type: ignore
|
|
41
41
|
|
|
42
42
|
@override
|
|
43
43
|
def model_forward(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
44
44
|
# TODO: Use IO binding to avoid copying the tensor to CPU.
|
|
45
45
|
# https://onnxruntime.ai/docs/api/python/api_summary.html#data-on-device
|
|
46
|
-
if not isinstance(self.
|
|
46
|
+
if not isinstance(self.model, ort.InferenceSession):
|
|
47
47
|
raise ValueError("Model is not loaded.")
|
|
48
|
-
inputs = {self.
|
|
49
|
-
outputs = self.
|
|
48
|
+
inputs = {self.model.get_inputs()[0].name: tensor.detach().cpu().numpy()}
|
|
49
|
+
outputs = self.model.run(None, inputs)[0]
|
|
50
50
|
return torch.from_numpy(outputs).float().to(tensor.device)
|
eva/core/trainers/trainer.py
CHANGED
|
@@ -8,6 +8,7 @@ from lightning.pytorch import loggers as pl_loggers
|
|
|
8
8
|
from lightning.pytorch import trainer as pl_trainer
|
|
9
9
|
from lightning.pytorch.utilities import argparse
|
|
10
10
|
from lightning_fabric.utilities import cloud_io
|
|
11
|
+
from lightning_utilities.core.rank_zero import rank_zero_only
|
|
11
12
|
from typing_extensions import override
|
|
12
13
|
|
|
13
14
|
from eva.core import loggers as eva_loggers
|
|
@@ -30,6 +31,8 @@ class Trainer(pl_trainer.Trainer):
|
|
|
30
31
|
default_root_dir: str = "logs",
|
|
31
32
|
n_runs: int = 1,
|
|
32
33
|
checkpoint_type: Literal["best", "last"] = "best",
|
|
34
|
+
accelerator: str = "auto",
|
|
35
|
+
devices: int = 1,
|
|
33
36
|
**kwargs: Any,
|
|
34
37
|
) -> None:
|
|
35
38
|
"""Initializes the trainer.
|
|
@@ -44,9 +47,17 @@ class Trainer(pl_trainer.Trainer):
|
|
|
44
47
|
n_runs: The amount of runs (fit and evaluate) to perform in an evaluation session.
|
|
45
48
|
checkpoint_type: Wether to load the "best" or "last" checkpoint saved by the checkpoint
|
|
46
49
|
callback for evaluations on validation & test sets.
|
|
50
|
+
accelerator: The accelerator to use for training (e.g. "cpu", "gpu").
|
|
51
|
+
devices: The number of devices (GPUs) to use for training.
|
|
47
52
|
kwargs: Kew-word arguments of ::class::`lightning.pytorch.Trainer`.
|
|
48
53
|
"""
|
|
49
|
-
super().__init__(
|
|
54
|
+
super().__init__(
|
|
55
|
+
*args,
|
|
56
|
+
default_root_dir=default_root_dir,
|
|
57
|
+
accelerator=accelerator,
|
|
58
|
+
devices=devices,
|
|
59
|
+
**kwargs,
|
|
60
|
+
)
|
|
50
61
|
|
|
51
62
|
self.checkpoint_type = checkpoint_type
|
|
52
63
|
self.n_runs = n_runs
|
|
@@ -66,6 +77,7 @@ class Trainer(pl_trainer.Trainer):
|
|
|
66
77
|
def log_dir(self) -> str | None:
|
|
67
78
|
return self.strategy.broadcast(self._log_dir)
|
|
68
79
|
|
|
80
|
+
@rank_zero_only
|
|
69
81
|
def init_logger_run(self, run_id: int | None) -> None:
|
|
70
82
|
"""Setup the loggers & log directories when starting a new run.
|
|
71
83
|
|
eva/core/utils/__init__.py
CHANGED
|
@@ -3,5 +3,6 @@
|
|
|
3
3
|
from eva.core.utils.clone import clone
|
|
4
4
|
from eva.core.utils.memory import to_cpu
|
|
5
5
|
from eva.core.utils.operations import numeric_sort
|
|
6
|
+
from eva.core.utils.paths import home_dir
|
|
6
7
|
|
|
7
|
-
__all__ = ["clone", "to_cpu", "numeric_sort"]
|
|
8
|
+
__all__ = ["clone", "to_cpu", "numeric_sort", "home_dir"]
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
"""Utility functions for distributed training."""
|
|
2
|
+
|
|
3
|
+
import torch.distributed as dist
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def is_distributed() -> bool:
|
|
7
|
+
"""Check if current environment is distributed.
|
|
8
|
+
|
|
9
|
+
Returns:
|
|
10
|
+
bool: True if distributed environment (e.g. multiple gpu processes).
|
|
11
|
+
"""
|
|
12
|
+
return dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1
|
eva/core/utils/paths.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
"""Utility functions for handling paths."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def home_dir():
|
|
7
|
+
"""Get eva's home directory for caching."""
|
|
8
|
+
torch_home = os.path.expanduser(
|
|
9
|
+
os.getenv(
|
|
10
|
+
"EVA_HOME",
|
|
11
|
+
os.path.join("~/.cache", "eva"),
|
|
12
|
+
)
|
|
13
|
+
)
|
|
14
|
+
return torch_home
|
eva/core/utils/requirements.py
CHANGED
|
@@ -3,10 +3,58 @@
|
|
|
3
3
|
import importlib
|
|
4
4
|
from typing import Dict
|
|
5
5
|
|
|
6
|
-
|
|
6
|
+
import packaging.version
|
|
7
7
|
|
|
8
8
|
|
|
9
|
-
def
|
|
9
|
+
def fetch_version(name: str) -> str | None:
|
|
10
|
+
"""Fetch the installed version of a package.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
name: The name of the package.
|
|
14
|
+
|
|
15
|
+
Returns:
|
|
16
|
+
A string representing the installed version of the package, or None if not found.
|
|
17
|
+
"""
|
|
18
|
+
try:
|
|
19
|
+
module = importlib.import_module(name)
|
|
20
|
+
return getattr(module, "__version__", None)
|
|
21
|
+
except ImportError:
|
|
22
|
+
return None
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def below(name: str, version: str) -> bool:
|
|
26
|
+
"""Check if the installed version of a package is below a certain version.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
name: The name of the package.
|
|
30
|
+
version: The version to compare against.
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
True if the installed version is below the specified version, False otherwise.
|
|
34
|
+
"""
|
|
35
|
+
actual = fetch_version(name)
|
|
36
|
+
if actual:
|
|
37
|
+
return packaging.version.parse(actual) < packaging.version.parse(version)
|
|
38
|
+
return False
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def above_or_equal(name: str, version: str) -> bool:
|
|
42
|
+
"""Check if the installed version of a package is above a certain version.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
name: The name of the package.
|
|
46
|
+
version: The version to compare against.
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
True if the installed version is above the specified version, False otherwise.
|
|
50
|
+
"""
|
|
51
|
+
actual = fetch_version(name)
|
|
52
|
+
if actual:
|
|
53
|
+
return packaging.version.parse(actual) >= packaging.version.parse(version)
|
|
54
|
+
return False
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def check_min_versions(requirements: Dict[str, str]) -> None:
|
|
10
58
|
"""Check installed package versions against requirements dict.
|
|
11
59
|
|
|
12
60
|
Args:
|
|
@@ -17,10 +65,8 @@ def check_dependencies(requirements: Dict[str, str]) -> None:
|
|
|
17
65
|
ImportError: If any package does not meet the minimum required version.
|
|
18
66
|
"""
|
|
19
67
|
for package, min_version in requirements.items():
|
|
20
|
-
|
|
21
|
-
actual = getattr(module, "__version__", None)
|
|
22
|
-
if actual and not (version.parse(actual) >= version.parse(min_version)):
|
|
68
|
+
if below(package, min_version):
|
|
23
69
|
raise ImportError(
|
|
24
|
-
f"Package '{package}' version {
|
|
70
|
+
f"Package '{package}' version {fetch_version(package)} does not meet "
|
|
25
71
|
f"the minimum required version {min_version}."
|
|
26
72
|
)
|
eva/language/__init__.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""eva language API."""
|
|
2
2
|
|
|
3
3
|
try:
|
|
4
|
+
from eva.language import models
|
|
4
5
|
from eva.language.data import datasets
|
|
5
6
|
except ImportError as e:
|
|
6
7
|
msg = (
|
|
@@ -10,4 +11,4 @@ except ImportError as e:
|
|
|
10
11
|
)
|
|
11
12
|
raise ImportError(str(e) + "\n\n" + msg) from e
|
|
12
13
|
|
|
13
|
-
__all__ = ["datasets"]
|
|
14
|
+
__all__ = ["models", "datasets"]
|