kaiko-eva 0.3.3__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaiko-eva might be problematic. Click here for more details.

Files changed (131) hide show
  1. eva/core/callbacks/config.py +15 -6
  2. eva/core/callbacks/writers/embeddings/base.py +44 -10
  3. eva/core/cli/setup.py +1 -1
  4. eva/core/data/dataloaders/__init__.py +1 -2
  5. eva/core/data/samplers/classification/balanced.py +24 -12
  6. eva/core/data/samplers/random.py +17 -10
  7. eva/core/interface/interface.py +21 -0
  8. eva/core/loggers/utils/wandb.py +4 -1
  9. eva/core/models/modules/module.py +2 -2
  10. eva/core/models/wrappers/base.py +2 -2
  11. eva/core/models/wrappers/from_function.py +3 -3
  12. eva/core/models/wrappers/from_torchhub.py +9 -7
  13. eva/core/models/wrappers/huggingface.py +4 -5
  14. eva/core/models/wrappers/onnx.py +5 -5
  15. eva/core/trainers/trainer.py +13 -1
  16. eva/core/utils/__init__.py +2 -1
  17. eva/core/utils/distributed.py +12 -0
  18. eva/core/utils/paths.py +14 -0
  19. eva/core/utils/requirements.py +52 -6
  20. eva/language/__init__.py +2 -1
  21. eva/language/callbacks/__init__.py +5 -0
  22. eva/language/callbacks/writers/__init__.py +5 -0
  23. eva/language/callbacks/writers/prediction.py +201 -0
  24. eva/language/data/dataloaders/__init__.py +5 -0
  25. eva/language/data/dataloaders/collate_fn/__init__.py +5 -0
  26. eva/language/data/dataloaders/collate_fn/text.py +57 -0
  27. eva/language/data/datasets/__init__.py +3 -1
  28. eva/language/data/datasets/{language.py → base.py} +1 -1
  29. eva/language/data/datasets/classification/base.py +3 -43
  30. eva/language/data/datasets/classification/pubmedqa.py +36 -4
  31. eva/language/data/datasets/prediction.py +151 -0
  32. eva/language/data/datasets/schemas.py +18 -0
  33. eva/language/data/datasets/text.py +92 -0
  34. eva/language/data/datasets/typings.py +39 -0
  35. eva/language/data/messages.py +60 -0
  36. eva/language/models/__init__.py +15 -11
  37. eva/language/models/modules/__init__.py +2 -2
  38. eva/language/models/modules/language.py +94 -0
  39. eva/language/models/networks/__init__.py +12 -0
  40. eva/language/models/networks/alibaba.py +26 -0
  41. eva/language/models/networks/api/__init__.py +11 -0
  42. eva/language/models/networks/api/anthropic.py +34 -0
  43. eva/language/models/networks/registry.py +5 -0
  44. eva/language/models/typings.py +56 -0
  45. eva/language/models/wrappers/__init__.py +13 -5
  46. eva/language/models/wrappers/base.py +47 -0
  47. eva/language/models/wrappers/from_registry.py +54 -0
  48. eva/language/models/wrappers/huggingface.py +57 -11
  49. eva/language/models/wrappers/litellm.py +91 -46
  50. eva/language/models/wrappers/vllm.py +37 -13
  51. eva/language/utils/__init__.py +2 -1
  52. eva/language/utils/str_to_int_tensor.py +20 -12
  53. eva/language/utils/text/__init__.py +5 -0
  54. eva/language/utils/text/messages.py +113 -0
  55. eva/multimodal/__init__.py +6 -0
  56. eva/multimodal/callbacks/__init__.py +5 -0
  57. eva/multimodal/callbacks/writers/__init__.py +5 -0
  58. eva/multimodal/callbacks/writers/prediction.py +39 -0
  59. eva/multimodal/data/__init__.py +5 -0
  60. eva/multimodal/data/dataloaders/__init__.py +5 -0
  61. eva/multimodal/data/dataloaders/collate_fn/__init__.py +5 -0
  62. eva/multimodal/data/dataloaders/collate_fn/text_image.py +28 -0
  63. eva/multimodal/data/datasets/__init__.py +6 -0
  64. eva/multimodal/data/datasets/base.py +13 -0
  65. eva/multimodal/data/datasets/multiple_choice/__init__.py +5 -0
  66. eva/multimodal/data/datasets/multiple_choice/patch_camelyon.py +80 -0
  67. eva/multimodal/data/datasets/schemas.py +14 -0
  68. eva/multimodal/data/datasets/text_image.py +77 -0
  69. eva/multimodal/data/datasets/typings.py +27 -0
  70. eva/multimodal/models/__init__.py +8 -0
  71. eva/multimodal/models/modules/__init__.py +5 -0
  72. eva/multimodal/models/modules/vision_language.py +56 -0
  73. eva/multimodal/models/networks/__init__.py +14 -0
  74. eva/multimodal/models/networks/alibaba.py +40 -0
  75. eva/multimodal/models/networks/api/__init__.py +11 -0
  76. eva/multimodal/models/networks/api/anthropic.py +34 -0
  77. eva/multimodal/models/networks/others.py +48 -0
  78. eva/multimodal/models/networks/registry.py +5 -0
  79. eva/multimodal/models/typings.py +27 -0
  80. eva/multimodal/models/wrappers/__init__.py +13 -0
  81. eva/multimodal/models/wrappers/base.py +48 -0
  82. eva/multimodal/models/wrappers/from_registry.py +54 -0
  83. eva/multimodal/models/wrappers/huggingface.py +193 -0
  84. eva/multimodal/models/wrappers/litellm.py +58 -0
  85. eva/multimodal/utils/__init__.py +1 -0
  86. eva/multimodal/utils/batch/__init__.py +5 -0
  87. eva/multimodal/utils/batch/unpack.py +11 -0
  88. eva/multimodal/utils/image/__init__.py +5 -0
  89. eva/multimodal/utils/image/encode.py +28 -0
  90. eva/multimodal/utils/text/__init__.py +1 -0
  91. eva/multimodal/utils/text/messages.py +79 -0
  92. eva/vision/data/datasets/classification/breakhis.py +5 -8
  93. eva/vision/data/datasets/classification/panda.py +12 -5
  94. eva/vision/data/datasets/classification/patch_camelyon.py +8 -6
  95. eva/vision/data/datasets/segmentation/btcv.py +1 -1
  96. eva/vision/data/datasets/segmentation/consep.py +1 -1
  97. eva/vision/data/datasets/segmentation/lits17.py +1 -1
  98. eva/vision/data/datasets/segmentation/monusac.py +15 -6
  99. eva/vision/data/datasets/segmentation/msd_task7_pancreas.py +1 -1
  100. eva/vision/data/transforms/__init__.py +2 -1
  101. eva/vision/data/transforms/base/__init__.py +2 -1
  102. eva/vision/data/transforms/base/monai.py +2 -2
  103. eva/vision/data/transforms/base/torchvision.py +33 -0
  104. eva/vision/data/transforms/common/squeeze.py +6 -3
  105. eva/vision/data/transforms/croppad/crop_foreground.py +8 -7
  106. eva/vision/data/transforms/croppad/rand_crop_by_label_classes.py +6 -5
  107. eva/vision/data/transforms/croppad/rand_crop_by_pos_neg_label.py +6 -5
  108. eva/vision/data/transforms/croppad/rand_spatial_crop.py +8 -7
  109. eva/vision/data/transforms/croppad/spatial_pad.py +6 -6
  110. eva/vision/data/transforms/intensity/rand_scale_intensity.py +3 -3
  111. eva/vision/data/transforms/intensity/rand_shift_intensity.py +3 -3
  112. eva/vision/data/transforms/intensity/scale_intensity_ranged.py +5 -5
  113. eva/vision/data/transforms/spatial/__init__.py +2 -1
  114. eva/vision/data/transforms/spatial/flip.py +8 -7
  115. eva/vision/data/transforms/spatial/functional/__init__.py +5 -0
  116. eva/vision/data/transforms/spatial/functional/resize.py +26 -0
  117. eva/vision/data/transforms/spatial/resize.py +63 -0
  118. eva/vision/data/transforms/spatial/rotate.py +8 -7
  119. eva/vision/data/transforms/spatial/spacing.py +7 -6
  120. eva/vision/data/transforms/utility/ensure_channel_first.py +6 -6
  121. eva/vision/models/networks/backbones/universal/vit.py +24 -0
  122. eva/vision/models/wrappers/from_registry.py +6 -5
  123. eva/vision/models/wrappers/from_timm.py +6 -4
  124. {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/METADATA +17 -3
  125. {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/RECORD +128 -66
  126. eva/core/data/dataloaders/collate_fn/__init__.py +0 -5
  127. eva/core/data/dataloaders/collate_fn/collate.py +0 -24
  128. eva/language/models/modules/text.py +0 -85
  129. {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/WHEEL +0 -0
  130. {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/entry_points.txt +0 -0
  131. {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,5 @@
1
+ """Language callbacks API."""
2
+
3
+ from eva.language.callbacks.writers import TextPredictionWriter
4
+
5
+ __all__ = ["TextPredictionWriter"]
@@ -0,0 +1,5 @@
1
+ """Language writers callbacks API."""
2
+
3
+ from eva.language.callbacks.writers.prediction import TextPredictionWriter
4
+
5
+ __all__ = ["TextPredictionWriter"]
@@ -0,0 +1,201 @@
1
+ """Text prediction writer callbacks."""
2
+
3
+ import abc
4
+ import os
5
+ from typing import Any, Dict, List, Literal, Sequence, Tuple, TypedDict
6
+
7
+ import lightning.pytorch as pl
8
+ import pandas as pd
9
+ import torch
10
+ import torch.distributed as dist
11
+ from lightning.pytorch import callbacks
12
+ from torch import nn
13
+ from typing_extensions import NotRequired, override
14
+
15
+ from eva.core.models.modules import utils as module_utils
16
+ from eva.core.utils import distributed as dist_utils
17
+ from eva.language.models.typings import TextBatch
18
+ from eva.language.utils.text import messages as message_utils
19
+
20
+
21
+ class ManifestEntry(TypedDict):
22
+ """A single entry in the manifest file."""
23
+
24
+ prediction: str
25
+ """The predicted text."""
26
+
27
+ target: str
28
+ """The ground truth text."""
29
+
30
+ text: NotRequired[str]
31
+ """The input text data."""
32
+
33
+ split: NotRequired[str]
34
+ """The dataset split (e.g. train, val, test)."""
35
+
36
+
37
+ class TextPredictionWriter(callbacks.BasePredictionWriter, abc.ABC):
38
+ """Callback for writing generated text predictions to disk."""
39
+
40
+ def __init__(
41
+ self,
42
+ output_dir: str,
43
+ model: nn.Module,
44
+ dataloader_idx_map: Dict[int, str] | None = None,
45
+ metadata_keys: List[str] | None = None,
46
+ include_input: bool = True,
47
+ overwrite: bool = False,
48
+ apply_postprocess: bool = False,
49
+ save_format: Literal["jsonl", "parquet", "csv"] = "jsonl",
50
+ ) -> None:
51
+ """Initializes a new callback.
52
+
53
+ Args:
54
+ output_dir: The directory where the embeddings will be saved.
55
+ model: The model instance used to generate the predictions.
56
+ dataloader_idx_map: A dictionary mapping dataloader indices to their respective
57
+ names (e.g. train, val, test).
58
+ metadata_keys: An optional list of keys to extract from the batch metadata and store
59
+ as additional columns in the manifest file.
60
+ include_input: Whether to include the original input text messages in the output.
61
+ overwrite: Whether to overwrite if embeddings are already present in the specified
62
+ output directory. If set to `False`, an error will be raised if embeddings are
63
+ already present (recommended).
64
+ apply_postprocess: Whether to apply the postprocesses specified in the model module.
65
+ save_format: The file format to use for saving the manifest file with the predictions.
66
+ """
67
+ super().__init__()
68
+ self.output_dir = output_dir
69
+ self.model = model
70
+ self.dataloader_idx_map = dataloader_idx_map or {}
71
+ self.metadata_keys = metadata_keys
72
+ self.include_input = include_input
73
+ self.overwrite = overwrite
74
+ self.apply_postprocess = apply_postprocess
75
+ self.save_format = save_format
76
+
77
+ self._manifest_path = os.path.join(self.output_dir, f"manifest.{self.save_format}")
78
+ self._data: List[ManifestEntry] = []
79
+ self._is_rank_zero: bool = False
80
+
81
+ @override
82
+ def on_predict_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
83
+ self._is_rank_zero = trainer.is_global_zero
84
+
85
+ if self._is_rank_zero:
86
+ self._check_if_exists()
87
+
88
+ self.model = self.model.to(pl_module.device)
89
+ self.model.eval()
90
+
91
+ @override
92
+ def write_on_batch_end(
93
+ self,
94
+ trainer: pl.Trainer,
95
+ pl_module: pl.LightningModule,
96
+ prediction: Any,
97
+ batch_indices: Sequence[int],
98
+ batch: TextBatch,
99
+ batch_idx: int,
100
+ dataloader_idx: int,
101
+ ) -> None:
102
+ text_batch, target_batch, metadata_batch = self._unpack_batch(batch)
103
+ has_target = target_batch is not None
104
+ split = self.dataloader_idx_map.get(dataloader_idx, "")
105
+
106
+ prediction_batch = self._get_predictions(batch)
107
+
108
+ target_batch, prediction_batch = self._apply_postprocess(
109
+ pl_module, target_batch, prediction_batch
110
+ )
111
+
112
+ for i in range(len(batch_indices)):
113
+ entry: ManifestEntry = {
114
+ "prediction": str(prediction_batch[i]),
115
+ "target": str(target_batch[i]) if has_target else "",
116
+ "split": split if split else "",
117
+ }
118
+ if self.include_input:
119
+ entry["text"] = message_utils.serialize(text_batch[i])
120
+
121
+ if self.metadata_keys is not None and metadata_batch is not None:
122
+ for key in self.metadata_keys:
123
+ entry[key] = metadata_batch[key][i]
124
+
125
+ self._data.append(entry)
126
+
127
+ @override
128
+ def on_predict_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
129
+ """Saves the gathered predictions to a manifest file."""
130
+ if dist_utils.is_distributed():
131
+ dist.barrier()
132
+ data = self._gather_data_from_ranks()
133
+ else:
134
+ data = self._data
135
+
136
+ if self._is_rank_zero:
137
+ df = pd.DataFrame(data)
138
+
139
+ match self.save_format:
140
+ case "jsonl":
141
+ df.to_json(self._manifest_path, orient="records", lines=True)
142
+ case "parquet":
143
+ df.to_parquet(self._manifest_path, index=False)
144
+ case "csv":
145
+ df.to_csv(self._manifest_path, index=False)
146
+ case _:
147
+ raise ValueError(f"Unsupported save format: {self.save_format}")
148
+
149
+ def _gather_data_from_ranks(self) -> List[ManifestEntry]:
150
+ world_size = dist.get_world_size()
151
+ gathered: List[List[ManifestEntry] | None] = [None] * world_size
152
+ dist.all_gather_object(gathered, self._data)
153
+ return [row for shard in gathered for row in (shard or [])]
154
+
155
+ def _get_predictions(self, batch: TextBatch) -> List[str]:
156
+ with torch.no_grad():
157
+ output = self.model(batch)
158
+
159
+ if (
160
+ not isinstance(output, dict)
161
+ or "generated_text" not in output
162
+ or not all(isinstance(p, str) for p in output["generated_text"])
163
+ ):
164
+ raise ValueError(
165
+ f"A dictionary with 'generated_text' key is expected, got {type(output)}"
166
+ )
167
+
168
+ return output["generated_text"]
169
+
170
+ def _check_if_exists(self) -> None:
171
+ """Checks if the output directory already exists and if it should be overwritten."""
172
+ os.makedirs(self.output_dir, exist_ok=True)
173
+ if os.path.exists(self._manifest_path) and not self.overwrite:
174
+ raise FileExistsError(
175
+ f"The specified output directory already exists: {self.output_dir}. This "
176
+ "either means that the predictions have been computed before or that a "
177
+ "wrong output directory is being used."
178
+ )
179
+
180
+ def _apply_postprocess(
181
+ self, pl_module: pl.LightningModule, targets: Any, predictions: Any
182
+ ) -> Tuple[List[Any], List[Any]]:
183
+ def _to_list(data: Any) -> List[Any]:
184
+ if isinstance(data, torch.Tensor):
185
+ return data.cpu().tolist()
186
+ return data
187
+
188
+ if self.apply_postprocess and hasattr(pl_module, "postprocess"):
189
+ if (
190
+ isinstance(pl_module.postprocess, module_utils.BatchPostProcess)
191
+ and pl_module.postprocess.predictions_transforms is not None
192
+ ):
193
+ outputs = {"targets": targets, "predictions": predictions}
194
+ pl_module.postprocess(outputs)
195
+ targets, predictions = outputs["targets"], outputs["predictions"]
196
+
197
+ return _to_list(targets), _to_list(predictions)
198
+
199
+ def _unpack_batch(self, batch: TextBatch) -> Tuple[list, list | None, dict | None]:
200
+ text_batch, target_batch, metadata_batch = TextBatch(*batch)
201
+ return text_batch, target_batch, metadata_batch
@@ -0,0 +1,5 @@
1
+ """Language Dataloaders API."""
2
+
3
+ from eva.language.data.dataloaders.collate_fn import prediction_collate, text_collate
4
+
5
+ __all__ = ["text_collate", "prediction_collate"]
@@ -0,0 +1,5 @@
1
+ """Collate functions API."""
2
+
3
+ from eva.language.data.dataloaders.collate_fn.text import prediction_collate, text_collate
4
+
5
+ __all__ = ["text_collate", "prediction_collate"]
@@ -0,0 +1,57 @@
1
+ """Collate functions for text data."""
2
+
3
+ from typing import List
4
+
5
+ from torch.utils.data._utils.collate import default_collate
6
+
7
+ from eva.language.data.datasets.typings import PredictionSample, TextSample
8
+ from eva.language.models.typings import PredictionBatch, TextBatch
9
+
10
+
11
+ def text_collate(batch: List[TextSample]) -> TextBatch:
12
+ """Collate function for text data that keeps texts as separate strings.
13
+
14
+ Args:
15
+ batch: List of tuples containing (text, target, metadata) from the dataset
16
+
17
+ Returns:
18
+ A batch of text samples with targets and metadata.
19
+ """
20
+ texts, targets, metadata = zip(*batch, strict=False)
21
+ first_sample = batch[0]
22
+ metadata = None
23
+ if first_sample.metadata is not None:
24
+ metadata = {
25
+ k: [sample.metadata[k] for sample in batch if sample.metadata]
26
+ for k in first_sample.metadata.keys()
27
+ }
28
+ return TextBatch(
29
+ text=list(texts),
30
+ target=default_collate(targets) if targets[0] is not None else None,
31
+ metadata=metadata,
32
+ )
33
+
34
+
35
+ def prediction_collate(batch: List[PredictionSample]) -> PredictionBatch:
36
+ """Collate function for text prediction data.
37
+
38
+ Args:
39
+ batch: List of tuples containing (prediction, target, text, metadata) from the dataset
40
+
41
+ Returns:
42
+ A batch of prediction samples.
43
+ """
44
+ predictions, targets, texts, metadata = zip(*batch, strict=False)
45
+ first_sample = batch[0]
46
+ metadata = None
47
+ if first_sample.metadata is not None:
48
+ metadata = {
49
+ k: [sample.metadata[k] for sample in batch if sample.metadata]
50
+ for k in first_sample.metadata.keys()
51
+ }
52
+ return PredictionBatch(
53
+ prediction=default_collate(predictions) if predictions[0] is not None else None,
54
+ target=default_collate(targets) if targets[0] is not None else None,
55
+ text=list(texts) if first_sample.text is not None else None,
56
+ metadata=metadata,
57
+ )
@@ -1,9 +1,11 @@
1
1
  """Language Datasets API."""
2
2
 
3
+ from eva.language.data.datasets.base import LanguageDataset
3
4
  from eva.language.data.datasets.classification import PubMedQA
4
- from eva.language.data.datasets.language import LanguageDataset
5
+ from eva.language.data.datasets.prediction import TextPredictionDataset
5
6
 
6
7
  __all__ = [
7
8
  "PubMedQA",
8
9
  "LanguageDataset",
10
+ "TextPredictionDataset",
9
11
  ]
@@ -10,4 +10,4 @@ DataSample = TypeVar("DataSample")
10
10
 
11
11
 
12
12
  class LanguageDataset(base.MapDataset, abc.ABC, Generic[DataSample]):
13
- """Base dataset class for text tasks."""
13
+ """Base dataset class for language tasks."""
@@ -1,15 +1,13 @@
1
1
  """Base for text classification datasets."""
2
2
 
3
- import abc
4
- from typing import Any, Dict, List, Tuple
3
+ from typing import Dict, List
5
4
 
6
5
  import torch
7
- from typing_extensions import override
8
6
 
9
- from eva.language.data.datasets.language import LanguageDataset
7
+ from eva.language.data.datasets.text import TextDataset
10
8
 
11
9
 
12
- class TextClassification(LanguageDataset[Tuple[str, torch.Tensor, Dict[str, Any]]], abc.ABC):
10
+ class TextClassification(TextDataset[torch.Tensor]):
13
11
  """Text classification abstract dataset."""
14
12
 
15
13
  def __init__(self) -> None:
@@ -23,41 +21,3 @@ class TextClassification(LanguageDataset[Tuple[str, torch.Tensor, Dict[str, Any]
23
21
  @property
24
22
  def class_to_idx(self) -> Dict[str, int] | None:
25
23
  """Returns class name to index mapping."""
26
-
27
- def load_metadata(self, index: int) -> Dict[str, Any] | None:
28
- """Returns the dataset metadata.
29
-
30
- Args:
31
- index: The index of the data sample.
32
-
33
- Returns:
34
- The sample metadata.
35
- """
36
-
37
- @abc.abstractmethod
38
- def load_text(self, index: int) -> str:
39
- """Returns the text content.
40
-
41
- Args:
42
- index: The index of the data sample.
43
-
44
- Returns:
45
- The text content.
46
- """
47
- raise NotImplementedError
48
-
49
- @abc.abstractmethod
50
- def load_target(self, index: int) -> torch.Tensor:
51
- """Returns the target label.
52
-
53
- Args:
54
- index: The index of the data sample.
55
-
56
- Returns:
57
- The target label.
58
- """
59
- raise NotImplementedError
60
-
61
- @override
62
- def __getitem__(self, index: int) -> Tuple[str, torch.Tensor, Dict[str, Any]]:
63
- return (self.load_text(index), self.load_target(index), self.load_metadata(index) or {})
@@ -10,11 +10,20 @@ from loguru import logger
10
10
  from typing_extensions import override
11
11
 
12
12
  from eva.language.data.datasets.classification import base
13
+ from eva.language.data.messages import MessageSeries, UserMessage
13
14
 
14
15
 
15
16
  class PubMedQA(base.TextClassification):
16
17
  """Dataset class for PubMedQA question answering task."""
17
18
 
19
+ _expected_dataset_lengths: Dict[str | None, int] = {
20
+ "train": 450,
21
+ "val": 50,
22
+ "test": 500,
23
+ None: 500,
24
+ }
25
+ """Expected dataset lengths for the splits and complete dataset."""
26
+
18
27
  _license: str = "MIT License (https://github.com/pubmedqa/pubmedqa/blob/master/LICENSE)"
19
28
  """Dataset license."""
20
29
 
@@ -52,7 +61,14 @@ class PubMedQA(base.TextClassification):
52
61
  """
53
62
  dataset_name = "bigbio/pubmed_qa"
54
63
  config_name = "pubmed_qa_labeled_fold0_source"
55
- split = (self._split or "train+test+validation") if self._split != "val" else "validation"
64
+
65
+ match self._split:
66
+ case "val":
67
+ split = "validation"
68
+ case None:
69
+ split = "train+test+validation"
70
+ case _:
71
+ split = self._split
56
72
 
57
73
  if self._download:
58
74
  logger.info("Downloading dataset from HuggingFace Hub")
@@ -88,7 +104,7 @@ class PubMedQA(base.TextClassification):
88
104
  dataset_path = None
89
105
 
90
106
  if self._root:
91
- dataset_path = self._root
107
+ dataset_path = os.path.join(self._root, self._split) if self._split else self._root
92
108
  os.makedirs(self._root, exist_ok=True)
93
109
 
94
110
  try:
@@ -103,6 +119,15 @@ class PubMedQA(base.TextClassification):
103
119
  except Exception as e:
104
120
  raise RuntimeError(f"Failed to prepare dataset: {e}") from e
105
121
 
122
+ @override
123
+ def validate(self) -> None:
124
+ if len(self) != (self._max_samples or self._expected_dataset_lengths[self._split]):
125
+ raise ValueError(
126
+ f"Dataset length mismatch for split '{self._split}': "
127
+ f"expected {self._expected_dataset_lengths[self._split]}, "
128
+ f"but got {len(self)}"
129
+ )
130
+
106
131
  @property
107
132
  @override
108
133
  def classes(self) -> List[str]:
@@ -114,11 +139,18 @@ class PubMedQA(base.TextClassification):
114
139
  return {"no": 0, "yes": 1, "maybe": 2}
115
140
 
116
141
  @override
117
- def load_text(self, index: int) -> str:
142
+ def load_text(self, index: int) -> MessageSeries:
118
143
  if index < 0 or index >= len(self.dataset):
119
144
  raise IndexError(f"Index {index} out of range for dataset of size {len(self.dataset)}")
120
145
  sample = dict(self.dataset[index])
121
- return f"Question: {sample['QUESTION']}\nContext: " + " ".join(sample["CONTEXTS"])
146
+ return [
147
+ UserMessage(
148
+ content=f"Question: {sample['QUESTION']}\nContext: "
149
+ + " ".join(sample["CONTEXTS"])
150
+ + "\nInstruction: Carefully read the question and the provided context. "
151
+ + "Answer with one word: 'yes', 'no', or 'maybe'. Answer: "
152
+ )
153
+ ]
122
154
 
123
155
  @override
124
156
  def load_target(self, index: int) -> torch.Tensor:
@@ -0,0 +1,151 @@
1
+ """Dataset class for loading pre-generated text predictions."""
2
+
3
+ import abc
4
+ from pathlib import Path
5
+ from typing import Any, Dict, Generic, Literal
6
+
7
+ import pandas as pd
8
+ from typing_extensions import override
9
+
10
+ from eva.language.data.datasets.base import LanguageDataset
11
+ from eva.language.data.datasets.schemas import TransformsSchema
12
+ from eva.language.data.datasets.typings import PredictionSample, TargetType
13
+ from eva.language.data.messages import MessageSeries, UserMessage
14
+ from eva.language.utils.text import messages as message_utils
15
+
16
+
17
+ class TextPredictionDataset(
18
+ LanguageDataset[PredictionSample[TargetType]], abc.ABC, Generic[TargetType]
19
+ ):
20
+ """Dataset class for loading pre-generated text predictions."""
21
+
22
+ def __init__(
23
+ self,
24
+ path: str,
25
+ prediction_column: str = "prediction",
26
+ target_column: str = "target",
27
+ text_column: str | None = None,
28
+ metadata_columns: list[str] | None = None,
29
+ split: Literal["train", "val", "test"] | None = None,
30
+ transforms: TransformsSchema | None = None,
31
+ ):
32
+ """Initialize the dataset.
33
+
34
+ Args:
35
+ path: The path to the manifest file holding the predictions & targets.
36
+ prediction_column: The name of the prediction column.
37
+ target_column: The name of the label column.
38
+ text_column: The name of the column with the text inputs that were used
39
+ to generate the predictions. If the text column contains chat message
40
+ json format ([{"role": ..., "content": ...}]), it will be deserialized into
41
+ a list of Message objects. Otherwise, the content is interpreted as a
42
+ single user message.
43
+ metadata_columns: List of column names to include in metadata.
44
+ split: The dataset split to use (train, val, test). If not specified,
45
+ the entire dataset will be used.
46
+ transforms: The transforms to apply to the text and target when
47
+ loading the samples.
48
+ """
49
+ super().__init__()
50
+
51
+ self.path = path
52
+ self.prediction_column = prediction_column
53
+ self.target_column = target_column
54
+ self.text_column = text_column
55
+ self.metadata_columns = metadata_columns
56
+ self.split = split
57
+ self.transforms = transforms
58
+
59
+ self._data: pd.DataFrame
60
+
61
+ @override
62
+ def __len__(self) -> int:
63
+ return len(self._data)
64
+
65
+ @override
66
+ def __getitem__(self, index: int) -> PredictionSample[TargetType]:
67
+ item = PredictionSample(
68
+ prediction=self.load_prediction(index),
69
+ target=self.load_target(index),
70
+ text=self.load_text(index),
71
+ metadata=self.load_metadata(index) or {},
72
+ )
73
+ return self._apply_transforms(item)
74
+
75
+ @override
76
+ def configure(self) -> None:
77
+ extension = Path(self.path).suffix
78
+
79
+ match extension:
80
+ case ".jsonl":
81
+ self._data = pd.read_json(self.path, lines=True)
82
+ case ".csv":
83
+ self._data = pd.read_csv(self.path)
84
+ case ".parquet":
85
+ self._data = pd.read_parquet(self.path)
86
+ case _:
87
+ raise ValueError(f"Unsupported file extension: {extension}")
88
+
89
+ if self.split is not None:
90
+ self._data = self._data[self._data["split"] == self.split].reset_index(drop=True) # type: ignore
91
+
92
+ @override
93
+ def validate(self) -> None:
94
+ if self.prediction_column not in self._data.columns:
95
+ raise ValueError(f"Label column '{self.prediction_column}' not found.")
96
+ if self.target_column not in self._data.columns:
97
+ raise ValueError(f"Label column '{self.target_column}' not found.")
98
+ if self.metadata_columns:
99
+ missing_columns = set(self.metadata_columns) - set(self._data.columns)
100
+ if missing_columns:
101
+ raise ValueError(f"Metadata columns {missing_columns} not found.")
102
+
103
+ def load_prediction(self, index: int) -> TargetType:
104
+ """Returns the prediction for the given index."""
105
+ return self._data.iloc[index][self.prediction_column]
106
+
107
+ def load_target(self, index: int) -> TargetType:
108
+ """Returns the target for the given index."""
109
+ return self._data.iloc[index][self.target_column]
110
+
111
+ def load_text(self, index: int) -> MessageSeries | None:
112
+ """Returns the text for the given index."""
113
+ if self.text_column is None:
114
+ return None
115
+
116
+ text = self._data.iloc[index][self.text_column]
117
+
118
+ try:
119
+ return message_utils.deserialize(self._data.iloc[index][self.text_column])
120
+ except Exception:
121
+ return [UserMessage(content=text)]
122
+
123
+ def load_metadata(self, index: int) -> Dict[str, Any] | None:
124
+ """Returns the metadata for the given index."""
125
+ if self.metadata_columns is None:
126
+ return None
127
+
128
+ row = self._data.iloc[index]
129
+ return {col: row[col] for col in self.metadata_columns}
130
+
131
+ def _apply_transforms(
132
+ self, sample: PredictionSample[TargetType]
133
+ ) -> PredictionSample[TargetType]:
134
+ """Applies the dataset transforms to the prediction and target."""
135
+ if self.transforms:
136
+ text = self.transforms.text(sample.text) if self.transforms.text else sample.text
137
+ prediction = (
138
+ self.transforms.prediction(sample.prediction)
139
+ if self.transforms.prediction
140
+ else sample.prediction
141
+ )
142
+ target = (
143
+ self.transforms.target(sample.target) if self.transforms.target else sample.target
144
+ )
145
+ return PredictionSample(
146
+ prediction=prediction,
147
+ target=target,
148
+ text=text,
149
+ metadata=sample.metadata,
150
+ )
151
+ return sample
@@ -0,0 +1,18 @@
1
+ """Schema definitions for dataset classes."""
2
+
3
+ import dataclasses
4
+ from typing import Callable
5
+
6
+
7
+ @dataclasses.dataclass(frozen=True)
8
+ class TransformsSchema:
9
+ """Schema for dataset transforms."""
10
+
11
+ text: Callable | None = None
12
+ """Text transformation"""
13
+
14
+ target: Callable | None = None
15
+ """Target transformation"""
16
+
17
+ prediction: Callable | None = None
18
+ """Prediction transformation"""