kaiko-eva 0.3.3__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaiko-eva might be problematic. Click here for more details.

Files changed (131) hide show
  1. eva/core/callbacks/config.py +15 -6
  2. eva/core/callbacks/writers/embeddings/base.py +44 -10
  3. eva/core/cli/setup.py +1 -1
  4. eva/core/data/dataloaders/__init__.py +1 -2
  5. eva/core/data/samplers/classification/balanced.py +24 -12
  6. eva/core/data/samplers/random.py +17 -10
  7. eva/core/interface/interface.py +21 -0
  8. eva/core/loggers/utils/wandb.py +4 -1
  9. eva/core/models/modules/module.py +2 -2
  10. eva/core/models/wrappers/base.py +2 -2
  11. eva/core/models/wrappers/from_function.py +3 -3
  12. eva/core/models/wrappers/from_torchhub.py +9 -7
  13. eva/core/models/wrappers/huggingface.py +4 -5
  14. eva/core/models/wrappers/onnx.py +5 -5
  15. eva/core/trainers/trainer.py +13 -1
  16. eva/core/utils/__init__.py +2 -1
  17. eva/core/utils/distributed.py +12 -0
  18. eva/core/utils/paths.py +14 -0
  19. eva/core/utils/requirements.py +52 -6
  20. eva/language/__init__.py +2 -1
  21. eva/language/callbacks/__init__.py +5 -0
  22. eva/language/callbacks/writers/__init__.py +5 -0
  23. eva/language/callbacks/writers/prediction.py +201 -0
  24. eva/language/data/dataloaders/__init__.py +5 -0
  25. eva/language/data/dataloaders/collate_fn/__init__.py +5 -0
  26. eva/language/data/dataloaders/collate_fn/text.py +57 -0
  27. eva/language/data/datasets/__init__.py +3 -1
  28. eva/language/data/datasets/{language.py → base.py} +1 -1
  29. eva/language/data/datasets/classification/base.py +3 -43
  30. eva/language/data/datasets/classification/pubmedqa.py +36 -4
  31. eva/language/data/datasets/prediction.py +151 -0
  32. eva/language/data/datasets/schemas.py +18 -0
  33. eva/language/data/datasets/text.py +92 -0
  34. eva/language/data/datasets/typings.py +39 -0
  35. eva/language/data/messages.py +60 -0
  36. eva/language/models/__init__.py +15 -11
  37. eva/language/models/modules/__init__.py +2 -2
  38. eva/language/models/modules/language.py +94 -0
  39. eva/language/models/networks/__init__.py +12 -0
  40. eva/language/models/networks/alibaba.py +26 -0
  41. eva/language/models/networks/api/__init__.py +11 -0
  42. eva/language/models/networks/api/anthropic.py +34 -0
  43. eva/language/models/networks/registry.py +5 -0
  44. eva/language/models/typings.py +56 -0
  45. eva/language/models/wrappers/__init__.py +13 -5
  46. eva/language/models/wrappers/base.py +47 -0
  47. eva/language/models/wrappers/from_registry.py +54 -0
  48. eva/language/models/wrappers/huggingface.py +57 -11
  49. eva/language/models/wrappers/litellm.py +91 -46
  50. eva/language/models/wrappers/vllm.py +37 -13
  51. eva/language/utils/__init__.py +2 -1
  52. eva/language/utils/str_to_int_tensor.py +20 -12
  53. eva/language/utils/text/__init__.py +5 -0
  54. eva/language/utils/text/messages.py +113 -0
  55. eva/multimodal/__init__.py +6 -0
  56. eva/multimodal/callbacks/__init__.py +5 -0
  57. eva/multimodal/callbacks/writers/__init__.py +5 -0
  58. eva/multimodal/callbacks/writers/prediction.py +39 -0
  59. eva/multimodal/data/__init__.py +5 -0
  60. eva/multimodal/data/dataloaders/__init__.py +5 -0
  61. eva/multimodal/data/dataloaders/collate_fn/__init__.py +5 -0
  62. eva/multimodal/data/dataloaders/collate_fn/text_image.py +28 -0
  63. eva/multimodal/data/datasets/__init__.py +6 -0
  64. eva/multimodal/data/datasets/base.py +13 -0
  65. eva/multimodal/data/datasets/multiple_choice/__init__.py +5 -0
  66. eva/multimodal/data/datasets/multiple_choice/patch_camelyon.py +80 -0
  67. eva/multimodal/data/datasets/schemas.py +14 -0
  68. eva/multimodal/data/datasets/text_image.py +77 -0
  69. eva/multimodal/data/datasets/typings.py +27 -0
  70. eva/multimodal/models/__init__.py +8 -0
  71. eva/multimodal/models/modules/__init__.py +5 -0
  72. eva/multimodal/models/modules/vision_language.py +56 -0
  73. eva/multimodal/models/networks/__init__.py +14 -0
  74. eva/multimodal/models/networks/alibaba.py +40 -0
  75. eva/multimodal/models/networks/api/__init__.py +11 -0
  76. eva/multimodal/models/networks/api/anthropic.py +34 -0
  77. eva/multimodal/models/networks/others.py +48 -0
  78. eva/multimodal/models/networks/registry.py +5 -0
  79. eva/multimodal/models/typings.py +27 -0
  80. eva/multimodal/models/wrappers/__init__.py +13 -0
  81. eva/multimodal/models/wrappers/base.py +48 -0
  82. eva/multimodal/models/wrappers/from_registry.py +54 -0
  83. eva/multimodal/models/wrappers/huggingface.py +193 -0
  84. eva/multimodal/models/wrappers/litellm.py +58 -0
  85. eva/multimodal/utils/__init__.py +1 -0
  86. eva/multimodal/utils/batch/__init__.py +5 -0
  87. eva/multimodal/utils/batch/unpack.py +11 -0
  88. eva/multimodal/utils/image/__init__.py +5 -0
  89. eva/multimodal/utils/image/encode.py +28 -0
  90. eva/multimodal/utils/text/__init__.py +1 -0
  91. eva/multimodal/utils/text/messages.py +79 -0
  92. eva/vision/data/datasets/classification/breakhis.py +5 -8
  93. eva/vision/data/datasets/classification/panda.py +12 -5
  94. eva/vision/data/datasets/classification/patch_camelyon.py +8 -6
  95. eva/vision/data/datasets/segmentation/btcv.py +1 -1
  96. eva/vision/data/datasets/segmentation/consep.py +1 -1
  97. eva/vision/data/datasets/segmentation/lits17.py +1 -1
  98. eva/vision/data/datasets/segmentation/monusac.py +15 -6
  99. eva/vision/data/datasets/segmentation/msd_task7_pancreas.py +1 -1
  100. eva/vision/data/transforms/__init__.py +2 -1
  101. eva/vision/data/transforms/base/__init__.py +2 -1
  102. eva/vision/data/transforms/base/monai.py +2 -2
  103. eva/vision/data/transforms/base/torchvision.py +33 -0
  104. eva/vision/data/transforms/common/squeeze.py +6 -3
  105. eva/vision/data/transforms/croppad/crop_foreground.py +8 -7
  106. eva/vision/data/transforms/croppad/rand_crop_by_label_classes.py +6 -5
  107. eva/vision/data/transforms/croppad/rand_crop_by_pos_neg_label.py +6 -5
  108. eva/vision/data/transforms/croppad/rand_spatial_crop.py +8 -7
  109. eva/vision/data/transforms/croppad/spatial_pad.py +6 -6
  110. eva/vision/data/transforms/intensity/rand_scale_intensity.py +3 -3
  111. eva/vision/data/transforms/intensity/rand_shift_intensity.py +3 -3
  112. eva/vision/data/transforms/intensity/scale_intensity_ranged.py +5 -5
  113. eva/vision/data/transforms/spatial/__init__.py +2 -1
  114. eva/vision/data/transforms/spatial/flip.py +8 -7
  115. eva/vision/data/transforms/spatial/functional/__init__.py +5 -0
  116. eva/vision/data/transforms/spatial/functional/resize.py +26 -0
  117. eva/vision/data/transforms/spatial/resize.py +63 -0
  118. eva/vision/data/transforms/spatial/rotate.py +8 -7
  119. eva/vision/data/transforms/spatial/spacing.py +7 -6
  120. eva/vision/data/transforms/utility/ensure_channel_first.py +6 -6
  121. eva/vision/models/networks/backbones/universal/vit.py +24 -0
  122. eva/vision/models/wrappers/from_registry.py +6 -5
  123. eva/vision/models/wrappers/from_timm.py +6 -4
  124. {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/METADATA +17 -3
  125. {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/RECORD +128 -66
  126. eva/core/data/dataloaders/collate_fn/__init__.py +0 -5
  127. eva/core/data/dataloaders/collate_fn/collate.py +0 -24
  128. eva/language/models/modules/text.py +0 -85
  129. {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/WHEEL +0 -0
  130. {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/entry_points.txt +0 -0
  131. {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,22 +1,34 @@
1
1
  """LLM wrapper for HuggingFace `transformers` models."""
2
2
 
3
- from typing import Any, Dict, List, Literal
3
+ from typing import Any, Callable, Dict, List, Literal
4
4
 
5
5
  from transformers.pipelines import pipeline
6
6
  from typing_extensions import override
7
7
 
8
- from eva.core.models.wrappers import base
8
+ from eva.language.models.typings import ModelOutput, TextBatch
9
+ from eva.language.models.wrappers import base
10
+ from eva.language.utils.text import messages as message_utils
9
11
 
10
12
 
11
- class HuggingFaceTextModel(base.BaseModel[List[str], List[str]]):
13
+ class HuggingFaceModel(base.LanguageModel):
12
14
  """Wrapper class for loading HuggingFace `transformers` models using pipelines."""
13
15
 
16
+ _default_generation_kwargs = {
17
+ "temperature": 0.0,
18
+ "max_new_tokens": 1024,
19
+ "do_sample": False,
20
+ "top_p": 1.0,
21
+ }
22
+ """Default HF model parameters for evaluation."""
23
+
14
24
  def __init__(
15
25
  self,
16
26
  model_name_or_path: str,
17
27
  task: Literal["text-generation"] = "text-generation",
18
28
  model_kwargs: Dict[str, Any] | None = None,
29
+ system_prompt: str | None = None,
19
30
  generation_kwargs: Dict[str, Any] | None = None,
31
+ chat_mode: bool = True,
20
32
  ) -> None:
21
33
  """Initializes the model.
22
34
 
@@ -26,21 +38,26 @@ class HuggingFaceTextModel(base.BaseModel[List[str], List[str]]):
26
38
  model hub.
27
39
  task: The pipeline task. Defaults to "text-generation".
28
40
  model_kwargs: Additional arguments for configuring the pipeline.
41
+ system_prompt: System prompt to use.
29
42
  generation_kwargs: Additional generation parameters (temperature, max_length, etc.).
43
+ chat_mode: Whether the specified model expects chat style messages. If set to False
44
+ the model is assumed to be a standard text completion model and will expect
45
+ plain text string inputs.
30
46
  """
31
- super().__init__()
47
+ super().__init__(system_prompt=system_prompt)
32
48
 
33
49
  self._model_name_or_path = model_name_or_path
34
50
  self._task = task
35
51
  self._model_kwargs = model_kwargs or {}
36
- self._generation_kwargs = generation_kwargs or {}
52
+ self._generation_kwargs = self._default_generation_kwargs | (generation_kwargs or {})
53
+ self._chat_mode = chat_mode
37
54
 
38
- self.load_model()
55
+ self.model = self.load_model()
39
56
 
40
57
  @override
41
- def load_model(self) -> None:
58
+ def load_model(self) -> Callable:
42
59
  """Loads the model as a Hugging Face pipeline."""
43
- self._pipeline = pipeline(
60
+ return pipeline(
44
61
  task=self._task,
45
62
  model=self._model_name_or_path,
46
63
  trust_remote_code=True,
@@ -48,7 +65,34 @@ class HuggingFaceTextModel(base.BaseModel[List[str], List[str]]):
48
65
  )
49
66
 
50
67
  @override
51
- def model_forward(self, prompts: List[str]) -> List[str]:
68
+ def format_inputs(self, batch: TextBatch) -> List[List[Dict[str, Any]]] | List[str]:
69
+ """Formats inputs for HuggingFace models.
70
+
71
+ Note: If multiple system messages are present, they will be combined
72
+ into a single message, given that many models only support a single
73
+ system prompt.
74
+
75
+ Args:
76
+ batch: A batch of text and image inputs.
77
+
78
+ Returns:
79
+ When in chat mode, returns a batch of message series following
80
+ OpenAI's API format {"role": "user", "content": "..."}, for non-chat
81
+ models returns a list of plain text strings.
82
+ """
83
+ message_batch, _, _ = TextBatch(*batch)
84
+ message_batch = message_utils.batch_insert_system_message(
85
+ message_batch, self.system_message
86
+ )
87
+ message_batch = list(map(message_utils.combine_system_messages, message_batch))
88
+
89
+ if self._chat_mode:
90
+ return list(map(message_utils.format_chat_message, message_batch))
91
+ else:
92
+ return list(map(message_utils.merge_message_contents, message_batch))
93
+
94
+ @override
95
+ def model_forward(self, prompts: List[str]) -> ModelOutput:
52
96
  """Generates text using the pipeline.
53
97
 
54
98
  Args:
@@ -57,13 +101,15 @@ class HuggingFaceTextModel(base.BaseModel[List[str], List[str]]):
57
101
  Returns:
58
102
  The generated text as a string.
59
103
  """
60
- outputs = self._pipeline(prompts, return_full_text=False, **self._generation_kwargs)
104
+ outputs = self.model(prompts, return_full_text=False, **self._generation_kwargs)
61
105
  if outputs is None:
62
106
  raise ValueError("Outputs from the model are None.")
107
+
63
108
  results = []
64
109
  for output in outputs:
65
110
  if isinstance(output, list):
66
111
  results.append(output[0]["generated_text"]) # type: ignore
67
112
  else:
68
113
  results.append(output["generated_text"]) # type: ignore
69
- return results
114
+
115
+ return ModelOutput(generated_text=results)
@@ -1,77 +1,122 @@
1
- """LLM wrapper for litellm models."""
1
+ """LiteLLM language model wrapper."""
2
2
 
3
+ import logging
3
4
  from typing import Any, Dict, List
4
5
 
5
- from litellm import batch_completion # type: ignore
6
+ import backoff
7
+ import litellm
8
+ from litellm import batch_completion
9
+ from litellm.exceptions import (
10
+ APIConnectionError,
11
+ InternalServerError,
12
+ RateLimitError,
13
+ ServiceUnavailableError,
14
+ Timeout,
15
+ )
6
16
  from loguru import logger
7
17
  from typing_extensions import override
8
18
 
9
- from eva.core.models.wrappers import base
19
+ from eva.language.models.typings import ModelOutput, TextBatch
20
+ from eva.language.models.wrappers import base
21
+ from eva.language.utils.text import messages as message_utils
10
22
 
23
+ RETRYABLE_ERRORS = (
24
+ RateLimitError,
25
+ Timeout,
26
+ InternalServerError,
27
+ APIConnectionError,
28
+ ServiceUnavailableError,
29
+ )
11
30
 
12
- class LiteLLMTextModel(base.BaseModel[List[str], List[str]]):
13
- """Wrapper class for using litellm for chat-based text generation.
14
31
 
15
- This wrapper uses litellm's `completion` function which accepts a list of
16
- message dicts. The `forward` method converts a string prompt into a chat
17
- message with a default "user" role, optionally prepends a system message,
18
- and includes an API key if provided.
19
- """
32
+ class LiteLLMModel(base.LanguageModel):
33
+ """Wrapper class for LiteLLM language models."""
34
+
35
+ _default_model_kwargs = {
36
+ "temperature": 0.0,
37
+ "max_completion_tokens": 1024,
38
+ "top_p": 1.0,
39
+ "seed": 42,
40
+ }
41
+ """Default API model parameters for evaluation."""
20
42
 
21
43
  def __init__(
22
44
  self,
23
- model_name_or_path: str,
45
+ model_name: str,
24
46
  model_kwargs: Dict[str, Any] | None = None,
25
- ) -> None:
26
- """Initializes the litellm chat model wrapper.
47
+ system_prompt: str | None = None,
48
+ log_level: int | None = logging.INFO,
49
+ ):
50
+ """Initialize the LiteLLM Wrapper.
27
51
 
28
52
  Args:
29
- model_name_or_path: The model identifier (or name) for litellm
30
- (e.g.,"openai/gpt-4o" or "anthropic/claude-3-sonnet-20240229").
53
+ model_name: The name of the model to use.
31
54
  model_kwargs: Additional keyword arguments to pass during
32
55
  generation (e.g., `temperature`, `max_tokens`).
56
+ system_prompt: The system prompt to use (optional).
57
+ log_level: Optional logging level for LiteLLM. Defaults to WARNING.
33
58
  """
34
- super().__init__()
35
- self._model_name_or_path = model_name_or_path
36
- self._model_kwargs = model_kwargs or {}
37
- self.load_model()
59
+ super().__init__(system_prompt=system_prompt)
38
60
 
39
- @override
40
- def load_model(self) -> None:
41
- """Prepares the litellm model.
61
+ self.model_name = model_name
62
+ self.model_kwargs = self._default_model_kwargs | (model_kwargs or {})
42
63
 
43
- Note:
44
- litellm doesn't require an explicit loading step; models are called
45
- directly during generation. This method exists for API consistency.
46
- """
47
- pass
64
+ litellm.suppress_debug_info = True
65
+ litellm.drop_params = True
66
+
67
+ if log_level is not None:
68
+ logging.getLogger("LiteLLM").setLevel(log_level)
48
69
 
49
70
  @override
50
- def model_forward(self, prompts: List[str]) -> List[str]:
51
- """Generates text using litellm.
71
+ def format_inputs(self, batch: TextBatch) -> List[List[Dict[str, Any]]]:
72
+ """Formats inputs for LiteLLM.
52
73
 
53
74
  Args:
54
- prompts: A list of prompts to be converted into a "user" message.
75
+ batch: A batch of text inputs.
55
76
 
56
77
  Returns:
57
- A list of generated text responses. Failed generations will contain
58
- error messages instead of generated text.
78
+ A list of messages in the following format:
79
+ [
80
+ {
81
+ "role": ...
82
+ "content": ...
83
+ },
84
+ ...
85
+ ]
59
86
  """
60
- messages = [[{"role": "user", "content": prompt}] for prompt in prompts]
87
+ message_batch, _, _ = TextBatch(*batch)
61
88
 
62
- responses = batch_completion(
63
- model=self._model_name_or_path,
64
- messages=messages,
65
- **self._model_kwargs,
89
+ message_batch = message_utils.batch_insert_system_message(
90
+ message_batch, self.system_message
66
91
  )
92
+ message_batch = list(map(message_utils.combine_system_messages, message_batch))
67
93
 
68
- results = []
69
- for i, response in enumerate(responses):
70
- if isinstance(response, Exception):
71
- error_msg = f"Error generating text for prompt {i}: {response}"
72
- logger.error(error_msg)
73
- raise RuntimeError(error_msg)
74
- else:
75
- results.append(response["choices"][0]["message"]["content"])
94
+ return list(map(message_utils.format_chat_message, message_batch))
76
95
 
77
- return results
96
+ @override
97
+ @backoff.on_exception(
98
+ backoff.expo,
99
+ RETRYABLE_ERRORS,
100
+ max_tries=20,
101
+ jitter=backoff.full_jitter,
102
+ on_backoff=lambda details: logger.warning(
103
+ f"Retrying due to {details.get('exception') or 'Unknown error'}"
104
+ ),
105
+ )
106
+ def model_forward(self, batch: List[List[Dict[str, Any]]]) -> ModelOutput:
107
+ """Generates output text through API calls via LiteLLM's batch completion functionality."""
108
+ outputs = batch_completion(model=self.model_name, messages=batch, **self.model_kwargs)
109
+ self._raise_exceptions(outputs)
110
+
111
+ generated_text = [
112
+ output["choices"][0]["message"]["content"]
113
+ for output in outputs
114
+ if output["choices"][0]["message"]["role"] == "assistant"
115
+ ]
116
+ return ModelOutput(generated_text=generated_text)
117
+
118
+ def _raise_exceptions(self, outputs: list):
119
+ for output in outputs:
120
+ if isinstance(output, Exception):
121
+ logger.error(f"Model {self.model_name} encountered an error: {output}")
122
+ raise output
@@ -1,6 +1,6 @@
1
1
  """LLM wrapper for vLLM models."""
2
2
 
3
- from typing import Any, Dict, List, Sequence
3
+ from typing import Any, Dict, List
4
4
 
5
5
  from loguru import logger
6
6
  from typing_extensions import override
@@ -11,17 +11,20 @@ try:
11
11
  from vllm.transformers_utils.tokenizer import AnyTokenizer # type: ignore
12
12
  except ImportError as e:
13
13
  raise ImportError(
14
- "vLLM is required for VLLMTextModel but not installed. "
14
+ "vLLM is required for VllmModel but not installed. "
15
15
  "vLLM must be installed manually as it requires CUDA and is not included in dependencies. "
16
16
  "Install with: pip install vllm "
17
17
  "Note: vLLM requires Linux with CUDA support for optimal performance. "
18
- "For alternatives, consider using HuggingFaceTextModel or LiteLLMTextModel."
18
+ "For alternatives, consider using HuggingFaceModel or LiteLLMModel."
19
19
  ) from e
20
20
 
21
- from eva.core.models.wrappers import base
21
+ from eva.language.data.messages import MessageSeries
22
+ from eva.language.models.typings import TextBatch
23
+ from eva.language.models.wrappers import base
24
+ from eva.language.utils.text import messages as message_utils
22
25
 
23
26
 
24
- class VLLMTextModel(base.BaseModel):
27
+ class VllmModel(base.LanguageModel):
25
28
  """Wrapper class for using vLLM for text generation.
26
29
 
27
30
  This wrapper loads a vLLM model, sets up the tokenizer and sampling
@@ -34,6 +37,7 @@ class VLLMTextModel(base.BaseModel):
34
37
  self,
35
38
  model_name_or_path: str,
36
39
  model_kwargs: Dict[str, Any] | None = None,
40
+ system_prompt: str | None = None,
37
41
  generation_kwargs: Dict[str, Any] | None = None,
38
42
  ) -> None:
39
43
  """Initializes the vLLM model wrapper.
@@ -44,12 +48,13 @@ class VLLMTextModel(base.BaseModel):
44
48
  model_kwargs: Arguments required to initialize the vLLM model,
45
49
  see [link](https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py)
46
50
  for more information.
51
+ system_prompt: System prompt to use.
47
52
  generation_kwargs: Arguments required to generate the output,
48
53
  need to align with the arguments of
49
54
  [vllm.SamplingParams](https://github.com/vllm-project/vllm/blob/main/vllm/sampling_params.py).
50
55
 
51
56
  """
52
- super().__init__()
57
+ super().__init__(system_prompt=system_prompt)
53
58
  self._model_name_or_path = model_name_or_path
54
59
  self._model_kwargs = model_kwargs or {}
55
60
  self._generation_kwargs = generation_kwargs or {}
@@ -71,11 +76,11 @@ class VLLMTextModel(base.BaseModel):
71
76
  raise RuntimeError("Model not initialized")
72
77
  self._llm_tokenizer = self._llm_model.get_tokenizer()
73
78
 
74
- def _apply_chat_template(self, prompts: Sequence[str]) -> list[TokensPrompt]:
79
+ def _tokenize_messages(self, messages: List[MessageSeries]) -> List[TokensPrompt]:
75
80
  """Apply chat template to the messages.
76
81
 
77
82
  Args:
78
- prompts: List of raw user strings.
83
+ messages: List of raw user strings.
79
84
 
80
85
  Returns:
81
86
  List of encoded messages.
@@ -90,7 +95,8 @@ class VLLMTextModel(base.BaseModel):
90
95
  if not hasattr(self._llm_tokenizer, "chat_template"):
91
96
  raise ValueError("Tokenizer does not have a chat template.")
92
97
 
93
- chat_messages = [[{"role": "user", "content": p}] for p in prompts]
98
+ chat_messages = list(map(message_utils.format_chat_message, messages))
99
+
94
100
  encoded_messages = self._llm_tokenizer.apply_chat_template(
95
101
  chat_messages, # type: ignore
96
102
  tokenize=True,
@@ -131,11 +137,30 @@ class VLLMTextModel(base.BaseModel):
131
137
 
132
138
  return result
133
139
 
134
- def generate(self, prompts: List[str]) -> List[str]:
140
+ @override
141
+ def format_inputs(self, batch: TextBatch) -> List[TokensPrompt]:
142
+ """Formats inputs for vLLM models.
143
+
144
+ Args:
145
+ batch: A batch of text and image inputs.
146
+
147
+ Returns:
148
+ List of formatted prompts.
149
+ """
150
+ message_batch, _, _ = TextBatch(*batch)
151
+ message_batch = message_utils.batch_insert_system_message(
152
+ message_batch, self.system_message
153
+ )
154
+ message_batch = list(map(message_utils.combine_system_messages, message_batch))
155
+
156
+ return self._tokenize_messages(message_batch)
157
+
158
+ @override
159
+ def model_forward(self, batch: List[TokensPrompt]) -> List[str]:
135
160
  """Generates text for the given prompt using the vLLM model.
136
161
 
137
162
  Args:
138
- prompts: A list of string prompts for generation.
163
+ batch: A list encoded / tokenized messages (TokensPrompt objects).
139
164
 
140
165
  Returns:
141
166
  The generated text response.
@@ -144,6 +169,5 @@ class VLLMTextModel(base.BaseModel):
144
169
  if self._llm_model is None:
145
170
  raise RuntimeError("Model not initialized")
146
171
 
147
- prompt_tokens = self._apply_chat_template(prompts)
148
- outputs = self._llm_model.generate(prompt_tokens, SamplingParams(**self._generation_kwargs))
172
+ outputs = self._llm_model.generate(batch, SamplingParams(**self._generation_kwargs))
149
173
  return [output.outputs[0].text for output in outputs]
@@ -1,5 +1,6 @@
1
1
  """Language utilities and helper functions."""
2
2
 
3
3
  from eva.language.utils.str_to_int_tensor import CastStrToIntTensor
4
+ from eva.language.utils.text.messages import format_chat_message
4
5
 
5
- __all__ = ["CastStrToIntTensor"]
6
+ __all__ = ["CastStrToIntTensor", "format_chat_message"]
@@ -16,11 +16,11 @@ class CastStrToIntTensor:
16
16
  Supports single values, lists of strings, or lists of integers.
17
17
 
18
18
  Example:
19
- >>> # Default mapping for yes/no/maybe classification
20
- >>> transform = CastStrToIntTensor()
21
- >>> transform(['yes', 'no', 'maybe'])
19
+ >>> # Default mapping for A/B/C classification
20
+ >>> transform = CastStrToIntTensor(mapping={"A": 0, "B": 1, "C": 2})
21
+ >>> transform(['B', 'A', 'C'])
22
22
  tensor([1, 0, 2])
23
- >>> transform('yes')
23
+ >>> transform('B')
24
24
  tensor([1])
25
25
 
26
26
  >>> # Custom mapping
@@ -29,20 +29,25 @@ class CastStrToIntTensor:
29
29
  tensor([1, 0])
30
30
  """
31
31
 
32
- def __init__(self, mapping: Dict[str, int] | None = None):
33
- """Initialize the transform with a regex-to-integer mapping.
32
+ def __init__(
33
+ self, mapping: Dict[str, int], standalone_words: bool = True, case_sensitive: bool = True
34
+ ) -> None:
35
+ r"""Initialize the transform with a regex-to-integer mapping.
34
36
 
35
37
  Args:
36
38
  mapping: Dictionary mapping regex patterns to integers. If None, uses default
37
39
  yes/no/maybe mapping: {'no': 0, 'yes': 1, 'maybe': 2}
40
+ standalone_words: If True, patterns are treated as standalone words (e.g., '\bno\b').
41
+ case_sensitive: If True, regex patterns are case-sensitive.
38
42
  """
39
- if mapping is None:
40
- self.mapping = {r"\bno\b": 0, r"\byes\b": 1, r"\bmaybe\b": 2}
41
- else:
42
- self.mapping = mapping
43
+ self.mapping = mapping
44
+
45
+ if standalone_words:
46
+ self.mapping = {rf"\b{k}\b": v for k, v in mapping.items()}
43
47
 
44
48
  self.compiled_patterns = [
45
- (re.compile(pattern, re.IGNORECASE), value) for pattern, value in self.mapping.items()
49
+ (re.compile(pattern, 0 if case_sensitive else re.IGNORECASE), value)
50
+ for pattern, value in self.mapping.items()
46
51
  ]
47
52
 
48
53
  def __call__(self, values: Union[str, List[str], List[int]]) -> torch.Tensor:
@@ -58,7 +63,10 @@ class CastStrToIntTensor:
58
63
  ValueError: If any value cannot be mapped to an integer.
59
64
  """
60
65
  return torch.tensor(
61
- [self._cast_single(v) for v in (values if isinstance(values, list) else [values])],
66
+ [
67
+ self._cast_single(v)
68
+ for v in (values if isinstance(values, list | tuple) else [values])
69
+ ],
62
70
  dtype=torch.int,
63
71
  )
64
72
 
@@ -0,0 +1,5 @@
1
+ """Text utilities for language models."""
2
+
3
+ from eva.language.utils.text.messages import format_chat_message
4
+
5
+ __all__ = ["format_chat_message"]
@@ -0,0 +1,113 @@
1
+ """Message formatting utilities for language models."""
2
+
3
+ import functools
4
+ import json
5
+ from typing import Any, Dict, List
6
+
7
+ from eva.language.data.messages import (
8
+ AssistantMessage,
9
+ MessageSeries,
10
+ Role,
11
+ SystemMessage,
12
+ UserMessage,
13
+ )
14
+
15
+
16
+ def format_chat_message(message: MessageSeries) -> List[Dict[str, Any]]:
17
+ """Formats a message series into a format following OpenAI's API specification."""
18
+ return [{"role": item.role, "content": item.content} for item in message]
19
+
20
+
21
+ def combine_system_messages(message: MessageSeries, join_char: str = "\n") -> MessageSeries:
22
+ """Combine system messages into a single message.
23
+
24
+ This is useful when the MessageSeries contains multiple system messages such
25
+ as `ModelSystemMessage` and `TaskSystemMessage`. But given that most models / apis
26
+ expect a single system message, this function can be used to combines them into one.
27
+
28
+ Args:
29
+ message: The message series containing one or multiple messages.
30
+ join_char: The character to use to join the system messages. Default is newline.
31
+
32
+ Returns:
33
+ A new message series with system messages combined into one and the
34
+ remaining messages unchanged.
35
+ """
36
+ system_messages = list(filter(lambda item: item.role == Role.SYSTEM, message))
37
+ if len(system_messages) == 0:
38
+ return message
39
+
40
+ non_system_messages = list(filter(lambda item: item.role != Role.SYSTEM, message))
41
+ return [
42
+ SystemMessage(content=merge_message_contents(system_messages, join_char=join_char))
43
+ ] + non_system_messages
44
+
45
+
46
+ def merge_message_contents(message: MessageSeries, join_char: str = "\n") -> str:
47
+ """Merges the all contents within a message series into a string.
48
+
49
+ Args:
50
+ message: The message series to combine.
51
+ join_char: The character to use to join the message contents. Default is newline.
52
+
53
+ Returns:
54
+ A string containing the combined message contents.
55
+ """
56
+ return join_char.join(item.content for item in message)
57
+
58
+
59
+ def insert_system_message(
60
+ message: MessageSeries, system_message: SystemMessage | None
61
+ ) -> MessageSeries:
62
+ """Insert a system message at the beginning of the message series."""
63
+ if system_message is None:
64
+ return message
65
+ return [system_message] + message
66
+
67
+
68
+ def batch_insert_system_message(
69
+ messages: List[MessageSeries], system_message: SystemMessage | None
70
+ ) -> List[MessageSeries]:
71
+ """Insert a system message at the beginning of each message series in a batch."""
72
+ return list(
73
+ map(functools.partial(insert_system_message, system_message=system_message), messages)
74
+ )
75
+
76
+
77
+ def serialize(messages: MessageSeries) -> str:
78
+ """Serialize a MessageSeries object into a JSON string.
79
+
80
+ Args:
81
+ messages: A list of message objects (MessagesSeries).
82
+
83
+ Returns:
84
+ A JSON string representing the message series, with the following format:
85
+ [{"role": "user", "content": "Hello"}, ...]
86
+ """
87
+ serialized_messages = format_chat_message(messages)
88
+ return json.dumps(serialized_messages)
89
+
90
+
91
+ def deserialize(messages: str) -> MessageSeries:
92
+ """Convert a json string to a MessageSeries object.
93
+
94
+ Format: [{"role": "user", "content": "Hello"}, {"role": "assistant", "content": "Hi there!"}]
95
+ """
96
+ message_dicts = json.loads(messages)
97
+
98
+ message_series = []
99
+ for message_dict in message_dicts:
100
+ if "role" not in message_dict or "content" not in message_dict:
101
+ raise ValueError("`role` or `content` keys are missing.")
102
+
103
+ match message_dict["role"]:
104
+ case Role.USER:
105
+ message_series.append(UserMessage(**message_dict))
106
+ case Role.ASSISTANT:
107
+ message_series.append(AssistantMessage(**message_dict))
108
+ case Role.SYSTEM:
109
+ message_series.append(SystemMessage(**message_dict))
110
+ case _:
111
+ raise ValueError(f"Unknown role: {message_dict['role']}")
112
+
113
+ return message_series
@@ -0,0 +1,6 @@
1
+ """Multimodal API."""
2
+
3
+ from eva.multimodal import models
4
+ from eva.multimodal.data import datasets
5
+
6
+ __all__ = ["models", "datasets"]
@@ -0,0 +1,5 @@
1
+ """Multimodal callbacks API."""
2
+
3
+ from eva.multimodal.callbacks.writers import TextPredictionWriter
4
+
5
+ __all__ = ["TextPredictionWriter"]
@@ -0,0 +1,5 @@
1
+ """Multimodal writers callbacks API."""
2
+
3
+ from eva.multimodal.callbacks.writers.prediction import TextPredictionWriter
4
+
5
+ __all__ = ["TextPredictionWriter"]
@@ -0,0 +1,39 @@
1
+ """Text prediction writer callbacks."""
2
+
3
+ from typing import Dict, List, Literal, Tuple
4
+
5
+ from torch import nn
6
+ from typing_extensions import override
7
+
8
+ from eva.language.callbacks import writers
9
+ from eva.multimodal.models.typings import TextImageBatch
10
+
11
+
12
+ class TextPredictionWriter(writers.TextPredictionWriter):
13
+ """Callback for writing generated text predictions to disk."""
14
+
15
+ def __init__(
16
+ self,
17
+ output_dir: str,
18
+ model: nn.Module,
19
+ dataloader_idx_map: Dict[int, str] | None = None,
20
+ metadata_keys: List[str] | None = None,
21
+ include_input: bool = True,
22
+ overwrite: bool = False,
23
+ save_format: Literal["jsonl", "parquet", "csv"] = "jsonl",
24
+ ) -> None:
25
+ """See docstring of base class."""
26
+ super().__init__(
27
+ output_dir=output_dir,
28
+ model=model,
29
+ dataloader_idx_map=dataloader_idx_map,
30
+ metadata_keys=metadata_keys,
31
+ include_input=include_input,
32
+ overwrite=overwrite,
33
+ save_format=save_format,
34
+ )
35
+
36
+ @override
37
+ def _unpack_batch(self, batch: TextImageBatch) -> Tuple[list, list | None, dict | None]: # type: ignore
38
+ text_batch, _, target_batch, metadata_batch = TextImageBatch(*batch)
39
+ return text_batch, target_batch, metadata_batch
@@ -0,0 +1,5 @@
1
+ """Data components for multimodal learning."""
2
+
3
+ from eva.multimodal.data import datasets
4
+
5
+ __all__ = ["datasets"]