kaiko-eva 0.3.2__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaiko-eva might be problematic. Click here for more details.

Files changed (101) hide show
  1. eva/core/callbacks/config.py +4 -0
  2. eva/core/cli/setup.py +1 -1
  3. eva/core/data/dataloaders/__init__.py +1 -2
  4. eva/core/data/dataloaders/dataloader.py +3 -1
  5. eva/core/data/samplers/random.py +17 -10
  6. eva/core/interface/interface.py +21 -0
  7. eva/core/loggers/log/__init__.py +2 -1
  8. eva/core/loggers/log/table.py +73 -0
  9. eva/core/models/modules/module.py +2 -2
  10. eva/core/models/wrappers/base.py +2 -2
  11. eva/core/models/wrappers/from_function.py +3 -3
  12. eva/core/models/wrappers/from_torchhub.py +9 -7
  13. eva/core/models/wrappers/huggingface.py +4 -5
  14. eva/core/models/wrappers/onnx.py +5 -5
  15. eva/core/trainers/trainer.py +2 -0
  16. eva/language/__init__.py +2 -1
  17. eva/language/callbacks/__init__.py +5 -0
  18. eva/language/callbacks/writers/__init__.py +5 -0
  19. eva/language/callbacks/writers/prediction.py +176 -0
  20. eva/language/data/dataloaders/__init__.py +5 -0
  21. eva/language/data/dataloaders/collate_fn/__init__.py +5 -0
  22. eva/language/data/dataloaders/collate_fn/text.py +57 -0
  23. eva/language/data/datasets/__init__.py +3 -1
  24. eva/language/data/datasets/{language.py → base.py} +1 -1
  25. eva/language/data/datasets/classification/base.py +3 -43
  26. eva/language/data/datasets/classification/pubmedqa.py +36 -4
  27. eva/language/data/datasets/prediction.py +151 -0
  28. eva/language/data/datasets/schemas.py +18 -0
  29. eva/language/data/datasets/text.py +92 -0
  30. eva/language/data/datasets/typings.py +39 -0
  31. eva/language/data/messages.py +60 -0
  32. eva/language/models/__init__.py +15 -11
  33. eva/language/models/modules/__init__.py +2 -2
  34. eva/language/models/modules/language.py +93 -0
  35. eva/language/models/networks/__init__.py +12 -0
  36. eva/language/models/networks/alibaba.py +26 -0
  37. eva/language/models/networks/api/__init__.py +11 -0
  38. eva/language/models/networks/api/anthropic.py +34 -0
  39. eva/language/models/networks/registry.py +5 -0
  40. eva/language/models/typings.py +39 -0
  41. eva/language/models/wrappers/__init__.py +13 -5
  42. eva/language/models/wrappers/base.py +47 -0
  43. eva/language/models/wrappers/from_registry.py +54 -0
  44. eva/language/models/wrappers/huggingface.py +44 -8
  45. eva/language/models/wrappers/litellm.py +81 -46
  46. eva/language/models/wrappers/vllm.py +37 -13
  47. eva/language/utils/__init__.py +2 -1
  48. eva/language/utils/str_to_int_tensor.py +20 -12
  49. eva/language/utils/text/__init__.py +5 -0
  50. eva/language/utils/text/messages.py +113 -0
  51. eva/multimodal/__init__.py +6 -0
  52. eva/multimodal/callbacks/__init__.py +5 -0
  53. eva/multimodal/callbacks/writers/__init__.py +5 -0
  54. eva/multimodal/callbacks/writers/prediction.py +39 -0
  55. eva/multimodal/data/__init__.py +5 -0
  56. eva/multimodal/data/dataloaders/__init__.py +5 -0
  57. eva/multimodal/data/dataloaders/collate_fn/__init__.py +5 -0
  58. eva/multimodal/data/dataloaders/collate_fn/text_image.py +28 -0
  59. eva/multimodal/data/datasets/__init__.py +6 -0
  60. eva/multimodal/data/datasets/base.py +13 -0
  61. eva/multimodal/data/datasets/multiple_choice/__init__.py +5 -0
  62. eva/multimodal/data/datasets/multiple_choice/patch_camelyon.py +80 -0
  63. eva/multimodal/data/datasets/schemas.py +14 -0
  64. eva/multimodal/data/datasets/text_image.py +77 -0
  65. eva/multimodal/data/datasets/typings.py +27 -0
  66. eva/multimodal/models/__init__.py +8 -0
  67. eva/multimodal/models/modules/__init__.py +5 -0
  68. eva/multimodal/models/modules/vision_language.py +55 -0
  69. eva/multimodal/models/networks/__init__.py +14 -0
  70. eva/multimodal/models/networks/alibaba.py +39 -0
  71. eva/multimodal/models/networks/api/__init__.py +11 -0
  72. eva/multimodal/models/networks/api/anthropic.py +34 -0
  73. eva/multimodal/models/networks/others.py +47 -0
  74. eva/multimodal/models/networks/registry.py +5 -0
  75. eva/multimodal/models/typings.py +27 -0
  76. eva/multimodal/models/wrappers/__init__.py +13 -0
  77. eva/multimodal/models/wrappers/base.py +47 -0
  78. eva/multimodal/models/wrappers/from_registry.py +54 -0
  79. eva/multimodal/models/wrappers/huggingface.py +180 -0
  80. eva/multimodal/models/wrappers/litellm.py +56 -0
  81. eva/multimodal/utils/__init__.py +1 -0
  82. eva/multimodal/utils/image/__init__.py +5 -0
  83. eva/multimodal/utils/image/encode.py +28 -0
  84. eva/multimodal/utils/text/__init__.py +1 -0
  85. eva/multimodal/utils/text/messages.py +79 -0
  86. eva/vision/data/datasets/classification/patch_camelyon.py +8 -6
  87. eva/vision/data/transforms/__init__.py +2 -1
  88. eva/vision/data/transforms/spatial/__init__.py +2 -1
  89. eva/vision/data/transforms/spatial/functional/__init__.py +5 -0
  90. eva/vision/data/transforms/spatial/functional/resize.py +26 -0
  91. eva/vision/data/transforms/spatial/resize.py +62 -0
  92. eva/vision/models/wrappers/from_registry.py +6 -5
  93. eva/vision/models/wrappers/from_timm.py +6 -4
  94. {kaiko_eva-0.3.2.dist-info → kaiko_eva-0.4.0.dist-info}/METADATA +10 -2
  95. {kaiko_eva-0.3.2.dist-info → kaiko_eva-0.4.0.dist-info}/RECORD +98 -40
  96. eva/core/data/dataloaders/collate_fn/__init__.py +0 -5
  97. eva/core/data/dataloaders/collate_fn/collate.py +0 -24
  98. eva/language/models/modules/text.py +0 -85
  99. {kaiko_eva-0.3.2.dist-info → kaiko_eva-0.4.0.dist-info}/WHEEL +0 -0
  100. {kaiko_eva-0.3.2.dist-info → kaiko_eva-0.4.0.dist-info}/entry_points.txt +0 -0
  101. {kaiko_eva-0.3.2.dist-info → kaiko_eva-0.4.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,26 @@
1
+ """Models from Alibaba."""
2
+
3
+ import torch
4
+
5
+ from eva.language.models import wrappers
6
+ from eva.language.models.networks.registry import model_registry
7
+
8
+
9
+ @model_registry.register("alibaba/qwen2-0-5b-instruct")
10
+ class Qwen205BInstruct(wrappers.HuggingFaceModel):
11
+ """Qwen2 0.5B Instruct model."""
12
+
13
+ def __init__(self, system_prompt: str | None = None, cache_dir: str | None = None):
14
+ """Initialize the model."""
15
+ super().__init__(
16
+ model_name_or_path="Qwen/Qwen2-0.5B-Instruct",
17
+ model_kwargs={
18
+ "torch_dtype": torch.bfloat16,
19
+ "cache_dir": cache_dir,
20
+ },
21
+ generation_kwargs={
22
+ "max_new_tokens": 512,
23
+ },
24
+ system_prompt=system_prompt,
25
+ chat_mode=True,
26
+ )
@@ -0,0 +1,11 @@
1
+ """Multimodal API networks."""
2
+
3
+ from eva.language.models.networks.api.anthropic import (
4
+ Claude35Sonnet20240620,
5
+ Claude37Sonnet20250219,
6
+ )
7
+
8
+ __all__ = [
9
+ "Claude35Sonnet20240620",
10
+ "Claude37Sonnet20250219",
11
+ ]
@@ -0,0 +1,34 @@
1
+ """Models from Anthropic."""
2
+
3
+ import os
4
+
5
+ from eva.language.models import wrappers
6
+ from eva.language.models.networks.registry import model_registry
7
+
8
+
9
+ class _Claude(wrappers.LiteLLMModel):
10
+ """Base class for Claude models."""
11
+
12
+ def __init__(self, model_name: str, system_prompt: str | None = None):
13
+ if not os.getenv("ANTHROPIC_API_KEY"):
14
+ raise ValueError("ANTHROPIC_API_KEY env variable must be set.")
15
+
16
+ super().__init__(model_name=model_name, system_prompt=system_prompt)
17
+
18
+
19
+ @model_registry.register("anthropic/claude-3-5-sonnet-20240620")
20
+ class Claude35Sonnet20240620(_Claude):
21
+ """Claude 3.5 Sonnet (June 2024) model."""
22
+
23
+ def __init__(self, system_prompt: str | None = None):
24
+ """Initialize the model."""
25
+ super().__init__(model_name="claude-3-5-sonnet-20240620", system_prompt=system_prompt)
26
+
27
+
28
+ @model_registry.register("anthropic/claude-3-7-sonnet-20250219")
29
+ class Claude37Sonnet20250219(_Claude):
30
+ """Claude 3.7 Sonnet (February 2025) model."""
31
+
32
+ def __init__(self, system_prompt: str | None = None):
33
+ """Initialize the model."""
34
+ super().__init__(model_name="claude-3-7-sonnet-20250219", system_prompt=system_prompt)
@@ -0,0 +1,5 @@
1
+ """Language Model Registry."""
2
+
3
+ from eva.core.utils.registry import Registry
4
+
5
+ model_registry = Registry()
@@ -0,0 +1,39 @@
1
+ """Type definitions for language models."""
2
+
3
+ from typing import Any, Dict, Generic, List, TypeVar
4
+
5
+ from typing_extensions import NamedTuple
6
+
7
+ from eva.language.data.messages import MessageSeries
8
+
9
+ TargetType = TypeVar("TargetType")
10
+ """The target data type."""
11
+
12
+
13
+ class TextBatch(NamedTuple, Generic[TargetType]):
14
+ """Text sample with target and metadata."""
15
+
16
+ text: List[MessageSeries]
17
+ """Text content."""
18
+
19
+ target: TargetType | None
20
+ """Target data."""
21
+
22
+ metadata: Dict[str, Any] | None
23
+ """Additional metadata."""
24
+
25
+
26
+ class PredictionBatch(NamedTuple, Generic[TargetType]):
27
+ """Text sample with target and metadata."""
28
+
29
+ prediction: TargetType
30
+ """Prediction data."""
31
+
32
+ target: TargetType
33
+ """Target data."""
34
+
35
+ text: List[MessageSeries] | None
36
+ """Conversation messages that were used as input."""
37
+
38
+ metadata: Dict[str, Any] | None
39
+ """Additional metadata."""
@@ -1,11 +1,19 @@
1
1
  """Language Model Wrappers API."""
2
2
 
3
- from eva.language.models.wrappers.huggingface import HuggingFaceTextModel
4
- from eva.language.models.wrappers.litellm import LiteLLMTextModel
3
+ from eva.language.models.wrappers.base import LanguageModel
4
+ from eva.language.models.wrappers.from_registry import ModelFromRegistry
5
+ from eva.language.models.wrappers.huggingface import HuggingFaceModel
6
+ from eva.language.models.wrappers.litellm import LiteLLMModel
5
7
 
6
8
  try:
7
- from eva.language.models.wrappers.vllm import VLLMTextModel
9
+ from eva.language.models.wrappers.vllm import VllmModel
8
10
 
9
- __all__ = ["HuggingFaceTextModel", "LiteLLMTextModel", "VLLMTextModel"]
11
+ __all__ = [
12
+ "LanguageModel",
13
+ "HuggingFaceModel",
14
+ "LiteLLMModel",
15
+ "VllmModel",
16
+ "ModelFromRegistry",
17
+ ]
10
18
  except ImportError:
11
- __all__ = ["HuggingFaceTextModel", "LiteLLMTextModel"]
19
+ __all__ = ["LanguageModel", "HuggingFaceModel", "LiteLLMModel", "ModelFromRegistry"]
@@ -0,0 +1,47 @@
1
+ """Base class for language model wrappers."""
2
+
3
+ import abc
4
+ from typing import Any, Callable, List
5
+
6
+ from typing_extensions import override
7
+
8
+ from eva.core.models.wrappers import base
9
+ from eva.language.data.messages import ModelSystemMessage
10
+ from eva.language.models.typings import TextBatch
11
+
12
+
13
+ class LanguageModel(base.BaseModel[TextBatch, List[str]]):
14
+ """Base class for language models.
15
+
16
+ Classes that inherit from this should implement the following methods:
17
+ - `load_model`: Loads & instantiates the model.
18
+ - `model_forward`: Implements the forward pass of the model. For API models,
19
+ this can be an API call.
20
+ - `format_inputs`: Preprocesses and converts the input batch into the format
21
+ expected by the `model_forward` method.
22
+ """
23
+
24
+ def __init__(
25
+ self, system_prompt: str | None, output_transforms: Callable | None = None
26
+ ) -> None:
27
+ """Creates a new model instance.
28
+
29
+ Args:
30
+ system_prompt: The system prompt to use for the model (optional).
31
+ output_transforms: Optional transforms to apply to the output of
32
+ the model's forward pass.
33
+ """
34
+ super().__init__(transforms=output_transforms)
35
+
36
+ self.system_message = ModelSystemMessage(content=system_prompt) if system_prompt else None
37
+
38
+ @override
39
+ def forward(self, batch: TextBatch) -> List[str]:
40
+ """Forward pass of the model."""
41
+ inputs = self.format_inputs(batch)
42
+ return super().forward(inputs)
43
+
44
+ @abc.abstractmethod
45
+ def format_inputs(self, batch: TextBatch) -> Any:
46
+ """Converts the inputs into the format expected by the model."""
47
+ raise NotImplementedError
@@ -0,0 +1,54 @@
1
+ """Vision backbone helper class."""
2
+
3
+ from typing import Any, Callable, Dict, List
4
+
5
+ from torch import nn
6
+ from typing_extensions import override
7
+
8
+ from eva.core.models.wrappers import base
9
+ from eva.core.utils import factory
10
+ from eva.language.models.networks.registry import model_registry
11
+ from eva.language.models.typings import TextBatch
12
+
13
+
14
+ class ModelFromRegistry(base.BaseModel[TextBatch, List[str]]):
15
+ """Wrapper class for vision backbone models.
16
+
17
+ This class can be used by load backbones available in eva's
18
+ model registry by name. New backbones can be registered by using
19
+ the `@backbone_registry.register(model_name)` decorator.
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ model_name: str,
25
+ model_kwargs: Dict[str, Any] | None = None,
26
+ model_extra_kwargs: Dict[str, Any] | None = None,
27
+ transforms: Callable | None = None,
28
+ ) -> None:
29
+ """Initializes the model.
30
+
31
+ Args:
32
+ model_name: The name of the model to load.
33
+ model_kwargs: The arguments used for instantiating the model.
34
+ model_extra_kwargs: Extra arguments used for instantiating the model.
35
+ transforms: The transforms to apply to the output tensor
36
+ produced by the model.
37
+ """
38
+ super().__init__(transforms=transforms)
39
+
40
+ self._model_name = model_name
41
+ self._model_kwargs = model_kwargs or {}
42
+ self._model_extra_kwargs = model_extra_kwargs or {}
43
+
44
+ self.model = self.load_model()
45
+
46
+ @override
47
+ def load_model(self) -> nn.Module:
48
+ ModelFromRegistry.__name__ = self._model_name
49
+
50
+ return factory.ModuleFactory(
51
+ registry=model_registry,
52
+ name=self._model_name,
53
+ init_args=self._model_kwargs | self._model_extra_kwargs,
54
+ )
@@ -1,14 +1,16 @@
1
1
  """LLM wrapper for HuggingFace `transformers` models."""
2
2
 
3
- from typing import Any, Dict, List, Literal
3
+ from typing import Any, Callable, Dict, List, Literal
4
4
 
5
5
  from transformers.pipelines import pipeline
6
6
  from typing_extensions import override
7
7
 
8
- from eva.core.models.wrappers import base
8
+ from eva.language.models.typings import TextBatch
9
+ from eva.language.models.wrappers import base
10
+ from eva.language.utils.text import messages as message_utils
9
11
 
10
12
 
11
- class HuggingFaceTextModel(base.BaseModel[List[str], List[str]]):
13
+ class HuggingFaceModel(base.LanguageModel):
12
14
  """Wrapper class for loading HuggingFace `transformers` models using pipelines."""
13
15
 
14
16
  def __init__(
@@ -16,7 +18,9 @@ class HuggingFaceTextModel(base.BaseModel[List[str], List[str]]):
16
18
  model_name_or_path: str,
17
19
  task: Literal["text-generation"] = "text-generation",
18
20
  model_kwargs: Dict[str, Any] | None = None,
21
+ system_prompt: str | None = None,
19
22
  generation_kwargs: Dict[str, Any] | None = None,
23
+ chat_mode: bool = True,
20
24
  ) -> None:
21
25
  """Initializes the model.
22
26
 
@@ -26,27 +30,59 @@ class HuggingFaceTextModel(base.BaseModel[List[str], List[str]]):
26
30
  model hub.
27
31
  task: The pipeline task. Defaults to "text-generation".
28
32
  model_kwargs: Additional arguments for configuring the pipeline.
33
+ system_prompt: System prompt to use.
29
34
  generation_kwargs: Additional generation parameters (temperature, max_length, etc.).
35
+ chat_mode: Whether the specified model expects chat style messages. If set to False
36
+ the model is assumed to be a standard text completion model and will expect
37
+ plain text string inputs.
30
38
  """
31
- super().__init__()
39
+ super().__init__(system_prompt=system_prompt)
32
40
 
33
41
  self._model_name_or_path = model_name_or_path
34
42
  self._task = task
35
43
  self._model_kwargs = model_kwargs or {}
36
44
  self._generation_kwargs = generation_kwargs or {}
45
+ self._chat_mode = chat_mode
37
46
 
38
- self.load_model()
47
+ self.model = self.load_model()
39
48
 
40
49
  @override
41
- def load_model(self) -> None:
50
+ def load_model(self) -> Callable:
42
51
  """Loads the model as a Hugging Face pipeline."""
43
- self._pipeline = pipeline(
52
+ return pipeline(
44
53
  task=self._task,
45
54
  model=self._model_name_or_path,
46
55
  trust_remote_code=True,
47
56
  **self._model_kwargs,
48
57
  )
49
58
 
59
+ @override
60
+ def format_inputs(self, batch: TextBatch) -> List[List[Dict[str, Any]]] | List[str]:
61
+ """Formats inputs for HuggingFace models.
62
+
63
+ Note: If multiple system messages are present, they will be combined
64
+ into a single message, given that many models only support a single
65
+ system prompt.
66
+
67
+ Args:
68
+ batch: A batch of text and image inputs.
69
+
70
+ Returns:
71
+ When in chat mode, returns a batch of message series following
72
+ OpenAI's API format {"role": "user", "content": "..."}, for non-chat
73
+ models returns a list of plain text strings.
74
+ """
75
+ message_batch, _, _ = TextBatch(*batch)
76
+ message_batch = message_utils.batch_insert_system_message(
77
+ message_batch, self.system_message
78
+ )
79
+ message_batch = list(map(message_utils.combine_system_messages, message_batch))
80
+
81
+ if self._chat_mode:
82
+ return list(map(message_utils.format_chat_message, message_batch))
83
+ else:
84
+ return list(map(message_utils.merge_message_contents, message_batch))
85
+
50
86
  @override
51
87
  def model_forward(self, prompts: List[str]) -> List[str]:
52
88
  """Generates text using the pipeline.
@@ -57,7 +93,7 @@ class HuggingFaceTextModel(base.BaseModel[List[str], List[str]]):
57
93
  Returns:
58
94
  The generated text as a string.
59
95
  """
60
- outputs = self._pipeline(prompts, return_full_text=False, **self._generation_kwargs)
96
+ outputs = self.model(prompts, return_full_text=False, **self._generation_kwargs)
61
97
  if outputs is None:
62
98
  raise ValueError("Outputs from the model are None.")
63
99
  results = []
@@ -1,77 +1,112 @@
1
- """LLM wrapper for litellm models."""
1
+ """LiteLLM language model wrapper."""
2
2
 
3
+ import logging
3
4
  from typing import Any, Dict, List
4
5
 
5
- from litellm import batch_completion # type: ignore
6
+ import backoff
7
+ import litellm
8
+ from litellm import batch_completion
9
+ from litellm.exceptions import (
10
+ APIConnectionError,
11
+ InternalServerError,
12
+ RateLimitError,
13
+ ServiceUnavailableError,
14
+ Timeout,
15
+ )
6
16
  from loguru import logger
7
17
  from typing_extensions import override
8
18
 
9
- from eva.core.models.wrappers import base
19
+ from eva.language.models.typings import TextBatch
20
+ from eva.language.models.wrappers import base
21
+ from eva.language.utils.text import messages as message_utils
10
22
 
23
+ RETRYABLE_ERRORS = (
24
+ RateLimitError,
25
+ Timeout,
26
+ InternalServerError,
27
+ APIConnectionError,
28
+ ServiceUnavailableError,
29
+ )
11
30
 
12
- class LiteLLMTextModel(base.BaseModel[List[str], List[str]]):
13
- """Wrapper class for using litellm for chat-based text generation.
14
31
 
15
- This wrapper uses litellm's `completion` function which accepts a list of
16
- message dicts. The `forward` method converts a string prompt into a chat
17
- message with a default "user" role, optionally prepends a system message,
18
- and includes an API key if provided.
19
- """
32
+ class LiteLLMModel(base.LanguageModel):
33
+ """Wrapper class for LiteLLM language models."""
20
34
 
21
35
  def __init__(
22
36
  self,
23
- model_name_or_path: str,
37
+ model_name: str,
24
38
  model_kwargs: Dict[str, Any] | None = None,
25
- ) -> None:
26
- """Initializes the litellm chat model wrapper.
39
+ system_prompt: str | None = None,
40
+ log_level: int | None = logging.INFO,
41
+ ):
42
+ """Initialize the LiteLLM Wrapper.
27
43
 
28
44
  Args:
29
- model_name_or_path: The model identifier (or name) for litellm
30
- (e.g.,"openai/gpt-4o" or "anthropic/claude-3-sonnet-20240229").
45
+ model_name: The name of the model to use.
31
46
  model_kwargs: Additional keyword arguments to pass during
32
47
  generation (e.g., `temperature`, `max_tokens`).
48
+ system_prompt: The system prompt to use (optional).
49
+ log_level: Optional logging level for LiteLLM. Defaults to WARNING.
33
50
  """
34
- super().__init__()
35
- self._model_name_or_path = model_name_or_path
36
- self._model_kwargs = model_kwargs or {}
37
- self.load_model()
51
+ super().__init__(system_prompt=system_prompt)
38
52
 
39
- @override
40
- def load_model(self) -> None:
41
- """Prepares the litellm model.
53
+ self.model_name = model_name
54
+ self.model_kwargs = model_kwargs or {}
42
55
 
43
- Note:
44
- litellm doesn't require an explicit loading step; models are called
45
- directly during generation. This method exists for API consistency.
46
- """
47
- pass
56
+ litellm.suppress_debug_info = True
57
+
58
+ if log_level is not None:
59
+ logging.getLogger("LiteLLM").setLevel(log_level)
48
60
 
49
61
  @override
50
- def model_forward(self, prompts: List[str]) -> List[str]:
51
- """Generates text using litellm.
62
+ def format_inputs(self, batch: TextBatch) -> List[List[Dict[str, Any]]]:
63
+ """Formats inputs for LiteLLM.
52
64
 
53
65
  Args:
54
- prompts: A list of prompts to be converted into a "user" message.
66
+ batch: A batch of text inputs.
55
67
 
56
68
  Returns:
57
- A list of generated text responses. Failed generations will contain
58
- error messages instead of generated text.
69
+ A list of messages in the following format:
70
+ [
71
+ {
72
+ "role": ...
73
+ "content": ...
74
+ },
75
+ ...
76
+ ]
59
77
  """
60
- messages = [[{"role": "user", "content": prompt}] for prompt in prompts]
78
+ message_batch, _, _ = TextBatch(*batch)
61
79
 
62
- responses = batch_completion(
63
- model=self._model_name_or_path,
64
- messages=messages,
65
- **self._model_kwargs,
80
+ message_batch = message_utils.batch_insert_system_message(
81
+ message_batch, self.system_message
66
82
  )
83
+ message_batch = list(map(message_utils.combine_system_messages, message_batch))
67
84
 
68
- results = []
69
- for i, response in enumerate(responses):
70
- if isinstance(response, Exception):
71
- error_msg = f"Error generating text for prompt {i}: {response}"
72
- logger.error(error_msg)
73
- raise RuntimeError(error_msg)
74
- else:
75
- results.append(response["choices"][0]["message"]["content"])
85
+ return list(map(message_utils.format_chat_message, message_batch))
76
86
 
77
- return results
87
+ @override
88
+ @backoff.on_exception(
89
+ backoff.expo,
90
+ RETRYABLE_ERRORS,
91
+ max_tries=20,
92
+ jitter=backoff.full_jitter,
93
+ on_backoff=lambda details: logger.warning(
94
+ f"Retrying due to {details.get('exception') or 'Unknown error'}"
95
+ ),
96
+ )
97
+ def model_forward(self, batch: List[List[Dict[str, Any]]]) -> List[str]:
98
+ """Generates output text through API calls via LiteLLM's batch completion functionality."""
99
+ outputs = batch_completion(model=self.model_name, messages=batch, **self.model_kwargs)
100
+ self._raise_exceptions(outputs)
101
+
102
+ return [
103
+ output["choices"][0]["message"]["content"]
104
+ for output in outputs
105
+ if output["choices"][0]["message"]["role"] == "assistant"
106
+ ]
107
+
108
+ def _raise_exceptions(self, outputs: list):
109
+ for output in outputs:
110
+ if isinstance(output, Exception):
111
+ logger.error(f"Model {self.model_name} encountered an error: {output}")
112
+ raise output
@@ -1,6 +1,6 @@
1
1
  """LLM wrapper for vLLM models."""
2
2
 
3
- from typing import Any, Dict, List, Sequence
3
+ from typing import Any, Dict, List
4
4
 
5
5
  from loguru import logger
6
6
  from typing_extensions import override
@@ -11,17 +11,20 @@ try:
11
11
  from vllm.transformers_utils.tokenizer import AnyTokenizer # type: ignore
12
12
  except ImportError as e:
13
13
  raise ImportError(
14
- "vLLM is required for VLLMTextModel but not installed. "
14
+ "vLLM is required for VllmModel but not installed. "
15
15
  "vLLM must be installed manually as it requires CUDA and is not included in dependencies. "
16
16
  "Install with: pip install vllm "
17
17
  "Note: vLLM requires Linux with CUDA support for optimal performance. "
18
- "For alternatives, consider using HuggingFaceTextModel or LiteLLMTextModel."
18
+ "For alternatives, consider using HuggingFaceModel or LiteLLMModel."
19
19
  ) from e
20
20
 
21
- from eva.core.models.wrappers import base
21
+ from eva.language.data.messages import MessageSeries
22
+ from eva.language.models.typings import TextBatch
23
+ from eva.language.models.wrappers import base
24
+ from eva.language.utils.text import messages as message_utils
22
25
 
23
26
 
24
- class VLLMTextModel(base.BaseModel):
27
+ class VllmModel(base.LanguageModel):
25
28
  """Wrapper class for using vLLM for text generation.
26
29
 
27
30
  This wrapper loads a vLLM model, sets up the tokenizer and sampling
@@ -34,6 +37,7 @@ class VLLMTextModel(base.BaseModel):
34
37
  self,
35
38
  model_name_or_path: str,
36
39
  model_kwargs: Dict[str, Any] | None = None,
40
+ system_prompt: str | None = None,
37
41
  generation_kwargs: Dict[str, Any] | None = None,
38
42
  ) -> None:
39
43
  """Initializes the vLLM model wrapper.
@@ -44,12 +48,13 @@ class VLLMTextModel(base.BaseModel):
44
48
  model_kwargs: Arguments required to initialize the vLLM model,
45
49
  see [link](https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py)
46
50
  for more information.
51
+ system_prompt: System prompt to use.
47
52
  generation_kwargs: Arguments required to generate the output,
48
53
  need to align with the arguments of
49
54
  [vllm.SamplingParams](https://github.com/vllm-project/vllm/blob/main/vllm/sampling_params.py).
50
55
 
51
56
  """
52
- super().__init__()
57
+ super().__init__(system_prompt=system_prompt)
53
58
  self._model_name_or_path = model_name_or_path
54
59
  self._model_kwargs = model_kwargs or {}
55
60
  self._generation_kwargs = generation_kwargs or {}
@@ -71,11 +76,11 @@ class VLLMTextModel(base.BaseModel):
71
76
  raise RuntimeError("Model not initialized")
72
77
  self._llm_tokenizer = self._llm_model.get_tokenizer()
73
78
 
74
- def _apply_chat_template(self, prompts: Sequence[str]) -> list[TokensPrompt]:
79
+ def _tokenize_messages(self, messages: List[MessageSeries]) -> List[TokensPrompt]:
75
80
  """Apply chat template to the messages.
76
81
 
77
82
  Args:
78
- prompts: List of raw user strings.
83
+ messages: List of raw user strings.
79
84
 
80
85
  Returns:
81
86
  List of encoded messages.
@@ -90,7 +95,8 @@ class VLLMTextModel(base.BaseModel):
90
95
  if not hasattr(self._llm_tokenizer, "chat_template"):
91
96
  raise ValueError("Tokenizer does not have a chat template.")
92
97
 
93
- chat_messages = [[{"role": "user", "content": p}] for p in prompts]
98
+ chat_messages = list(map(message_utils.format_chat_message, messages))
99
+
94
100
  encoded_messages = self._llm_tokenizer.apply_chat_template(
95
101
  chat_messages, # type: ignore
96
102
  tokenize=True,
@@ -131,11 +137,30 @@ class VLLMTextModel(base.BaseModel):
131
137
 
132
138
  return result
133
139
 
134
- def generate(self, prompts: List[str]) -> List[str]:
140
+ @override
141
+ def format_inputs(self, batch: TextBatch) -> List[TokensPrompt]:
142
+ """Formats inputs for vLLM models.
143
+
144
+ Args:
145
+ batch: A batch of text and image inputs.
146
+
147
+ Returns:
148
+ List of formatted prompts.
149
+ """
150
+ message_batch, _, _ = TextBatch(*batch)
151
+ message_batch = message_utils.batch_insert_system_message(
152
+ message_batch, self.system_message
153
+ )
154
+ message_batch = list(map(message_utils.combine_system_messages, message_batch))
155
+
156
+ return self._tokenize_messages(message_batch)
157
+
158
+ @override
159
+ def model_forward(self, batch: List[TokensPrompt]) -> List[str]:
135
160
  """Generates text for the given prompt using the vLLM model.
136
161
 
137
162
  Args:
138
- prompts: A list of string prompts for generation.
163
+ batch: A list encoded / tokenized messages (TokensPrompt objects).
139
164
 
140
165
  Returns:
141
166
  The generated text response.
@@ -144,6 +169,5 @@ class VLLMTextModel(base.BaseModel):
144
169
  if self._llm_model is None:
145
170
  raise RuntimeError("Model not initialized")
146
171
 
147
- prompt_tokens = self._apply_chat_template(prompts)
148
- outputs = self._llm_model.generate(prompt_tokens, SamplingParams(**self._generation_kwargs))
172
+ outputs = self._llm_model.generate(batch, SamplingParams(**self._generation_kwargs))
149
173
  return [output.outputs[0].text for output in outputs]
@@ -1,5 +1,6 @@
1
1
  """Language utilities and helper functions."""
2
2
 
3
3
  from eva.language.utils.str_to_int_tensor import CastStrToIntTensor
4
+ from eva.language.utils.text.messages import format_chat_message
4
5
 
5
- __all__ = ["CastStrToIntTensor"]
6
+ __all__ = ["CastStrToIntTensor", "format_chat_message"]