kaiko-eva 0.3.2__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaiko-eva might be problematic. Click here for more details.
- eva/core/callbacks/config.py +4 -0
- eva/core/cli/setup.py +1 -1
- eva/core/data/dataloaders/__init__.py +1 -2
- eva/core/data/dataloaders/dataloader.py +3 -1
- eva/core/data/samplers/random.py +17 -10
- eva/core/interface/interface.py +21 -0
- eva/core/loggers/log/__init__.py +2 -1
- eva/core/loggers/log/table.py +73 -0
- eva/core/models/modules/module.py +2 -2
- eva/core/models/wrappers/base.py +2 -2
- eva/core/models/wrappers/from_function.py +3 -3
- eva/core/models/wrappers/from_torchhub.py +9 -7
- eva/core/models/wrappers/huggingface.py +4 -5
- eva/core/models/wrappers/onnx.py +5 -5
- eva/core/trainers/trainer.py +2 -0
- eva/language/__init__.py +2 -1
- eva/language/callbacks/__init__.py +5 -0
- eva/language/callbacks/writers/__init__.py +5 -0
- eva/language/callbacks/writers/prediction.py +176 -0
- eva/language/data/dataloaders/__init__.py +5 -0
- eva/language/data/dataloaders/collate_fn/__init__.py +5 -0
- eva/language/data/dataloaders/collate_fn/text.py +57 -0
- eva/language/data/datasets/__init__.py +3 -1
- eva/language/data/datasets/{language.py → base.py} +1 -1
- eva/language/data/datasets/classification/base.py +3 -43
- eva/language/data/datasets/classification/pubmedqa.py +36 -4
- eva/language/data/datasets/prediction.py +151 -0
- eva/language/data/datasets/schemas.py +18 -0
- eva/language/data/datasets/text.py +92 -0
- eva/language/data/datasets/typings.py +39 -0
- eva/language/data/messages.py +60 -0
- eva/language/models/__init__.py +15 -11
- eva/language/models/modules/__init__.py +2 -2
- eva/language/models/modules/language.py +93 -0
- eva/language/models/networks/__init__.py +12 -0
- eva/language/models/networks/alibaba.py +26 -0
- eva/language/models/networks/api/__init__.py +11 -0
- eva/language/models/networks/api/anthropic.py +34 -0
- eva/language/models/networks/registry.py +5 -0
- eva/language/models/typings.py +39 -0
- eva/language/models/wrappers/__init__.py +13 -5
- eva/language/models/wrappers/base.py +47 -0
- eva/language/models/wrappers/from_registry.py +54 -0
- eva/language/models/wrappers/huggingface.py +44 -8
- eva/language/models/wrappers/litellm.py +81 -46
- eva/language/models/wrappers/vllm.py +37 -13
- eva/language/utils/__init__.py +2 -1
- eva/language/utils/str_to_int_tensor.py +20 -12
- eva/language/utils/text/__init__.py +5 -0
- eva/language/utils/text/messages.py +113 -0
- eva/multimodal/__init__.py +6 -0
- eva/multimodal/callbacks/__init__.py +5 -0
- eva/multimodal/callbacks/writers/__init__.py +5 -0
- eva/multimodal/callbacks/writers/prediction.py +39 -0
- eva/multimodal/data/__init__.py +5 -0
- eva/multimodal/data/dataloaders/__init__.py +5 -0
- eva/multimodal/data/dataloaders/collate_fn/__init__.py +5 -0
- eva/multimodal/data/dataloaders/collate_fn/text_image.py +28 -0
- eva/multimodal/data/datasets/__init__.py +6 -0
- eva/multimodal/data/datasets/base.py +13 -0
- eva/multimodal/data/datasets/multiple_choice/__init__.py +5 -0
- eva/multimodal/data/datasets/multiple_choice/patch_camelyon.py +80 -0
- eva/multimodal/data/datasets/schemas.py +14 -0
- eva/multimodal/data/datasets/text_image.py +77 -0
- eva/multimodal/data/datasets/typings.py +27 -0
- eva/multimodal/models/__init__.py +8 -0
- eva/multimodal/models/modules/__init__.py +5 -0
- eva/multimodal/models/modules/vision_language.py +55 -0
- eva/multimodal/models/networks/__init__.py +14 -0
- eva/multimodal/models/networks/alibaba.py +39 -0
- eva/multimodal/models/networks/api/__init__.py +11 -0
- eva/multimodal/models/networks/api/anthropic.py +34 -0
- eva/multimodal/models/networks/others.py +47 -0
- eva/multimodal/models/networks/registry.py +5 -0
- eva/multimodal/models/typings.py +27 -0
- eva/multimodal/models/wrappers/__init__.py +13 -0
- eva/multimodal/models/wrappers/base.py +47 -0
- eva/multimodal/models/wrappers/from_registry.py +54 -0
- eva/multimodal/models/wrappers/huggingface.py +180 -0
- eva/multimodal/models/wrappers/litellm.py +56 -0
- eva/multimodal/utils/__init__.py +1 -0
- eva/multimodal/utils/image/__init__.py +5 -0
- eva/multimodal/utils/image/encode.py +28 -0
- eva/multimodal/utils/text/__init__.py +1 -0
- eva/multimodal/utils/text/messages.py +79 -0
- eva/vision/data/datasets/classification/patch_camelyon.py +8 -6
- eva/vision/data/transforms/__init__.py +2 -1
- eva/vision/data/transforms/spatial/__init__.py +2 -1
- eva/vision/data/transforms/spatial/functional/__init__.py +5 -0
- eva/vision/data/transforms/spatial/functional/resize.py +26 -0
- eva/vision/data/transforms/spatial/resize.py +62 -0
- eva/vision/models/wrappers/from_registry.py +6 -5
- eva/vision/models/wrappers/from_timm.py +6 -4
- {kaiko_eva-0.3.2.dist-info → kaiko_eva-0.4.0.dist-info}/METADATA +10 -2
- {kaiko_eva-0.3.2.dist-info → kaiko_eva-0.4.0.dist-info}/RECORD +98 -40
- eva/core/data/dataloaders/collate_fn/__init__.py +0 -5
- eva/core/data/dataloaders/collate_fn/collate.py +0 -24
- eva/language/models/modules/text.py +0 -85
- {kaiko_eva-0.3.2.dist-info → kaiko_eva-0.4.0.dist-info}/WHEEL +0 -0
- {kaiko_eva-0.3.2.dist-info → kaiko_eva-0.4.0.dist-info}/entry_points.txt +0 -0
- {kaiko_eva-0.3.2.dist-info → kaiko_eva-0.4.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
"""Models from Alibaba."""
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from eva.language.models import wrappers
|
|
6
|
+
from eva.language.models.networks.registry import model_registry
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@model_registry.register("alibaba/qwen2-0-5b-instruct")
|
|
10
|
+
class Qwen205BInstruct(wrappers.HuggingFaceModel):
|
|
11
|
+
"""Qwen2 0.5B Instruct model."""
|
|
12
|
+
|
|
13
|
+
def __init__(self, system_prompt: str | None = None, cache_dir: str | None = None):
|
|
14
|
+
"""Initialize the model."""
|
|
15
|
+
super().__init__(
|
|
16
|
+
model_name_or_path="Qwen/Qwen2-0.5B-Instruct",
|
|
17
|
+
model_kwargs={
|
|
18
|
+
"torch_dtype": torch.bfloat16,
|
|
19
|
+
"cache_dir": cache_dir,
|
|
20
|
+
},
|
|
21
|
+
generation_kwargs={
|
|
22
|
+
"max_new_tokens": 512,
|
|
23
|
+
},
|
|
24
|
+
system_prompt=system_prompt,
|
|
25
|
+
chat_mode=True,
|
|
26
|
+
)
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
"""Models from Anthropic."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
from eva.language.models import wrappers
|
|
6
|
+
from eva.language.models.networks.registry import model_registry
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class _Claude(wrappers.LiteLLMModel):
|
|
10
|
+
"""Base class for Claude models."""
|
|
11
|
+
|
|
12
|
+
def __init__(self, model_name: str, system_prompt: str | None = None):
|
|
13
|
+
if not os.getenv("ANTHROPIC_API_KEY"):
|
|
14
|
+
raise ValueError("ANTHROPIC_API_KEY env variable must be set.")
|
|
15
|
+
|
|
16
|
+
super().__init__(model_name=model_name, system_prompt=system_prompt)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@model_registry.register("anthropic/claude-3-5-sonnet-20240620")
|
|
20
|
+
class Claude35Sonnet20240620(_Claude):
|
|
21
|
+
"""Claude 3.5 Sonnet (June 2024) model."""
|
|
22
|
+
|
|
23
|
+
def __init__(self, system_prompt: str | None = None):
|
|
24
|
+
"""Initialize the model."""
|
|
25
|
+
super().__init__(model_name="claude-3-5-sonnet-20240620", system_prompt=system_prompt)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@model_registry.register("anthropic/claude-3-7-sonnet-20250219")
|
|
29
|
+
class Claude37Sonnet20250219(_Claude):
|
|
30
|
+
"""Claude 3.7 Sonnet (February 2025) model."""
|
|
31
|
+
|
|
32
|
+
def __init__(self, system_prompt: str | None = None):
|
|
33
|
+
"""Initialize the model."""
|
|
34
|
+
super().__init__(model_name="claude-3-7-sonnet-20250219", system_prompt=system_prompt)
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
"""Type definitions for language models."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict, Generic, List, TypeVar
|
|
4
|
+
|
|
5
|
+
from typing_extensions import NamedTuple
|
|
6
|
+
|
|
7
|
+
from eva.language.data.messages import MessageSeries
|
|
8
|
+
|
|
9
|
+
TargetType = TypeVar("TargetType")
|
|
10
|
+
"""The target data type."""
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TextBatch(NamedTuple, Generic[TargetType]):
|
|
14
|
+
"""Text sample with target and metadata."""
|
|
15
|
+
|
|
16
|
+
text: List[MessageSeries]
|
|
17
|
+
"""Text content."""
|
|
18
|
+
|
|
19
|
+
target: TargetType | None
|
|
20
|
+
"""Target data."""
|
|
21
|
+
|
|
22
|
+
metadata: Dict[str, Any] | None
|
|
23
|
+
"""Additional metadata."""
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class PredictionBatch(NamedTuple, Generic[TargetType]):
|
|
27
|
+
"""Text sample with target and metadata."""
|
|
28
|
+
|
|
29
|
+
prediction: TargetType
|
|
30
|
+
"""Prediction data."""
|
|
31
|
+
|
|
32
|
+
target: TargetType
|
|
33
|
+
"""Target data."""
|
|
34
|
+
|
|
35
|
+
text: List[MessageSeries] | None
|
|
36
|
+
"""Conversation messages that were used as input."""
|
|
37
|
+
|
|
38
|
+
metadata: Dict[str, Any] | None
|
|
39
|
+
"""Additional metadata."""
|
|
@@ -1,11 +1,19 @@
|
|
|
1
1
|
"""Language Model Wrappers API."""
|
|
2
2
|
|
|
3
|
-
from eva.language.models.wrappers.
|
|
4
|
-
from eva.language.models.wrappers.
|
|
3
|
+
from eva.language.models.wrappers.base import LanguageModel
|
|
4
|
+
from eva.language.models.wrappers.from_registry import ModelFromRegistry
|
|
5
|
+
from eva.language.models.wrappers.huggingface import HuggingFaceModel
|
|
6
|
+
from eva.language.models.wrappers.litellm import LiteLLMModel
|
|
5
7
|
|
|
6
8
|
try:
|
|
7
|
-
from eva.language.models.wrappers.vllm import
|
|
9
|
+
from eva.language.models.wrappers.vllm import VllmModel
|
|
8
10
|
|
|
9
|
-
__all__ = [
|
|
11
|
+
__all__ = [
|
|
12
|
+
"LanguageModel",
|
|
13
|
+
"HuggingFaceModel",
|
|
14
|
+
"LiteLLMModel",
|
|
15
|
+
"VllmModel",
|
|
16
|
+
"ModelFromRegistry",
|
|
17
|
+
]
|
|
10
18
|
except ImportError:
|
|
11
|
-
__all__ = ["
|
|
19
|
+
__all__ = ["LanguageModel", "HuggingFaceModel", "LiteLLMModel", "ModelFromRegistry"]
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
"""Base class for language model wrappers."""
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
from typing import Any, Callable, List
|
|
5
|
+
|
|
6
|
+
from typing_extensions import override
|
|
7
|
+
|
|
8
|
+
from eva.core.models.wrappers import base
|
|
9
|
+
from eva.language.data.messages import ModelSystemMessage
|
|
10
|
+
from eva.language.models.typings import TextBatch
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class LanguageModel(base.BaseModel[TextBatch, List[str]]):
|
|
14
|
+
"""Base class for language models.
|
|
15
|
+
|
|
16
|
+
Classes that inherit from this should implement the following methods:
|
|
17
|
+
- `load_model`: Loads & instantiates the model.
|
|
18
|
+
- `model_forward`: Implements the forward pass of the model. For API models,
|
|
19
|
+
this can be an API call.
|
|
20
|
+
- `format_inputs`: Preprocesses and converts the input batch into the format
|
|
21
|
+
expected by the `model_forward` method.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self, system_prompt: str | None, output_transforms: Callable | None = None
|
|
26
|
+
) -> None:
|
|
27
|
+
"""Creates a new model instance.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
system_prompt: The system prompt to use for the model (optional).
|
|
31
|
+
output_transforms: Optional transforms to apply to the output of
|
|
32
|
+
the model's forward pass.
|
|
33
|
+
"""
|
|
34
|
+
super().__init__(transforms=output_transforms)
|
|
35
|
+
|
|
36
|
+
self.system_message = ModelSystemMessage(content=system_prompt) if system_prompt else None
|
|
37
|
+
|
|
38
|
+
@override
|
|
39
|
+
def forward(self, batch: TextBatch) -> List[str]:
|
|
40
|
+
"""Forward pass of the model."""
|
|
41
|
+
inputs = self.format_inputs(batch)
|
|
42
|
+
return super().forward(inputs)
|
|
43
|
+
|
|
44
|
+
@abc.abstractmethod
|
|
45
|
+
def format_inputs(self, batch: TextBatch) -> Any:
|
|
46
|
+
"""Converts the inputs into the format expected by the model."""
|
|
47
|
+
raise NotImplementedError
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
"""Vision backbone helper class."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Callable, Dict, List
|
|
4
|
+
|
|
5
|
+
from torch import nn
|
|
6
|
+
from typing_extensions import override
|
|
7
|
+
|
|
8
|
+
from eva.core.models.wrappers import base
|
|
9
|
+
from eva.core.utils import factory
|
|
10
|
+
from eva.language.models.networks.registry import model_registry
|
|
11
|
+
from eva.language.models.typings import TextBatch
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ModelFromRegistry(base.BaseModel[TextBatch, List[str]]):
|
|
15
|
+
"""Wrapper class for vision backbone models.
|
|
16
|
+
|
|
17
|
+
This class can be used by load backbones available in eva's
|
|
18
|
+
model registry by name. New backbones can be registered by using
|
|
19
|
+
the `@backbone_registry.register(model_name)` decorator.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
model_name: str,
|
|
25
|
+
model_kwargs: Dict[str, Any] | None = None,
|
|
26
|
+
model_extra_kwargs: Dict[str, Any] | None = None,
|
|
27
|
+
transforms: Callable | None = None,
|
|
28
|
+
) -> None:
|
|
29
|
+
"""Initializes the model.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
model_name: The name of the model to load.
|
|
33
|
+
model_kwargs: The arguments used for instantiating the model.
|
|
34
|
+
model_extra_kwargs: Extra arguments used for instantiating the model.
|
|
35
|
+
transforms: The transforms to apply to the output tensor
|
|
36
|
+
produced by the model.
|
|
37
|
+
"""
|
|
38
|
+
super().__init__(transforms=transforms)
|
|
39
|
+
|
|
40
|
+
self._model_name = model_name
|
|
41
|
+
self._model_kwargs = model_kwargs or {}
|
|
42
|
+
self._model_extra_kwargs = model_extra_kwargs or {}
|
|
43
|
+
|
|
44
|
+
self.model = self.load_model()
|
|
45
|
+
|
|
46
|
+
@override
|
|
47
|
+
def load_model(self) -> nn.Module:
|
|
48
|
+
ModelFromRegistry.__name__ = self._model_name
|
|
49
|
+
|
|
50
|
+
return factory.ModuleFactory(
|
|
51
|
+
registry=model_registry,
|
|
52
|
+
name=self._model_name,
|
|
53
|
+
init_args=self._model_kwargs | self._model_extra_kwargs,
|
|
54
|
+
)
|
|
@@ -1,14 +1,16 @@
|
|
|
1
1
|
"""LLM wrapper for HuggingFace `transformers` models."""
|
|
2
2
|
|
|
3
|
-
from typing import Any, Dict, List, Literal
|
|
3
|
+
from typing import Any, Callable, Dict, List, Literal
|
|
4
4
|
|
|
5
5
|
from transformers.pipelines import pipeline
|
|
6
6
|
from typing_extensions import override
|
|
7
7
|
|
|
8
|
-
from eva.
|
|
8
|
+
from eva.language.models.typings import TextBatch
|
|
9
|
+
from eva.language.models.wrappers import base
|
|
10
|
+
from eva.language.utils.text import messages as message_utils
|
|
9
11
|
|
|
10
12
|
|
|
11
|
-
class
|
|
13
|
+
class HuggingFaceModel(base.LanguageModel):
|
|
12
14
|
"""Wrapper class for loading HuggingFace `transformers` models using pipelines."""
|
|
13
15
|
|
|
14
16
|
def __init__(
|
|
@@ -16,7 +18,9 @@ class HuggingFaceTextModel(base.BaseModel[List[str], List[str]]):
|
|
|
16
18
|
model_name_or_path: str,
|
|
17
19
|
task: Literal["text-generation"] = "text-generation",
|
|
18
20
|
model_kwargs: Dict[str, Any] | None = None,
|
|
21
|
+
system_prompt: str | None = None,
|
|
19
22
|
generation_kwargs: Dict[str, Any] | None = None,
|
|
23
|
+
chat_mode: bool = True,
|
|
20
24
|
) -> None:
|
|
21
25
|
"""Initializes the model.
|
|
22
26
|
|
|
@@ -26,27 +30,59 @@ class HuggingFaceTextModel(base.BaseModel[List[str], List[str]]):
|
|
|
26
30
|
model hub.
|
|
27
31
|
task: The pipeline task. Defaults to "text-generation".
|
|
28
32
|
model_kwargs: Additional arguments for configuring the pipeline.
|
|
33
|
+
system_prompt: System prompt to use.
|
|
29
34
|
generation_kwargs: Additional generation parameters (temperature, max_length, etc.).
|
|
35
|
+
chat_mode: Whether the specified model expects chat style messages. If set to False
|
|
36
|
+
the model is assumed to be a standard text completion model and will expect
|
|
37
|
+
plain text string inputs.
|
|
30
38
|
"""
|
|
31
|
-
super().__init__()
|
|
39
|
+
super().__init__(system_prompt=system_prompt)
|
|
32
40
|
|
|
33
41
|
self._model_name_or_path = model_name_or_path
|
|
34
42
|
self._task = task
|
|
35
43
|
self._model_kwargs = model_kwargs or {}
|
|
36
44
|
self._generation_kwargs = generation_kwargs or {}
|
|
45
|
+
self._chat_mode = chat_mode
|
|
37
46
|
|
|
38
|
-
self.load_model()
|
|
47
|
+
self.model = self.load_model()
|
|
39
48
|
|
|
40
49
|
@override
|
|
41
|
-
def load_model(self) ->
|
|
50
|
+
def load_model(self) -> Callable:
|
|
42
51
|
"""Loads the model as a Hugging Face pipeline."""
|
|
43
|
-
|
|
52
|
+
return pipeline(
|
|
44
53
|
task=self._task,
|
|
45
54
|
model=self._model_name_or_path,
|
|
46
55
|
trust_remote_code=True,
|
|
47
56
|
**self._model_kwargs,
|
|
48
57
|
)
|
|
49
58
|
|
|
59
|
+
@override
|
|
60
|
+
def format_inputs(self, batch: TextBatch) -> List[List[Dict[str, Any]]] | List[str]:
|
|
61
|
+
"""Formats inputs for HuggingFace models.
|
|
62
|
+
|
|
63
|
+
Note: If multiple system messages are present, they will be combined
|
|
64
|
+
into a single message, given that many models only support a single
|
|
65
|
+
system prompt.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
batch: A batch of text and image inputs.
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
When in chat mode, returns a batch of message series following
|
|
72
|
+
OpenAI's API format {"role": "user", "content": "..."}, for non-chat
|
|
73
|
+
models returns a list of plain text strings.
|
|
74
|
+
"""
|
|
75
|
+
message_batch, _, _ = TextBatch(*batch)
|
|
76
|
+
message_batch = message_utils.batch_insert_system_message(
|
|
77
|
+
message_batch, self.system_message
|
|
78
|
+
)
|
|
79
|
+
message_batch = list(map(message_utils.combine_system_messages, message_batch))
|
|
80
|
+
|
|
81
|
+
if self._chat_mode:
|
|
82
|
+
return list(map(message_utils.format_chat_message, message_batch))
|
|
83
|
+
else:
|
|
84
|
+
return list(map(message_utils.merge_message_contents, message_batch))
|
|
85
|
+
|
|
50
86
|
@override
|
|
51
87
|
def model_forward(self, prompts: List[str]) -> List[str]:
|
|
52
88
|
"""Generates text using the pipeline.
|
|
@@ -57,7 +93,7 @@ class HuggingFaceTextModel(base.BaseModel[List[str], List[str]]):
|
|
|
57
93
|
Returns:
|
|
58
94
|
The generated text as a string.
|
|
59
95
|
"""
|
|
60
|
-
outputs = self.
|
|
96
|
+
outputs = self.model(prompts, return_full_text=False, **self._generation_kwargs)
|
|
61
97
|
if outputs is None:
|
|
62
98
|
raise ValueError("Outputs from the model are None.")
|
|
63
99
|
results = []
|
|
@@ -1,77 +1,112 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""LiteLLM language model wrapper."""
|
|
2
2
|
|
|
3
|
+
import logging
|
|
3
4
|
from typing import Any, Dict, List
|
|
4
5
|
|
|
5
|
-
|
|
6
|
+
import backoff
|
|
7
|
+
import litellm
|
|
8
|
+
from litellm import batch_completion
|
|
9
|
+
from litellm.exceptions import (
|
|
10
|
+
APIConnectionError,
|
|
11
|
+
InternalServerError,
|
|
12
|
+
RateLimitError,
|
|
13
|
+
ServiceUnavailableError,
|
|
14
|
+
Timeout,
|
|
15
|
+
)
|
|
6
16
|
from loguru import logger
|
|
7
17
|
from typing_extensions import override
|
|
8
18
|
|
|
9
|
-
from eva.
|
|
19
|
+
from eva.language.models.typings import TextBatch
|
|
20
|
+
from eva.language.models.wrappers import base
|
|
21
|
+
from eva.language.utils.text import messages as message_utils
|
|
10
22
|
|
|
23
|
+
RETRYABLE_ERRORS = (
|
|
24
|
+
RateLimitError,
|
|
25
|
+
Timeout,
|
|
26
|
+
InternalServerError,
|
|
27
|
+
APIConnectionError,
|
|
28
|
+
ServiceUnavailableError,
|
|
29
|
+
)
|
|
11
30
|
|
|
12
|
-
class LiteLLMTextModel(base.BaseModel[List[str], List[str]]):
|
|
13
|
-
"""Wrapper class for using litellm for chat-based text generation.
|
|
14
31
|
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
message with a default "user" role, optionally prepends a system message,
|
|
18
|
-
and includes an API key if provided.
|
|
19
|
-
"""
|
|
32
|
+
class LiteLLMModel(base.LanguageModel):
|
|
33
|
+
"""Wrapper class for LiteLLM language models."""
|
|
20
34
|
|
|
21
35
|
def __init__(
|
|
22
36
|
self,
|
|
23
|
-
|
|
37
|
+
model_name: str,
|
|
24
38
|
model_kwargs: Dict[str, Any] | None = None,
|
|
25
|
-
|
|
26
|
-
|
|
39
|
+
system_prompt: str | None = None,
|
|
40
|
+
log_level: int | None = logging.INFO,
|
|
41
|
+
):
|
|
42
|
+
"""Initialize the LiteLLM Wrapper.
|
|
27
43
|
|
|
28
44
|
Args:
|
|
29
|
-
|
|
30
|
-
(e.g.,"openai/gpt-4o" or "anthropic/claude-3-sonnet-20240229").
|
|
45
|
+
model_name: The name of the model to use.
|
|
31
46
|
model_kwargs: Additional keyword arguments to pass during
|
|
32
47
|
generation (e.g., `temperature`, `max_tokens`).
|
|
48
|
+
system_prompt: The system prompt to use (optional).
|
|
49
|
+
log_level: Optional logging level for LiteLLM. Defaults to WARNING.
|
|
33
50
|
"""
|
|
34
|
-
super().__init__()
|
|
35
|
-
self._model_name_or_path = model_name_or_path
|
|
36
|
-
self._model_kwargs = model_kwargs or {}
|
|
37
|
-
self.load_model()
|
|
51
|
+
super().__init__(system_prompt=system_prompt)
|
|
38
52
|
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
"""Prepares the litellm model.
|
|
53
|
+
self.model_name = model_name
|
|
54
|
+
self.model_kwargs = model_kwargs or {}
|
|
42
55
|
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
pass
|
|
56
|
+
litellm.suppress_debug_info = True
|
|
57
|
+
|
|
58
|
+
if log_level is not None:
|
|
59
|
+
logging.getLogger("LiteLLM").setLevel(log_level)
|
|
48
60
|
|
|
49
61
|
@override
|
|
50
|
-
def
|
|
51
|
-
"""
|
|
62
|
+
def format_inputs(self, batch: TextBatch) -> List[List[Dict[str, Any]]]:
|
|
63
|
+
"""Formats inputs for LiteLLM.
|
|
52
64
|
|
|
53
65
|
Args:
|
|
54
|
-
|
|
66
|
+
batch: A batch of text inputs.
|
|
55
67
|
|
|
56
68
|
Returns:
|
|
57
|
-
A list of
|
|
58
|
-
|
|
69
|
+
A list of messages in the following format:
|
|
70
|
+
[
|
|
71
|
+
{
|
|
72
|
+
"role": ...
|
|
73
|
+
"content": ...
|
|
74
|
+
},
|
|
75
|
+
...
|
|
76
|
+
]
|
|
59
77
|
"""
|
|
60
|
-
|
|
78
|
+
message_batch, _, _ = TextBatch(*batch)
|
|
61
79
|
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
messages=messages,
|
|
65
|
-
**self._model_kwargs,
|
|
80
|
+
message_batch = message_utils.batch_insert_system_message(
|
|
81
|
+
message_batch, self.system_message
|
|
66
82
|
)
|
|
83
|
+
message_batch = list(map(message_utils.combine_system_messages, message_batch))
|
|
67
84
|
|
|
68
|
-
|
|
69
|
-
for i, response in enumerate(responses):
|
|
70
|
-
if isinstance(response, Exception):
|
|
71
|
-
error_msg = f"Error generating text for prompt {i}: {response}"
|
|
72
|
-
logger.error(error_msg)
|
|
73
|
-
raise RuntimeError(error_msg)
|
|
74
|
-
else:
|
|
75
|
-
results.append(response["choices"][0]["message"]["content"])
|
|
85
|
+
return list(map(message_utils.format_chat_message, message_batch))
|
|
76
86
|
|
|
77
|
-
|
|
87
|
+
@override
|
|
88
|
+
@backoff.on_exception(
|
|
89
|
+
backoff.expo,
|
|
90
|
+
RETRYABLE_ERRORS,
|
|
91
|
+
max_tries=20,
|
|
92
|
+
jitter=backoff.full_jitter,
|
|
93
|
+
on_backoff=lambda details: logger.warning(
|
|
94
|
+
f"Retrying due to {details.get('exception') or 'Unknown error'}"
|
|
95
|
+
),
|
|
96
|
+
)
|
|
97
|
+
def model_forward(self, batch: List[List[Dict[str, Any]]]) -> List[str]:
|
|
98
|
+
"""Generates output text through API calls via LiteLLM's batch completion functionality."""
|
|
99
|
+
outputs = batch_completion(model=self.model_name, messages=batch, **self.model_kwargs)
|
|
100
|
+
self._raise_exceptions(outputs)
|
|
101
|
+
|
|
102
|
+
return [
|
|
103
|
+
output["choices"][0]["message"]["content"]
|
|
104
|
+
for output in outputs
|
|
105
|
+
if output["choices"][0]["message"]["role"] == "assistant"
|
|
106
|
+
]
|
|
107
|
+
|
|
108
|
+
def _raise_exceptions(self, outputs: list):
|
|
109
|
+
for output in outputs:
|
|
110
|
+
if isinstance(output, Exception):
|
|
111
|
+
logger.error(f"Model {self.model_name} encountered an error: {output}")
|
|
112
|
+
raise output
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""LLM wrapper for vLLM models."""
|
|
2
2
|
|
|
3
|
-
from typing import Any, Dict, List
|
|
3
|
+
from typing import Any, Dict, List
|
|
4
4
|
|
|
5
5
|
from loguru import logger
|
|
6
6
|
from typing_extensions import override
|
|
@@ -11,17 +11,20 @@ try:
|
|
|
11
11
|
from vllm.transformers_utils.tokenizer import AnyTokenizer # type: ignore
|
|
12
12
|
except ImportError as e:
|
|
13
13
|
raise ImportError(
|
|
14
|
-
"vLLM is required for
|
|
14
|
+
"vLLM is required for VllmModel but not installed. "
|
|
15
15
|
"vLLM must be installed manually as it requires CUDA and is not included in dependencies. "
|
|
16
16
|
"Install with: pip install vllm "
|
|
17
17
|
"Note: vLLM requires Linux with CUDA support for optimal performance. "
|
|
18
|
-
"For alternatives, consider using
|
|
18
|
+
"For alternatives, consider using HuggingFaceModel or LiteLLMModel."
|
|
19
19
|
) from e
|
|
20
20
|
|
|
21
|
-
from eva.
|
|
21
|
+
from eva.language.data.messages import MessageSeries
|
|
22
|
+
from eva.language.models.typings import TextBatch
|
|
23
|
+
from eva.language.models.wrappers import base
|
|
24
|
+
from eva.language.utils.text import messages as message_utils
|
|
22
25
|
|
|
23
26
|
|
|
24
|
-
class
|
|
27
|
+
class VllmModel(base.LanguageModel):
|
|
25
28
|
"""Wrapper class for using vLLM for text generation.
|
|
26
29
|
|
|
27
30
|
This wrapper loads a vLLM model, sets up the tokenizer and sampling
|
|
@@ -34,6 +37,7 @@ class VLLMTextModel(base.BaseModel):
|
|
|
34
37
|
self,
|
|
35
38
|
model_name_or_path: str,
|
|
36
39
|
model_kwargs: Dict[str, Any] | None = None,
|
|
40
|
+
system_prompt: str | None = None,
|
|
37
41
|
generation_kwargs: Dict[str, Any] | None = None,
|
|
38
42
|
) -> None:
|
|
39
43
|
"""Initializes the vLLM model wrapper.
|
|
@@ -44,12 +48,13 @@ class VLLMTextModel(base.BaseModel):
|
|
|
44
48
|
model_kwargs: Arguments required to initialize the vLLM model,
|
|
45
49
|
see [link](https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py)
|
|
46
50
|
for more information.
|
|
51
|
+
system_prompt: System prompt to use.
|
|
47
52
|
generation_kwargs: Arguments required to generate the output,
|
|
48
53
|
need to align with the arguments of
|
|
49
54
|
[vllm.SamplingParams](https://github.com/vllm-project/vllm/blob/main/vllm/sampling_params.py).
|
|
50
55
|
|
|
51
56
|
"""
|
|
52
|
-
super().__init__()
|
|
57
|
+
super().__init__(system_prompt=system_prompt)
|
|
53
58
|
self._model_name_or_path = model_name_or_path
|
|
54
59
|
self._model_kwargs = model_kwargs or {}
|
|
55
60
|
self._generation_kwargs = generation_kwargs or {}
|
|
@@ -71,11 +76,11 @@ class VLLMTextModel(base.BaseModel):
|
|
|
71
76
|
raise RuntimeError("Model not initialized")
|
|
72
77
|
self._llm_tokenizer = self._llm_model.get_tokenizer()
|
|
73
78
|
|
|
74
|
-
def
|
|
79
|
+
def _tokenize_messages(self, messages: List[MessageSeries]) -> List[TokensPrompt]:
|
|
75
80
|
"""Apply chat template to the messages.
|
|
76
81
|
|
|
77
82
|
Args:
|
|
78
|
-
|
|
83
|
+
messages: List of raw user strings.
|
|
79
84
|
|
|
80
85
|
Returns:
|
|
81
86
|
List of encoded messages.
|
|
@@ -90,7 +95,8 @@ class VLLMTextModel(base.BaseModel):
|
|
|
90
95
|
if not hasattr(self._llm_tokenizer, "chat_template"):
|
|
91
96
|
raise ValueError("Tokenizer does not have a chat template.")
|
|
92
97
|
|
|
93
|
-
chat_messages =
|
|
98
|
+
chat_messages = list(map(message_utils.format_chat_message, messages))
|
|
99
|
+
|
|
94
100
|
encoded_messages = self._llm_tokenizer.apply_chat_template(
|
|
95
101
|
chat_messages, # type: ignore
|
|
96
102
|
tokenize=True,
|
|
@@ -131,11 +137,30 @@ class VLLMTextModel(base.BaseModel):
|
|
|
131
137
|
|
|
132
138
|
return result
|
|
133
139
|
|
|
134
|
-
|
|
140
|
+
@override
|
|
141
|
+
def format_inputs(self, batch: TextBatch) -> List[TokensPrompt]:
|
|
142
|
+
"""Formats inputs for vLLM models.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
batch: A batch of text and image inputs.
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
List of formatted prompts.
|
|
149
|
+
"""
|
|
150
|
+
message_batch, _, _ = TextBatch(*batch)
|
|
151
|
+
message_batch = message_utils.batch_insert_system_message(
|
|
152
|
+
message_batch, self.system_message
|
|
153
|
+
)
|
|
154
|
+
message_batch = list(map(message_utils.combine_system_messages, message_batch))
|
|
155
|
+
|
|
156
|
+
return self._tokenize_messages(message_batch)
|
|
157
|
+
|
|
158
|
+
@override
|
|
159
|
+
def model_forward(self, batch: List[TokensPrompt]) -> List[str]:
|
|
135
160
|
"""Generates text for the given prompt using the vLLM model.
|
|
136
161
|
|
|
137
162
|
Args:
|
|
138
|
-
|
|
163
|
+
batch: A list encoded / tokenized messages (TokensPrompt objects).
|
|
139
164
|
|
|
140
165
|
Returns:
|
|
141
166
|
The generated text response.
|
|
@@ -144,6 +169,5 @@ class VLLMTextModel(base.BaseModel):
|
|
|
144
169
|
if self._llm_model is None:
|
|
145
170
|
raise RuntimeError("Model not initialized")
|
|
146
171
|
|
|
147
|
-
|
|
148
|
-
outputs = self._llm_model.generate(prompt_tokens, SamplingParams(**self._generation_kwargs))
|
|
172
|
+
outputs = self._llm_model.generate(batch, SamplingParams(**self._generation_kwargs))
|
|
149
173
|
return [output.outputs[0].text for output in outputs]
|
eva/language/utils/__init__.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Language utilities and helper functions."""
|
|
2
2
|
|
|
3
3
|
from eva.language.utils.str_to_int_tensor import CastStrToIntTensor
|
|
4
|
+
from eva.language.utils.text.messages import format_chat_message
|
|
4
5
|
|
|
5
|
-
__all__ = ["CastStrToIntTensor"]
|
|
6
|
+
__all__ = ["CastStrToIntTensor", "format_chat_message"]
|