kaiko-eva 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaiko-eva might be problematic. Click here for more details.
- eva/core/callbacks/config.py +11 -6
- eva/core/callbacks/writers/embeddings/base.py +44 -10
- eva/core/data/samplers/classification/balanced.py +24 -12
- eva/core/loggers/utils/wandb.py +4 -1
- eva/core/trainers/trainer.py +11 -1
- eva/core/utils/__init__.py +2 -1
- eva/core/utils/distributed.py +12 -0
- eva/core/utils/paths.py +14 -0
- eva/core/utils/requirements.py +52 -6
- eva/language/callbacks/writers/prediction.py +44 -19
- eva/language/data/datasets/classification/pubmedqa.py +1 -1
- eva/language/models/modules/language.py +7 -6
- eva/language/models/typings.py +19 -2
- eva/language/models/wrappers/base.py +4 -4
- eva/language/models/wrappers/huggingface.py +14 -4
- eva/language/models/wrappers/litellm.py +14 -4
- eva/multimodal/models/modules/vision_language.py +6 -5
- eva/multimodal/models/networks/alibaba.py +1 -0
- eva/multimodal/models/networks/others.py +2 -1
- eva/multimodal/models/wrappers/base.py +4 -3
- eva/multimodal/models/wrappers/huggingface.py +26 -13
- eva/multimodal/models/wrappers/litellm.py +4 -2
- eva/multimodal/utils/batch/__init__.py +5 -0
- eva/multimodal/utils/batch/unpack.py +11 -0
- eva/vision/data/datasets/classification/breakhis.py +5 -8
- eva/vision/data/datasets/classification/panda.py +12 -5
- eva/vision/data/datasets/segmentation/btcv.py +1 -1
- eva/vision/data/datasets/segmentation/consep.py +1 -1
- eva/vision/data/datasets/segmentation/lits17.py +1 -1
- eva/vision/data/datasets/segmentation/monusac.py +15 -6
- eva/vision/data/datasets/segmentation/msd_task7_pancreas.py +1 -1
- eva/vision/data/transforms/base/__init__.py +2 -1
- eva/vision/data/transforms/base/monai.py +2 -2
- eva/vision/data/transforms/base/torchvision.py +33 -0
- eva/vision/data/transforms/common/squeeze.py +6 -3
- eva/vision/data/transforms/croppad/crop_foreground.py +8 -7
- eva/vision/data/transforms/croppad/rand_crop_by_label_classes.py +6 -5
- eva/vision/data/transforms/croppad/rand_crop_by_pos_neg_label.py +6 -5
- eva/vision/data/transforms/croppad/rand_spatial_crop.py +8 -7
- eva/vision/data/transforms/croppad/spatial_pad.py +6 -6
- eva/vision/data/transforms/intensity/rand_scale_intensity.py +3 -3
- eva/vision/data/transforms/intensity/rand_shift_intensity.py +3 -3
- eva/vision/data/transforms/intensity/scale_intensity_ranged.py +5 -5
- eva/vision/data/transforms/spatial/flip.py +8 -7
- eva/vision/data/transforms/spatial/resize.py +5 -4
- eva/vision/data/transforms/spatial/rotate.py +8 -7
- eva/vision/data/transforms/spatial/spacing.py +7 -6
- eva/vision/data/transforms/utility/ensure_channel_first.py +6 -6
- eva/vision/models/networks/backbones/universal/vit.py +24 -0
- {kaiko_eva-0.4.0.dist-info → kaiko_eva-0.4.1.dist-info}/METADATA +8 -2
- {kaiko_eva-0.4.0.dist-info → kaiko_eva-0.4.1.dist-info}/RECORD +54 -49
- {kaiko_eva-0.4.0.dist-info → kaiko_eva-0.4.1.dist-info}/WHEEL +0 -0
- {kaiko_eva-0.4.0.dist-info → kaiko_eva-0.4.1.dist-info}/entry_points.txt +0 -0
- {kaiko_eva-0.4.0.dist-info → kaiko_eva-0.4.1.dist-info}/licenses/LICENSE +0 -0
eva/core/callbacks/config.py
CHANGED
|
@@ -9,11 +9,13 @@ from typing import Any, Dict, List
|
|
|
9
9
|
import lightning.pytorch as pl
|
|
10
10
|
import yaml
|
|
11
11
|
from lightning_fabric.utilities import cloud_io
|
|
12
|
+
from loguru import logger
|
|
12
13
|
from loguru import logger as cli_logger
|
|
13
14
|
from omegaconf import OmegaConf
|
|
14
15
|
from typing_extensions import TypeGuard, override
|
|
15
16
|
|
|
16
17
|
from eva.core import loggers
|
|
18
|
+
from eva.core.utils import distributed as dist_utils
|
|
17
19
|
|
|
18
20
|
|
|
19
21
|
class ConfigurationLogger(pl.Callback):
|
|
@@ -39,8 +41,14 @@ class ConfigurationLogger(pl.Callback):
|
|
|
39
41
|
pl_module: pl.LightningModule,
|
|
40
42
|
stage: str | None = None,
|
|
41
43
|
) -> None:
|
|
42
|
-
|
|
43
|
-
|
|
44
|
+
if dist_utils.is_distributed():
|
|
45
|
+
logger.info("ConfigurationLogger skipped as not supported in distributed mode.")
|
|
46
|
+
# TODO: Enabling leads to deadlocks in DDP mode, but I could not yet figure out why.
|
|
47
|
+
return
|
|
48
|
+
|
|
49
|
+
if not trainer.is_global_zero or not _logdir_exists(
|
|
50
|
+
log_dir := trainer.log_dir, self._verbose
|
|
51
|
+
):
|
|
44
52
|
return
|
|
45
53
|
|
|
46
54
|
configuration = _load_submitted_config()
|
|
@@ -130,7 +138,7 @@ def _type_resolver(mapping: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
130
138
|
for key, value in mapping.items():
|
|
131
139
|
if isinstance(value, dict):
|
|
132
140
|
formatted_value = _type_resolver(value)
|
|
133
|
-
elif isinstance(value, list) and isinstance(value[0], dict):
|
|
141
|
+
elif isinstance(value, list) and value and isinstance(value[0], dict):
|
|
134
142
|
formatted_value = [_type_resolver(subvalue) for subvalue in value]
|
|
135
143
|
else:
|
|
136
144
|
try:
|
|
@@ -138,10 +146,7 @@ def _type_resolver(mapping: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
138
146
|
formatted_value = (
|
|
139
147
|
value if isinstance(parsed_value, BuiltinFunctionType) else parsed_value
|
|
140
148
|
)
|
|
141
|
-
|
|
142
149
|
except Exception:
|
|
143
150
|
formatted_value = value
|
|
144
|
-
|
|
145
151
|
mapping[key] = formatted_value
|
|
146
|
-
|
|
147
152
|
return mapping
|
|
@@ -7,6 +7,7 @@ from typing import Any, Dict, List, Sequence
|
|
|
7
7
|
|
|
8
8
|
import lightning.pytorch as pl
|
|
9
9
|
import torch
|
|
10
|
+
import torch.distributed as dist
|
|
10
11
|
from lightning.pytorch import callbacks
|
|
11
12
|
from loguru import logger
|
|
12
13
|
from torch import multiprocessing, nn
|
|
@@ -15,6 +16,7 @@ from typing_extensions import override
|
|
|
15
16
|
from eva.core import utils
|
|
16
17
|
from eva.core.callbacks.writers.embeddings.typings import QUEUE_ITEM
|
|
17
18
|
from eva.core.models.modules.typings import INPUT_BATCH
|
|
19
|
+
from eva.core.utils import distributed as dist_utils
|
|
18
20
|
from eva.core.utils import multiprocessing as eva_multiprocessing
|
|
19
21
|
|
|
20
22
|
|
|
@@ -58,8 +60,9 @@ class EmbeddingsWriter(callbacks.BasePredictionWriter, abc.ABC):
|
|
|
58
60
|
self._save_every_n = save_every_n
|
|
59
61
|
self._metadata_keys = metadata_keys or []
|
|
60
62
|
|
|
61
|
-
self._write_queue: multiprocessing.Queue
|
|
62
|
-
self._write_process: eva_multiprocessing.Process
|
|
63
|
+
self._write_queue: multiprocessing.Queue | None = None
|
|
64
|
+
self._write_process: eva_multiprocessing.Process | None = None
|
|
65
|
+
self._is_rank_zero: bool = False
|
|
63
66
|
|
|
64
67
|
@staticmethod
|
|
65
68
|
@abc.abstractmethod
|
|
@@ -78,9 +81,13 @@ class EmbeddingsWriter(callbacks.BasePredictionWriter, abc.ABC):
|
|
|
78
81
|
|
|
79
82
|
@override
|
|
80
83
|
def on_predict_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
|
|
81
|
-
self.
|
|
82
|
-
self.
|
|
83
|
-
|
|
84
|
+
self._is_rank_zero = trainer.is_global_zero
|
|
85
|
+
if self._is_rank_zero:
|
|
86
|
+
self._check_if_exists()
|
|
87
|
+
self._initialize_write_process()
|
|
88
|
+
if self._write_process is None or self._write_queue is None:
|
|
89
|
+
raise RuntimeError("Failed to initialize embedding writer process.")
|
|
90
|
+
self._write_process.start()
|
|
84
91
|
|
|
85
92
|
if self._backbone is not None:
|
|
86
93
|
self._backbone = self._backbone.to(pl_module.device)
|
|
@@ -106,6 +113,7 @@ class EmbeddingsWriter(callbacks.BasePredictionWriter, abc.ABC):
|
|
|
106
113
|
with torch.no_grad():
|
|
107
114
|
embeddings = self._get_embeddings(prediction)
|
|
108
115
|
|
|
116
|
+
queue_items: List[QUEUE_ITEM] = []
|
|
109
117
|
for local_idx, global_idx in enumerate(batch_indices[: len(embeddings)]):
|
|
110
118
|
data_name = dataset.filename(global_idx)
|
|
111
119
|
save_name = os.path.splitext(data_name)[0] + ".pt"
|
|
@@ -121,15 +129,41 @@ class EmbeddingsWriter(callbacks.BasePredictionWriter, abc.ABC):
|
|
|
121
129
|
split=split,
|
|
122
130
|
metadata=item_metadata,
|
|
123
131
|
)
|
|
124
|
-
|
|
132
|
+
queue_items.append(item)
|
|
125
133
|
|
|
126
|
-
self.
|
|
134
|
+
gathered_items = self._gather_queue_items(queue_items)
|
|
135
|
+
if self._is_rank_zero:
|
|
136
|
+
for item in gathered_items:
|
|
137
|
+
self._write_queue.put(item) # type: ignore
|
|
138
|
+
self._write_process.check_exceptions() # type: ignore
|
|
127
139
|
|
|
128
140
|
@override
|
|
129
141
|
def on_predict_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
142
|
+
if dist_utils.is_distributed():
|
|
143
|
+
dist.barrier()
|
|
144
|
+
|
|
145
|
+
if self._is_rank_zero and self._write_queue is not None:
|
|
146
|
+
self._write_queue.put(None)
|
|
147
|
+
if self._write_process is not None:
|
|
148
|
+
self._write_process.join()
|
|
149
|
+
logger.info(f"Predictions and manifest saved to {self._output_dir}")
|
|
150
|
+
|
|
151
|
+
def _gather_queue_items(self, items: List[QUEUE_ITEM]) -> List[QUEUE_ITEM]:
|
|
152
|
+
"""Gather queue items across distributed ranks, returning only on rank zero."""
|
|
153
|
+
if not dist_utils.is_distributed():
|
|
154
|
+
return items
|
|
155
|
+
|
|
156
|
+
world_size = dist.get_world_size()
|
|
157
|
+
object_list: List[List[QUEUE_ITEM]] = [[] for _ in range(world_size)]
|
|
158
|
+
dist.all_gather_object(object_list, items)
|
|
159
|
+
|
|
160
|
+
if self._is_rank_zero:
|
|
161
|
+
gathered: List[QUEUE_ITEM] = []
|
|
162
|
+
for rank_items in object_list:
|
|
163
|
+
gathered.extend(rank_items)
|
|
164
|
+
return gathered
|
|
165
|
+
|
|
166
|
+
return []
|
|
133
167
|
|
|
134
168
|
def _initialize_write_process(self) -> None:
|
|
135
169
|
self._write_queue = multiprocessing.Queue()
|
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
"""Random class sampler for data loading."""
|
|
2
2
|
|
|
3
3
|
from collections import defaultdict
|
|
4
|
-
from typing import Dict, Iterator, List
|
|
4
|
+
from typing import Dict, Iterator, List, Union
|
|
5
5
|
|
|
6
6
|
import numpy as np
|
|
7
|
+
import torch
|
|
7
8
|
from loguru import logger
|
|
8
9
|
from typing_extensions import override
|
|
9
10
|
|
|
@@ -32,7 +33,7 @@ class BalancedSampler(SamplerWithDataSource[int]):
|
|
|
32
33
|
"""
|
|
33
34
|
self._num_samples = num_samples
|
|
34
35
|
self._replacement = replacement
|
|
35
|
-
self._class_indices: Dict[int, List[int]] = defaultdict(list)
|
|
36
|
+
self._class_indices: Dict[Union[int, str], List[int]] = defaultdict(list)
|
|
36
37
|
self._random_generator = np.random.default_rng(seed)
|
|
37
38
|
self._indices: List[int] = []
|
|
38
39
|
|
|
@@ -62,20 +63,31 @@ class BalancedSampler(SamplerWithDataSource[int]):
|
|
|
62
63
|
super().set_dataset(data_source)
|
|
63
64
|
self._make_indices()
|
|
64
65
|
|
|
66
|
+
def _get_class_idx(self, idx):
|
|
67
|
+
"""Load and validate the class index for a given sample index."""
|
|
68
|
+
if hasattr(self.data_source, "load_target"):
|
|
69
|
+
target = self.data_source.load_target(idx) # type: ignore
|
|
70
|
+
else:
|
|
71
|
+
_, target, _ = DataSample(*self.data_source[idx])
|
|
72
|
+
|
|
73
|
+
if target is None:
|
|
74
|
+
raise ValueError("The dataset must return non-empty targets.")
|
|
75
|
+
|
|
76
|
+
if isinstance(target, str):
|
|
77
|
+
return target
|
|
78
|
+
|
|
79
|
+
if isinstance(target, torch.Tensor):
|
|
80
|
+
if target.numel() != 1:
|
|
81
|
+
raise ValueError("The dataset must return a single & scalar target.")
|
|
82
|
+
return int(target.item())
|
|
83
|
+
|
|
84
|
+
raise ValueError("Unsupported target type. Expected str or tensor-like object.")
|
|
85
|
+
|
|
65
86
|
def _make_indices(self):
|
|
66
87
|
"""Samples the indices for each class in the dataset."""
|
|
67
88
|
self._class_indices.clear()
|
|
68
89
|
for idx in tqdm(range(len(self.data_source)), desc="Fetching class indices for sampler"):
|
|
69
|
-
|
|
70
|
-
target = self.data_source.load_target(idx) # type: ignore
|
|
71
|
-
else:
|
|
72
|
-
_, target, _ = DataSample(*self.data_source[idx])
|
|
73
|
-
if target is None:
|
|
74
|
-
raise ValueError("The dataset must return non-empty targets.")
|
|
75
|
-
if target.numel() != 1:
|
|
76
|
-
raise ValueError("The dataset must return a single & scalar target.")
|
|
77
|
-
|
|
78
|
-
class_idx = int(target.item())
|
|
90
|
+
class_idx = self._get_class_idx(idx)
|
|
79
91
|
self._class_indices[class_idx].append(idx)
|
|
80
92
|
|
|
81
93
|
if not self._replacement:
|
eva/core/loggers/utils/wandb.py
CHANGED
|
@@ -5,6 +5,8 @@ from typing import Any, Dict
|
|
|
5
5
|
|
|
6
6
|
from loguru import logger
|
|
7
7
|
|
|
8
|
+
from eva.core.utils import requirements
|
|
9
|
+
|
|
8
10
|
|
|
9
11
|
def rename_active_run(name: str) -> None:
|
|
10
12
|
"""Renames the current run."""
|
|
@@ -12,7 +14,8 @@ def rename_active_run(name: str) -> None:
|
|
|
12
14
|
|
|
13
15
|
if wandb.run:
|
|
14
16
|
wandb.run.name = name
|
|
15
|
-
wandb.
|
|
17
|
+
if requirements.below("wandb", "0.21.0"):
|
|
18
|
+
wandb.run.save()
|
|
16
19
|
else:
|
|
17
20
|
logger.warning("No active wandb run found that could be renamed.")
|
|
18
21
|
|
eva/core/trainers/trainer.py
CHANGED
|
@@ -31,6 +31,8 @@ class Trainer(pl_trainer.Trainer):
|
|
|
31
31
|
default_root_dir: str = "logs",
|
|
32
32
|
n_runs: int = 1,
|
|
33
33
|
checkpoint_type: Literal["best", "last"] = "best",
|
|
34
|
+
accelerator: str = "auto",
|
|
35
|
+
devices: int = 1,
|
|
34
36
|
**kwargs: Any,
|
|
35
37
|
) -> None:
|
|
36
38
|
"""Initializes the trainer.
|
|
@@ -45,9 +47,17 @@ class Trainer(pl_trainer.Trainer):
|
|
|
45
47
|
n_runs: The amount of runs (fit and evaluate) to perform in an evaluation session.
|
|
46
48
|
checkpoint_type: Wether to load the "best" or "last" checkpoint saved by the checkpoint
|
|
47
49
|
callback for evaluations on validation & test sets.
|
|
50
|
+
accelerator: The accelerator to use for training (e.g. "cpu", "gpu").
|
|
51
|
+
devices: The number of devices (GPUs) to use for training.
|
|
48
52
|
kwargs: Kew-word arguments of ::class::`lightning.pytorch.Trainer`.
|
|
49
53
|
"""
|
|
50
|
-
super().__init__(
|
|
54
|
+
super().__init__(
|
|
55
|
+
*args,
|
|
56
|
+
default_root_dir=default_root_dir,
|
|
57
|
+
accelerator=accelerator,
|
|
58
|
+
devices=devices,
|
|
59
|
+
**kwargs,
|
|
60
|
+
)
|
|
51
61
|
|
|
52
62
|
self.checkpoint_type = checkpoint_type
|
|
53
63
|
self.n_runs = n_runs
|
eva/core/utils/__init__.py
CHANGED
|
@@ -3,5 +3,6 @@
|
|
|
3
3
|
from eva.core.utils.clone import clone
|
|
4
4
|
from eva.core.utils.memory import to_cpu
|
|
5
5
|
from eva.core.utils.operations import numeric_sort
|
|
6
|
+
from eva.core.utils.paths import home_dir
|
|
6
7
|
|
|
7
|
-
__all__ = ["clone", "to_cpu", "numeric_sort"]
|
|
8
|
+
__all__ = ["clone", "to_cpu", "numeric_sort", "home_dir"]
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
"""Utility functions for distributed training."""
|
|
2
|
+
|
|
3
|
+
import torch.distributed as dist
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def is_distributed() -> bool:
|
|
7
|
+
"""Check if current environment is distributed.
|
|
8
|
+
|
|
9
|
+
Returns:
|
|
10
|
+
bool: True if distributed environment (e.g. multiple gpu processes).
|
|
11
|
+
"""
|
|
12
|
+
return dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1
|
eva/core/utils/paths.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
"""Utility functions for handling paths."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def home_dir():
|
|
7
|
+
"""Get eva's home directory for caching."""
|
|
8
|
+
torch_home = os.path.expanduser(
|
|
9
|
+
os.getenv(
|
|
10
|
+
"EVA_HOME",
|
|
11
|
+
os.path.join("~/.cache", "eva"),
|
|
12
|
+
)
|
|
13
|
+
)
|
|
14
|
+
return torch_home
|
eva/core/utils/requirements.py
CHANGED
|
@@ -3,10 +3,58 @@
|
|
|
3
3
|
import importlib
|
|
4
4
|
from typing import Dict
|
|
5
5
|
|
|
6
|
-
|
|
6
|
+
import packaging.version
|
|
7
7
|
|
|
8
8
|
|
|
9
|
-
def
|
|
9
|
+
def fetch_version(name: str) -> str | None:
|
|
10
|
+
"""Fetch the installed version of a package.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
name: The name of the package.
|
|
14
|
+
|
|
15
|
+
Returns:
|
|
16
|
+
A string representing the installed version of the package, or None if not found.
|
|
17
|
+
"""
|
|
18
|
+
try:
|
|
19
|
+
module = importlib.import_module(name)
|
|
20
|
+
return getattr(module, "__version__", None)
|
|
21
|
+
except ImportError:
|
|
22
|
+
return None
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def below(name: str, version: str) -> bool:
|
|
26
|
+
"""Check if the installed version of a package is below a certain version.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
name: The name of the package.
|
|
30
|
+
version: The version to compare against.
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
True if the installed version is below the specified version, False otherwise.
|
|
34
|
+
"""
|
|
35
|
+
actual = fetch_version(name)
|
|
36
|
+
if actual:
|
|
37
|
+
return packaging.version.parse(actual) < packaging.version.parse(version)
|
|
38
|
+
return False
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def above_or_equal(name: str, version: str) -> bool:
|
|
42
|
+
"""Check if the installed version of a package is above a certain version.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
name: The name of the package.
|
|
46
|
+
version: The version to compare against.
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
True if the installed version is above the specified version, False otherwise.
|
|
50
|
+
"""
|
|
51
|
+
actual = fetch_version(name)
|
|
52
|
+
if actual:
|
|
53
|
+
return packaging.version.parse(actual) >= packaging.version.parse(version)
|
|
54
|
+
return False
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def check_min_versions(requirements: Dict[str, str]) -> None:
|
|
10
58
|
"""Check installed package versions against requirements dict.
|
|
11
59
|
|
|
12
60
|
Args:
|
|
@@ -17,10 +65,8 @@ def check_dependencies(requirements: Dict[str, str]) -> None:
|
|
|
17
65
|
ImportError: If any package does not meet the minimum required version.
|
|
18
66
|
"""
|
|
19
67
|
for package, min_version in requirements.items():
|
|
20
|
-
|
|
21
|
-
actual = getattr(module, "__version__", None)
|
|
22
|
-
if actual and not (version.parse(actual) >= version.parse(min_version)):
|
|
68
|
+
if below(package, min_version):
|
|
23
69
|
raise ImportError(
|
|
24
|
-
f"Package '{package}' version {
|
|
70
|
+
f"Package '{package}' version {fetch_version(package)} does not meet "
|
|
25
71
|
f"the minimum required version {min_version}."
|
|
26
72
|
)
|
|
@@ -7,11 +7,13 @@ from typing import Any, Dict, List, Literal, Sequence, Tuple, TypedDict
|
|
|
7
7
|
import lightning.pytorch as pl
|
|
8
8
|
import pandas as pd
|
|
9
9
|
import torch
|
|
10
|
+
import torch.distributed as dist
|
|
10
11
|
from lightning.pytorch import callbacks
|
|
11
12
|
from torch import nn
|
|
12
13
|
from typing_extensions import NotRequired, override
|
|
13
14
|
|
|
14
15
|
from eva.core.models.modules import utils as module_utils
|
|
16
|
+
from eva.core.utils import distributed as dist_utils
|
|
15
17
|
from eva.language.models.typings import TextBatch
|
|
16
18
|
from eva.language.utils.text import messages as message_utils
|
|
17
19
|
|
|
@@ -74,10 +76,14 @@ class TextPredictionWriter(callbacks.BasePredictionWriter, abc.ABC):
|
|
|
74
76
|
|
|
75
77
|
self._manifest_path = os.path.join(self.output_dir, f"manifest.{self.save_format}")
|
|
76
78
|
self._data: List[ManifestEntry] = []
|
|
79
|
+
self._is_rank_zero: bool = False
|
|
77
80
|
|
|
78
81
|
@override
|
|
79
82
|
def on_predict_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
|
|
80
|
-
self.
|
|
83
|
+
self._is_rank_zero = trainer.is_global_zero
|
|
84
|
+
|
|
85
|
+
if self._is_rank_zero:
|
|
86
|
+
self._check_if_exists()
|
|
81
87
|
|
|
82
88
|
self.model = self.model.to(pl_module.device)
|
|
83
89
|
self.model.eval()
|
|
@@ -105,11 +111,12 @@ class TextPredictionWriter(callbacks.BasePredictionWriter, abc.ABC):
|
|
|
105
111
|
|
|
106
112
|
for i in range(len(batch_indices)):
|
|
107
113
|
entry: ManifestEntry = {
|
|
108
|
-
"text": message_utils.serialize(text_batch[i]),
|
|
109
114
|
"prediction": str(prediction_batch[i]),
|
|
110
115
|
"target": str(target_batch[i]) if has_target else "",
|
|
111
116
|
"split": split if split else "",
|
|
112
117
|
}
|
|
118
|
+
if self.include_input:
|
|
119
|
+
entry["text"] = message_utils.serialize(text_batch[i])
|
|
113
120
|
|
|
114
121
|
if self.metadata_keys is not None and metadata_batch is not None:
|
|
115
122
|
for key in self.metadata_keys:
|
|
@@ -120,26 +127,45 @@ class TextPredictionWriter(callbacks.BasePredictionWriter, abc.ABC):
|
|
|
120
127
|
@override
|
|
121
128
|
def on_predict_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
|
|
122
129
|
"""Saves the gathered predictions to a manifest file."""
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
130
|
+
if dist_utils.is_distributed():
|
|
131
|
+
dist.barrier()
|
|
132
|
+
data = self._gather_data_from_ranks()
|
|
133
|
+
else:
|
|
134
|
+
data = self._data
|
|
135
|
+
|
|
136
|
+
if self._is_rank_zero:
|
|
137
|
+
df = pd.DataFrame(data)
|
|
138
|
+
|
|
139
|
+
match self.save_format:
|
|
140
|
+
case "jsonl":
|
|
141
|
+
df.to_json(self._manifest_path, orient="records", lines=True)
|
|
142
|
+
case "parquet":
|
|
143
|
+
df.to_parquet(self._manifest_path, index=False)
|
|
144
|
+
case "csv":
|
|
145
|
+
df.to_csv(self._manifest_path, index=False)
|
|
146
|
+
case _:
|
|
147
|
+
raise ValueError(f"Unsupported save format: {self.save_format}")
|
|
148
|
+
|
|
149
|
+
def _gather_data_from_ranks(self) -> List[ManifestEntry]:
|
|
150
|
+
world_size = dist.get_world_size()
|
|
151
|
+
gathered: List[List[ManifestEntry] | None] = [None] * world_size
|
|
152
|
+
dist.all_gather_object(gathered, self._data)
|
|
153
|
+
return [row for shard in gathered for row in (shard or [])]
|
|
134
154
|
|
|
135
155
|
def _get_predictions(self, batch: TextBatch) -> List[str]:
|
|
136
156
|
with torch.no_grad():
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
if
|
|
140
|
-
|
|
157
|
+
output = self.model(batch)
|
|
158
|
+
|
|
159
|
+
if (
|
|
160
|
+
not isinstance(output, dict)
|
|
161
|
+
or "generated_text" not in output
|
|
162
|
+
or not all(isinstance(p, str) for p in output["generated_text"])
|
|
163
|
+
):
|
|
164
|
+
raise ValueError(
|
|
165
|
+
f"A dictionary with 'generated_text' key is expected, got {type(output)}"
|
|
166
|
+
)
|
|
141
167
|
|
|
142
|
-
return
|
|
168
|
+
return output["generated_text"]
|
|
143
169
|
|
|
144
170
|
def _check_if_exists(self) -> None:
|
|
145
171
|
"""Checks if the output directory already exists and if it should be overwritten."""
|
|
@@ -150,7 +176,6 @@ class TextPredictionWriter(callbacks.BasePredictionWriter, abc.ABC):
|
|
|
150
176
|
"either means that the predictions have been computed before or that a "
|
|
151
177
|
"wrong output directory is being used."
|
|
152
178
|
)
|
|
153
|
-
os.makedirs(self.output_dir, exist_ok=True)
|
|
154
179
|
|
|
155
180
|
def _apply_postprocess(
|
|
156
181
|
self, pl_module: pl.LightningModule, targets: Any, predictions: Any
|
|
@@ -121,7 +121,7 @@ class PubMedQA(base.TextClassification):
|
|
|
121
121
|
|
|
122
122
|
@override
|
|
123
123
|
def validate(self) -> None:
|
|
124
|
-
if len(self) != self._expected_dataset_lengths[self._split]:
|
|
124
|
+
if len(self) != (self._max_samples or self._expected_dataset_lengths[self._split]):
|
|
125
125
|
raise ValueError(
|
|
126
126
|
f"Dataset length mismatch for split '{self._split}': "
|
|
127
127
|
f"expected {self._expected_dataset_lengths[self._split]}, "
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""Model module for language models."""
|
|
2
2
|
|
|
3
|
-
from typing import Any
|
|
3
|
+
from typing import Any
|
|
4
4
|
|
|
5
5
|
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
|
6
6
|
from torch import nn
|
|
@@ -9,7 +9,7 @@ from typing_extensions import override
|
|
|
9
9
|
from eva.core.metrics import structs as metrics_lib
|
|
10
10
|
from eva.core.models.modules import module
|
|
11
11
|
from eva.core.models.modules.utils import batch_postprocess
|
|
12
|
-
from eva.language.models.typings import PredictionBatch, TextBatch
|
|
12
|
+
from eva.language.models.typings import ModelOutput, PredictionBatch, TextBatch
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
class LanguageModule(module.ModelModule):
|
|
@@ -33,7 +33,7 @@ class LanguageModule(module.ModelModule):
|
|
|
33
33
|
self.model = model
|
|
34
34
|
|
|
35
35
|
@override
|
|
36
|
-
def forward(self, batch: TextBatch, *args: Any, **kwargs: Any) ->
|
|
36
|
+
def forward(self, batch: TextBatch, *args: Any, **kwargs: Any) -> ModelOutput:
|
|
37
37
|
return self.model(batch)
|
|
38
38
|
|
|
39
39
|
@override
|
|
@@ -46,13 +46,14 @@ class LanguageModule(module.ModelModule):
|
|
|
46
46
|
|
|
47
47
|
def _batch_step(self, batch: TextBatch) -> STEP_OUTPUT:
|
|
48
48
|
text, targets, metadata = TextBatch(*batch)
|
|
49
|
-
|
|
49
|
+
output = self.forward(batch)
|
|
50
|
+
|
|
50
51
|
return {
|
|
51
52
|
"inputs": text,
|
|
52
|
-
"predictions":
|
|
53
|
+
"predictions": output.pop("generated_text"), # type: ignore
|
|
53
54
|
"targets": targets,
|
|
54
55
|
"metadata": metadata,
|
|
55
|
-
}
|
|
56
|
+
} | output
|
|
56
57
|
|
|
57
58
|
|
|
58
59
|
class OfflineLanguageModule(module.ModelModule):
|
eva/language/models/typings.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
"""Type definitions for language models."""
|
|
2
2
|
|
|
3
|
-
from typing import Any, Dict, Generic, List, TypeVar
|
|
3
|
+
from typing import Any, Dict, Generic, List, TypedDict, TypeVar
|
|
4
4
|
|
|
5
|
-
|
|
5
|
+
import torch
|
|
6
|
+
from typing_extensions import NamedTuple, NotRequired
|
|
6
7
|
|
|
7
8
|
from eva.language.data.messages import MessageSeries
|
|
8
9
|
|
|
@@ -37,3 +38,19 @@ class PredictionBatch(NamedTuple, Generic[TargetType]):
|
|
|
37
38
|
|
|
38
39
|
metadata: Dict[str, Any] | None
|
|
39
40
|
"""Additional metadata."""
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class ModelOutput(TypedDict):
|
|
44
|
+
"""The output batch produced by the model forward pass."""
|
|
45
|
+
|
|
46
|
+
generated_text: List[str]
|
|
47
|
+
"""The text generated by the model."""
|
|
48
|
+
|
|
49
|
+
input_ids: NotRequired[torch.Tensor | None]
|
|
50
|
+
"""The token ids of the input text."""
|
|
51
|
+
|
|
52
|
+
output_ids: NotRequired[torch.Tensor | None]
|
|
53
|
+
"""The token ids of the model output (usually containing both input and prediction)."""
|
|
54
|
+
|
|
55
|
+
attention_mask: NotRequired[torch.Tensor | None]
|
|
56
|
+
"""The attention mask for the input tokens."""
|
|
@@ -1,16 +1,16 @@
|
|
|
1
1
|
"""Base class for language model wrappers."""
|
|
2
2
|
|
|
3
3
|
import abc
|
|
4
|
-
from typing import Any, Callable
|
|
4
|
+
from typing import Any, Callable
|
|
5
5
|
|
|
6
6
|
from typing_extensions import override
|
|
7
7
|
|
|
8
8
|
from eva.core.models.wrappers import base
|
|
9
9
|
from eva.language.data.messages import ModelSystemMessage
|
|
10
|
-
from eva.language.models.typings import TextBatch
|
|
10
|
+
from eva.language.models.typings import ModelOutput, TextBatch
|
|
11
11
|
|
|
12
12
|
|
|
13
|
-
class LanguageModel(base.BaseModel[TextBatch,
|
|
13
|
+
class LanguageModel(base.BaseModel[TextBatch, ModelOutput]):
|
|
14
14
|
"""Base class for language models.
|
|
15
15
|
|
|
16
16
|
Classes that inherit from this should implement the following methods:
|
|
@@ -36,7 +36,7 @@ class LanguageModel(base.BaseModel[TextBatch, List[str]]):
|
|
|
36
36
|
self.system_message = ModelSystemMessage(content=system_prompt) if system_prompt else None
|
|
37
37
|
|
|
38
38
|
@override
|
|
39
|
-
def forward(self, batch: TextBatch) ->
|
|
39
|
+
def forward(self, batch: TextBatch) -> ModelOutput:
|
|
40
40
|
"""Forward pass of the model."""
|
|
41
41
|
inputs = self.format_inputs(batch)
|
|
42
42
|
return super().forward(inputs)
|
|
@@ -5,7 +5,7 @@ from typing import Any, Callable, Dict, List, Literal
|
|
|
5
5
|
from transformers.pipelines import pipeline
|
|
6
6
|
from typing_extensions import override
|
|
7
7
|
|
|
8
|
-
from eva.language.models.typings import TextBatch
|
|
8
|
+
from eva.language.models.typings import ModelOutput, TextBatch
|
|
9
9
|
from eva.language.models.wrappers import base
|
|
10
10
|
from eva.language.utils.text import messages as message_utils
|
|
11
11
|
|
|
@@ -13,6 +13,14 @@ from eva.language.utils.text import messages as message_utils
|
|
|
13
13
|
class HuggingFaceModel(base.LanguageModel):
|
|
14
14
|
"""Wrapper class for loading HuggingFace `transformers` models using pipelines."""
|
|
15
15
|
|
|
16
|
+
_default_generation_kwargs = {
|
|
17
|
+
"temperature": 0.0,
|
|
18
|
+
"max_new_tokens": 1024,
|
|
19
|
+
"do_sample": False,
|
|
20
|
+
"top_p": 1.0,
|
|
21
|
+
}
|
|
22
|
+
"""Default HF model parameters for evaluation."""
|
|
23
|
+
|
|
16
24
|
def __init__(
|
|
17
25
|
self,
|
|
18
26
|
model_name_or_path: str,
|
|
@@ -41,7 +49,7 @@ class HuggingFaceModel(base.LanguageModel):
|
|
|
41
49
|
self._model_name_or_path = model_name_or_path
|
|
42
50
|
self._task = task
|
|
43
51
|
self._model_kwargs = model_kwargs or {}
|
|
44
|
-
self._generation_kwargs = generation_kwargs or {}
|
|
52
|
+
self._generation_kwargs = self._default_generation_kwargs | (generation_kwargs or {})
|
|
45
53
|
self._chat_mode = chat_mode
|
|
46
54
|
|
|
47
55
|
self.model = self.load_model()
|
|
@@ -84,7 +92,7 @@ class HuggingFaceModel(base.LanguageModel):
|
|
|
84
92
|
return list(map(message_utils.merge_message_contents, message_batch))
|
|
85
93
|
|
|
86
94
|
@override
|
|
87
|
-
def model_forward(self, prompts: List[str]) ->
|
|
95
|
+
def model_forward(self, prompts: List[str]) -> ModelOutput:
|
|
88
96
|
"""Generates text using the pipeline.
|
|
89
97
|
|
|
90
98
|
Args:
|
|
@@ -96,10 +104,12 @@ class HuggingFaceModel(base.LanguageModel):
|
|
|
96
104
|
outputs = self.model(prompts, return_full_text=False, **self._generation_kwargs)
|
|
97
105
|
if outputs is None:
|
|
98
106
|
raise ValueError("Outputs from the model are None.")
|
|
107
|
+
|
|
99
108
|
results = []
|
|
100
109
|
for output in outputs:
|
|
101
110
|
if isinstance(output, list):
|
|
102
111
|
results.append(output[0]["generated_text"]) # type: ignore
|
|
103
112
|
else:
|
|
104
113
|
results.append(output["generated_text"]) # type: ignore
|
|
105
|
-
|
|
114
|
+
|
|
115
|
+
return ModelOutput(generated_text=results)
|