kaiko-eva 0.3.3__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaiko-eva might be problematic. Click here for more details.
- eva/core/callbacks/config.py +4 -0
- eva/core/cli/setup.py +1 -1
- eva/core/data/dataloaders/__init__.py +1 -2
- eva/core/data/samplers/random.py +17 -10
- eva/core/interface/interface.py +21 -0
- eva/core/models/modules/module.py +2 -2
- eva/core/models/wrappers/base.py +2 -2
- eva/core/models/wrappers/from_function.py +3 -3
- eva/core/models/wrappers/from_torchhub.py +9 -7
- eva/core/models/wrappers/huggingface.py +4 -5
- eva/core/models/wrappers/onnx.py +5 -5
- eva/core/trainers/trainer.py +2 -0
- eva/language/__init__.py +2 -1
- eva/language/callbacks/__init__.py +5 -0
- eva/language/callbacks/writers/__init__.py +5 -0
- eva/language/callbacks/writers/prediction.py +176 -0
- eva/language/data/dataloaders/__init__.py +5 -0
- eva/language/data/dataloaders/collate_fn/__init__.py +5 -0
- eva/language/data/dataloaders/collate_fn/text.py +57 -0
- eva/language/data/datasets/__init__.py +3 -1
- eva/language/data/datasets/{language.py → base.py} +1 -1
- eva/language/data/datasets/classification/base.py +3 -43
- eva/language/data/datasets/classification/pubmedqa.py +36 -4
- eva/language/data/datasets/prediction.py +151 -0
- eva/language/data/datasets/schemas.py +18 -0
- eva/language/data/datasets/text.py +92 -0
- eva/language/data/datasets/typings.py +39 -0
- eva/language/data/messages.py +60 -0
- eva/language/models/__init__.py +15 -11
- eva/language/models/modules/__init__.py +2 -2
- eva/language/models/modules/language.py +93 -0
- eva/language/models/networks/__init__.py +12 -0
- eva/language/models/networks/alibaba.py +26 -0
- eva/language/models/networks/api/__init__.py +11 -0
- eva/language/models/networks/api/anthropic.py +34 -0
- eva/language/models/networks/registry.py +5 -0
- eva/language/models/typings.py +39 -0
- eva/language/models/wrappers/__init__.py +13 -5
- eva/language/models/wrappers/base.py +47 -0
- eva/language/models/wrappers/from_registry.py +54 -0
- eva/language/models/wrappers/huggingface.py +44 -8
- eva/language/models/wrappers/litellm.py +81 -46
- eva/language/models/wrappers/vllm.py +37 -13
- eva/language/utils/__init__.py +2 -1
- eva/language/utils/str_to_int_tensor.py +20 -12
- eva/language/utils/text/__init__.py +5 -0
- eva/language/utils/text/messages.py +113 -0
- eva/multimodal/__init__.py +6 -0
- eva/multimodal/callbacks/__init__.py +5 -0
- eva/multimodal/callbacks/writers/__init__.py +5 -0
- eva/multimodal/callbacks/writers/prediction.py +39 -0
- eva/multimodal/data/__init__.py +5 -0
- eva/multimodal/data/dataloaders/__init__.py +5 -0
- eva/multimodal/data/dataloaders/collate_fn/__init__.py +5 -0
- eva/multimodal/data/dataloaders/collate_fn/text_image.py +28 -0
- eva/multimodal/data/datasets/__init__.py +6 -0
- eva/multimodal/data/datasets/base.py +13 -0
- eva/multimodal/data/datasets/multiple_choice/__init__.py +5 -0
- eva/multimodal/data/datasets/multiple_choice/patch_camelyon.py +80 -0
- eva/multimodal/data/datasets/schemas.py +14 -0
- eva/multimodal/data/datasets/text_image.py +77 -0
- eva/multimodal/data/datasets/typings.py +27 -0
- eva/multimodal/models/__init__.py +8 -0
- eva/multimodal/models/modules/__init__.py +5 -0
- eva/multimodal/models/modules/vision_language.py +55 -0
- eva/multimodal/models/networks/__init__.py +14 -0
- eva/multimodal/models/networks/alibaba.py +39 -0
- eva/multimodal/models/networks/api/__init__.py +11 -0
- eva/multimodal/models/networks/api/anthropic.py +34 -0
- eva/multimodal/models/networks/others.py +47 -0
- eva/multimodal/models/networks/registry.py +5 -0
- eva/multimodal/models/typings.py +27 -0
- eva/multimodal/models/wrappers/__init__.py +13 -0
- eva/multimodal/models/wrappers/base.py +47 -0
- eva/multimodal/models/wrappers/from_registry.py +54 -0
- eva/multimodal/models/wrappers/huggingface.py +180 -0
- eva/multimodal/models/wrappers/litellm.py +56 -0
- eva/multimodal/utils/__init__.py +1 -0
- eva/multimodal/utils/image/__init__.py +5 -0
- eva/multimodal/utils/image/encode.py +28 -0
- eva/multimodal/utils/text/__init__.py +1 -0
- eva/multimodal/utils/text/messages.py +79 -0
- eva/vision/data/datasets/classification/patch_camelyon.py +8 -6
- eva/vision/data/transforms/__init__.py +2 -1
- eva/vision/data/transforms/spatial/__init__.py +2 -1
- eva/vision/data/transforms/spatial/functional/__init__.py +5 -0
- eva/vision/data/transforms/spatial/functional/resize.py +26 -0
- eva/vision/data/transforms/spatial/resize.py +62 -0
- eva/vision/models/wrappers/from_registry.py +6 -5
- eva/vision/models/wrappers/from_timm.py +6 -4
- {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.0.dist-info}/METADATA +10 -2
- {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.0.dist-info}/RECORD +95 -38
- eva/core/data/dataloaders/collate_fn/__init__.py +0 -5
- eva/core/data/dataloaders/collate_fn/collate.py +0 -24
- eva/language/models/modules/text.py +0 -85
- {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.0.dist-info}/WHEEL +0 -0
- {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.0.dist-info}/entry_points.txt +0 -0
- {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.0.dist-info}/licenses/LICENSE +0 -0
eva/core/callbacks/config.py
CHANGED
|
@@ -51,6 +51,10 @@ class ConfigurationLogger(pl.Callback):
|
|
|
51
51
|
|
|
52
52
|
save_as = os.path.join(log_dir, self._save_as)
|
|
53
53
|
fs = cloud_io.get_filesystem(log_dir)
|
|
54
|
+
|
|
55
|
+
if not fs.exists(log_dir):
|
|
56
|
+
fs.makedirs(log_dir)
|
|
57
|
+
|
|
54
58
|
with fs.open(save_as, "w") as output_file:
|
|
55
59
|
yaml.dump(configuration, output_file, sort_keys=False)
|
|
56
60
|
|
eva/core/cli/setup.py
CHANGED
eva/core/data/samplers/random.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from typing import Optional
|
|
4
4
|
|
|
5
|
+
import torch
|
|
5
6
|
from torch.utils import data
|
|
6
7
|
from typing_extensions import override
|
|
7
8
|
|
|
@@ -10,30 +11,36 @@ from eva.core.data.samplers.sampler import SamplerWithDataSource
|
|
|
10
11
|
|
|
11
12
|
|
|
12
13
|
class RandomSampler(data.RandomSampler, SamplerWithDataSource[int]):
|
|
13
|
-
"""Samples elements randomly."""
|
|
14
|
+
"""Samples elements randomly from a MapDataset."""
|
|
14
15
|
|
|
15
16
|
data_source: datasets.MapDataset # type: ignore
|
|
16
17
|
|
|
17
18
|
def __init__(
|
|
18
|
-
self,
|
|
19
|
+
self,
|
|
20
|
+
replacement: bool = False,
|
|
21
|
+
num_samples: Optional[int] = None,
|
|
22
|
+
seed: Optional[int] = None,
|
|
19
23
|
) -> None:
|
|
20
|
-
"""
|
|
24
|
+
"""Initialize the random sampler.
|
|
21
25
|
|
|
22
26
|
Args:
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
generator: Generator used in sampling.
|
|
27
|
+
replacement: Samples are drawn on-demand with replacement if ``True``, default=``False``
|
|
28
|
+
num_samples: Number of samples to draw, default=``len(dataset)``.
|
|
29
|
+
seed: Optional seed for the random number generator.
|
|
27
30
|
"""
|
|
28
31
|
self.replacement = replacement
|
|
29
32
|
self._num_samples = num_samples
|
|
30
|
-
self.
|
|
33
|
+
self._generator = None
|
|
34
|
+
|
|
35
|
+
if seed is not None:
|
|
36
|
+
self._generator = torch.Generator()
|
|
37
|
+
self._generator.manual_seed(seed)
|
|
31
38
|
|
|
32
39
|
@override
|
|
33
40
|
def set_dataset(self, data_source: datasets.MapDataset) -> None:
|
|
34
41
|
super().__init__(
|
|
35
42
|
data_source,
|
|
36
43
|
replacement=self.replacement,
|
|
37
|
-
num_samples=self.
|
|
38
|
-
generator=self.
|
|
44
|
+
num_samples=self._num_samples,
|
|
45
|
+
generator=self._generator,
|
|
39
46
|
)
|
eva/core/interface/interface.py
CHANGED
|
@@ -132,3 +132,24 @@ class Interface:
|
|
|
132
132
|
n_runs=trainer.n_runs,
|
|
133
133
|
verbose=trainer.n_runs > 1,
|
|
134
134
|
)
|
|
135
|
+
|
|
136
|
+
def validate_test(
|
|
137
|
+
self,
|
|
138
|
+
trainer: eva_trainer.Trainer,
|
|
139
|
+
model: modules.ModelModule,
|
|
140
|
+
data: datamodules.DataModule,
|
|
141
|
+
) -> None:
|
|
142
|
+
"""Runs validation & test stages."""
|
|
143
|
+
if getattr(data.datasets, "val", None) is None:
|
|
144
|
+
raise ValueError("The provided data module does not contain a validation dataset.")
|
|
145
|
+
if getattr(data.datasets, "test", None) is None:
|
|
146
|
+
raise ValueError("The provided data module does not contain a test dataset.")
|
|
147
|
+
|
|
148
|
+
eva_trainer.run_evaluation_session(
|
|
149
|
+
base_trainer=trainer,
|
|
150
|
+
base_model=model,
|
|
151
|
+
datamodule=data,
|
|
152
|
+
stages=["validate", "test"],
|
|
153
|
+
n_runs=trainer.n_runs,
|
|
154
|
+
verbose=trainer.n_runs > 1,
|
|
155
|
+
)
|
|
@@ -33,8 +33,8 @@ class ModelModule(pl.LightningModule):
|
|
|
33
33
|
super().__init__()
|
|
34
34
|
|
|
35
35
|
self._metrics = metrics or self.default_metrics
|
|
36
|
-
self._postprocess = postprocess or self.default_postprocess
|
|
37
36
|
|
|
37
|
+
self.postprocess = postprocess or self.default_postprocess
|
|
38
38
|
self.metrics = metrics_lib.MetricModule.from_schema(self._metrics)
|
|
39
39
|
|
|
40
40
|
@property
|
|
@@ -133,7 +133,7 @@ class ModelModule(pl.LightningModule):
|
|
|
133
133
|
Returns:
|
|
134
134
|
The updated outputs.
|
|
135
135
|
"""
|
|
136
|
-
self.
|
|
136
|
+
self.postprocess(outputs)
|
|
137
137
|
return memory.recursive_detach(outputs, to_cpu=self.metrics_device.type == "cpu")
|
|
138
138
|
|
|
139
139
|
def _forward_and_log_metrics(
|
eva/core/models/wrappers/base.py
CHANGED
|
@@ -25,7 +25,7 @@ class BaseModel(nn.Module, Generic[InputType, OutputType]):
|
|
|
25
25
|
|
|
26
26
|
self._output_transforms = transforms
|
|
27
27
|
|
|
28
|
-
self.
|
|
28
|
+
self.model: Callable[..., OutputType] | nn.Module
|
|
29
29
|
|
|
30
30
|
@override
|
|
31
31
|
def forward(self, tensor: InputType) -> OutputType:
|
|
@@ -43,7 +43,7 @@ class BaseModel(nn.Module, Generic[InputType, OutputType]):
|
|
|
43
43
|
Args:
|
|
44
44
|
tensor: The input tensor to the model.
|
|
45
45
|
"""
|
|
46
|
-
return self.
|
|
46
|
+
return self.model(tensor)
|
|
47
47
|
|
|
48
48
|
def _apply_transforms(self, tensor: OutputType) -> OutputType:
|
|
49
49
|
if self._output_transforms is not None:
|
|
@@ -41,12 +41,12 @@ class ModelFromFunction(base.BaseModel[torch.Tensor, torch.Tensor]):
|
|
|
41
41
|
self._arguments = arguments
|
|
42
42
|
self._checkpoint_path = checkpoint_path
|
|
43
43
|
|
|
44
|
-
self.load_model()
|
|
44
|
+
self.model = self.load_model()
|
|
45
45
|
|
|
46
46
|
@override
|
|
47
|
-
def load_model(self) ->
|
|
47
|
+
def load_model(self) -> nn.Module:
|
|
48
48
|
class_path = jsonargparse.class_from_function(self._path, func_return=nn.Module)
|
|
49
49
|
model = class_path(**self._arguments or {})
|
|
50
50
|
if self._checkpoint_path is not None:
|
|
51
51
|
_utils.load_model_weights(model, self._checkpoint_path)
|
|
52
|
-
|
|
52
|
+
return model
|
|
@@ -52,12 +52,12 @@ class TorchHubModel(base.BaseModel[torch.Tensor, torch.Tensor]):
|
|
|
52
52
|
self._trust_repo = trust_repo
|
|
53
53
|
self._model_kwargs = model_kwargs or {}
|
|
54
54
|
|
|
55
|
-
self.load_model()
|
|
55
|
+
self.model = self.load_model()
|
|
56
56
|
|
|
57
57
|
@override
|
|
58
|
-
def load_model(self) ->
|
|
58
|
+
def load_model(self) -> nn.Module:
|
|
59
59
|
"""Builds and loads the torch.hub model."""
|
|
60
|
-
|
|
60
|
+
model: nn.Module = torch.hub.load(
|
|
61
61
|
repo_or_dir=self._repo_or_dir,
|
|
62
62
|
model=self._model_name,
|
|
63
63
|
trust_repo=self._trust_repo,
|
|
@@ -66,21 +66,23 @@ class TorchHubModel(base.BaseModel[torch.Tensor, torch.Tensor]):
|
|
|
66
66
|
) # type: ignore
|
|
67
67
|
|
|
68
68
|
if self._checkpoint_path:
|
|
69
|
-
_utils.load_model_weights(
|
|
69
|
+
_utils.load_model_weights(model, self._checkpoint_path)
|
|
70
70
|
|
|
71
71
|
TorchHubModel.__name__ = self._model_name
|
|
72
72
|
|
|
73
|
+
return model
|
|
74
|
+
|
|
73
75
|
@override
|
|
74
76
|
def model_forward(self, tensor: torch.Tensor) -> torch.Tensor | List[torch.Tensor]:
|
|
75
77
|
if self._out_indices is not None:
|
|
76
|
-
if not hasattr(self.
|
|
78
|
+
if not hasattr(self.model, "get_intermediate_layers"):
|
|
77
79
|
raise ValueError(
|
|
78
80
|
"Only models with `get_intermediate_layers` are supported "
|
|
79
81
|
"when using `out_indices`."
|
|
80
82
|
)
|
|
81
83
|
|
|
82
84
|
return list(
|
|
83
|
-
self.
|
|
85
|
+
self.model.get_intermediate_layers( # type: ignore
|
|
84
86
|
tensor,
|
|
85
87
|
self._out_indices,
|
|
86
88
|
reshape=True,
|
|
@@ -89,4 +91,4 @@ class TorchHubModel(base.BaseModel[torch.Tensor, torch.Tensor]):
|
|
|
89
91
|
)
|
|
90
92
|
)
|
|
91
93
|
|
|
92
|
-
return self.
|
|
94
|
+
return self.model(tensor)
|
|
@@ -4,6 +4,7 @@ from typing import Any, Callable, Dict
|
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
import transformers
|
|
7
|
+
from torch import nn
|
|
7
8
|
from typing_extensions import override
|
|
8
9
|
|
|
9
10
|
from eva.core.models.wrappers import base
|
|
@@ -33,12 +34,10 @@ class HuggingFaceModel(base.BaseModel[torch.Tensor, torch.Tensor]):
|
|
|
33
34
|
self._model_name_or_path = model_name_or_path
|
|
34
35
|
self._model_kwargs = model_kwargs or {}
|
|
35
36
|
|
|
36
|
-
self.load_model()
|
|
37
|
+
self.model = self.load_model()
|
|
37
38
|
|
|
38
39
|
@override
|
|
39
|
-
def load_model(self) ->
|
|
40
|
+
def load_model(self) -> nn.Module:
|
|
40
41
|
# Use safetensors to avoid torch.load security vulnerability
|
|
41
42
|
model_kwargs = {"use_safetensors": True, **self._model_kwargs}
|
|
42
|
-
|
|
43
|
-
self._model_name_or_path, **model_kwargs
|
|
44
|
-
)
|
|
43
|
+
return transformers.AutoModel.from_pretrained(self._model_name_or_path, **model_kwargs)
|
eva/core/models/wrappers/onnx.py
CHANGED
|
@@ -30,21 +30,21 @@ class ONNXModel(base.BaseModel[torch.Tensor, torch.Tensor]):
|
|
|
30
30
|
self._path = path
|
|
31
31
|
self._device = device
|
|
32
32
|
|
|
33
|
-
self.load_model()
|
|
33
|
+
self.model = self.load_model()
|
|
34
34
|
|
|
35
35
|
@override
|
|
36
36
|
def load_model(self) -> Any:
|
|
37
37
|
if self._device == "cuda" and not torch.cuda.is_available():
|
|
38
38
|
raise ValueError("Device is set to 'cuda', but CUDA is not available.")
|
|
39
39
|
provider = "CUDAExecutionProvider" if self._device == "cuda" else "CPUExecutionProvider"
|
|
40
|
-
|
|
40
|
+
return ort.InferenceSession(self._path, providers=[provider]) # type: ignore
|
|
41
41
|
|
|
42
42
|
@override
|
|
43
43
|
def model_forward(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
44
44
|
# TODO: Use IO binding to avoid copying the tensor to CPU.
|
|
45
45
|
# https://onnxruntime.ai/docs/api/python/api_summary.html#data-on-device
|
|
46
|
-
if not isinstance(self.
|
|
46
|
+
if not isinstance(self.model, ort.InferenceSession):
|
|
47
47
|
raise ValueError("Model is not loaded.")
|
|
48
|
-
inputs = {self.
|
|
49
|
-
outputs = self.
|
|
48
|
+
inputs = {self.model.get_inputs()[0].name: tensor.detach().cpu().numpy()}
|
|
49
|
+
outputs = self.model.run(None, inputs)[0]
|
|
50
50
|
return torch.from_numpy(outputs).float().to(tensor.device)
|
eva/core/trainers/trainer.py
CHANGED
|
@@ -8,6 +8,7 @@ from lightning.pytorch import loggers as pl_loggers
|
|
|
8
8
|
from lightning.pytorch import trainer as pl_trainer
|
|
9
9
|
from lightning.pytorch.utilities import argparse
|
|
10
10
|
from lightning_fabric.utilities import cloud_io
|
|
11
|
+
from lightning_utilities.core.rank_zero import rank_zero_only
|
|
11
12
|
from typing_extensions import override
|
|
12
13
|
|
|
13
14
|
from eva.core import loggers as eva_loggers
|
|
@@ -66,6 +67,7 @@ class Trainer(pl_trainer.Trainer):
|
|
|
66
67
|
def log_dir(self) -> str | None:
|
|
67
68
|
return self.strategy.broadcast(self._log_dir)
|
|
68
69
|
|
|
70
|
+
@rank_zero_only
|
|
69
71
|
def init_logger_run(self, run_id: int | None) -> None:
|
|
70
72
|
"""Setup the loggers & log directories when starting a new run.
|
|
71
73
|
|
eva/language/__init__.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""eva language API."""
|
|
2
2
|
|
|
3
3
|
try:
|
|
4
|
+
from eva.language import models
|
|
4
5
|
from eva.language.data import datasets
|
|
5
6
|
except ImportError as e:
|
|
6
7
|
msg = (
|
|
@@ -10,4 +11,4 @@ except ImportError as e:
|
|
|
10
11
|
)
|
|
11
12
|
raise ImportError(str(e) + "\n\n" + msg) from e
|
|
12
13
|
|
|
13
|
-
__all__ = ["datasets"]
|
|
14
|
+
__all__ = ["models", "datasets"]
|
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
"""Text prediction writer callbacks."""
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
import os
|
|
5
|
+
from typing import Any, Dict, List, Literal, Sequence, Tuple, TypedDict
|
|
6
|
+
|
|
7
|
+
import lightning.pytorch as pl
|
|
8
|
+
import pandas as pd
|
|
9
|
+
import torch
|
|
10
|
+
from lightning.pytorch import callbacks
|
|
11
|
+
from torch import nn
|
|
12
|
+
from typing_extensions import NotRequired, override
|
|
13
|
+
|
|
14
|
+
from eva.core.models.modules import utils as module_utils
|
|
15
|
+
from eva.language.models.typings import TextBatch
|
|
16
|
+
from eva.language.utils.text import messages as message_utils
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ManifestEntry(TypedDict):
|
|
20
|
+
"""A single entry in the manifest file."""
|
|
21
|
+
|
|
22
|
+
prediction: str
|
|
23
|
+
"""The predicted text."""
|
|
24
|
+
|
|
25
|
+
target: str
|
|
26
|
+
"""The ground truth text."""
|
|
27
|
+
|
|
28
|
+
text: NotRequired[str]
|
|
29
|
+
"""The input text data."""
|
|
30
|
+
|
|
31
|
+
split: NotRequired[str]
|
|
32
|
+
"""The dataset split (e.g. train, val, test)."""
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class TextPredictionWriter(callbacks.BasePredictionWriter, abc.ABC):
|
|
36
|
+
"""Callback for writing generated text predictions to disk."""
|
|
37
|
+
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
output_dir: str,
|
|
41
|
+
model: nn.Module,
|
|
42
|
+
dataloader_idx_map: Dict[int, str] | None = None,
|
|
43
|
+
metadata_keys: List[str] | None = None,
|
|
44
|
+
include_input: bool = True,
|
|
45
|
+
overwrite: bool = False,
|
|
46
|
+
apply_postprocess: bool = False,
|
|
47
|
+
save_format: Literal["jsonl", "parquet", "csv"] = "jsonl",
|
|
48
|
+
) -> None:
|
|
49
|
+
"""Initializes a new callback.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
output_dir: The directory where the embeddings will be saved.
|
|
53
|
+
model: The model instance used to generate the predictions.
|
|
54
|
+
dataloader_idx_map: A dictionary mapping dataloader indices to their respective
|
|
55
|
+
names (e.g. train, val, test).
|
|
56
|
+
metadata_keys: An optional list of keys to extract from the batch metadata and store
|
|
57
|
+
as additional columns in the manifest file.
|
|
58
|
+
include_input: Whether to include the original input text messages in the output.
|
|
59
|
+
overwrite: Whether to overwrite if embeddings are already present in the specified
|
|
60
|
+
output directory. If set to `False`, an error will be raised if embeddings are
|
|
61
|
+
already present (recommended).
|
|
62
|
+
apply_postprocess: Whether to apply the postprocesses specified in the model module.
|
|
63
|
+
save_format: The file format to use for saving the manifest file with the predictions.
|
|
64
|
+
"""
|
|
65
|
+
super().__init__()
|
|
66
|
+
self.output_dir = output_dir
|
|
67
|
+
self.model = model
|
|
68
|
+
self.dataloader_idx_map = dataloader_idx_map or {}
|
|
69
|
+
self.metadata_keys = metadata_keys
|
|
70
|
+
self.include_input = include_input
|
|
71
|
+
self.overwrite = overwrite
|
|
72
|
+
self.apply_postprocess = apply_postprocess
|
|
73
|
+
self.save_format = save_format
|
|
74
|
+
|
|
75
|
+
self._manifest_path = os.path.join(self.output_dir, f"manifest.{self.save_format}")
|
|
76
|
+
self._data: List[ManifestEntry] = []
|
|
77
|
+
|
|
78
|
+
@override
|
|
79
|
+
def on_predict_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
|
|
80
|
+
self._check_if_exists()
|
|
81
|
+
|
|
82
|
+
self.model = self.model.to(pl_module.device)
|
|
83
|
+
self.model.eval()
|
|
84
|
+
|
|
85
|
+
@override
|
|
86
|
+
def write_on_batch_end(
|
|
87
|
+
self,
|
|
88
|
+
trainer: pl.Trainer,
|
|
89
|
+
pl_module: pl.LightningModule,
|
|
90
|
+
prediction: Any,
|
|
91
|
+
batch_indices: Sequence[int],
|
|
92
|
+
batch: TextBatch,
|
|
93
|
+
batch_idx: int,
|
|
94
|
+
dataloader_idx: int,
|
|
95
|
+
) -> None:
|
|
96
|
+
text_batch, target_batch, metadata_batch = self._unpack_batch(batch)
|
|
97
|
+
has_target = target_batch is not None
|
|
98
|
+
split = self.dataloader_idx_map.get(dataloader_idx, "")
|
|
99
|
+
|
|
100
|
+
prediction_batch = self._get_predictions(batch)
|
|
101
|
+
|
|
102
|
+
target_batch, prediction_batch = self._apply_postprocess(
|
|
103
|
+
pl_module, target_batch, prediction_batch
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
for i in range(len(batch_indices)):
|
|
107
|
+
entry: ManifestEntry = {
|
|
108
|
+
"text": message_utils.serialize(text_batch[i]),
|
|
109
|
+
"prediction": str(prediction_batch[i]),
|
|
110
|
+
"target": str(target_batch[i]) if has_target else "",
|
|
111
|
+
"split": split if split else "",
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
if self.metadata_keys is not None and metadata_batch is not None:
|
|
115
|
+
for key in self.metadata_keys:
|
|
116
|
+
entry[key] = metadata_batch[key][i]
|
|
117
|
+
|
|
118
|
+
self._data.append(entry)
|
|
119
|
+
|
|
120
|
+
@override
|
|
121
|
+
def on_predict_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
|
|
122
|
+
"""Saves the gathered predictions to a manifest file."""
|
|
123
|
+
df = pd.DataFrame(self._data)
|
|
124
|
+
|
|
125
|
+
match self.save_format:
|
|
126
|
+
case "jsonl":
|
|
127
|
+
df.to_json(self._manifest_path, orient="records", lines=True)
|
|
128
|
+
case "parquet":
|
|
129
|
+
df.to_parquet(self._manifest_path, index=False)
|
|
130
|
+
case "csv":
|
|
131
|
+
df.to_csv(self._manifest_path, index=False)
|
|
132
|
+
case _:
|
|
133
|
+
raise ValueError(f"Unsupported save format: {self.save_format}")
|
|
134
|
+
|
|
135
|
+
def _get_predictions(self, batch: TextBatch) -> List[str]:
|
|
136
|
+
with torch.no_grad():
|
|
137
|
+
predictions = self.model(batch)
|
|
138
|
+
|
|
139
|
+
if not isinstance(predictions, list) or not all(isinstance(p, str) for p in predictions):
|
|
140
|
+
raise ValueError("The model's output should be a list of strings.")
|
|
141
|
+
|
|
142
|
+
return predictions
|
|
143
|
+
|
|
144
|
+
def _check_if_exists(self) -> None:
|
|
145
|
+
"""Checks if the output directory already exists and if it should be overwritten."""
|
|
146
|
+
os.makedirs(self.output_dir, exist_ok=True)
|
|
147
|
+
if os.path.exists(self._manifest_path) and not self.overwrite:
|
|
148
|
+
raise FileExistsError(
|
|
149
|
+
f"The specified output directory already exists: {self.output_dir}. This "
|
|
150
|
+
"either means that the predictions have been computed before or that a "
|
|
151
|
+
"wrong output directory is being used."
|
|
152
|
+
)
|
|
153
|
+
os.makedirs(self.output_dir, exist_ok=True)
|
|
154
|
+
|
|
155
|
+
def _apply_postprocess(
|
|
156
|
+
self, pl_module: pl.LightningModule, targets: Any, predictions: Any
|
|
157
|
+
) -> Tuple[List[Any], List[Any]]:
|
|
158
|
+
def _to_list(data: Any) -> List[Any]:
|
|
159
|
+
if isinstance(data, torch.Tensor):
|
|
160
|
+
return data.cpu().tolist()
|
|
161
|
+
return data
|
|
162
|
+
|
|
163
|
+
if self.apply_postprocess and hasattr(pl_module, "postprocess"):
|
|
164
|
+
if (
|
|
165
|
+
isinstance(pl_module.postprocess, module_utils.BatchPostProcess)
|
|
166
|
+
and pl_module.postprocess.predictions_transforms is not None
|
|
167
|
+
):
|
|
168
|
+
outputs = {"targets": targets, "predictions": predictions}
|
|
169
|
+
pl_module.postprocess(outputs)
|
|
170
|
+
targets, predictions = outputs["targets"], outputs["predictions"]
|
|
171
|
+
|
|
172
|
+
return _to_list(targets), _to_list(predictions)
|
|
173
|
+
|
|
174
|
+
def _unpack_batch(self, batch: TextBatch) -> Tuple[list, list | None, dict | None]:
|
|
175
|
+
text_batch, target_batch, metadata_batch = TextBatch(*batch)
|
|
176
|
+
return text_batch, target_batch, metadata_batch
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
"""Collate functions for text data."""
|
|
2
|
+
|
|
3
|
+
from typing import List
|
|
4
|
+
|
|
5
|
+
from torch.utils.data._utils.collate import default_collate
|
|
6
|
+
|
|
7
|
+
from eva.language.data.datasets.typings import PredictionSample, TextSample
|
|
8
|
+
from eva.language.models.typings import PredictionBatch, TextBatch
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def text_collate(batch: List[TextSample]) -> TextBatch:
|
|
12
|
+
"""Collate function for text data that keeps texts as separate strings.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
batch: List of tuples containing (text, target, metadata) from the dataset
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
A batch of text samples with targets and metadata.
|
|
19
|
+
"""
|
|
20
|
+
texts, targets, metadata = zip(*batch, strict=False)
|
|
21
|
+
first_sample = batch[0]
|
|
22
|
+
metadata = None
|
|
23
|
+
if first_sample.metadata is not None:
|
|
24
|
+
metadata = {
|
|
25
|
+
k: [sample.metadata[k] for sample in batch if sample.metadata]
|
|
26
|
+
for k in first_sample.metadata.keys()
|
|
27
|
+
}
|
|
28
|
+
return TextBatch(
|
|
29
|
+
text=list(texts),
|
|
30
|
+
target=default_collate(targets) if targets[0] is not None else None,
|
|
31
|
+
metadata=metadata,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def prediction_collate(batch: List[PredictionSample]) -> PredictionBatch:
|
|
36
|
+
"""Collate function for text prediction data.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
batch: List of tuples containing (prediction, target, text, metadata) from the dataset
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
A batch of prediction samples.
|
|
43
|
+
"""
|
|
44
|
+
predictions, targets, texts, metadata = zip(*batch, strict=False)
|
|
45
|
+
first_sample = batch[0]
|
|
46
|
+
metadata = None
|
|
47
|
+
if first_sample.metadata is not None:
|
|
48
|
+
metadata = {
|
|
49
|
+
k: [sample.metadata[k] for sample in batch if sample.metadata]
|
|
50
|
+
for k in first_sample.metadata.keys()
|
|
51
|
+
}
|
|
52
|
+
return PredictionBatch(
|
|
53
|
+
prediction=default_collate(predictions) if predictions[0] is not None else None,
|
|
54
|
+
target=default_collate(targets) if targets[0] is not None else None,
|
|
55
|
+
text=list(texts) if first_sample.text is not None else None,
|
|
56
|
+
metadata=metadata,
|
|
57
|
+
)
|
|
@@ -1,9 +1,11 @@
|
|
|
1
1
|
"""Language Datasets API."""
|
|
2
2
|
|
|
3
|
+
from eva.language.data.datasets.base import LanguageDataset
|
|
3
4
|
from eva.language.data.datasets.classification import PubMedQA
|
|
4
|
-
from eva.language.data.datasets.
|
|
5
|
+
from eva.language.data.datasets.prediction import TextPredictionDataset
|
|
5
6
|
|
|
6
7
|
__all__ = [
|
|
7
8
|
"PubMedQA",
|
|
8
9
|
"LanguageDataset",
|
|
10
|
+
"TextPredictionDataset",
|
|
9
11
|
]
|
|
@@ -1,15 +1,13 @@
|
|
|
1
1
|
"""Base for text classification datasets."""
|
|
2
2
|
|
|
3
|
-
import
|
|
4
|
-
from typing import Any, Dict, List, Tuple
|
|
3
|
+
from typing import Dict, List
|
|
5
4
|
|
|
6
5
|
import torch
|
|
7
|
-
from typing_extensions import override
|
|
8
6
|
|
|
9
|
-
from eva.language.data.datasets.
|
|
7
|
+
from eva.language.data.datasets.text import TextDataset
|
|
10
8
|
|
|
11
9
|
|
|
12
|
-
class TextClassification(
|
|
10
|
+
class TextClassification(TextDataset[torch.Tensor]):
|
|
13
11
|
"""Text classification abstract dataset."""
|
|
14
12
|
|
|
15
13
|
def __init__(self) -> None:
|
|
@@ -23,41 +21,3 @@ class TextClassification(LanguageDataset[Tuple[str, torch.Tensor, Dict[str, Any]
|
|
|
23
21
|
@property
|
|
24
22
|
def class_to_idx(self) -> Dict[str, int] | None:
|
|
25
23
|
"""Returns class name to index mapping."""
|
|
26
|
-
|
|
27
|
-
def load_metadata(self, index: int) -> Dict[str, Any] | None:
|
|
28
|
-
"""Returns the dataset metadata.
|
|
29
|
-
|
|
30
|
-
Args:
|
|
31
|
-
index: The index of the data sample.
|
|
32
|
-
|
|
33
|
-
Returns:
|
|
34
|
-
The sample metadata.
|
|
35
|
-
"""
|
|
36
|
-
|
|
37
|
-
@abc.abstractmethod
|
|
38
|
-
def load_text(self, index: int) -> str:
|
|
39
|
-
"""Returns the text content.
|
|
40
|
-
|
|
41
|
-
Args:
|
|
42
|
-
index: The index of the data sample.
|
|
43
|
-
|
|
44
|
-
Returns:
|
|
45
|
-
The text content.
|
|
46
|
-
"""
|
|
47
|
-
raise NotImplementedError
|
|
48
|
-
|
|
49
|
-
@abc.abstractmethod
|
|
50
|
-
def load_target(self, index: int) -> torch.Tensor:
|
|
51
|
-
"""Returns the target label.
|
|
52
|
-
|
|
53
|
-
Args:
|
|
54
|
-
index: The index of the data sample.
|
|
55
|
-
|
|
56
|
-
Returns:
|
|
57
|
-
The target label.
|
|
58
|
-
"""
|
|
59
|
-
raise NotImplementedError
|
|
60
|
-
|
|
61
|
-
@override
|
|
62
|
-
def __getitem__(self, index: int) -> Tuple[str, torch.Tensor, Dict[str, Any]]:
|
|
63
|
-
return (self.load_text(index), self.load_target(index), self.load_metadata(index) or {})
|