huggingface-hub 0.35.0rc0__py3-none-any.whl → 0.35.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of huggingface-hub might be problematic. Click here for more details.

Files changed (50) hide show
  1. huggingface_hub/__init__.py +19 -1
  2. huggingface_hub/_jobs_api.py +168 -12
  3. huggingface_hub/_local_folder.py +1 -1
  4. huggingface_hub/_oauth.py +5 -9
  5. huggingface_hub/_tensorboard_logger.py +9 -10
  6. huggingface_hub/_upload_large_folder.py +108 -1
  7. huggingface_hub/cli/auth.py +4 -1
  8. huggingface_hub/cli/cache.py +7 -9
  9. huggingface_hub/cli/hf.py +2 -5
  10. huggingface_hub/cli/jobs.py +591 -13
  11. huggingface_hub/cli/repo.py +10 -4
  12. huggingface_hub/commands/delete_cache.py +2 -2
  13. huggingface_hub/commands/scan_cache.py +1 -1
  14. huggingface_hub/dataclasses.py +3 -0
  15. huggingface_hub/file_download.py +12 -10
  16. huggingface_hub/hf_api.py +549 -95
  17. huggingface_hub/hf_file_system.py +4 -10
  18. huggingface_hub/hub_mixin.py +5 -3
  19. huggingface_hub/inference/_client.py +98 -181
  20. huggingface_hub/inference/_common.py +72 -70
  21. huggingface_hub/inference/_generated/_async_client.py +116 -201
  22. huggingface_hub/inference/_generated/types/chat_completion.py +2 -0
  23. huggingface_hub/inference/_mcp/_cli_hacks.py +3 -3
  24. huggingface_hub/inference/_mcp/cli.py +1 -1
  25. huggingface_hub/inference/_mcp/constants.py +1 -1
  26. huggingface_hub/inference/_mcp/mcp_client.py +28 -11
  27. huggingface_hub/inference/_mcp/types.py +3 -0
  28. huggingface_hub/inference/_mcp/utils.py +7 -3
  29. huggingface_hub/inference/_providers/__init__.py +13 -0
  30. huggingface_hub/inference/_providers/_common.py +29 -4
  31. huggingface_hub/inference/_providers/black_forest_labs.py +1 -1
  32. huggingface_hub/inference/_providers/fal_ai.py +33 -2
  33. huggingface_hub/inference/_providers/hf_inference.py +15 -7
  34. huggingface_hub/inference/_providers/publicai.py +6 -0
  35. huggingface_hub/inference/_providers/replicate.py +1 -1
  36. huggingface_hub/inference/_providers/scaleway.py +28 -0
  37. huggingface_hub/lfs.py +2 -4
  38. huggingface_hub/repocard.py +2 -1
  39. huggingface_hub/utils/_dotenv.py +24 -20
  40. huggingface_hub/utils/_git_credential.py +1 -1
  41. huggingface_hub/utils/_http.py +3 -5
  42. huggingface_hub/utils/_runtime.py +1 -0
  43. huggingface_hub/utils/_typing.py +24 -4
  44. huggingface_hub/utils/_xet_progress_reporting.py +31 -10
  45. {huggingface_hub-0.35.0rc0.dist-info → huggingface_hub-0.35.1.dist-info}/METADATA +7 -4
  46. {huggingface_hub-0.35.0rc0.dist-info → huggingface_hub-0.35.1.dist-info}/RECORD +50 -48
  47. {huggingface_hub-0.35.0rc0.dist-info → huggingface_hub-0.35.1.dist-info}/LICENSE +0 -0
  48. {huggingface_hub-0.35.0rc0.dist-info → huggingface_hub-0.35.1.dist-info}/WHEEL +0 -0
  49. {huggingface_hub-0.35.0rc0.dist-info → huggingface_hub-0.35.1.dist-info}/entry_points.txt +0 -0
  50. {huggingface_hub-0.35.0rc0.dist-info → huggingface_hub-0.35.1.dist-info}/top_level.txt +0 -0
@@ -30,7 +30,6 @@ from huggingface_hub.errors import InferenceTimeoutError
30
30
  from huggingface_hub.inference._common import (
31
31
  TASKS_EXPECTING_IMAGES,
32
32
  ContentT,
33
- ModelStatus,
34
33
  RequestParameters,
35
34
  _async_stream_chat_completion_response,
36
35
  _async_stream_text_generation_response,
@@ -41,7 +40,6 @@ from huggingface_hub.inference._common import (
41
40
  _bytes_to_list,
42
41
  _get_unsupported_text_generation_kwargs,
43
42
  _import_numpy,
44
- _open_as_binary,
45
43
  _set_unsupported_text_generation_kwargs,
46
44
  raise_text_generation_error,
47
45
  )
@@ -66,6 +64,7 @@ from huggingface_hub.inference._generated.types import (
66
64
  ImageSegmentationSubtask,
67
65
  ImageToImageTargetSize,
68
66
  ImageToTextOutput,
67
+ ImageToVideoTargetSize,
69
68
  ObjectDetectionOutputElement,
70
69
  Padding,
71
70
  QuestionAnsweringOutputElement,
@@ -87,9 +86,8 @@ from huggingface_hub.inference._generated.types import (
87
86
  ZeroShotImageClassificationOutputElement,
88
87
  )
89
88
  from huggingface_hub.inference._providers import PROVIDER_OR_POLICY_T, get_provider_helper
90
- from huggingface_hub.utils import build_hf_headers, get_session, hf_raise_for_status
89
+ from huggingface_hub.utils import build_hf_headers
91
90
  from huggingface_hub.utils._auth import get_token
92
- from huggingface_hub.utils._deprecation import _deprecate_method
93
91
 
94
92
  from .._common import _async_yield_from, _import_aiohttp
95
93
 
@@ -120,7 +118,7 @@ class AsyncInferenceClient:
120
118
  Note: for better compatibility with OpenAI's client, `model` has been aliased as `base_url`. Those 2
121
119
  arguments are mutually exclusive. If a URL is passed as `model` or `base_url` for chat completion, the `(/v1)/chat/completions` suffix path will be appended to the URL.
122
120
  provider (`str`, *optional*):
123
- Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `"replicate"`, "sambanova"` or `"together"`.
121
+ Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `publicai`, `"replicate"`, `"sambanova"`, `"scaleway"` or `"together"`.
124
122
  Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers.
125
123
  If model is a URL or `base_url` is passed, then `provider` is not used.
126
124
  token (`str`, *optional*):
@@ -256,39 +254,38 @@ class AsyncInferenceClient:
256
254
  if request_parameters.task in TASKS_EXPECTING_IMAGES and "Accept" not in request_parameters.headers:
257
255
  request_parameters.headers["Accept"] = "image/png"
258
256
 
259
- with _open_as_binary(request_parameters.data) as data_as_binary:
260
- # Do not use context manager as we don't want to close the connection immediately when returning
261
- # a stream
262
- session = self._get_client_session(headers=request_parameters.headers)
257
+ # Do not use context manager as we don't want to close the connection immediately when returning
258
+ # a stream
259
+ session = self._get_client_session(headers=request_parameters.headers)
263
260
 
264
- try:
265
- response = await session.post(
266
- request_parameters.url, json=request_parameters.json, data=data_as_binary, proxy=self.proxies
267
- )
268
- response_error_payload = None
269
- if response.status != 200:
270
- try:
271
- response_error_payload = await response.json() # get payload before connection closed
272
- except Exception:
273
- pass
274
- response.raise_for_status()
275
- if stream:
276
- return _async_yield_from(session, response)
277
- else:
278
- content = await response.read()
279
- await session.close()
280
- return content
281
- except asyncio.TimeoutError as error:
282
- await session.close()
283
- # Convert any `TimeoutError` to a `InferenceTimeoutError`
284
- raise InferenceTimeoutError(f"Inference call timed out: {request_parameters.url}") from error # type: ignore
285
- except aiohttp.ClientResponseError as error:
286
- error.response_error_payload = response_error_payload
287
- await session.close()
288
- raise error
289
- except Exception:
261
+ try:
262
+ response = await session.post(
263
+ request_parameters.url, json=request_parameters.json, data=request_parameters.data, proxy=self.proxies
264
+ )
265
+ response_error_payload = None
266
+ if response.status != 200:
267
+ try:
268
+ response_error_payload = await response.json() # get payload before connection closed
269
+ except Exception:
270
+ pass
271
+ response.raise_for_status()
272
+ if stream:
273
+ return _async_yield_from(session, response)
274
+ else:
275
+ content = await response.read()
290
276
  await session.close()
291
- raise
277
+ return content
278
+ except asyncio.TimeoutError as error:
279
+ await session.close()
280
+ # Convert any `TimeoutError` to a `InferenceTimeoutError`
281
+ raise InferenceTimeoutError(f"Inference call timed out: {request_parameters.url}") from error # type: ignore
282
+ except aiohttp.ClientResponseError as error:
283
+ error.response_error_payload = response_error_payload
284
+ await session.close()
285
+ raise error
286
+ except Exception:
287
+ await session.close()
288
+ raise
292
289
 
293
290
  async def __aenter__(self):
294
291
  return self
@@ -1385,6 +1382,86 @@ class AsyncInferenceClient:
1385
1382
  response = provider_helper.get_response(response, request_parameters)
1386
1383
  return _bytes_to_image(response)
1387
1384
 
1385
+ async def image_to_video(
1386
+ self,
1387
+ image: ContentT,
1388
+ *,
1389
+ model: Optional[str] = None,
1390
+ prompt: Optional[str] = None,
1391
+ negative_prompt: Optional[str] = None,
1392
+ num_frames: Optional[float] = None,
1393
+ num_inference_steps: Optional[int] = None,
1394
+ guidance_scale: Optional[float] = None,
1395
+ seed: Optional[int] = None,
1396
+ target_size: Optional[ImageToVideoTargetSize] = None,
1397
+ **kwargs,
1398
+ ) -> bytes:
1399
+ """
1400
+ Generate a video from an input image.
1401
+
1402
+ Args:
1403
+ image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`):
1404
+ The input image to generate a video from. It can be raw bytes, an image file, a URL to an online image, or a PIL Image.
1405
+ model (`str`, *optional*):
1406
+ The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
1407
+ Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
1408
+ prompt (`str`, *optional*):
1409
+ The text prompt to guide the video generation.
1410
+ negative_prompt (`str`, *optional*):
1411
+ One prompt to guide what NOT to include in video generation.
1412
+ num_frames (`float`, *optional*):
1413
+ The num_frames parameter determines how many video frames are generated.
1414
+ num_inference_steps (`int`, *optional*):
1415
+ For diffusion models. The number of denoising steps. More denoising steps usually lead to a higher
1416
+ quality image at the expense of slower inference.
1417
+ guidance_scale (`float`, *optional*):
1418
+ For diffusion models. A higher guidance scale value encourages the model to generate videos closely
1419
+ linked to the text prompt at the expense of lower image quality.
1420
+ seed (`int`, *optional*):
1421
+ The seed to use for the video generation.
1422
+ target_size (`ImageToVideoTargetSize`, *optional*):
1423
+ The size in pixel of the output video frames.
1424
+ num_inference_steps (`int`, *optional*):
1425
+ The number of denoising steps. More denoising steps usually lead to a higher quality video at the
1426
+ expense of slower inference.
1427
+ seed (`int`, *optional*):
1428
+ Seed for the random number generator.
1429
+
1430
+ Returns:
1431
+ `bytes`: The generated video.
1432
+
1433
+ Examples:
1434
+ ```py
1435
+ # Must be run in an async context
1436
+ >>> from huggingface_hub import AsyncInferenceClient
1437
+ >>> client = AsyncInferenceClient()
1438
+ >>> video = await client.image_to_video("cat.jpg", model="Wan-AI/Wan2.2-I2V-A14B", prompt="turn the cat into a tiger")
1439
+ >>> with open("tiger.mp4", "wb") as f:
1440
+ ... f.write(video)
1441
+ ```
1442
+ """
1443
+ model_id = model or self.model
1444
+ provider_helper = get_provider_helper(self.provider, task="image-to-video", model=model_id)
1445
+ request_parameters = provider_helper.prepare_request(
1446
+ inputs=image,
1447
+ parameters={
1448
+ "prompt": prompt,
1449
+ "negative_prompt": negative_prompt,
1450
+ "num_frames": num_frames,
1451
+ "num_inference_steps": num_inference_steps,
1452
+ "guidance_scale": guidance_scale,
1453
+ "seed": seed,
1454
+ "target_size": target_size,
1455
+ **kwargs,
1456
+ },
1457
+ headers=self.headers,
1458
+ model=model_id,
1459
+ api_key=self.token,
1460
+ )
1461
+ response = await self._inner_post(request_parameters)
1462
+ response = provider_helper.get_response(response, request_parameters)
1463
+ return response
1464
+
1388
1465
  async def image_to_text(self, image: ContentT, *, model: Optional[str] = None) -> ImageToTextOutput:
1389
1466
  """
1390
1467
  Takes an input image and return text.
@@ -1429,8 +1506,8 @@ class AsyncInferenceClient:
1429
1506
  api_key=self.token,
1430
1507
  )
1431
1508
  response = await self._inner_post(request_parameters)
1432
- output = ImageToTextOutput.parse_obj(response)
1433
- return output[0] if isinstance(output, list) else output
1509
+ output_list: List[ImageToTextOutput] = ImageToTextOutput.parse_obj_as_list(response)
1510
+ return output_list[0]
1434
1511
 
1435
1512
  async def object_detection(
1436
1513
  self, image: ContentT, *, model: Optional[str] = None, threshold: Optional[float] = None
@@ -3257,102 +3334,6 @@ class AsyncInferenceClient:
3257
3334
  response = await self._inner_post(request_parameters)
3258
3335
  return ZeroShotImageClassificationOutputElement.parse_obj_as_list(response)
3259
3336
 
3260
- @_deprecate_method(
3261
- version="0.35.0",
3262
- message=(
3263
- "HF Inference API is getting revamped and will only support warm models in the future (no cold start allowed)."
3264
- " Use `HfApi.list_models(..., inference_provider='...')` to list warm models per provider."
3265
- ),
3266
- )
3267
- async def list_deployed_models(
3268
- self, frameworks: Union[None, str, Literal["all"], List[str]] = None
3269
- ) -> Dict[str, List[str]]:
3270
- """
3271
- List models deployed on the HF Serverless Inference API service.
3272
-
3273
- This helper checks deployed models framework by framework. By default, it will check the 4 main frameworks that
3274
- are supported and account for 95% of the hosted models. However, if you want a complete list of models you can
3275
- specify `frameworks="all"` as input. Alternatively, if you know before-hand which framework you are interested
3276
- in, you can also restrict to search to this one (e.g. `frameworks="text-generation-inference"`). The more
3277
- frameworks are checked, the more time it will take.
3278
-
3279
- <Tip warning={true}>
3280
-
3281
- This endpoint method does not return a live list of all models available for the HF Inference API service.
3282
- It searches over a cached list of models that were recently available and the list may not be up to date.
3283
- If you want to know the live status of a specific model, use [`~InferenceClient.get_model_status`].
3284
-
3285
- </Tip>
3286
-
3287
- <Tip>
3288
-
3289
- This endpoint method is mostly useful for discoverability. If you already know which model you want to use and want to
3290
- check its availability, you can directly use [`~InferenceClient.get_model_status`].
3291
-
3292
- </Tip>
3293
-
3294
- Args:
3295
- frameworks (`Literal["all"]` or `List[str]` or `str`, *optional*):
3296
- The frameworks to filter on. By default only a subset of the available frameworks are tested. If set to
3297
- "all", all available frameworks will be tested. It is also possible to provide a single framework or a
3298
- custom set of frameworks to check.
3299
-
3300
- Returns:
3301
- `Dict[str, List[str]]`: A dictionary mapping task names to a sorted list of model IDs.
3302
-
3303
- Example:
3304
- ```py
3305
- # Must be run in an async contextthon
3306
- >>> from huggingface_hub import AsyncInferenceClient
3307
- >>> client = AsyncInferenceClient()
3308
-
3309
- # Discover zero-shot-classification models currently deployed
3310
- >>> models = await client.list_deployed_models()
3311
- >>> models["zero-shot-classification"]
3312
- ['Narsil/deberta-large-mnli-zero-cls', 'facebook/bart-large-mnli', ...]
3313
-
3314
- # List from only 1 framework
3315
- >>> await client.list_deployed_models("text-generation-inference")
3316
- {'text-generation': ['bigcode/starcoder', 'meta-llama/Llama-2-70b-chat-hf', ...], ...}
3317
- ```
3318
- """
3319
- if self.provider != "hf-inference":
3320
- raise ValueError(f"Listing deployed models is not supported on '{self.provider}'.")
3321
-
3322
- # Resolve which frameworks to check
3323
- if frameworks is None:
3324
- frameworks = constants.MAIN_INFERENCE_API_FRAMEWORKS
3325
- elif frameworks == "all":
3326
- frameworks = constants.ALL_INFERENCE_API_FRAMEWORKS
3327
- elif isinstance(frameworks, str):
3328
- frameworks = [frameworks]
3329
- frameworks = list(set(frameworks))
3330
-
3331
- # Fetch them iteratively
3332
- models_by_task: Dict[str, List[str]] = {}
3333
-
3334
- def _unpack_response(framework: str, items: List[Dict]) -> None:
3335
- for model in items:
3336
- if framework == "sentence-transformers":
3337
- # Model running with the `sentence-transformers` framework can work with both tasks even if not
3338
- # branded as such in the API response
3339
- models_by_task.setdefault("feature-extraction", []).append(model["model_id"])
3340
- models_by_task.setdefault("sentence-similarity", []).append(model["model_id"])
3341
- else:
3342
- models_by_task.setdefault(model["task"], []).append(model["model_id"])
3343
-
3344
- for framework in frameworks:
3345
- response = get_session().get(
3346
- f"{constants.INFERENCE_ENDPOINT}/framework/{framework}", headers=build_hf_headers(token=self.token)
3347
- )
3348
- hf_raise_for_status(response)
3349
- _unpack_response(framework, response.json())
3350
-
3351
- # Sort alphabetically for discoverability and return
3352
- for task, models in models_by_task.items():
3353
- models_by_task[task] = sorted(set(models), key=lambda x: x.lower())
3354
- return models_by_task
3355
-
3356
3337
  def _get_client_session(self, headers: Optional[Dict] = None) -> "ClientSession":
3357
3338
  aiohttp = _import_aiohttp()
3358
3339
  client_headers = self.headers.copy()
@@ -3459,7 +3440,6 @@ class AsyncInferenceClient:
3459
3440
  Check the health of the deployed endpoint.
3460
3441
 
3461
3442
  Health check is only available with Inference Endpoints powered by Text-Generation-Inference (TGI) or Text-Embedding-Inference (TEI).
3462
- For Inference API, please use [`InferenceClient.get_model_status`] instead.
3463
3443
 
3464
3444
  Args:
3465
3445
  model (`str`, *optional*):
@@ -3484,78 +3464,13 @@ class AsyncInferenceClient:
3484
3464
  if model is None:
3485
3465
  raise ValueError("Model id not provided.")
3486
3466
  if not model.startswith(("http://", "https://")):
3487
- raise ValueError(
3488
- "Model must be an Inference Endpoint URL. For serverless Inference API, please use `InferenceClient.get_model_status`."
3489
- )
3467
+ raise ValueError("Model must be an Inference Endpoint URL.")
3490
3468
  url = model.rstrip("/") + "/health"
3491
3469
 
3492
3470
  async with self._get_client_session(headers=build_hf_headers(token=self.token)) as client:
3493
3471
  response = await client.get(url, proxy=self.proxies)
3494
3472
  return response.status == 200
3495
3473
 
3496
- @_deprecate_method(
3497
- version="0.35.0",
3498
- message=(
3499
- "HF Inference API is getting revamped and will only support warm models in the future (no cold start allowed)."
3500
- " Use `HfApi.model_info` to get the model status both with HF Inference API and external providers."
3501
- ),
3502
- )
3503
- async def get_model_status(self, model: Optional[str] = None) -> ModelStatus:
3504
- """
3505
- Get the status of a model hosted on the HF Inference API.
3506
-
3507
- <Tip>
3508
-
3509
- This endpoint is mostly useful when you already know which model you want to use and want to check its
3510
- availability. If you want to discover already deployed models, you should rather use [`~InferenceClient.list_deployed_models`].
3511
-
3512
- </Tip>
3513
-
3514
- Args:
3515
- model (`str`, *optional*):
3516
- Identifier of the model for witch the status gonna be checked. If model is not provided,
3517
- the model associated with this instance of [`InferenceClient`] will be used. Only HF Inference API service can be checked so the
3518
- identifier cannot be a URL.
3519
-
3520
-
3521
- Returns:
3522
- [`ModelStatus`]: An instance of ModelStatus dataclass, containing information,
3523
- about the state of the model: load, state, compute type and framework.
3524
-
3525
- Example:
3526
- ```py
3527
- # Must be run in an async context
3528
- >>> from huggingface_hub import AsyncInferenceClient
3529
- >>> client = AsyncInferenceClient()
3530
- >>> await client.get_model_status("meta-llama/Meta-Llama-3-8B-Instruct")
3531
- ModelStatus(loaded=True, state='Loaded', compute_type='gpu', framework='text-generation-inference')
3532
- ```
3533
- """
3534
- if self.provider != "hf-inference":
3535
- raise ValueError(f"Getting model status is not supported on '{self.provider}'.")
3536
-
3537
- model = model or self.model
3538
- if model is None:
3539
- raise ValueError("Model id not provided.")
3540
- if model.startswith("https://"):
3541
- raise NotImplementedError("Model status is only available for Inference API endpoints.")
3542
- url = f"{constants.INFERENCE_ENDPOINT}/status/{model}"
3543
-
3544
- async with self._get_client_session(headers=build_hf_headers(token=self.token)) as client:
3545
- response = await client.get(url, proxy=self.proxies)
3546
- response.raise_for_status()
3547
- response_data = await response.json()
3548
-
3549
- if "error" in response_data:
3550
- raise ValueError(response_data["error"])
3551
-
3552
- return ModelStatus(
3553
- loaded=response_data["loaded"],
3554
- state=response_data["state"],
3555
- compute_type=response_data["compute_type"],
3556
- framework=response_data["framework"],
3557
- )
3558
-
3559
3474
  @property
3560
3475
  def chat(self) -> "ProxyClientChat":
3561
3476
  return ProxyClientChat(self)
@@ -239,6 +239,7 @@ class ChatCompletionOutputToolCall(BaseInferenceType):
239
239
  class ChatCompletionOutputMessage(BaseInferenceType):
240
240
  role: str
241
241
  content: Optional[str] = None
242
+ reasoning: Optional[str] = None
242
243
  tool_call_id: Optional[str] = None
243
244
  tool_calls: Optional[List[ChatCompletionOutputToolCall]] = None
244
245
 
@@ -292,6 +293,7 @@ class ChatCompletionStreamOutputDeltaToolCall(BaseInferenceType):
292
293
  class ChatCompletionStreamOutputDelta(BaseInferenceType):
293
294
  role: str
294
295
  content: Optional[str] = None
296
+ reasoning: Optional[str] = None
295
297
  tool_call_id: Optional[str] = None
296
298
  tool_calls: Optional[List[ChatCompletionStreamOutputDeltaToolCall]] = None
297
299
 
@@ -17,7 +17,7 @@ def _patch_anyio_open_process():
17
17
 
18
18
  if getattr(anyio, "_tiny_agents_patched", False):
19
19
  return
20
- anyio._tiny_agents_patched = True
20
+ anyio._tiny_agents_patched = True # ty: ignore[invalid-assignment]
21
21
 
22
22
  original_open_process = anyio.open_process
23
23
 
@@ -32,7 +32,7 @@ def _patch_anyio_open_process():
32
32
  kwargs.setdefault("creationflags", subprocess.CREATE_NEW_PROCESS_GROUP)
33
33
  return await original_open_process(*args, **kwargs)
34
34
 
35
- anyio.open_process = open_process_in_new_group
35
+ anyio.open_process = open_process_in_new_group # ty: ignore[invalid-assignment]
36
36
  else:
37
37
  # For Unix-like systems, we can use setsid to create a new session
38
38
  async def open_process_in_new_group(*args, **kwargs):
@@ -42,7 +42,7 @@ def _patch_anyio_open_process():
42
42
  kwargs.setdefault("start_new_session", True)
43
43
  return await original_open_process(*args, **kwargs)
44
44
 
45
- anyio.open_process = open_process_in_new_group
45
+ anyio.open_process = open_process_in_new_group # ty: ignore[invalid-assignment]
46
46
 
47
47
 
48
48
  async def _async_prompt(exit_event: asyncio.Event, prompt: str = "» ") -> str:
@@ -33,7 +33,7 @@ async def run_agent(
33
33
 
34
34
  Args:
35
35
  agent_path (`str`, *optional*):
36
- Path to a local folder containing an `agent.json` and optionally a custom `PROMPT.md` file or a built-in agent stored in a Hugging Face dataset.
36
+ Path to a local folder containing an `agent.json` and optionally a custom `PROMPT.md` or `AGENTS.md` file or a built-in agent stored in a Hugging Face dataset.
37
37
 
38
38
  """
39
39
  _patch_anyio_open_process() # Hacky way to prevent stdio connections to be stopped by Ctrl+C
@@ -8,7 +8,7 @@ from huggingface_hub import ChatCompletionInputTool
8
8
 
9
9
 
10
10
  FILENAME_CONFIG = "agent.json"
11
- FILENAME_PROMPT = "PROMPT.md"
11
+ PROMPT_FILENAMES = ("PROMPT.md", "AGENTS.md")
12
12
 
13
13
  DEFAULT_AGENT = {
14
14
  "model": "Qwen/Qwen2.5-72B-Instruct",
@@ -139,21 +139,27 @@ class MCPClient:
139
139
  - args (List[str], optional): Arguments for the command
140
140
  - env (Dict[str, str], optional): Environment variables for the command
141
141
  - cwd (Union[str, Path, None], optional): Working directory for the command
142
+ - allowed_tools (List[str], optional): List of tool names to allow from this server
142
143
  - For SSE servers:
143
144
  - url (str): The URL of the SSE server
144
145
  - headers (Dict[str, Any], optional): Headers for the SSE connection
145
146
  - timeout (float, optional): Connection timeout
146
147
  - sse_read_timeout (float, optional): SSE read timeout
148
+ - allowed_tools (List[str], optional): List of tool names to allow from this server
147
149
  - For StreamableHTTP servers:
148
150
  - url (str): The URL of the StreamableHTTP server
149
151
  - headers (Dict[str, Any], optional): Headers for the StreamableHTTP connection
150
152
  - timeout (timedelta, optional): Connection timeout
151
153
  - sse_read_timeout (timedelta, optional): SSE read timeout
152
154
  - terminate_on_close (bool, optional): Whether to terminate on close
155
+ - allowed_tools (List[str], optional): List of tool names to allow from this server
153
156
  """
154
157
  from mcp import ClientSession, StdioServerParameters
155
158
  from mcp import types as mcp_types
156
159
 
160
+ # Extract allowed_tools configuration if provided
161
+ allowed_tools = params.pop("allowed_tools", [])
162
+
157
163
  # Determine server type and create appropriate parameters
158
164
  if type == "stdio":
159
165
  # Handle stdio server
@@ -211,7 +217,15 @@ class MCPClient:
211
217
  response = await session.list_tools()
212
218
  logger.debug("Connected to server with tools:", [tool.name for tool in response.tools])
213
219
 
214
- for tool in response.tools:
220
+ # Filter tools based on allowed_tools configuration
221
+ filtered_tools = [tool for tool in response.tools if tool.name in allowed_tools]
222
+
223
+ if allowed_tools:
224
+ logger.debug(
225
+ f"Tool filtering applied. Using {len(filtered_tools)} of {len(response.tools)} available tools: {[tool.name for tool in filtered_tools]}"
226
+ )
227
+
228
+ for tool in filtered_tools:
215
229
  if tool.name in self.sessions:
216
230
  logger.warning(f"Tool '{tool.name}' already defined by another server. Skipping.")
217
231
  continue
@@ -286,16 +300,19 @@ class MCPClient:
286
300
  # Process tool calls
287
301
  if delta.tool_calls:
288
302
  for tool_call in delta.tool_calls:
289
- # Aggregate chunks into tool calls
290
- if tool_call.index not in final_tool_calls:
291
- if (
292
- tool_call.function.arguments is None or tool_call.function.arguments == "{}"
293
- ): # Corner case (depends on provider)
294
- tool_call.function.arguments = ""
295
- final_tool_calls[tool_call.index] = tool_call
296
-
297
- elif tool_call.function.arguments:
298
- final_tool_calls[tool_call.index].function.arguments += tool_call.function.arguments
303
+ idx = tool_call.index
304
+ # first chunk for this tool call
305
+ if idx not in final_tool_calls:
306
+ final_tool_calls[idx] = tool_call
307
+ if final_tool_calls[idx].function.arguments is None:
308
+ final_tool_calls[idx].function.arguments = ""
309
+ continue
310
+ # safety before concatenating text to .function.arguments
311
+ if final_tool_calls[idx].function.arguments is None:
312
+ final_tool_calls[idx].function.arguments = ""
313
+
314
+ if tool_call.function.arguments:
315
+ final_tool_calls[idx].function.arguments += tool_call.function.arguments
299
316
 
300
317
  # Optionally exit early if no tools in first chunks
301
318
  if exit_if_first_chunk_no_tool and num_of_chunks <= 2 and len(final_tool_calls) == 0:
@@ -16,18 +16,21 @@ class StdioServerConfig(TypedDict):
16
16
  args: List[str]
17
17
  env: Dict[str, str]
18
18
  cwd: str
19
+ allowed_tools: NotRequired[List[str]]
19
20
 
20
21
 
21
22
  class HTTPServerConfig(TypedDict):
22
23
  type: Literal["http"]
23
24
  url: str
24
25
  headers: Dict[str, str]
26
+ allowed_tools: NotRequired[List[str]]
25
27
 
26
28
 
27
29
  class SSEServerConfig(TypedDict):
28
30
  type: Literal["sse"]
29
31
  url: str
30
32
  headers: Dict[str, str]
33
+ allowed_tools: NotRequired[List[str]]
31
34
 
32
35
 
33
36
  ServerConfig = Union[StdioServerConfig, HTTPServerConfig, SSEServerConfig]
@@ -11,7 +11,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple
11
11
  from huggingface_hub import snapshot_download
12
12
  from huggingface_hub.errors import EntryNotFoundError
13
13
 
14
- from .constants import DEFAULT_AGENT, DEFAULT_REPO_ID, FILENAME_CONFIG, FILENAME_PROMPT
14
+ from .constants import DEFAULT_AGENT, DEFAULT_REPO_ID, FILENAME_CONFIG, PROMPT_FILENAMES
15
15
  from .types import AgentConfig
16
16
 
17
17
 
@@ -93,8 +93,12 @@ def _load_agent_config(agent_path: Optional[str]) -> Tuple[AgentConfig, Optional
93
93
  raise FileNotFoundError(f" Config file not found in {directory}! Please make sure it exists locally")
94
94
 
95
95
  config: AgentConfig = json.loads(cfg_file.read_text(encoding="utf-8"))
96
- prompt_file = directory / FILENAME_PROMPT
97
- prompt: Optional[str] = prompt_file.read_text(encoding="utf-8") if prompt_file.exists() else None
96
+ prompt: Optional[str] = None
97
+ for filename in PROMPT_FILENAMES:
98
+ prompt_file = directory / filename
99
+ if prompt_file.exists():
100
+ prompt = prompt_file.read_text(encoding="utf-8")
101
+ break
98
102
  return config, prompt
99
103
 
100
104
  if agent_path is None:
@@ -13,6 +13,7 @@ from .cohere import CohereConversationalTask
13
13
  from .fal_ai import (
14
14
  FalAIAutomaticSpeechRecognitionTask,
15
15
  FalAIImageToImageTask,
16
+ FalAIImageToVideoTask,
16
17
  FalAITextToImageTask,
17
18
  FalAITextToSpeechTask,
18
19
  FalAITextToVideoTask,
@@ -35,8 +36,10 @@ from .nebius import (
35
36
  from .novita import NovitaConversationalTask, NovitaTextGenerationTask, NovitaTextToVideoTask
36
37
  from .nscale import NscaleConversationalTask, NscaleTextToImageTask
37
38
  from .openai import OpenAIConversationalTask
39
+ from .publicai import PublicAIConversationalTask
38
40
  from .replicate import ReplicateImageToImageTask, ReplicateTask, ReplicateTextToImageTask, ReplicateTextToSpeechTask
39
41
  from .sambanova import SambanovaConversationalTask, SambanovaFeatureExtractionTask
42
+ from .scaleway import ScalewayConversationalTask, ScalewayFeatureExtractionTask
40
43
  from .together import TogetherConversationalTask, TogetherTextGenerationTask, TogetherTextToImageTask
41
44
 
42
45
 
@@ -57,8 +60,10 @@ PROVIDER_T = Literal[
57
60
  "novita",
58
61
  "nscale",
59
62
  "openai",
63
+ "publicai",
60
64
  "replicate",
61
65
  "sambanova",
66
+ "scaleway",
62
67
  "together",
63
68
  ]
64
69
 
@@ -79,6 +84,7 @@ PROVIDERS: Dict[PROVIDER_T, Dict[str, TaskProviderHelper]] = {
79
84
  "text-to-image": FalAITextToImageTask(),
80
85
  "text-to-speech": FalAITextToSpeechTask(),
81
86
  "text-to-video": FalAITextToVideoTask(),
87
+ "image-to-video": FalAIImageToVideoTask(),
82
88
  "image-to-image": FalAIImageToImageTask(),
83
89
  },
84
90
  "featherless-ai": {
@@ -142,6 +148,9 @@ PROVIDERS: Dict[PROVIDER_T, Dict[str, TaskProviderHelper]] = {
142
148
  "openai": {
143
149
  "conversational": OpenAIConversationalTask(),
144
150
  },
151
+ "publicai": {
152
+ "conversational": PublicAIConversationalTask(),
153
+ },
145
154
  "replicate": {
146
155
  "image-to-image": ReplicateImageToImageTask(),
147
156
  "text-to-image": ReplicateTextToImageTask(),
@@ -152,6 +161,10 @@ PROVIDERS: Dict[PROVIDER_T, Dict[str, TaskProviderHelper]] = {
152
161
  "conversational": SambanovaConversationalTask(),
153
162
  "feature-extraction": SambanovaFeatureExtractionTask(),
154
163
  },
164
+ "scaleway": {
165
+ "conversational": ScalewayConversationalTask(),
166
+ "feature-extraction": ScalewayFeatureExtractionTask(),
167
+ },
155
168
  "together": {
156
169
  "text-to-image": TogetherTextToImageTask(),
157
170
  "conversational": TogetherConversationalTask(),