kaiko-eva 0.3.3__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaiko-eva might be problematic. Click here for more details.

Files changed (98) hide show
  1. eva/core/callbacks/config.py +4 -0
  2. eva/core/cli/setup.py +1 -1
  3. eva/core/data/dataloaders/__init__.py +1 -2
  4. eva/core/data/samplers/random.py +17 -10
  5. eva/core/interface/interface.py +21 -0
  6. eva/core/models/modules/module.py +2 -2
  7. eva/core/models/wrappers/base.py +2 -2
  8. eva/core/models/wrappers/from_function.py +3 -3
  9. eva/core/models/wrappers/from_torchhub.py +9 -7
  10. eva/core/models/wrappers/huggingface.py +4 -5
  11. eva/core/models/wrappers/onnx.py +5 -5
  12. eva/core/trainers/trainer.py +2 -0
  13. eva/language/__init__.py +2 -1
  14. eva/language/callbacks/__init__.py +5 -0
  15. eva/language/callbacks/writers/__init__.py +5 -0
  16. eva/language/callbacks/writers/prediction.py +176 -0
  17. eva/language/data/dataloaders/__init__.py +5 -0
  18. eva/language/data/dataloaders/collate_fn/__init__.py +5 -0
  19. eva/language/data/dataloaders/collate_fn/text.py +57 -0
  20. eva/language/data/datasets/__init__.py +3 -1
  21. eva/language/data/datasets/{language.py → base.py} +1 -1
  22. eva/language/data/datasets/classification/base.py +3 -43
  23. eva/language/data/datasets/classification/pubmedqa.py +36 -4
  24. eva/language/data/datasets/prediction.py +151 -0
  25. eva/language/data/datasets/schemas.py +18 -0
  26. eva/language/data/datasets/text.py +92 -0
  27. eva/language/data/datasets/typings.py +39 -0
  28. eva/language/data/messages.py +60 -0
  29. eva/language/models/__init__.py +15 -11
  30. eva/language/models/modules/__init__.py +2 -2
  31. eva/language/models/modules/language.py +93 -0
  32. eva/language/models/networks/__init__.py +12 -0
  33. eva/language/models/networks/alibaba.py +26 -0
  34. eva/language/models/networks/api/__init__.py +11 -0
  35. eva/language/models/networks/api/anthropic.py +34 -0
  36. eva/language/models/networks/registry.py +5 -0
  37. eva/language/models/typings.py +39 -0
  38. eva/language/models/wrappers/__init__.py +13 -5
  39. eva/language/models/wrappers/base.py +47 -0
  40. eva/language/models/wrappers/from_registry.py +54 -0
  41. eva/language/models/wrappers/huggingface.py +44 -8
  42. eva/language/models/wrappers/litellm.py +81 -46
  43. eva/language/models/wrappers/vllm.py +37 -13
  44. eva/language/utils/__init__.py +2 -1
  45. eva/language/utils/str_to_int_tensor.py +20 -12
  46. eva/language/utils/text/__init__.py +5 -0
  47. eva/language/utils/text/messages.py +113 -0
  48. eva/multimodal/__init__.py +6 -0
  49. eva/multimodal/callbacks/__init__.py +5 -0
  50. eva/multimodal/callbacks/writers/__init__.py +5 -0
  51. eva/multimodal/callbacks/writers/prediction.py +39 -0
  52. eva/multimodal/data/__init__.py +5 -0
  53. eva/multimodal/data/dataloaders/__init__.py +5 -0
  54. eva/multimodal/data/dataloaders/collate_fn/__init__.py +5 -0
  55. eva/multimodal/data/dataloaders/collate_fn/text_image.py +28 -0
  56. eva/multimodal/data/datasets/__init__.py +6 -0
  57. eva/multimodal/data/datasets/base.py +13 -0
  58. eva/multimodal/data/datasets/multiple_choice/__init__.py +5 -0
  59. eva/multimodal/data/datasets/multiple_choice/patch_camelyon.py +80 -0
  60. eva/multimodal/data/datasets/schemas.py +14 -0
  61. eva/multimodal/data/datasets/text_image.py +77 -0
  62. eva/multimodal/data/datasets/typings.py +27 -0
  63. eva/multimodal/models/__init__.py +8 -0
  64. eva/multimodal/models/modules/__init__.py +5 -0
  65. eva/multimodal/models/modules/vision_language.py +55 -0
  66. eva/multimodal/models/networks/__init__.py +14 -0
  67. eva/multimodal/models/networks/alibaba.py +39 -0
  68. eva/multimodal/models/networks/api/__init__.py +11 -0
  69. eva/multimodal/models/networks/api/anthropic.py +34 -0
  70. eva/multimodal/models/networks/others.py +47 -0
  71. eva/multimodal/models/networks/registry.py +5 -0
  72. eva/multimodal/models/typings.py +27 -0
  73. eva/multimodal/models/wrappers/__init__.py +13 -0
  74. eva/multimodal/models/wrappers/base.py +47 -0
  75. eva/multimodal/models/wrappers/from_registry.py +54 -0
  76. eva/multimodal/models/wrappers/huggingface.py +180 -0
  77. eva/multimodal/models/wrappers/litellm.py +56 -0
  78. eva/multimodal/utils/__init__.py +1 -0
  79. eva/multimodal/utils/image/__init__.py +5 -0
  80. eva/multimodal/utils/image/encode.py +28 -0
  81. eva/multimodal/utils/text/__init__.py +1 -0
  82. eva/multimodal/utils/text/messages.py +79 -0
  83. eva/vision/data/datasets/classification/patch_camelyon.py +8 -6
  84. eva/vision/data/transforms/__init__.py +2 -1
  85. eva/vision/data/transforms/spatial/__init__.py +2 -1
  86. eva/vision/data/transforms/spatial/functional/__init__.py +5 -0
  87. eva/vision/data/transforms/spatial/functional/resize.py +26 -0
  88. eva/vision/data/transforms/spatial/resize.py +62 -0
  89. eva/vision/models/wrappers/from_registry.py +6 -5
  90. eva/vision/models/wrappers/from_timm.py +6 -4
  91. {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.0.dist-info}/METADATA +10 -2
  92. {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.0.dist-info}/RECORD +95 -38
  93. eva/core/data/dataloaders/collate_fn/__init__.py +0 -5
  94. eva/core/data/dataloaders/collate_fn/collate.py +0 -24
  95. eva/language/models/modules/text.py +0 -85
  96. {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.0.dist-info}/WHEEL +0 -0
  97. {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.0.dist-info}/entry_points.txt +0 -0
  98. {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.0.dist-info}/licenses/LICENSE +0 -0
@@ -51,6 +51,10 @@ class ConfigurationLogger(pl.Callback):
51
51
 
52
52
  save_as = os.path.join(log_dir, self._save_as)
53
53
  fs = cloud_io.get_filesystem(log_dir)
54
+
55
+ if not fs.exists(log_dir):
56
+ fs.makedirs(log_dir)
57
+
54
58
  with fs.open(save_as, "w") as output_file:
55
59
  yaml.dump(configuration, output_file, sort_keys=False)
56
60
 
eva/core/cli/setup.py CHANGED
@@ -59,7 +59,7 @@ def _initialize_logger() -> None:
59
59
  " :: <bold><level>{level}</level></bold>"
60
60
  " :: {message}",
61
61
  colorize=True,
62
- level="INFO",
62
+ level=os.getenv("LOGURU_LEVEL", "INFO"),
63
63
  )
64
64
 
65
65
 
@@ -1,6 +1,5 @@
1
1
  """Dataloaders API."""
2
2
 
3
- from eva.core.data.dataloaders.collate_fn import text_collate_fn
4
3
  from eva.core.data.dataloaders.dataloader import DataLoader
5
4
 
6
- __all__ = ["text_collate_fn", "DataLoader"]
5
+ __all__ = ["DataLoader"]
@@ -2,6 +2,7 @@
2
2
 
3
3
  from typing import Optional
4
4
 
5
+ import torch
5
6
  from torch.utils import data
6
7
  from typing_extensions import override
7
8
 
@@ -10,30 +11,36 @@ from eva.core.data.samplers.sampler import SamplerWithDataSource
10
11
 
11
12
 
12
13
  class RandomSampler(data.RandomSampler, SamplerWithDataSource[int]):
13
- """Samples elements randomly."""
14
+ """Samples elements randomly from a MapDataset."""
14
15
 
15
16
  data_source: datasets.MapDataset # type: ignore
16
17
 
17
18
  def __init__(
18
- self, replacement: bool = False, num_samples: Optional[int] = None, generator=None
19
+ self,
20
+ replacement: bool = False,
21
+ num_samples: Optional[int] = None,
22
+ seed: Optional[int] = None,
19
23
  ) -> None:
20
- """Initializes the random sampler.
24
+ """Initialize the random sampler.
21
25
 
22
26
  Args:
23
- data_source: dataset to sample from
24
- replacement: samples are drawn on-demand with replacement if ``True``, default=``False``
25
- num_samples: number of samples to draw, default=`len(dataset)`.
26
- generator: Generator used in sampling.
27
+ replacement: Samples are drawn on-demand with replacement if ``True``, default=``False``
28
+ num_samples: Number of samples to draw, default=``len(dataset)``.
29
+ seed: Optional seed for the random number generator.
27
30
  """
28
31
  self.replacement = replacement
29
32
  self._num_samples = num_samples
30
- self.generator = generator
33
+ self._generator = None
34
+
35
+ if seed is not None:
36
+ self._generator = torch.Generator()
37
+ self._generator.manual_seed(seed)
31
38
 
32
39
  @override
33
40
  def set_dataset(self, data_source: datasets.MapDataset) -> None:
34
41
  super().__init__(
35
42
  data_source,
36
43
  replacement=self.replacement,
37
- num_samples=self.num_samples,
38
- generator=self.generator,
44
+ num_samples=self._num_samples,
45
+ generator=self._generator,
39
46
  )
@@ -132,3 +132,24 @@ class Interface:
132
132
  n_runs=trainer.n_runs,
133
133
  verbose=trainer.n_runs > 1,
134
134
  )
135
+
136
+ def validate_test(
137
+ self,
138
+ trainer: eva_trainer.Trainer,
139
+ model: modules.ModelModule,
140
+ data: datamodules.DataModule,
141
+ ) -> None:
142
+ """Runs validation & test stages."""
143
+ if getattr(data.datasets, "val", None) is None:
144
+ raise ValueError("The provided data module does not contain a validation dataset.")
145
+ if getattr(data.datasets, "test", None) is None:
146
+ raise ValueError("The provided data module does not contain a test dataset.")
147
+
148
+ eva_trainer.run_evaluation_session(
149
+ base_trainer=trainer,
150
+ base_model=model,
151
+ datamodule=data,
152
+ stages=["validate", "test"],
153
+ n_runs=trainer.n_runs,
154
+ verbose=trainer.n_runs > 1,
155
+ )
@@ -33,8 +33,8 @@ class ModelModule(pl.LightningModule):
33
33
  super().__init__()
34
34
 
35
35
  self._metrics = metrics or self.default_metrics
36
- self._postprocess = postprocess or self.default_postprocess
37
36
 
37
+ self.postprocess = postprocess or self.default_postprocess
38
38
  self.metrics = metrics_lib.MetricModule.from_schema(self._metrics)
39
39
 
40
40
  @property
@@ -133,7 +133,7 @@ class ModelModule(pl.LightningModule):
133
133
  Returns:
134
134
  The updated outputs.
135
135
  """
136
- self._postprocess(outputs)
136
+ self.postprocess(outputs)
137
137
  return memory.recursive_detach(outputs, to_cpu=self.metrics_device.type == "cpu")
138
138
 
139
139
  def _forward_and_log_metrics(
@@ -25,7 +25,7 @@ class BaseModel(nn.Module, Generic[InputType, OutputType]):
25
25
 
26
26
  self._output_transforms = transforms
27
27
 
28
- self._model: Callable[..., OutputType] | nn.Module
28
+ self.model: Callable[..., OutputType] | nn.Module
29
29
 
30
30
  @override
31
31
  def forward(self, tensor: InputType) -> OutputType:
@@ -43,7 +43,7 @@ class BaseModel(nn.Module, Generic[InputType, OutputType]):
43
43
  Args:
44
44
  tensor: The input tensor to the model.
45
45
  """
46
- return self._model(tensor)
46
+ return self.model(tensor)
47
47
 
48
48
  def _apply_transforms(self, tensor: OutputType) -> OutputType:
49
49
  if self._output_transforms is not None:
@@ -41,12 +41,12 @@ class ModelFromFunction(base.BaseModel[torch.Tensor, torch.Tensor]):
41
41
  self._arguments = arguments
42
42
  self._checkpoint_path = checkpoint_path
43
43
 
44
- self.load_model()
44
+ self.model = self.load_model()
45
45
 
46
46
  @override
47
- def load_model(self) -> None:
47
+ def load_model(self) -> nn.Module:
48
48
  class_path = jsonargparse.class_from_function(self._path, func_return=nn.Module)
49
49
  model = class_path(**self._arguments or {})
50
50
  if self._checkpoint_path is not None:
51
51
  _utils.load_model_weights(model, self._checkpoint_path)
52
- self._model = model
52
+ return model
@@ -52,12 +52,12 @@ class TorchHubModel(base.BaseModel[torch.Tensor, torch.Tensor]):
52
52
  self._trust_repo = trust_repo
53
53
  self._model_kwargs = model_kwargs or {}
54
54
 
55
- self.load_model()
55
+ self.model = self.load_model()
56
56
 
57
57
  @override
58
- def load_model(self) -> None:
58
+ def load_model(self) -> nn.Module:
59
59
  """Builds and loads the torch.hub model."""
60
- self._model: nn.Module = torch.hub.load(
60
+ model: nn.Module = torch.hub.load(
61
61
  repo_or_dir=self._repo_or_dir,
62
62
  model=self._model_name,
63
63
  trust_repo=self._trust_repo,
@@ -66,21 +66,23 @@ class TorchHubModel(base.BaseModel[torch.Tensor, torch.Tensor]):
66
66
  ) # type: ignore
67
67
 
68
68
  if self._checkpoint_path:
69
- _utils.load_model_weights(self._model, self._checkpoint_path)
69
+ _utils.load_model_weights(model, self._checkpoint_path)
70
70
 
71
71
  TorchHubModel.__name__ = self._model_name
72
72
 
73
+ return model
74
+
73
75
  @override
74
76
  def model_forward(self, tensor: torch.Tensor) -> torch.Tensor | List[torch.Tensor]:
75
77
  if self._out_indices is not None:
76
- if not hasattr(self._model, "get_intermediate_layers"):
78
+ if not hasattr(self.model, "get_intermediate_layers"):
77
79
  raise ValueError(
78
80
  "Only models with `get_intermediate_layers` are supported "
79
81
  "when using `out_indices`."
80
82
  )
81
83
 
82
84
  return list(
83
- self._model.get_intermediate_layers(
85
+ self.model.get_intermediate_layers( # type: ignore
84
86
  tensor,
85
87
  self._out_indices,
86
88
  reshape=True,
@@ -89,4 +91,4 @@ class TorchHubModel(base.BaseModel[torch.Tensor, torch.Tensor]):
89
91
  )
90
92
  )
91
93
 
92
- return self._model(tensor)
94
+ return self.model(tensor)
@@ -4,6 +4,7 @@ from typing import Any, Callable, Dict
4
4
 
5
5
  import torch
6
6
  import transformers
7
+ from torch import nn
7
8
  from typing_extensions import override
8
9
 
9
10
  from eva.core.models.wrappers import base
@@ -33,12 +34,10 @@ class HuggingFaceModel(base.BaseModel[torch.Tensor, torch.Tensor]):
33
34
  self._model_name_or_path = model_name_or_path
34
35
  self._model_kwargs = model_kwargs or {}
35
36
 
36
- self.load_model()
37
+ self.model = self.load_model()
37
38
 
38
39
  @override
39
- def load_model(self) -> None:
40
+ def load_model(self) -> nn.Module:
40
41
  # Use safetensors to avoid torch.load security vulnerability
41
42
  model_kwargs = {"use_safetensors": True, **self._model_kwargs}
42
- self._model = transformers.AutoModel.from_pretrained(
43
- self._model_name_or_path, **model_kwargs
44
- )
43
+ return transformers.AutoModel.from_pretrained(self._model_name_or_path, **model_kwargs)
@@ -30,21 +30,21 @@ class ONNXModel(base.BaseModel[torch.Tensor, torch.Tensor]):
30
30
  self._path = path
31
31
  self._device = device
32
32
 
33
- self.load_model()
33
+ self.model = self.load_model()
34
34
 
35
35
  @override
36
36
  def load_model(self) -> Any:
37
37
  if self._device == "cuda" and not torch.cuda.is_available():
38
38
  raise ValueError("Device is set to 'cuda', but CUDA is not available.")
39
39
  provider = "CUDAExecutionProvider" if self._device == "cuda" else "CPUExecutionProvider"
40
- self._model = ort.InferenceSession(self._path, providers=[provider]) # type: ignore
40
+ return ort.InferenceSession(self._path, providers=[provider]) # type: ignore
41
41
 
42
42
  @override
43
43
  def model_forward(self, tensor: torch.Tensor) -> torch.Tensor:
44
44
  # TODO: Use IO binding to avoid copying the tensor to CPU.
45
45
  # https://onnxruntime.ai/docs/api/python/api_summary.html#data-on-device
46
- if not isinstance(self._model, ort.InferenceSession):
46
+ if not isinstance(self.model, ort.InferenceSession):
47
47
  raise ValueError("Model is not loaded.")
48
- inputs = {self._model.get_inputs()[0].name: tensor.detach().cpu().numpy()}
49
- outputs = self._model.run(None, inputs)[0]
48
+ inputs = {self.model.get_inputs()[0].name: tensor.detach().cpu().numpy()}
49
+ outputs = self.model.run(None, inputs)[0]
50
50
  return torch.from_numpy(outputs).float().to(tensor.device)
@@ -8,6 +8,7 @@ from lightning.pytorch import loggers as pl_loggers
8
8
  from lightning.pytorch import trainer as pl_trainer
9
9
  from lightning.pytorch.utilities import argparse
10
10
  from lightning_fabric.utilities import cloud_io
11
+ from lightning_utilities.core.rank_zero import rank_zero_only
11
12
  from typing_extensions import override
12
13
 
13
14
  from eva.core import loggers as eva_loggers
@@ -66,6 +67,7 @@ class Trainer(pl_trainer.Trainer):
66
67
  def log_dir(self) -> str | None:
67
68
  return self.strategy.broadcast(self._log_dir)
68
69
 
70
+ @rank_zero_only
69
71
  def init_logger_run(self, run_id: int | None) -> None:
70
72
  """Setup the loggers & log directories when starting a new run.
71
73
 
eva/language/__init__.py CHANGED
@@ -1,6 +1,7 @@
1
1
  """eva language API."""
2
2
 
3
3
  try:
4
+ from eva.language import models
4
5
  from eva.language.data import datasets
5
6
  except ImportError as e:
6
7
  msg = (
@@ -10,4 +11,4 @@ except ImportError as e:
10
11
  )
11
12
  raise ImportError(str(e) + "\n\n" + msg) from e
12
13
 
13
- __all__ = ["datasets"]
14
+ __all__ = ["models", "datasets"]
@@ -0,0 +1,5 @@
1
+ """Language callbacks API."""
2
+
3
+ from eva.language.callbacks.writers import TextPredictionWriter
4
+
5
+ __all__ = ["TextPredictionWriter"]
@@ -0,0 +1,5 @@
1
+ """Language writers callbacks API."""
2
+
3
+ from eva.language.callbacks.writers.prediction import TextPredictionWriter
4
+
5
+ __all__ = ["TextPredictionWriter"]
@@ -0,0 +1,176 @@
1
+ """Text prediction writer callbacks."""
2
+
3
+ import abc
4
+ import os
5
+ from typing import Any, Dict, List, Literal, Sequence, Tuple, TypedDict
6
+
7
+ import lightning.pytorch as pl
8
+ import pandas as pd
9
+ import torch
10
+ from lightning.pytorch import callbacks
11
+ from torch import nn
12
+ from typing_extensions import NotRequired, override
13
+
14
+ from eva.core.models.modules import utils as module_utils
15
+ from eva.language.models.typings import TextBatch
16
+ from eva.language.utils.text import messages as message_utils
17
+
18
+
19
+ class ManifestEntry(TypedDict):
20
+ """A single entry in the manifest file."""
21
+
22
+ prediction: str
23
+ """The predicted text."""
24
+
25
+ target: str
26
+ """The ground truth text."""
27
+
28
+ text: NotRequired[str]
29
+ """The input text data."""
30
+
31
+ split: NotRequired[str]
32
+ """The dataset split (e.g. train, val, test)."""
33
+
34
+
35
+ class TextPredictionWriter(callbacks.BasePredictionWriter, abc.ABC):
36
+ """Callback for writing generated text predictions to disk."""
37
+
38
+ def __init__(
39
+ self,
40
+ output_dir: str,
41
+ model: nn.Module,
42
+ dataloader_idx_map: Dict[int, str] | None = None,
43
+ metadata_keys: List[str] | None = None,
44
+ include_input: bool = True,
45
+ overwrite: bool = False,
46
+ apply_postprocess: bool = False,
47
+ save_format: Literal["jsonl", "parquet", "csv"] = "jsonl",
48
+ ) -> None:
49
+ """Initializes a new callback.
50
+
51
+ Args:
52
+ output_dir: The directory where the embeddings will be saved.
53
+ model: The model instance used to generate the predictions.
54
+ dataloader_idx_map: A dictionary mapping dataloader indices to their respective
55
+ names (e.g. train, val, test).
56
+ metadata_keys: An optional list of keys to extract from the batch metadata and store
57
+ as additional columns in the manifest file.
58
+ include_input: Whether to include the original input text messages in the output.
59
+ overwrite: Whether to overwrite if embeddings are already present in the specified
60
+ output directory. If set to `False`, an error will be raised if embeddings are
61
+ already present (recommended).
62
+ apply_postprocess: Whether to apply the postprocesses specified in the model module.
63
+ save_format: The file format to use for saving the manifest file with the predictions.
64
+ """
65
+ super().__init__()
66
+ self.output_dir = output_dir
67
+ self.model = model
68
+ self.dataloader_idx_map = dataloader_idx_map or {}
69
+ self.metadata_keys = metadata_keys
70
+ self.include_input = include_input
71
+ self.overwrite = overwrite
72
+ self.apply_postprocess = apply_postprocess
73
+ self.save_format = save_format
74
+
75
+ self._manifest_path = os.path.join(self.output_dir, f"manifest.{self.save_format}")
76
+ self._data: List[ManifestEntry] = []
77
+
78
+ @override
79
+ def on_predict_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
80
+ self._check_if_exists()
81
+
82
+ self.model = self.model.to(pl_module.device)
83
+ self.model.eval()
84
+
85
+ @override
86
+ def write_on_batch_end(
87
+ self,
88
+ trainer: pl.Trainer,
89
+ pl_module: pl.LightningModule,
90
+ prediction: Any,
91
+ batch_indices: Sequence[int],
92
+ batch: TextBatch,
93
+ batch_idx: int,
94
+ dataloader_idx: int,
95
+ ) -> None:
96
+ text_batch, target_batch, metadata_batch = self._unpack_batch(batch)
97
+ has_target = target_batch is not None
98
+ split = self.dataloader_idx_map.get(dataloader_idx, "")
99
+
100
+ prediction_batch = self._get_predictions(batch)
101
+
102
+ target_batch, prediction_batch = self._apply_postprocess(
103
+ pl_module, target_batch, prediction_batch
104
+ )
105
+
106
+ for i in range(len(batch_indices)):
107
+ entry: ManifestEntry = {
108
+ "text": message_utils.serialize(text_batch[i]),
109
+ "prediction": str(prediction_batch[i]),
110
+ "target": str(target_batch[i]) if has_target else "",
111
+ "split": split if split else "",
112
+ }
113
+
114
+ if self.metadata_keys is not None and metadata_batch is not None:
115
+ for key in self.metadata_keys:
116
+ entry[key] = metadata_batch[key][i]
117
+
118
+ self._data.append(entry)
119
+
120
+ @override
121
+ def on_predict_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
122
+ """Saves the gathered predictions to a manifest file."""
123
+ df = pd.DataFrame(self._data)
124
+
125
+ match self.save_format:
126
+ case "jsonl":
127
+ df.to_json(self._manifest_path, orient="records", lines=True)
128
+ case "parquet":
129
+ df.to_parquet(self._manifest_path, index=False)
130
+ case "csv":
131
+ df.to_csv(self._manifest_path, index=False)
132
+ case _:
133
+ raise ValueError(f"Unsupported save format: {self.save_format}")
134
+
135
+ def _get_predictions(self, batch: TextBatch) -> List[str]:
136
+ with torch.no_grad():
137
+ predictions = self.model(batch)
138
+
139
+ if not isinstance(predictions, list) or not all(isinstance(p, str) for p in predictions):
140
+ raise ValueError("The model's output should be a list of strings.")
141
+
142
+ return predictions
143
+
144
+ def _check_if_exists(self) -> None:
145
+ """Checks if the output directory already exists and if it should be overwritten."""
146
+ os.makedirs(self.output_dir, exist_ok=True)
147
+ if os.path.exists(self._manifest_path) and not self.overwrite:
148
+ raise FileExistsError(
149
+ f"The specified output directory already exists: {self.output_dir}. This "
150
+ "either means that the predictions have been computed before or that a "
151
+ "wrong output directory is being used."
152
+ )
153
+ os.makedirs(self.output_dir, exist_ok=True)
154
+
155
+ def _apply_postprocess(
156
+ self, pl_module: pl.LightningModule, targets: Any, predictions: Any
157
+ ) -> Tuple[List[Any], List[Any]]:
158
+ def _to_list(data: Any) -> List[Any]:
159
+ if isinstance(data, torch.Tensor):
160
+ return data.cpu().tolist()
161
+ return data
162
+
163
+ if self.apply_postprocess and hasattr(pl_module, "postprocess"):
164
+ if (
165
+ isinstance(pl_module.postprocess, module_utils.BatchPostProcess)
166
+ and pl_module.postprocess.predictions_transforms is not None
167
+ ):
168
+ outputs = {"targets": targets, "predictions": predictions}
169
+ pl_module.postprocess(outputs)
170
+ targets, predictions = outputs["targets"], outputs["predictions"]
171
+
172
+ return _to_list(targets), _to_list(predictions)
173
+
174
+ def _unpack_batch(self, batch: TextBatch) -> Tuple[list, list | None, dict | None]:
175
+ text_batch, target_batch, metadata_batch = TextBatch(*batch)
176
+ return text_batch, target_batch, metadata_batch
@@ -0,0 +1,5 @@
1
+ """Language Dataloaders API."""
2
+
3
+ from eva.language.data.dataloaders.collate_fn import prediction_collate, text_collate
4
+
5
+ __all__ = ["text_collate", "prediction_collate"]
@@ -0,0 +1,5 @@
1
+ """Collate functions API."""
2
+
3
+ from eva.language.data.dataloaders.collate_fn.text import prediction_collate, text_collate
4
+
5
+ __all__ = ["text_collate", "prediction_collate"]
@@ -0,0 +1,57 @@
1
+ """Collate functions for text data."""
2
+
3
+ from typing import List
4
+
5
+ from torch.utils.data._utils.collate import default_collate
6
+
7
+ from eva.language.data.datasets.typings import PredictionSample, TextSample
8
+ from eva.language.models.typings import PredictionBatch, TextBatch
9
+
10
+
11
+ def text_collate(batch: List[TextSample]) -> TextBatch:
12
+ """Collate function for text data that keeps texts as separate strings.
13
+
14
+ Args:
15
+ batch: List of tuples containing (text, target, metadata) from the dataset
16
+
17
+ Returns:
18
+ A batch of text samples with targets and metadata.
19
+ """
20
+ texts, targets, metadata = zip(*batch, strict=False)
21
+ first_sample = batch[0]
22
+ metadata = None
23
+ if first_sample.metadata is not None:
24
+ metadata = {
25
+ k: [sample.metadata[k] for sample in batch if sample.metadata]
26
+ for k in first_sample.metadata.keys()
27
+ }
28
+ return TextBatch(
29
+ text=list(texts),
30
+ target=default_collate(targets) if targets[0] is not None else None,
31
+ metadata=metadata,
32
+ )
33
+
34
+
35
+ def prediction_collate(batch: List[PredictionSample]) -> PredictionBatch:
36
+ """Collate function for text prediction data.
37
+
38
+ Args:
39
+ batch: List of tuples containing (prediction, target, text, metadata) from the dataset
40
+
41
+ Returns:
42
+ A batch of prediction samples.
43
+ """
44
+ predictions, targets, texts, metadata = zip(*batch, strict=False)
45
+ first_sample = batch[0]
46
+ metadata = None
47
+ if first_sample.metadata is not None:
48
+ metadata = {
49
+ k: [sample.metadata[k] for sample in batch if sample.metadata]
50
+ for k in first_sample.metadata.keys()
51
+ }
52
+ return PredictionBatch(
53
+ prediction=default_collate(predictions) if predictions[0] is not None else None,
54
+ target=default_collate(targets) if targets[0] is not None else None,
55
+ text=list(texts) if first_sample.text is not None else None,
56
+ metadata=metadata,
57
+ )
@@ -1,9 +1,11 @@
1
1
  """Language Datasets API."""
2
2
 
3
+ from eva.language.data.datasets.base import LanguageDataset
3
4
  from eva.language.data.datasets.classification import PubMedQA
4
- from eva.language.data.datasets.language import LanguageDataset
5
+ from eva.language.data.datasets.prediction import TextPredictionDataset
5
6
 
6
7
  __all__ = [
7
8
  "PubMedQA",
8
9
  "LanguageDataset",
10
+ "TextPredictionDataset",
9
11
  ]
@@ -10,4 +10,4 @@ DataSample = TypeVar("DataSample")
10
10
 
11
11
 
12
12
  class LanguageDataset(base.MapDataset, abc.ABC, Generic[DataSample]):
13
- """Base dataset class for text tasks."""
13
+ """Base dataset class for language tasks."""
@@ -1,15 +1,13 @@
1
1
  """Base for text classification datasets."""
2
2
 
3
- import abc
4
- from typing import Any, Dict, List, Tuple
3
+ from typing import Dict, List
5
4
 
6
5
  import torch
7
- from typing_extensions import override
8
6
 
9
- from eva.language.data.datasets.language import LanguageDataset
7
+ from eva.language.data.datasets.text import TextDataset
10
8
 
11
9
 
12
- class TextClassification(LanguageDataset[Tuple[str, torch.Tensor, Dict[str, Any]]], abc.ABC):
10
+ class TextClassification(TextDataset[torch.Tensor]):
13
11
  """Text classification abstract dataset."""
14
12
 
15
13
  def __init__(self) -> None:
@@ -23,41 +21,3 @@ class TextClassification(LanguageDataset[Tuple[str, torch.Tensor, Dict[str, Any]
23
21
  @property
24
22
  def class_to_idx(self) -> Dict[str, int] | None:
25
23
  """Returns class name to index mapping."""
26
-
27
- def load_metadata(self, index: int) -> Dict[str, Any] | None:
28
- """Returns the dataset metadata.
29
-
30
- Args:
31
- index: The index of the data sample.
32
-
33
- Returns:
34
- The sample metadata.
35
- """
36
-
37
- @abc.abstractmethod
38
- def load_text(self, index: int) -> str:
39
- """Returns the text content.
40
-
41
- Args:
42
- index: The index of the data sample.
43
-
44
- Returns:
45
- The text content.
46
- """
47
- raise NotImplementedError
48
-
49
- @abc.abstractmethod
50
- def load_target(self, index: int) -> torch.Tensor:
51
- """Returns the target label.
52
-
53
- Args:
54
- index: The index of the data sample.
55
-
56
- Returns:
57
- The target label.
58
- """
59
- raise NotImplementedError
60
-
61
- @override
62
- def __getitem__(self, index: int) -> Tuple[str, torch.Tensor, Dict[str, Any]]:
63
- return (self.load_text(index), self.load_target(index), self.load_metadata(index) or {})