huggingface-hub 0.34.6__py3-none-any.whl → 0.35.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of huggingface-hub might be problematic. Click here for more details.
- huggingface_hub/__init__.py +19 -1
- huggingface_hub/_jobs_api.py +159 -2
- huggingface_hub/_tensorboard_logger.py +9 -10
- huggingface_hub/cli/auth.py +1 -1
- huggingface_hub/cli/cache.py +3 -9
- huggingface_hub/cli/jobs.py +551 -1
- huggingface_hub/cli/repo.py +6 -4
- huggingface_hub/commands/delete_cache.py +2 -2
- huggingface_hub/commands/scan_cache.py +1 -1
- huggingface_hub/commands/user.py +1 -1
- huggingface_hub/hf_api.py +522 -78
- huggingface_hub/hf_file_system.py +3 -1
- huggingface_hub/hub_mixin.py +5 -3
- huggingface_hub/inference/_client.py +17 -180
- huggingface_hub/inference/_common.py +72 -70
- huggingface_hub/inference/_generated/_async_client.py +34 -200
- huggingface_hub/inference/_generated/types/chat_completion.py +2 -0
- huggingface_hub/inference/_mcp/_cli_hacks.py +3 -3
- huggingface_hub/inference/_mcp/cli.py +1 -1
- huggingface_hub/inference/_mcp/constants.py +1 -1
- huggingface_hub/inference/_mcp/mcp_client.py +28 -11
- huggingface_hub/inference/_mcp/types.py +3 -0
- huggingface_hub/inference/_mcp/utils.py +7 -3
- huggingface_hub/inference/_providers/_common.py +28 -4
- huggingface_hub/inference/_providers/black_forest_labs.py +1 -1
- huggingface_hub/inference/_providers/fal_ai.py +2 -2
- huggingface_hub/inference/_providers/hf_inference.py +15 -7
- huggingface_hub/inference/_providers/replicate.py +1 -1
- huggingface_hub/repocard.py +2 -1
- huggingface_hub/utils/_git_credential.py +1 -1
- huggingface_hub/utils/_typing.py +24 -4
- huggingface_hub/utils/_xet_progress_reporting.py +31 -10
- {huggingface_hub-0.34.6.dist-info → huggingface_hub-0.35.0.dist-info}/METADATA +7 -4
- {huggingface_hub-0.34.6.dist-info → huggingface_hub-0.35.0.dist-info}/RECORD +38 -38
- {huggingface_hub-0.34.6.dist-info → huggingface_hub-0.35.0.dist-info}/LICENSE +0 -0
- {huggingface_hub-0.34.6.dist-info → huggingface_hub-0.35.0.dist-info}/WHEEL +0 -0
- {huggingface_hub-0.34.6.dist-info → huggingface_hub-0.35.0.dist-info}/entry_points.txt +0 -0
- {huggingface_hub-0.34.6.dist-info → huggingface_hub-0.35.0.dist-info}/top_level.txt +0 -0
|
@@ -30,7 +30,6 @@ from huggingface_hub.errors import InferenceTimeoutError
|
|
|
30
30
|
from huggingface_hub.inference._common import (
|
|
31
31
|
TASKS_EXPECTING_IMAGES,
|
|
32
32
|
ContentT,
|
|
33
|
-
ModelStatus,
|
|
34
33
|
RequestParameters,
|
|
35
34
|
_async_stream_chat_completion_response,
|
|
36
35
|
_async_stream_text_generation_response,
|
|
@@ -41,7 +40,6 @@ from huggingface_hub.inference._common import (
|
|
|
41
40
|
_bytes_to_list,
|
|
42
41
|
_get_unsupported_text_generation_kwargs,
|
|
43
42
|
_import_numpy,
|
|
44
|
-
_open_as_binary,
|
|
45
43
|
_set_unsupported_text_generation_kwargs,
|
|
46
44
|
raise_text_generation_error,
|
|
47
45
|
)
|
|
@@ -88,9 +86,8 @@ from huggingface_hub.inference._generated.types import (
|
|
|
88
86
|
ZeroShotImageClassificationOutputElement,
|
|
89
87
|
)
|
|
90
88
|
from huggingface_hub.inference._providers import PROVIDER_OR_POLICY_T, get_provider_helper
|
|
91
|
-
from huggingface_hub.utils import build_hf_headers
|
|
89
|
+
from huggingface_hub.utils import build_hf_headers
|
|
92
90
|
from huggingface_hub.utils._auth import get_token
|
|
93
|
-
from huggingface_hub.utils._deprecation import _deprecate_method
|
|
94
91
|
|
|
95
92
|
from .._common import _async_yield_from, _import_aiohttp
|
|
96
93
|
|
|
@@ -257,39 +254,38 @@ class AsyncInferenceClient:
|
|
|
257
254
|
if request_parameters.task in TASKS_EXPECTING_IMAGES and "Accept" not in request_parameters.headers:
|
|
258
255
|
request_parameters.headers["Accept"] = "image/png"
|
|
259
256
|
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
session = self._get_client_session(headers=request_parameters.headers)
|
|
257
|
+
# Do not use context manager as we don't want to close the connection immediately when returning
|
|
258
|
+
# a stream
|
|
259
|
+
session = self._get_client_session(headers=request_parameters.headers)
|
|
264
260
|
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
await session.close()
|
|
281
|
-
return content
|
|
282
|
-
except asyncio.TimeoutError as error:
|
|
283
|
-
await session.close()
|
|
284
|
-
# Convert any `TimeoutError` to a `InferenceTimeoutError`
|
|
285
|
-
raise InferenceTimeoutError(f"Inference call timed out: {request_parameters.url}") from error # type: ignore
|
|
286
|
-
except aiohttp.ClientResponseError as error:
|
|
287
|
-
error.response_error_payload = response_error_payload
|
|
288
|
-
await session.close()
|
|
289
|
-
raise error
|
|
290
|
-
except Exception:
|
|
261
|
+
try:
|
|
262
|
+
response = await session.post(
|
|
263
|
+
request_parameters.url, json=request_parameters.json, data=request_parameters.data, proxy=self.proxies
|
|
264
|
+
)
|
|
265
|
+
response_error_payload = None
|
|
266
|
+
if response.status != 200:
|
|
267
|
+
try:
|
|
268
|
+
response_error_payload = await response.json() # get payload before connection closed
|
|
269
|
+
except Exception:
|
|
270
|
+
pass
|
|
271
|
+
response.raise_for_status()
|
|
272
|
+
if stream:
|
|
273
|
+
return _async_yield_from(session, response)
|
|
274
|
+
else:
|
|
275
|
+
content = await response.read()
|
|
291
276
|
await session.close()
|
|
292
|
-
|
|
277
|
+
return content
|
|
278
|
+
except asyncio.TimeoutError as error:
|
|
279
|
+
await session.close()
|
|
280
|
+
# Convert any `TimeoutError` to a `InferenceTimeoutError`
|
|
281
|
+
raise InferenceTimeoutError(f"Inference call timed out: {request_parameters.url}") from error # type: ignore
|
|
282
|
+
except aiohttp.ClientResponseError as error:
|
|
283
|
+
error.response_error_payload = response_error_payload
|
|
284
|
+
await session.close()
|
|
285
|
+
raise error
|
|
286
|
+
except Exception:
|
|
287
|
+
await session.close()
|
|
288
|
+
raise
|
|
293
289
|
|
|
294
290
|
async def __aenter__(self):
|
|
295
291
|
return self
|
|
@@ -1510,8 +1506,8 @@ class AsyncInferenceClient:
|
|
|
1510
1506
|
api_key=self.token,
|
|
1511
1507
|
)
|
|
1512
1508
|
response = await self._inner_post(request_parameters)
|
|
1513
|
-
|
|
1514
|
-
return
|
|
1509
|
+
output_list: List[ImageToTextOutput] = ImageToTextOutput.parse_obj_as_list(response)
|
|
1510
|
+
return output_list[0]
|
|
1515
1511
|
|
|
1516
1512
|
async def object_detection(
|
|
1517
1513
|
self, image: ContentT, *, model: Optional[str] = None, threshold: Optional[float] = None
|
|
@@ -3338,102 +3334,6 @@ class AsyncInferenceClient:
|
|
|
3338
3334
|
response = await self._inner_post(request_parameters)
|
|
3339
3335
|
return ZeroShotImageClassificationOutputElement.parse_obj_as_list(response)
|
|
3340
3336
|
|
|
3341
|
-
@_deprecate_method(
|
|
3342
|
-
version="0.35.0",
|
|
3343
|
-
message=(
|
|
3344
|
-
"HF Inference API is getting revamped and will only support warm models in the future (no cold start allowed)."
|
|
3345
|
-
" Use `HfApi.list_models(..., inference_provider='...')` to list warm models per provider."
|
|
3346
|
-
),
|
|
3347
|
-
)
|
|
3348
|
-
async def list_deployed_models(
|
|
3349
|
-
self, frameworks: Union[None, str, Literal["all"], List[str]] = None
|
|
3350
|
-
) -> Dict[str, List[str]]:
|
|
3351
|
-
"""
|
|
3352
|
-
List models deployed on the HF Serverless Inference API service.
|
|
3353
|
-
|
|
3354
|
-
This helper checks deployed models framework by framework. By default, it will check the 4 main frameworks that
|
|
3355
|
-
are supported and account for 95% of the hosted models. However, if you want a complete list of models you can
|
|
3356
|
-
specify `frameworks="all"` as input. Alternatively, if you know before-hand which framework you are interested
|
|
3357
|
-
in, you can also restrict to search to this one (e.g. `frameworks="text-generation-inference"`). The more
|
|
3358
|
-
frameworks are checked, the more time it will take.
|
|
3359
|
-
|
|
3360
|
-
<Tip warning={true}>
|
|
3361
|
-
|
|
3362
|
-
This endpoint method does not return a live list of all models available for the HF Inference API service.
|
|
3363
|
-
It searches over a cached list of models that were recently available and the list may not be up to date.
|
|
3364
|
-
If you want to know the live status of a specific model, use [`~InferenceClient.get_model_status`].
|
|
3365
|
-
|
|
3366
|
-
</Tip>
|
|
3367
|
-
|
|
3368
|
-
<Tip>
|
|
3369
|
-
|
|
3370
|
-
This endpoint method is mostly useful for discoverability. If you already know which model you want to use and want to
|
|
3371
|
-
check its availability, you can directly use [`~InferenceClient.get_model_status`].
|
|
3372
|
-
|
|
3373
|
-
</Tip>
|
|
3374
|
-
|
|
3375
|
-
Args:
|
|
3376
|
-
frameworks (`Literal["all"]` or `List[str]` or `str`, *optional*):
|
|
3377
|
-
The frameworks to filter on. By default only a subset of the available frameworks are tested. If set to
|
|
3378
|
-
"all", all available frameworks will be tested. It is also possible to provide a single framework or a
|
|
3379
|
-
custom set of frameworks to check.
|
|
3380
|
-
|
|
3381
|
-
Returns:
|
|
3382
|
-
`Dict[str, List[str]]`: A dictionary mapping task names to a sorted list of model IDs.
|
|
3383
|
-
|
|
3384
|
-
Example:
|
|
3385
|
-
```py
|
|
3386
|
-
# Must be run in an async contextthon
|
|
3387
|
-
>>> from huggingface_hub import AsyncInferenceClient
|
|
3388
|
-
>>> client = AsyncInferenceClient()
|
|
3389
|
-
|
|
3390
|
-
# Discover zero-shot-classification models currently deployed
|
|
3391
|
-
>>> models = await client.list_deployed_models()
|
|
3392
|
-
>>> models["zero-shot-classification"]
|
|
3393
|
-
['Narsil/deberta-large-mnli-zero-cls', 'facebook/bart-large-mnli', ...]
|
|
3394
|
-
|
|
3395
|
-
# List from only 1 framework
|
|
3396
|
-
>>> await client.list_deployed_models("text-generation-inference")
|
|
3397
|
-
{'text-generation': ['bigcode/starcoder', 'meta-llama/Llama-2-70b-chat-hf', ...], ...}
|
|
3398
|
-
```
|
|
3399
|
-
"""
|
|
3400
|
-
if self.provider != "hf-inference":
|
|
3401
|
-
raise ValueError(f"Listing deployed models is not supported on '{self.provider}'.")
|
|
3402
|
-
|
|
3403
|
-
# Resolve which frameworks to check
|
|
3404
|
-
if frameworks is None:
|
|
3405
|
-
frameworks = constants.MAIN_INFERENCE_API_FRAMEWORKS
|
|
3406
|
-
elif frameworks == "all":
|
|
3407
|
-
frameworks = constants.ALL_INFERENCE_API_FRAMEWORKS
|
|
3408
|
-
elif isinstance(frameworks, str):
|
|
3409
|
-
frameworks = [frameworks]
|
|
3410
|
-
frameworks = list(set(frameworks))
|
|
3411
|
-
|
|
3412
|
-
# Fetch them iteratively
|
|
3413
|
-
models_by_task: Dict[str, List[str]] = {}
|
|
3414
|
-
|
|
3415
|
-
def _unpack_response(framework: str, items: List[Dict]) -> None:
|
|
3416
|
-
for model in items:
|
|
3417
|
-
if framework == "sentence-transformers":
|
|
3418
|
-
# Model running with the `sentence-transformers` framework can work with both tasks even if not
|
|
3419
|
-
# branded as such in the API response
|
|
3420
|
-
models_by_task.setdefault("feature-extraction", []).append(model["model_id"])
|
|
3421
|
-
models_by_task.setdefault("sentence-similarity", []).append(model["model_id"])
|
|
3422
|
-
else:
|
|
3423
|
-
models_by_task.setdefault(model["task"], []).append(model["model_id"])
|
|
3424
|
-
|
|
3425
|
-
for framework in frameworks:
|
|
3426
|
-
response = get_session().get(
|
|
3427
|
-
f"{constants.INFERENCE_ENDPOINT}/framework/{framework}", headers=build_hf_headers(token=self.token)
|
|
3428
|
-
)
|
|
3429
|
-
hf_raise_for_status(response)
|
|
3430
|
-
_unpack_response(framework, response.json())
|
|
3431
|
-
|
|
3432
|
-
# Sort alphabetically for discoverability and return
|
|
3433
|
-
for task, models in models_by_task.items():
|
|
3434
|
-
models_by_task[task] = sorted(set(models), key=lambda x: x.lower())
|
|
3435
|
-
return models_by_task
|
|
3436
|
-
|
|
3437
3337
|
def _get_client_session(self, headers: Optional[Dict] = None) -> "ClientSession":
|
|
3438
3338
|
aiohttp = _import_aiohttp()
|
|
3439
3339
|
client_headers = self.headers.copy()
|
|
@@ -3540,7 +3440,6 @@ class AsyncInferenceClient:
|
|
|
3540
3440
|
Check the health of the deployed endpoint.
|
|
3541
3441
|
|
|
3542
3442
|
Health check is only available with Inference Endpoints powered by Text-Generation-Inference (TGI) or Text-Embedding-Inference (TEI).
|
|
3543
|
-
For Inference API, please use [`InferenceClient.get_model_status`] instead.
|
|
3544
3443
|
|
|
3545
3444
|
Args:
|
|
3546
3445
|
model (`str`, *optional*):
|
|
@@ -3565,78 +3464,13 @@ class AsyncInferenceClient:
|
|
|
3565
3464
|
if model is None:
|
|
3566
3465
|
raise ValueError("Model id not provided.")
|
|
3567
3466
|
if not model.startswith(("http://", "https://")):
|
|
3568
|
-
raise ValueError(
|
|
3569
|
-
"Model must be an Inference Endpoint URL. For serverless Inference API, please use `InferenceClient.get_model_status`."
|
|
3570
|
-
)
|
|
3467
|
+
raise ValueError("Model must be an Inference Endpoint URL.")
|
|
3571
3468
|
url = model.rstrip("/") + "/health"
|
|
3572
3469
|
|
|
3573
3470
|
async with self._get_client_session(headers=build_hf_headers(token=self.token)) as client:
|
|
3574
3471
|
response = await client.get(url, proxy=self.proxies)
|
|
3575
3472
|
return response.status == 200
|
|
3576
3473
|
|
|
3577
|
-
@_deprecate_method(
|
|
3578
|
-
version="0.35.0",
|
|
3579
|
-
message=(
|
|
3580
|
-
"HF Inference API is getting revamped and will only support warm models in the future (no cold start allowed)."
|
|
3581
|
-
" Use `HfApi.model_info` to get the model status both with HF Inference API and external providers."
|
|
3582
|
-
),
|
|
3583
|
-
)
|
|
3584
|
-
async def get_model_status(self, model: Optional[str] = None) -> ModelStatus:
|
|
3585
|
-
"""
|
|
3586
|
-
Get the status of a model hosted on the HF Inference API.
|
|
3587
|
-
|
|
3588
|
-
<Tip>
|
|
3589
|
-
|
|
3590
|
-
This endpoint is mostly useful when you already know which model you want to use and want to check its
|
|
3591
|
-
availability. If you want to discover already deployed models, you should rather use [`~InferenceClient.list_deployed_models`].
|
|
3592
|
-
|
|
3593
|
-
</Tip>
|
|
3594
|
-
|
|
3595
|
-
Args:
|
|
3596
|
-
model (`str`, *optional*):
|
|
3597
|
-
Identifier of the model for witch the status gonna be checked. If model is not provided,
|
|
3598
|
-
the model associated with this instance of [`InferenceClient`] will be used. Only HF Inference API service can be checked so the
|
|
3599
|
-
identifier cannot be a URL.
|
|
3600
|
-
|
|
3601
|
-
|
|
3602
|
-
Returns:
|
|
3603
|
-
[`ModelStatus`]: An instance of ModelStatus dataclass, containing information,
|
|
3604
|
-
about the state of the model: load, state, compute type and framework.
|
|
3605
|
-
|
|
3606
|
-
Example:
|
|
3607
|
-
```py
|
|
3608
|
-
# Must be run in an async context
|
|
3609
|
-
>>> from huggingface_hub import AsyncInferenceClient
|
|
3610
|
-
>>> client = AsyncInferenceClient()
|
|
3611
|
-
>>> await client.get_model_status("meta-llama/Meta-Llama-3-8B-Instruct")
|
|
3612
|
-
ModelStatus(loaded=True, state='Loaded', compute_type='gpu', framework='text-generation-inference')
|
|
3613
|
-
```
|
|
3614
|
-
"""
|
|
3615
|
-
if self.provider != "hf-inference":
|
|
3616
|
-
raise ValueError(f"Getting model status is not supported on '{self.provider}'.")
|
|
3617
|
-
|
|
3618
|
-
model = model or self.model
|
|
3619
|
-
if model is None:
|
|
3620
|
-
raise ValueError("Model id not provided.")
|
|
3621
|
-
if model.startswith("https://"):
|
|
3622
|
-
raise NotImplementedError("Model status is only available for Inference API endpoints.")
|
|
3623
|
-
url = f"{constants.INFERENCE_ENDPOINT}/status/{model}"
|
|
3624
|
-
|
|
3625
|
-
async with self._get_client_session(headers=build_hf_headers(token=self.token)) as client:
|
|
3626
|
-
response = await client.get(url, proxy=self.proxies)
|
|
3627
|
-
response.raise_for_status()
|
|
3628
|
-
response_data = await response.json()
|
|
3629
|
-
|
|
3630
|
-
if "error" in response_data:
|
|
3631
|
-
raise ValueError(response_data["error"])
|
|
3632
|
-
|
|
3633
|
-
return ModelStatus(
|
|
3634
|
-
loaded=response_data["loaded"],
|
|
3635
|
-
state=response_data["state"],
|
|
3636
|
-
compute_type=response_data["compute_type"],
|
|
3637
|
-
framework=response_data["framework"],
|
|
3638
|
-
)
|
|
3639
|
-
|
|
3640
3474
|
@property
|
|
3641
3475
|
def chat(self) -> "ProxyClientChat":
|
|
3642
3476
|
return ProxyClientChat(self)
|
|
@@ -239,6 +239,7 @@ class ChatCompletionOutputToolCall(BaseInferenceType):
|
|
|
239
239
|
class ChatCompletionOutputMessage(BaseInferenceType):
|
|
240
240
|
role: str
|
|
241
241
|
content: Optional[str] = None
|
|
242
|
+
reasoning: Optional[str] = None
|
|
242
243
|
tool_call_id: Optional[str] = None
|
|
243
244
|
tool_calls: Optional[List[ChatCompletionOutputToolCall]] = None
|
|
244
245
|
|
|
@@ -292,6 +293,7 @@ class ChatCompletionStreamOutputDeltaToolCall(BaseInferenceType):
|
|
|
292
293
|
class ChatCompletionStreamOutputDelta(BaseInferenceType):
|
|
293
294
|
role: str
|
|
294
295
|
content: Optional[str] = None
|
|
296
|
+
reasoning: Optional[str] = None
|
|
295
297
|
tool_call_id: Optional[str] = None
|
|
296
298
|
tool_calls: Optional[List[ChatCompletionStreamOutputDeltaToolCall]] = None
|
|
297
299
|
|
|
@@ -17,7 +17,7 @@ def _patch_anyio_open_process():
|
|
|
17
17
|
|
|
18
18
|
if getattr(anyio, "_tiny_agents_patched", False):
|
|
19
19
|
return
|
|
20
|
-
anyio._tiny_agents_patched = True
|
|
20
|
+
anyio._tiny_agents_patched = True # ty: ignore[invalid-assignment]
|
|
21
21
|
|
|
22
22
|
original_open_process = anyio.open_process
|
|
23
23
|
|
|
@@ -32,7 +32,7 @@ def _patch_anyio_open_process():
|
|
|
32
32
|
kwargs.setdefault("creationflags", subprocess.CREATE_NEW_PROCESS_GROUP)
|
|
33
33
|
return await original_open_process(*args, **kwargs)
|
|
34
34
|
|
|
35
|
-
anyio.open_process = open_process_in_new_group
|
|
35
|
+
anyio.open_process = open_process_in_new_group # ty: ignore[invalid-assignment]
|
|
36
36
|
else:
|
|
37
37
|
# For Unix-like systems, we can use setsid to create a new session
|
|
38
38
|
async def open_process_in_new_group(*args, **kwargs):
|
|
@@ -42,7 +42,7 @@ def _patch_anyio_open_process():
|
|
|
42
42
|
kwargs.setdefault("start_new_session", True)
|
|
43
43
|
return await original_open_process(*args, **kwargs)
|
|
44
44
|
|
|
45
|
-
anyio.open_process = open_process_in_new_group
|
|
45
|
+
anyio.open_process = open_process_in_new_group # ty: ignore[invalid-assignment]
|
|
46
46
|
|
|
47
47
|
|
|
48
48
|
async def _async_prompt(exit_event: asyncio.Event, prompt: str = "» ") -> str:
|
|
@@ -33,7 +33,7 @@ async def run_agent(
|
|
|
33
33
|
|
|
34
34
|
Args:
|
|
35
35
|
agent_path (`str`, *optional*):
|
|
36
|
-
Path to a local folder containing an `agent.json` and optionally a custom `PROMPT.md` file or a built-in agent stored in a Hugging Face dataset.
|
|
36
|
+
Path to a local folder containing an `agent.json` and optionally a custom `PROMPT.md` or `AGENTS.md` file or a built-in agent stored in a Hugging Face dataset.
|
|
37
37
|
|
|
38
38
|
"""
|
|
39
39
|
_patch_anyio_open_process() # Hacky way to prevent stdio connections to be stopped by Ctrl+C
|
|
@@ -139,21 +139,27 @@ class MCPClient:
|
|
|
139
139
|
- args (List[str], optional): Arguments for the command
|
|
140
140
|
- env (Dict[str, str], optional): Environment variables for the command
|
|
141
141
|
- cwd (Union[str, Path, None], optional): Working directory for the command
|
|
142
|
+
- allowed_tools (List[str], optional): List of tool names to allow from this server
|
|
142
143
|
- For SSE servers:
|
|
143
144
|
- url (str): The URL of the SSE server
|
|
144
145
|
- headers (Dict[str, Any], optional): Headers for the SSE connection
|
|
145
146
|
- timeout (float, optional): Connection timeout
|
|
146
147
|
- sse_read_timeout (float, optional): SSE read timeout
|
|
148
|
+
- allowed_tools (List[str], optional): List of tool names to allow from this server
|
|
147
149
|
- For StreamableHTTP servers:
|
|
148
150
|
- url (str): The URL of the StreamableHTTP server
|
|
149
151
|
- headers (Dict[str, Any], optional): Headers for the StreamableHTTP connection
|
|
150
152
|
- timeout (timedelta, optional): Connection timeout
|
|
151
153
|
- sse_read_timeout (timedelta, optional): SSE read timeout
|
|
152
154
|
- terminate_on_close (bool, optional): Whether to terminate on close
|
|
155
|
+
- allowed_tools (List[str], optional): List of tool names to allow from this server
|
|
153
156
|
"""
|
|
154
157
|
from mcp import ClientSession, StdioServerParameters
|
|
155
158
|
from mcp import types as mcp_types
|
|
156
159
|
|
|
160
|
+
# Extract allowed_tools configuration if provided
|
|
161
|
+
allowed_tools = params.pop("allowed_tools", [])
|
|
162
|
+
|
|
157
163
|
# Determine server type and create appropriate parameters
|
|
158
164
|
if type == "stdio":
|
|
159
165
|
# Handle stdio server
|
|
@@ -211,7 +217,15 @@ class MCPClient:
|
|
|
211
217
|
response = await session.list_tools()
|
|
212
218
|
logger.debug("Connected to server with tools:", [tool.name for tool in response.tools])
|
|
213
219
|
|
|
214
|
-
|
|
220
|
+
# Filter tools based on allowed_tools configuration
|
|
221
|
+
filtered_tools = [tool for tool in response.tools if tool.name in allowed_tools]
|
|
222
|
+
|
|
223
|
+
if allowed_tools:
|
|
224
|
+
logger.debug(
|
|
225
|
+
f"Tool filtering applied. Using {len(filtered_tools)} of {len(response.tools)} available tools: {[tool.name for tool in filtered_tools]}"
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
for tool in filtered_tools:
|
|
215
229
|
if tool.name in self.sessions:
|
|
216
230
|
logger.warning(f"Tool '{tool.name}' already defined by another server. Skipping.")
|
|
217
231
|
continue
|
|
@@ -286,16 +300,19 @@ class MCPClient:
|
|
|
286
300
|
# Process tool calls
|
|
287
301
|
if delta.tool_calls:
|
|
288
302
|
for tool_call in delta.tool_calls:
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
final_tool_calls[
|
|
303
|
+
idx = tool_call.index
|
|
304
|
+
# first chunk for this tool call
|
|
305
|
+
if idx not in final_tool_calls:
|
|
306
|
+
final_tool_calls[idx] = tool_call
|
|
307
|
+
if final_tool_calls[idx].function.arguments is None:
|
|
308
|
+
final_tool_calls[idx].function.arguments = ""
|
|
309
|
+
continue
|
|
310
|
+
# safety before concatenating text to .function.arguments
|
|
311
|
+
if final_tool_calls[idx].function.arguments is None:
|
|
312
|
+
final_tool_calls[idx].function.arguments = ""
|
|
313
|
+
|
|
314
|
+
if tool_call.function.arguments:
|
|
315
|
+
final_tool_calls[idx].function.arguments += tool_call.function.arguments
|
|
299
316
|
|
|
300
317
|
# Optionally exit early if no tools in first chunks
|
|
301
318
|
if exit_if_first_chunk_no_tool and num_of_chunks <= 2 and len(final_tool_calls) == 0:
|
|
@@ -16,18 +16,21 @@ class StdioServerConfig(TypedDict):
|
|
|
16
16
|
args: List[str]
|
|
17
17
|
env: Dict[str, str]
|
|
18
18
|
cwd: str
|
|
19
|
+
allowed_tools: NotRequired[List[str]]
|
|
19
20
|
|
|
20
21
|
|
|
21
22
|
class HTTPServerConfig(TypedDict):
|
|
22
23
|
type: Literal["http"]
|
|
23
24
|
url: str
|
|
24
25
|
headers: Dict[str, str]
|
|
26
|
+
allowed_tools: NotRequired[List[str]]
|
|
25
27
|
|
|
26
28
|
|
|
27
29
|
class SSEServerConfig(TypedDict):
|
|
28
30
|
type: Literal["sse"]
|
|
29
31
|
url: str
|
|
30
32
|
headers: Dict[str, str]
|
|
33
|
+
allowed_tools: NotRequired[List[str]]
|
|
31
34
|
|
|
32
35
|
|
|
33
36
|
ServerConfig = Union[StdioServerConfig, HTTPServerConfig, SSEServerConfig]
|
|
@@ -11,7 +11,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple
|
|
|
11
11
|
from huggingface_hub import snapshot_download
|
|
12
12
|
from huggingface_hub.errors import EntryNotFoundError
|
|
13
13
|
|
|
14
|
-
from .constants import DEFAULT_AGENT, DEFAULT_REPO_ID, FILENAME_CONFIG,
|
|
14
|
+
from .constants import DEFAULT_AGENT, DEFAULT_REPO_ID, FILENAME_CONFIG, PROMPT_FILENAMES
|
|
15
15
|
from .types import AgentConfig
|
|
16
16
|
|
|
17
17
|
|
|
@@ -93,8 +93,12 @@ def _load_agent_config(agent_path: Optional[str]) -> Tuple[AgentConfig, Optional
|
|
|
93
93
|
raise FileNotFoundError(f" Config file not found in {directory}! Please make sure it exists locally")
|
|
94
94
|
|
|
95
95
|
config: AgentConfig = json.loads(cfg_file.read_text(encoding="utf-8"))
|
|
96
|
-
|
|
97
|
-
|
|
96
|
+
prompt: Optional[str] = None
|
|
97
|
+
for filename in PROMPT_FILENAMES:
|
|
98
|
+
prompt_file = directory / filename
|
|
99
|
+
if prompt_file.exists():
|
|
100
|
+
prompt = prompt_file.read_text(encoding="utf-8")
|
|
101
|
+
break
|
|
98
102
|
return config, prompt
|
|
99
103
|
|
|
100
104
|
if agent_path is None:
|
|
@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional, Union, overload
|
|
|
3
3
|
|
|
4
4
|
from huggingface_hub import constants
|
|
5
5
|
from huggingface_hub.hf_api import InferenceProviderMapping
|
|
6
|
-
from huggingface_hub.inference._common import RequestParameters
|
|
6
|
+
from huggingface_hub.inference._common import MimeBytes, RequestParameters
|
|
7
7
|
from huggingface_hub.inference._generated.types.chat_completion import ChatCompletionInputMessage
|
|
8
8
|
from huggingface_hub.utils import build_hf_headers, get_token, logging
|
|
9
9
|
|
|
@@ -109,8 +109,17 @@ class TaskProviderHelper:
|
|
|
109
109
|
raise ValueError("Both payload and data cannot be set in the same request.")
|
|
110
110
|
if payload is None and data is None:
|
|
111
111
|
raise ValueError("Either payload or data must be set in the request.")
|
|
112
|
+
|
|
113
|
+
# normalize headers to lowercase and add content-type if not present
|
|
114
|
+
normalized_headers = self._normalize_headers(headers, payload, data)
|
|
115
|
+
|
|
112
116
|
return RequestParameters(
|
|
113
|
-
url=url,
|
|
117
|
+
url=url,
|
|
118
|
+
task=self.task,
|
|
119
|
+
model=provider_mapping_info.provider_id,
|
|
120
|
+
json=payload,
|
|
121
|
+
data=data,
|
|
122
|
+
headers=normalized_headers,
|
|
114
123
|
)
|
|
115
124
|
|
|
116
125
|
def get_response(
|
|
@@ -173,7 +182,22 @@ class TaskProviderHelper:
|
|
|
173
182
|
)
|
|
174
183
|
return provider_mapping
|
|
175
184
|
|
|
176
|
-
def
|
|
185
|
+
def _normalize_headers(
|
|
186
|
+
self, headers: Dict[str, Any], payload: Optional[Dict[str, Any]], data: Optional[MimeBytes]
|
|
187
|
+
) -> Dict[str, Any]:
|
|
188
|
+
"""Normalize the headers to use for the request.
|
|
189
|
+
|
|
190
|
+
Override this method in subclasses for customized headers.
|
|
191
|
+
"""
|
|
192
|
+
normalized_headers = {key.lower(): value for key, value in headers.items() if value is not None}
|
|
193
|
+
if normalized_headers.get("content-type") is None:
|
|
194
|
+
if data is not None and data.mime_type is not None:
|
|
195
|
+
normalized_headers["content-type"] = data.mime_type
|
|
196
|
+
elif payload is not None:
|
|
197
|
+
normalized_headers["content-type"] = "application/json"
|
|
198
|
+
return normalized_headers
|
|
199
|
+
|
|
200
|
+
def _prepare_headers(self, headers: Dict, api_key: str) -> Dict[str, Any]:
|
|
177
201
|
"""Return the headers to use for the request.
|
|
178
202
|
|
|
179
203
|
Override this method in subclasses for customized headers.
|
|
@@ -223,7 +247,7 @@ class TaskProviderHelper:
|
|
|
223
247
|
parameters: Dict,
|
|
224
248
|
provider_mapping_info: InferenceProviderMapping,
|
|
225
249
|
extra_payload: Optional[Dict],
|
|
226
|
-
) -> Optional[
|
|
250
|
+
) -> Optional[MimeBytes]:
|
|
227
251
|
"""Return the body to use for the request, as bytes.
|
|
228
252
|
|
|
229
253
|
Override this method in subclasses for customized body data.
|
|
@@ -18,7 +18,7 @@ class BlackForestLabsTextToImageTask(TaskProviderHelper):
|
|
|
18
18
|
def __init__(self):
|
|
19
19
|
super().__init__(provider="black-forest-labs", base_url="https://api.us1.bfl.ai", task="text-to-image")
|
|
20
20
|
|
|
21
|
-
def _prepare_headers(self, headers: Dict, api_key: str) -> Dict:
|
|
21
|
+
def _prepare_headers(self, headers: Dict, api_key: str) -> Dict[str, Any]:
|
|
22
22
|
headers = super()._prepare_headers(headers, api_key)
|
|
23
23
|
if not api_key.startswith("hf_"):
|
|
24
24
|
_ = headers.pop("authorization")
|
|
@@ -22,7 +22,7 @@ class FalAITask(TaskProviderHelper, ABC):
|
|
|
22
22
|
def __init__(self, task: str):
|
|
23
23
|
super().__init__(provider="fal-ai", base_url="https://fal.run", task=task)
|
|
24
24
|
|
|
25
|
-
def _prepare_headers(self, headers: Dict, api_key: str) -> Dict:
|
|
25
|
+
def _prepare_headers(self, headers: Dict, api_key: str) -> Dict[str, Any]:
|
|
26
26
|
headers = super()._prepare_headers(headers, api_key)
|
|
27
27
|
if not api_key.startswith("hf_"):
|
|
28
28
|
headers["authorization"] = f"Key {api_key}"
|
|
@@ -36,7 +36,7 @@ class FalAIQueueTask(TaskProviderHelper, ABC):
|
|
|
36
36
|
def __init__(self, task: str):
|
|
37
37
|
super().__init__(provider="fal-ai", base_url="https://queue.fal.run", task=task)
|
|
38
38
|
|
|
39
|
-
def _prepare_headers(self, headers: Dict, api_key: str) -> Dict:
|
|
39
|
+
def _prepare_headers(self, headers: Dict, api_key: str) -> Dict[str, Any]:
|
|
40
40
|
headers = super()._prepare_headers(headers, api_key)
|
|
41
41
|
if not api_key.startswith("hf_"):
|
|
42
42
|
headers["authorization"] = f"Key {api_key}"
|
|
@@ -6,7 +6,13 @@ from urllib.parse import urlparse, urlunparse
|
|
|
6
6
|
|
|
7
7
|
from huggingface_hub import constants
|
|
8
8
|
from huggingface_hub.hf_api import InferenceProviderMapping
|
|
9
|
-
from huggingface_hub.inference._common import
|
|
9
|
+
from huggingface_hub.inference._common import (
|
|
10
|
+
MimeBytes,
|
|
11
|
+
RequestParameters,
|
|
12
|
+
_b64_encode,
|
|
13
|
+
_bytes_to_dict,
|
|
14
|
+
_open_as_mime_bytes,
|
|
15
|
+
)
|
|
10
16
|
from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none
|
|
11
17
|
from huggingface_hub.utils import build_hf_headers, get_session, get_token, hf_raise_for_status
|
|
12
18
|
|
|
@@ -75,7 +81,7 @@ class HFInferenceBinaryInputTask(HFInferenceTask):
|
|
|
75
81
|
parameters: Dict,
|
|
76
82
|
provider_mapping_info: InferenceProviderMapping,
|
|
77
83
|
extra_payload: Optional[Dict],
|
|
78
|
-
) -> Optional[
|
|
84
|
+
) -> Optional[MimeBytes]:
|
|
79
85
|
parameters = filter_none(parameters)
|
|
80
86
|
extra_payload = extra_payload or {}
|
|
81
87
|
has_parameters = len(parameters) > 0 or len(extra_payload) > 0
|
|
@@ -86,12 +92,13 @@ class HFInferenceBinaryInputTask(HFInferenceTask):
|
|
|
86
92
|
|
|
87
93
|
# Send inputs as raw content when no parameters are provided
|
|
88
94
|
if not has_parameters:
|
|
89
|
-
|
|
90
|
-
data_as_bytes = data if isinstance(data, bytes) else data.read()
|
|
91
|
-
return data_as_bytes
|
|
95
|
+
return _open_as_mime_bytes(inputs)
|
|
92
96
|
|
|
93
97
|
# Otherwise encode as b64
|
|
94
|
-
return
|
|
98
|
+
return MimeBytes(
|
|
99
|
+
json.dumps({"inputs": _b64_encode(inputs), "parameters": parameters, **extra_payload}).encode("utf-8"),
|
|
100
|
+
mime_type="application/json",
|
|
101
|
+
)
|
|
95
102
|
|
|
96
103
|
|
|
97
104
|
class HFInferenceConversational(HFInferenceTask):
|
|
@@ -144,7 +151,8 @@ def _build_chat_completion_url(model_url: str) -> str:
|
|
|
144
151
|
new_path = path + "/v1/chat/completions"
|
|
145
152
|
|
|
146
153
|
# Reconstruct the URL with the new path and original query parameters.
|
|
147
|
-
|
|
154
|
+
new_parsed = parsed._replace(path=new_path)
|
|
155
|
+
return str(urlunparse(new_parsed))
|
|
148
156
|
|
|
149
157
|
|
|
150
158
|
@lru_cache(maxsize=1)
|
|
@@ -14,7 +14,7 @@ class ReplicateTask(TaskProviderHelper):
|
|
|
14
14
|
def __init__(self, task: str):
|
|
15
15
|
super().__init__(provider=_PROVIDER, base_url=_BASE_URL, task=task)
|
|
16
16
|
|
|
17
|
-
def _prepare_headers(self, headers: Dict, api_key: str) -> Dict:
|
|
17
|
+
def _prepare_headers(self, headers: Dict, api_key: str) -> Dict[str, Any]:
|
|
18
18
|
headers = super()._prepare_headers(headers, api_key)
|
|
19
19
|
headers["Prefer"] = "wait"
|
|
20
20
|
return headers
|
huggingface_hub/repocard.py
CHANGED
|
@@ -771,7 +771,8 @@ def metadata_update(
|
|
|
771
771
|
raise ValueError("Cannot update metadata on a Space that doesn't contain a `README.md` file.")
|
|
772
772
|
|
|
773
773
|
# Initialize a ModelCard or DatasetCard from default template and no data.
|
|
774
|
-
card
|
|
774
|
+
# Cast to the concrete expected card type to satisfy type checkers.
|
|
775
|
+
card = card_class.from_template(CardData()) # type: ignore[return-value]
|
|
775
776
|
|
|
776
777
|
for key, value in metadata.items():
|
|
777
778
|
if key == "model-index":
|
|
@@ -27,7 +27,7 @@ GIT_CREDENTIAL_REGEX = re.compile(
|
|
|
27
27
|
^\s* # start of line
|
|
28
28
|
credential\.helper # credential.helper value
|
|
29
29
|
\s*=\s* # separator
|
|
30
|
-
(\w+) # the helper name (group 1)
|
|
30
|
+
([\w\-\/]+) # the helper name or absolute path (group 1)
|
|
31
31
|
(\s|$) # whitespace or end of line
|
|
32
32
|
""",
|
|
33
33
|
flags=re.MULTILINE | re.IGNORECASE | re.VERBOSE,
|