huggingface-api-haystack 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- haystack_integrations/components/common/huggingface_api/__init__.py +3 -0
- haystack_integrations/components/common/huggingface_api/utils.py +112 -0
- haystack_integrations/components/common/py.typed +0 -0
- haystack_integrations/components/embedders/huggingface_api/__init__.py +7 -0
- haystack_integrations/components/embedders/huggingface_api/document_embedder.py +382 -0
- haystack_integrations/components/embedders/huggingface_api/text_embedder.py +262 -0
- haystack_integrations/components/embedders/py.typed +0 -0
- haystack_integrations/components/generators/huggingface_api/__init__.py +6 -0
- haystack_integrations/components/generators/huggingface_api/chat/__init__.py +3 -0
- haystack_integrations/components/generators/huggingface_api/chat/chat_generator.py +738 -0
- haystack_integrations/components/generators/py.typed +0 -0
- huggingface_api_haystack-0.1.0.dist-info/METADATA +40 -0
- huggingface_api_haystack-0.1.0.dist-info/RECORD +15 -0
- huggingface_api_haystack-0.1.0.dist-info/WHEEL +4 -0
- huggingface_api_haystack-0.1.0.dist-info/licenses/LICENSE.txt +201 -0
|
@@ -0,0 +1,738 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
|
2
|
+
#
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
from collections.abc import AsyncIterable, Iterable
|
|
7
|
+
from datetime import datetime, timezone
|
|
8
|
+
from typing import Any, Union
|
|
9
|
+
|
|
10
|
+
from haystack import component, default_from_dict, default_to_dict, logging
|
|
11
|
+
from haystack.components.generators.utils import _convert_streaming_chunks_to_chat_message, _normalize_messages
|
|
12
|
+
from haystack.dataclasses import (
|
|
13
|
+
AsyncStreamingCallbackT,
|
|
14
|
+
ChatMessage,
|
|
15
|
+
ComponentInfo,
|
|
16
|
+
ReasoningContent,
|
|
17
|
+
StreamingCallbackT,
|
|
18
|
+
StreamingChunk,
|
|
19
|
+
SyncStreamingCallbackT,
|
|
20
|
+
ToolCall,
|
|
21
|
+
select_streaming_callback,
|
|
22
|
+
)
|
|
23
|
+
from haystack.dataclasses.streaming_chunk import FinishReason
|
|
24
|
+
from haystack.tools import (
|
|
25
|
+
ToolsType,
|
|
26
|
+
_check_duplicate_tool_names,
|
|
27
|
+
deserialize_tools_or_toolset_inplace,
|
|
28
|
+
flatten_tools_or_toolsets,
|
|
29
|
+
serialize_tools_or_toolset,
|
|
30
|
+
warm_up_tools,
|
|
31
|
+
)
|
|
32
|
+
from haystack.utils import Secret, deserialize_callable, serialize_callable
|
|
33
|
+
from haystack.utils.hf import convert_message_to_hf_format
|
|
34
|
+
from haystack.utils.url_validation import is_valid_http_url
|
|
35
|
+
from huggingface_hub import (
|
|
36
|
+
AsyncInferenceClient,
|
|
37
|
+
ChatCompletionInputFunctionDefinition,
|
|
38
|
+
ChatCompletionInputStreamOptions,
|
|
39
|
+
ChatCompletionInputTool,
|
|
40
|
+
ChatCompletionOutput,
|
|
41
|
+
ChatCompletionOutputComplete,
|
|
42
|
+
ChatCompletionOutputToolCall,
|
|
43
|
+
ChatCompletionStreamOutput,
|
|
44
|
+
ChatCompletionStreamOutputChoice,
|
|
45
|
+
InferenceClient,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
from haystack_integrations.components.common.huggingface_api.utils import (
|
|
49
|
+
HFGenerationAPIType,
|
|
50
|
+
HFModelType,
|
|
51
|
+
_check_valid_model,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
logger = logging.getLogger(__name__)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _convert_hfapi_tool_calls(hfapi_tool_calls: list["ChatCompletionOutputToolCall"] | None) -> list[ToolCall]:
|
|
58
|
+
"""
|
|
59
|
+
Convert HuggingFace API tool calls to a list of Haystack ToolCall.
|
|
60
|
+
|
|
61
|
+
:param hfapi_tool_calls: The HuggingFace API tool calls to convert.
|
|
62
|
+
:returns: A list of ToolCall objects.
|
|
63
|
+
|
|
64
|
+
"""
|
|
65
|
+
if not hfapi_tool_calls:
|
|
66
|
+
return []
|
|
67
|
+
|
|
68
|
+
tool_calls = []
|
|
69
|
+
|
|
70
|
+
for hfapi_tc in hfapi_tool_calls:
|
|
71
|
+
hf_arguments = hfapi_tc.function.arguments
|
|
72
|
+
|
|
73
|
+
arguments = None
|
|
74
|
+
if isinstance(hf_arguments, dict):
|
|
75
|
+
arguments = hf_arguments
|
|
76
|
+
elif isinstance(hf_arguments, str):
|
|
77
|
+
try:
|
|
78
|
+
arguments = json.loads(hf_arguments)
|
|
79
|
+
except json.JSONDecodeError:
|
|
80
|
+
logger.warning(
|
|
81
|
+
"HuggingFace API returned a malformed JSON string for tool call arguments. This tool call "
|
|
82
|
+
"will be skipped. Tool call ID: {_id}, Tool name: {_name}, Arguments: {_arguments}",
|
|
83
|
+
_id=hfapi_tc.id,
|
|
84
|
+
_name=hfapi_tc.function.name,
|
|
85
|
+
_arguments=hf_arguments,
|
|
86
|
+
)
|
|
87
|
+
else:
|
|
88
|
+
logger.warning(
|
|
89
|
+
"HuggingFace API returned tool call arguments of type {_type}. Valid types are dict and str. This tool "
|
|
90
|
+
"call will be skipped. Tool call ID: {_id}, Tool name: {_name}, Arguments: {_arguments}",
|
|
91
|
+
_id=hfapi_tc.id,
|
|
92
|
+
_name=hfapi_tc.function.name,
|
|
93
|
+
_arguments=hf_arguments,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
if arguments:
|
|
97
|
+
tool_calls.append(ToolCall(tool_name=hfapi_tc.function.name, arguments=arguments, id=hfapi_tc.id))
|
|
98
|
+
|
|
99
|
+
return tool_calls
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _extract_reasoning_content(message_or_delta: Any) -> ReasoningContent | None:
|
|
103
|
+
"""
|
|
104
|
+
Extract reasoning content from a HuggingFace API message or delta object.
|
|
105
|
+
|
|
106
|
+
:param message_or_delta: The HuggingFace message or delta object that may contain reasoning.
|
|
107
|
+
:returns: ReasoningContent if reasoning is present, None otherwise.
|
|
108
|
+
"""
|
|
109
|
+
if hasattr(message_or_delta, "reasoning") and message_or_delta.reasoning:
|
|
110
|
+
return ReasoningContent(reasoning_text=message_or_delta.reasoning)
|
|
111
|
+
return None
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def _resolve_schema_refs(schema: dict[str, Any]) -> dict[str, Any]:
|
|
115
|
+
"""
|
|
116
|
+
Resolve ``$ref`` references in a JSON schema by inlining ``$defs`` definitions.
|
|
117
|
+
|
|
118
|
+
The HuggingFace API does not support ``$defs`` and ``$ref`` in tool parameter schemas.
|
|
119
|
+
This function expands all ``$ref`` pointers and removes the ``$defs`` section.
|
|
120
|
+
|
|
121
|
+
:param schema: A JSON schema dict potentially containing ``$defs`` and ``$ref``.
|
|
122
|
+
:returns: A new schema dict with all references resolved and ``$defs`` removed.
|
|
123
|
+
"""
|
|
124
|
+
defs = schema.get("$defs", {})
|
|
125
|
+
if not defs:
|
|
126
|
+
return schema
|
|
127
|
+
|
|
128
|
+
def _resolve(obj: Any, resolving: set[str] | None = None) -> Any:
|
|
129
|
+
if resolving is None:
|
|
130
|
+
resolving = set()
|
|
131
|
+
if isinstance(obj, dict):
|
|
132
|
+
if "$ref" in obj:
|
|
133
|
+
ref_path = obj["$ref"]
|
|
134
|
+
ref_prefix = "#/$defs/"
|
|
135
|
+
if isinstance(ref_path, str) and ref_path.startswith(ref_prefix):
|
|
136
|
+
def_name = ref_path.removeprefix(ref_prefix)
|
|
137
|
+
if def_name in defs and def_name not in resolving:
|
|
138
|
+
return _resolve(defs[def_name], resolving | {def_name})
|
|
139
|
+
return {k: _resolve(v, resolving) for k, v in obj.items() if k != "$defs"}
|
|
140
|
+
if isinstance(obj, list):
|
|
141
|
+
return [_resolve(item, resolving) for item in obj]
|
|
142
|
+
return obj
|
|
143
|
+
|
|
144
|
+
return _resolve(schema)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def _convert_tools_to_hfapi_tools(tools: ToolsType | None) -> list["ChatCompletionInputTool"] | None:
|
|
148
|
+
if not tools:
|
|
149
|
+
return None
|
|
150
|
+
|
|
151
|
+
hf_tools = []
|
|
152
|
+
for tool in flatten_tools_or_toolsets(tools):
|
|
153
|
+
hf_tools.append(
|
|
154
|
+
ChatCompletionInputTool(
|
|
155
|
+
function=ChatCompletionInputFunctionDefinition(
|
|
156
|
+
name=tool.name,
|
|
157
|
+
description=tool.description,
|
|
158
|
+
parameters=_resolve_schema_refs(tool.parameters),
|
|
159
|
+
),
|
|
160
|
+
type="function",
|
|
161
|
+
)
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
return hf_tools
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def _map_hf_finish_reason_to_haystack(
|
|
168
|
+
choice: Union["ChatCompletionStreamOutputChoice", "ChatCompletionOutputComplete"],
|
|
169
|
+
) -> FinishReason | None:
|
|
170
|
+
"""
|
|
171
|
+
Map HuggingFace finish reasons to Haystack FinishReason literals.
|
|
172
|
+
|
|
173
|
+
Uses the full choice object to detect tool calls and provide accurate mapping.
|
|
174
|
+
|
|
175
|
+
HuggingFace finish reasons (can be found here https://huggingface.github.io/text-generation-inference/ under
|
|
176
|
+
FinishReason):
|
|
177
|
+
- "length": number of generated tokens == `max_new_tokens`
|
|
178
|
+
- "eos_token": the model generated its end of sequence token
|
|
179
|
+
- "stop_sequence": the model generated a text included in `stop_sequences`
|
|
180
|
+
|
|
181
|
+
Additionally, detects tool calls from delta.tool_calls or delta.tool_call_id.
|
|
182
|
+
|
|
183
|
+
:param choice: The HuggingFace ChatCompletionStreamOutputChoice object.
|
|
184
|
+
:returns: The corresponding Haystack FinishReason or None.
|
|
185
|
+
"""
|
|
186
|
+
if choice.finish_reason is None:
|
|
187
|
+
return None
|
|
188
|
+
|
|
189
|
+
# Check if this choice contains tool call information
|
|
190
|
+
if isinstance(choice, ChatCompletionStreamOutputChoice):
|
|
191
|
+
has_tool_calls = choice.delta.tool_calls is not None or choice.delta.tool_call_id is not None
|
|
192
|
+
else:
|
|
193
|
+
has_tool_calls = choice.message.tool_calls is not None or choice.message.tool_call_id is not None
|
|
194
|
+
|
|
195
|
+
# If we detect tool calls, override the finish reason
|
|
196
|
+
if has_tool_calls:
|
|
197
|
+
return "tool_calls"
|
|
198
|
+
|
|
199
|
+
# Map HuggingFace finish reasons to Haystack standard ones
|
|
200
|
+
mapping: dict[str, FinishReason] = {
|
|
201
|
+
"length": "length", # Direct match
|
|
202
|
+
"eos_token": "stop", # EOS token means natural stop
|
|
203
|
+
"stop_sequence": "stop", # Stop sequence means natural stop
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
return mapping.get(choice.finish_reason, "stop") # Default to "stop" for unknown reasons
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def _convert_chat_completion_stream_output_to_streaming_chunk(
|
|
210
|
+
chunk: "ChatCompletionStreamOutput",
|
|
211
|
+
previous_chunks: list[StreamingChunk],
|
|
212
|
+
component_info: ComponentInfo | None = None,
|
|
213
|
+
) -> StreamingChunk:
|
|
214
|
+
"""
|
|
215
|
+
Converts the Hugging Face API ChatCompletionStreamOutput to a StreamingChunk.
|
|
216
|
+
"""
|
|
217
|
+
# Choices is empty if include_usage is set to True where the usage information is returned.
|
|
218
|
+
if len(chunk.choices) == 0:
|
|
219
|
+
usage = None
|
|
220
|
+
if chunk.usage:
|
|
221
|
+
usage = {"prompt_tokens": chunk.usage.prompt_tokens, "completion_tokens": chunk.usage.completion_tokens}
|
|
222
|
+
return StreamingChunk(
|
|
223
|
+
content="",
|
|
224
|
+
meta={"model": chunk.model, "received_at": datetime.now(timezone.utc).isoformat(), "usage": usage},
|
|
225
|
+
component_info=component_info,
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
# n is unused, so the API always returns only one choice
|
|
229
|
+
# the argument is probably allowed for compatibility with OpenAI
|
|
230
|
+
# see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n
|
|
231
|
+
choice = chunk.choices[0]
|
|
232
|
+
mapped_finish_reason = _map_hf_finish_reason_to_haystack(choice) if choice.finish_reason else None
|
|
233
|
+
|
|
234
|
+
# Extract reasoning content if present
|
|
235
|
+
reasoning = _extract_reasoning_content(choice.delta)
|
|
236
|
+
|
|
237
|
+
return StreamingChunk(
|
|
238
|
+
content=choice.delta.content or "",
|
|
239
|
+
meta={
|
|
240
|
+
"model": chunk.model,
|
|
241
|
+
"received_at": datetime.now(timezone.utc).isoformat(),
|
|
242
|
+
"finish_reason": choice.finish_reason,
|
|
243
|
+
},
|
|
244
|
+
component_info=component_info,
|
|
245
|
+
# Index must always be 0 since we don't allow tool calls in streaming mode.
|
|
246
|
+
index=0 if choice.finish_reason is None else None,
|
|
247
|
+
# start is True at the very beginning since first chunk contains role information + first part of the answer.
|
|
248
|
+
start=len(previous_chunks) == 0,
|
|
249
|
+
finish_reason=mapped_finish_reason,
|
|
250
|
+
reasoning=reasoning,
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
@component
|
|
255
|
+
class HuggingFaceAPIChatGenerator:
|
|
256
|
+
"""
|
|
257
|
+
Completes chats using Hugging Face APIs.
|
|
258
|
+
|
|
259
|
+
HuggingFaceAPIChatGenerator uses the [ChatMessage](https://docs.haystack.deepset.ai/docs/chatmessage)
|
|
260
|
+
format for input and output. Use it to generate text with Hugging Face APIs:
|
|
261
|
+
- [Serverless Inference API (Inference Providers)](https://huggingface.co/docs/inference-providers)
|
|
262
|
+
- [Paid Inference Endpoints](https://huggingface.co/inference-endpoints)
|
|
263
|
+
- [Self-hosted Text Generation Inference](https://github.com/huggingface/text-generation-inference)
|
|
264
|
+
|
|
265
|
+
### Usage examples
|
|
266
|
+
|
|
267
|
+
#### With the serverless inference API (Inference Providers) - free tier available
|
|
268
|
+
|
|
269
|
+
```python
|
|
270
|
+
from haystack_integrations.components.generators.huggingface_api import HuggingFaceAPIChatGenerator
|
|
271
|
+
from haystack.dataclasses import ChatMessage
|
|
272
|
+
from haystack.utils import Secret
|
|
273
|
+
from haystack_integrations.components.common.huggingface_api.utils import HFGenerationAPIType
|
|
274
|
+
|
|
275
|
+
messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant"),
|
|
276
|
+
ChatMessage.from_user("What's Natural Language Processing?")]
|
|
277
|
+
|
|
278
|
+
# the api_type can be expressed using the HFGenerationAPIType enum or as a string
|
|
279
|
+
api_type = HFGenerationAPIType.SERVERLESS_INFERENCE_API
|
|
280
|
+
api_type = "serverless_inference_api" # this is equivalent to the above
|
|
281
|
+
|
|
282
|
+
generator = HuggingFaceAPIChatGenerator(api_type=api_type,
|
|
283
|
+
api_params={"model": "Qwen/Qwen2.5-7B-Instruct",
|
|
284
|
+
"provider": "together"},
|
|
285
|
+
token=Secret.from_token("<your-api-key>"))
|
|
286
|
+
|
|
287
|
+
result = generator.run(messages)
|
|
288
|
+
print(result)
|
|
289
|
+
```
|
|
290
|
+
|
|
291
|
+
#### With the serverless inference API (Inference Providers) and text+image input
|
|
292
|
+
|
|
293
|
+
```python
|
|
294
|
+
from haystack_integrations.components.generators.huggingface_api import HuggingFaceAPIChatGenerator
|
|
295
|
+
from haystack.dataclasses import ChatMessage, ImageContent
|
|
296
|
+
from haystack.utils import Secret
|
|
297
|
+
from haystack_integrations.components.common.huggingface_api.utils import HFGenerationAPIType
|
|
298
|
+
|
|
299
|
+
# Create an image from file path, URL, or base64
|
|
300
|
+
image = ImageContent.from_file_path("path/to/your/image.jpg")
|
|
301
|
+
|
|
302
|
+
# Create a multimodal message with both text and image
|
|
303
|
+
messages = [ChatMessage.from_user(content_parts=["Describe this image in detail", image])]
|
|
304
|
+
|
|
305
|
+
generator = HuggingFaceAPIChatGenerator(
|
|
306
|
+
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
|
|
307
|
+
api_params={
|
|
308
|
+
"model": "Qwen/Qwen2.5-VL-7B-Instruct", # Vision Language Model
|
|
309
|
+
"provider": "hyperbolic"
|
|
310
|
+
},
|
|
311
|
+
token=Secret.from_token("<your-api-key>")
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
result = generator.run(messages)
|
|
315
|
+
print(result)
|
|
316
|
+
```
|
|
317
|
+
|
|
318
|
+
#### With paid inference endpoints
|
|
319
|
+
|
|
320
|
+
```python
|
|
321
|
+
from haystack_integrations.components.generators.huggingface_api import HuggingFaceAPIChatGenerator
|
|
322
|
+
from haystack.dataclasses import ChatMessage
|
|
323
|
+
from haystack.utils import Secret
|
|
324
|
+
|
|
325
|
+
messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant"),
|
|
326
|
+
ChatMessage.from_user("What's Natural Language Processing?")]
|
|
327
|
+
|
|
328
|
+
generator = HuggingFaceAPIChatGenerator(api_type="inference_endpoints",
|
|
329
|
+
api_params={"url": "<your-inference-endpoint-url>"},
|
|
330
|
+
token=Secret.from_token("<your-api-key>"))
|
|
331
|
+
|
|
332
|
+
result = generator.run(messages)
|
|
333
|
+
print(result)
|
|
334
|
+
```
|
|
335
|
+
|
|
336
|
+
#### With self-hosted text generation inference
|
|
337
|
+
|
|
338
|
+
```python
|
|
339
|
+
from haystack_integrations.components.generators.huggingface_api import HuggingFaceAPIChatGenerator
|
|
340
|
+
from haystack.dataclasses import ChatMessage
|
|
341
|
+
|
|
342
|
+
messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant"),
|
|
343
|
+
ChatMessage.from_user("What's Natural Language Processing?")]
|
|
344
|
+
|
|
345
|
+
generator = HuggingFaceAPIChatGenerator(api_type="text_generation_inference",
|
|
346
|
+
api_params={"url": "http://localhost:8080"})
|
|
347
|
+
|
|
348
|
+
result = generator.run(messages)
|
|
349
|
+
print(result)
|
|
350
|
+
```
|
|
351
|
+
"""
|
|
352
|
+
|
|
353
|
+
def __init__(
|
|
354
|
+
self,
|
|
355
|
+
api_type: HFGenerationAPIType | str,
|
|
356
|
+
api_params: dict[str, str],
|
|
357
|
+
token: Secret | None = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
|
|
358
|
+
generation_kwargs: dict[str, Any] | None = None,
|
|
359
|
+
stop_words: list[str] | None = None,
|
|
360
|
+
streaming_callback: StreamingCallbackT | None = None,
|
|
361
|
+
tools: ToolsType | None = None,
|
|
362
|
+
) -> None:
|
|
363
|
+
"""
|
|
364
|
+
Initialize the HuggingFaceAPIChatGenerator instance.
|
|
365
|
+
|
|
366
|
+
:param api_type:
|
|
367
|
+
The type of Hugging Face API to use. Available types:
|
|
368
|
+
- `text_generation_inference`: See [TGI](https://github.com/huggingface/text-generation-inference).
|
|
369
|
+
- `inference_endpoints`: See [Inference Endpoints](https://huggingface.co/inference-endpoints).
|
|
370
|
+
- `serverless_inference_api`: See
|
|
371
|
+
[Serverless Inference API - Inference Providers](https://huggingface.co/docs/inference-providers).
|
|
372
|
+
:param api_params:
|
|
373
|
+
A dictionary with the following keys:
|
|
374
|
+
- `model`: Hugging Face model ID. Required when `api_type` is `SERVERLESS_INFERENCE_API`.
|
|
375
|
+
- `provider`: Provider name. Recommended when `api_type` is `SERVERLESS_INFERENCE_API`.
|
|
376
|
+
- `url`: URL of the inference endpoint. Required when `api_type` is `INFERENCE_ENDPOINTS` or
|
|
377
|
+
`TEXT_GENERATION_INFERENCE`.
|
|
378
|
+
- Other parameters specific to the chosen API type, such as `timeout`, `headers`, etc.
|
|
379
|
+
:param token:
|
|
380
|
+
The Hugging Face token to use as HTTP bearer authorization.
|
|
381
|
+
Check your HF token in your [account settings](https://huggingface.co/settings/tokens).
|
|
382
|
+
:param generation_kwargs:
|
|
383
|
+
A dictionary with keyword arguments to customize text generation.
|
|
384
|
+
Some examples: `max_tokens`, `temperature`, `top_p`.
|
|
385
|
+
For details, see [Hugging Face chat_completion documentation](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion).
|
|
386
|
+
:param stop_words:
|
|
387
|
+
An optional list of strings representing the stop words.
|
|
388
|
+
:param streaming_callback:
|
|
389
|
+
An optional callable for handling streaming responses.
|
|
390
|
+
:param tools:
|
|
391
|
+
A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
|
|
392
|
+
The chosen model should support tool/function calling, according to the model card.
|
|
393
|
+
Support for tools in the Hugging Face API and TGI is not yet fully refined and you may experience
|
|
394
|
+
unexpected behavior.
|
|
395
|
+
"""
|
|
396
|
+
if isinstance(api_type, str):
|
|
397
|
+
api_type = HFGenerationAPIType.from_str(api_type)
|
|
398
|
+
|
|
399
|
+
if api_type == HFGenerationAPIType.SERVERLESS_INFERENCE_API:
|
|
400
|
+
model = api_params.get("model")
|
|
401
|
+
if model is None:
|
|
402
|
+
msg = "To use the Serverless Inference API, you need to specify the `model` parameter in `api_params`."
|
|
403
|
+
raise ValueError(msg)
|
|
404
|
+
_check_valid_model(model, HFModelType.GENERATION, token)
|
|
405
|
+
model_or_url = model
|
|
406
|
+
elif api_type in [HFGenerationAPIType.INFERENCE_ENDPOINTS, HFGenerationAPIType.TEXT_GENERATION_INFERENCE]:
|
|
407
|
+
url = api_params.get("url")
|
|
408
|
+
if url is None:
|
|
409
|
+
msg = (
|
|
410
|
+
"To use Text Generation Inference or Inference Endpoints, you need to specify the `url` parameter "
|
|
411
|
+
"in `api_params`."
|
|
412
|
+
)
|
|
413
|
+
raise ValueError(msg)
|
|
414
|
+
if not is_valid_http_url(url):
|
|
415
|
+
msg = f"Invalid URL: {url}"
|
|
416
|
+
raise ValueError(msg)
|
|
417
|
+
model_or_url = url
|
|
418
|
+
else:
|
|
419
|
+
msg = f"Unknown api_type {api_type}"
|
|
420
|
+
raise ValueError(msg)
|
|
421
|
+
|
|
422
|
+
if tools and streaming_callback is not None:
|
|
423
|
+
msg = "Using tools and streaming at the same time is not supported. Please choose one."
|
|
424
|
+
raise ValueError(msg)
|
|
425
|
+
_check_duplicate_tool_names(flatten_tools_or_toolsets(tools))
|
|
426
|
+
|
|
427
|
+
# handle generation kwargs setup
|
|
428
|
+
generation_kwargs = generation_kwargs.copy() if generation_kwargs else {}
|
|
429
|
+
generation_kwargs["stop"] = generation_kwargs.get("stop", [])
|
|
430
|
+
generation_kwargs["stop"].extend(stop_words or [])
|
|
431
|
+
generation_kwargs.setdefault("max_tokens", 512)
|
|
432
|
+
|
|
433
|
+
self.api_type = api_type
|
|
434
|
+
self.api_params = api_params
|
|
435
|
+
self.token = token
|
|
436
|
+
self.generation_kwargs = generation_kwargs
|
|
437
|
+
self.streaming_callback = streaming_callback
|
|
438
|
+
|
|
439
|
+
resolved_api_params: dict[str, Any] = {k: v for k, v in api_params.items() if k not in ("model", "url")}
|
|
440
|
+
self._client = InferenceClient(
|
|
441
|
+
model_or_url, token=token.resolve_value() if token else None, **resolved_api_params
|
|
442
|
+
)
|
|
443
|
+
self._async_client = AsyncInferenceClient(
|
|
444
|
+
model_or_url, token=token.resolve_value() if token else None, **resolved_api_params
|
|
445
|
+
)
|
|
446
|
+
self.tools = tools
|
|
447
|
+
self._is_warmed_up = False
|
|
448
|
+
|
|
449
|
+
def warm_up(self) -> None:
|
|
450
|
+
"""
|
|
451
|
+
Warm up the Hugging Face API chat generator.
|
|
452
|
+
|
|
453
|
+
This will warm up the tools registered in the chat generator.
|
|
454
|
+
This method is idempotent and will only warm up the tools once.
|
|
455
|
+
"""
|
|
456
|
+
if not self._is_warmed_up:
|
|
457
|
+
warm_up_tools(self.tools)
|
|
458
|
+
self._is_warmed_up = True
|
|
459
|
+
|
|
460
|
+
def to_dict(self) -> dict[str, Any]:
|
|
461
|
+
"""
|
|
462
|
+
Serialize this component to a dictionary.
|
|
463
|
+
|
|
464
|
+
:returns:
|
|
465
|
+
A dictionary containing the serialized component.
|
|
466
|
+
"""
|
|
467
|
+
callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
|
|
468
|
+
return default_to_dict(
|
|
469
|
+
self,
|
|
470
|
+
api_type=str(self.api_type),
|
|
471
|
+
api_params=self.api_params,
|
|
472
|
+
token=self.token,
|
|
473
|
+
generation_kwargs=self.generation_kwargs,
|
|
474
|
+
streaming_callback=callback_name,
|
|
475
|
+
tools=serialize_tools_or_toolset(self.tools),
|
|
476
|
+
)
|
|
477
|
+
|
|
478
|
+
@classmethod
|
|
479
|
+
def from_dict(cls, data: dict[str, Any]) -> "HuggingFaceAPIChatGenerator":
|
|
480
|
+
"""
|
|
481
|
+
Deserialize this component from a dictionary.
|
|
482
|
+
"""
|
|
483
|
+
deserialize_tools_or_toolset_inplace(data["init_parameters"], key="tools")
|
|
484
|
+
init_params = data.get("init_parameters", {})
|
|
485
|
+
serialized_callback_handler = init_params.get("streaming_callback")
|
|
486
|
+
if serialized_callback_handler:
|
|
487
|
+
data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)
|
|
488
|
+
return default_from_dict(cls, data)
|
|
489
|
+
|
|
490
|
+
@component.output_types(replies=list[ChatMessage])
|
|
491
|
+
def run(
|
|
492
|
+
self,
|
|
493
|
+
messages: list[ChatMessage] | str,
|
|
494
|
+
generation_kwargs: dict[str, Any] | None = None,
|
|
495
|
+
tools: ToolsType | None = None,
|
|
496
|
+
streaming_callback: StreamingCallbackT | None = None,
|
|
497
|
+
) -> dict[str, list[ChatMessage]]:
|
|
498
|
+
"""
|
|
499
|
+
Invoke the text generation inference based on the provided messages and generation parameters.
|
|
500
|
+
|
|
501
|
+
:param messages:
|
|
502
|
+
A list of ChatMessage objects representing the input messages. If a string is provided, it is converted
|
|
503
|
+
to a list containing a ChatMessage with user role.
|
|
504
|
+
:param generation_kwargs:
|
|
505
|
+
Additional keyword arguments for text generation.
|
|
506
|
+
:param tools:
|
|
507
|
+
A list of tools or a Toolset for which the model can prepare calls. If set, it will override
|
|
508
|
+
the `tools` parameter set during component initialization. This parameter can accept either a
|
|
509
|
+
list of `Tool` objects or a `Toolset` instance.
|
|
510
|
+
:param streaming_callback:
|
|
511
|
+
An optional callable for handling streaming responses. If set, it will override the `streaming_callback`
|
|
512
|
+
parameter set during component initialization.
|
|
513
|
+
:returns: A dictionary with the following keys:
|
|
514
|
+
- `replies`: A list containing the generated responses as ChatMessage objects.
|
|
515
|
+
"""
|
|
516
|
+
if not self._is_warmed_up:
|
|
517
|
+
self.warm_up()
|
|
518
|
+
|
|
519
|
+
messages = _normalize_messages(messages)
|
|
520
|
+
|
|
521
|
+
# update generation kwargs by merging with the default ones
|
|
522
|
+
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
|
|
523
|
+
|
|
524
|
+
formatted_messages = [convert_message_to_hf_format(message) for message in messages]
|
|
525
|
+
|
|
526
|
+
tools = tools or self.tools
|
|
527
|
+
if tools and self.streaming_callback:
|
|
528
|
+
msg = "Using tools and streaming at the same time is not supported. Please choose one."
|
|
529
|
+
raise ValueError(msg)
|
|
530
|
+
flat_tools = flatten_tools_or_toolsets(tools)
|
|
531
|
+
_check_duplicate_tool_names(flat_tools)
|
|
532
|
+
|
|
533
|
+
# validate and select the streaming callback
|
|
534
|
+
streaming_callback = select_streaming_callback(
|
|
535
|
+
self.streaming_callback, streaming_callback, requires_async=False
|
|
536
|
+
)
|
|
537
|
+
|
|
538
|
+
if streaming_callback:
|
|
539
|
+
return self._run_streaming(formatted_messages, generation_kwargs, streaming_callback)
|
|
540
|
+
|
|
541
|
+
hf_tools = _convert_tools_to_hfapi_tools(tools)
|
|
542
|
+
|
|
543
|
+
return self._run_non_streaming(formatted_messages, generation_kwargs, hf_tools)
|
|
544
|
+
|
|
545
|
+
@component.output_types(replies=list[ChatMessage])
|
|
546
|
+
async def run_async(
|
|
547
|
+
self,
|
|
548
|
+
messages: list[ChatMessage] | str,
|
|
549
|
+
generation_kwargs: dict[str, Any] | None = None,
|
|
550
|
+
tools: ToolsType | None = None,
|
|
551
|
+
streaming_callback: StreamingCallbackT | None = None,
|
|
552
|
+
) -> dict[str, list[ChatMessage]]:
|
|
553
|
+
"""
|
|
554
|
+
Asynchronously invokes the text generation inference based on the provided messages and generation parameters.
|
|
555
|
+
|
|
556
|
+
This is the asynchronous version of the `run` method. It has the same parameters
|
|
557
|
+
and return values but can be used with `await` in an async code.
|
|
558
|
+
|
|
559
|
+
:param messages:
|
|
560
|
+
A list of ChatMessage objects representing the input messages. If a string is provided, it is converted
|
|
561
|
+
to a list containing a ChatMessage with user role.
|
|
562
|
+
:param generation_kwargs:
|
|
563
|
+
Additional keyword arguments for text generation.
|
|
564
|
+
:param tools:
|
|
565
|
+
A list of tools or a Toolset for which the model can prepare calls. If set, it will override the `tools`
|
|
566
|
+
parameter set during component initialization. This parameter can accept either a list of `Tool` objects
|
|
567
|
+
or a `Toolset` instance.
|
|
568
|
+
:param streaming_callback:
|
|
569
|
+
An optional callable for handling streaming responses. If set, it will override the `streaming_callback`
|
|
570
|
+
parameter set during component initialization.
|
|
571
|
+
:returns: A dictionary with the following keys:
|
|
572
|
+
- `replies`: A list containing the generated responses as ChatMessage objects.
|
|
573
|
+
"""
|
|
574
|
+
if not self._is_warmed_up:
|
|
575
|
+
self.warm_up()
|
|
576
|
+
|
|
577
|
+
messages = _normalize_messages(messages)
|
|
578
|
+
|
|
579
|
+
# update generation kwargs by merging with the default ones
|
|
580
|
+
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
|
|
581
|
+
|
|
582
|
+
formatted_messages = [convert_message_to_hf_format(message) for message in messages]
|
|
583
|
+
|
|
584
|
+
tools = tools or self.tools
|
|
585
|
+
if tools and self.streaming_callback:
|
|
586
|
+
msg = "Using tools and streaming at the same time is not supported. Please choose one."
|
|
587
|
+
raise ValueError(msg)
|
|
588
|
+
flat_tools = flatten_tools_or_toolsets(tools)
|
|
589
|
+
_check_duplicate_tool_names(flat_tools)
|
|
590
|
+
|
|
591
|
+
# validate and select the streaming callback
|
|
592
|
+
streaming_callback = select_streaming_callback(self.streaming_callback, streaming_callback, requires_async=True)
|
|
593
|
+
|
|
594
|
+
if streaming_callback:
|
|
595
|
+
return await self._run_streaming_async(formatted_messages, generation_kwargs, streaming_callback)
|
|
596
|
+
|
|
597
|
+
hf_tools = _convert_tools_to_hfapi_tools(tools)
|
|
598
|
+
|
|
599
|
+
return await self._run_non_streaming_async(formatted_messages, generation_kwargs, hf_tools)
|
|
600
|
+
|
|
601
|
+
def _run_streaming(
|
|
602
|
+
self,
|
|
603
|
+
messages: list[dict[str, str]],
|
|
604
|
+
generation_kwargs: dict[str, Any],
|
|
605
|
+
streaming_callback: SyncStreamingCallbackT,
|
|
606
|
+
) -> dict[str, list[ChatMessage]]:
|
|
607
|
+
api_output: Iterable[ChatCompletionStreamOutput] = self._client.chat_completion(
|
|
608
|
+
messages,
|
|
609
|
+
stream=True,
|
|
610
|
+
stream_options=ChatCompletionInputStreamOptions(include_usage=True),
|
|
611
|
+
**generation_kwargs,
|
|
612
|
+
)
|
|
613
|
+
|
|
614
|
+
component_info = ComponentInfo.from_component(self)
|
|
615
|
+
streaming_chunks: list[StreamingChunk] = []
|
|
616
|
+
for chunk in api_output:
|
|
617
|
+
streaming_chunk = _convert_chat_completion_stream_output_to_streaming_chunk(
|
|
618
|
+
chunk=chunk, previous_chunks=streaming_chunks, component_info=component_info
|
|
619
|
+
)
|
|
620
|
+
streaming_chunks.append(streaming_chunk)
|
|
621
|
+
streaming_callback(streaming_chunk)
|
|
622
|
+
|
|
623
|
+
message = _convert_streaming_chunks_to_chat_message(chunks=streaming_chunks)
|
|
624
|
+
if message.meta.get("usage") is None:
|
|
625
|
+
message.meta["usage"] = {"prompt_tokens": 0, "completion_tokens": 0}
|
|
626
|
+
|
|
627
|
+
return {"replies": [message]}
|
|
628
|
+
|
|
629
|
+
def _run_non_streaming(
|
|
630
|
+
self,
|
|
631
|
+
messages: list[dict[str, str]],
|
|
632
|
+
generation_kwargs: dict[str, Any],
|
|
633
|
+
tools: list["ChatCompletionInputTool"] | None = None,
|
|
634
|
+
) -> dict[str, list[ChatMessage]]:
|
|
635
|
+
api_chat_output: ChatCompletionOutput = self._client.chat_completion(
|
|
636
|
+
messages=messages, tools=tools, **generation_kwargs
|
|
637
|
+
)
|
|
638
|
+
|
|
639
|
+
if api_chat_output.choices is None or len(api_chat_output.choices) == 0:
|
|
640
|
+
return {"replies": []}
|
|
641
|
+
|
|
642
|
+
# n is unused, so the API always returns only one choice
|
|
643
|
+
# the argument is probably allowed for compatibility with OpenAI
|
|
644
|
+
# see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n
|
|
645
|
+
choice = api_chat_output.choices[0]
|
|
646
|
+
|
|
647
|
+
text = choice.message.content
|
|
648
|
+
|
|
649
|
+
tool_calls = _convert_hfapi_tool_calls(choice.message.tool_calls)
|
|
650
|
+
|
|
651
|
+
# Extract reasoning content if present
|
|
652
|
+
reasoning = _extract_reasoning_content(choice.message)
|
|
653
|
+
|
|
654
|
+
mapped_finish_reason = _map_hf_finish_reason_to_haystack(choice) if choice.finish_reason else None
|
|
655
|
+
meta: dict[str, Any] = {
|
|
656
|
+
"model": self._client.model,
|
|
657
|
+
"finish_reason": mapped_finish_reason,
|
|
658
|
+
"index": choice.index,
|
|
659
|
+
}
|
|
660
|
+
|
|
661
|
+
usage = {"prompt_tokens": 0, "completion_tokens": 0}
|
|
662
|
+
if api_chat_output.usage:
|
|
663
|
+
usage = {
|
|
664
|
+
"prompt_tokens": api_chat_output.usage.prompt_tokens,
|
|
665
|
+
"completion_tokens": api_chat_output.usage.completion_tokens,
|
|
666
|
+
}
|
|
667
|
+
meta["usage"] = usage
|
|
668
|
+
|
|
669
|
+
message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls, reasoning=reasoning, meta=meta)
|
|
670
|
+
return {"replies": [message]}
|
|
671
|
+
|
|
672
|
+
async def _run_streaming_async(
|
|
673
|
+
self,
|
|
674
|
+
messages: list[dict[str, str]],
|
|
675
|
+
generation_kwargs: dict[str, Any],
|
|
676
|
+
streaming_callback: AsyncStreamingCallbackT,
|
|
677
|
+
) -> dict[str, list[ChatMessage]]:
|
|
678
|
+
api_output: AsyncIterable[ChatCompletionStreamOutput] = await self._async_client.chat_completion(
|
|
679
|
+
messages,
|
|
680
|
+
stream=True,
|
|
681
|
+
stream_options=ChatCompletionInputStreamOptions(include_usage=True),
|
|
682
|
+
**generation_kwargs,
|
|
683
|
+
)
|
|
684
|
+
|
|
685
|
+
component_info = ComponentInfo.from_component(self)
|
|
686
|
+
streaming_chunks: list[StreamingChunk] = []
|
|
687
|
+
async for chunk in api_output:
|
|
688
|
+
stream_chunk = _convert_chat_completion_stream_output_to_streaming_chunk(
|
|
689
|
+
chunk=chunk, previous_chunks=streaming_chunks, component_info=component_info
|
|
690
|
+
)
|
|
691
|
+
streaming_chunks.append(stream_chunk)
|
|
692
|
+
await streaming_callback(stream_chunk)
|
|
693
|
+
|
|
694
|
+
message = _convert_streaming_chunks_to_chat_message(chunks=streaming_chunks)
|
|
695
|
+
if message.meta.get("usage") is None:
|
|
696
|
+
message.meta["usage"] = {"prompt_tokens": 0, "completion_tokens": 0}
|
|
697
|
+
|
|
698
|
+
return {"replies": [message]}
|
|
699
|
+
|
|
700
|
+
async def _run_non_streaming_async(
|
|
701
|
+
self,
|
|
702
|
+
messages: list[dict[str, str]],
|
|
703
|
+
generation_kwargs: dict[str, Any],
|
|
704
|
+
tools: list["ChatCompletionInputTool"] | None = None,
|
|
705
|
+
) -> dict[str, list[ChatMessage]]:
|
|
706
|
+
api_chat_output: ChatCompletionOutput = await self._async_client.chat_completion(
|
|
707
|
+
messages=messages, tools=tools, **generation_kwargs
|
|
708
|
+
)
|
|
709
|
+
|
|
710
|
+
if api_chat_output.choices is None or len(api_chat_output.choices) == 0:
|
|
711
|
+
return {"replies": []}
|
|
712
|
+
|
|
713
|
+
choice = api_chat_output.choices[0]
|
|
714
|
+
|
|
715
|
+
text = choice.message.content
|
|
716
|
+
|
|
717
|
+
tool_calls = _convert_hfapi_tool_calls(choice.message.tool_calls)
|
|
718
|
+
|
|
719
|
+
# Extract reasoning content if present
|
|
720
|
+
reasoning = _extract_reasoning_content(choice.message)
|
|
721
|
+
|
|
722
|
+
mapped_finish_reason = _map_hf_finish_reason_to_haystack(choice) if choice.finish_reason else None
|
|
723
|
+
meta: dict[str, Any] = {
|
|
724
|
+
"model": self._async_client.model,
|
|
725
|
+
"finish_reason": mapped_finish_reason,
|
|
726
|
+
"index": choice.index,
|
|
727
|
+
}
|
|
728
|
+
|
|
729
|
+
usage = {"prompt_tokens": 0, "completion_tokens": 0}
|
|
730
|
+
if api_chat_output.usage:
|
|
731
|
+
usage = {
|
|
732
|
+
"prompt_tokens": api_chat_output.usage.prompt_tokens,
|
|
733
|
+
"completion_tokens": api_chat_output.usage.completion_tokens,
|
|
734
|
+
}
|
|
735
|
+
meta["usage"] = usage
|
|
736
|
+
|
|
737
|
+
message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls, reasoning=reasoning, meta=meta)
|
|
738
|
+
return {"replies": [message]}
|