transformers-haystack 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- haystack_integrations/components/classifiers/py.typed +0 -0
- haystack_integrations/components/classifiers/transformers/__init__.py +6 -0
- haystack_integrations/components/classifiers/transformers/zero_shot_document_classifier.py +247 -0
- haystack_integrations/components/common/py.typed +0 -0
- haystack_integrations/components/common/transformers/__init__.py +3 -0
- haystack_integrations/components/common/transformers/utils.py +234 -0
- haystack_integrations/components/extractors/py.typed +0 -0
- haystack_integrations/components/extractors/transformers/__init__.py +6 -0
- haystack_integrations/components/extractors/transformers/named_entity_extractor.py +262 -0
- haystack_integrations/components/generators/py.typed +0 -0
- haystack_integrations/components/generators/transformers/__init__.py +6 -0
- haystack_integrations/components/generators/transformers/chat/__init__.py +3 -0
- haystack_integrations/components/generators/transformers/chat/chat_generator.py +666 -0
- haystack_integrations/components/readers/py.typed +0 -0
- haystack_integrations/components/readers/transformers/__init__.py +6 -0
- haystack_integrations/components/readers/transformers/extractive_reader.py +662 -0
- haystack_integrations/components/routers/py.typed +0 -0
- haystack_integrations/components/routers/transformers/__init__.py +7 -0
- haystack_integrations/components/routers/transformers/text_router.py +196 -0
- haystack_integrations/components/routers/transformers/zero_shot_text_router.py +205 -0
- transformers_haystack-0.1.0.dist-info/METADATA +38 -0
- transformers_haystack-0.1.0.dist-info/RECORD +24 -0
- transformers_haystack-0.1.0.dist-info/WHEEL +4 -0
- transformers_haystack-0.1.0.dist-info/licenses/LICENSE.txt +201 -0
|
@@ -0,0 +1,666 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
|
2
|
+
#
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import json
|
|
7
|
+
import re
|
|
8
|
+
import sys
|
|
9
|
+
from collections.abc import AsyncIterator, Callable
|
|
10
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
11
|
+
from contextlib import asynccontextmanager, suppress
|
|
12
|
+
from typing import Any, Literal, Union
|
|
13
|
+
|
|
14
|
+
from haystack import component, default_from_dict, default_to_dict, logging
|
|
15
|
+
from haystack.components.generators.utils import _normalize_messages
|
|
16
|
+
from haystack.dataclasses import ChatMessage, ComponentInfo, StreamingCallbackT, ToolCall
|
|
17
|
+
from haystack.dataclasses.streaming_chunk import select_streaming_callback
|
|
18
|
+
from haystack.tools import (
|
|
19
|
+
Tool,
|
|
20
|
+
Toolset,
|
|
21
|
+
ToolsType,
|
|
22
|
+
_check_duplicate_tool_names,
|
|
23
|
+
deserialize_tools_or_toolset_inplace,
|
|
24
|
+
flatten_tools_or_toolsets,
|
|
25
|
+
serialize_tools_or_toolset,
|
|
26
|
+
)
|
|
27
|
+
from haystack.tools.utils import warm_up_tools
|
|
28
|
+
from haystack.utils import ComponentDevice, Secret, deserialize_callable, serialize_callable
|
|
29
|
+
from haystack.utils.hf import convert_message_to_hf_format, deserialize_hf_model_kwargs, serialize_hf_model_kwargs
|
|
30
|
+
from huggingface_hub import model_info
|
|
31
|
+
from packaging.version import Version
|
|
32
|
+
|
|
33
|
+
import transformers
|
|
34
|
+
from haystack_integrations.components.common.transformers.utils import (
|
|
35
|
+
_AsyncHFTokenStreamingHandler,
|
|
36
|
+
_HFTokenStreamingHandler,
|
|
37
|
+
_StopWordsCriteria,
|
|
38
|
+
)
|
|
39
|
+
from transformers import Pipeline as HfPipeline
|
|
40
|
+
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, StoppingCriteriaList, pipeline
|
|
41
|
+
|
|
42
|
+
logger = logging.getLogger(__name__)
|
|
43
|
+
|
|
44
|
+
PIPELINE_SUPPORTED_TASKS = ["text-generation", "text2text-generation", "image-text-to-text"]
|
|
45
|
+
|
|
46
|
+
DEFAULT_TOOL_PATTERN = (
|
|
47
|
+
r"(?:<tool_call>)?"
|
|
48
|
+
r'(?:\s*\{.*?"name"\s*:\s*"([^"]+)".*?"arguments"\s*:\s*(\{[^}]+\}).*?\}'
|
|
49
|
+
r'|\{.*?"function"\s*:\s*\{.*?"name"\s*:\s*"([^"]+)".*?"arguments"\s*:\s*(\{[^}]+\}).*?\})'
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def default_tool_parser(text: str) -> list[ToolCall] | None:
|
|
54
|
+
"""
|
|
55
|
+
Default implementation for parsing tool calls from model output text.
|
|
56
|
+
|
|
57
|
+
Uses DEFAULT_TOOL_PATTERN to extract tool calls.
|
|
58
|
+
|
|
59
|
+
:param text: The text to parse for tool calls.
|
|
60
|
+
:returns: A list containing a single ToolCall if a valid tool call is found, None otherwise.
|
|
61
|
+
"""
|
|
62
|
+
try:
|
|
63
|
+
match = re.search(DEFAULT_TOOL_PATTERN, text, re.DOTALL)
|
|
64
|
+
except re.error:
|
|
65
|
+
logger.warning("Invalid regex pattern for tool parsing: {pattern}", pattern=DEFAULT_TOOL_PATTERN)
|
|
66
|
+
return None
|
|
67
|
+
|
|
68
|
+
if not match:
|
|
69
|
+
return None
|
|
70
|
+
|
|
71
|
+
name = match.group(1) or match.group(3)
|
|
72
|
+
args_str = match.group(2) or match.group(4)
|
|
73
|
+
|
|
74
|
+
try:
|
|
75
|
+
arguments = json.loads(args_str)
|
|
76
|
+
return [ToolCall(tool_name=name, arguments=arguments)]
|
|
77
|
+
except json.JSONDecodeError:
|
|
78
|
+
logger.warning("Failed to parse tool call arguments: {args_str}", args_str=args_str)
|
|
79
|
+
return None
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@component
|
|
83
|
+
class TransformersChatGenerator:
|
|
84
|
+
"""
|
|
85
|
+
Generates chat responses using models from Hugging Face that run locally.
|
|
86
|
+
|
|
87
|
+
Use this component with chat-based models,
|
|
88
|
+
such as `Qwen/Qwen3-0.6B` or `meta-llama/Llama-2-7b-chat-hf`.
|
|
89
|
+
LLMs running locally may need powerful hardware.
|
|
90
|
+
|
|
91
|
+
### Usage example
|
|
92
|
+
```python
|
|
93
|
+
from haystack.dataclasses import ChatMessage
|
|
94
|
+
|
|
95
|
+
from haystack_integrations.components.generators.transformers import TransformersChatGenerator
|
|
96
|
+
|
|
97
|
+
generator = TransformersChatGenerator(model="Qwen/Qwen3-0.6B")
|
|
98
|
+
messages = [ChatMessage.from_user("What's Natural Language Processing? Be brief.")]
|
|
99
|
+
print(generator.run(messages))
|
|
100
|
+
```
|
|
101
|
+
|
|
102
|
+
```
|
|
103
|
+
{'replies':
|
|
104
|
+
[ChatMessage(_role=<ChatRole.ASSISTANT: 'assistant'>, _content=[TextContent(text=
|
|
105
|
+
"Natural Language Processing (NLP) is a subfield of artificial intelligence that deals
|
|
106
|
+
with the interaction between computers and human language. It enables computers to understand, interpret, and
|
|
107
|
+
generate human language in a valuable way. NLP involves various techniques such as speech recognition, text
|
|
108
|
+
analysis, sentiment analysis, and machine translation. The ultimate goal is to make it easier for computers to
|
|
109
|
+
process and derive meaning from human language, improving communication between humans and machines.")],
|
|
110
|
+
_name=None,
|
|
111
|
+
_meta={'finish_reason': 'stop', 'index': 0, 'model':
|
|
112
|
+
'mistralai/Mistral-7B-Instruct-v0.2',
|
|
113
|
+
'usage': {'completion_tokens': 90, 'prompt_tokens': 19, 'total_tokens': 109}})
|
|
114
|
+
]
|
|
115
|
+
}
|
|
116
|
+
```
|
|
117
|
+
"""
|
|
118
|
+
|
|
119
|
+
def __init__(
|
|
120
|
+
self,
|
|
121
|
+
model: str = "Qwen/Qwen3-0.6B",
|
|
122
|
+
task: Literal["text-generation", "text2text-generation", "image-text-to-text"] | None = None,
|
|
123
|
+
device: ComponentDevice | None = None,
|
|
124
|
+
token: Secret | None = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
|
|
125
|
+
chat_template: str | None = None,
|
|
126
|
+
generation_kwargs: dict[str, Any] | None = None,
|
|
127
|
+
huggingface_pipeline_kwargs: dict[str, Any] | None = None,
|
|
128
|
+
stop_words: list[str] | None = None,
|
|
129
|
+
streaming_callback: StreamingCallbackT | None = None,
|
|
130
|
+
tools: ToolsType | None = None,
|
|
131
|
+
tool_parsing_function: Callable[[str], list[ToolCall] | None] | None = None,
|
|
132
|
+
async_executor: ThreadPoolExecutor | None = None,
|
|
133
|
+
*,
|
|
134
|
+
enable_thinking: bool = False,
|
|
135
|
+
) -> None:
|
|
136
|
+
"""
|
|
137
|
+
Initializes the TransformersChatGenerator component.
|
|
138
|
+
|
|
139
|
+
:param model: The Hugging Face text generation model name or path,
|
|
140
|
+
for example, `mistralai/Mistral-7B-Instruct-v0.2` or `TheBloke/OpenHermes-2.5-Mistral-7B-16k-AWQ`.
|
|
141
|
+
The model must be a chat model supporting the ChatML messaging
|
|
142
|
+
format.
|
|
143
|
+
If the model is specified in `huggingface_pipeline_kwargs`, this parameter is ignored.
|
|
144
|
+
:param task: The task for the Hugging Face pipeline. Possible options:
|
|
145
|
+
- `text-generation`: Supported by decoder models, like GPT.
|
|
146
|
+
- `text2text-generation`: Deprecated as of Transformers v5; use `text-generation` instead.
|
|
147
|
+
Previously supported by encoder-decoder models such as T5.
|
|
148
|
+
- `image-text-to-text`: Supported by vision-language models.
|
|
149
|
+
If the task is specified in `huggingface_pipeline_kwargs`, this parameter is ignored.
|
|
150
|
+
If not specified, the component calls the Hugging Face API to infer the task from the model name.
|
|
151
|
+
:param device: The device for loading the model. If `None`, automatically selects the default device.
|
|
152
|
+
If a device or device map is specified in `huggingface_pipeline_kwargs`, it overrides this parameter.
|
|
153
|
+
:param token: The token to use as HTTP bearer authorization for remote files.
|
|
154
|
+
If the token is specified in `huggingface_pipeline_kwargs`, this parameter is ignored.
|
|
155
|
+
:param chat_template: Specifies an optional Jinja template for formatting chat
|
|
156
|
+
messages. Most high-quality chat models have their own templates, but for models without this
|
|
157
|
+
feature or if you prefer a custom template, use this parameter.
|
|
158
|
+
:param generation_kwargs: A dictionary with keyword arguments to customize text generation.
|
|
159
|
+
Some examples: `max_length`, `max_new_tokens`, `temperature`, `top_k`, `top_p`.
|
|
160
|
+
See Hugging Face's documentation for more information:
|
|
161
|
+
- - [customize-text-generation](https://huggingface.co/docs/transformers/main/en/generation_strategies#customize-text-generation)
|
|
162
|
+
- - [GenerationConfig](https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationConfig)
|
|
163
|
+
The only `generation_kwargs` set by default is `max_new_tokens`, which is set to 512 tokens.
|
|
164
|
+
:param huggingface_pipeline_kwargs: Dictionary with keyword arguments to initialize the
|
|
165
|
+
Hugging Face pipeline for text generation.
|
|
166
|
+
These keyword arguments provide fine-grained control over the Hugging Face pipeline.
|
|
167
|
+
In case of duplication, these kwargs override `model`, `task`, `device`, and `token` init parameters.
|
|
168
|
+
For kwargs, see [Hugging Face documentation](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.pipeline.task).
|
|
169
|
+
In this dictionary, you can also include `model_kwargs` to specify the kwargs for [model initialization](https://huggingface.co/docs/transformers/en/main_classes/model#transformers.PreTrainedModel.from_pretrained)
|
|
170
|
+
:param stop_words: A list of stop words. If the model generates a stop word, the generation stops.
|
|
171
|
+
If you provide this parameter, don't specify the `stopping_criteria` in `generation_kwargs`.
|
|
172
|
+
For some chat models, the output includes both the new text and the original prompt.
|
|
173
|
+
In these cases, make sure your prompt has no stop words.
|
|
174
|
+
:param streaming_callback: An optional callable for handling streaming responses.
|
|
175
|
+
:param tools: A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
|
|
176
|
+
:param tool_parsing_function:
|
|
177
|
+
A callable that takes a string and returns a list of ToolCall objects or None.
|
|
178
|
+
If None, the default_tool_parser will be used which extracts tool calls using a predefined pattern.
|
|
179
|
+
:param async_executor:
|
|
180
|
+
Optional ThreadPoolExecutor to use for async calls. If not provided, a single-threaded executor will be
|
|
181
|
+
initialized and used
|
|
182
|
+
:param enable_thinking:
|
|
183
|
+
Whether to enable thinking mode in the chat template for thinking-capable models.
|
|
184
|
+
When enabled, the model generates intermediate reasoning before the final response. Defaults to False.
|
|
185
|
+
"""
|
|
186
|
+
if tools and streaming_callback is not None:
|
|
187
|
+
msg = "Using tools and streaming at the same time is not supported. Please choose one."
|
|
188
|
+
raise ValueError(msg)
|
|
189
|
+
_check_duplicate_tool_names(flatten_tools_or_toolsets(tools))
|
|
190
|
+
|
|
191
|
+
huggingface_pipeline_kwargs = huggingface_pipeline_kwargs or {}
|
|
192
|
+
generation_kwargs = generation_kwargs or {}
|
|
193
|
+
|
|
194
|
+
self.token = token
|
|
195
|
+
token = token.resolve_value() if token else None
|
|
196
|
+
|
|
197
|
+
# check if the huggingface_pipeline_kwargs contain the essential parameters
|
|
198
|
+
# otherwise, populate them with values from other init parameters
|
|
199
|
+
huggingface_pipeline_kwargs.setdefault("model", model)
|
|
200
|
+
huggingface_pipeline_kwargs.setdefault("token", token)
|
|
201
|
+
|
|
202
|
+
device = ComponentDevice.resolve_device(device)
|
|
203
|
+
device.update_hf_kwargs(huggingface_pipeline_kwargs, overwrite=False)
|
|
204
|
+
|
|
205
|
+
# task identification and validation
|
|
206
|
+
if task is None:
|
|
207
|
+
if "task" in huggingface_pipeline_kwargs:
|
|
208
|
+
task = huggingface_pipeline_kwargs["task"]
|
|
209
|
+
elif isinstance(huggingface_pipeline_kwargs["model"], str):
|
|
210
|
+
task = model_info(
|
|
211
|
+
huggingface_pipeline_kwargs["model"], token=huggingface_pipeline_kwargs["token"]
|
|
212
|
+
).pipeline_tag # type: ignore[assignment] # we'll check below if task is in supported tasks
|
|
213
|
+
|
|
214
|
+
if task not in PIPELINE_SUPPORTED_TASKS:
|
|
215
|
+
msg = f"Task '{task}' is not supported. The supported tasks are: {', '.join(PIPELINE_SUPPORTED_TASKS)}."
|
|
216
|
+
raise ValueError(msg)
|
|
217
|
+
if task == "text2text-generation" and Version(transformers.__version__) >= Version("5.0.0"):
|
|
218
|
+
msg = "Task 'text2text-generation' is not supported with transformers v5 or higher."
|
|
219
|
+
raise ValueError(msg)
|
|
220
|
+
huggingface_pipeline_kwargs["task"] = task
|
|
221
|
+
|
|
222
|
+
# if not specified, set return_full_text to False for text-generation
|
|
223
|
+
# only generated text is returned (excluding prompt)
|
|
224
|
+
if task == "text-generation":
|
|
225
|
+
generation_kwargs.setdefault("return_full_text", False)
|
|
226
|
+
|
|
227
|
+
if stop_words and "stopping_criteria" in generation_kwargs:
|
|
228
|
+
msg = (
|
|
229
|
+
"Found both the `stop_words` init parameter and the `stopping_criteria` key in `generation_kwargs`. "
|
|
230
|
+
"Please specify only one of them."
|
|
231
|
+
)
|
|
232
|
+
raise ValueError(msg)
|
|
233
|
+
generation_kwargs.setdefault("max_new_tokens", 512)
|
|
234
|
+
generation_kwargs["stop_sequences"] = generation_kwargs.get("stop_sequences", [])
|
|
235
|
+
generation_kwargs["stop_sequences"].extend(stop_words or [])
|
|
236
|
+
|
|
237
|
+
self.tool_parsing_function = tool_parsing_function or default_tool_parser
|
|
238
|
+
self.huggingface_pipeline_kwargs = huggingface_pipeline_kwargs
|
|
239
|
+
self.generation_kwargs = generation_kwargs
|
|
240
|
+
self.chat_template = chat_template
|
|
241
|
+
self.streaming_callback = streaming_callback
|
|
242
|
+
self.pipeline: HfPipeline | None = None
|
|
243
|
+
self.tools = tools
|
|
244
|
+
self.enable_thinking = enable_thinking
|
|
245
|
+
|
|
246
|
+
self._owns_executor = async_executor is None
|
|
247
|
+
self.executor = (
|
|
248
|
+
ThreadPoolExecutor(thread_name_prefix=f"async-TransformersChatGenerator-executor-{id(self)}", max_workers=1)
|
|
249
|
+
if async_executor is None
|
|
250
|
+
else async_executor
|
|
251
|
+
)
|
|
252
|
+
self._is_warmed_up = False
|
|
253
|
+
|
|
254
|
+
def __del__(self) -> None:
|
|
255
|
+
"""
|
|
256
|
+
Cleanup when the instance is being destroyed.
|
|
257
|
+
"""
|
|
258
|
+
if hasattr(self, "_owns_executor") and self._owns_executor and hasattr(self, "executor"):
|
|
259
|
+
self.executor.shutdown(wait=True)
|
|
260
|
+
|
|
261
|
+
def shutdown(self) -> None:
|
|
262
|
+
"""
|
|
263
|
+
Explicitly shutdown the executor if we own it.
|
|
264
|
+
"""
|
|
265
|
+
if self._owns_executor:
|
|
266
|
+
self.executor.shutdown(wait=True)
|
|
267
|
+
|
|
268
|
+
def _get_telemetry_data(self) -> dict[str, Any]:
|
|
269
|
+
"""
|
|
270
|
+
Data that is sent to Posthog for usage analytics.
|
|
271
|
+
"""
|
|
272
|
+
if isinstance(self.huggingface_pipeline_kwargs["model"], str):
|
|
273
|
+
return {"model": self.huggingface_pipeline_kwargs["model"]}
|
|
274
|
+
return {"model": f"[object of type {type(self.huggingface_pipeline_kwargs['model'])}]"}
|
|
275
|
+
|
|
276
|
+
def warm_up(self) -> None:
|
|
277
|
+
"""
|
|
278
|
+
Initializes the component and warms up tools if provided.
|
|
279
|
+
"""
|
|
280
|
+
if self._is_warmed_up:
|
|
281
|
+
return
|
|
282
|
+
|
|
283
|
+
# Initialize the pipeline
|
|
284
|
+
if self.pipeline is None:
|
|
285
|
+
self.pipeline = pipeline(**self.huggingface_pipeline_kwargs)
|
|
286
|
+
|
|
287
|
+
# Warm up tools
|
|
288
|
+
if self.tools:
|
|
289
|
+
warm_up_tools(self.tools)
|
|
290
|
+
|
|
291
|
+
self._is_warmed_up = True
|
|
292
|
+
|
|
293
|
+
def to_dict(self) -> dict[str, Any]:
|
|
294
|
+
"""
|
|
295
|
+
Serializes the component to a dictionary.
|
|
296
|
+
|
|
297
|
+
:returns:
|
|
298
|
+
Dictionary with serialized data.
|
|
299
|
+
"""
|
|
300
|
+
callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
|
|
301
|
+
serialization_dict = default_to_dict(
|
|
302
|
+
self,
|
|
303
|
+
huggingface_pipeline_kwargs=self.huggingface_pipeline_kwargs,
|
|
304
|
+
generation_kwargs=self.generation_kwargs,
|
|
305
|
+
streaming_callback=callback_name,
|
|
306
|
+
token=self.token,
|
|
307
|
+
chat_template=self.chat_template,
|
|
308
|
+
tools=serialize_tools_or_toolset(self.tools),
|
|
309
|
+
tool_parsing_function=serialize_callable(self.tool_parsing_function),
|
|
310
|
+
enable_thinking=self.enable_thinking,
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
huggingface_pipeline_kwargs = serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"]
|
|
314
|
+
huggingface_pipeline_kwargs.pop("token", None)
|
|
315
|
+
|
|
316
|
+
serialize_hf_model_kwargs(huggingface_pipeline_kwargs)
|
|
317
|
+
return serialization_dict
|
|
318
|
+
|
|
319
|
+
@classmethod
|
|
320
|
+
def from_dict(cls, data: dict[str, Any]) -> "TransformersChatGenerator":
|
|
321
|
+
"""
|
|
322
|
+
Deserializes the component from a dictionary.
|
|
323
|
+
|
|
324
|
+
:param data:
|
|
325
|
+
The dictionary to deserialize from.
|
|
326
|
+
:returns:
|
|
327
|
+
The deserialized component.
|
|
328
|
+
"""
|
|
329
|
+
deserialize_tools_or_toolset_inplace(data["init_parameters"], key="tools")
|
|
330
|
+
init_params = data.get("init_parameters", {})
|
|
331
|
+
serialized_callback_handler = init_params.get("streaming_callback")
|
|
332
|
+
if serialized_callback_handler:
|
|
333
|
+
data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)
|
|
334
|
+
|
|
335
|
+
tool_parsing_function = init_params.get("tool_parsing_function")
|
|
336
|
+
if tool_parsing_function:
|
|
337
|
+
init_params["tool_parsing_function"] = deserialize_callable(tool_parsing_function)
|
|
338
|
+
|
|
339
|
+
huggingface_pipeline_kwargs = init_params.get("huggingface_pipeline_kwargs", {})
|
|
340
|
+
deserialize_hf_model_kwargs(huggingface_pipeline_kwargs)
|
|
341
|
+
return default_from_dict(cls, data)
|
|
342
|
+
|
|
343
|
+
@component.output_types(replies=list[ChatMessage])
|
|
344
|
+
def run(
|
|
345
|
+
self,
|
|
346
|
+
messages: list[ChatMessage] | str,
|
|
347
|
+
generation_kwargs: dict[str, Any] | None = None,
|
|
348
|
+
streaming_callback: StreamingCallbackT | None = None,
|
|
349
|
+
tools: ToolsType | None = None,
|
|
350
|
+
) -> dict[str, list[ChatMessage]]:
|
|
351
|
+
"""
|
|
352
|
+
Invoke text generation inference based on the provided messages and generation parameters.
|
|
353
|
+
|
|
354
|
+
:param messages: A list of ChatMessage objects representing the input messages. If a string is provided,
|
|
355
|
+
it is converted to a list containing a ChatMessage with user role.
|
|
356
|
+
:param generation_kwargs: Additional keyword arguments for text generation.
|
|
357
|
+
:param streaming_callback: An optional callable for handling streaming responses.
|
|
358
|
+
:param tools: A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
|
|
359
|
+
If set, it will override the `tools` parameter provided during initialization.
|
|
360
|
+
:returns: A dictionary with the following keys:
|
|
361
|
+
- `replies`: A list containing the generated responses as ChatMessage instances.
|
|
362
|
+
"""
|
|
363
|
+
if self.pipeline is None:
|
|
364
|
+
self.warm_up()
|
|
365
|
+
|
|
366
|
+
messages = _normalize_messages(messages)
|
|
367
|
+
|
|
368
|
+
prepared_inputs = self._prepare_inputs(
|
|
369
|
+
messages=messages, generation_kwargs=generation_kwargs, streaming_callback=streaming_callback, tools=tools
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
streaming_callback = select_streaming_callback(
|
|
373
|
+
self.streaming_callback, streaming_callback, requires_async=False
|
|
374
|
+
)
|
|
375
|
+
if streaming_callback:
|
|
376
|
+
# streamer parameter hooks into HF streaming, _HFTokenStreamingHandler is an adapter to our streaming
|
|
377
|
+
prepared_inputs["generation_kwargs"]["streamer"] = _HFTokenStreamingHandler(
|
|
378
|
+
tokenizer=prepared_inputs["tokenizer"],
|
|
379
|
+
stream_handler=streaming_callback,
|
|
380
|
+
stop_words=prepared_inputs["stop_words"],
|
|
381
|
+
component_info=ComponentInfo.from_component(self),
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
# We know it's not None because we check it in _prepare_inputs
|
|
385
|
+
assert self.pipeline is not None # noqa: S101
|
|
386
|
+
# Generate responses
|
|
387
|
+
output = self.pipeline(prepared_inputs["prepared_prompt"], **prepared_inputs["generation_kwargs"])
|
|
388
|
+
|
|
389
|
+
chat_messages = self._convert_hf_output_to_chat_messages(hf_pipeline_output=output, **prepared_inputs)
|
|
390
|
+
|
|
391
|
+
return {"replies": chat_messages}
|
|
392
|
+
|
|
393
|
+
def create_message(
|
|
394
|
+
self,
|
|
395
|
+
text: str,
|
|
396
|
+
index: int,
|
|
397
|
+
tokenizer: Union["PreTrainedTokenizer", "PreTrainedTokenizerFast"],
|
|
398
|
+
prompt: str,
|
|
399
|
+
generation_kwargs: dict[str, Any],
|
|
400
|
+
parse_tool_calls: bool = False,
|
|
401
|
+
) -> ChatMessage:
|
|
402
|
+
"""
|
|
403
|
+
Create a ChatMessage instance from the provided text, populated with metadata.
|
|
404
|
+
|
|
405
|
+
:param text: The generated text.
|
|
406
|
+
:param index: The index of the generated text.
|
|
407
|
+
:param tokenizer: The tokenizer used for generation.
|
|
408
|
+
:param prompt: The prompt used for generation.
|
|
409
|
+
:param generation_kwargs: The generation parameters.
|
|
410
|
+
:param parse_tool_calls: Whether to attempt parsing tool calls from the text.
|
|
411
|
+
:returns: A ChatMessage instance.
|
|
412
|
+
"""
|
|
413
|
+
|
|
414
|
+
completion_tokens = len(tokenizer.encode(text, add_special_tokens=False))
|
|
415
|
+
prompt_token_count = len(tokenizer.encode(prompt, add_special_tokens=False))
|
|
416
|
+
total_tokens = prompt_token_count + completion_tokens
|
|
417
|
+
|
|
418
|
+
tool_calls = self.tool_parsing_function(text) if parse_tool_calls else None
|
|
419
|
+
|
|
420
|
+
# Determine finish reason based on context
|
|
421
|
+
if completion_tokens >= generation_kwargs.get("max_new_tokens", sys.maxsize):
|
|
422
|
+
finish_reason = "length"
|
|
423
|
+
elif tool_calls:
|
|
424
|
+
finish_reason = "tool_calls"
|
|
425
|
+
else:
|
|
426
|
+
finish_reason = "stop"
|
|
427
|
+
|
|
428
|
+
meta = {
|
|
429
|
+
"finish_reason": finish_reason,
|
|
430
|
+
"index": index,
|
|
431
|
+
"model": self.huggingface_pipeline_kwargs["model"],
|
|
432
|
+
"usage": {
|
|
433
|
+
"completion_tokens": completion_tokens,
|
|
434
|
+
"prompt_tokens": prompt_token_count,
|
|
435
|
+
"total_tokens": total_tokens,
|
|
436
|
+
},
|
|
437
|
+
}
|
|
438
|
+
|
|
439
|
+
# If tool calls are detected, don't include the text content since it contains the raw tool call format
|
|
440
|
+
return ChatMessage.from_assistant(tool_calls=tool_calls, text=None if tool_calls else text, meta=meta)
|
|
441
|
+
|
|
442
|
+
@staticmethod
|
|
443
|
+
def _validate_stop_words(stop_words: list[str] | None) -> list[str] | None:
|
|
444
|
+
"""
|
|
445
|
+
Validates the provided stop words.
|
|
446
|
+
|
|
447
|
+
:param stop_words: A list of stop words to validate.
|
|
448
|
+
:return: A sanitized list of stop words or None if validation fails.
|
|
449
|
+
"""
|
|
450
|
+
if stop_words and not all(isinstance(word, str) for word in stop_words):
|
|
451
|
+
logger.warning(
|
|
452
|
+
"Invalid stop words provided. Stop words must be specified as a list of strings. "
|
|
453
|
+
"Ignoring stop words: {stop_words}",
|
|
454
|
+
stop_words=stop_words,
|
|
455
|
+
)
|
|
456
|
+
return None
|
|
457
|
+
|
|
458
|
+
return list(set(stop_words or []))
|
|
459
|
+
|
|
460
|
+
@component.output_types(replies=list[ChatMessage])
|
|
461
|
+
async def run_async(
|
|
462
|
+
self,
|
|
463
|
+
messages: list[ChatMessage] | str,
|
|
464
|
+
generation_kwargs: dict[str, Any] | None = None,
|
|
465
|
+
streaming_callback: StreamingCallbackT | None = None,
|
|
466
|
+
tools: ToolsType | None = None,
|
|
467
|
+
) -> dict[str, list[ChatMessage]]:
|
|
468
|
+
"""
|
|
469
|
+
Asynchronously invokes text generation inference based on the provided messages and generation parameters.
|
|
470
|
+
|
|
471
|
+
This is the asynchronous version of the `run` method. It has the same parameters
|
|
472
|
+
and return values but can be used with `await` in an async code.
|
|
473
|
+
|
|
474
|
+
:param messages: A list of ChatMessage objects representing the input messages.
|
|
475
|
+
:param generation_kwargs: Additional keyword arguments for text generation.
|
|
476
|
+
:param streaming_callback: An optional callable for handling streaming responses.
|
|
477
|
+
:param tools: A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
|
|
478
|
+
If set, it will override the `tools` parameter provided during initialization.
|
|
479
|
+
:returns: A dictionary with the following keys:
|
|
480
|
+
- `replies`: A list containing the generated responses as ChatMessage instances.
|
|
481
|
+
"""
|
|
482
|
+
if self.pipeline is None:
|
|
483
|
+
self.warm_up()
|
|
484
|
+
|
|
485
|
+
messages = _normalize_messages(messages)
|
|
486
|
+
|
|
487
|
+
prepared_inputs = self._prepare_inputs(
|
|
488
|
+
messages=messages, generation_kwargs=generation_kwargs, streaming_callback=streaming_callback, tools=tools
|
|
489
|
+
)
|
|
490
|
+
|
|
491
|
+
# Validate and select the streaming callback
|
|
492
|
+
streaming_callback = select_streaming_callback(self.streaming_callback, streaming_callback, requires_async=True)
|
|
493
|
+
|
|
494
|
+
if streaming_callback:
|
|
495
|
+
async_handler = _AsyncHFTokenStreamingHandler(
|
|
496
|
+
tokenizer=prepared_inputs["tokenizer"],
|
|
497
|
+
stream_handler=streaming_callback,
|
|
498
|
+
stop_words=prepared_inputs["stop_words"],
|
|
499
|
+
component_info=ComponentInfo.from_component(self),
|
|
500
|
+
)
|
|
501
|
+
prepared_inputs["generation_kwargs"]["streamer"] = async_handler
|
|
502
|
+
|
|
503
|
+
# Use async context manager for proper resource cleanup
|
|
504
|
+
async with self._manage_queue_processor(async_handler):
|
|
505
|
+
output = await asyncio.get_running_loop().run_in_executor(
|
|
506
|
+
self.executor,
|
|
507
|
+
# have to ignore since assert self.pipeline is not None doesn't work
|
|
508
|
+
lambda: self.pipeline( # type: ignore[misc]
|
|
509
|
+
prepared_inputs["prepared_prompt"], **prepared_inputs["generation_kwargs"]
|
|
510
|
+
),
|
|
511
|
+
)
|
|
512
|
+
else:
|
|
513
|
+
output = await asyncio.get_running_loop().run_in_executor(
|
|
514
|
+
self.executor,
|
|
515
|
+
# have to ignore since assert self.pipeline is not None doesn't work
|
|
516
|
+
lambda: self.pipeline( # type: ignore[misc]
|
|
517
|
+
prepared_inputs["prepared_prompt"], **prepared_inputs["generation_kwargs"]
|
|
518
|
+
),
|
|
519
|
+
)
|
|
520
|
+
|
|
521
|
+
chat_messages = self._convert_hf_output_to_chat_messages(hf_pipeline_output=output, **prepared_inputs)
|
|
522
|
+
return {"replies": chat_messages}
|
|
523
|
+
|
|
524
|
+
@asynccontextmanager
|
|
525
|
+
async def _manage_queue_processor(
|
|
526
|
+
self, async_handler: "_AsyncHFTokenStreamingHandler"
|
|
527
|
+
) -> AsyncIterator["asyncio.Task[None]"]:
|
|
528
|
+
"""Context manager for proper queue processor lifecycle management."""
|
|
529
|
+
queue_processor = asyncio.create_task(async_handler.process_queue())
|
|
530
|
+
try:
|
|
531
|
+
yield queue_processor
|
|
532
|
+
finally:
|
|
533
|
+
# Ensure the queue processor is cleaned up properly
|
|
534
|
+
try:
|
|
535
|
+
await asyncio.wait_for(queue_processor, timeout=0.1)
|
|
536
|
+
except asyncio.TimeoutError:
|
|
537
|
+
queue_processor.cancel()
|
|
538
|
+
with suppress(asyncio.CancelledError):
|
|
539
|
+
await queue_processor
|
|
540
|
+
|
|
541
|
+
def _prepare_inputs(
|
|
542
|
+
self,
|
|
543
|
+
messages: list[ChatMessage],
|
|
544
|
+
generation_kwargs: dict[str, Any] | None = None,
|
|
545
|
+
streaming_callback: StreamingCallbackT | None = None,
|
|
546
|
+
tools: ToolsType | None = None,
|
|
547
|
+
) -> dict[str, Any]:
|
|
548
|
+
"""
|
|
549
|
+
Prepares the inputs for the Hugging Face pipeline.
|
|
550
|
+
|
|
551
|
+
:param messages: A list of ChatMessage objects representing the input messages.
|
|
552
|
+
:param generation_kwargs: Additional keyword arguments for text generation.
|
|
553
|
+
:param streaming_callback: An optional callable for handling streaming responses.
|
|
554
|
+
:param tools: A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
|
|
555
|
+
:returns: A dictionary containing the prepared prompt, tokenizer, generation kwargs, and tools.
|
|
556
|
+
:raises ValueError: If both tools and streaming_callback are provided.
|
|
557
|
+
"""
|
|
558
|
+
tools = tools or self.tools
|
|
559
|
+
if tools and streaming_callback is not None:
|
|
560
|
+
msg = "Using tools and streaming at the same time is not supported. Please choose one."
|
|
561
|
+
raise ValueError(msg)
|
|
562
|
+
flat_tools = flatten_tools_or_toolsets(tools)
|
|
563
|
+
_check_duplicate_tool_names(flat_tools)
|
|
564
|
+
|
|
565
|
+
# mypy doesn't know this is set in warm_up
|
|
566
|
+
tokenizer = self.pipeline.tokenizer # type: ignore[union-attr]
|
|
567
|
+
|
|
568
|
+
# Check and update generation parameters
|
|
569
|
+
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
|
|
570
|
+
|
|
571
|
+
# If streaming_callback is provided, ensure that num_return_sequences is set to 1
|
|
572
|
+
if streaming_callback:
|
|
573
|
+
num_responses = generation_kwargs.get("num_return_sequences", 1)
|
|
574
|
+
if num_responses > 1:
|
|
575
|
+
msg = (
|
|
576
|
+
"Streaming is enabled, but the number of responses is set to {num_responses}. "
|
|
577
|
+
"Streaming is only supported for single response generation. "
|
|
578
|
+
"Setting the number of responses to 1."
|
|
579
|
+
)
|
|
580
|
+
logger.warning(msg, num_responses=num_responses)
|
|
581
|
+
generation_kwargs["num_return_sequences"] = 1
|
|
582
|
+
|
|
583
|
+
stop_words = generation_kwargs.pop("stop_words", []) + generation_kwargs.pop("stop_sequences", [])
|
|
584
|
+
stop_words = self._validate_stop_words(stop_words)
|
|
585
|
+
|
|
586
|
+
# Set up stop words criteria if stop words exist
|
|
587
|
+
stop_words_criteria = (
|
|
588
|
+
_StopWordsCriteria(
|
|
589
|
+
tokenizer, # type: ignore[arg-type]
|
|
590
|
+
stop_words,
|
|
591
|
+
self.pipeline.device, # type: ignore[union-attr]
|
|
592
|
+
)
|
|
593
|
+
if stop_words
|
|
594
|
+
else None
|
|
595
|
+
)
|
|
596
|
+
if stop_words_criteria:
|
|
597
|
+
generation_kwargs["stopping_criteria"] = StoppingCriteriaList([stop_words_criteria])
|
|
598
|
+
|
|
599
|
+
# convert messages to HF format
|
|
600
|
+
hf_messages = [convert_message_to_hf_format(message) for message in messages]
|
|
601
|
+
|
|
602
|
+
# mypy doesn't know tokenizer is set in warm_up
|
|
603
|
+
prepared_prompt = tokenizer.apply_chat_template( # type: ignore[union-attr]
|
|
604
|
+
hf_messages,
|
|
605
|
+
tokenize=False,
|
|
606
|
+
chat_template=self.chat_template,
|
|
607
|
+
add_generation_prompt=True,
|
|
608
|
+
tools=[tc.tool_spec for tc in flat_tools] if flat_tools else None,
|
|
609
|
+
enable_thinking=self.enable_thinking,
|
|
610
|
+
)
|
|
611
|
+
# prepared_prompt is a string since we set tokenize=False https://hf.co/docs/transformers/main/chat_templating
|
|
612
|
+
assert isinstance(prepared_prompt, str) # noqa: S101
|
|
613
|
+
|
|
614
|
+
# Avoid some unnecessary warnings in the generation pipeline call
|
|
615
|
+
# mypy doesn't know tokenizer is set in warm_up
|
|
616
|
+
generation_kwargs["pad_token_id"] = (
|
|
617
|
+
generation_kwargs.get("pad_token_id", tokenizer.pad_token_id) or tokenizer.eos_token_id # type: ignore[union-attr]
|
|
618
|
+
)
|
|
619
|
+
|
|
620
|
+
return {
|
|
621
|
+
"prepared_prompt": prepared_prompt,
|
|
622
|
+
"tokenizer": tokenizer,
|
|
623
|
+
"generation_kwargs": generation_kwargs,
|
|
624
|
+
"tools": flat_tools,
|
|
625
|
+
"stop_words": stop_words,
|
|
626
|
+
}
|
|
627
|
+
|
|
628
|
+
def _convert_hf_output_to_chat_messages(
|
|
629
|
+
self,
|
|
630
|
+
*,
|
|
631
|
+
hf_pipeline_output: list[dict[str, Any]],
|
|
632
|
+
prepared_prompt: str,
|
|
633
|
+
tokenizer: Union["PreTrainedTokenizer", "PreTrainedTokenizerFast"],
|
|
634
|
+
generation_kwargs: dict[str, Any],
|
|
635
|
+
stop_words: list[str] | None,
|
|
636
|
+
tools: list[Tool] | Toolset | None = None,
|
|
637
|
+
) -> list[ChatMessage]:
|
|
638
|
+
"""
|
|
639
|
+
Converts the HuggingFace pipeline output into a List of ChatMessages
|
|
640
|
+
|
|
641
|
+
:param hf_pipeline_output: The output from the HuggingFace pipeline.
|
|
642
|
+
:param prepared_prompt: The prompt used for generation.
|
|
643
|
+
:param tokenizer: The tokenizer used for generation.
|
|
644
|
+
:param generation_kwargs: The generation parameters.
|
|
645
|
+
:param stop_words: A list of stop words to remove from the replies.
|
|
646
|
+
:param tools: A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
|
|
647
|
+
This parameter can accept either a list of `Tool` objects or a `Toolset` instance.
|
|
648
|
+
"""
|
|
649
|
+
replies = [o.get("generated_text", "") for o in hf_pipeline_output]
|
|
650
|
+
|
|
651
|
+
# Remove stop words from replies if present
|
|
652
|
+
if stop_words:
|
|
653
|
+
for stop_word in stop_words:
|
|
654
|
+
replies = [reply.replace(stop_word, "").rstrip() for reply in replies]
|
|
655
|
+
|
|
656
|
+
return [
|
|
657
|
+
self.create_message(
|
|
658
|
+
text=reply,
|
|
659
|
+
index=r_index,
|
|
660
|
+
tokenizer=tokenizer,
|
|
661
|
+
prompt=prepared_prompt,
|
|
662
|
+
generation_kwargs=generation_kwargs,
|
|
663
|
+
parse_tool_calls=bool(tools),
|
|
664
|
+
)
|
|
665
|
+
for r_index, reply in enumerate(replies)
|
|
666
|
+
]
|
|
File without changes
|