huggingface-hub 0.34.6__py3-none-any.whl → 0.35.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of huggingface-hub might be problematic. Click here for more details.

Files changed (38) hide show
  1. huggingface_hub/__init__.py +19 -1
  2. huggingface_hub/_jobs_api.py +159 -2
  3. huggingface_hub/_tensorboard_logger.py +9 -10
  4. huggingface_hub/cli/auth.py +1 -1
  5. huggingface_hub/cli/cache.py +3 -9
  6. huggingface_hub/cli/jobs.py +551 -1
  7. huggingface_hub/cli/repo.py +6 -4
  8. huggingface_hub/commands/delete_cache.py +2 -2
  9. huggingface_hub/commands/scan_cache.py +1 -1
  10. huggingface_hub/commands/user.py +1 -1
  11. huggingface_hub/hf_api.py +522 -78
  12. huggingface_hub/hf_file_system.py +3 -1
  13. huggingface_hub/hub_mixin.py +5 -3
  14. huggingface_hub/inference/_client.py +17 -180
  15. huggingface_hub/inference/_common.py +72 -70
  16. huggingface_hub/inference/_generated/_async_client.py +34 -200
  17. huggingface_hub/inference/_generated/types/chat_completion.py +2 -0
  18. huggingface_hub/inference/_mcp/_cli_hacks.py +3 -3
  19. huggingface_hub/inference/_mcp/cli.py +1 -1
  20. huggingface_hub/inference/_mcp/constants.py +1 -1
  21. huggingface_hub/inference/_mcp/mcp_client.py +28 -11
  22. huggingface_hub/inference/_mcp/types.py +3 -0
  23. huggingface_hub/inference/_mcp/utils.py +7 -3
  24. huggingface_hub/inference/_providers/_common.py +28 -4
  25. huggingface_hub/inference/_providers/black_forest_labs.py +1 -1
  26. huggingface_hub/inference/_providers/fal_ai.py +2 -2
  27. huggingface_hub/inference/_providers/hf_inference.py +15 -7
  28. huggingface_hub/inference/_providers/replicate.py +1 -1
  29. huggingface_hub/repocard.py +2 -1
  30. huggingface_hub/utils/_git_credential.py +1 -1
  31. huggingface_hub/utils/_typing.py +24 -4
  32. huggingface_hub/utils/_xet_progress_reporting.py +31 -10
  33. {huggingface_hub-0.34.6.dist-info → huggingface_hub-0.35.0.dist-info}/METADATA +7 -4
  34. {huggingface_hub-0.34.6.dist-info → huggingface_hub-0.35.0.dist-info}/RECORD +38 -38
  35. {huggingface_hub-0.34.6.dist-info → huggingface_hub-0.35.0.dist-info}/LICENSE +0 -0
  36. {huggingface_hub-0.34.6.dist-info → huggingface_hub-0.35.0.dist-info}/WHEEL +0 -0
  37. {huggingface_hub-0.34.6.dist-info → huggingface_hub-0.35.0.dist-info}/entry_points.txt +0 -0
  38. {huggingface_hub-0.34.6.dist-info → huggingface_hub-0.35.0.dist-info}/top_level.txt +0 -0
@@ -30,7 +30,6 @@ from huggingface_hub.errors import InferenceTimeoutError
30
30
  from huggingface_hub.inference._common import (
31
31
  TASKS_EXPECTING_IMAGES,
32
32
  ContentT,
33
- ModelStatus,
34
33
  RequestParameters,
35
34
  _async_stream_chat_completion_response,
36
35
  _async_stream_text_generation_response,
@@ -41,7 +40,6 @@ from huggingface_hub.inference._common import (
41
40
  _bytes_to_list,
42
41
  _get_unsupported_text_generation_kwargs,
43
42
  _import_numpy,
44
- _open_as_binary,
45
43
  _set_unsupported_text_generation_kwargs,
46
44
  raise_text_generation_error,
47
45
  )
@@ -88,9 +86,8 @@ from huggingface_hub.inference._generated.types import (
88
86
  ZeroShotImageClassificationOutputElement,
89
87
  )
90
88
  from huggingface_hub.inference._providers import PROVIDER_OR_POLICY_T, get_provider_helper
91
- from huggingface_hub.utils import build_hf_headers, get_session, hf_raise_for_status
89
+ from huggingface_hub.utils import build_hf_headers
92
90
  from huggingface_hub.utils._auth import get_token
93
- from huggingface_hub.utils._deprecation import _deprecate_method
94
91
 
95
92
  from .._common import _async_yield_from, _import_aiohttp
96
93
 
@@ -257,39 +254,38 @@ class AsyncInferenceClient:
257
254
  if request_parameters.task in TASKS_EXPECTING_IMAGES and "Accept" not in request_parameters.headers:
258
255
  request_parameters.headers["Accept"] = "image/png"
259
256
 
260
- with _open_as_binary(request_parameters.data) as data_as_binary:
261
- # Do not use context manager as we don't want to close the connection immediately when returning
262
- # a stream
263
- session = self._get_client_session(headers=request_parameters.headers)
257
+ # Do not use context manager as we don't want to close the connection immediately when returning
258
+ # a stream
259
+ session = self._get_client_session(headers=request_parameters.headers)
264
260
 
265
- try:
266
- response = await session.post(
267
- request_parameters.url, json=request_parameters.json, data=data_as_binary, proxy=self.proxies
268
- )
269
- response_error_payload = None
270
- if response.status != 200:
271
- try:
272
- response_error_payload = await response.json() # get payload before connection closed
273
- except Exception:
274
- pass
275
- response.raise_for_status()
276
- if stream:
277
- return _async_yield_from(session, response)
278
- else:
279
- content = await response.read()
280
- await session.close()
281
- return content
282
- except asyncio.TimeoutError as error:
283
- await session.close()
284
- # Convert any `TimeoutError` to a `InferenceTimeoutError`
285
- raise InferenceTimeoutError(f"Inference call timed out: {request_parameters.url}") from error # type: ignore
286
- except aiohttp.ClientResponseError as error:
287
- error.response_error_payload = response_error_payload
288
- await session.close()
289
- raise error
290
- except Exception:
261
+ try:
262
+ response = await session.post(
263
+ request_parameters.url, json=request_parameters.json, data=request_parameters.data, proxy=self.proxies
264
+ )
265
+ response_error_payload = None
266
+ if response.status != 200:
267
+ try:
268
+ response_error_payload = await response.json() # get payload before connection closed
269
+ except Exception:
270
+ pass
271
+ response.raise_for_status()
272
+ if stream:
273
+ return _async_yield_from(session, response)
274
+ else:
275
+ content = await response.read()
291
276
  await session.close()
292
- raise
277
+ return content
278
+ except asyncio.TimeoutError as error:
279
+ await session.close()
280
+ # Convert any `TimeoutError` to a `InferenceTimeoutError`
281
+ raise InferenceTimeoutError(f"Inference call timed out: {request_parameters.url}") from error # type: ignore
282
+ except aiohttp.ClientResponseError as error:
283
+ error.response_error_payload = response_error_payload
284
+ await session.close()
285
+ raise error
286
+ except Exception:
287
+ await session.close()
288
+ raise
293
289
 
294
290
  async def __aenter__(self):
295
291
  return self
@@ -1510,8 +1506,8 @@ class AsyncInferenceClient:
1510
1506
  api_key=self.token,
1511
1507
  )
1512
1508
  response = await self._inner_post(request_parameters)
1513
- output = ImageToTextOutput.parse_obj(response)
1514
- return output[0] if isinstance(output, list) else output
1509
+ output_list: List[ImageToTextOutput] = ImageToTextOutput.parse_obj_as_list(response)
1510
+ return output_list[0]
1515
1511
 
1516
1512
  async def object_detection(
1517
1513
  self, image: ContentT, *, model: Optional[str] = None, threshold: Optional[float] = None
@@ -3338,102 +3334,6 @@ class AsyncInferenceClient:
3338
3334
  response = await self._inner_post(request_parameters)
3339
3335
  return ZeroShotImageClassificationOutputElement.parse_obj_as_list(response)
3340
3336
 
3341
- @_deprecate_method(
3342
- version="0.35.0",
3343
- message=(
3344
- "HF Inference API is getting revamped and will only support warm models in the future (no cold start allowed)."
3345
- " Use `HfApi.list_models(..., inference_provider='...')` to list warm models per provider."
3346
- ),
3347
- )
3348
- async def list_deployed_models(
3349
- self, frameworks: Union[None, str, Literal["all"], List[str]] = None
3350
- ) -> Dict[str, List[str]]:
3351
- """
3352
- List models deployed on the HF Serverless Inference API service.
3353
-
3354
- This helper checks deployed models framework by framework. By default, it will check the 4 main frameworks that
3355
- are supported and account for 95% of the hosted models. However, if you want a complete list of models you can
3356
- specify `frameworks="all"` as input. Alternatively, if you know before-hand which framework you are interested
3357
- in, you can also restrict to search to this one (e.g. `frameworks="text-generation-inference"`). The more
3358
- frameworks are checked, the more time it will take.
3359
-
3360
- <Tip warning={true}>
3361
-
3362
- This endpoint method does not return a live list of all models available for the HF Inference API service.
3363
- It searches over a cached list of models that were recently available and the list may not be up to date.
3364
- If you want to know the live status of a specific model, use [`~InferenceClient.get_model_status`].
3365
-
3366
- </Tip>
3367
-
3368
- <Tip>
3369
-
3370
- This endpoint method is mostly useful for discoverability. If you already know which model you want to use and want to
3371
- check its availability, you can directly use [`~InferenceClient.get_model_status`].
3372
-
3373
- </Tip>
3374
-
3375
- Args:
3376
- frameworks (`Literal["all"]` or `List[str]` or `str`, *optional*):
3377
- The frameworks to filter on. By default only a subset of the available frameworks are tested. If set to
3378
- "all", all available frameworks will be tested. It is also possible to provide a single framework or a
3379
- custom set of frameworks to check.
3380
-
3381
- Returns:
3382
- `Dict[str, List[str]]`: A dictionary mapping task names to a sorted list of model IDs.
3383
-
3384
- Example:
3385
- ```py
3386
- # Must be run in an async contextthon
3387
- >>> from huggingface_hub import AsyncInferenceClient
3388
- >>> client = AsyncInferenceClient()
3389
-
3390
- # Discover zero-shot-classification models currently deployed
3391
- >>> models = await client.list_deployed_models()
3392
- >>> models["zero-shot-classification"]
3393
- ['Narsil/deberta-large-mnli-zero-cls', 'facebook/bart-large-mnli', ...]
3394
-
3395
- # List from only 1 framework
3396
- >>> await client.list_deployed_models("text-generation-inference")
3397
- {'text-generation': ['bigcode/starcoder', 'meta-llama/Llama-2-70b-chat-hf', ...], ...}
3398
- ```
3399
- """
3400
- if self.provider != "hf-inference":
3401
- raise ValueError(f"Listing deployed models is not supported on '{self.provider}'.")
3402
-
3403
- # Resolve which frameworks to check
3404
- if frameworks is None:
3405
- frameworks = constants.MAIN_INFERENCE_API_FRAMEWORKS
3406
- elif frameworks == "all":
3407
- frameworks = constants.ALL_INFERENCE_API_FRAMEWORKS
3408
- elif isinstance(frameworks, str):
3409
- frameworks = [frameworks]
3410
- frameworks = list(set(frameworks))
3411
-
3412
- # Fetch them iteratively
3413
- models_by_task: Dict[str, List[str]] = {}
3414
-
3415
- def _unpack_response(framework: str, items: List[Dict]) -> None:
3416
- for model in items:
3417
- if framework == "sentence-transformers":
3418
- # Model running with the `sentence-transformers` framework can work with both tasks even if not
3419
- # branded as such in the API response
3420
- models_by_task.setdefault("feature-extraction", []).append(model["model_id"])
3421
- models_by_task.setdefault("sentence-similarity", []).append(model["model_id"])
3422
- else:
3423
- models_by_task.setdefault(model["task"], []).append(model["model_id"])
3424
-
3425
- for framework in frameworks:
3426
- response = get_session().get(
3427
- f"{constants.INFERENCE_ENDPOINT}/framework/{framework}", headers=build_hf_headers(token=self.token)
3428
- )
3429
- hf_raise_for_status(response)
3430
- _unpack_response(framework, response.json())
3431
-
3432
- # Sort alphabetically for discoverability and return
3433
- for task, models in models_by_task.items():
3434
- models_by_task[task] = sorted(set(models), key=lambda x: x.lower())
3435
- return models_by_task
3436
-
3437
3337
  def _get_client_session(self, headers: Optional[Dict] = None) -> "ClientSession":
3438
3338
  aiohttp = _import_aiohttp()
3439
3339
  client_headers = self.headers.copy()
@@ -3540,7 +3440,6 @@ class AsyncInferenceClient:
3540
3440
  Check the health of the deployed endpoint.
3541
3441
 
3542
3442
  Health check is only available with Inference Endpoints powered by Text-Generation-Inference (TGI) or Text-Embedding-Inference (TEI).
3543
- For Inference API, please use [`InferenceClient.get_model_status`] instead.
3544
3443
 
3545
3444
  Args:
3546
3445
  model (`str`, *optional*):
@@ -3565,78 +3464,13 @@ class AsyncInferenceClient:
3565
3464
  if model is None:
3566
3465
  raise ValueError("Model id not provided.")
3567
3466
  if not model.startswith(("http://", "https://")):
3568
- raise ValueError(
3569
- "Model must be an Inference Endpoint URL. For serverless Inference API, please use `InferenceClient.get_model_status`."
3570
- )
3467
+ raise ValueError("Model must be an Inference Endpoint URL.")
3571
3468
  url = model.rstrip("/") + "/health"
3572
3469
 
3573
3470
  async with self._get_client_session(headers=build_hf_headers(token=self.token)) as client:
3574
3471
  response = await client.get(url, proxy=self.proxies)
3575
3472
  return response.status == 200
3576
3473
 
3577
- @_deprecate_method(
3578
- version="0.35.0",
3579
- message=(
3580
- "HF Inference API is getting revamped and will only support warm models in the future (no cold start allowed)."
3581
- " Use `HfApi.model_info` to get the model status both with HF Inference API and external providers."
3582
- ),
3583
- )
3584
- async def get_model_status(self, model: Optional[str] = None) -> ModelStatus:
3585
- """
3586
- Get the status of a model hosted on the HF Inference API.
3587
-
3588
- <Tip>
3589
-
3590
- This endpoint is mostly useful when you already know which model you want to use and want to check its
3591
- availability. If you want to discover already deployed models, you should rather use [`~InferenceClient.list_deployed_models`].
3592
-
3593
- </Tip>
3594
-
3595
- Args:
3596
- model (`str`, *optional*):
3597
- Identifier of the model for witch the status gonna be checked. If model is not provided,
3598
- the model associated with this instance of [`InferenceClient`] will be used. Only HF Inference API service can be checked so the
3599
- identifier cannot be a URL.
3600
-
3601
-
3602
- Returns:
3603
- [`ModelStatus`]: An instance of ModelStatus dataclass, containing information,
3604
- about the state of the model: load, state, compute type and framework.
3605
-
3606
- Example:
3607
- ```py
3608
- # Must be run in an async context
3609
- >>> from huggingface_hub import AsyncInferenceClient
3610
- >>> client = AsyncInferenceClient()
3611
- >>> await client.get_model_status("meta-llama/Meta-Llama-3-8B-Instruct")
3612
- ModelStatus(loaded=True, state='Loaded', compute_type='gpu', framework='text-generation-inference')
3613
- ```
3614
- """
3615
- if self.provider != "hf-inference":
3616
- raise ValueError(f"Getting model status is not supported on '{self.provider}'.")
3617
-
3618
- model = model or self.model
3619
- if model is None:
3620
- raise ValueError("Model id not provided.")
3621
- if model.startswith("https://"):
3622
- raise NotImplementedError("Model status is only available for Inference API endpoints.")
3623
- url = f"{constants.INFERENCE_ENDPOINT}/status/{model}"
3624
-
3625
- async with self._get_client_session(headers=build_hf_headers(token=self.token)) as client:
3626
- response = await client.get(url, proxy=self.proxies)
3627
- response.raise_for_status()
3628
- response_data = await response.json()
3629
-
3630
- if "error" in response_data:
3631
- raise ValueError(response_data["error"])
3632
-
3633
- return ModelStatus(
3634
- loaded=response_data["loaded"],
3635
- state=response_data["state"],
3636
- compute_type=response_data["compute_type"],
3637
- framework=response_data["framework"],
3638
- )
3639
-
3640
3474
  @property
3641
3475
  def chat(self) -> "ProxyClientChat":
3642
3476
  return ProxyClientChat(self)
@@ -239,6 +239,7 @@ class ChatCompletionOutputToolCall(BaseInferenceType):
239
239
  class ChatCompletionOutputMessage(BaseInferenceType):
240
240
  role: str
241
241
  content: Optional[str] = None
242
+ reasoning: Optional[str] = None
242
243
  tool_call_id: Optional[str] = None
243
244
  tool_calls: Optional[List[ChatCompletionOutputToolCall]] = None
244
245
 
@@ -292,6 +293,7 @@ class ChatCompletionStreamOutputDeltaToolCall(BaseInferenceType):
292
293
  class ChatCompletionStreamOutputDelta(BaseInferenceType):
293
294
  role: str
294
295
  content: Optional[str] = None
296
+ reasoning: Optional[str] = None
295
297
  tool_call_id: Optional[str] = None
296
298
  tool_calls: Optional[List[ChatCompletionStreamOutputDeltaToolCall]] = None
297
299
 
@@ -17,7 +17,7 @@ def _patch_anyio_open_process():
17
17
 
18
18
  if getattr(anyio, "_tiny_agents_patched", False):
19
19
  return
20
- anyio._tiny_agents_patched = True
20
+ anyio._tiny_agents_patched = True # ty: ignore[invalid-assignment]
21
21
 
22
22
  original_open_process = anyio.open_process
23
23
 
@@ -32,7 +32,7 @@ def _patch_anyio_open_process():
32
32
  kwargs.setdefault("creationflags", subprocess.CREATE_NEW_PROCESS_GROUP)
33
33
  return await original_open_process(*args, **kwargs)
34
34
 
35
- anyio.open_process = open_process_in_new_group
35
+ anyio.open_process = open_process_in_new_group # ty: ignore[invalid-assignment]
36
36
  else:
37
37
  # For Unix-like systems, we can use setsid to create a new session
38
38
  async def open_process_in_new_group(*args, **kwargs):
@@ -42,7 +42,7 @@ def _patch_anyio_open_process():
42
42
  kwargs.setdefault("start_new_session", True)
43
43
  return await original_open_process(*args, **kwargs)
44
44
 
45
- anyio.open_process = open_process_in_new_group
45
+ anyio.open_process = open_process_in_new_group # ty: ignore[invalid-assignment]
46
46
 
47
47
 
48
48
  async def _async_prompt(exit_event: asyncio.Event, prompt: str = "» ") -> str:
@@ -33,7 +33,7 @@ async def run_agent(
33
33
 
34
34
  Args:
35
35
  agent_path (`str`, *optional*):
36
- Path to a local folder containing an `agent.json` and optionally a custom `PROMPT.md` file or a built-in agent stored in a Hugging Face dataset.
36
+ Path to a local folder containing an `agent.json` and optionally a custom `PROMPT.md` or `AGENTS.md` file or a built-in agent stored in a Hugging Face dataset.
37
37
 
38
38
  """
39
39
  _patch_anyio_open_process() # Hacky way to prevent stdio connections to be stopped by Ctrl+C
@@ -8,7 +8,7 @@ from huggingface_hub import ChatCompletionInputTool
8
8
 
9
9
 
10
10
  FILENAME_CONFIG = "agent.json"
11
- FILENAME_PROMPT = "PROMPT.md"
11
+ PROMPT_FILENAMES = ("PROMPT.md", "AGENTS.md")
12
12
 
13
13
  DEFAULT_AGENT = {
14
14
  "model": "Qwen/Qwen2.5-72B-Instruct",
@@ -139,21 +139,27 @@ class MCPClient:
139
139
  - args (List[str], optional): Arguments for the command
140
140
  - env (Dict[str, str], optional): Environment variables for the command
141
141
  - cwd (Union[str, Path, None], optional): Working directory for the command
142
+ - allowed_tools (List[str], optional): List of tool names to allow from this server
142
143
  - For SSE servers:
143
144
  - url (str): The URL of the SSE server
144
145
  - headers (Dict[str, Any], optional): Headers for the SSE connection
145
146
  - timeout (float, optional): Connection timeout
146
147
  - sse_read_timeout (float, optional): SSE read timeout
148
+ - allowed_tools (List[str], optional): List of tool names to allow from this server
147
149
  - For StreamableHTTP servers:
148
150
  - url (str): The URL of the StreamableHTTP server
149
151
  - headers (Dict[str, Any], optional): Headers for the StreamableHTTP connection
150
152
  - timeout (timedelta, optional): Connection timeout
151
153
  - sse_read_timeout (timedelta, optional): SSE read timeout
152
154
  - terminate_on_close (bool, optional): Whether to terminate on close
155
+ - allowed_tools (List[str], optional): List of tool names to allow from this server
153
156
  """
154
157
  from mcp import ClientSession, StdioServerParameters
155
158
  from mcp import types as mcp_types
156
159
 
160
+ # Extract allowed_tools configuration if provided
161
+ allowed_tools = params.pop("allowed_tools", [])
162
+
157
163
  # Determine server type and create appropriate parameters
158
164
  if type == "stdio":
159
165
  # Handle stdio server
@@ -211,7 +217,15 @@ class MCPClient:
211
217
  response = await session.list_tools()
212
218
  logger.debug("Connected to server with tools:", [tool.name for tool in response.tools])
213
219
 
214
- for tool in response.tools:
220
+ # Filter tools based on allowed_tools configuration
221
+ filtered_tools = [tool for tool in response.tools if tool.name in allowed_tools]
222
+
223
+ if allowed_tools:
224
+ logger.debug(
225
+ f"Tool filtering applied. Using {len(filtered_tools)} of {len(response.tools)} available tools: {[tool.name for tool in filtered_tools]}"
226
+ )
227
+
228
+ for tool in filtered_tools:
215
229
  if tool.name in self.sessions:
216
230
  logger.warning(f"Tool '{tool.name}' already defined by another server. Skipping.")
217
231
  continue
@@ -286,16 +300,19 @@ class MCPClient:
286
300
  # Process tool calls
287
301
  if delta.tool_calls:
288
302
  for tool_call in delta.tool_calls:
289
- # Aggregate chunks into tool calls
290
- if tool_call.index not in final_tool_calls:
291
- if (
292
- tool_call.function.arguments is None or tool_call.function.arguments == "{}"
293
- ): # Corner case (depends on provider)
294
- tool_call.function.arguments = ""
295
- final_tool_calls[tool_call.index] = tool_call
296
-
297
- elif tool_call.function.arguments:
298
- final_tool_calls[tool_call.index].function.arguments += tool_call.function.arguments
303
+ idx = tool_call.index
304
+ # first chunk for this tool call
305
+ if idx not in final_tool_calls:
306
+ final_tool_calls[idx] = tool_call
307
+ if final_tool_calls[idx].function.arguments is None:
308
+ final_tool_calls[idx].function.arguments = ""
309
+ continue
310
+ # safety before concatenating text to .function.arguments
311
+ if final_tool_calls[idx].function.arguments is None:
312
+ final_tool_calls[idx].function.arguments = ""
313
+
314
+ if tool_call.function.arguments:
315
+ final_tool_calls[idx].function.arguments += tool_call.function.arguments
299
316
 
300
317
  # Optionally exit early if no tools in first chunks
301
318
  if exit_if_first_chunk_no_tool and num_of_chunks <= 2 and len(final_tool_calls) == 0:
@@ -16,18 +16,21 @@ class StdioServerConfig(TypedDict):
16
16
  args: List[str]
17
17
  env: Dict[str, str]
18
18
  cwd: str
19
+ allowed_tools: NotRequired[List[str]]
19
20
 
20
21
 
21
22
  class HTTPServerConfig(TypedDict):
22
23
  type: Literal["http"]
23
24
  url: str
24
25
  headers: Dict[str, str]
26
+ allowed_tools: NotRequired[List[str]]
25
27
 
26
28
 
27
29
  class SSEServerConfig(TypedDict):
28
30
  type: Literal["sse"]
29
31
  url: str
30
32
  headers: Dict[str, str]
33
+ allowed_tools: NotRequired[List[str]]
31
34
 
32
35
 
33
36
  ServerConfig = Union[StdioServerConfig, HTTPServerConfig, SSEServerConfig]
@@ -11,7 +11,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple
11
11
  from huggingface_hub import snapshot_download
12
12
  from huggingface_hub.errors import EntryNotFoundError
13
13
 
14
- from .constants import DEFAULT_AGENT, DEFAULT_REPO_ID, FILENAME_CONFIG, FILENAME_PROMPT
14
+ from .constants import DEFAULT_AGENT, DEFAULT_REPO_ID, FILENAME_CONFIG, PROMPT_FILENAMES
15
15
  from .types import AgentConfig
16
16
 
17
17
 
@@ -93,8 +93,12 @@ def _load_agent_config(agent_path: Optional[str]) -> Tuple[AgentConfig, Optional
93
93
  raise FileNotFoundError(f" Config file not found in {directory}! Please make sure it exists locally")
94
94
 
95
95
  config: AgentConfig = json.loads(cfg_file.read_text(encoding="utf-8"))
96
- prompt_file = directory / FILENAME_PROMPT
97
- prompt: Optional[str] = prompt_file.read_text(encoding="utf-8") if prompt_file.exists() else None
96
+ prompt: Optional[str] = None
97
+ for filename in PROMPT_FILENAMES:
98
+ prompt_file = directory / filename
99
+ if prompt_file.exists():
100
+ prompt = prompt_file.read_text(encoding="utf-8")
101
+ break
98
102
  return config, prompt
99
103
 
100
104
  if agent_path is None:
@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional, Union, overload
3
3
 
4
4
  from huggingface_hub import constants
5
5
  from huggingface_hub.hf_api import InferenceProviderMapping
6
- from huggingface_hub.inference._common import RequestParameters
6
+ from huggingface_hub.inference._common import MimeBytes, RequestParameters
7
7
  from huggingface_hub.inference._generated.types.chat_completion import ChatCompletionInputMessage
8
8
  from huggingface_hub.utils import build_hf_headers, get_token, logging
9
9
 
@@ -109,8 +109,17 @@ class TaskProviderHelper:
109
109
  raise ValueError("Both payload and data cannot be set in the same request.")
110
110
  if payload is None and data is None:
111
111
  raise ValueError("Either payload or data must be set in the request.")
112
+
113
+ # normalize headers to lowercase and add content-type if not present
114
+ normalized_headers = self._normalize_headers(headers, payload, data)
115
+
112
116
  return RequestParameters(
113
- url=url, task=self.task, model=provider_mapping_info.provider_id, json=payload, data=data, headers=headers
117
+ url=url,
118
+ task=self.task,
119
+ model=provider_mapping_info.provider_id,
120
+ json=payload,
121
+ data=data,
122
+ headers=normalized_headers,
114
123
  )
115
124
 
116
125
  def get_response(
@@ -173,7 +182,22 @@ class TaskProviderHelper:
173
182
  )
174
183
  return provider_mapping
175
184
 
176
- def _prepare_headers(self, headers: Dict, api_key: str) -> Dict:
185
+ def _normalize_headers(
186
+ self, headers: Dict[str, Any], payload: Optional[Dict[str, Any]], data: Optional[MimeBytes]
187
+ ) -> Dict[str, Any]:
188
+ """Normalize the headers to use for the request.
189
+
190
+ Override this method in subclasses for customized headers.
191
+ """
192
+ normalized_headers = {key.lower(): value for key, value in headers.items() if value is not None}
193
+ if normalized_headers.get("content-type") is None:
194
+ if data is not None and data.mime_type is not None:
195
+ normalized_headers["content-type"] = data.mime_type
196
+ elif payload is not None:
197
+ normalized_headers["content-type"] = "application/json"
198
+ return normalized_headers
199
+
200
+ def _prepare_headers(self, headers: Dict, api_key: str) -> Dict[str, Any]:
177
201
  """Return the headers to use for the request.
178
202
 
179
203
  Override this method in subclasses for customized headers.
@@ -223,7 +247,7 @@ class TaskProviderHelper:
223
247
  parameters: Dict,
224
248
  provider_mapping_info: InferenceProviderMapping,
225
249
  extra_payload: Optional[Dict],
226
- ) -> Optional[bytes]:
250
+ ) -> Optional[MimeBytes]:
227
251
  """Return the body to use for the request, as bytes.
228
252
 
229
253
  Override this method in subclasses for customized body data.
@@ -18,7 +18,7 @@ class BlackForestLabsTextToImageTask(TaskProviderHelper):
18
18
  def __init__(self):
19
19
  super().__init__(provider="black-forest-labs", base_url="https://api.us1.bfl.ai", task="text-to-image")
20
20
 
21
- def _prepare_headers(self, headers: Dict, api_key: str) -> Dict:
21
+ def _prepare_headers(self, headers: Dict, api_key: str) -> Dict[str, Any]:
22
22
  headers = super()._prepare_headers(headers, api_key)
23
23
  if not api_key.startswith("hf_"):
24
24
  _ = headers.pop("authorization")
@@ -22,7 +22,7 @@ class FalAITask(TaskProviderHelper, ABC):
22
22
  def __init__(self, task: str):
23
23
  super().__init__(provider="fal-ai", base_url="https://fal.run", task=task)
24
24
 
25
- def _prepare_headers(self, headers: Dict, api_key: str) -> Dict:
25
+ def _prepare_headers(self, headers: Dict, api_key: str) -> Dict[str, Any]:
26
26
  headers = super()._prepare_headers(headers, api_key)
27
27
  if not api_key.startswith("hf_"):
28
28
  headers["authorization"] = f"Key {api_key}"
@@ -36,7 +36,7 @@ class FalAIQueueTask(TaskProviderHelper, ABC):
36
36
  def __init__(self, task: str):
37
37
  super().__init__(provider="fal-ai", base_url="https://queue.fal.run", task=task)
38
38
 
39
- def _prepare_headers(self, headers: Dict, api_key: str) -> Dict:
39
+ def _prepare_headers(self, headers: Dict, api_key: str) -> Dict[str, Any]:
40
40
  headers = super()._prepare_headers(headers, api_key)
41
41
  if not api_key.startswith("hf_"):
42
42
  headers["authorization"] = f"Key {api_key}"
@@ -6,7 +6,13 @@ from urllib.parse import urlparse, urlunparse
6
6
 
7
7
  from huggingface_hub import constants
8
8
  from huggingface_hub.hf_api import InferenceProviderMapping
9
- from huggingface_hub.inference._common import RequestParameters, _b64_encode, _bytes_to_dict, _open_as_binary
9
+ from huggingface_hub.inference._common import (
10
+ MimeBytes,
11
+ RequestParameters,
12
+ _b64_encode,
13
+ _bytes_to_dict,
14
+ _open_as_mime_bytes,
15
+ )
10
16
  from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none
11
17
  from huggingface_hub.utils import build_hf_headers, get_session, get_token, hf_raise_for_status
12
18
 
@@ -75,7 +81,7 @@ class HFInferenceBinaryInputTask(HFInferenceTask):
75
81
  parameters: Dict,
76
82
  provider_mapping_info: InferenceProviderMapping,
77
83
  extra_payload: Optional[Dict],
78
- ) -> Optional[bytes]:
84
+ ) -> Optional[MimeBytes]:
79
85
  parameters = filter_none(parameters)
80
86
  extra_payload = extra_payload or {}
81
87
  has_parameters = len(parameters) > 0 or len(extra_payload) > 0
@@ -86,12 +92,13 @@ class HFInferenceBinaryInputTask(HFInferenceTask):
86
92
 
87
93
  # Send inputs as raw content when no parameters are provided
88
94
  if not has_parameters:
89
- with _open_as_binary(inputs) as data:
90
- data_as_bytes = data if isinstance(data, bytes) else data.read()
91
- return data_as_bytes
95
+ return _open_as_mime_bytes(inputs)
92
96
 
93
97
  # Otherwise encode as b64
94
- return json.dumps({"inputs": _b64_encode(inputs), "parameters": parameters, **extra_payload}).encode("utf-8")
98
+ return MimeBytes(
99
+ json.dumps({"inputs": _b64_encode(inputs), "parameters": parameters, **extra_payload}).encode("utf-8"),
100
+ mime_type="application/json",
101
+ )
95
102
 
96
103
 
97
104
  class HFInferenceConversational(HFInferenceTask):
@@ -144,7 +151,8 @@ def _build_chat_completion_url(model_url: str) -> str:
144
151
  new_path = path + "/v1/chat/completions"
145
152
 
146
153
  # Reconstruct the URL with the new path and original query parameters.
147
- return urlunparse(parsed._replace(path=new_path))
154
+ new_parsed = parsed._replace(path=new_path)
155
+ return str(urlunparse(new_parsed))
148
156
 
149
157
 
150
158
  @lru_cache(maxsize=1)
@@ -14,7 +14,7 @@ class ReplicateTask(TaskProviderHelper):
14
14
  def __init__(self, task: str):
15
15
  super().__init__(provider=_PROVIDER, base_url=_BASE_URL, task=task)
16
16
 
17
- def _prepare_headers(self, headers: Dict, api_key: str) -> Dict:
17
+ def _prepare_headers(self, headers: Dict, api_key: str) -> Dict[str, Any]:
18
18
  headers = super()._prepare_headers(headers, api_key)
19
19
  headers["Prefer"] = "wait"
20
20
  return headers
@@ -771,7 +771,8 @@ def metadata_update(
771
771
  raise ValueError("Cannot update metadata on a Space that doesn't contain a `README.md` file.")
772
772
 
773
773
  # Initialize a ModelCard or DatasetCard from default template and no data.
774
- card = card_class.from_template(CardData())
774
+ # Cast to the concrete expected card type to satisfy type checkers.
775
+ card = card_class.from_template(CardData()) # type: ignore[return-value]
775
776
 
776
777
  for key, value in metadata.items():
777
778
  if key == "model-index":
@@ -27,7 +27,7 @@ GIT_CREDENTIAL_REGEX = re.compile(
27
27
  ^\s* # start of line
28
28
  credential\.helper # credential.helper value
29
29
  \s*=\s* # separator
30
- (\w+) # the helper name (group 1)
30
+ ([\w\-\/]+) # the helper name or absolute path (group 1)
31
31
  (\s|$) # whitespace or end of line
32
32
  """,
33
33
  flags=re.MULTILINE | re.IGNORECASE | re.VERBOSE,