huggingface-api-haystack 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,738 @@
1
+ # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
+ #
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ import json
6
+ from collections.abc import AsyncIterable, Iterable
7
+ from datetime import datetime, timezone
8
+ from typing import Any, Union
9
+
10
+ from haystack import component, default_from_dict, default_to_dict, logging
11
+ from haystack.components.generators.utils import _convert_streaming_chunks_to_chat_message, _normalize_messages
12
+ from haystack.dataclasses import (
13
+ AsyncStreamingCallbackT,
14
+ ChatMessage,
15
+ ComponentInfo,
16
+ ReasoningContent,
17
+ StreamingCallbackT,
18
+ StreamingChunk,
19
+ SyncStreamingCallbackT,
20
+ ToolCall,
21
+ select_streaming_callback,
22
+ )
23
+ from haystack.dataclasses.streaming_chunk import FinishReason
24
+ from haystack.tools import (
25
+ ToolsType,
26
+ _check_duplicate_tool_names,
27
+ deserialize_tools_or_toolset_inplace,
28
+ flatten_tools_or_toolsets,
29
+ serialize_tools_or_toolset,
30
+ warm_up_tools,
31
+ )
32
+ from haystack.utils import Secret, deserialize_callable, serialize_callable
33
+ from haystack.utils.hf import convert_message_to_hf_format
34
+ from haystack.utils.url_validation import is_valid_http_url
35
+ from huggingface_hub import (
36
+ AsyncInferenceClient,
37
+ ChatCompletionInputFunctionDefinition,
38
+ ChatCompletionInputStreamOptions,
39
+ ChatCompletionInputTool,
40
+ ChatCompletionOutput,
41
+ ChatCompletionOutputComplete,
42
+ ChatCompletionOutputToolCall,
43
+ ChatCompletionStreamOutput,
44
+ ChatCompletionStreamOutputChoice,
45
+ InferenceClient,
46
+ )
47
+
48
+ from haystack_integrations.components.common.huggingface_api.utils import (
49
+ HFGenerationAPIType,
50
+ HFModelType,
51
+ _check_valid_model,
52
+ )
53
+
54
+ logger = logging.getLogger(__name__)
55
+
56
+
57
+ def _convert_hfapi_tool_calls(hfapi_tool_calls: list["ChatCompletionOutputToolCall"] | None) -> list[ToolCall]:
58
+ """
59
+ Convert HuggingFace API tool calls to a list of Haystack ToolCall.
60
+
61
+ :param hfapi_tool_calls: The HuggingFace API tool calls to convert.
62
+ :returns: A list of ToolCall objects.
63
+
64
+ """
65
+ if not hfapi_tool_calls:
66
+ return []
67
+
68
+ tool_calls = []
69
+
70
+ for hfapi_tc in hfapi_tool_calls:
71
+ hf_arguments = hfapi_tc.function.arguments
72
+
73
+ arguments = None
74
+ if isinstance(hf_arguments, dict):
75
+ arguments = hf_arguments
76
+ elif isinstance(hf_arguments, str):
77
+ try:
78
+ arguments = json.loads(hf_arguments)
79
+ except json.JSONDecodeError:
80
+ logger.warning(
81
+ "HuggingFace API returned a malformed JSON string for tool call arguments. This tool call "
82
+ "will be skipped. Tool call ID: {_id}, Tool name: {_name}, Arguments: {_arguments}",
83
+ _id=hfapi_tc.id,
84
+ _name=hfapi_tc.function.name,
85
+ _arguments=hf_arguments,
86
+ )
87
+ else:
88
+ logger.warning(
89
+ "HuggingFace API returned tool call arguments of type {_type}. Valid types are dict and str. This tool "
90
+ "call will be skipped. Tool call ID: {_id}, Tool name: {_name}, Arguments: {_arguments}",
91
+ _id=hfapi_tc.id,
92
+ _name=hfapi_tc.function.name,
93
+ _arguments=hf_arguments,
94
+ )
95
+
96
+ if arguments:
97
+ tool_calls.append(ToolCall(tool_name=hfapi_tc.function.name, arguments=arguments, id=hfapi_tc.id))
98
+
99
+ return tool_calls
100
+
101
+
102
+ def _extract_reasoning_content(message_or_delta: Any) -> ReasoningContent | None:
103
+ """
104
+ Extract reasoning content from a HuggingFace API message or delta object.
105
+
106
+ :param message_or_delta: The HuggingFace message or delta object that may contain reasoning.
107
+ :returns: ReasoningContent if reasoning is present, None otherwise.
108
+ """
109
+ if hasattr(message_or_delta, "reasoning") and message_or_delta.reasoning:
110
+ return ReasoningContent(reasoning_text=message_or_delta.reasoning)
111
+ return None
112
+
113
+
114
+ def _resolve_schema_refs(schema: dict[str, Any]) -> dict[str, Any]:
115
+ """
116
+ Resolve ``$ref`` references in a JSON schema by inlining ``$defs`` definitions.
117
+
118
+ The HuggingFace API does not support ``$defs`` and ``$ref`` in tool parameter schemas.
119
+ This function expands all ``$ref`` pointers and removes the ``$defs`` section.
120
+
121
+ :param schema: A JSON schema dict potentially containing ``$defs`` and ``$ref``.
122
+ :returns: A new schema dict with all references resolved and ``$defs`` removed.
123
+ """
124
+ defs = schema.get("$defs", {})
125
+ if not defs:
126
+ return schema
127
+
128
+ def _resolve(obj: Any, resolving: set[str] | None = None) -> Any:
129
+ if resolving is None:
130
+ resolving = set()
131
+ if isinstance(obj, dict):
132
+ if "$ref" in obj:
133
+ ref_path = obj["$ref"]
134
+ ref_prefix = "#/$defs/"
135
+ if isinstance(ref_path, str) and ref_path.startswith(ref_prefix):
136
+ def_name = ref_path.removeprefix(ref_prefix)
137
+ if def_name in defs and def_name not in resolving:
138
+ return _resolve(defs[def_name], resolving | {def_name})
139
+ return {k: _resolve(v, resolving) for k, v in obj.items() if k != "$defs"}
140
+ if isinstance(obj, list):
141
+ return [_resolve(item, resolving) for item in obj]
142
+ return obj
143
+
144
+ return _resolve(schema)
145
+
146
+
147
+ def _convert_tools_to_hfapi_tools(tools: ToolsType | None) -> list["ChatCompletionInputTool"] | None:
148
+ if not tools:
149
+ return None
150
+
151
+ hf_tools = []
152
+ for tool in flatten_tools_or_toolsets(tools):
153
+ hf_tools.append(
154
+ ChatCompletionInputTool(
155
+ function=ChatCompletionInputFunctionDefinition(
156
+ name=tool.name,
157
+ description=tool.description,
158
+ parameters=_resolve_schema_refs(tool.parameters),
159
+ ),
160
+ type="function",
161
+ )
162
+ )
163
+
164
+ return hf_tools
165
+
166
+
167
+ def _map_hf_finish_reason_to_haystack(
168
+ choice: Union["ChatCompletionStreamOutputChoice", "ChatCompletionOutputComplete"],
169
+ ) -> FinishReason | None:
170
+ """
171
+ Map HuggingFace finish reasons to Haystack FinishReason literals.
172
+
173
+ Uses the full choice object to detect tool calls and provide accurate mapping.
174
+
175
+ HuggingFace finish reasons (can be found here https://huggingface.github.io/text-generation-inference/ under
176
+ FinishReason):
177
+ - "length": number of generated tokens == `max_new_tokens`
178
+ - "eos_token": the model generated its end of sequence token
179
+ - "stop_sequence": the model generated a text included in `stop_sequences`
180
+
181
+ Additionally, detects tool calls from delta.tool_calls or delta.tool_call_id.
182
+
183
+ :param choice: The HuggingFace ChatCompletionStreamOutputChoice object.
184
+ :returns: The corresponding Haystack FinishReason or None.
185
+ """
186
+ if choice.finish_reason is None:
187
+ return None
188
+
189
+ # Check if this choice contains tool call information
190
+ if isinstance(choice, ChatCompletionStreamOutputChoice):
191
+ has_tool_calls = choice.delta.tool_calls is not None or choice.delta.tool_call_id is not None
192
+ else:
193
+ has_tool_calls = choice.message.tool_calls is not None or choice.message.tool_call_id is not None
194
+
195
+ # If we detect tool calls, override the finish reason
196
+ if has_tool_calls:
197
+ return "tool_calls"
198
+
199
+ # Map HuggingFace finish reasons to Haystack standard ones
200
+ mapping: dict[str, FinishReason] = {
201
+ "length": "length", # Direct match
202
+ "eos_token": "stop", # EOS token means natural stop
203
+ "stop_sequence": "stop", # Stop sequence means natural stop
204
+ }
205
+
206
+ return mapping.get(choice.finish_reason, "stop") # Default to "stop" for unknown reasons
207
+
208
+
209
+ def _convert_chat_completion_stream_output_to_streaming_chunk(
210
+ chunk: "ChatCompletionStreamOutput",
211
+ previous_chunks: list[StreamingChunk],
212
+ component_info: ComponentInfo | None = None,
213
+ ) -> StreamingChunk:
214
+ """
215
+ Converts the Hugging Face API ChatCompletionStreamOutput to a StreamingChunk.
216
+ """
217
+ # Choices is empty if include_usage is set to True where the usage information is returned.
218
+ if len(chunk.choices) == 0:
219
+ usage = None
220
+ if chunk.usage:
221
+ usage = {"prompt_tokens": chunk.usage.prompt_tokens, "completion_tokens": chunk.usage.completion_tokens}
222
+ return StreamingChunk(
223
+ content="",
224
+ meta={"model": chunk.model, "received_at": datetime.now(timezone.utc).isoformat(), "usage": usage},
225
+ component_info=component_info,
226
+ )
227
+
228
+ # n is unused, so the API always returns only one choice
229
+ # the argument is probably allowed for compatibility with OpenAI
230
+ # see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n
231
+ choice = chunk.choices[0]
232
+ mapped_finish_reason = _map_hf_finish_reason_to_haystack(choice) if choice.finish_reason else None
233
+
234
+ # Extract reasoning content if present
235
+ reasoning = _extract_reasoning_content(choice.delta)
236
+
237
+ return StreamingChunk(
238
+ content=choice.delta.content or "",
239
+ meta={
240
+ "model": chunk.model,
241
+ "received_at": datetime.now(timezone.utc).isoformat(),
242
+ "finish_reason": choice.finish_reason,
243
+ },
244
+ component_info=component_info,
245
+ # Index must always be 0 since we don't allow tool calls in streaming mode.
246
+ index=0 if choice.finish_reason is None else None,
247
+ # start is True at the very beginning since first chunk contains role information + first part of the answer.
248
+ start=len(previous_chunks) == 0,
249
+ finish_reason=mapped_finish_reason,
250
+ reasoning=reasoning,
251
+ )
252
+
253
+
254
+ @component
255
+ class HuggingFaceAPIChatGenerator:
256
+ """
257
+ Completes chats using Hugging Face APIs.
258
+
259
+ HuggingFaceAPIChatGenerator uses the [ChatMessage](https://docs.haystack.deepset.ai/docs/chatmessage)
260
+ format for input and output. Use it to generate text with Hugging Face APIs:
261
+ - [Serverless Inference API (Inference Providers)](https://huggingface.co/docs/inference-providers)
262
+ - [Paid Inference Endpoints](https://huggingface.co/inference-endpoints)
263
+ - [Self-hosted Text Generation Inference](https://github.com/huggingface/text-generation-inference)
264
+
265
+ ### Usage examples
266
+
267
+ #### With the serverless inference API (Inference Providers) - free tier available
268
+
269
+ ```python
270
+ from haystack_integrations.components.generators.huggingface_api import HuggingFaceAPIChatGenerator
271
+ from haystack.dataclasses import ChatMessage
272
+ from haystack.utils import Secret
273
+ from haystack_integrations.components.common.huggingface_api.utils import HFGenerationAPIType
274
+
275
+ messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant"),
276
+ ChatMessage.from_user("What's Natural Language Processing?")]
277
+
278
+ # the api_type can be expressed using the HFGenerationAPIType enum or as a string
279
+ api_type = HFGenerationAPIType.SERVERLESS_INFERENCE_API
280
+ api_type = "serverless_inference_api" # this is equivalent to the above
281
+
282
+ generator = HuggingFaceAPIChatGenerator(api_type=api_type,
283
+ api_params={"model": "Qwen/Qwen2.5-7B-Instruct",
284
+ "provider": "together"},
285
+ token=Secret.from_token("<your-api-key>"))
286
+
287
+ result = generator.run(messages)
288
+ print(result)
289
+ ```
290
+
291
+ #### With the serverless inference API (Inference Providers) and text+image input
292
+
293
+ ```python
294
+ from haystack_integrations.components.generators.huggingface_api import HuggingFaceAPIChatGenerator
295
+ from haystack.dataclasses import ChatMessage, ImageContent
296
+ from haystack.utils import Secret
297
+ from haystack_integrations.components.common.huggingface_api.utils import HFGenerationAPIType
298
+
299
+ # Create an image from file path, URL, or base64
300
+ image = ImageContent.from_file_path("path/to/your/image.jpg")
301
+
302
+ # Create a multimodal message with both text and image
303
+ messages = [ChatMessage.from_user(content_parts=["Describe this image in detail", image])]
304
+
305
+ generator = HuggingFaceAPIChatGenerator(
306
+ api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
307
+ api_params={
308
+ "model": "Qwen/Qwen2.5-VL-7B-Instruct", # Vision Language Model
309
+ "provider": "hyperbolic"
310
+ },
311
+ token=Secret.from_token("<your-api-key>")
312
+ )
313
+
314
+ result = generator.run(messages)
315
+ print(result)
316
+ ```
317
+
318
+ #### With paid inference endpoints
319
+
320
+ ```python
321
+ from haystack_integrations.components.generators.huggingface_api import HuggingFaceAPIChatGenerator
322
+ from haystack.dataclasses import ChatMessage
323
+ from haystack.utils import Secret
324
+
325
+ messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant"),
326
+ ChatMessage.from_user("What's Natural Language Processing?")]
327
+
328
+ generator = HuggingFaceAPIChatGenerator(api_type="inference_endpoints",
329
+ api_params={"url": "<your-inference-endpoint-url>"},
330
+ token=Secret.from_token("<your-api-key>"))
331
+
332
+ result = generator.run(messages)
333
+ print(result)
334
+ ```
335
+
336
+ #### With self-hosted text generation inference
337
+
338
+ ```python
339
+ from haystack_integrations.components.generators.huggingface_api import HuggingFaceAPIChatGenerator
340
+ from haystack.dataclasses import ChatMessage
341
+
342
+ messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant"),
343
+ ChatMessage.from_user("What's Natural Language Processing?")]
344
+
345
+ generator = HuggingFaceAPIChatGenerator(api_type="text_generation_inference",
346
+ api_params={"url": "http://localhost:8080"})
347
+
348
+ result = generator.run(messages)
349
+ print(result)
350
+ ```
351
+ """
352
+
353
+ def __init__(
354
+ self,
355
+ api_type: HFGenerationAPIType | str,
356
+ api_params: dict[str, str],
357
+ token: Secret | None = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
358
+ generation_kwargs: dict[str, Any] | None = None,
359
+ stop_words: list[str] | None = None,
360
+ streaming_callback: StreamingCallbackT | None = None,
361
+ tools: ToolsType | None = None,
362
+ ) -> None:
363
+ """
364
+ Initialize the HuggingFaceAPIChatGenerator instance.
365
+
366
+ :param api_type:
367
+ The type of Hugging Face API to use. Available types:
368
+ - `text_generation_inference`: See [TGI](https://github.com/huggingface/text-generation-inference).
369
+ - `inference_endpoints`: See [Inference Endpoints](https://huggingface.co/inference-endpoints).
370
+ - `serverless_inference_api`: See
371
+ [Serverless Inference API - Inference Providers](https://huggingface.co/docs/inference-providers).
372
+ :param api_params:
373
+ A dictionary with the following keys:
374
+ - `model`: Hugging Face model ID. Required when `api_type` is `SERVERLESS_INFERENCE_API`.
375
+ - `provider`: Provider name. Recommended when `api_type` is `SERVERLESS_INFERENCE_API`.
376
+ - `url`: URL of the inference endpoint. Required when `api_type` is `INFERENCE_ENDPOINTS` or
377
+ `TEXT_GENERATION_INFERENCE`.
378
+ - Other parameters specific to the chosen API type, such as `timeout`, `headers`, etc.
379
+ :param token:
380
+ The Hugging Face token to use as HTTP bearer authorization.
381
+ Check your HF token in your [account settings](https://huggingface.co/settings/tokens).
382
+ :param generation_kwargs:
383
+ A dictionary with keyword arguments to customize text generation.
384
+ Some examples: `max_tokens`, `temperature`, `top_p`.
385
+ For details, see [Hugging Face chat_completion documentation](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion).
386
+ :param stop_words:
387
+ An optional list of strings representing the stop words.
388
+ :param streaming_callback:
389
+ An optional callable for handling streaming responses.
390
+ :param tools:
391
+ A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
392
+ The chosen model should support tool/function calling, according to the model card.
393
+ Support for tools in the Hugging Face API and TGI is not yet fully refined and you may experience
394
+ unexpected behavior.
395
+ """
396
+ if isinstance(api_type, str):
397
+ api_type = HFGenerationAPIType.from_str(api_type)
398
+
399
+ if api_type == HFGenerationAPIType.SERVERLESS_INFERENCE_API:
400
+ model = api_params.get("model")
401
+ if model is None:
402
+ msg = "To use the Serverless Inference API, you need to specify the `model` parameter in `api_params`."
403
+ raise ValueError(msg)
404
+ _check_valid_model(model, HFModelType.GENERATION, token)
405
+ model_or_url = model
406
+ elif api_type in [HFGenerationAPIType.INFERENCE_ENDPOINTS, HFGenerationAPIType.TEXT_GENERATION_INFERENCE]:
407
+ url = api_params.get("url")
408
+ if url is None:
409
+ msg = (
410
+ "To use Text Generation Inference or Inference Endpoints, you need to specify the `url` parameter "
411
+ "in `api_params`."
412
+ )
413
+ raise ValueError(msg)
414
+ if not is_valid_http_url(url):
415
+ msg = f"Invalid URL: {url}"
416
+ raise ValueError(msg)
417
+ model_or_url = url
418
+ else:
419
+ msg = f"Unknown api_type {api_type}"
420
+ raise ValueError(msg)
421
+
422
+ if tools and streaming_callback is not None:
423
+ msg = "Using tools and streaming at the same time is not supported. Please choose one."
424
+ raise ValueError(msg)
425
+ _check_duplicate_tool_names(flatten_tools_or_toolsets(tools))
426
+
427
+ # handle generation kwargs setup
428
+ generation_kwargs = generation_kwargs.copy() if generation_kwargs else {}
429
+ generation_kwargs["stop"] = generation_kwargs.get("stop", [])
430
+ generation_kwargs["stop"].extend(stop_words or [])
431
+ generation_kwargs.setdefault("max_tokens", 512)
432
+
433
+ self.api_type = api_type
434
+ self.api_params = api_params
435
+ self.token = token
436
+ self.generation_kwargs = generation_kwargs
437
+ self.streaming_callback = streaming_callback
438
+
439
+ resolved_api_params: dict[str, Any] = {k: v for k, v in api_params.items() if k not in ("model", "url")}
440
+ self._client = InferenceClient(
441
+ model_or_url, token=token.resolve_value() if token else None, **resolved_api_params
442
+ )
443
+ self._async_client = AsyncInferenceClient(
444
+ model_or_url, token=token.resolve_value() if token else None, **resolved_api_params
445
+ )
446
+ self.tools = tools
447
+ self._is_warmed_up = False
448
+
449
+ def warm_up(self) -> None:
450
+ """
451
+ Warm up the Hugging Face API chat generator.
452
+
453
+ This will warm up the tools registered in the chat generator.
454
+ This method is idempotent and will only warm up the tools once.
455
+ """
456
+ if not self._is_warmed_up:
457
+ warm_up_tools(self.tools)
458
+ self._is_warmed_up = True
459
+
460
+ def to_dict(self) -> dict[str, Any]:
461
+ """
462
+ Serialize this component to a dictionary.
463
+
464
+ :returns:
465
+ A dictionary containing the serialized component.
466
+ """
467
+ callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
468
+ return default_to_dict(
469
+ self,
470
+ api_type=str(self.api_type),
471
+ api_params=self.api_params,
472
+ token=self.token,
473
+ generation_kwargs=self.generation_kwargs,
474
+ streaming_callback=callback_name,
475
+ tools=serialize_tools_or_toolset(self.tools),
476
+ )
477
+
478
+ @classmethod
479
+ def from_dict(cls, data: dict[str, Any]) -> "HuggingFaceAPIChatGenerator":
480
+ """
481
+ Deserialize this component from a dictionary.
482
+ """
483
+ deserialize_tools_or_toolset_inplace(data["init_parameters"], key="tools")
484
+ init_params = data.get("init_parameters", {})
485
+ serialized_callback_handler = init_params.get("streaming_callback")
486
+ if serialized_callback_handler:
487
+ data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)
488
+ return default_from_dict(cls, data)
489
+
490
+ @component.output_types(replies=list[ChatMessage])
491
+ def run(
492
+ self,
493
+ messages: list[ChatMessage] | str,
494
+ generation_kwargs: dict[str, Any] | None = None,
495
+ tools: ToolsType | None = None,
496
+ streaming_callback: StreamingCallbackT | None = None,
497
+ ) -> dict[str, list[ChatMessage]]:
498
+ """
499
+ Invoke the text generation inference based on the provided messages and generation parameters.
500
+
501
+ :param messages:
502
+ A list of ChatMessage objects representing the input messages. If a string is provided, it is converted
503
+ to a list containing a ChatMessage with user role.
504
+ :param generation_kwargs:
505
+ Additional keyword arguments for text generation.
506
+ :param tools:
507
+ A list of tools or a Toolset for which the model can prepare calls. If set, it will override
508
+ the `tools` parameter set during component initialization. This parameter can accept either a
509
+ list of `Tool` objects or a `Toolset` instance.
510
+ :param streaming_callback:
511
+ An optional callable for handling streaming responses. If set, it will override the `streaming_callback`
512
+ parameter set during component initialization.
513
+ :returns: A dictionary with the following keys:
514
+ - `replies`: A list containing the generated responses as ChatMessage objects.
515
+ """
516
+ if not self._is_warmed_up:
517
+ self.warm_up()
518
+
519
+ messages = _normalize_messages(messages)
520
+
521
+ # update generation kwargs by merging with the default ones
522
+ generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
523
+
524
+ formatted_messages = [convert_message_to_hf_format(message) for message in messages]
525
+
526
+ tools = tools or self.tools
527
+ if tools and self.streaming_callback:
528
+ msg = "Using tools and streaming at the same time is not supported. Please choose one."
529
+ raise ValueError(msg)
530
+ flat_tools = flatten_tools_or_toolsets(tools)
531
+ _check_duplicate_tool_names(flat_tools)
532
+
533
+ # validate and select the streaming callback
534
+ streaming_callback = select_streaming_callback(
535
+ self.streaming_callback, streaming_callback, requires_async=False
536
+ )
537
+
538
+ if streaming_callback:
539
+ return self._run_streaming(formatted_messages, generation_kwargs, streaming_callback)
540
+
541
+ hf_tools = _convert_tools_to_hfapi_tools(tools)
542
+
543
+ return self._run_non_streaming(formatted_messages, generation_kwargs, hf_tools)
544
+
545
+ @component.output_types(replies=list[ChatMessage])
546
+ async def run_async(
547
+ self,
548
+ messages: list[ChatMessage] | str,
549
+ generation_kwargs: dict[str, Any] | None = None,
550
+ tools: ToolsType | None = None,
551
+ streaming_callback: StreamingCallbackT | None = None,
552
+ ) -> dict[str, list[ChatMessage]]:
553
+ """
554
+ Asynchronously invokes the text generation inference based on the provided messages and generation parameters.
555
+
556
+ This is the asynchronous version of the `run` method. It has the same parameters
557
+ and return values but can be used with `await` in an async code.
558
+
559
+ :param messages:
560
+ A list of ChatMessage objects representing the input messages. If a string is provided, it is converted
561
+ to a list containing a ChatMessage with user role.
562
+ :param generation_kwargs:
563
+ Additional keyword arguments for text generation.
564
+ :param tools:
565
+ A list of tools or a Toolset for which the model can prepare calls. If set, it will override the `tools`
566
+ parameter set during component initialization. This parameter can accept either a list of `Tool` objects
567
+ or a `Toolset` instance.
568
+ :param streaming_callback:
569
+ An optional callable for handling streaming responses. If set, it will override the `streaming_callback`
570
+ parameter set during component initialization.
571
+ :returns: A dictionary with the following keys:
572
+ - `replies`: A list containing the generated responses as ChatMessage objects.
573
+ """
574
+ if not self._is_warmed_up:
575
+ self.warm_up()
576
+
577
+ messages = _normalize_messages(messages)
578
+
579
+ # update generation kwargs by merging with the default ones
580
+ generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
581
+
582
+ formatted_messages = [convert_message_to_hf_format(message) for message in messages]
583
+
584
+ tools = tools or self.tools
585
+ if tools and self.streaming_callback:
586
+ msg = "Using tools and streaming at the same time is not supported. Please choose one."
587
+ raise ValueError(msg)
588
+ flat_tools = flatten_tools_or_toolsets(tools)
589
+ _check_duplicate_tool_names(flat_tools)
590
+
591
+ # validate and select the streaming callback
592
+ streaming_callback = select_streaming_callback(self.streaming_callback, streaming_callback, requires_async=True)
593
+
594
+ if streaming_callback:
595
+ return await self._run_streaming_async(formatted_messages, generation_kwargs, streaming_callback)
596
+
597
+ hf_tools = _convert_tools_to_hfapi_tools(tools)
598
+
599
+ return await self._run_non_streaming_async(formatted_messages, generation_kwargs, hf_tools)
600
+
601
+ def _run_streaming(
602
+ self,
603
+ messages: list[dict[str, str]],
604
+ generation_kwargs: dict[str, Any],
605
+ streaming_callback: SyncStreamingCallbackT,
606
+ ) -> dict[str, list[ChatMessage]]:
607
+ api_output: Iterable[ChatCompletionStreamOutput] = self._client.chat_completion(
608
+ messages,
609
+ stream=True,
610
+ stream_options=ChatCompletionInputStreamOptions(include_usage=True),
611
+ **generation_kwargs,
612
+ )
613
+
614
+ component_info = ComponentInfo.from_component(self)
615
+ streaming_chunks: list[StreamingChunk] = []
616
+ for chunk in api_output:
617
+ streaming_chunk = _convert_chat_completion_stream_output_to_streaming_chunk(
618
+ chunk=chunk, previous_chunks=streaming_chunks, component_info=component_info
619
+ )
620
+ streaming_chunks.append(streaming_chunk)
621
+ streaming_callback(streaming_chunk)
622
+
623
+ message = _convert_streaming_chunks_to_chat_message(chunks=streaming_chunks)
624
+ if message.meta.get("usage") is None:
625
+ message.meta["usage"] = {"prompt_tokens": 0, "completion_tokens": 0}
626
+
627
+ return {"replies": [message]}
628
+
629
+ def _run_non_streaming(
630
+ self,
631
+ messages: list[dict[str, str]],
632
+ generation_kwargs: dict[str, Any],
633
+ tools: list["ChatCompletionInputTool"] | None = None,
634
+ ) -> dict[str, list[ChatMessage]]:
635
+ api_chat_output: ChatCompletionOutput = self._client.chat_completion(
636
+ messages=messages, tools=tools, **generation_kwargs
637
+ )
638
+
639
+ if api_chat_output.choices is None or len(api_chat_output.choices) == 0:
640
+ return {"replies": []}
641
+
642
+ # n is unused, so the API always returns only one choice
643
+ # the argument is probably allowed for compatibility with OpenAI
644
+ # see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n
645
+ choice = api_chat_output.choices[0]
646
+
647
+ text = choice.message.content
648
+
649
+ tool_calls = _convert_hfapi_tool_calls(choice.message.tool_calls)
650
+
651
+ # Extract reasoning content if present
652
+ reasoning = _extract_reasoning_content(choice.message)
653
+
654
+ mapped_finish_reason = _map_hf_finish_reason_to_haystack(choice) if choice.finish_reason else None
655
+ meta: dict[str, Any] = {
656
+ "model": self._client.model,
657
+ "finish_reason": mapped_finish_reason,
658
+ "index": choice.index,
659
+ }
660
+
661
+ usage = {"prompt_tokens": 0, "completion_tokens": 0}
662
+ if api_chat_output.usage:
663
+ usage = {
664
+ "prompt_tokens": api_chat_output.usage.prompt_tokens,
665
+ "completion_tokens": api_chat_output.usage.completion_tokens,
666
+ }
667
+ meta["usage"] = usage
668
+
669
+ message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls, reasoning=reasoning, meta=meta)
670
+ return {"replies": [message]}
671
+
672
+ async def _run_streaming_async(
673
+ self,
674
+ messages: list[dict[str, str]],
675
+ generation_kwargs: dict[str, Any],
676
+ streaming_callback: AsyncStreamingCallbackT,
677
+ ) -> dict[str, list[ChatMessage]]:
678
+ api_output: AsyncIterable[ChatCompletionStreamOutput] = await self._async_client.chat_completion(
679
+ messages,
680
+ stream=True,
681
+ stream_options=ChatCompletionInputStreamOptions(include_usage=True),
682
+ **generation_kwargs,
683
+ )
684
+
685
+ component_info = ComponentInfo.from_component(self)
686
+ streaming_chunks: list[StreamingChunk] = []
687
+ async for chunk in api_output:
688
+ stream_chunk = _convert_chat_completion_stream_output_to_streaming_chunk(
689
+ chunk=chunk, previous_chunks=streaming_chunks, component_info=component_info
690
+ )
691
+ streaming_chunks.append(stream_chunk)
692
+ await streaming_callback(stream_chunk)
693
+
694
+ message = _convert_streaming_chunks_to_chat_message(chunks=streaming_chunks)
695
+ if message.meta.get("usage") is None:
696
+ message.meta["usage"] = {"prompt_tokens": 0, "completion_tokens": 0}
697
+
698
+ return {"replies": [message]}
699
+
700
+ async def _run_non_streaming_async(
701
+ self,
702
+ messages: list[dict[str, str]],
703
+ generation_kwargs: dict[str, Any],
704
+ tools: list["ChatCompletionInputTool"] | None = None,
705
+ ) -> dict[str, list[ChatMessage]]:
706
+ api_chat_output: ChatCompletionOutput = await self._async_client.chat_completion(
707
+ messages=messages, tools=tools, **generation_kwargs
708
+ )
709
+
710
+ if api_chat_output.choices is None or len(api_chat_output.choices) == 0:
711
+ return {"replies": []}
712
+
713
+ choice = api_chat_output.choices[0]
714
+
715
+ text = choice.message.content
716
+
717
+ tool_calls = _convert_hfapi_tool_calls(choice.message.tool_calls)
718
+
719
+ # Extract reasoning content if present
720
+ reasoning = _extract_reasoning_content(choice.message)
721
+
722
+ mapped_finish_reason = _map_hf_finish_reason_to_haystack(choice) if choice.finish_reason else None
723
+ meta: dict[str, Any] = {
724
+ "model": self._async_client.model,
725
+ "finish_reason": mapped_finish_reason,
726
+ "index": choice.index,
727
+ }
728
+
729
+ usage = {"prompt_tokens": 0, "completion_tokens": 0}
730
+ if api_chat_output.usage:
731
+ usage = {
732
+ "prompt_tokens": api_chat_output.usage.prompt_tokens,
733
+ "completion_tokens": api_chat_output.usage.completion_tokens,
734
+ }
735
+ meta["usage"] = usage
736
+
737
+ message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls, reasoning=reasoning, meta=meta)
738
+ return {"replies": [message]}