huggingface-hub 0.36.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of huggingface-hub might be problematic. Click here for more details.
- huggingface_hub/__init__.py +33 -45
- huggingface_hub/_commit_api.py +39 -43
- huggingface_hub/_commit_scheduler.py +11 -8
- huggingface_hub/_inference_endpoints.py +8 -8
- huggingface_hub/_jobs_api.py +20 -20
- huggingface_hub/_login.py +17 -43
- huggingface_hub/_oauth.py +8 -8
- huggingface_hub/_snapshot_download.py +135 -50
- huggingface_hub/_space_api.py +4 -4
- huggingface_hub/_tensorboard_logger.py +5 -5
- huggingface_hub/_upload_large_folder.py +18 -32
- huggingface_hub/_webhooks_payload.py +3 -3
- huggingface_hub/_webhooks_server.py +2 -2
- huggingface_hub/cli/__init__.py +0 -14
- huggingface_hub/cli/_cli_utils.py +143 -39
- huggingface_hub/cli/auth.py +105 -171
- huggingface_hub/cli/cache.py +594 -361
- huggingface_hub/cli/download.py +120 -112
- huggingface_hub/cli/hf.py +38 -41
- huggingface_hub/cli/jobs.py +689 -1017
- huggingface_hub/cli/lfs.py +120 -143
- huggingface_hub/cli/repo.py +282 -216
- huggingface_hub/cli/repo_files.py +50 -84
- huggingface_hub/cli/system.py +6 -25
- huggingface_hub/cli/upload.py +198 -220
- huggingface_hub/cli/upload_large_folder.py +91 -106
- huggingface_hub/community.py +5 -5
- huggingface_hub/constants.py +17 -52
- huggingface_hub/dataclasses.py +135 -21
- huggingface_hub/errors.py +47 -30
- huggingface_hub/fastai_utils.py +8 -9
- huggingface_hub/file_download.py +351 -303
- huggingface_hub/hf_api.py +398 -570
- huggingface_hub/hf_file_system.py +101 -66
- huggingface_hub/hub_mixin.py +32 -54
- huggingface_hub/inference/_client.py +177 -162
- huggingface_hub/inference/_common.py +38 -54
- huggingface_hub/inference/_generated/_async_client.py +218 -258
- huggingface_hub/inference/_generated/types/automatic_speech_recognition.py +3 -3
- huggingface_hub/inference/_generated/types/base.py +10 -7
- huggingface_hub/inference/_generated/types/chat_completion.py +16 -16
- huggingface_hub/inference/_generated/types/depth_estimation.py +2 -2
- huggingface_hub/inference/_generated/types/document_question_answering.py +2 -2
- huggingface_hub/inference/_generated/types/feature_extraction.py +2 -2
- huggingface_hub/inference/_generated/types/fill_mask.py +2 -2
- huggingface_hub/inference/_generated/types/sentence_similarity.py +3 -3
- huggingface_hub/inference/_generated/types/summarization.py +2 -2
- huggingface_hub/inference/_generated/types/table_question_answering.py +4 -4
- huggingface_hub/inference/_generated/types/text2text_generation.py +2 -2
- huggingface_hub/inference/_generated/types/text_generation.py +10 -10
- huggingface_hub/inference/_generated/types/text_to_video.py +2 -2
- huggingface_hub/inference/_generated/types/token_classification.py +2 -2
- huggingface_hub/inference/_generated/types/translation.py +2 -2
- huggingface_hub/inference/_generated/types/zero_shot_classification.py +2 -2
- huggingface_hub/inference/_generated/types/zero_shot_image_classification.py +2 -2
- huggingface_hub/inference/_generated/types/zero_shot_object_detection.py +1 -3
- huggingface_hub/inference/_mcp/agent.py +3 -3
- huggingface_hub/inference/_mcp/constants.py +1 -2
- huggingface_hub/inference/_mcp/mcp_client.py +33 -22
- huggingface_hub/inference/_mcp/types.py +10 -10
- huggingface_hub/inference/_mcp/utils.py +4 -4
- huggingface_hub/inference/_providers/__init__.py +12 -4
- huggingface_hub/inference/_providers/_common.py +62 -24
- huggingface_hub/inference/_providers/black_forest_labs.py +6 -6
- huggingface_hub/inference/_providers/cohere.py +3 -3
- huggingface_hub/inference/_providers/fal_ai.py +25 -25
- huggingface_hub/inference/_providers/featherless_ai.py +4 -4
- huggingface_hub/inference/_providers/fireworks_ai.py +3 -3
- huggingface_hub/inference/_providers/hf_inference.py +13 -13
- huggingface_hub/inference/_providers/hyperbolic.py +4 -4
- huggingface_hub/inference/_providers/nebius.py +10 -10
- huggingface_hub/inference/_providers/novita.py +5 -5
- huggingface_hub/inference/_providers/nscale.py +4 -4
- huggingface_hub/inference/_providers/replicate.py +15 -15
- huggingface_hub/inference/_providers/sambanova.py +6 -6
- huggingface_hub/inference/_providers/together.py +7 -7
- huggingface_hub/lfs.py +21 -94
- huggingface_hub/repocard.py +15 -16
- huggingface_hub/repocard_data.py +57 -57
- huggingface_hub/serialization/__init__.py +0 -1
- huggingface_hub/serialization/_base.py +9 -9
- huggingface_hub/serialization/_dduf.py +7 -7
- huggingface_hub/serialization/_torch.py +28 -28
- huggingface_hub/utils/__init__.py +11 -6
- huggingface_hub/utils/_auth.py +5 -5
- huggingface_hub/utils/_cache_manager.py +49 -74
- huggingface_hub/utils/_deprecation.py +1 -1
- huggingface_hub/utils/_dotenv.py +3 -3
- huggingface_hub/utils/_fixes.py +0 -10
- huggingface_hub/utils/_git_credential.py +3 -3
- huggingface_hub/utils/_headers.py +7 -29
- huggingface_hub/utils/_http.py +371 -208
- huggingface_hub/utils/_pagination.py +4 -4
- huggingface_hub/utils/_parsing.py +98 -0
- huggingface_hub/utils/_paths.py +5 -5
- huggingface_hub/utils/_runtime.py +59 -23
- huggingface_hub/utils/_safetensors.py +21 -21
- huggingface_hub/utils/_subprocess.py +9 -9
- huggingface_hub/utils/_telemetry.py +3 -3
- huggingface_hub/{commands/_cli_utils.py → utils/_terminal.py} +4 -9
- huggingface_hub/utils/_typing.py +3 -3
- huggingface_hub/utils/_validators.py +53 -72
- huggingface_hub/utils/_xet.py +16 -16
- huggingface_hub/utils/_xet_progress_reporting.py +1 -1
- huggingface_hub/utils/insecure_hashlib.py +3 -9
- huggingface_hub/utils/tqdm.py +3 -3
- {huggingface_hub-0.36.0.dist-info → huggingface_hub-1.0.0.dist-info}/METADATA +16 -35
- huggingface_hub-1.0.0.dist-info/RECORD +152 -0
- {huggingface_hub-0.36.0.dist-info → huggingface_hub-1.0.0.dist-info}/entry_points.txt +0 -1
- huggingface_hub/commands/__init__.py +0 -27
- huggingface_hub/commands/delete_cache.py +0 -476
- huggingface_hub/commands/download.py +0 -204
- huggingface_hub/commands/env.py +0 -39
- huggingface_hub/commands/huggingface_cli.py +0 -65
- huggingface_hub/commands/lfs.py +0 -200
- huggingface_hub/commands/repo.py +0 -151
- huggingface_hub/commands/repo_files.py +0 -132
- huggingface_hub/commands/scan_cache.py +0 -183
- huggingface_hub/commands/tag.py +0 -161
- huggingface_hub/commands/upload.py +0 -318
- huggingface_hub/commands/upload_large_folder.py +0 -131
- huggingface_hub/commands/user.py +0 -208
- huggingface_hub/commands/version.py +0 -40
- huggingface_hub/inference_api.py +0 -217
- huggingface_hub/keras_mixin.py +0 -497
- huggingface_hub/repository.py +0 -1471
- huggingface_hub/serialization/_tensorflow.py +0 -92
- huggingface_hub/utils/_hf_folder.py +0 -68
- huggingface_hub-0.36.0.dist-info/RECORD +0 -170
- {huggingface_hub-0.36.0.dist-info → huggingface_hub-1.0.0.dist-info}/LICENSE +0 -0
- {huggingface_hub-0.36.0.dist-info → huggingface_hub-1.0.0.dist-info}/WHEEL +0 -0
- {huggingface_hub-0.36.0.dist-info → huggingface_hub-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -3,9 +3,9 @@ import logging
|
|
|
3
3
|
from contextlib import AsyncExitStack
|
|
4
4
|
from datetime import timedelta
|
|
5
5
|
from pathlib import Path
|
|
6
|
-
from typing import TYPE_CHECKING, Any, AsyncIterable,
|
|
6
|
+
from typing import TYPE_CHECKING, Any, AsyncIterable, Literal, Optional, TypedDict, Union, overload
|
|
7
7
|
|
|
8
|
-
from typing_extensions import NotRequired, TypeAlias,
|
|
8
|
+
from typing_extensions import NotRequired, TypeAlias, Unpack
|
|
9
9
|
|
|
10
10
|
from ...utils._runtime import get_hf_hub_version
|
|
11
11
|
from .._generated._async_client import AsyncInferenceClient
|
|
@@ -32,14 +32,14 @@ ServerType: TypeAlias = Literal["stdio", "sse", "http"]
|
|
|
32
32
|
|
|
33
33
|
class StdioServerParameters_T(TypedDict):
|
|
34
34
|
command: str
|
|
35
|
-
args: NotRequired[
|
|
36
|
-
env: NotRequired[
|
|
35
|
+
args: NotRequired[list[str]]
|
|
36
|
+
env: NotRequired[dict[str, str]]
|
|
37
37
|
cwd: NotRequired[Union[str, Path, None]]
|
|
38
38
|
|
|
39
39
|
|
|
40
40
|
class SSEServerParameters_T(TypedDict):
|
|
41
41
|
url: str
|
|
42
|
-
headers: NotRequired[
|
|
42
|
+
headers: NotRequired[dict[str, Any]]
|
|
43
43
|
timeout: NotRequired[float]
|
|
44
44
|
sse_read_timeout: NotRequired[float]
|
|
45
45
|
|
|
@@ -81,9 +81,9 @@ class MCPClient:
|
|
|
81
81
|
api_key: Optional[str] = None,
|
|
82
82
|
):
|
|
83
83
|
# Initialize MCP sessions as a dictionary of ClientSession objects
|
|
84
|
-
self.sessions:
|
|
84
|
+
self.sessions: dict[ToolName, "ClientSession"] = {}
|
|
85
85
|
self.exit_stack = AsyncExitStack()
|
|
86
|
-
self.available_tools:
|
|
86
|
+
self.available_tools: list[ChatCompletionInputTool] = []
|
|
87
87
|
# To be able to send the model in the payload if `base_url` is provided
|
|
88
88
|
if model is None and base_url is None:
|
|
89
89
|
raise ValueError("At least one of `model` or `base_url` should be set in `MCPClient`.")
|
|
@@ -129,27 +129,27 @@ class MCPClient:
|
|
|
129
129
|
- "stdio": Standard input/output server (local)
|
|
130
130
|
- "sse": Server-sent events (SSE) server
|
|
131
131
|
- "http": StreamableHTTP server
|
|
132
|
-
**params (`
|
|
132
|
+
**params (`dict[str, Any]`):
|
|
133
133
|
Server parameters that can be either:
|
|
134
134
|
- For stdio servers:
|
|
135
135
|
- command (str): The command to run the MCP server
|
|
136
|
-
- args (
|
|
137
|
-
- env (
|
|
136
|
+
- args (list[str], optional): Arguments for the command
|
|
137
|
+
- env (dict[str, str], optional): Environment variables for the command
|
|
138
138
|
- cwd (Union[str, Path, None], optional): Working directory for the command
|
|
139
|
-
- allowed_tools (
|
|
139
|
+
- allowed_tools (list[str], optional): List of tool names to allow from this server
|
|
140
140
|
- For SSE servers:
|
|
141
141
|
- url (str): The URL of the SSE server
|
|
142
|
-
- headers (
|
|
142
|
+
- headers (dict[str, Any], optional): Headers for the SSE connection
|
|
143
143
|
- timeout (float, optional): Connection timeout
|
|
144
144
|
- sse_read_timeout (float, optional): SSE read timeout
|
|
145
|
-
- allowed_tools (
|
|
145
|
+
- allowed_tools (list[str], optional): List of tool names to allow from this server
|
|
146
146
|
- For StreamableHTTP servers:
|
|
147
147
|
- url (str): The URL of the StreamableHTTP server
|
|
148
|
-
- headers (
|
|
148
|
+
- headers (dict[str, Any], optional): Headers for the StreamableHTTP connection
|
|
149
149
|
- timeout (timedelta, optional): Connection timeout
|
|
150
150
|
- sse_read_timeout (timedelta, optional): SSE read timeout
|
|
151
151
|
- terminate_on_close (bool, optional): Whether to terminate on close
|
|
152
|
-
- allowed_tools (
|
|
152
|
+
- allowed_tools (list[str], optional): List of tool names to allow from this server
|
|
153
153
|
"""
|
|
154
154
|
from mcp import ClientSession, StdioServerParameters
|
|
155
155
|
from mcp import types as mcp_types
|
|
@@ -247,16 +247,16 @@ class MCPClient:
|
|
|
247
247
|
|
|
248
248
|
async def process_single_turn_with_tools(
|
|
249
249
|
self,
|
|
250
|
-
messages:
|
|
251
|
-
exit_loop_tools: Optional[
|
|
250
|
+
messages: list[Union[dict, ChatCompletionInputMessage]],
|
|
251
|
+
exit_loop_tools: Optional[list[ChatCompletionInputTool]] = None,
|
|
252
252
|
exit_if_first_chunk_no_tool: bool = False,
|
|
253
253
|
) -> AsyncIterable[Union[ChatCompletionStreamOutput, ChatCompletionInputMessage]]:
|
|
254
254
|
"""Process a query using `self.model` and available tools, yielding chunks and tool outputs.
|
|
255
255
|
|
|
256
256
|
Args:
|
|
257
|
-
messages (`
|
|
257
|
+
messages (`list[dict]`):
|
|
258
258
|
List of message objects representing the conversation history
|
|
259
|
-
exit_loop_tools (`
|
|
259
|
+
exit_loop_tools (`list[ChatCompletionInputTool]`, *optional*):
|
|
260
260
|
List of tools that should exit the generator when called
|
|
261
261
|
exit_if_first_chunk_no_tool (`bool`, *optional*):
|
|
262
262
|
Exit if no tool is present in the first chunks. Default to False.
|
|
@@ -278,8 +278,8 @@ class MCPClient:
|
|
|
278
278
|
stream=True,
|
|
279
279
|
)
|
|
280
280
|
|
|
281
|
-
message:
|
|
282
|
-
final_tool_calls:
|
|
281
|
+
message: dict[str, Any] = {"role": "unknown", "content": ""}
|
|
282
|
+
final_tool_calls: dict[int, ChatCompletionStreamOutputDeltaToolCall] = {}
|
|
283
283
|
num_of_chunks = 0
|
|
284
284
|
|
|
285
285
|
# Read from stream
|
|
@@ -326,7 +326,7 @@ class MCPClient:
|
|
|
326
326
|
message["role"] = "assistant"
|
|
327
327
|
# Convert final_tool_calls to the format expected by OpenAI
|
|
328
328
|
if final_tool_calls:
|
|
329
|
-
tool_calls_list:
|
|
329
|
+
tool_calls_list: list[dict[str, Any]] = []
|
|
330
330
|
for tc in final_tool_calls.values():
|
|
331
331
|
tool_calls_list.append(
|
|
332
332
|
{
|
|
@@ -344,6 +344,17 @@ class MCPClient:
|
|
|
344
344
|
# Process tool calls one by one
|
|
345
345
|
for tool_call in final_tool_calls.values():
|
|
346
346
|
function_name = tool_call.function.name
|
|
347
|
+
if function_name is None:
|
|
348
|
+
message = ChatCompletionInputMessage.parse_obj_as_instance(
|
|
349
|
+
{
|
|
350
|
+
"role": "tool",
|
|
351
|
+
"tool_call_id": tool_call.id,
|
|
352
|
+
"content": "Invalid tool call with no function name.",
|
|
353
|
+
}
|
|
354
|
+
)
|
|
355
|
+
messages.append(message)
|
|
356
|
+
yield message
|
|
357
|
+
continue # move to next tool call
|
|
347
358
|
try:
|
|
348
359
|
function_args = json.loads(tool_call.function.arguments or "{}")
|
|
349
360
|
except json.JSONDecodeError as err:
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import
|
|
1
|
+
from typing import Literal, TypedDict, Union
|
|
2
2
|
|
|
3
3
|
from typing_extensions import NotRequired
|
|
4
4
|
|
|
@@ -13,24 +13,24 @@ class InputConfig(TypedDict, total=False):
|
|
|
13
13
|
class StdioServerConfig(TypedDict):
|
|
14
14
|
type: Literal["stdio"]
|
|
15
15
|
command: str
|
|
16
|
-
args:
|
|
17
|
-
env:
|
|
16
|
+
args: list[str]
|
|
17
|
+
env: dict[str, str]
|
|
18
18
|
cwd: str
|
|
19
|
-
allowed_tools: NotRequired[
|
|
19
|
+
allowed_tools: NotRequired[list[str]]
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
class HTTPServerConfig(TypedDict):
|
|
23
23
|
type: Literal["http"]
|
|
24
24
|
url: str
|
|
25
|
-
headers:
|
|
26
|
-
allowed_tools: NotRequired[
|
|
25
|
+
headers: dict[str, str]
|
|
26
|
+
allowed_tools: NotRequired[list[str]]
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
class SSEServerConfig(TypedDict):
|
|
30
30
|
type: Literal["sse"]
|
|
31
31
|
url: str
|
|
32
|
-
headers:
|
|
33
|
-
allowed_tools: NotRequired[
|
|
32
|
+
headers: dict[str, str]
|
|
33
|
+
allowed_tools: NotRequired[list[str]]
|
|
34
34
|
|
|
35
35
|
|
|
36
36
|
ServerConfig = Union[StdioServerConfig, HTTPServerConfig, SSEServerConfig]
|
|
@@ -41,5 +41,5 @@ class AgentConfig(TypedDict):
|
|
|
41
41
|
model: str
|
|
42
42
|
provider: str
|
|
43
43
|
apiKey: NotRequired[str]
|
|
44
|
-
inputs:
|
|
45
|
-
servers:
|
|
44
|
+
inputs: list[InputConfig]
|
|
45
|
+
servers: list[ServerConfig]
|
|
@@ -6,7 +6,7 @@ Formatting utilities taken from the JS SDK: https://github.com/huggingface/huggi
|
|
|
6
6
|
|
|
7
7
|
import json
|
|
8
8
|
from pathlib import Path
|
|
9
|
-
from typing import TYPE_CHECKING,
|
|
9
|
+
from typing import TYPE_CHECKING, Optional
|
|
10
10
|
|
|
11
11
|
from huggingface_hub import snapshot_download
|
|
12
12
|
from huggingface_hub.errors import EntryNotFoundError
|
|
@@ -36,7 +36,7 @@ def format_result(result: "mcp_types.CallToolResult") -> str:
|
|
|
36
36
|
if len(content) == 0:
|
|
37
37
|
return "[No content]"
|
|
38
38
|
|
|
39
|
-
formatted_parts:
|
|
39
|
+
formatted_parts: list[str] = []
|
|
40
40
|
|
|
41
41
|
for item in content:
|
|
42
42
|
if item.type == "text":
|
|
@@ -84,10 +84,10 @@ def _get_base64_size(base64_str: str) -> int:
|
|
|
84
84
|
return (len(base64_str) * 3) // 4 - padding
|
|
85
85
|
|
|
86
86
|
|
|
87
|
-
def _load_agent_config(agent_path: Optional[str]) ->
|
|
87
|
+
def _load_agent_config(agent_path: Optional[str]) -> tuple[AgentConfig, Optional[str]]:
|
|
88
88
|
"""Load server config and prompt."""
|
|
89
89
|
|
|
90
|
-
def _read_dir(directory: Path) ->
|
|
90
|
+
def _read_dir(directory: Path) -> tuple[AgentConfig, Optional[str]]:
|
|
91
91
|
cfg_file = directory / FILENAME_CONFIG
|
|
92
92
|
if not cfg_file.exists():
|
|
93
93
|
raise FileNotFoundError(f" Config file not found in {directory}! Please make sure it exists locally")
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import
|
|
1
|
+
from typing import Literal, Optional, Union
|
|
2
2
|
|
|
3
3
|
from huggingface_hub.inference._providers.featherless_ai import (
|
|
4
4
|
FeatherlessConversationalTask,
|
|
@@ -6,7 +6,7 @@ from huggingface_hub.inference._providers.featherless_ai import (
|
|
|
6
6
|
)
|
|
7
7
|
from huggingface_hub.utils import logging
|
|
8
8
|
|
|
9
|
-
from ._common import TaskProviderHelper, _fetch_inference_provider_mapping
|
|
9
|
+
from ._common import AutoRouterConversationalTask, TaskProviderHelper, _fetch_inference_provider_mapping
|
|
10
10
|
from .black_forest_labs import BlackForestLabsTextToImageTask
|
|
11
11
|
from .cerebras import CerebrasConversationalTask
|
|
12
12
|
from .clarifai import ClarifaiConversationalTask
|
|
@@ -73,7 +73,9 @@ PROVIDER_T = Literal[
|
|
|
73
73
|
|
|
74
74
|
PROVIDER_OR_POLICY_T = Union[PROVIDER_T, Literal["auto"]]
|
|
75
75
|
|
|
76
|
-
|
|
76
|
+
CONVERSATIONAL_AUTO_ROUTER = AutoRouterConversationalTask()
|
|
77
|
+
|
|
78
|
+
PROVIDERS: dict[PROVIDER_T, dict[str, TaskProviderHelper]] = {
|
|
77
79
|
"black-forest-labs": {
|
|
78
80
|
"text-to-image": BlackForestLabsTextToImageTask(),
|
|
79
81
|
},
|
|
@@ -206,13 +208,19 @@ def get_provider_helper(
|
|
|
206
208
|
|
|
207
209
|
if provider is None:
|
|
208
210
|
logger.info(
|
|
209
|
-
"
|
|
211
|
+
"No provider specified for task `conversational`. Defaulting to server-side auto routing."
|
|
212
|
+
if task == "conversational"
|
|
213
|
+
else "Defaulting to 'auto' which will select the first provider available for the model, sorted by the user's order in https://hf.co/settings/inference-providers."
|
|
210
214
|
)
|
|
211
215
|
provider = "auto"
|
|
212
216
|
|
|
213
217
|
if provider == "auto":
|
|
214
218
|
if model is None:
|
|
215
219
|
raise ValueError("Specifying a model is required when provider is 'auto'")
|
|
220
|
+
if task == "conversational":
|
|
221
|
+
# Special case: we have a dedicated auto-router for conversational models. No need to fetch provider mapping.
|
|
222
|
+
return CONVERSATIONAL_AUTO_ROUTER
|
|
223
|
+
|
|
216
224
|
provider_mapping = _fetch_inference_provider_mapping(model)
|
|
217
225
|
provider = next(iter(provider_mapping)).provider
|
|
218
226
|
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from functools import lru_cache
|
|
2
|
-
from typing import Any,
|
|
2
|
+
from typing import Any, Optional, Union, overload
|
|
3
3
|
|
|
4
4
|
from huggingface_hub import constants
|
|
5
5
|
from huggingface_hub.hf_api import InferenceProviderMapping
|
|
@@ -14,7 +14,7 @@ logger = logging.get_logger(__name__)
|
|
|
14
14
|
# Dev purposes only.
|
|
15
15
|
# If you want to try to run inference for a new model locally before it's registered on huggingface.co
|
|
16
16
|
# for a given Inference Provider, you can add it to the following dictionary.
|
|
17
|
-
HARDCODED_MODEL_INFERENCE_MAPPING:
|
|
17
|
+
HARDCODED_MODEL_INFERENCE_MAPPING: dict[str, dict[str, InferenceProviderMapping]] = {
|
|
18
18
|
# "HF model ID" => InferenceProviderMapping object initialized with "Model ID on Inference Provider's side"
|
|
19
19
|
#
|
|
20
20
|
# Example:
|
|
@@ -41,14 +41,14 @@ HARDCODED_MODEL_INFERENCE_MAPPING: Dict[str, Dict[str, InferenceProviderMapping]
|
|
|
41
41
|
|
|
42
42
|
|
|
43
43
|
@overload
|
|
44
|
-
def filter_none(obj:
|
|
44
|
+
def filter_none(obj: dict[str, Any]) -> dict[str, Any]: ...
|
|
45
45
|
@overload
|
|
46
|
-
def filter_none(obj:
|
|
46
|
+
def filter_none(obj: list[Any]) -> list[Any]: ...
|
|
47
47
|
|
|
48
48
|
|
|
49
|
-
def filter_none(obj: Union[
|
|
49
|
+
def filter_none(obj: Union[dict[str, Any], list[Any]]) -> Union[dict[str, Any], list[Any]]:
|
|
50
50
|
if isinstance(obj, dict):
|
|
51
|
-
cleaned:
|
|
51
|
+
cleaned: dict[str, Any] = {}
|
|
52
52
|
for k, v in obj.items():
|
|
53
53
|
if v is None:
|
|
54
54
|
continue
|
|
@@ -75,11 +75,11 @@ class TaskProviderHelper:
|
|
|
75
75
|
self,
|
|
76
76
|
*,
|
|
77
77
|
inputs: Any,
|
|
78
|
-
parameters:
|
|
79
|
-
headers:
|
|
78
|
+
parameters: dict[str, Any],
|
|
79
|
+
headers: dict,
|
|
80
80
|
model: Optional[str],
|
|
81
81
|
api_key: Optional[str],
|
|
82
|
-
extra_payload: Optional[
|
|
82
|
+
extra_payload: Optional[dict[str, Any]] = None,
|
|
83
83
|
) -> RequestParameters:
|
|
84
84
|
"""
|
|
85
85
|
Prepare the request to be sent to the provider.
|
|
@@ -126,7 +126,7 @@ class TaskProviderHelper:
|
|
|
126
126
|
|
|
127
127
|
def get_response(
|
|
128
128
|
self,
|
|
129
|
-
response: Union[bytes,
|
|
129
|
+
response: Union[bytes, dict],
|
|
130
130
|
request_params: Optional[RequestParameters] = None,
|
|
131
131
|
) -> Any:
|
|
132
132
|
"""
|
|
@@ -185,8 +185,8 @@ class TaskProviderHelper:
|
|
|
185
185
|
return provider_mapping
|
|
186
186
|
|
|
187
187
|
def _normalize_headers(
|
|
188
|
-
self, headers:
|
|
189
|
-
) ->
|
|
188
|
+
self, headers: dict[str, Any], payload: Optional[dict[str, Any]], data: Optional[MimeBytes]
|
|
189
|
+
) -> dict[str, Any]:
|
|
190
190
|
"""Normalize the headers to use for the request.
|
|
191
191
|
|
|
192
192
|
Override this method in subclasses for customized headers.
|
|
@@ -199,7 +199,7 @@ class TaskProviderHelper:
|
|
|
199
199
|
normalized_headers["content-type"] = "application/json"
|
|
200
200
|
return normalized_headers
|
|
201
201
|
|
|
202
|
-
def _prepare_headers(self, headers:
|
|
202
|
+
def _prepare_headers(self, headers: dict, api_key: str) -> dict[str, Any]:
|
|
203
203
|
"""Return the headers to use for the request.
|
|
204
204
|
|
|
205
205
|
Override this method in subclasses for customized headers.
|
|
@@ -234,8 +234,8 @@ class TaskProviderHelper:
|
|
|
234
234
|
return ""
|
|
235
235
|
|
|
236
236
|
def _prepare_payload_as_dict(
|
|
237
|
-
self, inputs: Any, parameters:
|
|
238
|
-
) -> Optional[
|
|
237
|
+
self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping
|
|
238
|
+
) -> Optional[dict]:
|
|
239
239
|
"""Return the payload to use for the request, as a dict.
|
|
240
240
|
|
|
241
241
|
Override this method in subclasses for customized payloads.
|
|
@@ -246,9 +246,9 @@ class TaskProviderHelper:
|
|
|
246
246
|
def _prepare_payload_as_bytes(
|
|
247
247
|
self,
|
|
248
248
|
inputs: Any,
|
|
249
|
-
parameters:
|
|
249
|
+
parameters: dict,
|
|
250
250
|
provider_mapping_info: InferenceProviderMapping,
|
|
251
|
-
extra_payload: Optional[
|
|
251
|
+
extra_payload: Optional[dict],
|
|
252
252
|
) -> Optional[MimeBytes]:
|
|
253
253
|
"""Return the body to use for the request, as bytes.
|
|
254
254
|
|
|
@@ -272,13 +272,51 @@ class BaseConversationalTask(TaskProviderHelper):
|
|
|
272
272
|
|
|
273
273
|
def _prepare_payload_as_dict(
|
|
274
274
|
self,
|
|
275
|
-
inputs:
|
|
276
|
-
parameters:
|
|
275
|
+
inputs: list[Union[dict, ChatCompletionInputMessage]],
|
|
276
|
+
parameters: dict,
|
|
277
277
|
provider_mapping_info: InferenceProviderMapping,
|
|
278
|
-
) -> Optional[
|
|
278
|
+
) -> Optional[dict]:
|
|
279
279
|
return filter_none({"messages": inputs, **parameters, "model": provider_mapping_info.provider_id})
|
|
280
280
|
|
|
281
281
|
|
|
282
|
+
class AutoRouterConversationalTask(BaseConversationalTask):
|
|
283
|
+
"""
|
|
284
|
+
Auto-router for conversational tasks.
|
|
285
|
+
|
|
286
|
+
We let the Hugging Face router select the best provider for the model, based on availability and user preferences.
|
|
287
|
+
This is a special case since the selection is done server-side (avoid 1 API call to fetch provider mapping).
|
|
288
|
+
"""
|
|
289
|
+
|
|
290
|
+
def __init__(self):
|
|
291
|
+
super().__init__(provider="auto", base_url="https://router.huggingface.co")
|
|
292
|
+
|
|
293
|
+
def _prepare_base_url(self, api_key: str) -> str:
|
|
294
|
+
"""Return the base URL to use for the request.
|
|
295
|
+
|
|
296
|
+
Usually not overwritten in subclasses."""
|
|
297
|
+
# Route to the proxy if the api_key is a HF TOKEN
|
|
298
|
+
if not api_key.startswith("hf_"):
|
|
299
|
+
raise ValueError("Cannot select auto-router when using non-Hugging Face API key.")
|
|
300
|
+
else:
|
|
301
|
+
return self.base_url # No `/auto` suffix in the URL
|
|
302
|
+
|
|
303
|
+
def _prepare_mapping_info(self, model: Optional[str]) -> InferenceProviderMapping:
|
|
304
|
+
"""
|
|
305
|
+
In auto-router, we don't need to fetch provider mapping info.
|
|
306
|
+
We just return a dummy mapping info with provider_id set to the HF model ID.
|
|
307
|
+
"""
|
|
308
|
+
if model is None:
|
|
309
|
+
raise ValueError("Please provide an HF model ID.")
|
|
310
|
+
|
|
311
|
+
return InferenceProviderMapping(
|
|
312
|
+
provider="auto",
|
|
313
|
+
hf_model_id=model,
|
|
314
|
+
providerId=model,
|
|
315
|
+
status="live",
|
|
316
|
+
task="conversational",
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
|
|
282
320
|
class BaseTextGenerationTask(TaskProviderHelper):
|
|
283
321
|
"""
|
|
284
322
|
Base class for text-generation (completion) tasks.
|
|
@@ -292,13 +330,13 @@ class BaseTextGenerationTask(TaskProviderHelper):
|
|
|
292
330
|
return "/v1/completions"
|
|
293
331
|
|
|
294
332
|
def _prepare_payload_as_dict(
|
|
295
|
-
self, inputs: Any, parameters:
|
|
296
|
-
) -> Optional[
|
|
333
|
+
self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping
|
|
334
|
+
) -> Optional[dict]:
|
|
297
335
|
return filter_none({"prompt": inputs, **parameters, "model": provider_mapping_info.provider_id})
|
|
298
336
|
|
|
299
337
|
|
|
300
338
|
@lru_cache(maxsize=None)
|
|
301
|
-
def _fetch_inference_provider_mapping(model: str) ->
|
|
339
|
+
def _fetch_inference_provider_mapping(model: str) -> list["InferenceProviderMapping"]:
|
|
302
340
|
"""
|
|
303
341
|
Fetch provider mappings for a model from the Hub.
|
|
304
342
|
"""
|
|
@@ -311,7 +349,7 @@ def _fetch_inference_provider_mapping(model: str) -> List["InferenceProviderMapp
|
|
|
311
349
|
return provider_mapping
|
|
312
350
|
|
|
313
351
|
|
|
314
|
-
def recursive_merge(dict1:
|
|
352
|
+
def recursive_merge(dict1: dict, dict2: dict) -> dict:
|
|
315
353
|
return {
|
|
316
354
|
**dict1,
|
|
317
355
|
**{
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import time
|
|
2
|
-
from typing import Any,
|
|
2
|
+
from typing import Any, Optional, Union
|
|
3
3
|
|
|
4
4
|
from huggingface_hub.hf_api import InferenceProviderMapping
|
|
5
5
|
from huggingface_hub.inference._common import RequestParameters, _as_dict
|
|
@@ -18,7 +18,7 @@ class BlackForestLabsTextToImageTask(TaskProviderHelper):
|
|
|
18
18
|
def __init__(self):
|
|
19
19
|
super().__init__(provider="black-forest-labs", base_url="https://api.us1.bfl.ai", task="text-to-image")
|
|
20
20
|
|
|
21
|
-
def _prepare_headers(self, headers:
|
|
21
|
+
def _prepare_headers(self, headers: dict, api_key: str) -> dict[str, Any]:
|
|
22
22
|
headers = super()._prepare_headers(headers, api_key)
|
|
23
23
|
if not api_key.startswith("hf_"):
|
|
24
24
|
_ = headers.pop("authorization")
|
|
@@ -29,8 +29,8 @@ class BlackForestLabsTextToImageTask(TaskProviderHelper):
|
|
|
29
29
|
return f"/v1/{mapped_model}"
|
|
30
30
|
|
|
31
31
|
def _prepare_payload_as_dict(
|
|
32
|
-
self, inputs: Any, parameters:
|
|
33
|
-
) -> Optional[
|
|
32
|
+
self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping
|
|
33
|
+
) -> Optional[dict]:
|
|
34
34
|
parameters = filter_none(parameters)
|
|
35
35
|
if "num_inference_steps" in parameters:
|
|
36
36
|
parameters["steps"] = parameters.pop("num_inference_steps")
|
|
@@ -39,7 +39,7 @@ class BlackForestLabsTextToImageTask(TaskProviderHelper):
|
|
|
39
39
|
|
|
40
40
|
return {"prompt": inputs, **parameters}
|
|
41
41
|
|
|
42
|
-
def get_response(self, response: Union[bytes,
|
|
42
|
+
def get_response(self, response: Union[bytes, dict], request_params: Optional[RequestParameters] = None) -> Any:
|
|
43
43
|
"""
|
|
44
44
|
Polling mechanism for Black Forest Labs since the API is asynchronous.
|
|
45
45
|
"""
|
|
@@ -50,7 +50,7 @@ class BlackForestLabsTextToImageTask(TaskProviderHelper):
|
|
|
50
50
|
|
|
51
51
|
response = session.get(url, headers={"Content-Type": "application/json"}) # type: ignore
|
|
52
52
|
response.raise_for_status() # type: ignore
|
|
53
|
-
response_json:
|
|
53
|
+
response_json: dict = response.json() # type: ignore
|
|
54
54
|
status = response_json.get("status")
|
|
55
55
|
logger.info(
|
|
56
56
|
f"Polling generation result from {url}. Current status: {status}. "
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Any,
|
|
1
|
+
from typing import Any, Optional
|
|
2
2
|
|
|
3
3
|
from huggingface_hub.hf_api import InferenceProviderMapping
|
|
4
4
|
|
|
@@ -17,8 +17,8 @@ class CohereConversationalTask(BaseConversationalTask):
|
|
|
17
17
|
return "/compatibility/v1/chat/completions"
|
|
18
18
|
|
|
19
19
|
def _prepare_payload_as_dict(
|
|
20
|
-
self, inputs: Any, parameters:
|
|
21
|
-
) -> Optional[
|
|
20
|
+
self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping
|
|
21
|
+
) -> Optional[dict]:
|
|
22
22
|
payload = super()._prepare_payload_as_dict(inputs, parameters, provider_mapping_info)
|
|
23
23
|
response_format = parameters.get("response_format")
|
|
24
24
|
if isinstance(response_format, dict) and response_format.get("type") == "json_schema":
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import base64
|
|
2
2
|
import time
|
|
3
3
|
from abc import ABC
|
|
4
|
-
from typing import Any,
|
|
4
|
+
from typing import Any, Optional, Union
|
|
5
5
|
from urllib.parse import urlparse
|
|
6
6
|
|
|
7
7
|
from huggingface_hub import constants
|
|
@@ -22,7 +22,7 @@ class FalAITask(TaskProviderHelper, ABC):
|
|
|
22
22
|
def __init__(self, task: str):
|
|
23
23
|
super().__init__(provider="fal-ai", base_url="https://fal.run", task=task)
|
|
24
24
|
|
|
25
|
-
def _prepare_headers(self, headers:
|
|
25
|
+
def _prepare_headers(self, headers: dict, api_key: str) -> dict[str, Any]:
|
|
26
26
|
headers = super()._prepare_headers(headers, api_key)
|
|
27
27
|
if not api_key.startswith("hf_"):
|
|
28
28
|
headers["authorization"] = f"Key {api_key}"
|
|
@@ -36,7 +36,7 @@ class FalAIQueueTask(TaskProviderHelper, ABC):
|
|
|
36
36
|
def __init__(self, task: str):
|
|
37
37
|
super().__init__(provider="fal-ai", base_url="https://queue.fal.run", task=task)
|
|
38
38
|
|
|
39
|
-
def _prepare_headers(self, headers:
|
|
39
|
+
def _prepare_headers(self, headers: dict, api_key: str) -> dict[str, Any]:
|
|
40
40
|
headers = super()._prepare_headers(headers, api_key)
|
|
41
41
|
if not api_key.startswith("hf_"):
|
|
42
42
|
headers["authorization"] = f"Key {api_key}"
|
|
@@ -50,7 +50,7 @@ class FalAIQueueTask(TaskProviderHelper, ABC):
|
|
|
50
50
|
|
|
51
51
|
def get_response(
|
|
52
52
|
self,
|
|
53
|
-
response: Union[bytes,
|
|
53
|
+
response: Union[bytes, dict],
|
|
54
54
|
request_params: Optional[RequestParameters] = None,
|
|
55
55
|
) -> Any:
|
|
56
56
|
response_dict = _as_dict(response)
|
|
@@ -91,8 +91,8 @@ class FalAIAutomaticSpeechRecognitionTask(FalAITask):
|
|
|
91
91
|
super().__init__("automatic-speech-recognition")
|
|
92
92
|
|
|
93
93
|
def _prepare_payload_as_dict(
|
|
94
|
-
self, inputs: Any, parameters:
|
|
95
|
-
) -> Optional[
|
|
94
|
+
self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping
|
|
95
|
+
) -> Optional[dict]:
|
|
96
96
|
if isinstance(inputs, str) and inputs.startswith(("http://", "https://")):
|
|
97
97
|
# If input is a URL, pass it directly
|
|
98
98
|
audio_url = inputs
|
|
@@ -108,7 +108,7 @@ class FalAIAutomaticSpeechRecognitionTask(FalAITask):
|
|
|
108
108
|
|
|
109
109
|
return {"audio_url": audio_url, **filter_none(parameters)}
|
|
110
110
|
|
|
111
|
-
def get_response(self, response: Union[bytes,
|
|
111
|
+
def get_response(self, response: Union[bytes, dict], request_params: Optional[RequestParameters] = None) -> Any:
|
|
112
112
|
text = _as_dict(response)["text"]
|
|
113
113
|
if not isinstance(text, str):
|
|
114
114
|
raise ValueError(f"Unexpected output format from FalAI API. Expected string, got {type(text)}.")
|
|
@@ -120,9 +120,9 @@ class FalAITextToImageTask(FalAITask):
|
|
|
120
120
|
super().__init__("text-to-image")
|
|
121
121
|
|
|
122
122
|
def _prepare_payload_as_dict(
|
|
123
|
-
self, inputs: Any, parameters:
|
|
124
|
-
) -> Optional[
|
|
125
|
-
payload:
|
|
123
|
+
self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping
|
|
124
|
+
) -> Optional[dict]:
|
|
125
|
+
payload: dict[str, Any] = {
|
|
126
126
|
"prompt": inputs,
|
|
127
127
|
**filter_none(parameters),
|
|
128
128
|
}
|
|
@@ -145,7 +145,7 @@ class FalAITextToImageTask(FalAITask):
|
|
|
145
145
|
|
|
146
146
|
return payload
|
|
147
147
|
|
|
148
|
-
def get_response(self, response: Union[bytes,
|
|
148
|
+
def get_response(self, response: Union[bytes, dict], request_params: Optional[RequestParameters] = None) -> Any:
|
|
149
149
|
url = _as_dict(response)["images"][0]["url"]
|
|
150
150
|
return get_session().get(url).content
|
|
151
151
|
|
|
@@ -155,11 +155,11 @@ class FalAITextToSpeechTask(FalAITask):
|
|
|
155
155
|
super().__init__("text-to-speech")
|
|
156
156
|
|
|
157
157
|
def _prepare_payload_as_dict(
|
|
158
|
-
self, inputs: Any, parameters:
|
|
159
|
-
) -> Optional[
|
|
158
|
+
self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping
|
|
159
|
+
) -> Optional[dict]:
|
|
160
160
|
return {"text": inputs, **filter_none(parameters)}
|
|
161
161
|
|
|
162
|
-
def get_response(self, response: Union[bytes,
|
|
162
|
+
def get_response(self, response: Union[bytes, dict], request_params: Optional[RequestParameters] = None) -> Any:
|
|
163
163
|
url = _as_dict(response)["audio"]["url"]
|
|
164
164
|
return get_session().get(url).content
|
|
165
165
|
|
|
@@ -169,13 +169,13 @@ class FalAITextToVideoTask(FalAIQueueTask):
|
|
|
169
169
|
super().__init__("text-to-video")
|
|
170
170
|
|
|
171
171
|
def _prepare_payload_as_dict(
|
|
172
|
-
self, inputs: Any, parameters:
|
|
173
|
-
) -> Optional[
|
|
172
|
+
self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping
|
|
173
|
+
) -> Optional[dict]:
|
|
174
174
|
return {"prompt": inputs, **filter_none(parameters)}
|
|
175
175
|
|
|
176
176
|
def get_response(
|
|
177
177
|
self,
|
|
178
|
-
response: Union[bytes,
|
|
178
|
+
response: Union[bytes, dict],
|
|
179
179
|
request_params: Optional[RequestParameters] = None,
|
|
180
180
|
) -> Any:
|
|
181
181
|
output = super().get_response(response, request_params)
|
|
@@ -188,12 +188,12 @@ class FalAIImageToImageTask(FalAIQueueTask):
|
|
|
188
188
|
super().__init__("image-to-image")
|
|
189
189
|
|
|
190
190
|
def _prepare_payload_as_dict(
|
|
191
|
-
self, inputs: Any, parameters:
|
|
192
|
-
) -> Optional[
|
|
191
|
+
self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping
|
|
192
|
+
) -> Optional[dict]:
|
|
193
193
|
image_url = _as_url(inputs, default_mime_type="image/jpeg")
|
|
194
194
|
if "target_size" in parameters:
|
|
195
195
|
parameters["image_size"] = parameters.pop("target_size")
|
|
196
|
-
payload:
|
|
196
|
+
payload: dict[str, Any] = {
|
|
197
197
|
"image_url": image_url,
|
|
198
198
|
**filter_none(parameters),
|
|
199
199
|
}
|
|
@@ -209,7 +209,7 @@ class FalAIImageToImageTask(FalAIQueueTask):
|
|
|
209
209
|
|
|
210
210
|
def get_response(
|
|
211
211
|
self,
|
|
212
|
-
response: Union[bytes,
|
|
212
|
+
response: Union[bytes, dict],
|
|
213
213
|
request_params: Optional[RequestParameters] = None,
|
|
214
214
|
) -> Any:
|
|
215
215
|
output = super().get_response(response, request_params)
|
|
@@ -222,10 +222,10 @@ class FalAIImageToVideoTask(FalAIQueueTask):
|
|
|
222
222
|
super().__init__("image-to-video")
|
|
223
223
|
|
|
224
224
|
def _prepare_payload_as_dict(
|
|
225
|
-
self, inputs: Any, parameters:
|
|
226
|
-
) -> Optional[
|
|
225
|
+
self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping
|
|
226
|
+
) -> Optional[dict]:
|
|
227
227
|
image_url = _as_url(inputs, default_mime_type="image/jpeg")
|
|
228
|
-
payload:
|
|
228
|
+
payload: dict[str, Any] = {
|
|
229
229
|
"image_url": image_url,
|
|
230
230
|
**filter_none(parameters),
|
|
231
231
|
}
|
|
@@ -240,7 +240,7 @@ class FalAIImageToVideoTask(FalAIQueueTask):
|
|
|
240
240
|
|
|
241
241
|
def get_response(
|
|
242
242
|
self,
|
|
243
|
-
response: Union[bytes,
|
|
243
|
+
response: Union[bytes, dict],
|
|
244
244
|
request_params: Optional[RequestParameters] = None,
|
|
245
245
|
) -> Any:
|
|
246
246
|
output = super().get_response(response, request_params)
|