transformers-haystack 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. haystack_integrations/components/classifiers/py.typed +0 -0
  2. haystack_integrations/components/classifiers/transformers/__init__.py +6 -0
  3. haystack_integrations/components/classifiers/transformers/zero_shot_document_classifier.py +247 -0
  4. haystack_integrations/components/common/py.typed +0 -0
  5. haystack_integrations/components/common/transformers/__init__.py +3 -0
  6. haystack_integrations/components/common/transformers/utils.py +234 -0
  7. haystack_integrations/components/extractors/py.typed +0 -0
  8. haystack_integrations/components/extractors/transformers/__init__.py +6 -0
  9. haystack_integrations/components/extractors/transformers/named_entity_extractor.py +262 -0
  10. haystack_integrations/components/generators/py.typed +0 -0
  11. haystack_integrations/components/generators/transformers/__init__.py +6 -0
  12. haystack_integrations/components/generators/transformers/chat/__init__.py +3 -0
  13. haystack_integrations/components/generators/transformers/chat/chat_generator.py +666 -0
  14. haystack_integrations/components/readers/py.typed +0 -0
  15. haystack_integrations/components/readers/transformers/__init__.py +6 -0
  16. haystack_integrations/components/readers/transformers/extractive_reader.py +662 -0
  17. haystack_integrations/components/routers/py.typed +0 -0
  18. haystack_integrations/components/routers/transformers/__init__.py +7 -0
  19. haystack_integrations/components/routers/transformers/text_router.py +196 -0
  20. haystack_integrations/components/routers/transformers/zero_shot_text_router.py +205 -0
  21. transformers_haystack-0.1.0.dist-info/METADATA +38 -0
  22. transformers_haystack-0.1.0.dist-info/RECORD +24 -0
  23. transformers_haystack-0.1.0.dist-info/WHEEL +4 -0
  24. transformers_haystack-0.1.0.dist-info/licenses/LICENSE.txt +201 -0
@@ -0,0 +1,666 @@
1
+ # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
+ #
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ import asyncio
6
+ import json
7
+ import re
8
+ import sys
9
+ from collections.abc import AsyncIterator, Callable
10
+ from concurrent.futures import ThreadPoolExecutor
11
+ from contextlib import asynccontextmanager, suppress
12
+ from typing import Any, Literal, Union
13
+
14
+ from haystack import component, default_from_dict, default_to_dict, logging
15
+ from haystack.components.generators.utils import _normalize_messages
16
+ from haystack.dataclasses import ChatMessage, ComponentInfo, StreamingCallbackT, ToolCall
17
+ from haystack.dataclasses.streaming_chunk import select_streaming_callback
18
+ from haystack.tools import (
19
+ Tool,
20
+ Toolset,
21
+ ToolsType,
22
+ _check_duplicate_tool_names,
23
+ deserialize_tools_or_toolset_inplace,
24
+ flatten_tools_or_toolsets,
25
+ serialize_tools_or_toolset,
26
+ )
27
+ from haystack.tools.utils import warm_up_tools
28
+ from haystack.utils import ComponentDevice, Secret, deserialize_callable, serialize_callable
29
+ from haystack.utils.hf import convert_message_to_hf_format, deserialize_hf_model_kwargs, serialize_hf_model_kwargs
30
+ from huggingface_hub import model_info
31
+ from packaging.version import Version
32
+
33
+ import transformers
34
+ from haystack_integrations.components.common.transformers.utils import (
35
+ _AsyncHFTokenStreamingHandler,
36
+ _HFTokenStreamingHandler,
37
+ _StopWordsCriteria,
38
+ )
39
+ from transformers import Pipeline as HfPipeline
40
+ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, StoppingCriteriaList, pipeline
41
+
42
+ logger = logging.getLogger(__name__)
43
+
44
+ PIPELINE_SUPPORTED_TASKS = ["text-generation", "text2text-generation", "image-text-to-text"]
45
+
46
+ DEFAULT_TOOL_PATTERN = (
47
+ r"(?:<tool_call>)?"
48
+ r'(?:\s*\{.*?"name"\s*:\s*"([^"]+)".*?"arguments"\s*:\s*(\{[^}]+\}).*?\}'
49
+ r'|\{.*?"function"\s*:\s*\{.*?"name"\s*:\s*"([^"]+)".*?"arguments"\s*:\s*(\{[^}]+\}).*?\})'
50
+ )
51
+
52
+
53
+ def default_tool_parser(text: str) -> list[ToolCall] | None:
54
+ """
55
+ Default implementation for parsing tool calls from model output text.
56
+
57
+ Uses DEFAULT_TOOL_PATTERN to extract tool calls.
58
+
59
+ :param text: The text to parse for tool calls.
60
+ :returns: A list containing a single ToolCall if a valid tool call is found, None otherwise.
61
+ """
62
+ try:
63
+ match = re.search(DEFAULT_TOOL_PATTERN, text, re.DOTALL)
64
+ except re.error:
65
+ logger.warning("Invalid regex pattern for tool parsing: {pattern}", pattern=DEFAULT_TOOL_PATTERN)
66
+ return None
67
+
68
+ if not match:
69
+ return None
70
+
71
+ name = match.group(1) or match.group(3)
72
+ args_str = match.group(2) or match.group(4)
73
+
74
+ try:
75
+ arguments = json.loads(args_str)
76
+ return [ToolCall(tool_name=name, arguments=arguments)]
77
+ except json.JSONDecodeError:
78
+ logger.warning("Failed to parse tool call arguments: {args_str}", args_str=args_str)
79
+ return None
80
+
81
+
82
+ @component
83
+ class TransformersChatGenerator:
84
+ """
85
+ Generates chat responses using models from Hugging Face that run locally.
86
+
87
+ Use this component with chat-based models,
88
+ such as `Qwen/Qwen3-0.6B` or `meta-llama/Llama-2-7b-chat-hf`.
89
+ LLMs running locally may need powerful hardware.
90
+
91
+ ### Usage example
92
+ ```python
93
+ from haystack.dataclasses import ChatMessage
94
+
95
+ from haystack_integrations.components.generators.transformers import TransformersChatGenerator
96
+
97
+ generator = TransformersChatGenerator(model="Qwen/Qwen3-0.6B")
98
+ messages = [ChatMessage.from_user("What's Natural Language Processing? Be brief.")]
99
+ print(generator.run(messages))
100
+ ```
101
+
102
+ ```
103
+ {'replies':
104
+ [ChatMessage(_role=<ChatRole.ASSISTANT: 'assistant'>, _content=[TextContent(text=
105
+ "Natural Language Processing (NLP) is a subfield of artificial intelligence that deals
106
+ with the interaction between computers and human language. It enables computers to understand, interpret, and
107
+ generate human language in a valuable way. NLP involves various techniques such as speech recognition, text
108
+ analysis, sentiment analysis, and machine translation. The ultimate goal is to make it easier for computers to
109
+ process and derive meaning from human language, improving communication between humans and machines.")],
110
+ _name=None,
111
+ _meta={'finish_reason': 'stop', 'index': 0, 'model':
112
+ 'mistralai/Mistral-7B-Instruct-v0.2',
113
+ 'usage': {'completion_tokens': 90, 'prompt_tokens': 19, 'total_tokens': 109}})
114
+ ]
115
+ }
116
+ ```
117
+ """
118
+
119
+ def __init__(
120
+ self,
121
+ model: str = "Qwen/Qwen3-0.6B",
122
+ task: Literal["text-generation", "text2text-generation", "image-text-to-text"] | None = None,
123
+ device: ComponentDevice | None = None,
124
+ token: Secret | None = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
125
+ chat_template: str | None = None,
126
+ generation_kwargs: dict[str, Any] | None = None,
127
+ huggingface_pipeline_kwargs: dict[str, Any] | None = None,
128
+ stop_words: list[str] | None = None,
129
+ streaming_callback: StreamingCallbackT | None = None,
130
+ tools: ToolsType | None = None,
131
+ tool_parsing_function: Callable[[str], list[ToolCall] | None] | None = None,
132
+ async_executor: ThreadPoolExecutor | None = None,
133
+ *,
134
+ enable_thinking: bool = False,
135
+ ) -> None:
136
+ """
137
+ Initializes the TransformersChatGenerator component.
138
+
139
+ :param model: The Hugging Face text generation model name or path,
140
+ for example, `mistralai/Mistral-7B-Instruct-v0.2` or `TheBloke/OpenHermes-2.5-Mistral-7B-16k-AWQ`.
141
+ The model must be a chat model supporting the ChatML messaging
142
+ format.
143
+ If the model is specified in `huggingface_pipeline_kwargs`, this parameter is ignored.
144
+ :param task: The task for the Hugging Face pipeline. Possible options:
145
+ - `text-generation`: Supported by decoder models, like GPT.
146
+ - `text2text-generation`: Deprecated as of Transformers v5; use `text-generation` instead.
147
+ Previously supported by encoder-decoder models such as T5.
148
+ - `image-text-to-text`: Supported by vision-language models.
149
+ If the task is specified in `huggingface_pipeline_kwargs`, this parameter is ignored.
150
+ If not specified, the component calls the Hugging Face API to infer the task from the model name.
151
+ :param device: The device for loading the model. If `None`, automatically selects the default device.
152
+ If a device or device map is specified in `huggingface_pipeline_kwargs`, it overrides this parameter.
153
+ :param token: The token to use as HTTP bearer authorization for remote files.
154
+ If the token is specified in `huggingface_pipeline_kwargs`, this parameter is ignored.
155
+ :param chat_template: Specifies an optional Jinja template for formatting chat
156
+ messages. Most high-quality chat models have their own templates, but for models without this
157
+ feature or if you prefer a custom template, use this parameter.
158
+ :param generation_kwargs: A dictionary with keyword arguments to customize text generation.
159
+ Some examples: `max_length`, `max_new_tokens`, `temperature`, `top_k`, `top_p`.
160
+ See Hugging Face's documentation for more information:
161
+ - - [customize-text-generation](https://huggingface.co/docs/transformers/main/en/generation_strategies#customize-text-generation)
162
+ - - [GenerationConfig](https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationConfig)
163
+ The only `generation_kwargs` set by default is `max_new_tokens`, which is set to 512 tokens.
164
+ :param huggingface_pipeline_kwargs: Dictionary with keyword arguments to initialize the
165
+ Hugging Face pipeline for text generation.
166
+ These keyword arguments provide fine-grained control over the Hugging Face pipeline.
167
+ In case of duplication, these kwargs override `model`, `task`, `device`, and `token` init parameters.
168
+ For kwargs, see [Hugging Face documentation](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.pipeline.task).
169
+ In this dictionary, you can also include `model_kwargs` to specify the kwargs for [model initialization](https://huggingface.co/docs/transformers/en/main_classes/model#transformers.PreTrainedModel.from_pretrained)
170
+ :param stop_words: A list of stop words. If the model generates a stop word, the generation stops.
171
+ If you provide this parameter, don't specify the `stopping_criteria` in `generation_kwargs`.
172
+ For some chat models, the output includes both the new text and the original prompt.
173
+ In these cases, make sure your prompt has no stop words.
174
+ :param streaming_callback: An optional callable for handling streaming responses.
175
+ :param tools: A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
176
+ :param tool_parsing_function:
177
+ A callable that takes a string and returns a list of ToolCall objects or None.
178
+ If None, the default_tool_parser will be used which extracts tool calls using a predefined pattern.
179
+ :param async_executor:
180
+ Optional ThreadPoolExecutor to use for async calls. If not provided, a single-threaded executor will be
181
+ initialized and used
182
+ :param enable_thinking:
183
+ Whether to enable thinking mode in the chat template for thinking-capable models.
184
+ When enabled, the model generates intermediate reasoning before the final response. Defaults to False.
185
+ """
186
+ if tools and streaming_callback is not None:
187
+ msg = "Using tools and streaming at the same time is not supported. Please choose one."
188
+ raise ValueError(msg)
189
+ _check_duplicate_tool_names(flatten_tools_or_toolsets(tools))
190
+
191
+ huggingface_pipeline_kwargs = huggingface_pipeline_kwargs or {}
192
+ generation_kwargs = generation_kwargs or {}
193
+
194
+ self.token = token
195
+ token = token.resolve_value() if token else None
196
+
197
+ # check if the huggingface_pipeline_kwargs contain the essential parameters
198
+ # otherwise, populate them with values from other init parameters
199
+ huggingface_pipeline_kwargs.setdefault("model", model)
200
+ huggingface_pipeline_kwargs.setdefault("token", token)
201
+
202
+ device = ComponentDevice.resolve_device(device)
203
+ device.update_hf_kwargs(huggingface_pipeline_kwargs, overwrite=False)
204
+
205
+ # task identification and validation
206
+ if task is None:
207
+ if "task" in huggingface_pipeline_kwargs:
208
+ task = huggingface_pipeline_kwargs["task"]
209
+ elif isinstance(huggingface_pipeline_kwargs["model"], str):
210
+ task = model_info(
211
+ huggingface_pipeline_kwargs["model"], token=huggingface_pipeline_kwargs["token"]
212
+ ).pipeline_tag # type: ignore[assignment] # we'll check below if task is in supported tasks
213
+
214
+ if task not in PIPELINE_SUPPORTED_TASKS:
215
+ msg = f"Task '{task}' is not supported. The supported tasks are: {', '.join(PIPELINE_SUPPORTED_TASKS)}."
216
+ raise ValueError(msg)
217
+ if task == "text2text-generation" and Version(transformers.__version__) >= Version("5.0.0"):
218
+ msg = "Task 'text2text-generation' is not supported with transformers v5 or higher."
219
+ raise ValueError(msg)
220
+ huggingface_pipeline_kwargs["task"] = task
221
+
222
+ # if not specified, set return_full_text to False for text-generation
223
+ # only generated text is returned (excluding prompt)
224
+ if task == "text-generation":
225
+ generation_kwargs.setdefault("return_full_text", False)
226
+
227
+ if stop_words and "stopping_criteria" in generation_kwargs:
228
+ msg = (
229
+ "Found both the `stop_words` init parameter and the `stopping_criteria` key in `generation_kwargs`. "
230
+ "Please specify only one of them."
231
+ )
232
+ raise ValueError(msg)
233
+ generation_kwargs.setdefault("max_new_tokens", 512)
234
+ generation_kwargs["stop_sequences"] = generation_kwargs.get("stop_sequences", [])
235
+ generation_kwargs["stop_sequences"].extend(stop_words or [])
236
+
237
+ self.tool_parsing_function = tool_parsing_function or default_tool_parser
238
+ self.huggingface_pipeline_kwargs = huggingface_pipeline_kwargs
239
+ self.generation_kwargs = generation_kwargs
240
+ self.chat_template = chat_template
241
+ self.streaming_callback = streaming_callback
242
+ self.pipeline: HfPipeline | None = None
243
+ self.tools = tools
244
+ self.enable_thinking = enable_thinking
245
+
246
+ self._owns_executor = async_executor is None
247
+ self.executor = (
248
+ ThreadPoolExecutor(thread_name_prefix=f"async-TransformersChatGenerator-executor-{id(self)}", max_workers=1)
249
+ if async_executor is None
250
+ else async_executor
251
+ )
252
+ self._is_warmed_up = False
253
+
254
+ def __del__(self) -> None:
255
+ """
256
+ Cleanup when the instance is being destroyed.
257
+ """
258
+ if hasattr(self, "_owns_executor") and self._owns_executor and hasattr(self, "executor"):
259
+ self.executor.shutdown(wait=True)
260
+
261
+ def shutdown(self) -> None:
262
+ """
263
+ Explicitly shutdown the executor if we own it.
264
+ """
265
+ if self._owns_executor:
266
+ self.executor.shutdown(wait=True)
267
+
268
+ def _get_telemetry_data(self) -> dict[str, Any]:
269
+ """
270
+ Data that is sent to Posthog for usage analytics.
271
+ """
272
+ if isinstance(self.huggingface_pipeline_kwargs["model"], str):
273
+ return {"model": self.huggingface_pipeline_kwargs["model"]}
274
+ return {"model": f"[object of type {type(self.huggingface_pipeline_kwargs['model'])}]"}
275
+
276
+ def warm_up(self) -> None:
277
+ """
278
+ Initializes the component and warms up tools if provided.
279
+ """
280
+ if self._is_warmed_up:
281
+ return
282
+
283
+ # Initialize the pipeline
284
+ if self.pipeline is None:
285
+ self.pipeline = pipeline(**self.huggingface_pipeline_kwargs)
286
+
287
+ # Warm up tools
288
+ if self.tools:
289
+ warm_up_tools(self.tools)
290
+
291
+ self._is_warmed_up = True
292
+
293
+ def to_dict(self) -> dict[str, Any]:
294
+ """
295
+ Serializes the component to a dictionary.
296
+
297
+ :returns:
298
+ Dictionary with serialized data.
299
+ """
300
+ callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
301
+ serialization_dict = default_to_dict(
302
+ self,
303
+ huggingface_pipeline_kwargs=self.huggingface_pipeline_kwargs,
304
+ generation_kwargs=self.generation_kwargs,
305
+ streaming_callback=callback_name,
306
+ token=self.token,
307
+ chat_template=self.chat_template,
308
+ tools=serialize_tools_or_toolset(self.tools),
309
+ tool_parsing_function=serialize_callable(self.tool_parsing_function),
310
+ enable_thinking=self.enable_thinking,
311
+ )
312
+
313
+ huggingface_pipeline_kwargs = serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"]
314
+ huggingface_pipeline_kwargs.pop("token", None)
315
+
316
+ serialize_hf_model_kwargs(huggingface_pipeline_kwargs)
317
+ return serialization_dict
318
+
319
+ @classmethod
320
+ def from_dict(cls, data: dict[str, Any]) -> "TransformersChatGenerator":
321
+ """
322
+ Deserializes the component from a dictionary.
323
+
324
+ :param data:
325
+ The dictionary to deserialize from.
326
+ :returns:
327
+ The deserialized component.
328
+ """
329
+ deserialize_tools_or_toolset_inplace(data["init_parameters"], key="tools")
330
+ init_params = data.get("init_parameters", {})
331
+ serialized_callback_handler = init_params.get("streaming_callback")
332
+ if serialized_callback_handler:
333
+ data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)
334
+
335
+ tool_parsing_function = init_params.get("tool_parsing_function")
336
+ if tool_parsing_function:
337
+ init_params["tool_parsing_function"] = deserialize_callable(tool_parsing_function)
338
+
339
+ huggingface_pipeline_kwargs = init_params.get("huggingface_pipeline_kwargs", {})
340
+ deserialize_hf_model_kwargs(huggingface_pipeline_kwargs)
341
+ return default_from_dict(cls, data)
342
+
343
+ @component.output_types(replies=list[ChatMessage])
344
+ def run(
345
+ self,
346
+ messages: list[ChatMessage] | str,
347
+ generation_kwargs: dict[str, Any] | None = None,
348
+ streaming_callback: StreamingCallbackT | None = None,
349
+ tools: ToolsType | None = None,
350
+ ) -> dict[str, list[ChatMessage]]:
351
+ """
352
+ Invoke text generation inference based on the provided messages and generation parameters.
353
+
354
+ :param messages: A list of ChatMessage objects representing the input messages. If a string is provided,
355
+ it is converted to a list containing a ChatMessage with user role.
356
+ :param generation_kwargs: Additional keyword arguments for text generation.
357
+ :param streaming_callback: An optional callable for handling streaming responses.
358
+ :param tools: A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
359
+ If set, it will override the `tools` parameter provided during initialization.
360
+ :returns: A dictionary with the following keys:
361
+ - `replies`: A list containing the generated responses as ChatMessage instances.
362
+ """
363
+ if self.pipeline is None:
364
+ self.warm_up()
365
+
366
+ messages = _normalize_messages(messages)
367
+
368
+ prepared_inputs = self._prepare_inputs(
369
+ messages=messages, generation_kwargs=generation_kwargs, streaming_callback=streaming_callback, tools=tools
370
+ )
371
+
372
+ streaming_callback = select_streaming_callback(
373
+ self.streaming_callback, streaming_callback, requires_async=False
374
+ )
375
+ if streaming_callback:
376
+ # streamer parameter hooks into HF streaming, _HFTokenStreamingHandler is an adapter to our streaming
377
+ prepared_inputs["generation_kwargs"]["streamer"] = _HFTokenStreamingHandler(
378
+ tokenizer=prepared_inputs["tokenizer"],
379
+ stream_handler=streaming_callback,
380
+ stop_words=prepared_inputs["stop_words"],
381
+ component_info=ComponentInfo.from_component(self),
382
+ )
383
+
384
+ # We know it's not None because we check it in _prepare_inputs
385
+ assert self.pipeline is not None # noqa: S101
386
+ # Generate responses
387
+ output = self.pipeline(prepared_inputs["prepared_prompt"], **prepared_inputs["generation_kwargs"])
388
+
389
+ chat_messages = self._convert_hf_output_to_chat_messages(hf_pipeline_output=output, **prepared_inputs)
390
+
391
+ return {"replies": chat_messages}
392
+
393
+ def create_message(
394
+ self,
395
+ text: str,
396
+ index: int,
397
+ tokenizer: Union["PreTrainedTokenizer", "PreTrainedTokenizerFast"],
398
+ prompt: str,
399
+ generation_kwargs: dict[str, Any],
400
+ parse_tool_calls: bool = False,
401
+ ) -> ChatMessage:
402
+ """
403
+ Create a ChatMessage instance from the provided text, populated with metadata.
404
+
405
+ :param text: The generated text.
406
+ :param index: The index of the generated text.
407
+ :param tokenizer: The tokenizer used for generation.
408
+ :param prompt: The prompt used for generation.
409
+ :param generation_kwargs: The generation parameters.
410
+ :param parse_tool_calls: Whether to attempt parsing tool calls from the text.
411
+ :returns: A ChatMessage instance.
412
+ """
413
+
414
+ completion_tokens = len(tokenizer.encode(text, add_special_tokens=False))
415
+ prompt_token_count = len(tokenizer.encode(prompt, add_special_tokens=False))
416
+ total_tokens = prompt_token_count + completion_tokens
417
+
418
+ tool_calls = self.tool_parsing_function(text) if parse_tool_calls else None
419
+
420
+ # Determine finish reason based on context
421
+ if completion_tokens >= generation_kwargs.get("max_new_tokens", sys.maxsize):
422
+ finish_reason = "length"
423
+ elif tool_calls:
424
+ finish_reason = "tool_calls"
425
+ else:
426
+ finish_reason = "stop"
427
+
428
+ meta = {
429
+ "finish_reason": finish_reason,
430
+ "index": index,
431
+ "model": self.huggingface_pipeline_kwargs["model"],
432
+ "usage": {
433
+ "completion_tokens": completion_tokens,
434
+ "prompt_tokens": prompt_token_count,
435
+ "total_tokens": total_tokens,
436
+ },
437
+ }
438
+
439
+ # If tool calls are detected, don't include the text content since it contains the raw tool call format
440
+ return ChatMessage.from_assistant(tool_calls=tool_calls, text=None if tool_calls else text, meta=meta)
441
+
442
+ @staticmethod
443
+ def _validate_stop_words(stop_words: list[str] | None) -> list[str] | None:
444
+ """
445
+ Validates the provided stop words.
446
+
447
+ :param stop_words: A list of stop words to validate.
448
+ :return: A sanitized list of stop words or None if validation fails.
449
+ """
450
+ if stop_words and not all(isinstance(word, str) for word in stop_words):
451
+ logger.warning(
452
+ "Invalid stop words provided. Stop words must be specified as a list of strings. "
453
+ "Ignoring stop words: {stop_words}",
454
+ stop_words=stop_words,
455
+ )
456
+ return None
457
+
458
+ return list(set(stop_words or []))
459
+
460
+ @component.output_types(replies=list[ChatMessage])
461
+ async def run_async(
462
+ self,
463
+ messages: list[ChatMessage] | str,
464
+ generation_kwargs: dict[str, Any] | None = None,
465
+ streaming_callback: StreamingCallbackT | None = None,
466
+ tools: ToolsType | None = None,
467
+ ) -> dict[str, list[ChatMessage]]:
468
+ """
469
+ Asynchronously invokes text generation inference based on the provided messages and generation parameters.
470
+
471
+ This is the asynchronous version of the `run` method. It has the same parameters
472
+ and return values but can be used with `await` in an async code.
473
+
474
+ :param messages: A list of ChatMessage objects representing the input messages.
475
+ :param generation_kwargs: Additional keyword arguments for text generation.
476
+ :param streaming_callback: An optional callable for handling streaming responses.
477
+ :param tools: A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
478
+ If set, it will override the `tools` parameter provided during initialization.
479
+ :returns: A dictionary with the following keys:
480
+ - `replies`: A list containing the generated responses as ChatMessage instances.
481
+ """
482
+ if self.pipeline is None:
483
+ self.warm_up()
484
+
485
+ messages = _normalize_messages(messages)
486
+
487
+ prepared_inputs = self._prepare_inputs(
488
+ messages=messages, generation_kwargs=generation_kwargs, streaming_callback=streaming_callback, tools=tools
489
+ )
490
+
491
+ # Validate and select the streaming callback
492
+ streaming_callback = select_streaming_callback(self.streaming_callback, streaming_callback, requires_async=True)
493
+
494
+ if streaming_callback:
495
+ async_handler = _AsyncHFTokenStreamingHandler(
496
+ tokenizer=prepared_inputs["tokenizer"],
497
+ stream_handler=streaming_callback,
498
+ stop_words=prepared_inputs["stop_words"],
499
+ component_info=ComponentInfo.from_component(self),
500
+ )
501
+ prepared_inputs["generation_kwargs"]["streamer"] = async_handler
502
+
503
+ # Use async context manager for proper resource cleanup
504
+ async with self._manage_queue_processor(async_handler):
505
+ output = await asyncio.get_running_loop().run_in_executor(
506
+ self.executor,
507
+ # have to ignore since assert self.pipeline is not None doesn't work
508
+ lambda: self.pipeline( # type: ignore[misc]
509
+ prepared_inputs["prepared_prompt"], **prepared_inputs["generation_kwargs"]
510
+ ),
511
+ )
512
+ else:
513
+ output = await asyncio.get_running_loop().run_in_executor(
514
+ self.executor,
515
+ # have to ignore since assert self.pipeline is not None doesn't work
516
+ lambda: self.pipeline( # type: ignore[misc]
517
+ prepared_inputs["prepared_prompt"], **prepared_inputs["generation_kwargs"]
518
+ ),
519
+ )
520
+
521
+ chat_messages = self._convert_hf_output_to_chat_messages(hf_pipeline_output=output, **prepared_inputs)
522
+ return {"replies": chat_messages}
523
+
524
+ @asynccontextmanager
525
+ async def _manage_queue_processor(
526
+ self, async_handler: "_AsyncHFTokenStreamingHandler"
527
+ ) -> AsyncIterator["asyncio.Task[None]"]:
528
+ """Context manager for proper queue processor lifecycle management."""
529
+ queue_processor = asyncio.create_task(async_handler.process_queue())
530
+ try:
531
+ yield queue_processor
532
+ finally:
533
+ # Ensure the queue processor is cleaned up properly
534
+ try:
535
+ await asyncio.wait_for(queue_processor, timeout=0.1)
536
+ except asyncio.TimeoutError:
537
+ queue_processor.cancel()
538
+ with suppress(asyncio.CancelledError):
539
+ await queue_processor
540
+
541
+ def _prepare_inputs(
542
+ self,
543
+ messages: list[ChatMessage],
544
+ generation_kwargs: dict[str, Any] | None = None,
545
+ streaming_callback: StreamingCallbackT | None = None,
546
+ tools: ToolsType | None = None,
547
+ ) -> dict[str, Any]:
548
+ """
549
+ Prepares the inputs for the Hugging Face pipeline.
550
+
551
+ :param messages: A list of ChatMessage objects representing the input messages.
552
+ :param generation_kwargs: Additional keyword arguments for text generation.
553
+ :param streaming_callback: An optional callable for handling streaming responses.
554
+ :param tools: A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
555
+ :returns: A dictionary containing the prepared prompt, tokenizer, generation kwargs, and tools.
556
+ :raises ValueError: If both tools and streaming_callback are provided.
557
+ """
558
+ tools = tools or self.tools
559
+ if tools and streaming_callback is not None:
560
+ msg = "Using tools and streaming at the same time is not supported. Please choose one."
561
+ raise ValueError(msg)
562
+ flat_tools = flatten_tools_or_toolsets(tools)
563
+ _check_duplicate_tool_names(flat_tools)
564
+
565
+ # mypy doesn't know this is set in warm_up
566
+ tokenizer = self.pipeline.tokenizer # type: ignore[union-attr]
567
+
568
+ # Check and update generation parameters
569
+ generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
570
+
571
+ # If streaming_callback is provided, ensure that num_return_sequences is set to 1
572
+ if streaming_callback:
573
+ num_responses = generation_kwargs.get("num_return_sequences", 1)
574
+ if num_responses > 1:
575
+ msg = (
576
+ "Streaming is enabled, but the number of responses is set to {num_responses}. "
577
+ "Streaming is only supported for single response generation. "
578
+ "Setting the number of responses to 1."
579
+ )
580
+ logger.warning(msg, num_responses=num_responses)
581
+ generation_kwargs["num_return_sequences"] = 1
582
+
583
+ stop_words = generation_kwargs.pop("stop_words", []) + generation_kwargs.pop("stop_sequences", [])
584
+ stop_words = self._validate_stop_words(stop_words)
585
+
586
+ # Set up stop words criteria if stop words exist
587
+ stop_words_criteria = (
588
+ _StopWordsCriteria(
589
+ tokenizer, # type: ignore[arg-type]
590
+ stop_words,
591
+ self.pipeline.device, # type: ignore[union-attr]
592
+ )
593
+ if stop_words
594
+ else None
595
+ )
596
+ if stop_words_criteria:
597
+ generation_kwargs["stopping_criteria"] = StoppingCriteriaList([stop_words_criteria])
598
+
599
+ # convert messages to HF format
600
+ hf_messages = [convert_message_to_hf_format(message) for message in messages]
601
+
602
+ # mypy doesn't know tokenizer is set in warm_up
603
+ prepared_prompt = tokenizer.apply_chat_template( # type: ignore[union-attr]
604
+ hf_messages,
605
+ tokenize=False,
606
+ chat_template=self.chat_template,
607
+ add_generation_prompt=True,
608
+ tools=[tc.tool_spec for tc in flat_tools] if flat_tools else None,
609
+ enable_thinking=self.enable_thinking,
610
+ )
611
+ # prepared_prompt is a string since we set tokenize=False https://hf.co/docs/transformers/main/chat_templating
612
+ assert isinstance(prepared_prompt, str) # noqa: S101
613
+
614
+ # Avoid some unnecessary warnings in the generation pipeline call
615
+ # mypy doesn't know tokenizer is set in warm_up
616
+ generation_kwargs["pad_token_id"] = (
617
+ generation_kwargs.get("pad_token_id", tokenizer.pad_token_id) or tokenizer.eos_token_id # type: ignore[union-attr]
618
+ )
619
+
620
+ return {
621
+ "prepared_prompt": prepared_prompt,
622
+ "tokenizer": tokenizer,
623
+ "generation_kwargs": generation_kwargs,
624
+ "tools": flat_tools,
625
+ "stop_words": stop_words,
626
+ }
627
+
628
+ def _convert_hf_output_to_chat_messages(
629
+ self,
630
+ *,
631
+ hf_pipeline_output: list[dict[str, Any]],
632
+ prepared_prompt: str,
633
+ tokenizer: Union["PreTrainedTokenizer", "PreTrainedTokenizerFast"],
634
+ generation_kwargs: dict[str, Any],
635
+ stop_words: list[str] | None,
636
+ tools: list[Tool] | Toolset | None = None,
637
+ ) -> list[ChatMessage]:
638
+ """
639
+ Converts the HuggingFace pipeline output into a List of ChatMessages
640
+
641
+ :param hf_pipeline_output: The output from the HuggingFace pipeline.
642
+ :param prepared_prompt: The prompt used for generation.
643
+ :param tokenizer: The tokenizer used for generation.
644
+ :param generation_kwargs: The generation parameters.
645
+ :param stop_words: A list of stop words to remove from the replies.
646
+ :param tools: A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
647
+ This parameter can accept either a list of `Tool` objects or a `Toolset` instance.
648
+ """
649
+ replies = [o.get("generated_text", "") for o in hf_pipeline_output]
650
+
651
+ # Remove stop words from replies if present
652
+ if stop_words:
653
+ for stop_word in stop_words:
654
+ replies = [reply.replace(stop_word, "").rstrip() for reply in replies]
655
+
656
+ return [
657
+ self.create_message(
658
+ text=reply,
659
+ index=r_index,
660
+ tokenizer=tokenizer,
661
+ prompt=prepared_prompt,
662
+ generation_kwargs=generation_kwargs,
663
+ parse_tool_calls=bool(tools),
664
+ )
665
+ for r_index, reply in enumerate(replies)
666
+ ]
File without changes
@@ -0,0 +1,6 @@
1
+ # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
+ #
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ from .extractive_reader import TransformersExtractiveReader
5
+
6
+ __all__ = ["TransformersExtractiveReader"]