huggingface-hub 0.31.0rc0__py3-none-any.whl → 1.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- huggingface_hub/__init__.py +145 -46
- huggingface_hub/_commit_api.py +168 -119
- huggingface_hub/_commit_scheduler.py +15 -15
- huggingface_hub/_inference_endpoints.py +15 -12
- huggingface_hub/_jobs_api.py +301 -0
- huggingface_hub/_local_folder.py +18 -3
- huggingface_hub/_login.py +31 -63
- huggingface_hub/_oauth.py +460 -0
- huggingface_hub/_snapshot_download.py +239 -80
- huggingface_hub/_space_api.py +5 -5
- huggingface_hub/_tensorboard_logger.py +15 -19
- huggingface_hub/_upload_large_folder.py +172 -76
- huggingface_hub/_webhooks_payload.py +3 -3
- huggingface_hub/_webhooks_server.py +13 -25
- huggingface_hub/{commands → cli}/__init__.py +1 -15
- huggingface_hub/cli/_cli_utils.py +173 -0
- huggingface_hub/cli/auth.py +147 -0
- huggingface_hub/cli/cache.py +841 -0
- huggingface_hub/cli/download.py +189 -0
- huggingface_hub/cli/hf.py +60 -0
- huggingface_hub/cli/inference_endpoints.py +377 -0
- huggingface_hub/cli/jobs.py +772 -0
- huggingface_hub/cli/lfs.py +175 -0
- huggingface_hub/cli/repo.py +315 -0
- huggingface_hub/cli/repo_files.py +94 -0
- huggingface_hub/{commands/env.py → cli/system.py} +10 -13
- huggingface_hub/cli/upload.py +294 -0
- huggingface_hub/cli/upload_large_folder.py +117 -0
- huggingface_hub/community.py +20 -12
- huggingface_hub/constants.py +38 -53
- huggingface_hub/dataclasses.py +609 -0
- huggingface_hub/errors.py +80 -30
- huggingface_hub/fastai_utils.py +30 -41
- huggingface_hub/file_download.py +435 -351
- huggingface_hub/hf_api.py +2050 -1124
- huggingface_hub/hf_file_system.py +269 -152
- huggingface_hub/hub_mixin.py +43 -63
- huggingface_hub/inference/_client.py +347 -434
- huggingface_hub/inference/_common.py +133 -121
- huggingface_hub/inference/_generated/_async_client.py +397 -541
- huggingface_hub/inference/_generated/types/__init__.py +5 -1
- huggingface_hub/inference/_generated/types/automatic_speech_recognition.py +3 -3
- huggingface_hub/inference/_generated/types/base.py +10 -7
- huggingface_hub/inference/_generated/types/chat_completion.py +59 -23
- huggingface_hub/inference/_generated/types/depth_estimation.py +2 -2
- huggingface_hub/inference/_generated/types/document_question_answering.py +2 -2
- huggingface_hub/inference/_generated/types/feature_extraction.py +2 -2
- huggingface_hub/inference/_generated/types/fill_mask.py +2 -2
- huggingface_hub/inference/_generated/types/image_to_image.py +6 -2
- huggingface_hub/inference/_generated/types/image_to_video.py +60 -0
- huggingface_hub/inference/_generated/types/sentence_similarity.py +3 -3
- huggingface_hub/inference/_generated/types/summarization.py +2 -2
- huggingface_hub/inference/_generated/types/table_question_answering.py +5 -5
- huggingface_hub/inference/_generated/types/text2text_generation.py +2 -2
- huggingface_hub/inference/_generated/types/text_generation.py +10 -10
- huggingface_hub/inference/_generated/types/text_to_video.py +2 -2
- huggingface_hub/inference/_generated/types/token_classification.py +2 -2
- huggingface_hub/inference/_generated/types/translation.py +2 -2
- huggingface_hub/inference/_generated/types/zero_shot_classification.py +2 -2
- huggingface_hub/inference/_generated/types/zero_shot_image_classification.py +2 -2
- huggingface_hub/inference/_generated/types/zero_shot_object_detection.py +1 -3
- huggingface_hub/inference/_mcp/__init__.py +0 -0
- huggingface_hub/inference/_mcp/_cli_hacks.py +88 -0
- huggingface_hub/inference/_mcp/agent.py +100 -0
- huggingface_hub/inference/_mcp/cli.py +247 -0
- huggingface_hub/inference/_mcp/constants.py +81 -0
- huggingface_hub/inference/_mcp/mcp_client.py +395 -0
- huggingface_hub/inference/_mcp/types.py +45 -0
- huggingface_hub/inference/_mcp/utils.py +128 -0
- huggingface_hub/inference/_providers/__init__.py +82 -7
- huggingface_hub/inference/_providers/_common.py +129 -27
- huggingface_hub/inference/_providers/black_forest_labs.py +6 -6
- huggingface_hub/inference/_providers/cerebras.py +1 -1
- huggingface_hub/inference/_providers/clarifai.py +13 -0
- huggingface_hub/inference/_providers/cohere.py +20 -3
- huggingface_hub/inference/_providers/fal_ai.py +183 -56
- huggingface_hub/inference/_providers/featherless_ai.py +38 -0
- huggingface_hub/inference/_providers/fireworks_ai.py +18 -0
- huggingface_hub/inference/_providers/groq.py +9 -0
- huggingface_hub/inference/_providers/hf_inference.py +69 -30
- huggingface_hub/inference/_providers/hyperbolic.py +4 -4
- huggingface_hub/inference/_providers/nebius.py +33 -5
- huggingface_hub/inference/_providers/novita.py +5 -5
- huggingface_hub/inference/_providers/nscale.py +44 -0
- huggingface_hub/inference/_providers/openai.py +3 -1
- huggingface_hub/inference/_providers/publicai.py +6 -0
- huggingface_hub/inference/_providers/replicate.py +31 -13
- huggingface_hub/inference/_providers/sambanova.py +18 -4
- huggingface_hub/inference/_providers/scaleway.py +28 -0
- huggingface_hub/inference/_providers/together.py +20 -5
- huggingface_hub/inference/_providers/wavespeed.py +138 -0
- huggingface_hub/inference/_providers/zai_org.py +17 -0
- huggingface_hub/lfs.py +33 -100
- huggingface_hub/repocard.py +34 -38
- huggingface_hub/repocard_data.py +57 -57
- huggingface_hub/serialization/__init__.py +0 -1
- huggingface_hub/serialization/_base.py +12 -15
- huggingface_hub/serialization/_dduf.py +8 -8
- huggingface_hub/serialization/_torch.py +69 -69
- huggingface_hub/utils/__init__.py +19 -8
- huggingface_hub/utils/_auth.py +7 -7
- huggingface_hub/utils/_cache_manager.py +92 -147
- huggingface_hub/utils/_chunk_utils.py +2 -3
- huggingface_hub/utils/_deprecation.py +1 -1
- huggingface_hub/utils/_dotenv.py +55 -0
- huggingface_hub/utils/_experimental.py +7 -5
- huggingface_hub/utils/_fixes.py +0 -10
- huggingface_hub/utils/_git_credential.py +5 -5
- huggingface_hub/utils/_headers.py +8 -30
- huggingface_hub/utils/_http.py +398 -239
- huggingface_hub/utils/_pagination.py +4 -4
- huggingface_hub/utils/_parsing.py +98 -0
- huggingface_hub/utils/_paths.py +5 -5
- huggingface_hub/utils/_runtime.py +61 -24
- huggingface_hub/utils/_safetensors.py +21 -21
- huggingface_hub/utils/_subprocess.py +9 -9
- huggingface_hub/utils/_telemetry.py +4 -4
- huggingface_hub/{commands/_cli_utils.py → utils/_terminal.py} +4 -4
- huggingface_hub/utils/_typing.py +25 -5
- huggingface_hub/utils/_validators.py +55 -74
- huggingface_hub/utils/_verification.py +167 -0
- huggingface_hub/utils/_xet.py +64 -17
- huggingface_hub/utils/_xet_progress_reporting.py +162 -0
- huggingface_hub/utils/insecure_hashlib.py +3 -5
- huggingface_hub/utils/logging.py +8 -11
- huggingface_hub/utils/tqdm.py +5 -4
- {huggingface_hub-0.31.0rc0.dist-info → huggingface_hub-1.1.3.dist-info}/METADATA +94 -85
- huggingface_hub-1.1.3.dist-info/RECORD +155 -0
- {huggingface_hub-0.31.0rc0.dist-info → huggingface_hub-1.1.3.dist-info}/WHEEL +1 -1
- huggingface_hub-1.1.3.dist-info/entry_points.txt +6 -0
- huggingface_hub/commands/delete_cache.py +0 -474
- huggingface_hub/commands/download.py +0 -200
- huggingface_hub/commands/huggingface_cli.py +0 -61
- huggingface_hub/commands/lfs.py +0 -200
- huggingface_hub/commands/repo_files.py +0 -128
- huggingface_hub/commands/scan_cache.py +0 -181
- huggingface_hub/commands/tag.py +0 -159
- huggingface_hub/commands/upload.py +0 -314
- huggingface_hub/commands/upload_large_folder.py +0 -129
- huggingface_hub/commands/user.py +0 -304
- huggingface_hub/commands/version.py +0 -37
- huggingface_hub/inference_api.py +0 -217
- huggingface_hub/keras_mixin.py +0 -500
- huggingface_hub/repository.py +0 -1477
- huggingface_hub/serialization/_tensorflow.py +0 -95
- huggingface_hub/utils/_hf_folder.py +0 -68
- huggingface_hub-0.31.0rc0.dist-info/RECORD +0 -135
- huggingface_hub-0.31.0rc0.dist-info/entry_points.txt +0 -6
- {huggingface_hub-0.31.0rc0.dist-info → huggingface_hub-1.1.3.dist-info/licenses}/LICENSE +0 -0
- {huggingface_hub-0.31.0rc0.dist-info → huggingface_hub-1.1.3.dist-info}/top_level.txt +0 -0
|
@@ -21,16 +21,19 @@
|
|
|
21
21
|
import asyncio
|
|
22
22
|
import base64
|
|
23
23
|
import logging
|
|
24
|
+
import os
|
|
24
25
|
import re
|
|
25
26
|
import warnings
|
|
26
|
-
from
|
|
27
|
+
from contextlib import AsyncExitStack
|
|
28
|
+
from typing import TYPE_CHECKING, Any, AsyncIterable, Literal, Optional, Union, overload
|
|
29
|
+
|
|
30
|
+
import httpx
|
|
27
31
|
|
|
28
32
|
from huggingface_hub import constants
|
|
29
|
-
from huggingface_hub.errors import InferenceTimeoutError
|
|
33
|
+
from huggingface_hub.errors import BadRequestError, HfHubHTTPError, InferenceTimeoutError
|
|
30
34
|
from huggingface_hub.inference._common import (
|
|
31
35
|
TASKS_EXPECTING_IMAGES,
|
|
32
36
|
ContentT,
|
|
33
|
-
ModelStatus,
|
|
34
37
|
RequestParameters,
|
|
35
38
|
_async_stream_chat_completion_response,
|
|
36
39
|
_async_stream_text_generation_response,
|
|
@@ -41,7 +44,6 @@ from huggingface_hub.inference._common import (
|
|
|
41
44
|
_bytes_to_list,
|
|
42
45
|
_get_unsupported_text_generation_kwargs,
|
|
43
46
|
_import_numpy,
|
|
44
|
-
_open_as_binary,
|
|
45
47
|
_set_unsupported_text_generation_kwargs,
|
|
46
48
|
raise_text_generation_error,
|
|
47
49
|
)
|
|
@@ -51,6 +53,7 @@ from huggingface_hub.inference._generated.types import (
|
|
|
51
53
|
AudioToAudioOutputElement,
|
|
52
54
|
AutomaticSpeechRecognitionOutput,
|
|
53
55
|
ChatCompletionInputGrammarType,
|
|
56
|
+
ChatCompletionInputMessage,
|
|
54
57
|
ChatCompletionInputStreamOptions,
|
|
55
58
|
ChatCompletionInputTool,
|
|
56
59
|
ChatCompletionInputToolChoiceClass,
|
|
@@ -65,6 +68,7 @@ from huggingface_hub.inference._generated.types import (
|
|
|
65
68
|
ImageSegmentationSubtask,
|
|
66
69
|
ImageToImageTargetSize,
|
|
67
70
|
ImageToTextOutput,
|
|
71
|
+
ImageToVideoTargetSize,
|
|
68
72
|
ObjectDetectionOutputElement,
|
|
69
73
|
Padding,
|
|
70
74
|
QuestionAnsweringOutputElement,
|
|
@@ -85,17 +89,20 @@ from huggingface_hub.inference._generated.types import (
|
|
|
85
89
|
ZeroShotClassificationOutputElement,
|
|
86
90
|
ZeroShotImageClassificationOutputElement,
|
|
87
91
|
)
|
|
88
|
-
from huggingface_hub.inference._providers import
|
|
89
|
-
from huggingface_hub.utils import
|
|
92
|
+
from huggingface_hub.inference._providers import PROVIDER_OR_POLICY_T, get_provider_helper
|
|
93
|
+
from huggingface_hub.utils import (
|
|
94
|
+
build_hf_headers,
|
|
95
|
+
get_async_session,
|
|
96
|
+
hf_raise_for_status,
|
|
97
|
+
validate_hf_hub_args,
|
|
98
|
+
)
|
|
90
99
|
from huggingface_hub.utils._auth import get_token
|
|
91
|
-
from huggingface_hub.utils._deprecation import _deprecate_method
|
|
92
100
|
|
|
93
|
-
from .._common import _async_yield_from
|
|
101
|
+
from .._common import _async_yield_from
|
|
94
102
|
|
|
95
103
|
|
|
96
104
|
if TYPE_CHECKING:
|
|
97
105
|
import numpy as np
|
|
98
|
-
from aiohttp import ClientResponse, ClientSession
|
|
99
106
|
from PIL.Image import Image
|
|
100
107
|
|
|
101
108
|
logger = logging.getLogger(__name__)
|
|
@@ -117,11 +124,9 @@ class AsyncInferenceClient:
|
|
|
117
124
|
or a URL to a deployed Inference Endpoint. Defaults to None, in which case a recommended model is
|
|
118
125
|
automatically selected for the task.
|
|
119
126
|
Note: for better compatibility with OpenAI's client, `model` has been aliased as `base_url`. Those 2
|
|
120
|
-
arguments are mutually exclusive. If
|
|
121
|
-
path will be appended to the base URL (see the [TGI Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api)
|
|
122
|
-
documentation for details). When passing a URL as `model`, the client will not append any suffix path to it.
|
|
127
|
+
arguments are mutually exclusive. If a URL is passed as `model` or `base_url` for chat completion, the `(/v1)/chat/completions` suffix path will be appended to the URL.
|
|
123
128
|
provider (`str`, *optional*):
|
|
124
|
-
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"openai"`, `"replicate"`, "sambanova"` or `"
|
|
129
|
+
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"clarifai"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `"publicai"`, `"replicate"`, `"sambanova"`, `"scaleway"`, `"together"`, `"wavespeed"` or `"zai-org"`.
|
|
125
130
|
Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers.
|
|
126
131
|
If model is a URL or `base_url` is passed, then `provider` is not used.
|
|
127
132
|
token (`str`, *optional*):
|
|
@@ -130,18 +135,14 @@ class AsyncInferenceClient:
|
|
|
130
135
|
arguments are mutually exclusive and have the exact same behavior.
|
|
131
136
|
timeout (`float`, `optional`):
|
|
132
137
|
The maximum number of seconds to wait for a response from the server. Defaults to None, meaning it will loop until the server is available.
|
|
133
|
-
headers (`
|
|
138
|
+
headers (`dict[str, str]`, `optional`):
|
|
134
139
|
Additional headers to send to the server. By default only the authorization and user-agent headers are sent.
|
|
135
140
|
Values in this dictionary will override the default values.
|
|
136
141
|
bill_to (`str`, `optional`):
|
|
137
142
|
The billing account to use for the requests. By default the requests are billed on the user's account.
|
|
138
143
|
Requests can only be billed to an organization the user is a member of, and which has subscribed to Enterprise Hub.
|
|
139
|
-
cookies (`
|
|
144
|
+
cookies (`dict[str, str]`, `optional`):
|
|
140
145
|
Additional cookies to send to the server.
|
|
141
|
-
trust_env ('bool', 'optional'):
|
|
142
|
-
Trust environment settings for proxy configuration if the parameter is `True` (`False` by default).
|
|
143
|
-
proxies (`Any`, `optional`):
|
|
144
|
-
Proxies to use for the request.
|
|
145
146
|
base_url (`str`, `optional`):
|
|
146
147
|
Base URL to run inference. This is a duplicated argument from `model` to make [`InferenceClient`]
|
|
147
148
|
follow the same pattern as `openai.OpenAI` client. Cannot be used if `model` is set. Defaults to None.
|
|
@@ -150,17 +151,16 @@ class AsyncInferenceClient:
|
|
|
150
151
|
follow the same pattern as `openai.OpenAI` client. Cannot be used if `token` is set. Defaults to None.
|
|
151
152
|
"""
|
|
152
153
|
|
|
154
|
+
@validate_hf_hub_args
|
|
153
155
|
def __init__(
|
|
154
156
|
self,
|
|
155
157
|
model: Optional[str] = None,
|
|
156
158
|
*,
|
|
157
|
-
provider:
|
|
159
|
+
provider: Optional[PROVIDER_OR_POLICY_T] = None,
|
|
158
160
|
token: Optional[str] = None,
|
|
159
161
|
timeout: Optional[float] = None,
|
|
160
|
-
headers: Optional[
|
|
161
|
-
cookies: Optional[
|
|
162
|
-
trust_env: bool = False,
|
|
163
|
-
proxies: Optional[Any] = None,
|
|
162
|
+
headers: Optional[dict[str, str]] = None,
|
|
163
|
+
cookies: Optional[dict[str, str]] = None,
|
|
164
164
|
bill_to: Optional[str] = None,
|
|
165
165
|
# OpenAI compatibility
|
|
166
166
|
base_url: Optional[str] = None,
|
|
@@ -181,7 +181,7 @@ class AsyncInferenceClient:
|
|
|
181
181
|
)
|
|
182
182
|
token = token if token is not None else api_key
|
|
183
183
|
if isinstance(token, bool):
|
|
184
|
-
# Legacy behavior: previously
|
|
184
|
+
# Legacy behavior: previously it was possible to pass `token=False` to disable authentication. This is not
|
|
185
185
|
# supported anymore as authentication is required. Better to explicitly raise here rather than risking
|
|
186
186
|
# sending the locally saved token without the user knowing about it.
|
|
187
187
|
if token is False:
|
|
@@ -222,15 +222,36 @@ class AsyncInferenceClient:
|
|
|
222
222
|
|
|
223
223
|
self.cookies = cookies
|
|
224
224
|
self.timeout = timeout
|
|
225
|
-
self.trust_env = trust_env
|
|
226
|
-
self.proxies = proxies
|
|
227
225
|
|
|
228
|
-
|
|
229
|
-
self.
|
|
226
|
+
self.exit_stack = AsyncExitStack()
|
|
227
|
+
self._async_client: Optional[httpx.AsyncClient] = None
|
|
230
228
|
|
|
231
229
|
def __repr__(self):
|
|
232
230
|
return f"<InferenceClient(model='{self.model if self.model else ''}', timeout={self.timeout})>"
|
|
233
231
|
|
|
232
|
+
async def __aenter__(self):
|
|
233
|
+
return self
|
|
234
|
+
|
|
235
|
+
async def __aexit__(self, exc_type, exc_value, traceback):
|
|
236
|
+
await self.close()
|
|
237
|
+
|
|
238
|
+
async def close(self):
|
|
239
|
+
"""Close the client.
|
|
240
|
+
|
|
241
|
+
This method is automatically called when using the client as a context manager.
|
|
242
|
+
"""
|
|
243
|
+
await self.exit_stack.aclose()
|
|
244
|
+
|
|
245
|
+
async def _get_async_client(self):
|
|
246
|
+
"""Get a unique async client for this AsyncInferenceClient instance.
|
|
247
|
+
|
|
248
|
+
Returns the same client instance on subsequent calls, ensuring proper
|
|
249
|
+
connection reuse and resource management through the exit stack.
|
|
250
|
+
"""
|
|
251
|
+
if self._async_client is None:
|
|
252
|
+
self._async_client = await self.exit_stack.enter_async_context(get_async_session())
|
|
253
|
+
return self._async_client
|
|
254
|
+
|
|
234
255
|
@overload
|
|
235
256
|
async def _inner_post( # type: ignore[misc]
|
|
236
257
|
self, request_parameters: RequestParameters, *, stream: Literal[False] = ...
|
|
@@ -239,83 +260,59 @@ class AsyncInferenceClient:
|
|
|
239
260
|
@overload
|
|
240
261
|
async def _inner_post( # type: ignore[misc]
|
|
241
262
|
self, request_parameters: RequestParameters, *, stream: Literal[True] = ...
|
|
242
|
-
) -> AsyncIterable[
|
|
263
|
+
) -> AsyncIterable[str]: ...
|
|
243
264
|
|
|
244
265
|
@overload
|
|
245
266
|
async def _inner_post(
|
|
246
267
|
self, request_parameters: RequestParameters, *, stream: bool = False
|
|
247
|
-
) -> Union[bytes, AsyncIterable[
|
|
268
|
+
) -> Union[bytes, AsyncIterable[str]]: ...
|
|
248
269
|
|
|
249
270
|
async def _inner_post(
|
|
250
271
|
self, request_parameters: RequestParameters, *, stream: bool = False
|
|
251
|
-
) -> Union[bytes, AsyncIterable[
|
|
272
|
+
) -> Union[bytes, AsyncIterable[str]]:
|
|
252
273
|
"""Make a request to the inference server."""
|
|
253
274
|
|
|
254
|
-
aiohttp = _import_aiohttp()
|
|
255
|
-
|
|
256
275
|
# TODO: this should be handled in provider helpers directly
|
|
257
276
|
if request_parameters.task in TASKS_EXPECTING_IMAGES and "Accept" not in request_parameters.headers:
|
|
258
277
|
request_parameters.headers["Accept"] = "image/png"
|
|
259
278
|
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
279
|
+
try:
|
|
280
|
+
client = await self._get_async_client()
|
|
281
|
+
if stream:
|
|
282
|
+
response = await self.exit_stack.enter_async_context(
|
|
283
|
+
client.stream(
|
|
284
|
+
"POST",
|
|
285
|
+
request_parameters.url,
|
|
286
|
+
json=request_parameters.json,
|
|
287
|
+
data=request_parameters.data,
|
|
288
|
+
headers=request_parameters.headers,
|
|
289
|
+
cookies=self.cookies,
|
|
290
|
+
timeout=self.timeout,
|
|
291
|
+
)
|
|
268
292
|
)
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
raise
|
|
293
|
-
|
|
294
|
-
async def __aenter__(self):
|
|
295
|
-
return self
|
|
296
|
-
|
|
297
|
-
async def __aexit__(self, exc_type, exc_value, traceback):
|
|
298
|
-
await self.close()
|
|
299
|
-
|
|
300
|
-
def __del__(self):
|
|
301
|
-
if len(self._sessions) > 0:
|
|
302
|
-
warnings.warn(
|
|
303
|
-
"Deleting 'AsyncInferenceClient' client but some sessions are still open. "
|
|
304
|
-
"This can happen if you've stopped streaming data from the server before the stream was complete. "
|
|
305
|
-
"To close the client properly, you must call `await client.close()` "
|
|
306
|
-
"or use an async context (e.g. `async with AsyncInferenceClient(): ...`."
|
|
307
|
-
)
|
|
308
|
-
|
|
309
|
-
async def close(self):
|
|
310
|
-
"""Close all open sessions.
|
|
311
|
-
|
|
312
|
-
By default, 'aiohttp.ClientSession' objects are closed automatically when a call is completed. However, if you
|
|
313
|
-
are streaming data from the server and you stop before the stream is complete, you must call this method to
|
|
314
|
-
close the session properly.
|
|
315
|
-
|
|
316
|
-
Another possibility is to use an async context (e.g. `async with AsyncInferenceClient(): ...`).
|
|
317
|
-
"""
|
|
318
|
-
await asyncio.gather(*[session.close() for session in self._sessions.keys()])
|
|
293
|
+
hf_raise_for_status(response)
|
|
294
|
+
return _async_yield_from(client, response)
|
|
295
|
+
else:
|
|
296
|
+
response = await client.post(
|
|
297
|
+
request_parameters.url,
|
|
298
|
+
json=request_parameters.json,
|
|
299
|
+
data=request_parameters.data,
|
|
300
|
+
headers=request_parameters.headers,
|
|
301
|
+
cookies=self.cookies,
|
|
302
|
+
timeout=self.timeout,
|
|
303
|
+
)
|
|
304
|
+
hf_raise_for_status(response)
|
|
305
|
+
return response.content
|
|
306
|
+
except asyncio.TimeoutError as error:
|
|
307
|
+
# Convert any `TimeoutError` to a `InferenceTimeoutError`
|
|
308
|
+
raise InferenceTimeoutError(f"Inference call timed out: {request_parameters.url}") from error # type: ignore
|
|
309
|
+
except HfHubHTTPError as error:
|
|
310
|
+
if error.response.status_code == 422 and request_parameters.task != "unknown":
|
|
311
|
+
msg = str(error.args[0])
|
|
312
|
+
if len(error.response.text) > 0:
|
|
313
|
+
msg += f"{os.linesep}{error.response.text}{os.linesep}"
|
|
314
|
+
error.args = (msg,) + error.args[1:]
|
|
315
|
+
raise
|
|
319
316
|
|
|
320
317
|
async def audio_classification(
|
|
321
318
|
self,
|
|
@@ -324,7 +321,7 @@ class AsyncInferenceClient:
|
|
|
324
321
|
model: Optional[str] = None,
|
|
325
322
|
top_k: Optional[int] = None,
|
|
326
323
|
function_to_apply: Optional["AudioClassificationOutputTransform"] = None,
|
|
327
|
-
) ->
|
|
324
|
+
) -> list[AudioClassificationOutputElement]:
|
|
328
325
|
"""
|
|
329
326
|
Perform audio classification on the provided audio content.
|
|
330
327
|
|
|
@@ -342,12 +339,12 @@ class AsyncInferenceClient:
|
|
|
342
339
|
The function to apply to the model outputs in order to retrieve the scores.
|
|
343
340
|
|
|
344
341
|
Returns:
|
|
345
|
-
`
|
|
342
|
+
`list[AudioClassificationOutputElement]`: List of [`AudioClassificationOutputElement`] items containing the predicted labels and their confidence.
|
|
346
343
|
|
|
347
344
|
Raises:
|
|
348
345
|
[`InferenceTimeoutError`]:
|
|
349
346
|
If the model is unavailable or the request times out.
|
|
350
|
-
`
|
|
347
|
+
[`HfHubHTTPError`]:
|
|
351
348
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
352
349
|
|
|
353
350
|
Example:
|
|
@@ -380,7 +377,7 @@ class AsyncInferenceClient:
|
|
|
380
377
|
audio: ContentT,
|
|
381
378
|
*,
|
|
382
379
|
model: Optional[str] = None,
|
|
383
|
-
) ->
|
|
380
|
+
) -> list[AudioToAudioOutputElement]:
|
|
384
381
|
"""
|
|
385
382
|
Performs multiple tasks related to audio-to-audio depending on the model (eg: speech enhancement, source separation).
|
|
386
383
|
|
|
@@ -394,12 +391,12 @@ class AsyncInferenceClient:
|
|
|
394
391
|
audio_to_audio will be used.
|
|
395
392
|
|
|
396
393
|
Returns:
|
|
397
|
-
`
|
|
394
|
+
`list[AudioToAudioOutputElement]`: A list of [`AudioToAudioOutputElement`] items containing audios label, content-type, and audio content in blob.
|
|
398
395
|
|
|
399
396
|
Raises:
|
|
400
397
|
`InferenceTimeoutError`:
|
|
401
398
|
If the model is unavailable or the request times out.
|
|
402
|
-
`
|
|
399
|
+
[`HfHubHTTPError`]:
|
|
403
400
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
404
401
|
|
|
405
402
|
Example:
|
|
@@ -433,7 +430,7 @@ class AsyncInferenceClient:
|
|
|
433
430
|
audio: ContentT,
|
|
434
431
|
*,
|
|
435
432
|
model: Optional[str] = None,
|
|
436
|
-
extra_body: Optional[
|
|
433
|
+
extra_body: Optional[dict] = None,
|
|
437
434
|
) -> AutomaticSpeechRecognitionOutput:
|
|
438
435
|
"""
|
|
439
436
|
Perform automatic speech recognition (ASR or audio-to-text) on the given audio content.
|
|
@@ -444,7 +441,7 @@ class AsyncInferenceClient:
|
|
|
444
441
|
model (`str`, *optional*):
|
|
445
442
|
The model to use for ASR. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
|
446
443
|
Inference Endpoint. If not provided, the default recommended model for ASR will be used.
|
|
447
|
-
extra_body (`
|
|
444
|
+
extra_body (`dict`, *optional*):
|
|
448
445
|
Additional provider-specific parameters to pass to the model. Refer to the provider's documentation
|
|
449
446
|
for supported parameters.
|
|
450
447
|
Returns:
|
|
@@ -453,7 +450,7 @@ class AsyncInferenceClient:
|
|
|
453
450
|
Raises:
|
|
454
451
|
[`InferenceTimeoutError`]:
|
|
455
452
|
If the model is unavailable or the request times out.
|
|
456
|
-
`
|
|
453
|
+
[`HfHubHTTPError`]:
|
|
457
454
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
458
455
|
|
|
459
456
|
Example:
|
|
@@ -480,121 +477,117 @@ class AsyncInferenceClient:
|
|
|
480
477
|
@overload
|
|
481
478
|
async def chat_completion( # type: ignore
|
|
482
479
|
self,
|
|
483
|
-
messages:
|
|
480
|
+
messages: list[Union[dict, ChatCompletionInputMessage]],
|
|
484
481
|
*,
|
|
485
482
|
model: Optional[str] = None,
|
|
486
483
|
stream: Literal[False] = False,
|
|
487
484
|
frequency_penalty: Optional[float] = None,
|
|
488
|
-
logit_bias: Optional[
|
|
485
|
+
logit_bias: Optional[list[float]] = None,
|
|
489
486
|
logprobs: Optional[bool] = None,
|
|
490
487
|
max_tokens: Optional[int] = None,
|
|
491
488
|
n: Optional[int] = None,
|
|
492
489
|
presence_penalty: Optional[float] = None,
|
|
493
490
|
response_format: Optional[ChatCompletionInputGrammarType] = None,
|
|
494
491
|
seed: Optional[int] = None,
|
|
495
|
-
stop: Optional[
|
|
492
|
+
stop: Optional[list[str]] = None,
|
|
496
493
|
stream_options: Optional[ChatCompletionInputStreamOptions] = None,
|
|
497
494
|
temperature: Optional[float] = None,
|
|
498
495
|
tool_choice: Optional[Union[ChatCompletionInputToolChoiceClass, "ChatCompletionInputToolChoiceEnum"]] = None,
|
|
499
496
|
tool_prompt: Optional[str] = None,
|
|
500
|
-
tools: Optional[
|
|
497
|
+
tools: Optional[list[ChatCompletionInputTool]] = None,
|
|
501
498
|
top_logprobs: Optional[int] = None,
|
|
502
499
|
top_p: Optional[float] = None,
|
|
503
|
-
extra_body: Optional[
|
|
500
|
+
extra_body: Optional[dict] = None,
|
|
504
501
|
) -> ChatCompletionOutput: ...
|
|
505
502
|
|
|
506
503
|
@overload
|
|
507
504
|
async def chat_completion( # type: ignore
|
|
508
505
|
self,
|
|
509
|
-
messages:
|
|
506
|
+
messages: list[Union[dict, ChatCompletionInputMessage]],
|
|
510
507
|
*,
|
|
511
508
|
model: Optional[str] = None,
|
|
512
509
|
stream: Literal[True] = True,
|
|
513
510
|
frequency_penalty: Optional[float] = None,
|
|
514
|
-
logit_bias: Optional[
|
|
511
|
+
logit_bias: Optional[list[float]] = None,
|
|
515
512
|
logprobs: Optional[bool] = None,
|
|
516
513
|
max_tokens: Optional[int] = None,
|
|
517
514
|
n: Optional[int] = None,
|
|
518
515
|
presence_penalty: Optional[float] = None,
|
|
519
516
|
response_format: Optional[ChatCompletionInputGrammarType] = None,
|
|
520
517
|
seed: Optional[int] = None,
|
|
521
|
-
stop: Optional[
|
|
518
|
+
stop: Optional[list[str]] = None,
|
|
522
519
|
stream_options: Optional[ChatCompletionInputStreamOptions] = None,
|
|
523
520
|
temperature: Optional[float] = None,
|
|
524
521
|
tool_choice: Optional[Union[ChatCompletionInputToolChoiceClass, "ChatCompletionInputToolChoiceEnum"]] = None,
|
|
525
522
|
tool_prompt: Optional[str] = None,
|
|
526
|
-
tools: Optional[
|
|
523
|
+
tools: Optional[list[ChatCompletionInputTool]] = None,
|
|
527
524
|
top_logprobs: Optional[int] = None,
|
|
528
525
|
top_p: Optional[float] = None,
|
|
529
|
-
extra_body: Optional[
|
|
526
|
+
extra_body: Optional[dict] = None,
|
|
530
527
|
) -> AsyncIterable[ChatCompletionStreamOutput]: ...
|
|
531
528
|
|
|
532
529
|
@overload
|
|
533
530
|
async def chat_completion(
|
|
534
531
|
self,
|
|
535
|
-
messages:
|
|
532
|
+
messages: list[Union[dict, ChatCompletionInputMessage]],
|
|
536
533
|
*,
|
|
537
534
|
model: Optional[str] = None,
|
|
538
535
|
stream: bool = False,
|
|
539
536
|
frequency_penalty: Optional[float] = None,
|
|
540
|
-
logit_bias: Optional[
|
|
537
|
+
logit_bias: Optional[list[float]] = None,
|
|
541
538
|
logprobs: Optional[bool] = None,
|
|
542
539
|
max_tokens: Optional[int] = None,
|
|
543
540
|
n: Optional[int] = None,
|
|
544
541
|
presence_penalty: Optional[float] = None,
|
|
545
542
|
response_format: Optional[ChatCompletionInputGrammarType] = None,
|
|
546
543
|
seed: Optional[int] = None,
|
|
547
|
-
stop: Optional[
|
|
544
|
+
stop: Optional[list[str]] = None,
|
|
548
545
|
stream_options: Optional[ChatCompletionInputStreamOptions] = None,
|
|
549
546
|
temperature: Optional[float] = None,
|
|
550
547
|
tool_choice: Optional[Union[ChatCompletionInputToolChoiceClass, "ChatCompletionInputToolChoiceEnum"]] = None,
|
|
551
548
|
tool_prompt: Optional[str] = None,
|
|
552
|
-
tools: Optional[
|
|
549
|
+
tools: Optional[list[ChatCompletionInputTool]] = None,
|
|
553
550
|
top_logprobs: Optional[int] = None,
|
|
554
551
|
top_p: Optional[float] = None,
|
|
555
|
-
extra_body: Optional[
|
|
552
|
+
extra_body: Optional[dict] = None,
|
|
556
553
|
) -> Union[ChatCompletionOutput, AsyncIterable[ChatCompletionStreamOutput]]: ...
|
|
557
554
|
|
|
558
555
|
async def chat_completion(
|
|
559
556
|
self,
|
|
560
|
-
messages:
|
|
557
|
+
messages: list[Union[dict, ChatCompletionInputMessage]],
|
|
561
558
|
*,
|
|
562
559
|
model: Optional[str] = None,
|
|
563
560
|
stream: bool = False,
|
|
564
561
|
# Parameters from ChatCompletionInput (handled manually)
|
|
565
562
|
frequency_penalty: Optional[float] = None,
|
|
566
|
-
logit_bias: Optional[
|
|
563
|
+
logit_bias: Optional[list[float]] = None,
|
|
567
564
|
logprobs: Optional[bool] = None,
|
|
568
565
|
max_tokens: Optional[int] = None,
|
|
569
566
|
n: Optional[int] = None,
|
|
570
567
|
presence_penalty: Optional[float] = None,
|
|
571
568
|
response_format: Optional[ChatCompletionInputGrammarType] = None,
|
|
572
569
|
seed: Optional[int] = None,
|
|
573
|
-
stop: Optional[
|
|
570
|
+
stop: Optional[list[str]] = None,
|
|
574
571
|
stream_options: Optional[ChatCompletionInputStreamOptions] = None,
|
|
575
572
|
temperature: Optional[float] = None,
|
|
576
573
|
tool_choice: Optional[Union[ChatCompletionInputToolChoiceClass, "ChatCompletionInputToolChoiceEnum"]] = None,
|
|
577
574
|
tool_prompt: Optional[str] = None,
|
|
578
|
-
tools: Optional[
|
|
575
|
+
tools: Optional[list[ChatCompletionInputTool]] = None,
|
|
579
576
|
top_logprobs: Optional[int] = None,
|
|
580
577
|
top_p: Optional[float] = None,
|
|
581
|
-
extra_body: Optional[
|
|
578
|
+
extra_body: Optional[dict] = None,
|
|
582
579
|
) -> Union[ChatCompletionOutput, AsyncIterable[ChatCompletionStreamOutput]]:
|
|
583
580
|
"""
|
|
584
581
|
A method for completing conversations using a specified language model.
|
|
585
582
|
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
for more details about OpenAI's compatibility.
|
|
583
|
+
> [!TIP]
|
|
584
|
+
> The `client.chat_completion` method is aliased as `client.chat.completions.create` for compatibility with OpenAI's client.
|
|
585
|
+
> Inputs and outputs are strictly the same and using either syntax will yield the same results.
|
|
586
|
+
> Check out the [Inference guide](https://huggingface.co/docs/huggingface_hub/guides/inference#openai-compatibility)
|
|
587
|
+
> for more details about OpenAI's compatibility.
|
|
592
588
|
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
<Tip>
|
|
596
|
-
You can pass provider-specific parameters to the model by using the `extra_body` argument.
|
|
597
|
-
</Tip>
|
|
589
|
+
> [!TIP]
|
|
590
|
+
> You can pass provider-specific parameters to the model by using the `extra_body` argument.
|
|
598
591
|
|
|
599
592
|
Args:
|
|
600
593
|
messages (List of [`ChatCompletionInputMessage`]):
|
|
@@ -608,7 +601,7 @@ class AsyncInferenceClient:
|
|
|
608
601
|
frequency_penalty (`float`, *optional*):
|
|
609
602
|
Penalizes new tokens based on their existing frequency
|
|
610
603
|
in the text so far. Range: [-2.0, 2.0]. Defaults to 0.0.
|
|
611
|
-
logit_bias (`
|
|
604
|
+
logit_bias (`list[float]`, *optional*):
|
|
612
605
|
Adjusts the likelihood of specific tokens appearing in the generated output.
|
|
613
606
|
logprobs (`bool`, *optional*):
|
|
614
607
|
Whether to return log probabilities of the output tokens or not. If true, returns the log
|
|
@@ -624,7 +617,7 @@ class AsyncInferenceClient:
|
|
|
624
617
|
Grammar constraints. Can be either a JSONSchema or a regex.
|
|
625
618
|
seed (Optional[`int`], *optional*):
|
|
626
619
|
Seed for reproducible control flow. Defaults to None.
|
|
627
|
-
stop (`
|
|
620
|
+
stop (`list[str]`, *optional*):
|
|
628
621
|
Up to four strings which trigger the end of the response.
|
|
629
622
|
Defaults to None.
|
|
630
623
|
stream (`bool`, *optional*):
|
|
@@ -648,7 +641,7 @@ class AsyncInferenceClient:
|
|
|
648
641
|
tools (List of [`ChatCompletionInputTool`], *optional*):
|
|
649
642
|
A list of tools the model may call. Currently, only functions are supported as a tool. Use this to
|
|
650
643
|
provide a list of functions the model may generate JSON inputs for.
|
|
651
|
-
extra_body (`
|
|
644
|
+
extra_body (`dict`, *optional*):
|
|
652
645
|
Additional provider-specific parameters to pass to the model. Refer to the provider's documentation
|
|
653
646
|
for supported parameters.
|
|
654
647
|
Returns:
|
|
@@ -660,7 +653,7 @@ class AsyncInferenceClient:
|
|
|
660
653
|
Raises:
|
|
661
654
|
[`InferenceTimeoutError`]:
|
|
662
655
|
If the model is unavailable or the request times out.
|
|
663
|
-
`
|
|
656
|
+
[`HfHubHTTPError`]:
|
|
664
657
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
665
658
|
|
|
666
659
|
Example:
|
|
@@ -892,7 +885,7 @@ class AsyncInferenceClient:
|
|
|
892
885
|
>>> messages = [
|
|
893
886
|
... {
|
|
894
887
|
... "role": "user",
|
|
895
|
-
... "content": "I saw a puppy a cat and a raccoon during my bike ride in the park. What did I
|
|
888
|
+
... "content": "I saw a puppy a cat and a raccoon during my bike ride in the park. What did I see and when?",
|
|
896
889
|
... },
|
|
897
890
|
... ]
|
|
898
891
|
>>> response_format = {
|
|
@@ -980,8 +973,8 @@ class AsyncInferenceClient:
|
|
|
980
973
|
max_question_len: Optional[int] = None,
|
|
981
974
|
max_seq_len: Optional[int] = None,
|
|
982
975
|
top_k: Optional[int] = None,
|
|
983
|
-
word_boxes: Optional[
|
|
984
|
-
) ->
|
|
976
|
+
word_boxes: Optional[list[Union[list[float], str]]] = None,
|
|
977
|
+
) -> list[DocumentQuestionAnsweringOutputElement]:
|
|
985
978
|
"""
|
|
986
979
|
Answer questions on document images.
|
|
987
980
|
|
|
@@ -1011,16 +1004,16 @@ class AsyncInferenceClient:
|
|
|
1011
1004
|
top_k (`int`, *optional*):
|
|
1012
1005
|
The number of answers to return (will be chosen by order of likelihood). Can return less than top_k
|
|
1013
1006
|
answers if there are not enough options available within the context.
|
|
1014
|
-
word_boxes (`
|
|
1007
|
+
word_boxes (`list[Union[list[float], str`, *optional*):
|
|
1015
1008
|
A list of words and bounding boxes (normalized 0->1000). If provided, the inference will skip the OCR
|
|
1016
1009
|
step and use the provided bounding boxes instead.
|
|
1017
1010
|
Returns:
|
|
1018
|
-
`
|
|
1011
|
+
`list[DocumentQuestionAnsweringOutputElement]`: a list of [`DocumentQuestionAnsweringOutputElement`] items containing the predicted label, associated probability, word ids, and page number.
|
|
1019
1012
|
|
|
1020
1013
|
Raises:
|
|
1021
1014
|
[`InferenceTimeoutError`]:
|
|
1022
1015
|
If the model is unavailable or the request times out.
|
|
1023
|
-
`
|
|
1016
|
+
[`HfHubHTTPError`]:
|
|
1024
1017
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
1025
1018
|
|
|
1026
1019
|
|
|
@@ -1035,7 +1028,7 @@ class AsyncInferenceClient:
|
|
|
1035
1028
|
"""
|
|
1036
1029
|
model_id = model or self.model
|
|
1037
1030
|
provider_helper = get_provider_helper(self.provider, task="document-question-answering", model=model_id)
|
|
1038
|
-
inputs:
|
|
1031
|
+
inputs: dict[str, Any] = {"question": question, "image": _b64_encode(image)}
|
|
1039
1032
|
request_parameters = provider_helper.prepare_request(
|
|
1040
1033
|
inputs=inputs,
|
|
1041
1034
|
parameters={
|
|
@@ -1096,7 +1089,7 @@ class AsyncInferenceClient:
|
|
|
1096
1089
|
Raises:
|
|
1097
1090
|
[`InferenceTimeoutError`]:
|
|
1098
1091
|
If the model is unavailable or the request times out.
|
|
1099
|
-
`
|
|
1092
|
+
[`HfHubHTTPError`]:
|
|
1100
1093
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
1101
1094
|
|
|
1102
1095
|
Example:
|
|
@@ -1134,9 +1127,9 @@ class AsyncInferenceClient:
|
|
|
1134
1127
|
text: str,
|
|
1135
1128
|
*,
|
|
1136
1129
|
model: Optional[str] = None,
|
|
1137
|
-
targets: Optional[
|
|
1130
|
+
targets: Optional[list[str]] = None,
|
|
1138
1131
|
top_k: Optional[int] = None,
|
|
1139
|
-
) ->
|
|
1132
|
+
) -> list[FillMaskOutputElement]:
|
|
1140
1133
|
"""
|
|
1141
1134
|
Fill in a hole with a missing word (token to be precise).
|
|
1142
1135
|
|
|
@@ -1146,20 +1139,20 @@ class AsyncInferenceClient:
|
|
|
1146
1139
|
model (`str`, *optional*):
|
|
1147
1140
|
The model to use for the fill mask task. Can be a model ID hosted on the Hugging Face Hub or a URL to
|
|
1148
1141
|
a deployed Inference Endpoint. If not provided, the default recommended fill mask model will be used.
|
|
1149
|
-
targets (`
|
|
1142
|
+
targets (`list[str`, *optional*):
|
|
1150
1143
|
When passed, the model will limit the scores to the passed targets instead of looking up in the whole
|
|
1151
1144
|
vocabulary. If the provided targets are not in the model vocab, they will be tokenized and the first
|
|
1152
1145
|
resulting token will be used (with a warning, and that might be slower).
|
|
1153
1146
|
top_k (`int`, *optional*):
|
|
1154
1147
|
When passed, overrides the number of predictions to return.
|
|
1155
1148
|
Returns:
|
|
1156
|
-
`
|
|
1149
|
+
`list[FillMaskOutputElement]`: a list of [`FillMaskOutputElement`] items containing the predicted label, associated
|
|
1157
1150
|
probability, token reference, and completed text.
|
|
1158
1151
|
|
|
1159
1152
|
Raises:
|
|
1160
1153
|
[`InferenceTimeoutError`]:
|
|
1161
1154
|
If the model is unavailable or the request times out.
|
|
1162
|
-
`
|
|
1155
|
+
[`HfHubHTTPError`]:
|
|
1163
1156
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
1164
1157
|
|
|
1165
1158
|
Example:
|
|
@@ -1193,13 +1186,13 @@ class AsyncInferenceClient:
|
|
|
1193
1186
|
model: Optional[str] = None,
|
|
1194
1187
|
function_to_apply: Optional["ImageClassificationOutputTransform"] = None,
|
|
1195
1188
|
top_k: Optional[int] = None,
|
|
1196
|
-
) ->
|
|
1189
|
+
) -> list[ImageClassificationOutputElement]:
|
|
1197
1190
|
"""
|
|
1198
1191
|
Perform image classification on the given image using the specified model.
|
|
1199
1192
|
|
|
1200
1193
|
Args:
|
|
1201
|
-
image (`Union[str, Path, bytes, BinaryIO]`):
|
|
1202
|
-
The image to classify. It can be raw bytes, an image file,
|
|
1194
|
+
image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`):
|
|
1195
|
+
The image to classify. It can be raw bytes, an image file, a URL to an online image, or a PIL Image.
|
|
1203
1196
|
model (`str`, *optional*):
|
|
1204
1197
|
The model to use for image classification. Can be a model ID hosted on the Hugging Face Hub or a URL to a
|
|
1205
1198
|
deployed Inference Endpoint. If not provided, the default recommended model for image classification will be used.
|
|
@@ -1208,12 +1201,12 @@ class AsyncInferenceClient:
|
|
|
1208
1201
|
top_k (`int`, *optional*):
|
|
1209
1202
|
When specified, limits the output to the top K most probable classes.
|
|
1210
1203
|
Returns:
|
|
1211
|
-
`
|
|
1204
|
+
`list[ImageClassificationOutputElement]`: a list of [`ImageClassificationOutputElement`] items containing the predicted label and associated probability.
|
|
1212
1205
|
|
|
1213
1206
|
Raises:
|
|
1214
1207
|
[`InferenceTimeoutError`]:
|
|
1215
1208
|
If the model is unavailable or the request times out.
|
|
1216
|
-
`
|
|
1209
|
+
[`HfHubHTTPError`]:
|
|
1217
1210
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
1218
1211
|
|
|
1219
1212
|
Example:
|
|
@@ -1246,19 +1239,16 @@ class AsyncInferenceClient:
|
|
|
1246
1239
|
overlap_mask_area_threshold: Optional[float] = None,
|
|
1247
1240
|
subtask: Optional["ImageSegmentationSubtask"] = None,
|
|
1248
1241
|
threshold: Optional[float] = None,
|
|
1249
|
-
) ->
|
|
1242
|
+
) -> list[ImageSegmentationOutputElement]:
|
|
1250
1243
|
"""
|
|
1251
1244
|
Perform image segmentation on the given image using the specified model.
|
|
1252
1245
|
|
|
1253
|
-
|
|
1254
|
-
|
|
1255
|
-
You must have `PIL` installed if you want to work with images (`pip install Pillow`).
|
|
1256
|
-
|
|
1257
|
-
</Tip>
|
|
1246
|
+
> [!WARNING]
|
|
1247
|
+
> You must have `PIL` installed if you want to work with images (`pip install Pillow`).
|
|
1258
1248
|
|
|
1259
1249
|
Args:
|
|
1260
|
-
image (`Union[str, Path, bytes, BinaryIO]`):
|
|
1261
|
-
The image to segment. It can be raw bytes, an image file,
|
|
1250
|
+
image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`):
|
|
1251
|
+
The image to segment. It can be raw bytes, an image file, a URL to an online image, or a PIL Image.
|
|
1262
1252
|
model (`str`, *optional*):
|
|
1263
1253
|
The model to use for image segmentation. Can be a model ID hosted on the Hugging Face Hub or a URL to a
|
|
1264
1254
|
deployed Inference Endpoint. If not provided, the default recommended model for image segmentation will be used.
|
|
@@ -1271,12 +1261,12 @@ class AsyncInferenceClient:
|
|
|
1271
1261
|
threshold (`float`, *optional*):
|
|
1272
1262
|
Probability threshold to filter out predicted masks.
|
|
1273
1263
|
Returns:
|
|
1274
|
-
`
|
|
1264
|
+
`list[ImageSegmentationOutputElement]`: A list of [`ImageSegmentationOutputElement`] items containing the segmented masks and associated attributes.
|
|
1275
1265
|
|
|
1276
1266
|
Raises:
|
|
1277
1267
|
[`InferenceTimeoutError`]:
|
|
1278
1268
|
If the model is unavailable or the request times out.
|
|
1279
|
-
`
|
|
1269
|
+
[`HfHubHTTPError`]:
|
|
1280
1270
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
1281
1271
|
|
|
1282
1272
|
Example:
|
|
@@ -1303,6 +1293,7 @@ class AsyncInferenceClient:
|
|
|
1303
1293
|
api_key=self.token,
|
|
1304
1294
|
)
|
|
1305
1295
|
response = await self._inner_post(request_parameters)
|
|
1296
|
+
response = provider_helper.get_response(response, request_parameters)
|
|
1306
1297
|
output = ImageSegmentationOutputElement.parse_obj_as_list(response)
|
|
1307
1298
|
for item in output:
|
|
1308
1299
|
item.mask = _b64_to_image(item.mask) # type: ignore [assignment]
|
|
@@ -1323,15 +1314,12 @@ class AsyncInferenceClient:
|
|
|
1323
1314
|
"""
|
|
1324
1315
|
Perform image-to-image translation using a specified model.
|
|
1325
1316
|
|
|
1326
|
-
|
|
1327
|
-
|
|
1328
|
-
You must have `PIL` installed if you want to work with images (`pip install Pillow`).
|
|
1329
|
-
|
|
1330
|
-
</Tip>
|
|
1317
|
+
> [!WARNING]
|
|
1318
|
+
> You must have `PIL` installed if you want to work with images (`pip install Pillow`).
|
|
1331
1319
|
|
|
1332
1320
|
Args:
|
|
1333
|
-
image (`Union[str, Path, bytes, BinaryIO]`):
|
|
1334
|
-
The input image for translation. It can be raw bytes, an image file,
|
|
1321
|
+
image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`):
|
|
1322
|
+
The input image for translation. It can be raw bytes, an image file, a URL to an online image, or a PIL Image.
|
|
1335
1323
|
prompt (`str`, *optional*):
|
|
1336
1324
|
The text prompt to guide the image generation.
|
|
1337
1325
|
negative_prompt (`str`, *optional*):
|
|
@@ -1346,7 +1334,8 @@ class AsyncInferenceClient:
|
|
|
1346
1334
|
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
|
1347
1335
|
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
|
|
1348
1336
|
target_size (`ImageToImageTargetSize`, *optional*):
|
|
1349
|
-
The size in
|
|
1337
|
+
The size in pixels of the output image. This parameter is only supported by some providers and for
|
|
1338
|
+
specific models. It will be ignored when unsupported.
|
|
1350
1339
|
|
|
1351
1340
|
Returns:
|
|
1352
1341
|
`Image`: The translated image.
|
|
@@ -1354,7 +1343,7 @@ class AsyncInferenceClient:
|
|
|
1354
1343
|
Raises:
|
|
1355
1344
|
[`InferenceTimeoutError`]:
|
|
1356
1345
|
If the model is unavailable or the request times out.
|
|
1357
|
-
`
|
|
1346
|
+
[`HfHubHTTPError`]:
|
|
1358
1347
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
1359
1348
|
|
|
1360
1349
|
Example:
|
|
@@ -1365,6 +1354,7 @@ class AsyncInferenceClient:
|
|
|
1365
1354
|
>>> image = await client.image_to_image("cat.jpg", prompt="turn the cat into a tiger")
|
|
1366
1355
|
>>> image.save("tiger.jpg")
|
|
1367
1356
|
```
|
|
1357
|
+
|
|
1368
1358
|
"""
|
|
1369
1359
|
model_id = model or self.model
|
|
1370
1360
|
provider_helper = get_provider_helper(self.provider, task="image-to-image", model=model_id)
|
|
@@ -1383,18 +1373,99 @@ class AsyncInferenceClient:
|
|
|
1383
1373
|
api_key=self.token,
|
|
1384
1374
|
)
|
|
1385
1375
|
response = await self._inner_post(request_parameters)
|
|
1376
|
+
response = provider_helper.get_response(response, request_parameters)
|
|
1386
1377
|
return _bytes_to_image(response)
|
|
1387
1378
|
|
|
1379
|
+
async def image_to_video(
|
|
1380
|
+
self,
|
|
1381
|
+
image: ContentT,
|
|
1382
|
+
*,
|
|
1383
|
+
model: Optional[str] = None,
|
|
1384
|
+
prompt: Optional[str] = None,
|
|
1385
|
+
negative_prompt: Optional[str] = None,
|
|
1386
|
+
num_frames: Optional[float] = None,
|
|
1387
|
+
num_inference_steps: Optional[int] = None,
|
|
1388
|
+
guidance_scale: Optional[float] = None,
|
|
1389
|
+
seed: Optional[int] = None,
|
|
1390
|
+
target_size: Optional[ImageToVideoTargetSize] = None,
|
|
1391
|
+
**kwargs,
|
|
1392
|
+
) -> bytes:
|
|
1393
|
+
"""
|
|
1394
|
+
Generate a video from an input image.
|
|
1395
|
+
|
|
1396
|
+
Args:
|
|
1397
|
+
image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`):
|
|
1398
|
+
The input image to generate a video from. It can be raw bytes, an image file, a URL to an online image, or a PIL Image.
|
|
1399
|
+
model (`str`, *optional*):
|
|
1400
|
+
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
|
1401
|
+
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
|
|
1402
|
+
prompt (`str`, *optional*):
|
|
1403
|
+
The text prompt to guide the video generation.
|
|
1404
|
+
negative_prompt (`str`, *optional*):
|
|
1405
|
+
One prompt to guide what NOT to include in video generation.
|
|
1406
|
+
num_frames (`float`, *optional*):
|
|
1407
|
+
The num_frames parameter determines how many video frames are generated.
|
|
1408
|
+
num_inference_steps (`int`, *optional*):
|
|
1409
|
+
For diffusion models. The number of denoising steps. More denoising steps usually lead to a higher
|
|
1410
|
+
quality image at the expense of slower inference.
|
|
1411
|
+
guidance_scale (`float`, *optional*):
|
|
1412
|
+
For diffusion models. A higher guidance scale value encourages the model to generate videos closely
|
|
1413
|
+
linked to the text prompt at the expense of lower image quality.
|
|
1414
|
+
seed (`int`, *optional*):
|
|
1415
|
+
The seed to use for the video generation.
|
|
1416
|
+
target_size (`ImageToVideoTargetSize`, *optional*):
|
|
1417
|
+
The size in pixel of the output video frames.
|
|
1418
|
+
num_inference_steps (`int`, *optional*):
|
|
1419
|
+
The number of denoising steps. More denoising steps usually lead to a higher quality video at the
|
|
1420
|
+
expense of slower inference.
|
|
1421
|
+
seed (`int`, *optional*):
|
|
1422
|
+
Seed for the random number generator.
|
|
1423
|
+
|
|
1424
|
+
Returns:
|
|
1425
|
+
`bytes`: The generated video.
|
|
1426
|
+
|
|
1427
|
+
Examples:
|
|
1428
|
+
```py
|
|
1429
|
+
# Must be run in an async context
|
|
1430
|
+
>>> from huggingface_hub import AsyncInferenceClient
|
|
1431
|
+
>>> client = AsyncInferenceClient()
|
|
1432
|
+
>>> video = await client.image_to_video("cat.jpg", model="Wan-AI/Wan2.2-I2V-A14B", prompt="turn the cat into a tiger")
|
|
1433
|
+
>>> with open("tiger.mp4", "wb") as f:
|
|
1434
|
+
... f.write(video)
|
|
1435
|
+
```
|
|
1436
|
+
"""
|
|
1437
|
+
model_id = model or self.model
|
|
1438
|
+
provider_helper = get_provider_helper(self.provider, task="image-to-video", model=model_id)
|
|
1439
|
+
request_parameters = provider_helper.prepare_request(
|
|
1440
|
+
inputs=image,
|
|
1441
|
+
parameters={
|
|
1442
|
+
"prompt": prompt,
|
|
1443
|
+
"negative_prompt": negative_prompt,
|
|
1444
|
+
"num_frames": num_frames,
|
|
1445
|
+
"num_inference_steps": num_inference_steps,
|
|
1446
|
+
"guidance_scale": guidance_scale,
|
|
1447
|
+
"seed": seed,
|
|
1448
|
+
"target_size": target_size,
|
|
1449
|
+
**kwargs,
|
|
1450
|
+
},
|
|
1451
|
+
headers=self.headers,
|
|
1452
|
+
model=model_id,
|
|
1453
|
+
api_key=self.token,
|
|
1454
|
+
)
|
|
1455
|
+
response = await self._inner_post(request_parameters)
|
|
1456
|
+
response = provider_helper.get_response(response, request_parameters)
|
|
1457
|
+
return response
|
|
1458
|
+
|
|
1388
1459
|
async def image_to_text(self, image: ContentT, *, model: Optional[str] = None) -> ImageToTextOutput:
|
|
1389
1460
|
"""
|
|
1390
1461
|
Takes an input image and return text.
|
|
1391
1462
|
|
|
1392
1463
|
Models can have very different outputs depending on your use case (image captioning, optical character recognition
|
|
1393
|
-
(OCR), Pix2Struct, etc). Please have a look to the model card to learn more about a model's specificities.
|
|
1464
|
+
(OCR), Pix2Struct, etc.). Please have a look to the model card to learn more about a model's specificities.
|
|
1394
1465
|
|
|
1395
1466
|
Args:
|
|
1396
|
-
image (`Union[str, Path, bytes, BinaryIO]`):
|
|
1397
|
-
The input image to caption. It can be raw bytes, an image file,
|
|
1467
|
+
image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`):
|
|
1468
|
+
The input image to caption. It can be raw bytes, an image file, a URL to an online image, or a PIL Image.
|
|
1398
1469
|
model (`str`, *optional*):
|
|
1399
1470
|
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
|
1400
1471
|
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
|
|
@@ -1405,7 +1476,7 @@ class AsyncInferenceClient:
|
|
|
1405
1476
|
Raises:
|
|
1406
1477
|
[`InferenceTimeoutError`]:
|
|
1407
1478
|
If the model is unavailable or the request times out.
|
|
1408
|
-
`
|
|
1479
|
+
[`HfHubHTTPError`]:
|
|
1409
1480
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
1410
1481
|
|
|
1411
1482
|
Example:
|
|
@@ -1429,36 +1500,33 @@ class AsyncInferenceClient:
|
|
|
1429
1500
|
api_key=self.token,
|
|
1430
1501
|
)
|
|
1431
1502
|
response = await self._inner_post(request_parameters)
|
|
1432
|
-
|
|
1433
|
-
return
|
|
1503
|
+
output_list: list[ImageToTextOutput] = ImageToTextOutput.parse_obj_as_list(response)
|
|
1504
|
+
return output_list[0]
|
|
1434
1505
|
|
|
1435
1506
|
async def object_detection(
|
|
1436
1507
|
self, image: ContentT, *, model: Optional[str] = None, threshold: Optional[float] = None
|
|
1437
|
-
) ->
|
|
1508
|
+
) -> list[ObjectDetectionOutputElement]:
|
|
1438
1509
|
"""
|
|
1439
1510
|
Perform object detection on the given image using the specified model.
|
|
1440
1511
|
|
|
1441
|
-
|
|
1442
|
-
|
|
1443
|
-
You must have `PIL` installed if you want to work with images (`pip install Pillow`).
|
|
1444
|
-
|
|
1445
|
-
</Tip>
|
|
1512
|
+
> [!WARNING]
|
|
1513
|
+
> You must have `PIL` installed if you want to work with images (`pip install Pillow`).
|
|
1446
1514
|
|
|
1447
1515
|
Args:
|
|
1448
|
-
image (`Union[str, Path, bytes, BinaryIO]`):
|
|
1449
|
-
The image to detect objects on. It can be raw bytes, an image file,
|
|
1516
|
+
image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`):
|
|
1517
|
+
The image to detect objects on. It can be raw bytes, an image file, a URL to an online image, or a PIL Image.
|
|
1450
1518
|
model (`str`, *optional*):
|
|
1451
1519
|
The model to use for object detection. Can be a model ID hosted on the Hugging Face Hub or a URL to a
|
|
1452
1520
|
deployed Inference Endpoint. If not provided, the default recommended model for object detection (DETR) will be used.
|
|
1453
1521
|
threshold (`float`, *optional*):
|
|
1454
1522
|
The probability necessary to make a prediction.
|
|
1455
1523
|
Returns:
|
|
1456
|
-
`
|
|
1524
|
+
`list[ObjectDetectionOutputElement]`: A list of [`ObjectDetectionOutputElement`] items containing the bounding boxes and associated attributes.
|
|
1457
1525
|
|
|
1458
1526
|
Raises:
|
|
1459
1527
|
[`InferenceTimeoutError`]:
|
|
1460
1528
|
If the model is unavailable or the request times out.
|
|
1461
|
-
`
|
|
1529
|
+
[`HfHubHTTPError`]:
|
|
1462
1530
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
1463
1531
|
`ValueError`:
|
|
1464
1532
|
If the request output is not a List.
|
|
@@ -1497,7 +1565,7 @@ class AsyncInferenceClient:
|
|
|
1497
1565
|
max_question_len: Optional[int] = None,
|
|
1498
1566
|
max_seq_len: Optional[int] = None,
|
|
1499
1567
|
top_k: Optional[int] = None,
|
|
1500
|
-
) -> Union[QuestionAnsweringOutputElement,
|
|
1568
|
+
) -> Union[QuestionAnsweringOutputElement, list[QuestionAnsweringOutputElement]]:
|
|
1501
1569
|
"""
|
|
1502
1570
|
Retrieve the answer to a question from a given text.
|
|
1503
1571
|
|
|
@@ -1529,13 +1597,13 @@ class AsyncInferenceClient:
|
|
|
1529
1597
|
topk answers if there are not enough options available within the context.
|
|
1530
1598
|
|
|
1531
1599
|
Returns:
|
|
1532
|
-
Union[`QuestionAnsweringOutputElement`,
|
|
1600
|
+
Union[`QuestionAnsweringOutputElement`, list[`QuestionAnsweringOutputElement`]]:
|
|
1533
1601
|
When top_k is 1 or not provided, it returns a single `QuestionAnsweringOutputElement`.
|
|
1534
1602
|
When top_k is greater than 1, it returns a list of `QuestionAnsweringOutputElement`.
|
|
1535
1603
|
Raises:
|
|
1536
1604
|
[`InferenceTimeoutError`]:
|
|
1537
1605
|
If the model is unavailable or the request times out.
|
|
1538
|
-
`
|
|
1606
|
+
[`HfHubHTTPError`]:
|
|
1539
1607
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
1540
1608
|
|
|
1541
1609
|
Example:
|
|
@@ -1550,7 +1618,7 @@ class AsyncInferenceClient:
|
|
|
1550
1618
|
model_id = model or self.model
|
|
1551
1619
|
provider_helper = get_provider_helper(self.provider, task="question-answering", model=model_id)
|
|
1552
1620
|
request_parameters = provider_helper.prepare_request(
|
|
1553
|
-
inputs=
|
|
1621
|
+
inputs={"question": question, "context": context},
|
|
1554
1622
|
parameters={
|
|
1555
1623
|
"align_to_words": align_to_words,
|
|
1556
1624
|
"doc_stride": doc_stride,
|
|
@@ -1560,7 +1628,6 @@ class AsyncInferenceClient:
|
|
|
1560
1628
|
"max_seq_len": max_seq_len,
|
|
1561
1629
|
"top_k": top_k,
|
|
1562
1630
|
},
|
|
1563
|
-
extra_payload={"question": question, "context": context},
|
|
1564
1631
|
headers=self.headers,
|
|
1565
1632
|
model=model_id,
|
|
1566
1633
|
api_key=self.token,
|
|
@@ -1571,15 +1638,15 @@ class AsyncInferenceClient:
|
|
|
1571
1638
|
return output
|
|
1572
1639
|
|
|
1573
1640
|
async def sentence_similarity(
|
|
1574
|
-
self, sentence: str, other_sentences:
|
|
1575
|
-
) ->
|
|
1641
|
+
self, sentence: str, other_sentences: list[str], *, model: Optional[str] = None
|
|
1642
|
+
) -> list[float]:
|
|
1576
1643
|
"""
|
|
1577
1644
|
Compute the semantic similarity between a sentence and a list of other sentences by comparing their embeddings.
|
|
1578
1645
|
|
|
1579
1646
|
Args:
|
|
1580
1647
|
sentence (`str`):
|
|
1581
1648
|
The main sentence to compare to others.
|
|
1582
|
-
other_sentences (`
|
|
1649
|
+
other_sentences (`list[str]`):
|
|
1583
1650
|
The list of sentences to compare to.
|
|
1584
1651
|
model (`str`, *optional*):
|
|
1585
1652
|
The model to use for the sentence similarity task. Can be a model ID hosted on the Hugging Face Hub or a URL to
|
|
@@ -1587,12 +1654,12 @@ class AsyncInferenceClient:
|
|
|
1587
1654
|
Defaults to None.
|
|
1588
1655
|
|
|
1589
1656
|
Returns:
|
|
1590
|
-
`
|
|
1657
|
+
`list[float]`: The embedding representing the input text.
|
|
1591
1658
|
|
|
1592
1659
|
Raises:
|
|
1593
1660
|
[`InferenceTimeoutError`]:
|
|
1594
1661
|
If the model is unavailable or the request times out.
|
|
1595
|
-
`
|
|
1662
|
+
[`HfHubHTTPError`]:
|
|
1596
1663
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
1597
1664
|
|
|
1598
1665
|
Example:
|
|
@@ -1630,7 +1697,7 @@ class AsyncInferenceClient:
|
|
|
1630
1697
|
*,
|
|
1631
1698
|
model: Optional[str] = None,
|
|
1632
1699
|
clean_up_tokenization_spaces: Optional[bool] = None,
|
|
1633
|
-
generate_parameters: Optional[
|
|
1700
|
+
generate_parameters: Optional[dict[str, Any]] = None,
|
|
1634
1701
|
truncation: Optional["SummarizationTruncationStrategy"] = None,
|
|
1635
1702
|
) -> SummarizationOutput:
|
|
1636
1703
|
"""
|
|
@@ -1644,7 +1711,7 @@ class AsyncInferenceClient:
|
|
|
1644
1711
|
Inference Endpoint. If not provided, the default recommended model for summarization will be used.
|
|
1645
1712
|
clean_up_tokenization_spaces (`bool`, *optional*):
|
|
1646
1713
|
Whether to clean up the potential extra spaces in the text output.
|
|
1647
|
-
generate_parameters (`
|
|
1714
|
+
generate_parameters (`dict[str, Any]`, *optional*):
|
|
1648
1715
|
Additional parametrization of the text generation algorithm.
|
|
1649
1716
|
truncation (`"SummarizationTruncationStrategy"`, *optional*):
|
|
1650
1717
|
The truncation strategy to use.
|
|
@@ -1654,7 +1721,7 @@ class AsyncInferenceClient:
|
|
|
1654
1721
|
Raises:
|
|
1655
1722
|
[`InferenceTimeoutError`]:
|
|
1656
1723
|
If the model is unavailable or the request times out.
|
|
1657
|
-
`
|
|
1724
|
+
[`HfHubHTTPError`]:
|
|
1658
1725
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
1659
1726
|
|
|
1660
1727
|
Example:
|
|
@@ -1685,7 +1752,7 @@ class AsyncInferenceClient:
|
|
|
1685
1752
|
|
|
1686
1753
|
async def table_question_answering(
|
|
1687
1754
|
self,
|
|
1688
|
-
table:
|
|
1755
|
+
table: dict[str, Any],
|
|
1689
1756
|
query: str,
|
|
1690
1757
|
*,
|
|
1691
1758
|
model: Optional[str] = None,
|
|
@@ -1720,7 +1787,7 @@ class AsyncInferenceClient:
|
|
|
1720
1787
|
Raises:
|
|
1721
1788
|
[`InferenceTimeoutError`]:
|
|
1722
1789
|
If the model is unavailable or the request times out.
|
|
1723
|
-
`
|
|
1790
|
+
[`HfHubHTTPError`]:
|
|
1724
1791
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
1725
1792
|
|
|
1726
1793
|
Example:
|
|
@@ -1737,9 +1804,8 @@ class AsyncInferenceClient:
|
|
|
1737
1804
|
model_id = model or self.model
|
|
1738
1805
|
provider_helper = get_provider_helper(self.provider, task="table-question-answering", model=model_id)
|
|
1739
1806
|
request_parameters = provider_helper.prepare_request(
|
|
1740
|
-
inputs=
|
|
1807
|
+
inputs={"query": query, "table": table},
|
|
1741
1808
|
parameters={"model": model, "padding": padding, "sequential": sequential, "truncation": truncation},
|
|
1742
|
-
extra_payload={"query": query, "table": table},
|
|
1743
1809
|
headers=self.headers,
|
|
1744
1810
|
model=model_id,
|
|
1745
1811
|
api_key=self.token,
|
|
@@ -1747,12 +1813,12 @@ class AsyncInferenceClient:
|
|
|
1747
1813
|
response = await self._inner_post(request_parameters)
|
|
1748
1814
|
return TableQuestionAnsweringOutputElement.parse_obj_as_instance(response)
|
|
1749
1815
|
|
|
1750
|
-
async def tabular_classification(self, table:
|
|
1816
|
+
async def tabular_classification(self, table: dict[str, Any], *, model: Optional[str] = None) -> list[str]:
|
|
1751
1817
|
"""
|
|
1752
1818
|
Classifying a target category (a group) based on a set of attributes.
|
|
1753
1819
|
|
|
1754
1820
|
Args:
|
|
1755
|
-
table (`
|
|
1821
|
+
table (`dict[str, Any]`):
|
|
1756
1822
|
Set of attributes to classify.
|
|
1757
1823
|
model (`str`, *optional*):
|
|
1758
1824
|
The model to use for the tabular classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to
|
|
@@ -1765,7 +1831,7 @@ class AsyncInferenceClient:
|
|
|
1765
1831
|
Raises:
|
|
1766
1832
|
[`InferenceTimeoutError`]:
|
|
1767
1833
|
If the model is unavailable or the request times out.
|
|
1768
|
-
`
|
|
1834
|
+
[`HfHubHTTPError`]:
|
|
1769
1835
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
1770
1836
|
|
|
1771
1837
|
Example:
|
|
@@ -1803,12 +1869,12 @@ class AsyncInferenceClient:
|
|
|
1803
1869
|
response = await self._inner_post(request_parameters)
|
|
1804
1870
|
return _bytes_to_list(response)
|
|
1805
1871
|
|
|
1806
|
-
async def tabular_regression(self, table:
|
|
1872
|
+
async def tabular_regression(self, table: dict[str, Any], *, model: Optional[str] = None) -> list[float]:
|
|
1807
1873
|
"""
|
|
1808
1874
|
Predicting a numerical target value given a set of attributes/features in a table.
|
|
1809
1875
|
|
|
1810
1876
|
Args:
|
|
1811
|
-
table (`
|
|
1877
|
+
table (`dict[str, Any]`):
|
|
1812
1878
|
Set of attributes stored in a table. The attributes used to predict the target can be both numerical and categorical.
|
|
1813
1879
|
model (`str`, *optional*):
|
|
1814
1880
|
The model to use for the tabular regression task. Can be a model ID hosted on the Hugging Face Hub or a URL to
|
|
@@ -1821,7 +1887,7 @@ class AsyncInferenceClient:
|
|
|
1821
1887
|
Raises:
|
|
1822
1888
|
[`InferenceTimeoutError`]:
|
|
1823
1889
|
If the model is unavailable or the request times out.
|
|
1824
|
-
`
|
|
1890
|
+
[`HfHubHTTPError`]:
|
|
1825
1891
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
1826
1892
|
|
|
1827
1893
|
Example:
|
|
@@ -1861,7 +1927,7 @@ class AsyncInferenceClient:
|
|
|
1861
1927
|
model: Optional[str] = None,
|
|
1862
1928
|
top_k: Optional[int] = None,
|
|
1863
1929
|
function_to_apply: Optional["TextClassificationOutputTransform"] = None,
|
|
1864
|
-
) ->
|
|
1930
|
+
) -> list[TextClassificationOutputElement]:
|
|
1865
1931
|
"""
|
|
1866
1932
|
Perform text classification (e.g. sentiment-analysis) on the given text.
|
|
1867
1933
|
|
|
@@ -1878,12 +1944,12 @@ class AsyncInferenceClient:
|
|
|
1878
1944
|
The function to apply to the model outputs in order to retrieve the scores.
|
|
1879
1945
|
|
|
1880
1946
|
Returns:
|
|
1881
|
-
`
|
|
1947
|
+
`list[TextClassificationOutputElement]`: a list of [`TextClassificationOutputElement`] items containing the predicted label and associated probability.
|
|
1882
1948
|
|
|
1883
1949
|
Raises:
|
|
1884
1950
|
[`InferenceTimeoutError`]:
|
|
1885
1951
|
If the model is unavailable or the request times out.
|
|
1886
|
-
`
|
|
1952
|
+
[`HfHubHTTPError`]:
|
|
1887
1953
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
1888
1954
|
|
|
1889
1955
|
Example:
|
|
@@ -1914,26 +1980,26 @@ class AsyncInferenceClient:
|
|
|
1914
1980
|
return TextClassificationOutputElement.parse_obj_as_list(response)[0] # type: ignore [return-value]
|
|
1915
1981
|
|
|
1916
1982
|
@overload
|
|
1917
|
-
async def text_generation(
|
|
1983
|
+
async def text_generation(
|
|
1918
1984
|
self,
|
|
1919
1985
|
prompt: str,
|
|
1920
1986
|
*,
|
|
1921
|
-
details: Literal[
|
|
1922
|
-
stream: Literal[
|
|
1987
|
+
details: Literal[True],
|
|
1988
|
+
stream: Literal[True],
|
|
1923
1989
|
model: Optional[str] = None,
|
|
1924
1990
|
# Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
|
|
1925
1991
|
adapter_id: Optional[str] = None,
|
|
1926
1992
|
best_of: Optional[int] = None,
|
|
1927
1993
|
decoder_input_details: Optional[bool] = None,
|
|
1928
|
-
do_sample: Optional[bool] =
|
|
1994
|
+
do_sample: Optional[bool] = None,
|
|
1929
1995
|
frequency_penalty: Optional[float] = None,
|
|
1930
1996
|
grammar: Optional[TextGenerationInputGrammarType] = None,
|
|
1931
1997
|
max_new_tokens: Optional[int] = None,
|
|
1932
1998
|
repetition_penalty: Optional[float] = None,
|
|
1933
|
-
return_full_text: Optional[bool] =
|
|
1999
|
+
return_full_text: Optional[bool] = None,
|
|
1934
2000
|
seed: Optional[int] = None,
|
|
1935
|
-
stop: Optional[
|
|
1936
|
-
stop_sequences: Optional[
|
|
2001
|
+
stop: Optional[list[str]] = None,
|
|
2002
|
+
stop_sequences: Optional[list[str]] = None, # Deprecated, use `stop` instead
|
|
1937
2003
|
temperature: Optional[float] = None,
|
|
1938
2004
|
top_k: Optional[int] = None,
|
|
1939
2005
|
top_n_tokens: Optional[int] = None,
|
|
@@ -1941,29 +2007,29 @@ class AsyncInferenceClient:
|
|
|
1941
2007
|
truncate: Optional[int] = None,
|
|
1942
2008
|
typical_p: Optional[float] = None,
|
|
1943
2009
|
watermark: Optional[bool] = None,
|
|
1944
|
-
) ->
|
|
2010
|
+
) -> AsyncIterable[TextGenerationStreamOutput]: ...
|
|
1945
2011
|
|
|
1946
2012
|
@overload
|
|
1947
|
-
async def text_generation(
|
|
2013
|
+
async def text_generation(
|
|
1948
2014
|
self,
|
|
1949
2015
|
prompt: str,
|
|
1950
2016
|
*,
|
|
1951
|
-
details: Literal[True]
|
|
1952
|
-
stream: Literal[False] =
|
|
2017
|
+
details: Literal[True],
|
|
2018
|
+
stream: Optional[Literal[False]] = None,
|
|
1953
2019
|
model: Optional[str] = None,
|
|
1954
2020
|
# Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
|
|
1955
2021
|
adapter_id: Optional[str] = None,
|
|
1956
2022
|
best_of: Optional[int] = None,
|
|
1957
2023
|
decoder_input_details: Optional[bool] = None,
|
|
1958
|
-
do_sample: Optional[bool] =
|
|
2024
|
+
do_sample: Optional[bool] = None,
|
|
1959
2025
|
frequency_penalty: Optional[float] = None,
|
|
1960
2026
|
grammar: Optional[TextGenerationInputGrammarType] = None,
|
|
1961
2027
|
max_new_tokens: Optional[int] = None,
|
|
1962
2028
|
repetition_penalty: Optional[float] = None,
|
|
1963
|
-
return_full_text: Optional[bool] =
|
|
2029
|
+
return_full_text: Optional[bool] = None,
|
|
1964
2030
|
seed: Optional[int] = None,
|
|
1965
|
-
stop: Optional[
|
|
1966
|
-
stop_sequences: Optional[
|
|
2031
|
+
stop: Optional[list[str]] = None,
|
|
2032
|
+
stop_sequences: Optional[list[str]] = None, # Deprecated, use `stop` instead
|
|
1967
2033
|
temperature: Optional[float] = None,
|
|
1968
2034
|
top_k: Optional[int] = None,
|
|
1969
2035
|
top_n_tokens: Optional[int] = None,
|
|
@@ -1974,26 +2040,26 @@ class AsyncInferenceClient:
|
|
|
1974
2040
|
) -> TextGenerationOutput: ...
|
|
1975
2041
|
|
|
1976
2042
|
@overload
|
|
1977
|
-
async def text_generation(
|
|
2043
|
+
async def text_generation(
|
|
1978
2044
|
self,
|
|
1979
2045
|
prompt: str,
|
|
1980
2046
|
*,
|
|
1981
|
-
details: Literal[False] =
|
|
1982
|
-
stream: Literal[True]
|
|
2047
|
+
details: Optional[Literal[False]] = None,
|
|
2048
|
+
stream: Literal[True],
|
|
1983
2049
|
model: Optional[str] = None,
|
|
1984
2050
|
# Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
|
|
1985
2051
|
adapter_id: Optional[str] = None,
|
|
1986
2052
|
best_of: Optional[int] = None,
|
|
1987
2053
|
decoder_input_details: Optional[bool] = None,
|
|
1988
|
-
do_sample: Optional[bool] =
|
|
2054
|
+
do_sample: Optional[bool] = None,
|
|
1989
2055
|
frequency_penalty: Optional[float] = None,
|
|
1990
2056
|
grammar: Optional[TextGenerationInputGrammarType] = None,
|
|
1991
2057
|
max_new_tokens: Optional[int] = None,
|
|
1992
2058
|
repetition_penalty: Optional[float] = None,
|
|
1993
|
-
return_full_text: Optional[bool] =
|
|
2059
|
+
return_full_text: Optional[bool] = None, # Manual default value
|
|
1994
2060
|
seed: Optional[int] = None,
|
|
1995
|
-
stop: Optional[
|
|
1996
|
-
stop_sequences: Optional[
|
|
2061
|
+
stop: Optional[list[str]] = None,
|
|
2062
|
+
stop_sequences: Optional[list[str]] = None, # Deprecated, use `stop` instead
|
|
1997
2063
|
temperature: Optional[float] = None,
|
|
1998
2064
|
top_k: Optional[int] = None,
|
|
1999
2065
|
top_n_tokens: Optional[int] = None,
|
|
@@ -2004,26 +2070,26 @@ class AsyncInferenceClient:
|
|
|
2004
2070
|
) -> AsyncIterable[str]: ...
|
|
2005
2071
|
|
|
2006
2072
|
@overload
|
|
2007
|
-
async def text_generation(
|
|
2073
|
+
async def text_generation(
|
|
2008
2074
|
self,
|
|
2009
2075
|
prompt: str,
|
|
2010
2076
|
*,
|
|
2011
|
-
details: Literal[
|
|
2012
|
-
stream: Literal[
|
|
2077
|
+
details: Optional[Literal[False]] = None,
|
|
2078
|
+
stream: Optional[Literal[False]] = None,
|
|
2013
2079
|
model: Optional[str] = None,
|
|
2014
2080
|
# Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
|
|
2015
2081
|
adapter_id: Optional[str] = None,
|
|
2016
2082
|
best_of: Optional[int] = None,
|
|
2017
2083
|
decoder_input_details: Optional[bool] = None,
|
|
2018
|
-
do_sample: Optional[bool] =
|
|
2084
|
+
do_sample: Optional[bool] = None,
|
|
2019
2085
|
frequency_penalty: Optional[float] = None,
|
|
2020
2086
|
grammar: Optional[TextGenerationInputGrammarType] = None,
|
|
2021
2087
|
max_new_tokens: Optional[int] = None,
|
|
2022
2088
|
repetition_penalty: Optional[float] = None,
|
|
2023
|
-
return_full_text: Optional[bool] =
|
|
2089
|
+
return_full_text: Optional[bool] = None,
|
|
2024
2090
|
seed: Optional[int] = None,
|
|
2025
|
-
stop: Optional[
|
|
2026
|
-
stop_sequences: Optional[
|
|
2091
|
+
stop: Optional[list[str]] = None,
|
|
2092
|
+
stop_sequences: Optional[list[str]] = None, # Deprecated, use `stop` instead
|
|
2027
2093
|
temperature: Optional[float] = None,
|
|
2028
2094
|
top_k: Optional[int] = None,
|
|
2029
2095
|
top_n_tokens: Optional[int] = None,
|
|
@@ -2031,29 +2097,29 @@ class AsyncInferenceClient:
|
|
|
2031
2097
|
truncate: Optional[int] = None,
|
|
2032
2098
|
typical_p: Optional[float] = None,
|
|
2033
2099
|
watermark: Optional[bool] = None,
|
|
2034
|
-
) ->
|
|
2100
|
+
) -> str: ...
|
|
2035
2101
|
|
|
2036
2102
|
@overload
|
|
2037
2103
|
async def text_generation(
|
|
2038
2104
|
self,
|
|
2039
2105
|
prompt: str,
|
|
2040
2106
|
*,
|
|
2041
|
-
details:
|
|
2042
|
-
stream: bool =
|
|
2107
|
+
details: Optional[bool] = None,
|
|
2108
|
+
stream: Optional[bool] = None,
|
|
2043
2109
|
model: Optional[str] = None,
|
|
2044
2110
|
# Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
|
|
2045
2111
|
adapter_id: Optional[str] = None,
|
|
2046
2112
|
best_of: Optional[int] = None,
|
|
2047
2113
|
decoder_input_details: Optional[bool] = None,
|
|
2048
|
-
do_sample: Optional[bool] =
|
|
2114
|
+
do_sample: Optional[bool] = None,
|
|
2049
2115
|
frequency_penalty: Optional[float] = None,
|
|
2050
2116
|
grammar: Optional[TextGenerationInputGrammarType] = None,
|
|
2051
2117
|
max_new_tokens: Optional[int] = None,
|
|
2052
2118
|
repetition_penalty: Optional[float] = None,
|
|
2053
|
-
return_full_text: Optional[bool] =
|
|
2119
|
+
return_full_text: Optional[bool] = None,
|
|
2054
2120
|
seed: Optional[int] = None,
|
|
2055
|
-
stop: Optional[
|
|
2056
|
-
stop_sequences: Optional[
|
|
2121
|
+
stop: Optional[list[str]] = None,
|
|
2122
|
+
stop_sequences: Optional[list[str]] = None, # Deprecated, use `stop` instead
|
|
2057
2123
|
temperature: Optional[float] = None,
|
|
2058
2124
|
top_k: Optional[int] = None,
|
|
2059
2125
|
top_n_tokens: Optional[int] = None,
|
|
@@ -2061,28 +2127,28 @@ class AsyncInferenceClient:
|
|
|
2061
2127
|
truncate: Optional[int] = None,
|
|
2062
2128
|
typical_p: Optional[float] = None,
|
|
2063
2129
|
watermark: Optional[bool] = None,
|
|
2064
|
-
) -> Union[TextGenerationOutput, AsyncIterable[TextGenerationStreamOutput]]: ...
|
|
2130
|
+
) -> Union[str, TextGenerationOutput, AsyncIterable[str], AsyncIterable[TextGenerationStreamOutput]]: ...
|
|
2065
2131
|
|
|
2066
2132
|
async def text_generation(
|
|
2067
2133
|
self,
|
|
2068
2134
|
prompt: str,
|
|
2069
2135
|
*,
|
|
2070
|
-
details: bool =
|
|
2071
|
-
stream: bool =
|
|
2136
|
+
details: Optional[bool] = None,
|
|
2137
|
+
stream: Optional[bool] = None,
|
|
2072
2138
|
model: Optional[str] = None,
|
|
2073
2139
|
# Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
|
|
2074
2140
|
adapter_id: Optional[str] = None,
|
|
2075
2141
|
best_of: Optional[int] = None,
|
|
2076
2142
|
decoder_input_details: Optional[bool] = None,
|
|
2077
|
-
do_sample: Optional[bool] =
|
|
2143
|
+
do_sample: Optional[bool] = None,
|
|
2078
2144
|
frequency_penalty: Optional[float] = None,
|
|
2079
2145
|
grammar: Optional[TextGenerationInputGrammarType] = None,
|
|
2080
2146
|
max_new_tokens: Optional[int] = None,
|
|
2081
2147
|
repetition_penalty: Optional[float] = None,
|
|
2082
|
-
return_full_text: Optional[bool] =
|
|
2148
|
+
return_full_text: Optional[bool] = None,
|
|
2083
2149
|
seed: Optional[int] = None,
|
|
2084
|
-
stop: Optional[
|
|
2085
|
-
stop_sequences: Optional[
|
|
2150
|
+
stop: Optional[list[str]] = None,
|
|
2151
|
+
stop_sequences: Optional[list[str]] = None, # Deprecated, use `stop` instead
|
|
2086
2152
|
temperature: Optional[float] = None,
|
|
2087
2153
|
top_k: Optional[int] = None,
|
|
2088
2154
|
top_n_tokens: Optional[int] = None,
|
|
@@ -2094,12 +2160,9 @@ class AsyncInferenceClient:
|
|
|
2094
2160
|
"""
|
|
2095
2161
|
Given a prompt, generate the following text.
|
|
2096
2162
|
|
|
2097
|
-
|
|
2098
|
-
|
|
2099
|
-
|
|
2100
|
-
It accepts a list of messages instead of a single text prompt and handles the chat templating for you.
|
|
2101
|
-
|
|
2102
|
-
</Tip>
|
|
2163
|
+
> [!TIP]
|
|
2164
|
+
> If you want to generate a response from chat messages, you should use the [`InferenceClient.chat_completion`] method.
|
|
2165
|
+
> It accepts a list of messages instead of a single text prompt and handles the chat templating for you.
|
|
2103
2166
|
|
|
2104
2167
|
Args:
|
|
2105
2168
|
prompt (`str`):
|
|
@@ -2138,9 +2201,9 @@ class AsyncInferenceClient:
|
|
|
2138
2201
|
Whether to prepend the prompt to the generated text
|
|
2139
2202
|
seed (`int`, *optional*):
|
|
2140
2203
|
Random sampling seed
|
|
2141
|
-
stop (`
|
|
2204
|
+
stop (`list[str]`, *optional*):
|
|
2142
2205
|
Stop generating tokens if a member of `stop` is generated.
|
|
2143
|
-
stop_sequences (`
|
|
2206
|
+
stop_sequences (`list[str]`, *optional*):
|
|
2144
2207
|
Deprecated argument. Use `stop` instead.
|
|
2145
2208
|
temperature (`float`, *optional*):
|
|
2146
2209
|
The value used to module the logits distribution.
|
|
@@ -2157,14 +2220,14 @@ class AsyncInferenceClient:
|
|
|
2157
2220
|
typical_p (`float`, *optional`):
|
|
2158
2221
|
Typical Decoding mass
|
|
2159
2222
|
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
|
|
2160
|
-
watermark (`bool`, *optional
|
|
2223
|
+
watermark (`bool`, *optional*):
|
|
2161
2224
|
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
|
2162
2225
|
|
|
2163
2226
|
Returns:
|
|
2164
|
-
`Union[str, TextGenerationOutput,
|
|
2227
|
+
`Union[str, TextGenerationOutput, AsyncIterable[str], AsyncIterable[TextGenerationStreamOutput]]`:
|
|
2165
2228
|
Generated text returned from the server:
|
|
2166
2229
|
- if `stream=False` and `details=False`, the generated text is returned as a `str` (default)
|
|
2167
|
-
- if `stream=True` and `details=False`, the generated text is returned token by token as a `
|
|
2230
|
+
- if `stream=True` and `details=False`, the generated text is returned token by token as a `AsyncIterable[str]`
|
|
2168
2231
|
- if `stream=False` and `details=True`, the generated text is returned with more details as a [`~huggingface_hub.TextGenerationOutput`]
|
|
2169
2232
|
- if `details=True` and `stream=True`, the generated text is returned token by token as a iterable of [`~huggingface_hub.TextGenerationStreamOutput`]
|
|
2170
2233
|
|
|
@@ -2173,7 +2236,7 @@ class AsyncInferenceClient:
|
|
|
2173
2236
|
If input values are not valid. No HTTP call is made to the server.
|
|
2174
2237
|
[`InferenceTimeoutError`]:
|
|
2175
2238
|
If the model is unavailable or the request times out.
|
|
2176
|
-
`
|
|
2239
|
+
[`HfHubHTTPError`]:
|
|
2177
2240
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
2178
2241
|
|
|
2179
2242
|
Example:
|
|
@@ -2308,7 +2371,7 @@ class AsyncInferenceClient:
|
|
|
2308
2371
|
"repetition_penalty": repetition_penalty,
|
|
2309
2372
|
"return_full_text": return_full_text,
|
|
2310
2373
|
"seed": seed,
|
|
2311
|
-
"stop": stop
|
|
2374
|
+
"stop": stop,
|
|
2312
2375
|
"temperature": temperature,
|
|
2313
2376
|
"top_k": top_k,
|
|
2314
2377
|
"top_n_tokens": top_n_tokens,
|
|
@@ -2362,10 +2425,10 @@ class AsyncInferenceClient:
|
|
|
2362
2425
|
|
|
2363
2426
|
# Handle errors separately for more precise error messages
|
|
2364
2427
|
try:
|
|
2365
|
-
bytes_output = await self._inner_post(request_parameters, stream=stream)
|
|
2366
|
-
except
|
|
2367
|
-
match = MODEL_KWARGS_NOT_USED_REGEX.search(e
|
|
2368
|
-
if e
|
|
2428
|
+
bytes_output = await self._inner_post(request_parameters, stream=stream or False)
|
|
2429
|
+
except HfHubHTTPError as e:
|
|
2430
|
+
match = MODEL_KWARGS_NOT_USED_REGEX.search(str(e))
|
|
2431
|
+
if isinstance(e, BadRequestError) and match:
|
|
2369
2432
|
unused_params = [kwarg.strip("' ") for kwarg in match.group(1).split(",")]
|
|
2370
2433
|
_set_unsupported_text_generation_kwargs(model, unused_params)
|
|
2371
2434
|
return await self.text_generation( # type: ignore
|
|
@@ -2418,20 +2481,16 @@ class AsyncInferenceClient:
|
|
|
2418
2481
|
model: Optional[str] = None,
|
|
2419
2482
|
scheduler: Optional[str] = None,
|
|
2420
2483
|
seed: Optional[int] = None,
|
|
2421
|
-
extra_body: Optional[
|
|
2484
|
+
extra_body: Optional[dict[str, Any]] = None,
|
|
2422
2485
|
) -> "Image":
|
|
2423
2486
|
"""
|
|
2424
2487
|
Generate an image based on a given text using a specified model.
|
|
2425
2488
|
|
|
2426
|
-
|
|
2489
|
+
> [!WARNING]
|
|
2490
|
+
> You must have `PIL` installed if you want to work with images (`pip install Pillow`).
|
|
2427
2491
|
|
|
2428
|
-
|
|
2429
|
-
|
|
2430
|
-
</Tip>
|
|
2431
|
-
|
|
2432
|
-
<Tip>
|
|
2433
|
-
You can pass provider-specific parameters to the model by using the `extra_body` argument.
|
|
2434
|
-
</Tip>
|
|
2492
|
+
> [!TIP]
|
|
2493
|
+
> You can pass provider-specific parameters to the model by using the `extra_body` argument.
|
|
2435
2494
|
|
|
2436
2495
|
Args:
|
|
2437
2496
|
prompt (`str`):
|
|
@@ -2456,7 +2515,7 @@ class AsyncInferenceClient:
|
|
|
2456
2515
|
Override the scheduler with a compatible one.
|
|
2457
2516
|
seed (`int`, *optional*):
|
|
2458
2517
|
Seed for the random number generator.
|
|
2459
|
-
extra_body (`
|
|
2518
|
+
extra_body (`dict[str, Any]`, *optional*):
|
|
2460
2519
|
Additional provider-specific parameters to pass to the model. Refer to the provider's documentation
|
|
2461
2520
|
for supported parameters.
|
|
2462
2521
|
|
|
@@ -2466,7 +2525,7 @@ class AsyncInferenceClient:
|
|
|
2466
2525
|
Raises:
|
|
2467
2526
|
[`InferenceTimeoutError`]:
|
|
2468
2527
|
If the model is unavailable or the request times out.
|
|
2469
|
-
`
|
|
2528
|
+
[`HfHubHTTPError`]:
|
|
2470
2529
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
2471
2530
|
|
|
2472
2531
|
Example:
|
|
@@ -2527,6 +2586,7 @@ class AsyncInferenceClient:
|
|
|
2527
2586
|
... )
|
|
2528
2587
|
>>> image.save("astronaut.png")
|
|
2529
2588
|
```
|
|
2589
|
+
|
|
2530
2590
|
"""
|
|
2531
2591
|
model_id = model or self.model
|
|
2532
2592
|
provider_helper = get_provider_helper(self.provider, task="text-to-image", model=model_id)
|
|
@@ -2547,7 +2607,7 @@ class AsyncInferenceClient:
|
|
|
2547
2607
|
api_key=self.token,
|
|
2548
2608
|
)
|
|
2549
2609
|
response = await self._inner_post(request_parameters)
|
|
2550
|
-
response = provider_helper.get_response(response)
|
|
2610
|
+
response = provider_helper.get_response(response, request_parameters)
|
|
2551
2611
|
return _bytes_to_image(response)
|
|
2552
2612
|
|
|
2553
2613
|
async def text_to_video(
|
|
@@ -2556,18 +2616,17 @@ class AsyncInferenceClient:
|
|
|
2556
2616
|
*,
|
|
2557
2617
|
model: Optional[str] = None,
|
|
2558
2618
|
guidance_scale: Optional[float] = None,
|
|
2559
|
-
negative_prompt: Optional[
|
|
2619
|
+
negative_prompt: Optional[list[str]] = None,
|
|
2560
2620
|
num_frames: Optional[float] = None,
|
|
2561
2621
|
num_inference_steps: Optional[int] = None,
|
|
2562
2622
|
seed: Optional[int] = None,
|
|
2563
|
-
extra_body: Optional[
|
|
2623
|
+
extra_body: Optional[dict[str, Any]] = None,
|
|
2564
2624
|
) -> bytes:
|
|
2565
2625
|
"""
|
|
2566
2626
|
Generate a video based on a given text.
|
|
2567
2627
|
|
|
2568
|
-
|
|
2569
|
-
You can pass provider-specific parameters to the model by using the `extra_body` argument.
|
|
2570
|
-
</Tip>
|
|
2628
|
+
> [!TIP]
|
|
2629
|
+
> You can pass provider-specific parameters to the model by using the `extra_body` argument.
|
|
2571
2630
|
|
|
2572
2631
|
Args:
|
|
2573
2632
|
prompt (`str`):
|
|
@@ -2579,7 +2638,7 @@ class AsyncInferenceClient:
|
|
|
2579
2638
|
guidance_scale (`float`, *optional*):
|
|
2580
2639
|
A higher guidance scale value encourages the model to generate videos closely linked to the text
|
|
2581
2640
|
prompt, but values too high may cause saturation and other artifacts.
|
|
2582
|
-
negative_prompt (`
|
|
2641
|
+
negative_prompt (`list[str]`, *optional*):
|
|
2583
2642
|
One or several prompt to guide what NOT to include in video generation.
|
|
2584
2643
|
num_frames (`float`, *optional*):
|
|
2585
2644
|
The num_frames parameter determines how many video frames are generated.
|
|
@@ -2588,7 +2647,7 @@ class AsyncInferenceClient:
|
|
|
2588
2647
|
expense of slower inference.
|
|
2589
2648
|
seed (`int`, *optional*):
|
|
2590
2649
|
Seed for the random number generator.
|
|
2591
|
-
extra_body (`
|
|
2650
|
+
extra_body (`dict[str, Any]`, *optional*):
|
|
2592
2651
|
Additional provider-specific parameters to pass to the model. Refer to the provider's documentation
|
|
2593
2652
|
for supported parameters.
|
|
2594
2653
|
|
|
@@ -2626,6 +2685,7 @@ class AsyncInferenceClient:
|
|
|
2626
2685
|
>>> with open("cat.mp4", "wb") as file:
|
|
2627
2686
|
... file.write(video)
|
|
2628
2687
|
```
|
|
2688
|
+
|
|
2629
2689
|
"""
|
|
2630
2690
|
model_id = model or self.model
|
|
2631
2691
|
provider_helper = get_provider_helper(self.provider, task="text-to-video", model=model_id)
|
|
@@ -2668,14 +2728,13 @@ class AsyncInferenceClient:
|
|
|
2668
2728
|
top_p: Optional[float] = None,
|
|
2669
2729
|
typical_p: Optional[float] = None,
|
|
2670
2730
|
use_cache: Optional[bool] = None,
|
|
2671
|
-
extra_body: Optional[
|
|
2731
|
+
extra_body: Optional[dict[str, Any]] = None,
|
|
2672
2732
|
) -> bytes:
|
|
2673
2733
|
"""
|
|
2674
2734
|
Synthesize an audio of a voice pronouncing a given text.
|
|
2675
2735
|
|
|
2676
|
-
|
|
2677
|
-
You can pass provider-specific parameters to the model by using the `extra_body` argument.
|
|
2678
|
-
</Tip>
|
|
2736
|
+
> [!TIP]
|
|
2737
|
+
> You can pass provider-specific parameters to the model by using the `extra_body` argument.
|
|
2679
2738
|
|
|
2680
2739
|
Args:
|
|
2681
2740
|
text (`str`):
|
|
@@ -2730,7 +2789,7 @@ class AsyncInferenceClient:
|
|
|
2730
2789
|
paper](https://hf.co/papers/2202.00666) for more details.
|
|
2731
2790
|
use_cache (`bool`, *optional*):
|
|
2732
2791
|
Whether the model should use the past last key/values attentions to speed up decoding
|
|
2733
|
-
extra_body (`
|
|
2792
|
+
extra_body (`dict[str, Any]`, *optional*):
|
|
2734
2793
|
Additional provider-specific parameters to pass to the model. Refer to the provider's documentation
|
|
2735
2794
|
for supported parameters.
|
|
2736
2795
|
Returns:
|
|
@@ -2739,7 +2798,7 @@ class AsyncInferenceClient:
|
|
|
2739
2798
|
Raises:
|
|
2740
2799
|
[`InferenceTimeoutError`]:
|
|
2741
2800
|
If the model is unavailable or the request times out.
|
|
2742
|
-
`
|
|
2801
|
+
[`HfHubHTTPError`]:
|
|
2743
2802
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
2744
2803
|
|
|
2745
2804
|
Example:
|
|
@@ -2863,9 +2922,9 @@ class AsyncInferenceClient:
|
|
|
2863
2922
|
*,
|
|
2864
2923
|
model: Optional[str] = None,
|
|
2865
2924
|
aggregation_strategy: Optional["TokenClassificationAggregationStrategy"] = None,
|
|
2866
|
-
ignore_labels: Optional[
|
|
2925
|
+
ignore_labels: Optional[list[str]] = None,
|
|
2867
2926
|
stride: Optional[int] = None,
|
|
2868
|
-
) ->
|
|
2927
|
+
) -> list[TokenClassificationOutputElement]:
|
|
2869
2928
|
"""
|
|
2870
2929
|
Perform token classification on the given text.
|
|
2871
2930
|
Usually used for sentence parsing, either grammatical, or Named Entity Recognition (NER) to understand keywords contained within text.
|
|
@@ -2879,18 +2938,18 @@ class AsyncInferenceClient:
|
|
|
2879
2938
|
Defaults to None.
|
|
2880
2939
|
aggregation_strategy (`"TokenClassificationAggregationStrategy"`, *optional*):
|
|
2881
2940
|
The strategy used to fuse tokens based on model predictions
|
|
2882
|
-
ignore_labels (`
|
|
2941
|
+
ignore_labels (`list[str`, *optional*):
|
|
2883
2942
|
A list of labels to ignore
|
|
2884
2943
|
stride (`int`, *optional*):
|
|
2885
2944
|
The number of overlapping tokens between chunks when splitting the input text.
|
|
2886
2945
|
|
|
2887
2946
|
Returns:
|
|
2888
|
-
`
|
|
2947
|
+
`list[TokenClassificationOutputElement]`: List of [`TokenClassificationOutputElement`] items containing the entity group, confidence score, word, start and end index.
|
|
2889
2948
|
|
|
2890
2949
|
Raises:
|
|
2891
2950
|
[`InferenceTimeoutError`]:
|
|
2892
2951
|
If the model is unavailable or the request times out.
|
|
2893
|
-
`
|
|
2952
|
+
[`HfHubHTTPError`]:
|
|
2894
2953
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
2895
2954
|
|
|
2896
2955
|
Example:
|
|
@@ -2942,7 +3001,7 @@ class AsyncInferenceClient:
|
|
|
2942
3001
|
tgt_lang: Optional[str] = None,
|
|
2943
3002
|
clean_up_tokenization_spaces: Optional[bool] = None,
|
|
2944
3003
|
truncation: Optional["TranslationTruncationStrategy"] = None,
|
|
2945
|
-
generate_parameters: Optional[
|
|
3004
|
+
generate_parameters: Optional[dict[str, Any]] = None,
|
|
2946
3005
|
) -> TranslationOutput:
|
|
2947
3006
|
"""
|
|
2948
3007
|
Convert text from one language to another.
|
|
@@ -2967,7 +3026,7 @@ class AsyncInferenceClient:
|
|
|
2967
3026
|
Whether to clean up the potential extra spaces in the text output.
|
|
2968
3027
|
truncation (`"TranslationTruncationStrategy"`, *optional*):
|
|
2969
3028
|
The truncation strategy to use.
|
|
2970
|
-
generate_parameters (`
|
|
3029
|
+
generate_parameters (`dict[str, Any]`, *optional*):
|
|
2971
3030
|
Additional parametrization of the text generation algorithm.
|
|
2972
3031
|
|
|
2973
3032
|
Returns:
|
|
@@ -2976,7 +3035,7 @@ class AsyncInferenceClient:
|
|
|
2976
3035
|
Raises:
|
|
2977
3036
|
[`InferenceTimeoutError`]:
|
|
2978
3037
|
If the model is unavailable or the request times out.
|
|
2979
|
-
`
|
|
3038
|
+
[`HfHubHTTPError`]:
|
|
2980
3039
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
2981
3040
|
`ValueError`:
|
|
2982
3041
|
If only one of the `src_lang` and `tgt_lang` arguments are provided.
|
|
@@ -3030,13 +3089,13 @@ class AsyncInferenceClient:
|
|
|
3030
3089
|
*,
|
|
3031
3090
|
model: Optional[str] = None,
|
|
3032
3091
|
top_k: Optional[int] = None,
|
|
3033
|
-
) ->
|
|
3092
|
+
) -> list[VisualQuestionAnsweringOutputElement]:
|
|
3034
3093
|
"""
|
|
3035
3094
|
Answering open-ended questions based on an image.
|
|
3036
3095
|
|
|
3037
3096
|
Args:
|
|
3038
|
-
image (`Union[str, Path, bytes, BinaryIO]`):
|
|
3039
|
-
The input image for the context. It can be raw bytes, an image file,
|
|
3097
|
+
image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`):
|
|
3098
|
+
The input image for the context. It can be raw bytes, an image file, a URL to an online image, or a PIL Image.
|
|
3040
3099
|
question (`str`):
|
|
3041
3100
|
Question to be answered.
|
|
3042
3101
|
model (`str`, *optional*):
|
|
@@ -3047,12 +3106,12 @@ class AsyncInferenceClient:
|
|
|
3047
3106
|
The number of answers to return (will be chosen by order of likelihood). Note that we return less than
|
|
3048
3107
|
topk answers if there are not enough options available within the context.
|
|
3049
3108
|
Returns:
|
|
3050
|
-
`
|
|
3109
|
+
`list[VisualQuestionAnsweringOutputElement]`: a list of [`VisualQuestionAnsweringOutputElement`] items containing the predicted label and associated probability.
|
|
3051
3110
|
|
|
3052
3111
|
Raises:
|
|
3053
3112
|
`InferenceTimeoutError`:
|
|
3054
3113
|
If the model is unavailable or the request times out.
|
|
3055
|
-
`
|
|
3114
|
+
[`HfHubHTTPError`]:
|
|
3056
3115
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
3057
3116
|
|
|
3058
3117
|
Example:
|
|
@@ -3086,21 +3145,21 @@ class AsyncInferenceClient:
|
|
|
3086
3145
|
async def zero_shot_classification(
|
|
3087
3146
|
self,
|
|
3088
3147
|
text: str,
|
|
3089
|
-
candidate_labels:
|
|
3148
|
+
candidate_labels: list[str],
|
|
3090
3149
|
*,
|
|
3091
3150
|
multi_label: Optional[bool] = False,
|
|
3092
3151
|
hypothesis_template: Optional[str] = None,
|
|
3093
3152
|
model: Optional[str] = None,
|
|
3094
|
-
) ->
|
|
3153
|
+
) -> list[ZeroShotClassificationOutputElement]:
|
|
3095
3154
|
"""
|
|
3096
3155
|
Provide as input a text and a set of candidate labels to classify the input text.
|
|
3097
3156
|
|
|
3098
3157
|
Args:
|
|
3099
3158
|
text (`str`):
|
|
3100
3159
|
The input text to classify.
|
|
3101
|
-
candidate_labels (`
|
|
3160
|
+
candidate_labels (`list[str]`):
|
|
3102
3161
|
The set of possible class labels to classify the text into.
|
|
3103
|
-
labels (`
|
|
3162
|
+
labels (`list[str]`, *optional*):
|
|
3104
3163
|
(deprecated) List of strings. Each string is the verbalization of a possible label for the input text.
|
|
3105
3164
|
multi_label (`bool`, *optional*):
|
|
3106
3165
|
Whether multiple candidate labels can be true. If false, the scores are normalized such that the sum of
|
|
@@ -3115,12 +3174,12 @@ class AsyncInferenceClient:
|
|
|
3115
3174
|
|
|
3116
3175
|
|
|
3117
3176
|
Returns:
|
|
3118
|
-
`
|
|
3177
|
+
`list[ZeroShotClassificationOutputElement]`: List of [`ZeroShotClassificationOutputElement`] items containing the predicted labels and their confidence.
|
|
3119
3178
|
|
|
3120
3179
|
Raises:
|
|
3121
3180
|
[`InferenceTimeoutError`]:
|
|
3122
3181
|
If the model is unavailable or the request times out.
|
|
3123
|
-
`
|
|
3182
|
+
[`HfHubHTTPError`]:
|
|
3124
3183
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
3125
3184
|
|
|
3126
3185
|
Example with `multi_label=False`:
|
|
@@ -3194,22 +3253,22 @@ class AsyncInferenceClient:
|
|
|
3194
3253
|
async def zero_shot_image_classification(
|
|
3195
3254
|
self,
|
|
3196
3255
|
image: ContentT,
|
|
3197
|
-
candidate_labels:
|
|
3256
|
+
candidate_labels: list[str],
|
|
3198
3257
|
*,
|
|
3199
3258
|
model: Optional[str] = None,
|
|
3200
3259
|
hypothesis_template: Optional[str] = None,
|
|
3201
3260
|
# deprecated argument
|
|
3202
|
-
labels:
|
|
3203
|
-
) ->
|
|
3261
|
+
labels: list[str] = None, # type: ignore
|
|
3262
|
+
) -> list[ZeroShotImageClassificationOutputElement]:
|
|
3204
3263
|
"""
|
|
3205
3264
|
Provide input image and text labels to predict text labels for the image.
|
|
3206
3265
|
|
|
3207
3266
|
Args:
|
|
3208
|
-
image (`Union[str, Path, bytes, BinaryIO]`):
|
|
3209
|
-
The input image to caption. It can be raw bytes, an image file,
|
|
3210
|
-
candidate_labels (`
|
|
3267
|
+
image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`):
|
|
3268
|
+
The input image to caption. It can be raw bytes, an image file, a URL to an online image, or a PIL Image.
|
|
3269
|
+
candidate_labels (`list[str]`):
|
|
3211
3270
|
The candidate labels for this image
|
|
3212
|
-
labels (`
|
|
3271
|
+
labels (`list[str]`, *optional*):
|
|
3213
3272
|
(deprecated) List of string possible labels. There must be at least 2 labels.
|
|
3214
3273
|
model (`str`, *optional*):
|
|
3215
3274
|
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
|
@@ -3219,12 +3278,12 @@ class AsyncInferenceClient:
|
|
|
3219
3278
|
replacing the placeholder with the candidate labels.
|
|
3220
3279
|
|
|
3221
3280
|
Returns:
|
|
3222
|
-
`
|
|
3281
|
+
`list[ZeroShotImageClassificationOutputElement]`: List of [`ZeroShotImageClassificationOutputElement`] items containing the predicted labels and their confidence.
|
|
3223
3282
|
|
|
3224
3283
|
Raises:
|
|
3225
3284
|
[`InferenceTimeoutError`]:
|
|
3226
3285
|
If the model is unavailable or the request times out.
|
|
3227
|
-
`
|
|
3286
|
+
[`HfHubHTTPError`]:
|
|
3228
3287
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
3229
3288
|
|
|
3230
3289
|
Example:
|
|
@@ -3259,144 +3318,7 @@ class AsyncInferenceClient:
|
|
|
3259
3318
|
response = await self._inner_post(request_parameters)
|
|
3260
3319
|
return ZeroShotImageClassificationOutputElement.parse_obj_as_list(response)
|
|
3261
3320
|
|
|
3262
|
-
|
|
3263
|
-
version="0.33.0",
|
|
3264
|
-
message=(
|
|
3265
|
-
"HF Inference API is getting revamped and will only support warm models in the future (no cold start allowed)."
|
|
3266
|
-
" Use `HfApi.list_models(..., inference_provider='...')` to list warm models per provider."
|
|
3267
|
-
),
|
|
3268
|
-
)
|
|
3269
|
-
async def list_deployed_models(
|
|
3270
|
-
self, frameworks: Union[None, str, Literal["all"], List[str]] = None
|
|
3271
|
-
) -> Dict[str, List[str]]:
|
|
3272
|
-
"""
|
|
3273
|
-
List models deployed on the HF Serverless Inference API service.
|
|
3274
|
-
|
|
3275
|
-
This helper checks deployed models framework by framework. By default, it will check the 4 main frameworks that
|
|
3276
|
-
are supported and account for 95% of the hosted models. However, if you want a complete list of models you can
|
|
3277
|
-
specify `frameworks="all"` as input. Alternatively, if you know before-hand which framework you are interested
|
|
3278
|
-
in, you can also restrict to search to this one (e.g. `frameworks="text-generation-inference"`). The more
|
|
3279
|
-
frameworks are checked, the more time it will take.
|
|
3280
|
-
|
|
3281
|
-
<Tip warning={true}>
|
|
3282
|
-
|
|
3283
|
-
This endpoint method does not return a live list of all models available for the HF Inference API service.
|
|
3284
|
-
It searches over a cached list of models that were recently available and the list may not be up to date.
|
|
3285
|
-
If you want to know the live status of a specific model, use [`~InferenceClient.get_model_status`].
|
|
3286
|
-
|
|
3287
|
-
</Tip>
|
|
3288
|
-
|
|
3289
|
-
<Tip>
|
|
3290
|
-
|
|
3291
|
-
This endpoint method is mostly useful for discoverability. If you already know which model you want to use and want to
|
|
3292
|
-
check its availability, you can directly use [`~InferenceClient.get_model_status`].
|
|
3293
|
-
|
|
3294
|
-
</Tip>
|
|
3295
|
-
|
|
3296
|
-
Args:
|
|
3297
|
-
frameworks (`Literal["all"]` or `List[str]` or `str`, *optional*):
|
|
3298
|
-
The frameworks to filter on. By default only a subset of the available frameworks are tested. If set to
|
|
3299
|
-
"all", all available frameworks will be tested. It is also possible to provide a single framework or a
|
|
3300
|
-
custom set of frameworks to check.
|
|
3301
|
-
|
|
3302
|
-
Returns:
|
|
3303
|
-
`Dict[str, List[str]]`: A dictionary mapping task names to a sorted list of model IDs.
|
|
3304
|
-
|
|
3305
|
-
Example:
|
|
3306
|
-
```py
|
|
3307
|
-
# Must be run in an async contextthon
|
|
3308
|
-
>>> from huggingface_hub import AsyncInferenceClient
|
|
3309
|
-
>>> client = AsyncInferenceClient()
|
|
3310
|
-
|
|
3311
|
-
# Discover zero-shot-classification models currently deployed
|
|
3312
|
-
>>> models = await client.list_deployed_models()
|
|
3313
|
-
>>> models["zero-shot-classification"]
|
|
3314
|
-
['Narsil/deberta-large-mnli-zero-cls', 'facebook/bart-large-mnli', ...]
|
|
3315
|
-
|
|
3316
|
-
# List from only 1 framework
|
|
3317
|
-
>>> await client.list_deployed_models("text-generation-inference")
|
|
3318
|
-
{'text-generation': ['bigcode/starcoder', 'meta-llama/Llama-2-70b-chat-hf', ...], ...}
|
|
3319
|
-
```
|
|
3320
|
-
"""
|
|
3321
|
-
if self.provider != "hf-inference":
|
|
3322
|
-
raise ValueError(f"Listing deployed models is not supported on '{self.provider}'.")
|
|
3323
|
-
|
|
3324
|
-
# Resolve which frameworks to check
|
|
3325
|
-
if frameworks is None:
|
|
3326
|
-
frameworks = constants.MAIN_INFERENCE_API_FRAMEWORKS
|
|
3327
|
-
elif frameworks == "all":
|
|
3328
|
-
frameworks = constants.ALL_INFERENCE_API_FRAMEWORKS
|
|
3329
|
-
elif isinstance(frameworks, str):
|
|
3330
|
-
frameworks = [frameworks]
|
|
3331
|
-
frameworks = list(set(frameworks))
|
|
3332
|
-
|
|
3333
|
-
# Fetch them iteratively
|
|
3334
|
-
models_by_task: Dict[str, List[str]] = {}
|
|
3335
|
-
|
|
3336
|
-
def _unpack_response(framework: str, items: List[Dict]) -> None:
|
|
3337
|
-
for model in items:
|
|
3338
|
-
if framework == "sentence-transformers":
|
|
3339
|
-
# Model running with the `sentence-transformers` framework can work with both tasks even if not
|
|
3340
|
-
# branded as such in the API response
|
|
3341
|
-
models_by_task.setdefault("feature-extraction", []).append(model["model_id"])
|
|
3342
|
-
models_by_task.setdefault("sentence-similarity", []).append(model["model_id"])
|
|
3343
|
-
else:
|
|
3344
|
-
models_by_task.setdefault(model["task"], []).append(model["model_id"])
|
|
3345
|
-
|
|
3346
|
-
for framework in frameworks:
|
|
3347
|
-
response = get_session().get(
|
|
3348
|
-
f"{constants.INFERENCE_ENDPOINT}/framework/{framework}", headers=build_hf_headers(token=self.token)
|
|
3349
|
-
)
|
|
3350
|
-
hf_raise_for_status(response)
|
|
3351
|
-
_unpack_response(framework, response.json())
|
|
3352
|
-
|
|
3353
|
-
# Sort alphabetically for discoverability and return
|
|
3354
|
-
for task, models in models_by_task.items():
|
|
3355
|
-
models_by_task[task] = sorted(set(models), key=lambda x: x.lower())
|
|
3356
|
-
return models_by_task
|
|
3357
|
-
|
|
3358
|
-
def _get_client_session(self, headers: Optional[Dict] = None) -> "ClientSession":
|
|
3359
|
-
aiohttp = _import_aiohttp()
|
|
3360
|
-
client_headers = self.headers.copy()
|
|
3361
|
-
if headers is not None:
|
|
3362
|
-
client_headers.update(headers)
|
|
3363
|
-
|
|
3364
|
-
# Return a new aiohttp ClientSession with correct settings.
|
|
3365
|
-
session = aiohttp.ClientSession(
|
|
3366
|
-
headers=client_headers,
|
|
3367
|
-
cookies=self.cookies,
|
|
3368
|
-
timeout=aiohttp.ClientTimeout(self.timeout),
|
|
3369
|
-
trust_env=self.trust_env,
|
|
3370
|
-
)
|
|
3371
|
-
|
|
3372
|
-
# Keep track of sessions to close them later
|
|
3373
|
-
self._sessions[session] = set()
|
|
3374
|
-
|
|
3375
|
-
# Override the `._request` method to register responses to be closed
|
|
3376
|
-
session._wrapped_request = session._request
|
|
3377
|
-
|
|
3378
|
-
async def _request(method, url, **kwargs):
|
|
3379
|
-
response = await session._wrapped_request(method, url, **kwargs)
|
|
3380
|
-
self._sessions[session].add(response)
|
|
3381
|
-
return response
|
|
3382
|
-
|
|
3383
|
-
session._request = _request
|
|
3384
|
-
|
|
3385
|
-
# Override the 'close' method to
|
|
3386
|
-
# 1. close ongoing responses
|
|
3387
|
-
# 2. deregister the session when closed
|
|
3388
|
-
session._close = session.close
|
|
3389
|
-
|
|
3390
|
-
async def close_session():
|
|
3391
|
-
for response in self._sessions[session]:
|
|
3392
|
-
response.close()
|
|
3393
|
-
await session._close()
|
|
3394
|
-
self._sessions.pop(session, None)
|
|
3395
|
-
|
|
3396
|
-
session.close = close_session
|
|
3397
|
-
return session
|
|
3398
|
-
|
|
3399
|
-
async def get_endpoint_info(self, *, model: Optional[str] = None) -> Dict[str, Any]:
|
|
3321
|
+
async def get_endpoint_info(self, *, model: Optional[str] = None) -> dict[str, Any]:
|
|
3400
3322
|
"""
|
|
3401
3323
|
Get information about the deployed endpoint.
|
|
3402
3324
|
|
|
@@ -3409,7 +3331,7 @@ class AsyncInferenceClient:
|
|
|
3409
3331
|
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
|
|
3410
3332
|
|
|
3411
3333
|
Returns:
|
|
3412
|
-
`
|
|
3334
|
+
`dict[str, Any]`: Information about the endpoint.
|
|
3413
3335
|
|
|
3414
3336
|
Example:
|
|
3415
3337
|
```py
|
|
@@ -3451,17 +3373,16 @@ class AsyncInferenceClient:
|
|
|
3451
3373
|
else:
|
|
3452
3374
|
url = f"{constants.INFERENCE_ENDPOINT}/models/{model}/info"
|
|
3453
3375
|
|
|
3454
|
-
|
|
3455
|
-
|
|
3456
|
-
|
|
3457
|
-
|
|
3376
|
+
client = await self._get_async_client()
|
|
3377
|
+
response = await client.get(url, headers=build_hf_headers(token=self.token))
|
|
3378
|
+
hf_raise_for_status(response)
|
|
3379
|
+
return response.json()
|
|
3458
3380
|
|
|
3459
3381
|
async def health_check(self, model: Optional[str] = None) -> bool:
|
|
3460
3382
|
"""
|
|
3461
3383
|
Check the health of the deployed endpoint.
|
|
3462
3384
|
|
|
3463
3385
|
Health check is only available with Inference Endpoints powered by Text-Generation-Inference (TGI) or Text-Embedding-Inference (TEI).
|
|
3464
|
-
For Inference API, please use [`InferenceClient.get_model_status`] instead.
|
|
3465
3386
|
|
|
3466
3387
|
Args:
|
|
3467
3388
|
model (`str`, *optional*):
|
|
@@ -3486,77 +3407,12 @@ class AsyncInferenceClient:
|
|
|
3486
3407
|
if model is None:
|
|
3487
3408
|
raise ValueError("Model id not provided.")
|
|
3488
3409
|
if not model.startswith(("http://", "https://")):
|
|
3489
|
-
raise ValueError(
|
|
3490
|
-
"Model must be an Inference Endpoint URL. For serverless Inference API, please use `InferenceClient.get_model_status`."
|
|
3491
|
-
)
|
|
3410
|
+
raise ValueError("Model must be an Inference Endpoint URL.")
|
|
3492
3411
|
url = model.rstrip("/") + "/health"
|
|
3493
3412
|
|
|
3494
|
-
|
|
3495
|
-
|
|
3496
|
-
|
|
3497
|
-
|
|
3498
|
-
@_deprecate_method(
|
|
3499
|
-
version="0.33.0",
|
|
3500
|
-
message=(
|
|
3501
|
-
"HF Inference API is getting revamped and will only support warm models in the future (no cold start allowed)."
|
|
3502
|
-
" Use `HfApi.model_info` to get the model status both with HF Inference API and external providers."
|
|
3503
|
-
),
|
|
3504
|
-
)
|
|
3505
|
-
async def get_model_status(self, model: Optional[str] = None) -> ModelStatus:
|
|
3506
|
-
"""
|
|
3507
|
-
Get the status of a model hosted on the HF Inference API.
|
|
3508
|
-
|
|
3509
|
-
<Tip>
|
|
3510
|
-
|
|
3511
|
-
This endpoint is mostly useful when you already know which model you want to use and want to check its
|
|
3512
|
-
availability. If you want to discover already deployed models, you should rather use [`~InferenceClient.list_deployed_models`].
|
|
3513
|
-
|
|
3514
|
-
</Tip>
|
|
3515
|
-
|
|
3516
|
-
Args:
|
|
3517
|
-
model (`str`, *optional*):
|
|
3518
|
-
Identifier of the model for witch the status gonna be checked. If model is not provided,
|
|
3519
|
-
the model associated with this instance of [`InferenceClient`] will be used. Only HF Inference API service can be checked so the
|
|
3520
|
-
identifier cannot be a URL.
|
|
3521
|
-
|
|
3522
|
-
|
|
3523
|
-
Returns:
|
|
3524
|
-
[`ModelStatus`]: An instance of ModelStatus dataclass, containing information,
|
|
3525
|
-
about the state of the model: load, state, compute type and framework.
|
|
3526
|
-
|
|
3527
|
-
Example:
|
|
3528
|
-
```py
|
|
3529
|
-
# Must be run in an async context
|
|
3530
|
-
>>> from huggingface_hub import AsyncInferenceClient
|
|
3531
|
-
>>> client = AsyncInferenceClient()
|
|
3532
|
-
>>> await client.get_model_status("meta-llama/Meta-Llama-3-8B-Instruct")
|
|
3533
|
-
ModelStatus(loaded=True, state='Loaded', compute_type='gpu', framework='text-generation-inference')
|
|
3534
|
-
```
|
|
3535
|
-
"""
|
|
3536
|
-
if self.provider != "hf-inference":
|
|
3537
|
-
raise ValueError(f"Getting model status is not supported on '{self.provider}'.")
|
|
3538
|
-
|
|
3539
|
-
model = model or self.model
|
|
3540
|
-
if model is None:
|
|
3541
|
-
raise ValueError("Model id not provided.")
|
|
3542
|
-
if model.startswith("https://"):
|
|
3543
|
-
raise NotImplementedError("Model status is only available for Inference API endpoints.")
|
|
3544
|
-
url = f"{constants.INFERENCE_ENDPOINT}/status/{model}"
|
|
3545
|
-
|
|
3546
|
-
async with self._get_client_session(headers=build_hf_headers(token=self.token)) as client:
|
|
3547
|
-
response = await client.get(url, proxy=self.proxies)
|
|
3548
|
-
response.raise_for_status()
|
|
3549
|
-
response_data = await response.json()
|
|
3550
|
-
|
|
3551
|
-
if "error" in response_data:
|
|
3552
|
-
raise ValueError(response_data["error"])
|
|
3553
|
-
|
|
3554
|
-
return ModelStatus(
|
|
3555
|
-
loaded=response_data["loaded"],
|
|
3556
|
-
state=response_data["state"],
|
|
3557
|
-
compute_type=response_data["compute_type"],
|
|
3558
|
-
framework=response_data["framework"],
|
|
3559
|
-
)
|
|
3413
|
+
client = await self._get_async_client()
|
|
3414
|
+
response = await client.get(url, headers=build_hf_headers(token=self.token))
|
|
3415
|
+
return response.status_code == 200
|
|
3560
3416
|
|
|
3561
3417
|
@property
|
|
3562
3418
|
def chat(self) -> "ProxyClientChat":
|