huggingface-hub 0.35.0rc0__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of huggingface-hub might be problematic. Click here for more details.

Files changed (127) hide show
  1. huggingface_hub/__init__.py +46 -45
  2. huggingface_hub/_commit_api.py +28 -28
  3. huggingface_hub/_commit_scheduler.py +11 -8
  4. huggingface_hub/_inference_endpoints.py +8 -8
  5. huggingface_hub/_jobs_api.py +176 -20
  6. huggingface_hub/_local_folder.py +1 -1
  7. huggingface_hub/_login.py +13 -39
  8. huggingface_hub/_oauth.py +10 -14
  9. huggingface_hub/_snapshot_download.py +14 -28
  10. huggingface_hub/_space_api.py +4 -4
  11. huggingface_hub/_tensorboard_logger.py +13 -14
  12. huggingface_hub/_upload_large_folder.py +120 -13
  13. huggingface_hub/_webhooks_payload.py +3 -3
  14. huggingface_hub/_webhooks_server.py +2 -2
  15. huggingface_hub/cli/_cli_utils.py +2 -2
  16. huggingface_hub/cli/auth.py +8 -6
  17. huggingface_hub/cli/cache.py +18 -20
  18. huggingface_hub/cli/download.py +4 -4
  19. huggingface_hub/cli/hf.py +2 -5
  20. huggingface_hub/cli/jobs.py +599 -22
  21. huggingface_hub/cli/lfs.py +4 -4
  22. huggingface_hub/cli/repo.py +11 -7
  23. huggingface_hub/cli/repo_files.py +2 -2
  24. huggingface_hub/cli/upload.py +4 -4
  25. huggingface_hub/cli/upload_large_folder.py +3 -3
  26. huggingface_hub/commands/_cli_utils.py +2 -2
  27. huggingface_hub/commands/delete_cache.py +13 -13
  28. huggingface_hub/commands/download.py +4 -13
  29. huggingface_hub/commands/lfs.py +4 -4
  30. huggingface_hub/commands/repo_files.py +2 -2
  31. huggingface_hub/commands/scan_cache.py +1 -1
  32. huggingface_hub/commands/tag.py +1 -3
  33. huggingface_hub/commands/upload.py +4 -4
  34. huggingface_hub/commands/upload_large_folder.py +3 -3
  35. huggingface_hub/commands/user.py +4 -5
  36. huggingface_hub/community.py +5 -5
  37. huggingface_hub/constants.py +3 -41
  38. huggingface_hub/dataclasses.py +16 -19
  39. huggingface_hub/errors.py +42 -29
  40. huggingface_hub/fastai_utils.py +8 -9
  41. huggingface_hub/file_download.py +162 -259
  42. huggingface_hub/hf_api.py +841 -616
  43. huggingface_hub/hf_file_system.py +98 -62
  44. huggingface_hub/hub_mixin.py +37 -57
  45. huggingface_hub/inference/_client.py +257 -325
  46. huggingface_hub/inference/_common.py +110 -124
  47. huggingface_hub/inference/_generated/_async_client.py +307 -432
  48. huggingface_hub/inference/_generated/types/automatic_speech_recognition.py +3 -3
  49. huggingface_hub/inference/_generated/types/base.py +10 -7
  50. huggingface_hub/inference/_generated/types/chat_completion.py +18 -16
  51. huggingface_hub/inference/_generated/types/depth_estimation.py +2 -2
  52. huggingface_hub/inference/_generated/types/document_question_answering.py +2 -2
  53. huggingface_hub/inference/_generated/types/feature_extraction.py +2 -2
  54. huggingface_hub/inference/_generated/types/fill_mask.py +2 -2
  55. huggingface_hub/inference/_generated/types/sentence_similarity.py +3 -3
  56. huggingface_hub/inference/_generated/types/summarization.py +2 -2
  57. huggingface_hub/inference/_generated/types/table_question_answering.py +4 -4
  58. huggingface_hub/inference/_generated/types/text2text_generation.py +2 -2
  59. huggingface_hub/inference/_generated/types/text_generation.py +10 -10
  60. huggingface_hub/inference/_generated/types/text_to_video.py +2 -2
  61. huggingface_hub/inference/_generated/types/token_classification.py +2 -2
  62. huggingface_hub/inference/_generated/types/translation.py +2 -2
  63. huggingface_hub/inference/_generated/types/zero_shot_classification.py +2 -2
  64. huggingface_hub/inference/_generated/types/zero_shot_image_classification.py +2 -2
  65. huggingface_hub/inference/_generated/types/zero_shot_object_detection.py +1 -3
  66. huggingface_hub/inference/_mcp/_cli_hacks.py +3 -3
  67. huggingface_hub/inference/_mcp/agent.py +3 -3
  68. huggingface_hub/inference/_mcp/cli.py +1 -1
  69. huggingface_hub/inference/_mcp/constants.py +2 -3
  70. huggingface_hub/inference/_mcp/mcp_client.py +58 -30
  71. huggingface_hub/inference/_mcp/types.py +10 -7
  72. huggingface_hub/inference/_mcp/utils.py +11 -7
  73. huggingface_hub/inference/_providers/__init__.py +4 -2
  74. huggingface_hub/inference/_providers/_common.py +49 -25
  75. huggingface_hub/inference/_providers/black_forest_labs.py +6 -6
  76. huggingface_hub/inference/_providers/cohere.py +3 -3
  77. huggingface_hub/inference/_providers/fal_ai.py +52 -21
  78. huggingface_hub/inference/_providers/featherless_ai.py +4 -4
  79. huggingface_hub/inference/_providers/fireworks_ai.py +3 -3
  80. huggingface_hub/inference/_providers/hf_inference.py +28 -20
  81. huggingface_hub/inference/_providers/hyperbolic.py +4 -4
  82. huggingface_hub/inference/_providers/nebius.py +10 -10
  83. huggingface_hub/inference/_providers/novita.py +5 -5
  84. huggingface_hub/inference/_providers/nscale.py +4 -4
  85. huggingface_hub/inference/_providers/replicate.py +15 -15
  86. huggingface_hub/inference/_providers/sambanova.py +6 -6
  87. huggingface_hub/inference/_providers/together.py +7 -7
  88. huggingface_hub/lfs.py +20 -31
  89. huggingface_hub/repocard.py +18 -18
  90. huggingface_hub/repocard_data.py +56 -56
  91. huggingface_hub/serialization/__init__.py +0 -1
  92. huggingface_hub/serialization/_base.py +9 -9
  93. huggingface_hub/serialization/_dduf.py +7 -7
  94. huggingface_hub/serialization/_torch.py +28 -28
  95. huggingface_hub/utils/__init__.py +10 -4
  96. huggingface_hub/utils/_auth.py +5 -5
  97. huggingface_hub/utils/_cache_manager.py +31 -31
  98. huggingface_hub/utils/_deprecation.py +1 -1
  99. huggingface_hub/utils/_dotenv.py +25 -21
  100. huggingface_hub/utils/_fixes.py +0 -10
  101. huggingface_hub/utils/_git_credential.py +4 -4
  102. huggingface_hub/utils/_headers.py +7 -29
  103. huggingface_hub/utils/_http.py +366 -208
  104. huggingface_hub/utils/_pagination.py +4 -4
  105. huggingface_hub/utils/_paths.py +5 -5
  106. huggingface_hub/utils/_runtime.py +16 -13
  107. huggingface_hub/utils/_safetensors.py +21 -21
  108. huggingface_hub/utils/_subprocess.py +9 -9
  109. huggingface_hub/utils/_telemetry.py +3 -3
  110. huggingface_hub/utils/_typing.py +25 -5
  111. huggingface_hub/utils/_validators.py +53 -72
  112. huggingface_hub/utils/_xet.py +16 -16
  113. huggingface_hub/utils/_xet_progress_reporting.py +32 -11
  114. huggingface_hub/utils/insecure_hashlib.py +3 -9
  115. huggingface_hub/utils/tqdm.py +3 -3
  116. {huggingface_hub-0.35.0rc0.dist-info → huggingface_hub-1.0.0rc0.dist-info}/METADATA +18 -29
  117. huggingface_hub-1.0.0rc0.dist-info/RECORD +161 -0
  118. huggingface_hub/inference_api.py +0 -217
  119. huggingface_hub/keras_mixin.py +0 -500
  120. huggingface_hub/repository.py +0 -1477
  121. huggingface_hub/serialization/_tensorflow.py +0 -95
  122. huggingface_hub/utils/_hf_folder.py +0 -68
  123. huggingface_hub-0.35.0rc0.dist-info/RECORD +0 -166
  124. {huggingface_hub-0.35.0rc0.dist-info → huggingface_hub-1.0.0rc0.dist-info}/LICENSE +0 -0
  125. {huggingface_hub-0.35.0rc0.dist-info → huggingface_hub-1.0.0rc0.dist-info}/WHEEL +0 -0
  126. {huggingface_hub-0.35.0rc0.dist-info → huggingface_hub-1.0.0rc0.dist-info}/entry_points.txt +0 -0
  127. {huggingface_hub-0.35.0rc0.dist-info → huggingface_hub-1.0.0rc0.dist-info}/top_level.txt +0 -0
@@ -3,9 +3,9 @@ import logging
3
3
  from contextlib import AsyncExitStack
4
4
  from datetime import timedelta
5
5
  from pathlib import Path
6
- from typing import TYPE_CHECKING, Any, AsyncIterable, Dict, List, Literal, Optional, Union, overload
6
+ from typing import TYPE_CHECKING, Any, AsyncIterable, Literal, Optional, TypedDict, Union, overload
7
7
 
8
- from typing_extensions import NotRequired, TypeAlias, TypedDict, Unpack
8
+ from typing_extensions import NotRequired, TypeAlias, Unpack
9
9
 
10
10
  from ...utils._runtime import get_hf_hub_version
11
11
  from .._generated._async_client import AsyncInferenceClient
@@ -32,14 +32,14 @@ ServerType: TypeAlias = Literal["stdio", "sse", "http"]
32
32
 
33
33
  class StdioServerParameters_T(TypedDict):
34
34
  command: str
35
- args: NotRequired[List[str]]
36
- env: NotRequired[Dict[str, str]]
35
+ args: NotRequired[list[str]]
36
+ env: NotRequired[dict[str, str]]
37
37
  cwd: NotRequired[Union[str, Path, None]]
38
38
 
39
39
 
40
40
  class SSEServerParameters_T(TypedDict):
41
41
  url: str
42
- headers: NotRequired[Dict[str, Any]]
42
+ headers: NotRequired[dict[str, Any]]
43
43
  timeout: NotRequired[float]
44
44
  sse_read_timeout: NotRequired[float]
45
45
 
@@ -84,9 +84,9 @@ class MCPClient:
84
84
  api_key: Optional[str] = None,
85
85
  ):
86
86
  # Initialize MCP sessions as a dictionary of ClientSession objects
87
- self.sessions: Dict[ToolName, "ClientSession"] = {}
87
+ self.sessions: dict[ToolName, "ClientSession"] = {}
88
88
  self.exit_stack = AsyncExitStack()
89
- self.available_tools: List[ChatCompletionInputTool] = []
89
+ self.available_tools: list[ChatCompletionInputTool] = []
90
90
  # To be able to send the model in the payload if `base_url` is provided
91
91
  if model is None and base_url is None:
92
92
  raise ValueError("At least one of `model` or `base_url` should be set in `MCPClient`.")
@@ -132,28 +132,34 @@ class MCPClient:
132
132
  - "stdio": Standard input/output server (local)
133
133
  - "sse": Server-sent events (SSE) server
134
134
  - "http": StreamableHTTP server
135
- **params (`Dict[str, Any]`):
135
+ **params (`dict[str, Any]`):
136
136
  Server parameters that can be either:
137
137
  - For stdio servers:
138
138
  - command (str): The command to run the MCP server
139
- - args (List[str], optional): Arguments for the command
140
- - env (Dict[str, str], optional): Environment variables for the command
139
+ - args (list[str], optional): Arguments for the command
140
+ - env (dict[str, str], optional): Environment variables for the command
141
141
  - cwd (Union[str, Path, None], optional): Working directory for the command
142
+ - allowed_tools (list[str], optional): List of tool names to allow from this server
142
143
  - For SSE servers:
143
144
  - url (str): The URL of the SSE server
144
- - headers (Dict[str, Any], optional): Headers for the SSE connection
145
+ - headers (dict[str, Any], optional): Headers for the SSE connection
145
146
  - timeout (float, optional): Connection timeout
146
147
  - sse_read_timeout (float, optional): SSE read timeout
148
+ - allowed_tools (list[str], optional): List of tool names to allow from this server
147
149
  - For StreamableHTTP servers:
148
150
  - url (str): The URL of the StreamableHTTP server
149
- - headers (Dict[str, Any], optional): Headers for the StreamableHTTP connection
151
+ - headers (dict[str, Any], optional): Headers for the StreamableHTTP connection
150
152
  - timeout (timedelta, optional): Connection timeout
151
153
  - sse_read_timeout (timedelta, optional): SSE read timeout
152
154
  - terminate_on_close (bool, optional): Whether to terminate on close
155
+ - allowed_tools (list[str], optional): List of tool names to allow from this server
153
156
  """
154
157
  from mcp import ClientSession, StdioServerParameters
155
158
  from mcp import types as mcp_types
156
159
 
160
+ # Extract allowed_tools configuration if provided
161
+ allowed_tools = params.pop("allowed_tools", [])
162
+
157
163
  # Determine server type and create appropriate parameters
158
164
  if type == "stdio":
159
165
  # Handle stdio server
@@ -211,7 +217,15 @@ class MCPClient:
211
217
  response = await session.list_tools()
212
218
  logger.debug("Connected to server with tools:", [tool.name for tool in response.tools])
213
219
 
214
- for tool in response.tools:
220
+ # Filter tools based on allowed_tools configuration
221
+ filtered_tools = [tool for tool in response.tools if tool.name in allowed_tools]
222
+
223
+ if allowed_tools:
224
+ logger.debug(
225
+ f"Tool filtering applied. Using {len(filtered_tools)} of {len(response.tools)} available tools: {[tool.name for tool in filtered_tools]}"
226
+ )
227
+
228
+ for tool in filtered_tools:
215
229
  if tool.name in self.sessions:
216
230
  logger.warning(f"Tool '{tool.name}' already defined by another server. Skipping.")
217
231
  continue
@@ -235,16 +249,16 @@ class MCPClient:
235
249
 
236
250
  async def process_single_turn_with_tools(
237
251
  self,
238
- messages: List[Union[Dict, ChatCompletionInputMessage]],
239
- exit_loop_tools: Optional[List[ChatCompletionInputTool]] = None,
252
+ messages: list[Union[dict, ChatCompletionInputMessage]],
253
+ exit_loop_tools: Optional[list[ChatCompletionInputTool]] = None,
240
254
  exit_if_first_chunk_no_tool: bool = False,
241
255
  ) -> AsyncIterable[Union[ChatCompletionStreamOutput, ChatCompletionInputMessage]]:
242
256
  """Process a query using `self.model` and available tools, yielding chunks and tool outputs.
243
257
 
244
258
  Args:
245
- messages (`List[Dict]`):
259
+ messages (`list[dict]`):
246
260
  List of message objects representing the conversation history
247
- exit_loop_tools (`List[ChatCompletionInputTool]`, *optional*):
261
+ exit_loop_tools (`list[ChatCompletionInputTool]`, *optional*):
248
262
  List of tools that should exit the generator when called
249
263
  exit_if_first_chunk_no_tool (`bool`, *optional*):
250
264
  Exit if no tool is present in the first chunks. Default to False.
@@ -266,8 +280,8 @@ class MCPClient:
266
280
  stream=True,
267
281
  )
268
282
 
269
- message: Dict[str, Any] = {"role": "unknown", "content": ""}
270
- final_tool_calls: Dict[int, ChatCompletionStreamOutputDeltaToolCall] = {}
283
+ message: dict[str, Any] = {"role": "unknown", "content": ""}
284
+ final_tool_calls: dict[int, ChatCompletionStreamOutputDeltaToolCall] = {}
271
285
  num_of_chunks = 0
272
286
 
273
287
  # Read from stream
@@ -286,16 +300,19 @@ class MCPClient:
286
300
  # Process tool calls
287
301
  if delta.tool_calls:
288
302
  for tool_call in delta.tool_calls:
289
- # Aggregate chunks into tool calls
290
- if tool_call.index not in final_tool_calls:
291
- if (
292
- tool_call.function.arguments is None or tool_call.function.arguments == "{}"
293
- ): # Corner case (depends on provider)
294
- tool_call.function.arguments = ""
295
- final_tool_calls[tool_call.index] = tool_call
296
-
297
- elif tool_call.function.arguments:
298
- final_tool_calls[tool_call.index].function.arguments += tool_call.function.arguments
303
+ idx = tool_call.index
304
+ # first chunk for this tool call
305
+ if idx not in final_tool_calls:
306
+ final_tool_calls[idx] = tool_call
307
+ if final_tool_calls[idx].function.arguments is None:
308
+ final_tool_calls[idx].function.arguments = ""
309
+ continue
310
+ # safety before concatenating text to .function.arguments
311
+ if final_tool_calls[idx].function.arguments is None:
312
+ final_tool_calls[idx].function.arguments = ""
313
+
314
+ if tool_call.function.arguments:
315
+ final_tool_calls[idx].function.arguments += tool_call.function.arguments
299
316
 
300
317
  # Optionally exit early if no tools in first chunks
301
318
  if exit_if_first_chunk_no_tool and num_of_chunks <= 2 and len(final_tool_calls) == 0:
@@ -311,7 +328,7 @@ class MCPClient:
311
328
  message["role"] = "assistant"
312
329
  # Convert final_tool_calls to the format expected by OpenAI
313
330
  if final_tool_calls:
314
- tool_calls_list: List[Dict[str, Any]] = []
331
+ tool_calls_list: list[dict[str, Any]] = []
315
332
  for tc in final_tool_calls.values():
316
333
  tool_calls_list.append(
317
334
  {
@@ -329,6 +346,17 @@ class MCPClient:
329
346
  # Process tool calls one by one
330
347
  for tool_call in final_tool_calls.values():
331
348
  function_name = tool_call.function.name
349
+ if function_name is None:
350
+ message = ChatCompletionInputMessage.parse_obj_as_instance(
351
+ {
352
+ "role": "tool",
353
+ "tool_call_id": tool_call.id,
354
+ "content": "Invalid tool call with no function name.",
355
+ }
356
+ )
357
+ messages.append(message)
358
+ yield message
359
+ continue # move to next tool call
332
360
  try:
333
361
  function_args = json.loads(tool_call.function.arguments or "{}")
334
362
  except json.JSONDecodeError as err:
@@ -1,4 +1,4 @@
1
- from typing import Dict, List, Literal, TypedDict, Union
1
+ from typing import Literal, TypedDict, Union
2
2
 
3
3
  from typing_extensions import NotRequired
4
4
 
@@ -13,21 +13,24 @@ class InputConfig(TypedDict, total=False):
13
13
  class StdioServerConfig(TypedDict):
14
14
  type: Literal["stdio"]
15
15
  command: str
16
- args: List[str]
17
- env: Dict[str, str]
16
+ args: list[str]
17
+ env: dict[str, str]
18
18
  cwd: str
19
+ allowed_tools: NotRequired[list[str]]
19
20
 
20
21
 
21
22
  class HTTPServerConfig(TypedDict):
22
23
  type: Literal["http"]
23
24
  url: str
24
- headers: Dict[str, str]
25
+ headers: dict[str, str]
26
+ allowed_tools: NotRequired[list[str]]
25
27
 
26
28
 
27
29
  class SSEServerConfig(TypedDict):
28
30
  type: Literal["sse"]
29
31
  url: str
30
- headers: Dict[str, str]
32
+ headers: dict[str, str]
33
+ allowed_tools: NotRequired[list[str]]
31
34
 
32
35
 
33
36
  ServerConfig = Union[StdioServerConfig, HTTPServerConfig, SSEServerConfig]
@@ -38,5 +41,5 @@ class AgentConfig(TypedDict):
38
41
  model: str
39
42
  provider: str
40
43
  apiKey: NotRequired[str]
41
- inputs: List[InputConfig]
42
- servers: List[ServerConfig]
44
+ inputs: list[InputConfig]
45
+ servers: list[ServerConfig]
@@ -6,12 +6,12 @@ Formatting utilities taken from the JS SDK: https://github.com/huggingface/huggi
6
6
 
7
7
  import json
8
8
  from pathlib import Path
9
- from typing import TYPE_CHECKING, List, Optional, Tuple
9
+ from typing import TYPE_CHECKING, Optional
10
10
 
11
11
  from huggingface_hub import snapshot_download
12
12
  from huggingface_hub.errors import EntryNotFoundError
13
13
 
14
- from .constants import DEFAULT_AGENT, DEFAULT_REPO_ID, FILENAME_CONFIG, FILENAME_PROMPT
14
+ from .constants import DEFAULT_AGENT, DEFAULT_REPO_ID, FILENAME_CONFIG, PROMPT_FILENAMES
15
15
  from .types import AgentConfig
16
16
 
17
17
 
@@ -36,7 +36,7 @@ def format_result(result: "mcp_types.CallToolResult") -> str:
36
36
  if len(content) == 0:
37
37
  return "[No content]"
38
38
 
39
- formatted_parts: List[str] = []
39
+ formatted_parts: list[str] = []
40
40
 
41
41
  for item in content:
42
42
  if item.type == "text":
@@ -84,17 +84,21 @@ def _get_base64_size(base64_str: str) -> int:
84
84
  return (len(base64_str) * 3) // 4 - padding
85
85
 
86
86
 
87
- def _load_agent_config(agent_path: Optional[str]) -> Tuple[AgentConfig, Optional[str]]:
87
+ def _load_agent_config(agent_path: Optional[str]) -> tuple[AgentConfig, Optional[str]]:
88
88
  """Load server config and prompt."""
89
89
 
90
- def _read_dir(directory: Path) -> Tuple[AgentConfig, Optional[str]]:
90
+ def _read_dir(directory: Path) -> tuple[AgentConfig, Optional[str]]:
91
91
  cfg_file = directory / FILENAME_CONFIG
92
92
  if not cfg_file.exists():
93
93
  raise FileNotFoundError(f" Config file not found in {directory}! Please make sure it exists locally")
94
94
 
95
95
  config: AgentConfig = json.loads(cfg_file.read_text(encoding="utf-8"))
96
- prompt_file = directory / FILENAME_PROMPT
97
- prompt: Optional[str] = prompt_file.read_text(encoding="utf-8") if prompt_file.exists() else None
96
+ prompt: Optional[str] = None
97
+ for filename in PROMPT_FILENAMES:
98
+ prompt_file = directory / filename
99
+ if prompt_file.exists():
100
+ prompt = prompt_file.read_text(encoding="utf-8")
101
+ break
98
102
  return config, prompt
99
103
 
100
104
  if agent_path is None:
@@ -1,4 +1,4 @@
1
- from typing import Dict, Literal, Optional, Union
1
+ from typing import Literal, Optional, Union
2
2
 
3
3
  from huggingface_hub.inference._providers.featherless_ai import (
4
4
  FeatherlessConversationalTask,
@@ -13,6 +13,7 @@ from .cohere import CohereConversationalTask
13
13
  from .fal_ai import (
14
14
  FalAIAutomaticSpeechRecognitionTask,
15
15
  FalAIImageToImageTask,
16
+ FalAIImageToVideoTask,
16
17
  FalAITextToImageTask,
17
18
  FalAITextToSpeechTask,
18
19
  FalAITextToVideoTask,
@@ -64,7 +65,7 @@ PROVIDER_T = Literal[
64
65
 
65
66
  PROVIDER_OR_POLICY_T = Union[PROVIDER_T, Literal["auto"]]
66
67
 
67
- PROVIDERS: Dict[PROVIDER_T, Dict[str, TaskProviderHelper]] = {
68
+ PROVIDERS: dict[PROVIDER_T, dict[str, TaskProviderHelper]] = {
68
69
  "black-forest-labs": {
69
70
  "text-to-image": BlackForestLabsTextToImageTask(),
70
71
  },
@@ -79,6 +80,7 @@ PROVIDERS: Dict[PROVIDER_T, Dict[str, TaskProviderHelper]] = {
79
80
  "text-to-image": FalAITextToImageTask(),
80
81
  "text-to-speech": FalAITextToSpeechTask(),
81
82
  "text-to-video": FalAITextToVideoTask(),
83
+ "image-to-video": FalAIImageToVideoTask(),
82
84
  "image-to-image": FalAIImageToImageTask(),
83
85
  },
84
86
  "featherless-ai": {
@@ -1,9 +1,9 @@
1
1
  from functools import lru_cache
2
- from typing import Any, Dict, List, Optional, Union, overload
2
+ from typing import Any, Optional, Union, overload
3
3
 
4
4
  from huggingface_hub import constants
5
5
  from huggingface_hub.hf_api import InferenceProviderMapping
6
- from huggingface_hub.inference._common import RequestParameters
6
+ from huggingface_hub.inference._common import MimeBytes, RequestParameters
7
7
  from huggingface_hub.inference._generated.types.chat_completion import ChatCompletionInputMessage
8
8
  from huggingface_hub.utils import build_hf_headers, get_token, logging
9
9
 
@@ -14,7 +14,7 @@ logger = logging.get_logger(__name__)
14
14
  # Dev purposes only.
15
15
  # If you want to try to run inference for a new model locally before it's registered on huggingface.co
16
16
  # for a given Inference Provider, you can add it to the following dictionary.
17
- HARDCODED_MODEL_INFERENCE_MAPPING: Dict[str, Dict[str, InferenceProviderMapping]] = {
17
+ HARDCODED_MODEL_INFERENCE_MAPPING: dict[str, dict[str, InferenceProviderMapping]] = {
18
18
  # "HF model ID" => InferenceProviderMapping object initialized with "Model ID on Inference Provider's side"
19
19
  #
20
20
  # Example:
@@ -38,14 +38,14 @@ HARDCODED_MODEL_INFERENCE_MAPPING: Dict[str, Dict[str, InferenceProviderMapping]
38
38
 
39
39
 
40
40
  @overload
41
- def filter_none(obj: Dict[str, Any]) -> Dict[str, Any]: ...
41
+ def filter_none(obj: dict[str, Any]) -> dict[str, Any]: ...
42
42
  @overload
43
- def filter_none(obj: List[Any]) -> List[Any]: ...
43
+ def filter_none(obj: list[Any]) -> list[Any]: ...
44
44
 
45
45
 
46
- def filter_none(obj: Union[Dict[str, Any], List[Any]]) -> Union[Dict[str, Any], List[Any]]:
46
+ def filter_none(obj: Union[dict[str, Any], list[Any]]) -> Union[dict[str, Any], list[Any]]:
47
47
  if isinstance(obj, dict):
48
- cleaned: Dict[str, Any] = {}
48
+ cleaned: dict[str, Any] = {}
49
49
  for k, v in obj.items():
50
50
  if v is None:
51
51
  continue
@@ -72,11 +72,11 @@ class TaskProviderHelper:
72
72
  self,
73
73
  *,
74
74
  inputs: Any,
75
- parameters: Dict[str, Any],
76
- headers: Dict,
75
+ parameters: dict[str, Any],
76
+ headers: dict,
77
77
  model: Optional[str],
78
78
  api_key: Optional[str],
79
- extra_payload: Optional[Dict[str, Any]] = None,
79
+ extra_payload: Optional[dict[str, Any]] = None,
80
80
  ) -> RequestParameters:
81
81
  """
82
82
  Prepare the request to be sent to the provider.
@@ -108,13 +108,22 @@ class TaskProviderHelper:
108
108
  raise ValueError("Both payload and data cannot be set in the same request.")
109
109
  if payload is None and data is None:
110
110
  raise ValueError("Either payload or data must be set in the request.")
111
+
112
+ # normalize headers to lowercase and add content-type if not present
113
+ normalized_headers = self._normalize_headers(headers, payload, data)
114
+
111
115
  return RequestParameters(
112
- url=url, task=self.task, model=provider_mapping_info.provider_id, json=payload, data=data, headers=headers
116
+ url=url,
117
+ task=self.task,
118
+ model=provider_mapping_info.provider_id,
119
+ json=payload,
120
+ data=data,
121
+ headers=normalized_headers,
113
122
  )
114
123
 
115
124
  def get_response(
116
125
  self,
117
- response: Union[bytes, Dict],
126
+ response: Union[bytes, dict],
118
127
  request_params: Optional[RequestParameters] = None,
119
128
  ) -> Any:
120
129
  """
@@ -172,7 +181,22 @@ class TaskProviderHelper:
172
181
  )
173
182
  return provider_mapping
174
183
 
175
- def _prepare_headers(self, headers: Dict, api_key: str) -> Dict:
184
+ def _normalize_headers(
185
+ self, headers: dict[str, Any], payload: Optional[dict[str, Any]], data: Optional[MimeBytes]
186
+ ) -> dict[str, Any]:
187
+ """Normalize the headers to use for the request.
188
+
189
+ Override this method in subclasses for customized headers.
190
+ """
191
+ normalized_headers = {key.lower(): value for key, value in headers.items() if value is not None}
192
+ if normalized_headers.get("content-type") is None:
193
+ if data is not None and data.mime_type is not None:
194
+ normalized_headers["content-type"] = data.mime_type
195
+ elif payload is not None:
196
+ normalized_headers["content-type"] = "application/json"
197
+ return normalized_headers
198
+
199
+ def _prepare_headers(self, headers: dict, api_key: str) -> dict[str, Any]:
176
200
  """Return the headers to use for the request.
177
201
 
178
202
  Override this method in subclasses for customized headers.
@@ -207,8 +231,8 @@ class TaskProviderHelper:
207
231
  return ""
208
232
 
209
233
  def _prepare_payload_as_dict(
210
- self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
211
- ) -> Optional[Dict]:
234
+ self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping
235
+ ) -> Optional[dict]:
212
236
  """Return the payload to use for the request, as a dict.
213
237
 
214
238
  Override this method in subclasses for customized payloads.
@@ -219,10 +243,10 @@ class TaskProviderHelper:
219
243
  def _prepare_payload_as_bytes(
220
244
  self,
221
245
  inputs: Any,
222
- parameters: Dict,
246
+ parameters: dict,
223
247
  provider_mapping_info: InferenceProviderMapping,
224
- extra_payload: Optional[Dict],
225
- ) -> Optional[bytes]:
248
+ extra_payload: Optional[dict],
249
+ ) -> Optional[MimeBytes]:
226
250
  """Return the body to use for the request, as bytes.
227
251
 
228
252
  Override this method in subclasses for customized body data.
@@ -245,10 +269,10 @@ class BaseConversationalTask(TaskProviderHelper):
245
269
 
246
270
  def _prepare_payload_as_dict(
247
271
  self,
248
- inputs: List[Union[Dict, ChatCompletionInputMessage]],
249
- parameters: Dict,
272
+ inputs: list[Union[dict, ChatCompletionInputMessage]],
273
+ parameters: dict,
250
274
  provider_mapping_info: InferenceProviderMapping,
251
- ) -> Optional[Dict]:
275
+ ) -> Optional[dict]:
252
276
  return filter_none({"messages": inputs, **parameters, "model": provider_mapping_info.provider_id})
253
277
 
254
278
 
@@ -265,13 +289,13 @@ class BaseTextGenerationTask(TaskProviderHelper):
265
289
  return "/v1/completions"
266
290
 
267
291
  def _prepare_payload_as_dict(
268
- self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
269
- ) -> Optional[Dict]:
292
+ self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping
293
+ ) -> Optional[dict]:
270
294
  return filter_none({"prompt": inputs, **parameters, "model": provider_mapping_info.provider_id})
271
295
 
272
296
 
273
297
  @lru_cache(maxsize=None)
274
- def _fetch_inference_provider_mapping(model: str) -> List["InferenceProviderMapping"]:
298
+ def _fetch_inference_provider_mapping(model: str) -> list["InferenceProviderMapping"]:
275
299
  """
276
300
  Fetch provider mappings for a model from the Hub.
277
301
  """
@@ -284,7 +308,7 @@ def _fetch_inference_provider_mapping(model: str) -> List["InferenceProviderMapp
284
308
  return provider_mapping
285
309
 
286
310
 
287
- def recursive_merge(dict1: Dict, dict2: Dict) -> Dict:
311
+ def recursive_merge(dict1: dict, dict2: dict) -> dict:
288
312
  return {
289
313
  **dict1,
290
314
  **{
@@ -1,5 +1,5 @@
1
1
  import time
2
- from typing import Any, Dict, Optional, Union
2
+ from typing import Any, Optional, Union
3
3
 
4
4
  from huggingface_hub.hf_api import InferenceProviderMapping
5
5
  from huggingface_hub.inference._common import RequestParameters, _as_dict
@@ -18,7 +18,7 @@ class BlackForestLabsTextToImageTask(TaskProviderHelper):
18
18
  def __init__(self):
19
19
  super().__init__(provider="black-forest-labs", base_url="https://api.us1.bfl.ai", task="text-to-image")
20
20
 
21
- def _prepare_headers(self, headers: Dict, api_key: str) -> Dict:
21
+ def _prepare_headers(self, headers: dict, api_key: str) -> dict[str, Any]:
22
22
  headers = super()._prepare_headers(headers, api_key)
23
23
  if not api_key.startswith("hf_"):
24
24
  _ = headers.pop("authorization")
@@ -29,8 +29,8 @@ class BlackForestLabsTextToImageTask(TaskProviderHelper):
29
29
  return f"/v1/{mapped_model}"
30
30
 
31
31
  def _prepare_payload_as_dict(
32
- self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
33
- ) -> Optional[Dict]:
32
+ self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping
33
+ ) -> Optional[dict]:
34
34
  parameters = filter_none(parameters)
35
35
  if "num_inference_steps" in parameters:
36
36
  parameters["steps"] = parameters.pop("num_inference_steps")
@@ -39,7 +39,7 @@ class BlackForestLabsTextToImageTask(TaskProviderHelper):
39
39
 
40
40
  return {"prompt": inputs, **parameters}
41
41
 
42
- def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any:
42
+ def get_response(self, response: Union[bytes, dict], request_params: Optional[RequestParameters] = None) -> Any:
43
43
  """
44
44
  Polling mechanism for Black Forest Labs since the API is asynchronous.
45
45
  """
@@ -50,7 +50,7 @@ class BlackForestLabsTextToImageTask(TaskProviderHelper):
50
50
 
51
51
  response = session.get(url, headers={"Content-Type": "application/json"}) # type: ignore
52
52
  response.raise_for_status() # type: ignore
53
- response_json: Dict = response.json() # type: ignore
53
+ response_json: dict = response.json() # type: ignore
54
54
  status = response_json.get("status")
55
55
  logger.info(
56
56
  f"Polling generation result from {url}. Current status: {status}. "
@@ -1,4 +1,4 @@
1
- from typing import Any, Dict, Optional
1
+ from typing import Any, Optional
2
2
 
3
3
  from huggingface_hub.hf_api import InferenceProviderMapping
4
4
 
@@ -17,8 +17,8 @@ class CohereConversationalTask(BaseConversationalTask):
17
17
  return "/compatibility/v1/chat/completions"
18
18
 
19
19
  def _prepare_payload_as_dict(
20
- self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
21
- ) -> Optional[Dict]:
20
+ self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping
21
+ ) -> Optional[dict]:
22
22
  payload = super()._prepare_payload_as_dict(inputs, parameters, provider_mapping_info)
23
23
  response_format = parameters.get("response_format")
24
24
  if isinstance(response_format, dict) and response_format.get("type") == "json_schema":