kaiko-eva 0.3.3__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaiko-eva might be problematic. Click here for more details.
- eva/core/callbacks/config.py +15 -6
- eva/core/callbacks/writers/embeddings/base.py +44 -10
- eva/core/cli/setup.py +1 -1
- eva/core/data/dataloaders/__init__.py +1 -2
- eva/core/data/samplers/classification/balanced.py +24 -12
- eva/core/data/samplers/random.py +17 -10
- eva/core/interface/interface.py +21 -0
- eva/core/loggers/utils/wandb.py +4 -1
- eva/core/models/modules/module.py +2 -2
- eva/core/models/wrappers/base.py +2 -2
- eva/core/models/wrappers/from_function.py +3 -3
- eva/core/models/wrappers/from_torchhub.py +9 -7
- eva/core/models/wrappers/huggingface.py +4 -5
- eva/core/models/wrappers/onnx.py +5 -5
- eva/core/trainers/trainer.py +13 -1
- eva/core/utils/__init__.py +2 -1
- eva/core/utils/distributed.py +12 -0
- eva/core/utils/paths.py +14 -0
- eva/core/utils/requirements.py +52 -6
- eva/language/__init__.py +2 -1
- eva/language/callbacks/__init__.py +5 -0
- eva/language/callbacks/writers/__init__.py +5 -0
- eva/language/callbacks/writers/prediction.py +201 -0
- eva/language/data/dataloaders/__init__.py +5 -0
- eva/language/data/dataloaders/collate_fn/__init__.py +5 -0
- eva/language/data/dataloaders/collate_fn/text.py +57 -0
- eva/language/data/datasets/__init__.py +3 -1
- eva/language/data/datasets/{language.py → base.py} +1 -1
- eva/language/data/datasets/classification/base.py +3 -43
- eva/language/data/datasets/classification/pubmedqa.py +36 -4
- eva/language/data/datasets/prediction.py +151 -0
- eva/language/data/datasets/schemas.py +18 -0
- eva/language/data/datasets/text.py +92 -0
- eva/language/data/datasets/typings.py +39 -0
- eva/language/data/messages.py +60 -0
- eva/language/models/__init__.py +15 -11
- eva/language/models/modules/__init__.py +2 -2
- eva/language/models/modules/language.py +94 -0
- eva/language/models/networks/__init__.py +12 -0
- eva/language/models/networks/alibaba.py +26 -0
- eva/language/models/networks/api/__init__.py +11 -0
- eva/language/models/networks/api/anthropic.py +34 -0
- eva/language/models/networks/registry.py +5 -0
- eva/language/models/typings.py +56 -0
- eva/language/models/wrappers/__init__.py +13 -5
- eva/language/models/wrappers/base.py +47 -0
- eva/language/models/wrappers/from_registry.py +54 -0
- eva/language/models/wrappers/huggingface.py +57 -11
- eva/language/models/wrappers/litellm.py +91 -46
- eva/language/models/wrappers/vllm.py +37 -13
- eva/language/utils/__init__.py +2 -1
- eva/language/utils/str_to_int_tensor.py +20 -12
- eva/language/utils/text/__init__.py +5 -0
- eva/language/utils/text/messages.py +113 -0
- eva/multimodal/__init__.py +6 -0
- eva/multimodal/callbacks/__init__.py +5 -0
- eva/multimodal/callbacks/writers/__init__.py +5 -0
- eva/multimodal/callbacks/writers/prediction.py +39 -0
- eva/multimodal/data/__init__.py +5 -0
- eva/multimodal/data/dataloaders/__init__.py +5 -0
- eva/multimodal/data/dataloaders/collate_fn/__init__.py +5 -0
- eva/multimodal/data/dataloaders/collate_fn/text_image.py +28 -0
- eva/multimodal/data/datasets/__init__.py +6 -0
- eva/multimodal/data/datasets/base.py +13 -0
- eva/multimodal/data/datasets/multiple_choice/__init__.py +5 -0
- eva/multimodal/data/datasets/multiple_choice/patch_camelyon.py +80 -0
- eva/multimodal/data/datasets/schemas.py +14 -0
- eva/multimodal/data/datasets/text_image.py +77 -0
- eva/multimodal/data/datasets/typings.py +27 -0
- eva/multimodal/models/__init__.py +8 -0
- eva/multimodal/models/modules/__init__.py +5 -0
- eva/multimodal/models/modules/vision_language.py +56 -0
- eva/multimodal/models/networks/__init__.py +14 -0
- eva/multimodal/models/networks/alibaba.py +40 -0
- eva/multimodal/models/networks/api/__init__.py +11 -0
- eva/multimodal/models/networks/api/anthropic.py +34 -0
- eva/multimodal/models/networks/others.py +48 -0
- eva/multimodal/models/networks/registry.py +5 -0
- eva/multimodal/models/typings.py +27 -0
- eva/multimodal/models/wrappers/__init__.py +13 -0
- eva/multimodal/models/wrappers/base.py +48 -0
- eva/multimodal/models/wrappers/from_registry.py +54 -0
- eva/multimodal/models/wrappers/huggingface.py +193 -0
- eva/multimodal/models/wrappers/litellm.py +58 -0
- eva/multimodal/utils/__init__.py +1 -0
- eva/multimodal/utils/batch/__init__.py +5 -0
- eva/multimodal/utils/batch/unpack.py +11 -0
- eva/multimodal/utils/image/__init__.py +5 -0
- eva/multimodal/utils/image/encode.py +28 -0
- eva/multimodal/utils/text/__init__.py +1 -0
- eva/multimodal/utils/text/messages.py +79 -0
- eva/vision/data/datasets/classification/breakhis.py +5 -8
- eva/vision/data/datasets/classification/panda.py +12 -5
- eva/vision/data/datasets/classification/patch_camelyon.py +8 -6
- eva/vision/data/datasets/segmentation/btcv.py +1 -1
- eva/vision/data/datasets/segmentation/consep.py +1 -1
- eva/vision/data/datasets/segmentation/lits17.py +1 -1
- eva/vision/data/datasets/segmentation/monusac.py +15 -6
- eva/vision/data/datasets/segmentation/msd_task7_pancreas.py +1 -1
- eva/vision/data/transforms/__init__.py +2 -1
- eva/vision/data/transforms/base/__init__.py +2 -1
- eva/vision/data/transforms/base/monai.py +2 -2
- eva/vision/data/transforms/base/torchvision.py +33 -0
- eva/vision/data/transforms/common/squeeze.py +6 -3
- eva/vision/data/transforms/croppad/crop_foreground.py +8 -7
- eva/vision/data/transforms/croppad/rand_crop_by_label_classes.py +6 -5
- eva/vision/data/transforms/croppad/rand_crop_by_pos_neg_label.py +6 -5
- eva/vision/data/transforms/croppad/rand_spatial_crop.py +8 -7
- eva/vision/data/transforms/croppad/spatial_pad.py +6 -6
- eva/vision/data/transforms/intensity/rand_scale_intensity.py +3 -3
- eva/vision/data/transforms/intensity/rand_shift_intensity.py +3 -3
- eva/vision/data/transforms/intensity/scale_intensity_ranged.py +5 -5
- eva/vision/data/transforms/spatial/__init__.py +2 -1
- eva/vision/data/transforms/spatial/flip.py +8 -7
- eva/vision/data/transforms/spatial/functional/__init__.py +5 -0
- eva/vision/data/transforms/spatial/functional/resize.py +26 -0
- eva/vision/data/transforms/spatial/resize.py +63 -0
- eva/vision/data/transforms/spatial/rotate.py +8 -7
- eva/vision/data/transforms/spatial/spacing.py +7 -6
- eva/vision/data/transforms/utility/ensure_channel_first.py +6 -6
- eva/vision/models/networks/backbones/universal/vit.py +24 -0
- eva/vision/models/wrappers/from_registry.py +6 -5
- eva/vision/models/wrappers/from_timm.py +6 -4
- {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/METADATA +17 -3
- {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/RECORD +128 -66
- eva/core/data/dataloaders/collate_fn/__init__.py +0 -5
- eva/core/data/dataloaders/collate_fn/collate.py +0 -24
- eva/language/models/modules/text.py +0 -85
- {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/WHEEL +0 -0
- {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/entry_points.txt +0 -0
- {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
"""Text prediction writer callbacks."""
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
import os
|
|
5
|
+
from typing import Any, Dict, List, Literal, Sequence, Tuple, TypedDict
|
|
6
|
+
|
|
7
|
+
import lightning.pytorch as pl
|
|
8
|
+
import pandas as pd
|
|
9
|
+
import torch
|
|
10
|
+
import torch.distributed as dist
|
|
11
|
+
from lightning.pytorch import callbacks
|
|
12
|
+
from torch import nn
|
|
13
|
+
from typing_extensions import NotRequired, override
|
|
14
|
+
|
|
15
|
+
from eva.core.models.modules import utils as module_utils
|
|
16
|
+
from eva.core.utils import distributed as dist_utils
|
|
17
|
+
from eva.language.models.typings import TextBatch
|
|
18
|
+
from eva.language.utils.text import messages as message_utils
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ManifestEntry(TypedDict):
|
|
22
|
+
"""A single entry in the manifest file."""
|
|
23
|
+
|
|
24
|
+
prediction: str
|
|
25
|
+
"""The predicted text."""
|
|
26
|
+
|
|
27
|
+
target: str
|
|
28
|
+
"""The ground truth text."""
|
|
29
|
+
|
|
30
|
+
text: NotRequired[str]
|
|
31
|
+
"""The input text data."""
|
|
32
|
+
|
|
33
|
+
split: NotRequired[str]
|
|
34
|
+
"""The dataset split (e.g. train, val, test)."""
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class TextPredictionWriter(callbacks.BasePredictionWriter, abc.ABC):
|
|
38
|
+
"""Callback for writing generated text predictions to disk."""
|
|
39
|
+
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
output_dir: str,
|
|
43
|
+
model: nn.Module,
|
|
44
|
+
dataloader_idx_map: Dict[int, str] | None = None,
|
|
45
|
+
metadata_keys: List[str] | None = None,
|
|
46
|
+
include_input: bool = True,
|
|
47
|
+
overwrite: bool = False,
|
|
48
|
+
apply_postprocess: bool = False,
|
|
49
|
+
save_format: Literal["jsonl", "parquet", "csv"] = "jsonl",
|
|
50
|
+
) -> None:
|
|
51
|
+
"""Initializes a new callback.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
output_dir: The directory where the embeddings will be saved.
|
|
55
|
+
model: The model instance used to generate the predictions.
|
|
56
|
+
dataloader_idx_map: A dictionary mapping dataloader indices to their respective
|
|
57
|
+
names (e.g. train, val, test).
|
|
58
|
+
metadata_keys: An optional list of keys to extract from the batch metadata and store
|
|
59
|
+
as additional columns in the manifest file.
|
|
60
|
+
include_input: Whether to include the original input text messages in the output.
|
|
61
|
+
overwrite: Whether to overwrite if embeddings are already present in the specified
|
|
62
|
+
output directory. If set to `False`, an error will be raised if embeddings are
|
|
63
|
+
already present (recommended).
|
|
64
|
+
apply_postprocess: Whether to apply the postprocesses specified in the model module.
|
|
65
|
+
save_format: The file format to use for saving the manifest file with the predictions.
|
|
66
|
+
"""
|
|
67
|
+
super().__init__()
|
|
68
|
+
self.output_dir = output_dir
|
|
69
|
+
self.model = model
|
|
70
|
+
self.dataloader_idx_map = dataloader_idx_map or {}
|
|
71
|
+
self.metadata_keys = metadata_keys
|
|
72
|
+
self.include_input = include_input
|
|
73
|
+
self.overwrite = overwrite
|
|
74
|
+
self.apply_postprocess = apply_postprocess
|
|
75
|
+
self.save_format = save_format
|
|
76
|
+
|
|
77
|
+
self._manifest_path = os.path.join(self.output_dir, f"manifest.{self.save_format}")
|
|
78
|
+
self._data: List[ManifestEntry] = []
|
|
79
|
+
self._is_rank_zero: bool = False
|
|
80
|
+
|
|
81
|
+
@override
|
|
82
|
+
def on_predict_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
|
|
83
|
+
self._is_rank_zero = trainer.is_global_zero
|
|
84
|
+
|
|
85
|
+
if self._is_rank_zero:
|
|
86
|
+
self._check_if_exists()
|
|
87
|
+
|
|
88
|
+
self.model = self.model.to(pl_module.device)
|
|
89
|
+
self.model.eval()
|
|
90
|
+
|
|
91
|
+
@override
|
|
92
|
+
def write_on_batch_end(
|
|
93
|
+
self,
|
|
94
|
+
trainer: pl.Trainer,
|
|
95
|
+
pl_module: pl.LightningModule,
|
|
96
|
+
prediction: Any,
|
|
97
|
+
batch_indices: Sequence[int],
|
|
98
|
+
batch: TextBatch,
|
|
99
|
+
batch_idx: int,
|
|
100
|
+
dataloader_idx: int,
|
|
101
|
+
) -> None:
|
|
102
|
+
text_batch, target_batch, metadata_batch = self._unpack_batch(batch)
|
|
103
|
+
has_target = target_batch is not None
|
|
104
|
+
split = self.dataloader_idx_map.get(dataloader_idx, "")
|
|
105
|
+
|
|
106
|
+
prediction_batch = self._get_predictions(batch)
|
|
107
|
+
|
|
108
|
+
target_batch, prediction_batch = self._apply_postprocess(
|
|
109
|
+
pl_module, target_batch, prediction_batch
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
for i in range(len(batch_indices)):
|
|
113
|
+
entry: ManifestEntry = {
|
|
114
|
+
"prediction": str(prediction_batch[i]),
|
|
115
|
+
"target": str(target_batch[i]) if has_target else "",
|
|
116
|
+
"split": split if split else "",
|
|
117
|
+
}
|
|
118
|
+
if self.include_input:
|
|
119
|
+
entry["text"] = message_utils.serialize(text_batch[i])
|
|
120
|
+
|
|
121
|
+
if self.metadata_keys is not None and metadata_batch is not None:
|
|
122
|
+
for key in self.metadata_keys:
|
|
123
|
+
entry[key] = metadata_batch[key][i]
|
|
124
|
+
|
|
125
|
+
self._data.append(entry)
|
|
126
|
+
|
|
127
|
+
@override
|
|
128
|
+
def on_predict_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
|
|
129
|
+
"""Saves the gathered predictions to a manifest file."""
|
|
130
|
+
if dist_utils.is_distributed():
|
|
131
|
+
dist.barrier()
|
|
132
|
+
data = self._gather_data_from_ranks()
|
|
133
|
+
else:
|
|
134
|
+
data = self._data
|
|
135
|
+
|
|
136
|
+
if self._is_rank_zero:
|
|
137
|
+
df = pd.DataFrame(data)
|
|
138
|
+
|
|
139
|
+
match self.save_format:
|
|
140
|
+
case "jsonl":
|
|
141
|
+
df.to_json(self._manifest_path, orient="records", lines=True)
|
|
142
|
+
case "parquet":
|
|
143
|
+
df.to_parquet(self._manifest_path, index=False)
|
|
144
|
+
case "csv":
|
|
145
|
+
df.to_csv(self._manifest_path, index=False)
|
|
146
|
+
case _:
|
|
147
|
+
raise ValueError(f"Unsupported save format: {self.save_format}")
|
|
148
|
+
|
|
149
|
+
def _gather_data_from_ranks(self) -> List[ManifestEntry]:
|
|
150
|
+
world_size = dist.get_world_size()
|
|
151
|
+
gathered: List[List[ManifestEntry] | None] = [None] * world_size
|
|
152
|
+
dist.all_gather_object(gathered, self._data)
|
|
153
|
+
return [row for shard in gathered for row in (shard or [])]
|
|
154
|
+
|
|
155
|
+
def _get_predictions(self, batch: TextBatch) -> List[str]:
|
|
156
|
+
with torch.no_grad():
|
|
157
|
+
output = self.model(batch)
|
|
158
|
+
|
|
159
|
+
if (
|
|
160
|
+
not isinstance(output, dict)
|
|
161
|
+
or "generated_text" not in output
|
|
162
|
+
or not all(isinstance(p, str) for p in output["generated_text"])
|
|
163
|
+
):
|
|
164
|
+
raise ValueError(
|
|
165
|
+
f"A dictionary with 'generated_text' key is expected, got {type(output)}"
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
return output["generated_text"]
|
|
169
|
+
|
|
170
|
+
def _check_if_exists(self) -> None:
|
|
171
|
+
"""Checks if the output directory already exists and if it should be overwritten."""
|
|
172
|
+
os.makedirs(self.output_dir, exist_ok=True)
|
|
173
|
+
if os.path.exists(self._manifest_path) and not self.overwrite:
|
|
174
|
+
raise FileExistsError(
|
|
175
|
+
f"The specified output directory already exists: {self.output_dir}. This "
|
|
176
|
+
"either means that the predictions have been computed before or that a "
|
|
177
|
+
"wrong output directory is being used."
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
def _apply_postprocess(
|
|
181
|
+
self, pl_module: pl.LightningModule, targets: Any, predictions: Any
|
|
182
|
+
) -> Tuple[List[Any], List[Any]]:
|
|
183
|
+
def _to_list(data: Any) -> List[Any]:
|
|
184
|
+
if isinstance(data, torch.Tensor):
|
|
185
|
+
return data.cpu().tolist()
|
|
186
|
+
return data
|
|
187
|
+
|
|
188
|
+
if self.apply_postprocess and hasattr(pl_module, "postprocess"):
|
|
189
|
+
if (
|
|
190
|
+
isinstance(pl_module.postprocess, module_utils.BatchPostProcess)
|
|
191
|
+
and pl_module.postprocess.predictions_transforms is not None
|
|
192
|
+
):
|
|
193
|
+
outputs = {"targets": targets, "predictions": predictions}
|
|
194
|
+
pl_module.postprocess(outputs)
|
|
195
|
+
targets, predictions = outputs["targets"], outputs["predictions"]
|
|
196
|
+
|
|
197
|
+
return _to_list(targets), _to_list(predictions)
|
|
198
|
+
|
|
199
|
+
def _unpack_batch(self, batch: TextBatch) -> Tuple[list, list | None, dict | None]:
|
|
200
|
+
text_batch, target_batch, metadata_batch = TextBatch(*batch)
|
|
201
|
+
return text_batch, target_batch, metadata_batch
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
"""Collate functions for text data."""
|
|
2
|
+
|
|
3
|
+
from typing import List
|
|
4
|
+
|
|
5
|
+
from torch.utils.data._utils.collate import default_collate
|
|
6
|
+
|
|
7
|
+
from eva.language.data.datasets.typings import PredictionSample, TextSample
|
|
8
|
+
from eva.language.models.typings import PredictionBatch, TextBatch
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def text_collate(batch: List[TextSample]) -> TextBatch:
|
|
12
|
+
"""Collate function for text data that keeps texts as separate strings.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
batch: List of tuples containing (text, target, metadata) from the dataset
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
A batch of text samples with targets and metadata.
|
|
19
|
+
"""
|
|
20
|
+
texts, targets, metadata = zip(*batch, strict=False)
|
|
21
|
+
first_sample = batch[0]
|
|
22
|
+
metadata = None
|
|
23
|
+
if first_sample.metadata is not None:
|
|
24
|
+
metadata = {
|
|
25
|
+
k: [sample.metadata[k] for sample in batch if sample.metadata]
|
|
26
|
+
for k in first_sample.metadata.keys()
|
|
27
|
+
}
|
|
28
|
+
return TextBatch(
|
|
29
|
+
text=list(texts),
|
|
30
|
+
target=default_collate(targets) if targets[0] is not None else None,
|
|
31
|
+
metadata=metadata,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def prediction_collate(batch: List[PredictionSample]) -> PredictionBatch:
|
|
36
|
+
"""Collate function for text prediction data.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
batch: List of tuples containing (prediction, target, text, metadata) from the dataset
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
A batch of prediction samples.
|
|
43
|
+
"""
|
|
44
|
+
predictions, targets, texts, metadata = zip(*batch, strict=False)
|
|
45
|
+
first_sample = batch[0]
|
|
46
|
+
metadata = None
|
|
47
|
+
if first_sample.metadata is not None:
|
|
48
|
+
metadata = {
|
|
49
|
+
k: [sample.metadata[k] for sample in batch if sample.metadata]
|
|
50
|
+
for k in first_sample.metadata.keys()
|
|
51
|
+
}
|
|
52
|
+
return PredictionBatch(
|
|
53
|
+
prediction=default_collate(predictions) if predictions[0] is not None else None,
|
|
54
|
+
target=default_collate(targets) if targets[0] is not None else None,
|
|
55
|
+
text=list(texts) if first_sample.text is not None else None,
|
|
56
|
+
metadata=metadata,
|
|
57
|
+
)
|
|
@@ -1,9 +1,11 @@
|
|
|
1
1
|
"""Language Datasets API."""
|
|
2
2
|
|
|
3
|
+
from eva.language.data.datasets.base import LanguageDataset
|
|
3
4
|
from eva.language.data.datasets.classification import PubMedQA
|
|
4
|
-
from eva.language.data.datasets.
|
|
5
|
+
from eva.language.data.datasets.prediction import TextPredictionDataset
|
|
5
6
|
|
|
6
7
|
__all__ = [
|
|
7
8
|
"PubMedQA",
|
|
8
9
|
"LanguageDataset",
|
|
10
|
+
"TextPredictionDataset",
|
|
9
11
|
]
|
|
@@ -1,15 +1,13 @@
|
|
|
1
1
|
"""Base for text classification datasets."""
|
|
2
2
|
|
|
3
|
-
import
|
|
4
|
-
from typing import Any, Dict, List, Tuple
|
|
3
|
+
from typing import Dict, List
|
|
5
4
|
|
|
6
5
|
import torch
|
|
7
|
-
from typing_extensions import override
|
|
8
6
|
|
|
9
|
-
from eva.language.data.datasets.
|
|
7
|
+
from eva.language.data.datasets.text import TextDataset
|
|
10
8
|
|
|
11
9
|
|
|
12
|
-
class TextClassification(
|
|
10
|
+
class TextClassification(TextDataset[torch.Tensor]):
|
|
13
11
|
"""Text classification abstract dataset."""
|
|
14
12
|
|
|
15
13
|
def __init__(self) -> None:
|
|
@@ -23,41 +21,3 @@ class TextClassification(LanguageDataset[Tuple[str, torch.Tensor, Dict[str, Any]
|
|
|
23
21
|
@property
|
|
24
22
|
def class_to_idx(self) -> Dict[str, int] | None:
|
|
25
23
|
"""Returns class name to index mapping."""
|
|
26
|
-
|
|
27
|
-
def load_metadata(self, index: int) -> Dict[str, Any] | None:
|
|
28
|
-
"""Returns the dataset metadata.
|
|
29
|
-
|
|
30
|
-
Args:
|
|
31
|
-
index: The index of the data sample.
|
|
32
|
-
|
|
33
|
-
Returns:
|
|
34
|
-
The sample metadata.
|
|
35
|
-
"""
|
|
36
|
-
|
|
37
|
-
@abc.abstractmethod
|
|
38
|
-
def load_text(self, index: int) -> str:
|
|
39
|
-
"""Returns the text content.
|
|
40
|
-
|
|
41
|
-
Args:
|
|
42
|
-
index: The index of the data sample.
|
|
43
|
-
|
|
44
|
-
Returns:
|
|
45
|
-
The text content.
|
|
46
|
-
"""
|
|
47
|
-
raise NotImplementedError
|
|
48
|
-
|
|
49
|
-
@abc.abstractmethod
|
|
50
|
-
def load_target(self, index: int) -> torch.Tensor:
|
|
51
|
-
"""Returns the target label.
|
|
52
|
-
|
|
53
|
-
Args:
|
|
54
|
-
index: The index of the data sample.
|
|
55
|
-
|
|
56
|
-
Returns:
|
|
57
|
-
The target label.
|
|
58
|
-
"""
|
|
59
|
-
raise NotImplementedError
|
|
60
|
-
|
|
61
|
-
@override
|
|
62
|
-
def __getitem__(self, index: int) -> Tuple[str, torch.Tensor, Dict[str, Any]]:
|
|
63
|
-
return (self.load_text(index), self.load_target(index), self.load_metadata(index) or {})
|
|
@@ -10,11 +10,20 @@ from loguru import logger
|
|
|
10
10
|
from typing_extensions import override
|
|
11
11
|
|
|
12
12
|
from eva.language.data.datasets.classification import base
|
|
13
|
+
from eva.language.data.messages import MessageSeries, UserMessage
|
|
13
14
|
|
|
14
15
|
|
|
15
16
|
class PubMedQA(base.TextClassification):
|
|
16
17
|
"""Dataset class for PubMedQA question answering task."""
|
|
17
18
|
|
|
19
|
+
_expected_dataset_lengths: Dict[str | None, int] = {
|
|
20
|
+
"train": 450,
|
|
21
|
+
"val": 50,
|
|
22
|
+
"test": 500,
|
|
23
|
+
None: 500,
|
|
24
|
+
}
|
|
25
|
+
"""Expected dataset lengths for the splits and complete dataset."""
|
|
26
|
+
|
|
18
27
|
_license: str = "MIT License (https://github.com/pubmedqa/pubmedqa/blob/master/LICENSE)"
|
|
19
28
|
"""Dataset license."""
|
|
20
29
|
|
|
@@ -52,7 +61,14 @@ class PubMedQA(base.TextClassification):
|
|
|
52
61
|
"""
|
|
53
62
|
dataset_name = "bigbio/pubmed_qa"
|
|
54
63
|
config_name = "pubmed_qa_labeled_fold0_source"
|
|
55
|
-
|
|
64
|
+
|
|
65
|
+
match self._split:
|
|
66
|
+
case "val":
|
|
67
|
+
split = "validation"
|
|
68
|
+
case None:
|
|
69
|
+
split = "train+test+validation"
|
|
70
|
+
case _:
|
|
71
|
+
split = self._split
|
|
56
72
|
|
|
57
73
|
if self._download:
|
|
58
74
|
logger.info("Downloading dataset from HuggingFace Hub")
|
|
@@ -88,7 +104,7 @@ class PubMedQA(base.TextClassification):
|
|
|
88
104
|
dataset_path = None
|
|
89
105
|
|
|
90
106
|
if self._root:
|
|
91
|
-
dataset_path = self._root
|
|
107
|
+
dataset_path = os.path.join(self._root, self._split) if self._split else self._root
|
|
92
108
|
os.makedirs(self._root, exist_ok=True)
|
|
93
109
|
|
|
94
110
|
try:
|
|
@@ -103,6 +119,15 @@ class PubMedQA(base.TextClassification):
|
|
|
103
119
|
except Exception as e:
|
|
104
120
|
raise RuntimeError(f"Failed to prepare dataset: {e}") from e
|
|
105
121
|
|
|
122
|
+
@override
|
|
123
|
+
def validate(self) -> None:
|
|
124
|
+
if len(self) != (self._max_samples or self._expected_dataset_lengths[self._split]):
|
|
125
|
+
raise ValueError(
|
|
126
|
+
f"Dataset length mismatch for split '{self._split}': "
|
|
127
|
+
f"expected {self._expected_dataset_lengths[self._split]}, "
|
|
128
|
+
f"but got {len(self)}"
|
|
129
|
+
)
|
|
130
|
+
|
|
106
131
|
@property
|
|
107
132
|
@override
|
|
108
133
|
def classes(self) -> List[str]:
|
|
@@ -114,11 +139,18 @@ class PubMedQA(base.TextClassification):
|
|
|
114
139
|
return {"no": 0, "yes": 1, "maybe": 2}
|
|
115
140
|
|
|
116
141
|
@override
|
|
117
|
-
def load_text(self, index: int) ->
|
|
142
|
+
def load_text(self, index: int) -> MessageSeries:
|
|
118
143
|
if index < 0 or index >= len(self.dataset):
|
|
119
144
|
raise IndexError(f"Index {index} out of range for dataset of size {len(self.dataset)}")
|
|
120
145
|
sample = dict(self.dataset[index])
|
|
121
|
-
return
|
|
146
|
+
return [
|
|
147
|
+
UserMessage(
|
|
148
|
+
content=f"Question: {sample['QUESTION']}\nContext: "
|
|
149
|
+
+ " ".join(sample["CONTEXTS"])
|
|
150
|
+
+ "\nInstruction: Carefully read the question and the provided context. "
|
|
151
|
+
+ "Answer with one word: 'yes', 'no', or 'maybe'. Answer: "
|
|
152
|
+
)
|
|
153
|
+
]
|
|
122
154
|
|
|
123
155
|
@override
|
|
124
156
|
def load_target(self, index: int) -> torch.Tensor:
|
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
"""Dataset class for loading pre-generated text predictions."""
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any, Dict, Generic, Literal
|
|
6
|
+
|
|
7
|
+
import pandas as pd
|
|
8
|
+
from typing_extensions import override
|
|
9
|
+
|
|
10
|
+
from eva.language.data.datasets.base import LanguageDataset
|
|
11
|
+
from eva.language.data.datasets.schemas import TransformsSchema
|
|
12
|
+
from eva.language.data.datasets.typings import PredictionSample, TargetType
|
|
13
|
+
from eva.language.data.messages import MessageSeries, UserMessage
|
|
14
|
+
from eva.language.utils.text import messages as message_utils
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class TextPredictionDataset(
|
|
18
|
+
LanguageDataset[PredictionSample[TargetType]], abc.ABC, Generic[TargetType]
|
|
19
|
+
):
|
|
20
|
+
"""Dataset class for loading pre-generated text predictions."""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
path: str,
|
|
25
|
+
prediction_column: str = "prediction",
|
|
26
|
+
target_column: str = "target",
|
|
27
|
+
text_column: str | None = None,
|
|
28
|
+
metadata_columns: list[str] | None = None,
|
|
29
|
+
split: Literal["train", "val", "test"] | None = None,
|
|
30
|
+
transforms: TransformsSchema | None = None,
|
|
31
|
+
):
|
|
32
|
+
"""Initialize the dataset.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
path: The path to the manifest file holding the predictions & targets.
|
|
36
|
+
prediction_column: The name of the prediction column.
|
|
37
|
+
target_column: The name of the label column.
|
|
38
|
+
text_column: The name of the column with the text inputs that were used
|
|
39
|
+
to generate the predictions. If the text column contains chat message
|
|
40
|
+
json format ([{"role": ..., "content": ...}]), it will be deserialized into
|
|
41
|
+
a list of Message objects. Otherwise, the content is interpreted as a
|
|
42
|
+
single user message.
|
|
43
|
+
metadata_columns: List of column names to include in metadata.
|
|
44
|
+
split: The dataset split to use (train, val, test). If not specified,
|
|
45
|
+
the entire dataset will be used.
|
|
46
|
+
transforms: The transforms to apply to the text and target when
|
|
47
|
+
loading the samples.
|
|
48
|
+
"""
|
|
49
|
+
super().__init__()
|
|
50
|
+
|
|
51
|
+
self.path = path
|
|
52
|
+
self.prediction_column = prediction_column
|
|
53
|
+
self.target_column = target_column
|
|
54
|
+
self.text_column = text_column
|
|
55
|
+
self.metadata_columns = metadata_columns
|
|
56
|
+
self.split = split
|
|
57
|
+
self.transforms = transforms
|
|
58
|
+
|
|
59
|
+
self._data: pd.DataFrame
|
|
60
|
+
|
|
61
|
+
@override
|
|
62
|
+
def __len__(self) -> int:
|
|
63
|
+
return len(self._data)
|
|
64
|
+
|
|
65
|
+
@override
|
|
66
|
+
def __getitem__(self, index: int) -> PredictionSample[TargetType]:
|
|
67
|
+
item = PredictionSample(
|
|
68
|
+
prediction=self.load_prediction(index),
|
|
69
|
+
target=self.load_target(index),
|
|
70
|
+
text=self.load_text(index),
|
|
71
|
+
metadata=self.load_metadata(index) or {},
|
|
72
|
+
)
|
|
73
|
+
return self._apply_transforms(item)
|
|
74
|
+
|
|
75
|
+
@override
|
|
76
|
+
def configure(self) -> None:
|
|
77
|
+
extension = Path(self.path).suffix
|
|
78
|
+
|
|
79
|
+
match extension:
|
|
80
|
+
case ".jsonl":
|
|
81
|
+
self._data = pd.read_json(self.path, lines=True)
|
|
82
|
+
case ".csv":
|
|
83
|
+
self._data = pd.read_csv(self.path)
|
|
84
|
+
case ".parquet":
|
|
85
|
+
self._data = pd.read_parquet(self.path)
|
|
86
|
+
case _:
|
|
87
|
+
raise ValueError(f"Unsupported file extension: {extension}")
|
|
88
|
+
|
|
89
|
+
if self.split is not None:
|
|
90
|
+
self._data = self._data[self._data["split"] == self.split].reset_index(drop=True) # type: ignore
|
|
91
|
+
|
|
92
|
+
@override
|
|
93
|
+
def validate(self) -> None:
|
|
94
|
+
if self.prediction_column not in self._data.columns:
|
|
95
|
+
raise ValueError(f"Label column '{self.prediction_column}' not found.")
|
|
96
|
+
if self.target_column not in self._data.columns:
|
|
97
|
+
raise ValueError(f"Label column '{self.target_column}' not found.")
|
|
98
|
+
if self.metadata_columns:
|
|
99
|
+
missing_columns = set(self.metadata_columns) - set(self._data.columns)
|
|
100
|
+
if missing_columns:
|
|
101
|
+
raise ValueError(f"Metadata columns {missing_columns} not found.")
|
|
102
|
+
|
|
103
|
+
def load_prediction(self, index: int) -> TargetType:
|
|
104
|
+
"""Returns the prediction for the given index."""
|
|
105
|
+
return self._data.iloc[index][self.prediction_column]
|
|
106
|
+
|
|
107
|
+
def load_target(self, index: int) -> TargetType:
|
|
108
|
+
"""Returns the target for the given index."""
|
|
109
|
+
return self._data.iloc[index][self.target_column]
|
|
110
|
+
|
|
111
|
+
def load_text(self, index: int) -> MessageSeries | None:
|
|
112
|
+
"""Returns the text for the given index."""
|
|
113
|
+
if self.text_column is None:
|
|
114
|
+
return None
|
|
115
|
+
|
|
116
|
+
text = self._data.iloc[index][self.text_column]
|
|
117
|
+
|
|
118
|
+
try:
|
|
119
|
+
return message_utils.deserialize(self._data.iloc[index][self.text_column])
|
|
120
|
+
except Exception:
|
|
121
|
+
return [UserMessage(content=text)]
|
|
122
|
+
|
|
123
|
+
def load_metadata(self, index: int) -> Dict[str, Any] | None:
|
|
124
|
+
"""Returns the metadata for the given index."""
|
|
125
|
+
if self.metadata_columns is None:
|
|
126
|
+
return None
|
|
127
|
+
|
|
128
|
+
row = self._data.iloc[index]
|
|
129
|
+
return {col: row[col] for col in self.metadata_columns}
|
|
130
|
+
|
|
131
|
+
def _apply_transforms(
|
|
132
|
+
self, sample: PredictionSample[TargetType]
|
|
133
|
+
) -> PredictionSample[TargetType]:
|
|
134
|
+
"""Applies the dataset transforms to the prediction and target."""
|
|
135
|
+
if self.transforms:
|
|
136
|
+
text = self.transforms.text(sample.text) if self.transforms.text else sample.text
|
|
137
|
+
prediction = (
|
|
138
|
+
self.transforms.prediction(sample.prediction)
|
|
139
|
+
if self.transforms.prediction
|
|
140
|
+
else sample.prediction
|
|
141
|
+
)
|
|
142
|
+
target = (
|
|
143
|
+
self.transforms.target(sample.target) if self.transforms.target else sample.target
|
|
144
|
+
)
|
|
145
|
+
return PredictionSample(
|
|
146
|
+
prediction=prediction,
|
|
147
|
+
target=target,
|
|
148
|
+
text=text,
|
|
149
|
+
metadata=sample.metadata,
|
|
150
|
+
)
|
|
151
|
+
return sample
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
"""Schema definitions for dataset classes."""
|
|
2
|
+
|
|
3
|
+
import dataclasses
|
|
4
|
+
from typing import Callable
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclasses.dataclass(frozen=True)
|
|
8
|
+
class TransformsSchema:
|
|
9
|
+
"""Schema for dataset transforms."""
|
|
10
|
+
|
|
11
|
+
text: Callable | None = None
|
|
12
|
+
"""Text transformation"""
|
|
13
|
+
|
|
14
|
+
target: Callable | None = None
|
|
15
|
+
"""Target transformation"""
|
|
16
|
+
|
|
17
|
+
prediction: Callable | None = None
|
|
18
|
+
"""Prediction transformation"""
|