kaiko-eva 0.3.3__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaiko-eva might be problematic. Click here for more details.
- eva/core/callbacks/config.py +15 -6
- eva/core/callbacks/writers/embeddings/base.py +44 -10
- eva/core/cli/setup.py +1 -1
- eva/core/data/dataloaders/__init__.py +1 -2
- eva/core/data/samplers/classification/balanced.py +24 -12
- eva/core/data/samplers/random.py +17 -10
- eva/core/interface/interface.py +21 -0
- eva/core/loggers/utils/wandb.py +4 -1
- eva/core/models/modules/module.py +2 -2
- eva/core/models/wrappers/base.py +2 -2
- eva/core/models/wrappers/from_function.py +3 -3
- eva/core/models/wrappers/from_torchhub.py +9 -7
- eva/core/models/wrappers/huggingface.py +4 -5
- eva/core/models/wrappers/onnx.py +5 -5
- eva/core/trainers/trainer.py +13 -1
- eva/core/utils/__init__.py +2 -1
- eva/core/utils/distributed.py +12 -0
- eva/core/utils/paths.py +14 -0
- eva/core/utils/requirements.py +52 -6
- eva/language/__init__.py +2 -1
- eva/language/callbacks/__init__.py +5 -0
- eva/language/callbacks/writers/__init__.py +5 -0
- eva/language/callbacks/writers/prediction.py +201 -0
- eva/language/data/dataloaders/__init__.py +5 -0
- eva/language/data/dataloaders/collate_fn/__init__.py +5 -0
- eva/language/data/dataloaders/collate_fn/text.py +57 -0
- eva/language/data/datasets/__init__.py +3 -1
- eva/language/data/datasets/{language.py → base.py} +1 -1
- eva/language/data/datasets/classification/base.py +3 -43
- eva/language/data/datasets/classification/pubmedqa.py +36 -4
- eva/language/data/datasets/prediction.py +151 -0
- eva/language/data/datasets/schemas.py +18 -0
- eva/language/data/datasets/text.py +92 -0
- eva/language/data/datasets/typings.py +39 -0
- eva/language/data/messages.py +60 -0
- eva/language/models/__init__.py +15 -11
- eva/language/models/modules/__init__.py +2 -2
- eva/language/models/modules/language.py +94 -0
- eva/language/models/networks/__init__.py +12 -0
- eva/language/models/networks/alibaba.py +26 -0
- eva/language/models/networks/api/__init__.py +11 -0
- eva/language/models/networks/api/anthropic.py +34 -0
- eva/language/models/networks/registry.py +5 -0
- eva/language/models/typings.py +56 -0
- eva/language/models/wrappers/__init__.py +13 -5
- eva/language/models/wrappers/base.py +47 -0
- eva/language/models/wrappers/from_registry.py +54 -0
- eva/language/models/wrappers/huggingface.py +57 -11
- eva/language/models/wrappers/litellm.py +91 -46
- eva/language/models/wrappers/vllm.py +37 -13
- eva/language/utils/__init__.py +2 -1
- eva/language/utils/str_to_int_tensor.py +20 -12
- eva/language/utils/text/__init__.py +5 -0
- eva/language/utils/text/messages.py +113 -0
- eva/multimodal/__init__.py +6 -0
- eva/multimodal/callbacks/__init__.py +5 -0
- eva/multimodal/callbacks/writers/__init__.py +5 -0
- eva/multimodal/callbacks/writers/prediction.py +39 -0
- eva/multimodal/data/__init__.py +5 -0
- eva/multimodal/data/dataloaders/__init__.py +5 -0
- eva/multimodal/data/dataloaders/collate_fn/__init__.py +5 -0
- eva/multimodal/data/dataloaders/collate_fn/text_image.py +28 -0
- eva/multimodal/data/datasets/__init__.py +6 -0
- eva/multimodal/data/datasets/base.py +13 -0
- eva/multimodal/data/datasets/multiple_choice/__init__.py +5 -0
- eva/multimodal/data/datasets/multiple_choice/patch_camelyon.py +80 -0
- eva/multimodal/data/datasets/schemas.py +14 -0
- eva/multimodal/data/datasets/text_image.py +77 -0
- eva/multimodal/data/datasets/typings.py +27 -0
- eva/multimodal/models/__init__.py +8 -0
- eva/multimodal/models/modules/__init__.py +5 -0
- eva/multimodal/models/modules/vision_language.py +56 -0
- eva/multimodal/models/networks/__init__.py +14 -0
- eva/multimodal/models/networks/alibaba.py +40 -0
- eva/multimodal/models/networks/api/__init__.py +11 -0
- eva/multimodal/models/networks/api/anthropic.py +34 -0
- eva/multimodal/models/networks/others.py +48 -0
- eva/multimodal/models/networks/registry.py +5 -0
- eva/multimodal/models/typings.py +27 -0
- eva/multimodal/models/wrappers/__init__.py +13 -0
- eva/multimodal/models/wrappers/base.py +48 -0
- eva/multimodal/models/wrappers/from_registry.py +54 -0
- eva/multimodal/models/wrappers/huggingface.py +193 -0
- eva/multimodal/models/wrappers/litellm.py +58 -0
- eva/multimodal/utils/__init__.py +1 -0
- eva/multimodal/utils/batch/__init__.py +5 -0
- eva/multimodal/utils/batch/unpack.py +11 -0
- eva/multimodal/utils/image/__init__.py +5 -0
- eva/multimodal/utils/image/encode.py +28 -0
- eva/multimodal/utils/text/__init__.py +1 -0
- eva/multimodal/utils/text/messages.py +79 -0
- eva/vision/data/datasets/classification/breakhis.py +5 -8
- eva/vision/data/datasets/classification/panda.py +12 -5
- eva/vision/data/datasets/classification/patch_camelyon.py +8 -6
- eva/vision/data/datasets/segmentation/btcv.py +1 -1
- eva/vision/data/datasets/segmentation/consep.py +1 -1
- eva/vision/data/datasets/segmentation/lits17.py +1 -1
- eva/vision/data/datasets/segmentation/monusac.py +15 -6
- eva/vision/data/datasets/segmentation/msd_task7_pancreas.py +1 -1
- eva/vision/data/transforms/__init__.py +2 -1
- eva/vision/data/transforms/base/__init__.py +2 -1
- eva/vision/data/transforms/base/monai.py +2 -2
- eva/vision/data/transforms/base/torchvision.py +33 -0
- eva/vision/data/transforms/common/squeeze.py +6 -3
- eva/vision/data/transforms/croppad/crop_foreground.py +8 -7
- eva/vision/data/transforms/croppad/rand_crop_by_label_classes.py +6 -5
- eva/vision/data/transforms/croppad/rand_crop_by_pos_neg_label.py +6 -5
- eva/vision/data/transforms/croppad/rand_spatial_crop.py +8 -7
- eva/vision/data/transforms/croppad/spatial_pad.py +6 -6
- eva/vision/data/transforms/intensity/rand_scale_intensity.py +3 -3
- eva/vision/data/transforms/intensity/rand_shift_intensity.py +3 -3
- eva/vision/data/transforms/intensity/scale_intensity_ranged.py +5 -5
- eva/vision/data/transforms/spatial/__init__.py +2 -1
- eva/vision/data/transforms/spatial/flip.py +8 -7
- eva/vision/data/transforms/spatial/functional/__init__.py +5 -0
- eva/vision/data/transforms/spatial/functional/resize.py +26 -0
- eva/vision/data/transforms/spatial/resize.py +63 -0
- eva/vision/data/transforms/spatial/rotate.py +8 -7
- eva/vision/data/transforms/spatial/spacing.py +7 -6
- eva/vision/data/transforms/utility/ensure_channel_first.py +6 -6
- eva/vision/models/networks/backbones/universal/vit.py +24 -0
- eva/vision/models/wrappers/from_registry.py +6 -5
- eva/vision/models/wrappers/from_timm.py +6 -4
- {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/METADATA +17 -3
- {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/RECORD +128 -66
- eva/core/data/dataloaders/collate_fn/__init__.py +0 -5
- eva/core/data/dataloaders/collate_fn/collate.py +0 -24
- eva/language/models/modules/text.py +0 -85
- {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/WHEEL +0 -0
- {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/entry_points.txt +0 -0
- {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,22 +1,34 @@
|
|
|
1
1
|
"""LLM wrapper for HuggingFace `transformers` models."""
|
|
2
2
|
|
|
3
|
-
from typing import Any, Dict, List, Literal
|
|
3
|
+
from typing import Any, Callable, Dict, List, Literal
|
|
4
4
|
|
|
5
5
|
from transformers.pipelines import pipeline
|
|
6
6
|
from typing_extensions import override
|
|
7
7
|
|
|
8
|
-
from eva.
|
|
8
|
+
from eva.language.models.typings import ModelOutput, TextBatch
|
|
9
|
+
from eva.language.models.wrappers import base
|
|
10
|
+
from eva.language.utils.text import messages as message_utils
|
|
9
11
|
|
|
10
12
|
|
|
11
|
-
class
|
|
13
|
+
class HuggingFaceModel(base.LanguageModel):
|
|
12
14
|
"""Wrapper class for loading HuggingFace `transformers` models using pipelines."""
|
|
13
15
|
|
|
16
|
+
_default_generation_kwargs = {
|
|
17
|
+
"temperature": 0.0,
|
|
18
|
+
"max_new_tokens": 1024,
|
|
19
|
+
"do_sample": False,
|
|
20
|
+
"top_p": 1.0,
|
|
21
|
+
}
|
|
22
|
+
"""Default HF model parameters for evaluation."""
|
|
23
|
+
|
|
14
24
|
def __init__(
|
|
15
25
|
self,
|
|
16
26
|
model_name_or_path: str,
|
|
17
27
|
task: Literal["text-generation"] = "text-generation",
|
|
18
28
|
model_kwargs: Dict[str, Any] | None = None,
|
|
29
|
+
system_prompt: str | None = None,
|
|
19
30
|
generation_kwargs: Dict[str, Any] | None = None,
|
|
31
|
+
chat_mode: bool = True,
|
|
20
32
|
) -> None:
|
|
21
33
|
"""Initializes the model.
|
|
22
34
|
|
|
@@ -26,21 +38,26 @@ class HuggingFaceTextModel(base.BaseModel[List[str], List[str]]):
|
|
|
26
38
|
model hub.
|
|
27
39
|
task: The pipeline task. Defaults to "text-generation".
|
|
28
40
|
model_kwargs: Additional arguments for configuring the pipeline.
|
|
41
|
+
system_prompt: System prompt to use.
|
|
29
42
|
generation_kwargs: Additional generation parameters (temperature, max_length, etc.).
|
|
43
|
+
chat_mode: Whether the specified model expects chat style messages. If set to False
|
|
44
|
+
the model is assumed to be a standard text completion model and will expect
|
|
45
|
+
plain text string inputs.
|
|
30
46
|
"""
|
|
31
|
-
super().__init__()
|
|
47
|
+
super().__init__(system_prompt=system_prompt)
|
|
32
48
|
|
|
33
49
|
self._model_name_or_path = model_name_or_path
|
|
34
50
|
self._task = task
|
|
35
51
|
self._model_kwargs = model_kwargs or {}
|
|
36
|
-
self._generation_kwargs = generation_kwargs or {}
|
|
52
|
+
self._generation_kwargs = self._default_generation_kwargs | (generation_kwargs or {})
|
|
53
|
+
self._chat_mode = chat_mode
|
|
37
54
|
|
|
38
|
-
self.load_model()
|
|
55
|
+
self.model = self.load_model()
|
|
39
56
|
|
|
40
57
|
@override
|
|
41
|
-
def load_model(self) ->
|
|
58
|
+
def load_model(self) -> Callable:
|
|
42
59
|
"""Loads the model as a Hugging Face pipeline."""
|
|
43
|
-
|
|
60
|
+
return pipeline(
|
|
44
61
|
task=self._task,
|
|
45
62
|
model=self._model_name_or_path,
|
|
46
63
|
trust_remote_code=True,
|
|
@@ -48,7 +65,34 @@ class HuggingFaceTextModel(base.BaseModel[List[str], List[str]]):
|
|
|
48
65
|
)
|
|
49
66
|
|
|
50
67
|
@override
|
|
51
|
-
def
|
|
68
|
+
def format_inputs(self, batch: TextBatch) -> List[List[Dict[str, Any]]] | List[str]:
|
|
69
|
+
"""Formats inputs for HuggingFace models.
|
|
70
|
+
|
|
71
|
+
Note: If multiple system messages are present, they will be combined
|
|
72
|
+
into a single message, given that many models only support a single
|
|
73
|
+
system prompt.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
batch: A batch of text and image inputs.
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
When in chat mode, returns a batch of message series following
|
|
80
|
+
OpenAI's API format {"role": "user", "content": "..."}, for non-chat
|
|
81
|
+
models returns a list of plain text strings.
|
|
82
|
+
"""
|
|
83
|
+
message_batch, _, _ = TextBatch(*batch)
|
|
84
|
+
message_batch = message_utils.batch_insert_system_message(
|
|
85
|
+
message_batch, self.system_message
|
|
86
|
+
)
|
|
87
|
+
message_batch = list(map(message_utils.combine_system_messages, message_batch))
|
|
88
|
+
|
|
89
|
+
if self._chat_mode:
|
|
90
|
+
return list(map(message_utils.format_chat_message, message_batch))
|
|
91
|
+
else:
|
|
92
|
+
return list(map(message_utils.merge_message_contents, message_batch))
|
|
93
|
+
|
|
94
|
+
@override
|
|
95
|
+
def model_forward(self, prompts: List[str]) -> ModelOutput:
|
|
52
96
|
"""Generates text using the pipeline.
|
|
53
97
|
|
|
54
98
|
Args:
|
|
@@ -57,13 +101,15 @@ class HuggingFaceTextModel(base.BaseModel[List[str], List[str]]):
|
|
|
57
101
|
Returns:
|
|
58
102
|
The generated text as a string.
|
|
59
103
|
"""
|
|
60
|
-
outputs = self.
|
|
104
|
+
outputs = self.model(prompts, return_full_text=False, **self._generation_kwargs)
|
|
61
105
|
if outputs is None:
|
|
62
106
|
raise ValueError("Outputs from the model are None.")
|
|
107
|
+
|
|
63
108
|
results = []
|
|
64
109
|
for output in outputs:
|
|
65
110
|
if isinstance(output, list):
|
|
66
111
|
results.append(output[0]["generated_text"]) # type: ignore
|
|
67
112
|
else:
|
|
68
113
|
results.append(output["generated_text"]) # type: ignore
|
|
69
|
-
|
|
114
|
+
|
|
115
|
+
return ModelOutput(generated_text=results)
|
|
@@ -1,77 +1,122 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""LiteLLM language model wrapper."""
|
|
2
2
|
|
|
3
|
+
import logging
|
|
3
4
|
from typing import Any, Dict, List
|
|
4
5
|
|
|
5
|
-
|
|
6
|
+
import backoff
|
|
7
|
+
import litellm
|
|
8
|
+
from litellm import batch_completion
|
|
9
|
+
from litellm.exceptions import (
|
|
10
|
+
APIConnectionError,
|
|
11
|
+
InternalServerError,
|
|
12
|
+
RateLimitError,
|
|
13
|
+
ServiceUnavailableError,
|
|
14
|
+
Timeout,
|
|
15
|
+
)
|
|
6
16
|
from loguru import logger
|
|
7
17
|
from typing_extensions import override
|
|
8
18
|
|
|
9
|
-
from eva.
|
|
19
|
+
from eva.language.models.typings import ModelOutput, TextBatch
|
|
20
|
+
from eva.language.models.wrappers import base
|
|
21
|
+
from eva.language.utils.text import messages as message_utils
|
|
10
22
|
|
|
23
|
+
RETRYABLE_ERRORS = (
|
|
24
|
+
RateLimitError,
|
|
25
|
+
Timeout,
|
|
26
|
+
InternalServerError,
|
|
27
|
+
APIConnectionError,
|
|
28
|
+
ServiceUnavailableError,
|
|
29
|
+
)
|
|
11
30
|
|
|
12
|
-
class LiteLLMTextModel(base.BaseModel[List[str], List[str]]):
|
|
13
|
-
"""Wrapper class for using litellm for chat-based text generation.
|
|
14
31
|
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
32
|
+
class LiteLLMModel(base.LanguageModel):
|
|
33
|
+
"""Wrapper class for LiteLLM language models."""
|
|
34
|
+
|
|
35
|
+
_default_model_kwargs = {
|
|
36
|
+
"temperature": 0.0,
|
|
37
|
+
"max_completion_tokens": 1024,
|
|
38
|
+
"top_p": 1.0,
|
|
39
|
+
"seed": 42,
|
|
40
|
+
}
|
|
41
|
+
"""Default API model parameters for evaluation."""
|
|
20
42
|
|
|
21
43
|
def __init__(
|
|
22
44
|
self,
|
|
23
|
-
|
|
45
|
+
model_name: str,
|
|
24
46
|
model_kwargs: Dict[str, Any] | None = None,
|
|
25
|
-
|
|
26
|
-
|
|
47
|
+
system_prompt: str | None = None,
|
|
48
|
+
log_level: int | None = logging.INFO,
|
|
49
|
+
):
|
|
50
|
+
"""Initialize the LiteLLM Wrapper.
|
|
27
51
|
|
|
28
52
|
Args:
|
|
29
|
-
|
|
30
|
-
(e.g.,"openai/gpt-4o" or "anthropic/claude-3-sonnet-20240229").
|
|
53
|
+
model_name: The name of the model to use.
|
|
31
54
|
model_kwargs: Additional keyword arguments to pass during
|
|
32
55
|
generation (e.g., `temperature`, `max_tokens`).
|
|
56
|
+
system_prompt: The system prompt to use (optional).
|
|
57
|
+
log_level: Optional logging level for LiteLLM. Defaults to WARNING.
|
|
33
58
|
"""
|
|
34
|
-
super().__init__()
|
|
35
|
-
self._model_name_or_path = model_name_or_path
|
|
36
|
-
self._model_kwargs = model_kwargs or {}
|
|
37
|
-
self.load_model()
|
|
59
|
+
super().__init__(system_prompt=system_prompt)
|
|
38
60
|
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
"""Prepares the litellm model.
|
|
61
|
+
self.model_name = model_name
|
|
62
|
+
self.model_kwargs = self._default_model_kwargs | (model_kwargs or {})
|
|
42
63
|
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
64
|
+
litellm.suppress_debug_info = True
|
|
65
|
+
litellm.drop_params = True
|
|
66
|
+
|
|
67
|
+
if log_level is not None:
|
|
68
|
+
logging.getLogger("LiteLLM").setLevel(log_level)
|
|
48
69
|
|
|
49
70
|
@override
|
|
50
|
-
def
|
|
51
|
-
"""
|
|
71
|
+
def format_inputs(self, batch: TextBatch) -> List[List[Dict[str, Any]]]:
|
|
72
|
+
"""Formats inputs for LiteLLM.
|
|
52
73
|
|
|
53
74
|
Args:
|
|
54
|
-
|
|
75
|
+
batch: A batch of text inputs.
|
|
55
76
|
|
|
56
77
|
Returns:
|
|
57
|
-
A list of
|
|
58
|
-
|
|
78
|
+
A list of messages in the following format:
|
|
79
|
+
[
|
|
80
|
+
{
|
|
81
|
+
"role": ...
|
|
82
|
+
"content": ...
|
|
83
|
+
},
|
|
84
|
+
...
|
|
85
|
+
]
|
|
59
86
|
"""
|
|
60
|
-
|
|
87
|
+
message_batch, _, _ = TextBatch(*batch)
|
|
61
88
|
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
messages=messages,
|
|
65
|
-
**self._model_kwargs,
|
|
89
|
+
message_batch = message_utils.batch_insert_system_message(
|
|
90
|
+
message_batch, self.system_message
|
|
66
91
|
)
|
|
92
|
+
message_batch = list(map(message_utils.combine_system_messages, message_batch))
|
|
67
93
|
|
|
68
|
-
|
|
69
|
-
for i, response in enumerate(responses):
|
|
70
|
-
if isinstance(response, Exception):
|
|
71
|
-
error_msg = f"Error generating text for prompt {i}: {response}"
|
|
72
|
-
logger.error(error_msg)
|
|
73
|
-
raise RuntimeError(error_msg)
|
|
74
|
-
else:
|
|
75
|
-
results.append(response["choices"][0]["message"]["content"])
|
|
94
|
+
return list(map(message_utils.format_chat_message, message_batch))
|
|
76
95
|
|
|
77
|
-
|
|
96
|
+
@override
|
|
97
|
+
@backoff.on_exception(
|
|
98
|
+
backoff.expo,
|
|
99
|
+
RETRYABLE_ERRORS,
|
|
100
|
+
max_tries=20,
|
|
101
|
+
jitter=backoff.full_jitter,
|
|
102
|
+
on_backoff=lambda details: logger.warning(
|
|
103
|
+
f"Retrying due to {details.get('exception') or 'Unknown error'}"
|
|
104
|
+
),
|
|
105
|
+
)
|
|
106
|
+
def model_forward(self, batch: List[List[Dict[str, Any]]]) -> ModelOutput:
|
|
107
|
+
"""Generates output text through API calls via LiteLLM's batch completion functionality."""
|
|
108
|
+
outputs = batch_completion(model=self.model_name, messages=batch, **self.model_kwargs)
|
|
109
|
+
self._raise_exceptions(outputs)
|
|
110
|
+
|
|
111
|
+
generated_text = [
|
|
112
|
+
output["choices"][0]["message"]["content"]
|
|
113
|
+
for output in outputs
|
|
114
|
+
if output["choices"][0]["message"]["role"] == "assistant"
|
|
115
|
+
]
|
|
116
|
+
return ModelOutput(generated_text=generated_text)
|
|
117
|
+
|
|
118
|
+
def _raise_exceptions(self, outputs: list):
|
|
119
|
+
for output in outputs:
|
|
120
|
+
if isinstance(output, Exception):
|
|
121
|
+
logger.error(f"Model {self.model_name} encountered an error: {output}")
|
|
122
|
+
raise output
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""LLM wrapper for vLLM models."""
|
|
2
2
|
|
|
3
|
-
from typing import Any, Dict, List
|
|
3
|
+
from typing import Any, Dict, List
|
|
4
4
|
|
|
5
5
|
from loguru import logger
|
|
6
6
|
from typing_extensions import override
|
|
@@ -11,17 +11,20 @@ try:
|
|
|
11
11
|
from vllm.transformers_utils.tokenizer import AnyTokenizer # type: ignore
|
|
12
12
|
except ImportError as e:
|
|
13
13
|
raise ImportError(
|
|
14
|
-
"vLLM is required for
|
|
14
|
+
"vLLM is required for VllmModel but not installed. "
|
|
15
15
|
"vLLM must be installed manually as it requires CUDA and is not included in dependencies. "
|
|
16
16
|
"Install with: pip install vllm "
|
|
17
17
|
"Note: vLLM requires Linux with CUDA support for optimal performance. "
|
|
18
|
-
"For alternatives, consider using
|
|
18
|
+
"For alternatives, consider using HuggingFaceModel or LiteLLMModel."
|
|
19
19
|
) from e
|
|
20
20
|
|
|
21
|
-
from eva.
|
|
21
|
+
from eva.language.data.messages import MessageSeries
|
|
22
|
+
from eva.language.models.typings import TextBatch
|
|
23
|
+
from eva.language.models.wrappers import base
|
|
24
|
+
from eva.language.utils.text import messages as message_utils
|
|
22
25
|
|
|
23
26
|
|
|
24
|
-
class
|
|
27
|
+
class VllmModel(base.LanguageModel):
|
|
25
28
|
"""Wrapper class for using vLLM for text generation.
|
|
26
29
|
|
|
27
30
|
This wrapper loads a vLLM model, sets up the tokenizer and sampling
|
|
@@ -34,6 +37,7 @@ class VLLMTextModel(base.BaseModel):
|
|
|
34
37
|
self,
|
|
35
38
|
model_name_or_path: str,
|
|
36
39
|
model_kwargs: Dict[str, Any] | None = None,
|
|
40
|
+
system_prompt: str | None = None,
|
|
37
41
|
generation_kwargs: Dict[str, Any] | None = None,
|
|
38
42
|
) -> None:
|
|
39
43
|
"""Initializes the vLLM model wrapper.
|
|
@@ -44,12 +48,13 @@ class VLLMTextModel(base.BaseModel):
|
|
|
44
48
|
model_kwargs: Arguments required to initialize the vLLM model,
|
|
45
49
|
see [link](https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py)
|
|
46
50
|
for more information.
|
|
51
|
+
system_prompt: System prompt to use.
|
|
47
52
|
generation_kwargs: Arguments required to generate the output,
|
|
48
53
|
need to align with the arguments of
|
|
49
54
|
[vllm.SamplingParams](https://github.com/vllm-project/vllm/blob/main/vllm/sampling_params.py).
|
|
50
55
|
|
|
51
56
|
"""
|
|
52
|
-
super().__init__()
|
|
57
|
+
super().__init__(system_prompt=system_prompt)
|
|
53
58
|
self._model_name_or_path = model_name_or_path
|
|
54
59
|
self._model_kwargs = model_kwargs or {}
|
|
55
60
|
self._generation_kwargs = generation_kwargs or {}
|
|
@@ -71,11 +76,11 @@ class VLLMTextModel(base.BaseModel):
|
|
|
71
76
|
raise RuntimeError("Model not initialized")
|
|
72
77
|
self._llm_tokenizer = self._llm_model.get_tokenizer()
|
|
73
78
|
|
|
74
|
-
def
|
|
79
|
+
def _tokenize_messages(self, messages: List[MessageSeries]) -> List[TokensPrompt]:
|
|
75
80
|
"""Apply chat template to the messages.
|
|
76
81
|
|
|
77
82
|
Args:
|
|
78
|
-
|
|
83
|
+
messages: List of raw user strings.
|
|
79
84
|
|
|
80
85
|
Returns:
|
|
81
86
|
List of encoded messages.
|
|
@@ -90,7 +95,8 @@ class VLLMTextModel(base.BaseModel):
|
|
|
90
95
|
if not hasattr(self._llm_tokenizer, "chat_template"):
|
|
91
96
|
raise ValueError("Tokenizer does not have a chat template.")
|
|
92
97
|
|
|
93
|
-
chat_messages =
|
|
98
|
+
chat_messages = list(map(message_utils.format_chat_message, messages))
|
|
99
|
+
|
|
94
100
|
encoded_messages = self._llm_tokenizer.apply_chat_template(
|
|
95
101
|
chat_messages, # type: ignore
|
|
96
102
|
tokenize=True,
|
|
@@ -131,11 +137,30 @@ class VLLMTextModel(base.BaseModel):
|
|
|
131
137
|
|
|
132
138
|
return result
|
|
133
139
|
|
|
134
|
-
|
|
140
|
+
@override
|
|
141
|
+
def format_inputs(self, batch: TextBatch) -> List[TokensPrompt]:
|
|
142
|
+
"""Formats inputs for vLLM models.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
batch: A batch of text and image inputs.
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
List of formatted prompts.
|
|
149
|
+
"""
|
|
150
|
+
message_batch, _, _ = TextBatch(*batch)
|
|
151
|
+
message_batch = message_utils.batch_insert_system_message(
|
|
152
|
+
message_batch, self.system_message
|
|
153
|
+
)
|
|
154
|
+
message_batch = list(map(message_utils.combine_system_messages, message_batch))
|
|
155
|
+
|
|
156
|
+
return self._tokenize_messages(message_batch)
|
|
157
|
+
|
|
158
|
+
@override
|
|
159
|
+
def model_forward(self, batch: List[TokensPrompt]) -> List[str]:
|
|
135
160
|
"""Generates text for the given prompt using the vLLM model.
|
|
136
161
|
|
|
137
162
|
Args:
|
|
138
|
-
|
|
163
|
+
batch: A list encoded / tokenized messages (TokensPrompt objects).
|
|
139
164
|
|
|
140
165
|
Returns:
|
|
141
166
|
The generated text response.
|
|
@@ -144,6 +169,5 @@ class VLLMTextModel(base.BaseModel):
|
|
|
144
169
|
if self._llm_model is None:
|
|
145
170
|
raise RuntimeError("Model not initialized")
|
|
146
171
|
|
|
147
|
-
|
|
148
|
-
outputs = self._llm_model.generate(prompt_tokens, SamplingParams(**self._generation_kwargs))
|
|
172
|
+
outputs = self._llm_model.generate(batch, SamplingParams(**self._generation_kwargs))
|
|
149
173
|
return [output.outputs[0].text for output in outputs]
|
eva/language/utils/__init__.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Language utilities and helper functions."""
|
|
2
2
|
|
|
3
3
|
from eva.language.utils.str_to_int_tensor import CastStrToIntTensor
|
|
4
|
+
from eva.language.utils.text.messages import format_chat_message
|
|
4
5
|
|
|
5
|
-
__all__ = ["CastStrToIntTensor"]
|
|
6
|
+
__all__ = ["CastStrToIntTensor", "format_chat_message"]
|
|
@@ -16,11 +16,11 @@ class CastStrToIntTensor:
|
|
|
16
16
|
Supports single values, lists of strings, or lists of integers.
|
|
17
17
|
|
|
18
18
|
Example:
|
|
19
|
-
>>> # Default mapping for
|
|
20
|
-
>>> transform = CastStrToIntTensor()
|
|
21
|
-
>>> transform(['
|
|
19
|
+
>>> # Default mapping for A/B/C classification
|
|
20
|
+
>>> transform = CastStrToIntTensor(mapping={"A": 0, "B": 1, "C": 2})
|
|
21
|
+
>>> transform(['B', 'A', 'C'])
|
|
22
22
|
tensor([1, 0, 2])
|
|
23
|
-
>>> transform('
|
|
23
|
+
>>> transform('B')
|
|
24
24
|
tensor([1])
|
|
25
25
|
|
|
26
26
|
>>> # Custom mapping
|
|
@@ -29,20 +29,25 @@ class CastStrToIntTensor:
|
|
|
29
29
|
tensor([1, 0])
|
|
30
30
|
"""
|
|
31
31
|
|
|
32
|
-
def __init__(
|
|
33
|
-
|
|
32
|
+
def __init__(
|
|
33
|
+
self, mapping: Dict[str, int], standalone_words: bool = True, case_sensitive: bool = True
|
|
34
|
+
) -> None:
|
|
35
|
+
r"""Initialize the transform with a regex-to-integer mapping.
|
|
34
36
|
|
|
35
37
|
Args:
|
|
36
38
|
mapping: Dictionary mapping regex patterns to integers. If None, uses default
|
|
37
39
|
yes/no/maybe mapping: {'no': 0, 'yes': 1, 'maybe': 2}
|
|
40
|
+
standalone_words: If True, patterns are treated as standalone words (e.g., '\bno\b').
|
|
41
|
+
case_sensitive: If True, regex patterns are case-sensitive.
|
|
38
42
|
"""
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
self.mapping = mapping
|
|
43
|
+
self.mapping = mapping
|
|
44
|
+
|
|
45
|
+
if standalone_words:
|
|
46
|
+
self.mapping = {rf"\b{k}\b": v for k, v in mapping.items()}
|
|
43
47
|
|
|
44
48
|
self.compiled_patterns = [
|
|
45
|
-
(re.compile(pattern, re.IGNORECASE), value)
|
|
49
|
+
(re.compile(pattern, 0 if case_sensitive else re.IGNORECASE), value)
|
|
50
|
+
for pattern, value in self.mapping.items()
|
|
46
51
|
]
|
|
47
52
|
|
|
48
53
|
def __call__(self, values: Union[str, List[str], List[int]]) -> torch.Tensor:
|
|
@@ -58,7 +63,10 @@ class CastStrToIntTensor:
|
|
|
58
63
|
ValueError: If any value cannot be mapped to an integer.
|
|
59
64
|
"""
|
|
60
65
|
return torch.tensor(
|
|
61
|
-
[
|
|
66
|
+
[
|
|
67
|
+
self._cast_single(v)
|
|
68
|
+
for v in (values if isinstance(values, list | tuple) else [values])
|
|
69
|
+
],
|
|
62
70
|
dtype=torch.int,
|
|
63
71
|
)
|
|
64
72
|
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
"""Message formatting utilities for language models."""
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
import json
|
|
5
|
+
from typing import Any, Dict, List
|
|
6
|
+
|
|
7
|
+
from eva.language.data.messages import (
|
|
8
|
+
AssistantMessage,
|
|
9
|
+
MessageSeries,
|
|
10
|
+
Role,
|
|
11
|
+
SystemMessage,
|
|
12
|
+
UserMessage,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def format_chat_message(message: MessageSeries) -> List[Dict[str, Any]]:
|
|
17
|
+
"""Formats a message series into a format following OpenAI's API specification."""
|
|
18
|
+
return [{"role": item.role, "content": item.content} for item in message]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def combine_system_messages(message: MessageSeries, join_char: str = "\n") -> MessageSeries:
|
|
22
|
+
"""Combine system messages into a single message.
|
|
23
|
+
|
|
24
|
+
This is useful when the MessageSeries contains multiple system messages such
|
|
25
|
+
as `ModelSystemMessage` and `TaskSystemMessage`. But given that most models / apis
|
|
26
|
+
expect a single system message, this function can be used to combines them into one.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
message: The message series containing one or multiple messages.
|
|
30
|
+
join_char: The character to use to join the system messages. Default is newline.
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
A new message series with system messages combined into one and the
|
|
34
|
+
remaining messages unchanged.
|
|
35
|
+
"""
|
|
36
|
+
system_messages = list(filter(lambda item: item.role == Role.SYSTEM, message))
|
|
37
|
+
if len(system_messages) == 0:
|
|
38
|
+
return message
|
|
39
|
+
|
|
40
|
+
non_system_messages = list(filter(lambda item: item.role != Role.SYSTEM, message))
|
|
41
|
+
return [
|
|
42
|
+
SystemMessage(content=merge_message_contents(system_messages, join_char=join_char))
|
|
43
|
+
] + non_system_messages
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def merge_message_contents(message: MessageSeries, join_char: str = "\n") -> str:
|
|
47
|
+
"""Merges the all contents within a message series into a string.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
message: The message series to combine.
|
|
51
|
+
join_char: The character to use to join the message contents. Default is newline.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
A string containing the combined message contents.
|
|
55
|
+
"""
|
|
56
|
+
return join_char.join(item.content for item in message)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def insert_system_message(
|
|
60
|
+
message: MessageSeries, system_message: SystemMessage | None
|
|
61
|
+
) -> MessageSeries:
|
|
62
|
+
"""Insert a system message at the beginning of the message series."""
|
|
63
|
+
if system_message is None:
|
|
64
|
+
return message
|
|
65
|
+
return [system_message] + message
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def batch_insert_system_message(
|
|
69
|
+
messages: List[MessageSeries], system_message: SystemMessage | None
|
|
70
|
+
) -> List[MessageSeries]:
|
|
71
|
+
"""Insert a system message at the beginning of each message series in a batch."""
|
|
72
|
+
return list(
|
|
73
|
+
map(functools.partial(insert_system_message, system_message=system_message), messages)
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def serialize(messages: MessageSeries) -> str:
|
|
78
|
+
"""Serialize a MessageSeries object into a JSON string.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
messages: A list of message objects (MessagesSeries).
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
A JSON string representing the message series, with the following format:
|
|
85
|
+
[{"role": "user", "content": "Hello"}, ...]
|
|
86
|
+
"""
|
|
87
|
+
serialized_messages = format_chat_message(messages)
|
|
88
|
+
return json.dumps(serialized_messages)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def deserialize(messages: str) -> MessageSeries:
|
|
92
|
+
"""Convert a json string to a MessageSeries object.
|
|
93
|
+
|
|
94
|
+
Format: [{"role": "user", "content": "Hello"}, {"role": "assistant", "content": "Hi there!"}]
|
|
95
|
+
"""
|
|
96
|
+
message_dicts = json.loads(messages)
|
|
97
|
+
|
|
98
|
+
message_series = []
|
|
99
|
+
for message_dict in message_dicts:
|
|
100
|
+
if "role" not in message_dict or "content" not in message_dict:
|
|
101
|
+
raise ValueError("`role` or `content` keys are missing.")
|
|
102
|
+
|
|
103
|
+
match message_dict["role"]:
|
|
104
|
+
case Role.USER:
|
|
105
|
+
message_series.append(UserMessage(**message_dict))
|
|
106
|
+
case Role.ASSISTANT:
|
|
107
|
+
message_series.append(AssistantMessage(**message_dict))
|
|
108
|
+
case Role.SYSTEM:
|
|
109
|
+
message_series.append(SystemMessage(**message_dict))
|
|
110
|
+
case _:
|
|
111
|
+
raise ValueError(f"Unknown role: {message_dict['role']}")
|
|
112
|
+
|
|
113
|
+
return message_series
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
"""Text prediction writer callbacks."""
|
|
2
|
+
|
|
3
|
+
from typing import Dict, List, Literal, Tuple
|
|
4
|
+
|
|
5
|
+
from torch import nn
|
|
6
|
+
from typing_extensions import override
|
|
7
|
+
|
|
8
|
+
from eva.language.callbacks import writers
|
|
9
|
+
from eva.multimodal.models.typings import TextImageBatch
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class TextPredictionWriter(writers.TextPredictionWriter):
|
|
13
|
+
"""Callback for writing generated text predictions to disk."""
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
output_dir: str,
|
|
18
|
+
model: nn.Module,
|
|
19
|
+
dataloader_idx_map: Dict[int, str] | None = None,
|
|
20
|
+
metadata_keys: List[str] | None = None,
|
|
21
|
+
include_input: bool = True,
|
|
22
|
+
overwrite: bool = False,
|
|
23
|
+
save_format: Literal["jsonl", "parquet", "csv"] = "jsonl",
|
|
24
|
+
) -> None:
|
|
25
|
+
"""See docstring of base class."""
|
|
26
|
+
super().__init__(
|
|
27
|
+
output_dir=output_dir,
|
|
28
|
+
model=model,
|
|
29
|
+
dataloader_idx_map=dataloader_idx_map,
|
|
30
|
+
metadata_keys=metadata_keys,
|
|
31
|
+
include_input=include_input,
|
|
32
|
+
overwrite=overwrite,
|
|
33
|
+
save_format=save_format,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
@override
|
|
37
|
+
def _unpack_batch(self, batch: TextImageBatch) -> Tuple[list, list | None, dict | None]: # type: ignore
|
|
38
|
+
text_batch, _, target_batch, metadata_batch = TextImageBatch(*batch)
|
|
39
|
+
return text_batch, target_batch, metadata_batch
|