huggingface-hub 0.34.4__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of huggingface-hub might be problematic. Click here for more details.
- huggingface_hub/__init__.py +46 -45
- huggingface_hub/_commit_api.py +28 -28
- huggingface_hub/_commit_scheduler.py +11 -8
- huggingface_hub/_inference_endpoints.py +8 -8
- huggingface_hub/_jobs_api.py +167 -10
- huggingface_hub/_login.py +13 -39
- huggingface_hub/_oauth.py +8 -8
- huggingface_hub/_snapshot_download.py +14 -28
- huggingface_hub/_space_api.py +4 -4
- huggingface_hub/_tensorboard_logger.py +13 -14
- huggingface_hub/_upload_large_folder.py +15 -15
- huggingface_hub/_webhooks_payload.py +3 -3
- huggingface_hub/_webhooks_server.py +2 -2
- huggingface_hub/cli/_cli_utils.py +2 -2
- huggingface_hub/cli/auth.py +5 -6
- huggingface_hub/cli/cache.py +14 -20
- huggingface_hub/cli/download.py +4 -4
- huggingface_hub/cli/jobs.py +560 -11
- huggingface_hub/cli/lfs.py +4 -4
- huggingface_hub/cli/repo.py +7 -7
- huggingface_hub/cli/repo_files.py +2 -2
- huggingface_hub/cli/upload.py +4 -4
- huggingface_hub/cli/upload_large_folder.py +3 -3
- huggingface_hub/commands/_cli_utils.py +2 -2
- huggingface_hub/commands/delete_cache.py +13 -13
- huggingface_hub/commands/download.py +4 -13
- huggingface_hub/commands/lfs.py +4 -4
- huggingface_hub/commands/repo_files.py +2 -2
- huggingface_hub/commands/scan_cache.py +1 -1
- huggingface_hub/commands/tag.py +1 -3
- huggingface_hub/commands/upload.py +4 -4
- huggingface_hub/commands/upload_large_folder.py +3 -3
- huggingface_hub/commands/user.py +5 -6
- huggingface_hub/community.py +5 -5
- huggingface_hub/constants.py +3 -41
- huggingface_hub/dataclasses.py +16 -19
- huggingface_hub/errors.py +42 -29
- huggingface_hub/fastai_utils.py +8 -9
- huggingface_hub/file_download.py +153 -252
- huggingface_hub/hf_api.py +815 -600
- huggingface_hub/hf_file_system.py +98 -62
- huggingface_hub/hub_mixin.py +37 -57
- huggingface_hub/inference/_client.py +177 -325
- huggingface_hub/inference/_common.py +110 -124
- huggingface_hub/inference/_generated/_async_client.py +226 -432
- huggingface_hub/inference/_generated/types/automatic_speech_recognition.py +3 -3
- huggingface_hub/inference/_generated/types/base.py +10 -7
- huggingface_hub/inference/_generated/types/chat_completion.py +18 -16
- huggingface_hub/inference/_generated/types/depth_estimation.py +2 -2
- huggingface_hub/inference/_generated/types/document_question_answering.py +2 -2
- huggingface_hub/inference/_generated/types/feature_extraction.py +2 -2
- huggingface_hub/inference/_generated/types/fill_mask.py +2 -2
- huggingface_hub/inference/_generated/types/sentence_similarity.py +3 -3
- huggingface_hub/inference/_generated/types/summarization.py +2 -2
- huggingface_hub/inference/_generated/types/table_question_answering.py +4 -4
- huggingface_hub/inference/_generated/types/text2text_generation.py +2 -2
- huggingface_hub/inference/_generated/types/text_generation.py +10 -10
- huggingface_hub/inference/_generated/types/text_to_video.py +2 -2
- huggingface_hub/inference/_generated/types/token_classification.py +2 -2
- huggingface_hub/inference/_generated/types/translation.py +2 -2
- huggingface_hub/inference/_generated/types/zero_shot_classification.py +2 -2
- huggingface_hub/inference/_generated/types/zero_shot_image_classification.py +2 -2
- huggingface_hub/inference/_generated/types/zero_shot_object_detection.py +1 -3
- huggingface_hub/inference/_mcp/_cli_hacks.py +3 -3
- huggingface_hub/inference/_mcp/agent.py +3 -3
- huggingface_hub/inference/_mcp/cli.py +1 -1
- huggingface_hub/inference/_mcp/constants.py +2 -3
- huggingface_hub/inference/_mcp/mcp_client.py +58 -30
- huggingface_hub/inference/_mcp/types.py +10 -7
- huggingface_hub/inference/_mcp/utils.py +11 -7
- huggingface_hub/inference/_providers/__init__.py +2 -2
- huggingface_hub/inference/_providers/_common.py +49 -25
- huggingface_hub/inference/_providers/black_forest_labs.py +6 -6
- huggingface_hub/inference/_providers/cohere.py +3 -3
- huggingface_hub/inference/_providers/fal_ai.py +25 -25
- huggingface_hub/inference/_providers/featherless_ai.py +4 -4
- huggingface_hub/inference/_providers/fireworks_ai.py +3 -3
- huggingface_hub/inference/_providers/hf_inference.py +28 -20
- huggingface_hub/inference/_providers/hyperbolic.py +4 -4
- huggingface_hub/inference/_providers/nebius.py +10 -10
- huggingface_hub/inference/_providers/novita.py +5 -5
- huggingface_hub/inference/_providers/nscale.py +4 -4
- huggingface_hub/inference/_providers/replicate.py +15 -15
- huggingface_hub/inference/_providers/sambanova.py +6 -6
- huggingface_hub/inference/_providers/together.py +7 -7
- huggingface_hub/lfs.py +20 -31
- huggingface_hub/repocard.py +18 -18
- huggingface_hub/repocard_data.py +56 -56
- huggingface_hub/serialization/__init__.py +0 -1
- huggingface_hub/serialization/_base.py +9 -9
- huggingface_hub/serialization/_dduf.py +7 -7
- huggingface_hub/serialization/_torch.py +28 -28
- huggingface_hub/utils/__init__.py +10 -4
- huggingface_hub/utils/_auth.py +5 -5
- huggingface_hub/utils/_cache_manager.py +31 -31
- huggingface_hub/utils/_deprecation.py +1 -1
- huggingface_hub/utils/_dotenv.py +3 -3
- huggingface_hub/utils/_fixes.py +0 -10
- huggingface_hub/utils/_git_credential.py +4 -4
- huggingface_hub/utils/_headers.py +7 -29
- huggingface_hub/utils/_http.py +366 -208
- huggingface_hub/utils/_pagination.py +4 -4
- huggingface_hub/utils/_paths.py +5 -5
- huggingface_hub/utils/_runtime.py +15 -13
- huggingface_hub/utils/_safetensors.py +21 -21
- huggingface_hub/utils/_subprocess.py +9 -9
- huggingface_hub/utils/_telemetry.py +3 -3
- huggingface_hub/utils/_typing.py +25 -5
- huggingface_hub/utils/_validators.py +53 -72
- huggingface_hub/utils/_xet.py +16 -16
- huggingface_hub/utils/_xet_progress_reporting.py +32 -11
- huggingface_hub/utils/insecure_hashlib.py +3 -9
- huggingface_hub/utils/tqdm.py +3 -3
- {huggingface_hub-0.34.4.dist-info → huggingface_hub-1.0.0rc0.dist-info}/METADATA +18 -29
- huggingface_hub-1.0.0rc0.dist-info/RECORD +161 -0
- huggingface_hub/inference_api.py +0 -217
- huggingface_hub/keras_mixin.py +0 -500
- huggingface_hub/repository.py +0 -1477
- huggingface_hub/serialization/_tensorflow.py +0 -95
- huggingface_hub/utils/_hf_folder.py +0 -68
- huggingface_hub-0.34.4.dist-info/RECORD +0 -166
- {huggingface_hub-0.34.4.dist-info → huggingface_hub-1.0.0rc0.dist-info}/LICENSE +0 -0
- {huggingface_hub-0.34.4.dist-info → huggingface_hub-1.0.0rc0.dist-info}/WHEEL +0 -0
- {huggingface_hub-0.34.4.dist-info → huggingface_hub-1.0.0rc0.dist-info}/entry_points.txt +0 -0
- {huggingface_hub-0.34.4.dist-info → huggingface_hub-1.0.0rc0.dist-info}/top_level.txt +0 -0
|
@@ -3,9 +3,9 @@ import logging
|
|
|
3
3
|
from contextlib import AsyncExitStack
|
|
4
4
|
from datetime import timedelta
|
|
5
5
|
from pathlib import Path
|
|
6
|
-
from typing import TYPE_CHECKING, Any, AsyncIterable,
|
|
6
|
+
from typing import TYPE_CHECKING, Any, AsyncIterable, Literal, Optional, TypedDict, Union, overload
|
|
7
7
|
|
|
8
|
-
from typing_extensions import NotRequired, TypeAlias,
|
|
8
|
+
from typing_extensions import NotRequired, TypeAlias, Unpack
|
|
9
9
|
|
|
10
10
|
from ...utils._runtime import get_hf_hub_version
|
|
11
11
|
from .._generated._async_client import AsyncInferenceClient
|
|
@@ -32,14 +32,14 @@ ServerType: TypeAlias = Literal["stdio", "sse", "http"]
|
|
|
32
32
|
|
|
33
33
|
class StdioServerParameters_T(TypedDict):
|
|
34
34
|
command: str
|
|
35
|
-
args: NotRequired[
|
|
36
|
-
env: NotRequired[
|
|
35
|
+
args: NotRequired[list[str]]
|
|
36
|
+
env: NotRequired[dict[str, str]]
|
|
37
37
|
cwd: NotRequired[Union[str, Path, None]]
|
|
38
38
|
|
|
39
39
|
|
|
40
40
|
class SSEServerParameters_T(TypedDict):
|
|
41
41
|
url: str
|
|
42
|
-
headers: NotRequired[
|
|
42
|
+
headers: NotRequired[dict[str, Any]]
|
|
43
43
|
timeout: NotRequired[float]
|
|
44
44
|
sse_read_timeout: NotRequired[float]
|
|
45
45
|
|
|
@@ -84,9 +84,9 @@ class MCPClient:
|
|
|
84
84
|
api_key: Optional[str] = None,
|
|
85
85
|
):
|
|
86
86
|
# Initialize MCP sessions as a dictionary of ClientSession objects
|
|
87
|
-
self.sessions:
|
|
87
|
+
self.sessions: dict[ToolName, "ClientSession"] = {}
|
|
88
88
|
self.exit_stack = AsyncExitStack()
|
|
89
|
-
self.available_tools:
|
|
89
|
+
self.available_tools: list[ChatCompletionInputTool] = []
|
|
90
90
|
# To be able to send the model in the payload if `base_url` is provided
|
|
91
91
|
if model is None and base_url is None:
|
|
92
92
|
raise ValueError("At least one of `model` or `base_url` should be set in `MCPClient`.")
|
|
@@ -132,28 +132,34 @@ class MCPClient:
|
|
|
132
132
|
- "stdio": Standard input/output server (local)
|
|
133
133
|
- "sse": Server-sent events (SSE) server
|
|
134
134
|
- "http": StreamableHTTP server
|
|
135
|
-
**params (`
|
|
135
|
+
**params (`dict[str, Any]`):
|
|
136
136
|
Server parameters that can be either:
|
|
137
137
|
- For stdio servers:
|
|
138
138
|
- command (str): The command to run the MCP server
|
|
139
|
-
- args (
|
|
140
|
-
- env (
|
|
139
|
+
- args (list[str], optional): Arguments for the command
|
|
140
|
+
- env (dict[str, str], optional): Environment variables for the command
|
|
141
141
|
- cwd (Union[str, Path, None], optional): Working directory for the command
|
|
142
|
+
- allowed_tools (list[str], optional): List of tool names to allow from this server
|
|
142
143
|
- For SSE servers:
|
|
143
144
|
- url (str): The URL of the SSE server
|
|
144
|
-
- headers (
|
|
145
|
+
- headers (dict[str, Any], optional): Headers for the SSE connection
|
|
145
146
|
- timeout (float, optional): Connection timeout
|
|
146
147
|
- sse_read_timeout (float, optional): SSE read timeout
|
|
148
|
+
- allowed_tools (list[str], optional): List of tool names to allow from this server
|
|
147
149
|
- For StreamableHTTP servers:
|
|
148
150
|
- url (str): The URL of the StreamableHTTP server
|
|
149
|
-
- headers (
|
|
151
|
+
- headers (dict[str, Any], optional): Headers for the StreamableHTTP connection
|
|
150
152
|
- timeout (timedelta, optional): Connection timeout
|
|
151
153
|
- sse_read_timeout (timedelta, optional): SSE read timeout
|
|
152
154
|
- terminate_on_close (bool, optional): Whether to terminate on close
|
|
155
|
+
- allowed_tools (list[str], optional): List of tool names to allow from this server
|
|
153
156
|
"""
|
|
154
157
|
from mcp import ClientSession, StdioServerParameters
|
|
155
158
|
from mcp import types as mcp_types
|
|
156
159
|
|
|
160
|
+
# Extract allowed_tools configuration if provided
|
|
161
|
+
allowed_tools = params.pop("allowed_tools", [])
|
|
162
|
+
|
|
157
163
|
# Determine server type and create appropriate parameters
|
|
158
164
|
if type == "stdio":
|
|
159
165
|
# Handle stdio server
|
|
@@ -211,7 +217,15 @@ class MCPClient:
|
|
|
211
217
|
response = await session.list_tools()
|
|
212
218
|
logger.debug("Connected to server with tools:", [tool.name for tool in response.tools])
|
|
213
219
|
|
|
214
|
-
|
|
220
|
+
# Filter tools based on allowed_tools configuration
|
|
221
|
+
filtered_tools = [tool for tool in response.tools if tool.name in allowed_tools]
|
|
222
|
+
|
|
223
|
+
if allowed_tools:
|
|
224
|
+
logger.debug(
|
|
225
|
+
f"Tool filtering applied. Using {len(filtered_tools)} of {len(response.tools)} available tools: {[tool.name for tool in filtered_tools]}"
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
for tool in filtered_tools:
|
|
215
229
|
if tool.name in self.sessions:
|
|
216
230
|
logger.warning(f"Tool '{tool.name}' already defined by another server. Skipping.")
|
|
217
231
|
continue
|
|
@@ -235,16 +249,16 @@ class MCPClient:
|
|
|
235
249
|
|
|
236
250
|
async def process_single_turn_with_tools(
|
|
237
251
|
self,
|
|
238
|
-
messages:
|
|
239
|
-
exit_loop_tools: Optional[
|
|
252
|
+
messages: list[Union[dict, ChatCompletionInputMessage]],
|
|
253
|
+
exit_loop_tools: Optional[list[ChatCompletionInputTool]] = None,
|
|
240
254
|
exit_if_first_chunk_no_tool: bool = False,
|
|
241
255
|
) -> AsyncIterable[Union[ChatCompletionStreamOutput, ChatCompletionInputMessage]]:
|
|
242
256
|
"""Process a query using `self.model` and available tools, yielding chunks and tool outputs.
|
|
243
257
|
|
|
244
258
|
Args:
|
|
245
|
-
messages (`
|
|
259
|
+
messages (`list[dict]`):
|
|
246
260
|
List of message objects representing the conversation history
|
|
247
|
-
exit_loop_tools (`
|
|
261
|
+
exit_loop_tools (`list[ChatCompletionInputTool]`, *optional*):
|
|
248
262
|
List of tools that should exit the generator when called
|
|
249
263
|
exit_if_first_chunk_no_tool (`bool`, *optional*):
|
|
250
264
|
Exit if no tool is present in the first chunks. Default to False.
|
|
@@ -266,8 +280,8 @@ class MCPClient:
|
|
|
266
280
|
stream=True,
|
|
267
281
|
)
|
|
268
282
|
|
|
269
|
-
message:
|
|
270
|
-
final_tool_calls:
|
|
283
|
+
message: dict[str, Any] = {"role": "unknown", "content": ""}
|
|
284
|
+
final_tool_calls: dict[int, ChatCompletionStreamOutputDeltaToolCall] = {}
|
|
271
285
|
num_of_chunks = 0
|
|
272
286
|
|
|
273
287
|
# Read from stream
|
|
@@ -286,16 +300,19 @@ class MCPClient:
|
|
|
286
300
|
# Process tool calls
|
|
287
301
|
if delta.tool_calls:
|
|
288
302
|
for tool_call in delta.tool_calls:
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
final_tool_calls[
|
|
303
|
+
idx = tool_call.index
|
|
304
|
+
# first chunk for this tool call
|
|
305
|
+
if idx not in final_tool_calls:
|
|
306
|
+
final_tool_calls[idx] = tool_call
|
|
307
|
+
if final_tool_calls[idx].function.arguments is None:
|
|
308
|
+
final_tool_calls[idx].function.arguments = ""
|
|
309
|
+
continue
|
|
310
|
+
# safety before concatenating text to .function.arguments
|
|
311
|
+
if final_tool_calls[idx].function.arguments is None:
|
|
312
|
+
final_tool_calls[idx].function.arguments = ""
|
|
313
|
+
|
|
314
|
+
if tool_call.function.arguments:
|
|
315
|
+
final_tool_calls[idx].function.arguments += tool_call.function.arguments
|
|
299
316
|
|
|
300
317
|
# Optionally exit early if no tools in first chunks
|
|
301
318
|
if exit_if_first_chunk_no_tool and num_of_chunks <= 2 and len(final_tool_calls) == 0:
|
|
@@ -311,7 +328,7 @@ class MCPClient:
|
|
|
311
328
|
message["role"] = "assistant"
|
|
312
329
|
# Convert final_tool_calls to the format expected by OpenAI
|
|
313
330
|
if final_tool_calls:
|
|
314
|
-
tool_calls_list:
|
|
331
|
+
tool_calls_list: list[dict[str, Any]] = []
|
|
315
332
|
for tc in final_tool_calls.values():
|
|
316
333
|
tool_calls_list.append(
|
|
317
334
|
{
|
|
@@ -329,6 +346,17 @@ class MCPClient:
|
|
|
329
346
|
# Process tool calls one by one
|
|
330
347
|
for tool_call in final_tool_calls.values():
|
|
331
348
|
function_name = tool_call.function.name
|
|
349
|
+
if function_name is None:
|
|
350
|
+
message = ChatCompletionInputMessage.parse_obj_as_instance(
|
|
351
|
+
{
|
|
352
|
+
"role": "tool",
|
|
353
|
+
"tool_call_id": tool_call.id,
|
|
354
|
+
"content": "Invalid tool call with no function name.",
|
|
355
|
+
}
|
|
356
|
+
)
|
|
357
|
+
messages.append(message)
|
|
358
|
+
yield message
|
|
359
|
+
continue # move to next tool call
|
|
332
360
|
try:
|
|
333
361
|
function_args = json.loads(tool_call.function.arguments or "{}")
|
|
334
362
|
except json.JSONDecodeError as err:
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import
|
|
1
|
+
from typing import Literal, TypedDict, Union
|
|
2
2
|
|
|
3
3
|
from typing_extensions import NotRequired
|
|
4
4
|
|
|
@@ -13,21 +13,24 @@ class InputConfig(TypedDict, total=False):
|
|
|
13
13
|
class StdioServerConfig(TypedDict):
|
|
14
14
|
type: Literal["stdio"]
|
|
15
15
|
command: str
|
|
16
|
-
args:
|
|
17
|
-
env:
|
|
16
|
+
args: list[str]
|
|
17
|
+
env: dict[str, str]
|
|
18
18
|
cwd: str
|
|
19
|
+
allowed_tools: NotRequired[list[str]]
|
|
19
20
|
|
|
20
21
|
|
|
21
22
|
class HTTPServerConfig(TypedDict):
|
|
22
23
|
type: Literal["http"]
|
|
23
24
|
url: str
|
|
24
|
-
headers:
|
|
25
|
+
headers: dict[str, str]
|
|
26
|
+
allowed_tools: NotRequired[list[str]]
|
|
25
27
|
|
|
26
28
|
|
|
27
29
|
class SSEServerConfig(TypedDict):
|
|
28
30
|
type: Literal["sse"]
|
|
29
31
|
url: str
|
|
30
|
-
headers:
|
|
32
|
+
headers: dict[str, str]
|
|
33
|
+
allowed_tools: NotRequired[list[str]]
|
|
31
34
|
|
|
32
35
|
|
|
33
36
|
ServerConfig = Union[StdioServerConfig, HTTPServerConfig, SSEServerConfig]
|
|
@@ -38,5 +41,5 @@ class AgentConfig(TypedDict):
|
|
|
38
41
|
model: str
|
|
39
42
|
provider: str
|
|
40
43
|
apiKey: NotRequired[str]
|
|
41
|
-
inputs:
|
|
42
|
-
servers:
|
|
44
|
+
inputs: list[InputConfig]
|
|
45
|
+
servers: list[ServerConfig]
|
|
@@ -6,12 +6,12 @@ Formatting utilities taken from the JS SDK: https://github.com/huggingface/huggi
|
|
|
6
6
|
|
|
7
7
|
import json
|
|
8
8
|
from pathlib import Path
|
|
9
|
-
from typing import TYPE_CHECKING,
|
|
9
|
+
from typing import TYPE_CHECKING, Optional
|
|
10
10
|
|
|
11
11
|
from huggingface_hub import snapshot_download
|
|
12
12
|
from huggingface_hub.errors import EntryNotFoundError
|
|
13
13
|
|
|
14
|
-
from .constants import DEFAULT_AGENT, DEFAULT_REPO_ID, FILENAME_CONFIG,
|
|
14
|
+
from .constants import DEFAULT_AGENT, DEFAULT_REPO_ID, FILENAME_CONFIG, PROMPT_FILENAMES
|
|
15
15
|
from .types import AgentConfig
|
|
16
16
|
|
|
17
17
|
|
|
@@ -36,7 +36,7 @@ def format_result(result: "mcp_types.CallToolResult") -> str:
|
|
|
36
36
|
if len(content) == 0:
|
|
37
37
|
return "[No content]"
|
|
38
38
|
|
|
39
|
-
formatted_parts:
|
|
39
|
+
formatted_parts: list[str] = []
|
|
40
40
|
|
|
41
41
|
for item in content:
|
|
42
42
|
if item.type == "text":
|
|
@@ -84,17 +84,21 @@ def _get_base64_size(base64_str: str) -> int:
|
|
|
84
84
|
return (len(base64_str) * 3) // 4 - padding
|
|
85
85
|
|
|
86
86
|
|
|
87
|
-
def _load_agent_config(agent_path: Optional[str]) ->
|
|
87
|
+
def _load_agent_config(agent_path: Optional[str]) -> tuple[AgentConfig, Optional[str]]:
|
|
88
88
|
"""Load server config and prompt."""
|
|
89
89
|
|
|
90
|
-
def _read_dir(directory: Path) ->
|
|
90
|
+
def _read_dir(directory: Path) -> tuple[AgentConfig, Optional[str]]:
|
|
91
91
|
cfg_file = directory / FILENAME_CONFIG
|
|
92
92
|
if not cfg_file.exists():
|
|
93
93
|
raise FileNotFoundError(f" Config file not found in {directory}! Please make sure it exists locally")
|
|
94
94
|
|
|
95
95
|
config: AgentConfig = json.loads(cfg_file.read_text(encoding="utf-8"))
|
|
96
|
-
|
|
97
|
-
|
|
96
|
+
prompt: Optional[str] = None
|
|
97
|
+
for filename in PROMPT_FILENAMES:
|
|
98
|
+
prompt_file = directory / filename
|
|
99
|
+
if prompt_file.exists():
|
|
100
|
+
prompt = prompt_file.read_text(encoding="utf-8")
|
|
101
|
+
break
|
|
98
102
|
return config, prompt
|
|
99
103
|
|
|
100
104
|
if agent_path is None:
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import
|
|
1
|
+
from typing import Literal, Optional, Union
|
|
2
2
|
|
|
3
3
|
from huggingface_hub.inference._providers.featherless_ai import (
|
|
4
4
|
FeatherlessConversationalTask,
|
|
@@ -65,7 +65,7 @@ PROVIDER_T = Literal[
|
|
|
65
65
|
|
|
66
66
|
PROVIDER_OR_POLICY_T = Union[PROVIDER_T, Literal["auto"]]
|
|
67
67
|
|
|
68
|
-
PROVIDERS:
|
|
68
|
+
PROVIDERS: dict[PROVIDER_T, dict[str, TaskProviderHelper]] = {
|
|
69
69
|
"black-forest-labs": {
|
|
70
70
|
"text-to-image": BlackForestLabsTextToImageTask(),
|
|
71
71
|
},
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
from functools import lru_cache
|
|
2
|
-
from typing import Any,
|
|
2
|
+
from typing import Any, Optional, Union, overload
|
|
3
3
|
|
|
4
4
|
from huggingface_hub import constants
|
|
5
5
|
from huggingface_hub.hf_api import InferenceProviderMapping
|
|
6
|
-
from huggingface_hub.inference._common import RequestParameters
|
|
6
|
+
from huggingface_hub.inference._common import MimeBytes, RequestParameters
|
|
7
7
|
from huggingface_hub.inference._generated.types.chat_completion import ChatCompletionInputMessage
|
|
8
8
|
from huggingface_hub.utils import build_hf_headers, get_token, logging
|
|
9
9
|
|
|
@@ -14,7 +14,7 @@ logger = logging.get_logger(__name__)
|
|
|
14
14
|
# Dev purposes only.
|
|
15
15
|
# If you want to try to run inference for a new model locally before it's registered on huggingface.co
|
|
16
16
|
# for a given Inference Provider, you can add it to the following dictionary.
|
|
17
|
-
HARDCODED_MODEL_INFERENCE_MAPPING:
|
|
17
|
+
HARDCODED_MODEL_INFERENCE_MAPPING: dict[str, dict[str, InferenceProviderMapping]] = {
|
|
18
18
|
# "HF model ID" => InferenceProviderMapping object initialized with "Model ID on Inference Provider's side"
|
|
19
19
|
#
|
|
20
20
|
# Example:
|
|
@@ -38,14 +38,14 @@ HARDCODED_MODEL_INFERENCE_MAPPING: Dict[str, Dict[str, InferenceProviderMapping]
|
|
|
38
38
|
|
|
39
39
|
|
|
40
40
|
@overload
|
|
41
|
-
def filter_none(obj:
|
|
41
|
+
def filter_none(obj: dict[str, Any]) -> dict[str, Any]: ...
|
|
42
42
|
@overload
|
|
43
|
-
def filter_none(obj:
|
|
43
|
+
def filter_none(obj: list[Any]) -> list[Any]: ...
|
|
44
44
|
|
|
45
45
|
|
|
46
|
-
def filter_none(obj: Union[
|
|
46
|
+
def filter_none(obj: Union[dict[str, Any], list[Any]]) -> Union[dict[str, Any], list[Any]]:
|
|
47
47
|
if isinstance(obj, dict):
|
|
48
|
-
cleaned:
|
|
48
|
+
cleaned: dict[str, Any] = {}
|
|
49
49
|
for k, v in obj.items():
|
|
50
50
|
if v is None:
|
|
51
51
|
continue
|
|
@@ -72,11 +72,11 @@ class TaskProviderHelper:
|
|
|
72
72
|
self,
|
|
73
73
|
*,
|
|
74
74
|
inputs: Any,
|
|
75
|
-
parameters:
|
|
76
|
-
headers:
|
|
75
|
+
parameters: dict[str, Any],
|
|
76
|
+
headers: dict,
|
|
77
77
|
model: Optional[str],
|
|
78
78
|
api_key: Optional[str],
|
|
79
|
-
extra_payload: Optional[
|
|
79
|
+
extra_payload: Optional[dict[str, Any]] = None,
|
|
80
80
|
) -> RequestParameters:
|
|
81
81
|
"""
|
|
82
82
|
Prepare the request to be sent to the provider.
|
|
@@ -108,13 +108,22 @@ class TaskProviderHelper:
|
|
|
108
108
|
raise ValueError("Both payload and data cannot be set in the same request.")
|
|
109
109
|
if payload is None and data is None:
|
|
110
110
|
raise ValueError("Either payload or data must be set in the request.")
|
|
111
|
+
|
|
112
|
+
# normalize headers to lowercase and add content-type if not present
|
|
113
|
+
normalized_headers = self._normalize_headers(headers, payload, data)
|
|
114
|
+
|
|
111
115
|
return RequestParameters(
|
|
112
|
-
url=url,
|
|
116
|
+
url=url,
|
|
117
|
+
task=self.task,
|
|
118
|
+
model=provider_mapping_info.provider_id,
|
|
119
|
+
json=payload,
|
|
120
|
+
data=data,
|
|
121
|
+
headers=normalized_headers,
|
|
113
122
|
)
|
|
114
123
|
|
|
115
124
|
def get_response(
|
|
116
125
|
self,
|
|
117
|
-
response: Union[bytes,
|
|
126
|
+
response: Union[bytes, dict],
|
|
118
127
|
request_params: Optional[RequestParameters] = None,
|
|
119
128
|
) -> Any:
|
|
120
129
|
"""
|
|
@@ -172,7 +181,22 @@ class TaskProviderHelper:
|
|
|
172
181
|
)
|
|
173
182
|
return provider_mapping
|
|
174
183
|
|
|
175
|
-
def
|
|
184
|
+
def _normalize_headers(
|
|
185
|
+
self, headers: dict[str, Any], payload: Optional[dict[str, Any]], data: Optional[MimeBytes]
|
|
186
|
+
) -> dict[str, Any]:
|
|
187
|
+
"""Normalize the headers to use for the request.
|
|
188
|
+
|
|
189
|
+
Override this method in subclasses for customized headers.
|
|
190
|
+
"""
|
|
191
|
+
normalized_headers = {key.lower(): value for key, value in headers.items() if value is not None}
|
|
192
|
+
if normalized_headers.get("content-type") is None:
|
|
193
|
+
if data is not None and data.mime_type is not None:
|
|
194
|
+
normalized_headers["content-type"] = data.mime_type
|
|
195
|
+
elif payload is not None:
|
|
196
|
+
normalized_headers["content-type"] = "application/json"
|
|
197
|
+
return normalized_headers
|
|
198
|
+
|
|
199
|
+
def _prepare_headers(self, headers: dict, api_key: str) -> dict[str, Any]:
|
|
176
200
|
"""Return the headers to use for the request.
|
|
177
201
|
|
|
178
202
|
Override this method in subclasses for customized headers.
|
|
@@ -207,8 +231,8 @@ class TaskProviderHelper:
|
|
|
207
231
|
return ""
|
|
208
232
|
|
|
209
233
|
def _prepare_payload_as_dict(
|
|
210
|
-
self, inputs: Any, parameters:
|
|
211
|
-
) -> Optional[
|
|
234
|
+
self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping
|
|
235
|
+
) -> Optional[dict]:
|
|
212
236
|
"""Return the payload to use for the request, as a dict.
|
|
213
237
|
|
|
214
238
|
Override this method in subclasses for customized payloads.
|
|
@@ -219,10 +243,10 @@ class TaskProviderHelper:
|
|
|
219
243
|
def _prepare_payload_as_bytes(
|
|
220
244
|
self,
|
|
221
245
|
inputs: Any,
|
|
222
|
-
parameters:
|
|
246
|
+
parameters: dict,
|
|
223
247
|
provider_mapping_info: InferenceProviderMapping,
|
|
224
|
-
extra_payload: Optional[
|
|
225
|
-
) -> Optional[
|
|
248
|
+
extra_payload: Optional[dict],
|
|
249
|
+
) -> Optional[MimeBytes]:
|
|
226
250
|
"""Return the body to use for the request, as bytes.
|
|
227
251
|
|
|
228
252
|
Override this method in subclasses for customized body data.
|
|
@@ -245,10 +269,10 @@ class BaseConversationalTask(TaskProviderHelper):
|
|
|
245
269
|
|
|
246
270
|
def _prepare_payload_as_dict(
|
|
247
271
|
self,
|
|
248
|
-
inputs:
|
|
249
|
-
parameters:
|
|
272
|
+
inputs: list[Union[dict, ChatCompletionInputMessage]],
|
|
273
|
+
parameters: dict,
|
|
250
274
|
provider_mapping_info: InferenceProviderMapping,
|
|
251
|
-
) -> Optional[
|
|
275
|
+
) -> Optional[dict]:
|
|
252
276
|
return filter_none({"messages": inputs, **parameters, "model": provider_mapping_info.provider_id})
|
|
253
277
|
|
|
254
278
|
|
|
@@ -265,13 +289,13 @@ class BaseTextGenerationTask(TaskProviderHelper):
|
|
|
265
289
|
return "/v1/completions"
|
|
266
290
|
|
|
267
291
|
def _prepare_payload_as_dict(
|
|
268
|
-
self, inputs: Any, parameters:
|
|
269
|
-
) -> Optional[
|
|
292
|
+
self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping
|
|
293
|
+
) -> Optional[dict]:
|
|
270
294
|
return filter_none({"prompt": inputs, **parameters, "model": provider_mapping_info.provider_id})
|
|
271
295
|
|
|
272
296
|
|
|
273
297
|
@lru_cache(maxsize=None)
|
|
274
|
-
def _fetch_inference_provider_mapping(model: str) ->
|
|
298
|
+
def _fetch_inference_provider_mapping(model: str) -> list["InferenceProviderMapping"]:
|
|
275
299
|
"""
|
|
276
300
|
Fetch provider mappings for a model from the Hub.
|
|
277
301
|
"""
|
|
@@ -284,7 +308,7 @@ def _fetch_inference_provider_mapping(model: str) -> List["InferenceProviderMapp
|
|
|
284
308
|
return provider_mapping
|
|
285
309
|
|
|
286
310
|
|
|
287
|
-
def recursive_merge(dict1:
|
|
311
|
+
def recursive_merge(dict1: dict, dict2: dict) -> dict:
|
|
288
312
|
return {
|
|
289
313
|
**dict1,
|
|
290
314
|
**{
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import time
|
|
2
|
-
from typing import Any,
|
|
2
|
+
from typing import Any, Optional, Union
|
|
3
3
|
|
|
4
4
|
from huggingface_hub.hf_api import InferenceProviderMapping
|
|
5
5
|
from huggingface_hub.inference._common import RequestParameters, _as_dict
|
|
@@ -18,7 +18,7 @@ class BlackForestLabsTextToImageTask(TaskProviderHelper):
|
|
|
18
18
|
def __init__(self):
|
|
19
19
|
super().__init__(provider="black-forest-labs", base_url="https://api.us1.bfl.ai", task="text-to-image")
|
|
20
20
|
|
|
21
|
-
def _prepare_headers(self, headers:
|
|
21
|
+
def _prepare_headers(self, headers: dict, api_key: str) -> dict[str, Any]:
|
|
22
22
|
headers = super()._prepare_headers(headers, api_key)
|
|
23
23
|
if not api_key.startswith("hf_"):
|
|
24
24
|
_ = headers.pop("authorization")
|
|
@@ -29,8 +29,8 @@ class BlackForestLabsTextToImageTask(TaskProviderHelper):
|
|
|
29
29
|
return f"/v1/{mapped_model}"
|
|
30
30
|
|
|
31
31
|
def _prepare_payload_as_dict(
|
|
32
|
-
self, inputs: Any, parameters:
|
|
33
|
-
) -> Optional[
|
|
32
|
+
self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping
|
|
33
|
+
) -> Optional[dict]:
|
|
34
34
|
parameters = filter_none(parameters)
|
|
35
35
|
if "num_inference_steps" in parameters:
|
|
36
36
|
parameters["steps"] = parameters.pop("num_inference_steps")
|
|
@@ -39,7 +39,7 @@ class BlackForestLabsTextToImageTask(TaskProviderHelper):
|
|
|
39
39
|
|
|
40
40
|
return {"prompt": inputs, **parameters}
|
|
41
41
|
|
|
42
|
-
def get_response(self, response: Union[bytes,
|
|
42
|
+
def get_response(self, response: Union[bytes, dict], request_params: Optional[RequestParameters] = None) -> Any:
|
|
43
43
|
"""
|
|
44
44
|
Polling mechanism for Black Forest Labs since the API is asynchronous.
|
|
45
45
|
"""
|
|
@@ -50,7 +50,7 @@ class BlackForestLabsTextToImageTask(TaskProviderHelper):
|
|
|
50
50
|
|
|
51
51
|
response = session.get(url, headers={"Content-Type": "application/json"}) # type: ignore
|
|
52
52
|
response.raise_for_status() # type: ignore
|
|
53
|
-
response_json:
|
|
53
|
+
response_json: dict = response.json() # type: ignore
|
|
54
54
|
status = response_json.get("status")
|
|
55
55
|
logger.info(
|
|
56
56
|
f"Polling generation result from {url}. Current status: {status}. "
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Any,
|
|
1
|
+
from typing import Any, Optional
|
|
2
2
|
|
|
3
3
|
from huggingface_hub.hf_api import InferenceProviderMapping
|
|
4
4
|
|
|
@@ -17,8 +17,8 @@ class CohereConversationalTask(BaseConversationalTask):
|
|
|
17
17
|
return "/compatibility/v1/chat/completions"
|
|
18
18
|
|
|
19
19
|
def _prepare_payload_as_dict(
|
|
20
|
-
self, inputs: Any, parameters:
|
|
21
|
-
) -> Optional[
|
|
20
|
+
self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping
|
|
21
|
+
) -> Optional[dict]:
|
|
22
22
|
payload = super()._prepare_payload_as_dict(inputs, parameters, provider_mapping_info)
|
|
23
23
|
response_format = parameters.get("response_format")
|
|
24
24
|
if isinstance(response_format, dict) and response_format.get("type") == "json_schema":
|