huggingface-hub 0.29.0rc2__py3-none-any.whl → 1.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- huggingface_hub/__init__.py +160 -46
- huggingface_hub/_commit_api.py +277 -71
- huggingface_hub/_commit_scheduler.py +15 -15
- huggingface_hub/_inference_endpoints.py +33 -22
- huggingface_hub/_jobs_api.py +301 -0
- huggingface_hub/_local_folder.py +18 -3
- huggingface_hub/_login.py +31 -63
- huggingface_hub/_oauth.py +460 -0
- huggingface_hub/_snapshot_download.py +241 -81
- huggingface_hub/_space_api.py +18 -10
- huggingface_hub/_tensorboard_logger.py +15 -19
- huggingface_hub/_upload_large_folder.py +196 -76
- huggingface_hub/_webhooks_payload.py +3 -3
- huggingface_hub/_webhooks_server.py +15 -25
- huggingface_hub/{commands → cli}/__init__.py +1 -15
- huggingface_hub/cli/_cli_utils.py +173 -0
- huggingface_hub/cli/auth.py +147 -0
- huggingface_hub/cli/cache.py +841 -0
- huggingface_hub/cli/download.py +189 -0
- huggingface_hub/cli/hf.py +60 -0
- huggingface_hub/cli/inference_endpoints.py +377 -0
- huggingface_hub/cli/jobs.py +772 -0
- huggingface_hub/cli/lfs.py +175 -0
- huggingface_hub/cli/repo.py +315 -0
- huggingface_hub/cli/repo_files.py +94 -0
- huggingface_hub/{commands/env.py → cli/system.py} +10 -13
- huggingface_hub/cli/upload.py +294 -0
- huggingface_hub/cli/upload_large_folder.py +117 -0
- huggingface_hub/community.py +20 -12
- huggingface_hub/constants.py +83 -59
- huggingface_hub/dataclasses.py +609 -0
- huggingface_hub/errors.py +99 -30
- huggingface_hub/fastai_utils.py +30 -41
- huggingface_hub/file_download.py +606 -346
- huggingface_hub/hf_api.py +2445 -1132
- huggingface_hub/hf_file_system.py +269 -152
- huggingface_hub/hub_mixin.py +61 -66
- huggingface_hub/inference/_client.py +501 -630
- huggingface_hub/inference/_common.py +133 -121
- huggingface_hub/inference/_generated/_async_client.py +536 -722
- huggingface_hub/inference/_generated/types/__init__.py +6 -1
- huggingface_hub/inference/_generated/types/automatic_speech_recognition.py +5 -6
- huggingface_hub/inference/_generated/types/base.py +10 -7
- huggingface_hub/inference/_generated/types/chat_completion.py +77 -31
- huggingface_hub/inference/_generated/types/depth_estimation.py +2 -2
- huggingface_hub/inference/_generated/types/document_question_answering.py +2 -2
- huggingface_hub/inference/_generated/types/feature_extraction.py +2 -2
- huggingface_hub/inference/_generated/types/fill_mask.py +2 -2
- huggingface_hub/inference/_generated/types/image_to_image.py +8 -2
- huggingface_hub/inference/_generated/types/image_to_text.py +2 -3
- huggingface_hub/inference/_generated/types/image_to_video.py +60 -0
- huggingface_hub/inference/_generated/types/sentence_similarity.py +3 -3
- huggingface_hub/inference/_generated/types/summarization.py +2 -2
- huggingface_hub/inference/_generated/types/table_question_answering.py +5 -5
- huggingface_hub/inference/_generated/types/text2text_generation.py +2 -2
- huggingface_hub/inference/_generated/types/text_generation.py +11 -11
- huggingface_hub/inference/_generated/types/text_to_audio.py +1 -2
- huggingface_hub/inference/_generated/types/text_to_speech.py +1 -2
- huggingface_hub/inference/_generated/types/text_to_video.py +2 -2
- huggingface_hub/inference/_generated/types/token_classification.py +2 -2
- huggingface_hub/inference/_generated/types/translation.py +2 -2
- huggingface_hub/inference/_generated/types/zero_shot_classification.py +2 -2
- huggingface_hub/inference/_generated/types/zero_shot_image_classification.py +2 -2
- huggingface_hub/inference/_generated/types/zero_shot_object_detection.py +1 -3
- huggingface_hub/inference/_mcp/__init__.py +0 -0
- huggingface_hub/inference/_mcp/_cli_hacks.py +88 -0
- huggingface_hub/inference/_mcp/agent.py +100 -0
- huggingface_hub/inference/_mcp/cli.py +247 -0
- huggingface_hub/inference/_mcp/constants.py +81 -0
- huggingface_hub/inference/_mcp/mcp_client.py +395 -0
- huggingface_hub/inference/_mcp/types.py +45 -0
- huggingface_hub/inference/_mcp/utils.py +128 -0
- huggingface_hub/inference/_providers/__init__.py +149 -20
- huggingface_hub/inference/_providers/_common.py +160 -37
- huggingface_hub/inference/_providers/black_forest_labs.py +12 -9
- huggingface_hub/inference/_providers/cerebras.py +6 -0
- huggingface_hub/inference/_providers/clarifai.py +13 -0
- huggingface_hub/inference/_providers/cohere.py +32 -0
- huggingface_hub/inference/_providers/fal_ai.py +231 -22
- huggingface_hub/inference/_providers/featherless_ai.py +38 -0
- huggingface_hub/inference/_providers/fireworks_ai.py +22 -1
- huggingface_hub/inference/_providers/groq.py +9 -0
- huggingface_hub/inference/_providers/hf_inference.py +143 -33
- huggingface_hub/inference/_providers/hyperbolic.py +9 -5
- huggingface_hub/inference/_providers/nebius.py +47 -5
- huggingface_hub/inference/_providers/novita.py +48 -5
- huggingface_hub/inference/_providers/nscale.py +44 -0
- huggingface_hub/inference/_providers/openai.py +25 -0
- huggingface_hub/inference/_providers/publicai.py +6 -0
- huggingface_hub/inference/_providers/replicate.py +46 -9
- huggingface_hub/inference/_providers/sambanova.py +37 -1
- huggingface_hub/inference/_providers/scaleway.py +28 -0
- huggingface_hub/inference/_providers/together.py +34 -5
- huggingface_hub/inference/_providers/wavespeed.py +138 -0
- huggingface_hub/inference/_providers/zai_org.py +17 -0
- huggingface_hub/lfs.py +33 -100
- huggingface_hub/repocard.py +34 -38
- huggingface_hub/repocard_data.py +79 -59
- huggingface_hub/serialization/__init__.py +0 -1
- huggingface_hub/serialization/_base.py +12 -15
- huggingface_hub/serialization/_dduf.py +8 -8
- huggingface_hub/serialization/_torch.py +69 -69
- huggingface_hub/utils/__init__.py +27 -8
- huggingface_hub/utils/_auth.py +7 -7
- huggingface_hub/utils/_cache_manager.py +92 -147
- huggingface_hub/utils/_chunk_utils.py +2 -3
- huggingface_hub/utils/_deprecation.py +1 -1
- huggingface_hub/utils/_dotenv.py +55 -0
- huggingface_hub/utils/_experimental.py +7 -5
- huggingface_hub/utils/_fixes.py +0 -10
- huggingface_hub/utils/_git_credential.py +5 -5
- huggingface_hub/utils/_headers.py +8 -30
- huggingface_hub/utils/_http.py +399 -237
- huggingface_hub/utils/_pagination.py +6 -6
- huggingface_hub/utils/_parsing.py +98 -0
- huggingface_hub/utils/_paths.py +5 -5
- huggingface_hub/utils/_runtime.py +74 -22
- huggingface_hub/utils/_safetensors.py +21 -21
- huggingface_hub/utils/_subprocess.py +13 -11
- huggingface_hub/utils/_telemetry.py +4 -4
- huggingface_hub/{commands/_cli_utils.py → utils/_terminal.py} +4 -4
- huggingface_hub/utils/_typing.py +25 -5
- huggingface_hub/utils/_validators.py +55 -74
- huggingface_hub/utils/_verification.py +167 -0
- huggingface_hub/utils/_xet.py +235 -0
- huggingface_hub/utils/_xet_progress_reporting.py +162 -0
- huggingface_hub/utils/insecure_hashlib.py +3 -5
- huggingface_hub/utils/logging.py +8 -11
- huggingface_hub/utils/tqdm.py +33 -4
- {huggingface_hub-0.29.0rc2.dist-info → huggingface_hub-1.1.3.dist-info}/METADATA +94 -82
- huggingface_hub-1.1.3.dist-info/RECORD +155 -0
- {huggingface_hub-0.29.0rc2.dist-info → huggingface_hub-1.1.3.dist-info}/WHEEL +1 -1
- huggingface_hub-1.1.3.dist-info/entry_points.txt +6 -0
- huggingface_hub/commands/delete_cache.py +0 -428
- huggingface_hub/commands/download.py +0 -200
- huggingface_hub/commands/huggingface_cli.py +0 -61
- huggingface_hub/commands/lfs.py +0 -200
- huggingface_hub/commands/repo_files.py +0 -128
- huggingface_hub/commands/scan_cache.py +0 -181
- huggingface_hub/commands/tag.py +0 -159
- huggingface_hub/commands/upload.py +0 -299
- huggingface_hub/commands/upload_large_folder.py +0 -129
- huggingface_hub/commands/user.py +0 -304
- huggingface_hub/commands/version.py +0 -37
- huggingface_hub/inference_api.py +0 -217
- huggingface_hub/keras_mixin.py +0 -500
- huggingface_hub/repository.py +0 -1477
- huggingface_hub/serialization/_tensorflow.py +0 -95
- huggingface_hub/utils/_hf_folder.py +0 -68
- huggingface_hub-0.29.0rc2.dist-info/RECORD +0 -131
- huggingface_hub-0.29.0rc2.dist-info/entry_points.txt +0 -6
- {huggingface_hub-0.29.0rc2.dist-info → huggingface_hub-1.1.3.dist-info/licenses}/LICENSE +0 -0
- {huggingface_hub-0.29.0rc2.dist-info → huggingface_hub-1.1.3.dist-info}/top_level.txt +0 -0
|
@@ -21,16 +21,19 @@
|
|
|
21
21
|
import asyncio
|
|
22
22
|
import base64
|
|
23
23
|
import logging
|
|
24
|
+
import os
|
|
24
25
|
import re
|
|
25
26
|
import warnings
|
|
26
|
-
from
|
|
27
|
+
from contextlib import AsyncExitStack
|
|
28
|
+
from typing import TYPE_CHECKING, Any, AsyncIterable, Literal, Optional, Union, overload
|
|
29
|
+
|
|
30
|
+
import httpx
|
|
27
31
|
|
|
28
32
|
from huggingface_hub import constants
|
|
29
|
-
from huggingface_hub.errors import InferenceTimeoutError
|
|
33
|
+
from huggingface_hub.errors import BadRequestError, HfHubHTTPError, InferenceTimeoutError
|
|
30
34
|
from huggingface_hub.inference._common import (
|
|
31
35
|
TASKS_EXPECTING_IMAGES,
|
|
32
36
|
ContentT,
|
|
33
|
-
ModelStatus,
|
|
34
37
|
RequestParameters,
|
|
35
38
|
_async_stream_chat_completion_response,
|
|
36
39
|
_async_stream_text_generation_response,
|
|
@@ -41,7 +44,6 @@ from huggingface_hub.inference._common import (
|
|
|
41
44
|
_bytes_to_list,
|
|
42
45
|
_get_unsupported_text_generation_kwargs,
|
|
43
46
|
_import_numpy,
|
|
44
|
-
_open_as_binary,
|
|
45
47
|
_set_unsupported_text_generation_kwargs,
|
|
46
48
|
raise_text_generation_error,
|
|
47
49
|
)
|
|
@@ -51,6 +53,7 @@ from huggingface_hub.inference._generated.types import (
|
|
|
51
53
|
AudioToAudioOutputElement,
|
|
52
54
|
AutomaticSpeechRecognitionOutput,
|
|
53
55
|
ChatCompletionInputGrammarType,
|
|
56
|
+
ChatCompletionInputMessage,
|
|
54
57
|
ChatCompletionInputStreamOptions,
|
|
55
58
|
ChatCompletionInputTool,
|
|
56
59
|
ChatCompletionInputToolChoiceClass,
|
|
@@ -65,6 +68,7 @@ from huggingface_hub.inference._generated.types import (
|
|
|
65
68
|
ImageSegmentationSubtask,
|
|
66
69
|
ImageToImageTargetSize,
|
|
67
70
|
ImageToTextOutput,
|
|
71
|
+
ImageToVideoTargetSize,
|
|
68
72
|
ObjectDetectionOutputElement,
|
|
69
73
|
Padding,
|
|
70
74
|
QuestionAnsweringOutputElement,
|
|
@@ -85,16 +89,20 @@ from huggingface_hub.inference._generated.types import (
|
|
|
85
89
|
ZeroShotClassificationOutputElement,
|
|
86
90
|
ZeroShotImageClassificationOutputElement,
|
|
87
91
|
)
|
|
88
|
-
from huggingface_hub.inference._providers import
|
|
89
|
-
from huggingface_hub.utils import
|
|
90
|
-
|
|
92
|
+
from huggingface_hub.inference._providers import PROVIDER_OR_POLICY_T, get_provider_helper
|
|
93
|
+
from huggingface_hub.utils import (
|
|
94
|
+
build_hf_headers,
|
|
95
|
+
get_async_session,
|
|
96
|
+
hf_raise_for_status,
|
|
97
|
+
validate_hf_hub_args,
|
|
98
|
+
)
|
|
99
|
+
from huggingface_hub.utils._auth import get_token
|
|
91
100
|
|
|
92
|
-
from .._common import _async_yield_from
|
|
101
|
+
from .._common import _async_yield_from
|
|
93
102
|
|
|
94
103
|
|
|
95
104
|
if TYPE_CHECKING:
|
|
96
105
|
import numpy as np
|
|
97
|
-
from aiohttp import ClientResponse, ClientSession
|
|
98
106
|
from PIL.Image import Image
|
|
99
107
|
|
|
100
108
|
logger = logging.getLogger(__name__)
|
|
@@ -116,30 +124,25 @@ class AsyncInferenceClient:
|
|
|
116
124
|
or a URL to a deployed Inference Endpoint. Defaults to None, in which case a recommended model is
|
|
117
125
|
automatically selected for the task.
|
|
118
126
|
Note: for better compatibility with OpenAI's client, `model` has been aliased as `base_url`. Those 2
|
|
119
|
-
arguments are mutually exclusive. If
|
|
120
|
-
path will be appended to the base URL (see the [TGI Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api)
|
|
121
|
-
documentation for details). When passing a URL as `model`, the client will not append any suffix path to it.
|
|
127
|
+
arguments are mutually exclusive. If a URL is passed as `model` or `base_url` for chat completion, the `(/v1)/chat/completions` suffix path will be appended to the URL.
|
|
122
128
|
provider (`str`, *optional*):
|
|
123
|
-
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"replicate"`, "sambanova"` or `"
|
|
124
|
-
|
|
129
|
+
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"clarifai"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `"publicai"`, `"replicate"`, `"sambanova"`, `"scaleway"`, `"together"`, `"wavespeed"` or `"zai-org"`.
|
|
130
|
+
Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers.
|
|
125
131
|
If model is a URL or `base_url` is passed, then `provider` is not used.
|
|
126
|
-
token (`str
|
|
132
|
+
token (`str`, *optional*):
|
|
127
133
|
Hugging Face token. Will default to the locally saved token if not provided.
|
|
128
|
-
Pass `token=False` if you don't want to send your token to the server.
|
|
129
134
|
Note: for better compatibility with OpenAI's client, `token` has been aliased as `api_key`. Those 2
|
|
130
135
|
arguments are mutually exclusive and have the exact same behavior.
|
|
131
136
|
timeout (`float`, `optional`):
|
|
132
|
-
The maximum number of seconds to wait for a response from the server.
|
|
133
|
-
|
|
134
|
-
headers (`Dict[str, str]`, `optional`):
|
|
137
|
+
The maximum number of seconds to wait for a response from the server. Defaults to None, meaning it will loop until the server is available.
|
|
138
|
+
headers (`dict[str, str]`, `optional`):
|
|
135
139
|
Additional headers to send to the server. By default only the authorization and user-agent headers are sent.
|
|
136
140
|
Values in this dictionary will override the default values.
|
|
137
|
-
|
|
141
|
+
bill_to (`str`, `optional`):
|
|
142
|
+
The billing account to use for the requests. By default the requests are billed on the user's account.
|
|
143
|
+
Requests can only be billed to an organization the user is a member of, and which has subscribed to Enterprise Hub.
|
|
144
|
+
cookies (`dict[str, str]`, `optional`):
|
|
138
145
|
Additional cookies to send to the server.
|
|
139
|
-
trust_env ('bool', 'optional'):
|
|
140
|
-
Trust environment settings for proxy configuration if the parameter is `True` (`False` by default).
|
|
141
|
-
proxies (`Any`, `optional`):
|
|
142
|
-
Proxies to use for the request.
|
|
143
146
|
base_url (`str`, `optional`):
|
|
144
147
|
Base URL to run inference. This is a duplicated argument from `model` to make [`InferenceClient`]
|
|
145
148
|
follow the same pattern as `openai.OpenAI` client. Cannot be used if `model` is set. Defaults to None.
|
|
@@ -148,17 +151,17 @@ class AsyncInferenceClient:
|
|
|
148
151
|
follow the same pattern as `openai.OpenAI` client. Cannot be used if `token` is set. Defaults to None.
|
|
149
152
|
"""
|
|
150
153
|
|
|
154
|
+
@validate_hf_hub_args
|
|
151
155
|
def __init__(
|
|
152
156
|
self,
|
|
153
157
|
model: Optional[str] = None,
|
|
154
158
|
*,
|
|
155
|
-
provider: Optional[
|
|
159
|
+
provider: Optional[PROVIDER_OR_POLICY_T] = None,
|
|
156
160
|
token: Optional[str] = None,
|
|
157
161
|
timeout: Optional[float] = None,
|
|
158
|
-
headers: Optional[
|
|
159
|
-
cookies: Optional[
|
|
160
|
-
|
|
161
|
-
proxies: Optional[Any] = None,
|
|
162
|
+
headers: Optional[dict[str, str]] = None,
|
|
163
|
+
cookies: Optional[dict[str, str]] = None,
|
|
164
|
+
bill_to: Optional[str] = None,
|
|
162
165
|
# OpenAI compatibility
|
|
163
166
|
base_url: Optional[str] = None,
|
|
164
167
|
api_key: Optional[str] = None,
|
|
@@ -176,101 +179,78 @@ class AsyncInferenceClient:
|
|
|
176
179
|
" `api_key` is an alias for `token` to make the API compatible with OpenAI's client."
|
|
177
180
|
" It has the exact same behavior as `token`."
|
|
178
181
|
)
|
|
182
|
+
token = token if token is not None else api_key
|
|
183
|
+
if isinstance(token, bool):
|
|
184
|
+
# Legacy behavior: previously it was possible to pass `token=False` to disable authentication. This is not
|
|
185
|
+
# supported anymore as authentication is required. Better to explicitly raise here rather than risking
|
|
186
|
+
# sending the locally saved token without the user knowing about it.
|
|
187
|
+
if token is False:
|
|
188
|
+
raise ValueError(
|
|
189
|
+
"Cannot use `token=False` to disable authentication as authentication is required to run Inference."
|
|
190
|
+
)
|
|
191
|
+
warnings.warn(
|
|
192
|
+
"Using `token=True` to automatically use the locally saved token is deprecated and will be removed in a future release. "
|
|
193
|
+
"Please use `token=None` instead (default).",
|
|
194
|
+
DeprecationWarning,
|
|
195
|
+
)
|
|
196
|
+
token = get_token()
|
|
179
197
|
|
|
180
198
|
self.model: Optional[str] = base_url or model
|
|
181
|
-
self.token: Optional[str] = token
|
|
182
|
-
|
|
199
|
+
self.token: Optional[str] = token
|
|
200
|
+
|
|
201
|
+
self.headers = {**headers} if headers is not None else {}
|
|
202
|
+
if bill_to is not None:
|
|
203
|
+
if (
|
|
204
|
+
constants.HUGGINGFACE_HEADER_X_BILL_TO in self.headers
|
|
205
|
+
and self.headers[constants.HUGGINGFACE_HEADER_X_BILL_TO] != bill_to
|
|
206
|
+
):
|
|
207
|
+
warnings.warn(
|
|
208
|
+
f"Overriding existing '{self.headers[constants.HUGGINGFACE_HEADER_X_BILL_TO]}' value in headers with '{bill_to}'.",
|
|
209
|
+
UserWarning,
|
|
210
|
+
)
|
|
211
|
+
self.headers[constants.HUGGINGFACE_HEADER_X_BILL_TO] = bill_to
|
|
212
|
+
|
|
213
|
+
if token is not None and not token.startswith("hf_"):
|
|
214
|
+
warnings.warn(
|
|
215
|
+
"You've provided an external provider's API key, so requests will be billed directly by the provider. "
|
|
216
|
+
"The `bill_to` parameter is only applicable for Hugging Face billing and will be ignored.",
|
|
217
|
+
UserWarning,
|
|
218
|
+
)
|
|
183
219
|
|
|
184
220
|
# Configure provider
|
|
185
|
-
self.provider = provider
|
|
221
|
+
self.provider = provider
|
|
186
222
|
|
|
187
223
|
self.cookies = cookies
|
|
188
224
|
self.timeout = timeout
|
|
189
|
-
self.trust_env = trust_env
|
|
190
|
-
self.proxies = proxies
|
|
191
225
|
|
|
192
|
-
|
|
193
|
-
self.
|
|
226
|
+
self.exit_stack = AsyncExitStack()
|
|
227
|
+
self._async_client: Optional[httpx.AsyncClient] = None
|
|
194
228
|
|
|
195
229
|
def __repr__(self):
|
|
196
230
|
return f"<InferenceClient(model='{self.model if self.model else ''}', timeout={self.timeout})>"
|
|
197
231
|
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
self,
|
|
201
|
-
*,
|
|
202
|
-
json: Optional[Union[str, Dict, List]] = None,
|
|
203
|
-
data: Optional[ContentT] = None,
|
|
204
|
-
model: Optional[str] = None,
|
|
205
|
-
task: Optional[str] = None,
|
|
206
|
-
stream: Literal[False] = ...,
|
|
207
|
-
) -> bytes: ...
|
|
232
|
+
async def __aenter__(self):
|
|
233
|
+
return self
|
|
208
234
|
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
self,
|
|
212
|
-
*,
|
|
213
|
-
json: Optional[Union[str, Dict, List]] = None,
|
|
214
|
-
data: Optional[ContentT] = None,
|
|
215
|
-
model: Optional[str] = None,
|
|
216
|
-
task: Optional[str] = None,
|
|
217
|
-
stream: Literal[True] = ...,
|
|
218
|
-
) -> AsyncIterable[bytes]: ...
|
|
235
|
+
async def __aexit__(self, exc_type, exc_value, traceback):
|
|
236
|
+
await self.close()
|
|
219
237
|
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
json: Optional[Union[str, Dict, List]] = None,
|
|
225
|
-
data: Optional[ContentT] = None,
|
|
226
|
-
model: Optional[str] = None,
|
|
227
|
-
task: Optional[str] = None,
|
|
228
|
-
stream: bool = False,
|
|
229
|
-
) -> Union[bytes, AsyncIterable[bytes]]: ...
|
|
230
|
-
|
|
231
|
-
@_deprecate_method(
|
|
232
|
-
version="0.31.0",
|
|
233
|
-
message=(
|
|
234
|
-
"Making direct POST requests to the inference server is not supported anymore. "
|
|
235
|
-
"Please use task methods instead (e.g. `InferenceClient.chat_completion`). "
|
|
236
|
-
"If your use case is not supported, please open an issue in https://github.com/huggingface/huggingface_hub."
|
|
237
|
-
),
|
|
238
|
-
)
|
|
239
|
-
async def post(
|
|
240
|
-
self,
|
|
241
|
-
*,
|
|
242
|
-
json: Optional[Union[str, Dict, List]] = None,
|
|
243
|
-
data: Optional[ContentT] = None,
|
|
244
|
-
model: Optional[str] = None,
|
|
245
|
-
task: Optional[str] = None,
|
|
246
|
-
stream: bool = False,
|
|
247
|
-
) -> Union[bytes, AsyncIterable[bytes]]:
|
|
238
|
+
async def close(self):
|
|
239
|
+
"""Close the client.
|
|
240
|
+
|
|
241
|
+
This method is automatically called when using the client as a context manager.
|
|
248
242
|
"""
|
|
249
|
-
|
|
243
|
+
await self.exit_stack.aclose()
|
|
250
244
|
|
|
251
|
-
|
|
252
|
-
|
|
245
|
+
async def _get_async_client(self):
|
|
246
|
+
"""Get a unique async client for this AsyncInferenceClient instance.
|
|
247
|
+
|
|
248
|
+
Returns the same client instance on subsequent calls, ensuring proper
|
|
249
|
+
connection reuse and resource management through the exit stack.
|
|
253
250
|
"""
|
|
254
|
-
if self.
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
"`InferenceClient.post` is deprecated and should not be used directly anymore."
|
|
258
|
-
)
|
|
259
|
-
provider_helper = HFInferenceTask(task or "unknown")
|
|
260
|
-
mapped_model = provider_helper._prepare_mapped_model(model or self.model)
|
|
261
|
-
url = provider_helper._prepare_url(self.token, mapped_model) # type: ignore[arg-type]
|
|
262
|
-
headers = provider_helper._prepare_headers(self.headers, self.token) # type: ignore[arg-type]
|
|
263
|
-
return await self._inner_post(
|
|
264
|
-
request_parameters=RequestParameters(
|
|
265
|
-
url=url,
|
|
266
|
-
task=task or "unknown",
|
|
267
|
-
model=model or "unknown",
|
|
268
|
-
json=json,
|
|
269
|
-
data=data,
|
|
270
|
-
headers=headers,
|
|
271
|
-
),
|
|
272
|
-
stream=stream,
|
|
273
|
-
)
|
|
251
|
+
if self._async_client is None:
|
|
252
|
+
self._async_client = await self.exit_stack.enter_async_context(get_async_session())
|
|
253
|
+
return self._async_client
|
|
274
254
|
|
|
275
255
|
@overload
|
|
276
256
|
async def _inner_post( # type: ignore[misc]
|
|
@@ -280,84 +260,59 @@ class AsyncInferenceClient:
|
|
|
280
260
|
@overload
|
|
281
261
|
async def _inner_post( # type: ignore[misc]
|
|
282
262
|
self, request_parameters: RequestParameters, *, stream: Literal[True] = ...
|
|
283
|
-
) -> AsyncIterable[
|
|
263
|
+
) -> AsyncIterable[str]: ...
|
|
284
264
|
|
|
285
265
|
@overload
|
|
286
266
|
async def _inner_post(
|
|
287
267
|
self, request_parameters: RequestParameters, *, stream: bool = False
|
|
288
|
-
) -> Union[bytes, AsyncIterable[
|
|
268
|
+
) -> Union[bytes, AsyncIterable[str]]: ...
|
|
289
269
|
|
|
290
270
|
async def _inner_post(
|
|
291
271
|
self, request_parameters: RequestParameters, *, stream: bool = False
|
|
292
|
-
) -> Union[bytes, AsyncIterable[
|
|
272
|
+
) -> Union[bytes, AsyncIterable[str]]:
|
|
293
273
|
"""Make a request to the inference server."""
|
|
294
274
|
|
|
295
|
-
aiohttp = _import_aiohttp()
|
|
296
|
-
|
|
297
275
|
# TODO: this should be handled in provider helpers directly
|
|
298
276
|
if request_parameters.task in TASKS_EXPECTING_IMAGES and "Accept" not in request_parameters.headers:
|
|
299
277
|
request_parameters.headers["Accept"] = "image/png"
|
|
300
278
|
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
279
|
+
try:
|
|
280
|
+
client = await self._get_async_client()
|
|
281
|
+
if stream:
|
|
282
|
+
response = await self.exit_stack.enter_async_context(
|
|
283
|
+
client.stream(
|
|
284
|
+
"POST",
|
|
285
|
+
request_parameters.url,
|
|
286
|
+
json=request_parameters.json,
|
|
287
|
+
data=request_parameters.data,
|
|
288
|
+
headers=request_parameters.headers,
|
|
289
|
+
cookies=self.cookies,
|
|
290
|
+
timeout=self.timeout,
|
|
310
291
|
)
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
async def __aenter__(self):
|
|
337
|
-
return self
|
|
338
|
-
|
|
339
|
-
async def __aexit__(self, exc_type, exc_value, traceback):
|
|
340
|
-
await self.close()
|
|
341
|
-
|
|
342
|
-
def __del__(self):
|
|
343
|
-
if len(self._sessions) > 0:
|
|
344
|
-
warnings.warn(
|
|
345
|
-
"Deleting 'AsyncInferenceClient' client but some sessions are still open. "
|
|
346
|
-
"This can happen if you've stopped streaming data from the server before the stream was complete. "
|
|
347
|
-
"To close the client properly, you must call `await client.close()` "
|
|
348
|
-
"or use an async context (e.g. `async with AsyncInferenceClient(): ...`."
|
|
349
|
-
)
|
|
350
|
-
|
|
351
|
-
async def close(self):
|
|
352
|
-
"""Close all open sessions.
|
|
353
|
-
|
|
354
|
-
By default, 'aiohttp.ClientSession' objects are closed automatically when a call is completed. However, if you
|
|
355
|
-
are streaming data from the server and you stop before the stream is complete, you must call this method to
|
|
356
|
-
close the session properly.
|
|
357
|
-
|
|
358
|
-
Another possibility is to use an async context (e.g. `async with AsyncInferenceClient(): ...`).
|
|
359
|
-
"""
|
|
360
|
-
await asyncio.gather(*[session.close() for session in self._sessions.keys()])
|
|
292
|
+
)
|
|
293
|
+
hf_raise_for_status(response)
|
|
294
|
+
return _async_yield_from(client, response)
|
|
295
|
+
else:
|
|
296
|
+
response = await client.post(
|
|
297
|
+
request_parameters.url,
|
|
298
|
+
json=request_parameters.json,
|
|
299
|
+
data=request_parameters.data,
|
|
300
|
+
headers=request_parameters.headers,
|
|
301
|
+
cookies=self.cookies,
|
|
302
|
+
timeout=self.timeout,
|
|
303
|
+
)
|
|
304
|
+
hf_raise_for_status(response)
|
|
305
|
+
return response.content
|
|
306
|
+
except asyncio.TimeoutError as error:
|
|
307
|
+
# Convert any `TimeoutError` to a `InferenceTimeoutError`
|
|
308
|
+
raise InferenceTimeoutError(f"Inference call timed out: {request_parameters.url}") from error # type: ignore
|
|
309
|
+
except HfHubHTTPError as error:
|
|
310
|
+
if error.response.status_code == 422 and request_parameters.task != "unknown":
|
|
311
|
+
msg = str(error.args[0])
|
|
312
|
+
if len(error.response.text) > 0:
|
|
313
|
+
msg += f"{os.linesep}{error.response.text}{os.linesep}"
|
|
314
|
+
error.args = (msg,) + error.args[1:]
|
|
315
|
+
raise
|
|
361
316
|
|
|
362
317
|
async def audio_classification(
|
|
363
318
|
self,
|
|
@@ -366,7 +321,7 @@ class AsyncInferenceClient:
|
|
|
366
321
|
model: Optional[str] = None,
|
|
367
322
|
top_k: Optional[int] = None,
|
|
368
323
|
function_to_apply: Optional["AudioClassificationOutputTransform"] = None,
|
|
369
|
-
) ->
|
|
324
|
+
) -> list[AudioClassificationOutputElement]:
|
|
370
325
|
"""
|
|
371
326
|
Perform audio classification on the provided audio content.
|
|
372
327
|
|
|
@@ -384,12 +339,12 @@ class AsyncInferenceClient:
|
|
|
384
339
|
The function to apply to the model outputs in order to retrieve the scores.
|
|
385
340
|
|
|
386
341
|
Returns:
|
|
387
|
-
`
|
|
342
|
+
`list[AudioClassificationOutputElement]`: List of [`AudioClassificationOutputElement`] items containing the predicted labels and their confidence.
|
|
388
343
|
|
|
389
344
|
Raises:
|
|
390
345
|
[`InferenceTimeoutError`]:
|
|
391
346
|
If the model is unavailable or the request times out.
|
|
392
|
-
`
|
|
347
|
+
[`HfHubHTTPError`]:
|
|
393
348
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
394
349
|
|
|
395
350
|
Example:
|
|
@@ -405,12 +360,13 @@ class AsyncInferenceClient:
|
|
|
405
360
|
]
|
|
406
361
|
```
|
|
407
362
|
"""
|
|
408
|
-
|
|
363
|
+
model_id = model or self.model
|
|
364
|
+
provider_helper = get_provider_helper(self.provider, task="audio-classification", model=model_id)
|
|
409
365
|
request_parameters = provider_helper.prepare_request(
|
|
410
366
|
inputs=audio,
|
|
411
367
|
parameters={"function_to_apply": function_to_apply, "top_k": top_k},
|
|
412
368
|
headers=self.headers,
|
|
413
|
-
model=
|
|
369
|
+
model=model_id,
|
|
414
370
|
api_key=self.token,
|
|
415
371
|
)
|
|
416
372
|
response = await self._inner_post(request_parameters)
|
|
@@ -421,7 +377,7 @@ class AsyncInferenceClient:
|
|
|
421
377
|
audio: ContentT,
|
|
422
378
|
*,
|
|
423
379
|
model: Optional[str] = None,
|
|
424
|
-
) ->
|
|
380
|
+
) -> list[AudioToAudioOutputElement]:
|
|
425
381
|
"""
|
|
426
382
|
Performs multiple tasks related to audio-to-audio depending on the model (eg: speech enhancement, source separation).
|
|
427
383
|
|
|
@@ -435,12 +391,12 @@ class AsyncInferenceClient:
|
|
|
435
391
|
audio_to_audio will be used.
|
|
436
392
|
|
|
437
393
|
Returns:
|
|
438
|
-
`
|
|
394
|
+
`list[AudioToAudioOutputElement]`: A list of [`AudioToAudioOutputElement`] items containing audios label, content-type, and audio content in blob.
|
|
439
395
|
|
|
440
396
|
Raises:
|
|
441
397
|
`InferenceTimeoutError`:
|
|
442
398
|
If the model is unavailable or the request times out.
|
|
443
|
-
`
|
|
399
|
+
[`HfHubHTTPError`]:
|
|
444
400
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
445
401
|
|
|
446
402
|
Example:
|
|
@@ -454,12 +410,13 @@ class AsyncInferenceClient:
|
|
|
454
410
|
f.write(item.blob)
|
|
455
411
|
```
|
|
456
412
|
"""
|
|
457
|
-
|
|
413
|
+
model_id = model or self.model
|
|
414
|
+
provider_helper = get_provider_helper(self.provider, task="audio-to-audio", model=model_id)
|
|
458
415
|
request_parameters = provider_helper.prepare_request(
|
|
459
416
|
inputs=audio,
|
|
460
417
|
parameters={},
|
|
461
418
|
headers=self.headers,
|
|
462
|
-
model=
|
|
419
|
+
model=model_id,
|
|
463
420
|
api_key=self.token,
|
|
464
421
|
)
|
|
465
422
|
response = await self._inner_post(request_parameters)
|
|
@@ -473,7 +430,7 @@ class AsyncInferenceClient:
|
|
|
473
430
|
audio: ContentT,
|
|
474
431
|
*,
|
|
475
432
|
model: Optional[str] = None,
|
|
476
|
-
extra_body: Optional[
|
|
433
|
+
extra_body: Optional[dict] = None,
|
|
477
434
|
) -> AutomaticSpeechRecognitionOutput:
|
|
478
435
|
"""
|
|
479
436
|
Perform automatic speech recognition (ASR or audio-to-text) on the given audio content.
|
|
@@ -484,7 +441,7 @@ class AsyncInferenceClient:
|
|
|
484
441
|
model (`str`, *optional*):
|
|
485
442
|
The model to use for ASR. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
|
486
443
|
Inference Endpoint. If not provided, the default recommended model for ASR will be used.
|
|
487
|
-
extra_body (`
|
|
444
|
+
extra_body (`dict`, *optional*):
|
|
488
445
|
Additional provider-specific parameters to pass to the model. Refer to the provider's documentation
|
|
489
446
|
for supported parameters.
|
|
490
447
|
Returns:
|
|
@@ -493,7 +450,7 @@ class AsyncInferenceClient:
|
|
|
493
450
|
Raises:
|
|
494
451
|
[`InferenceTimeoutError`]:
|
|
495
452
|
If the model is unavailable or the request times out.
|
|
496
|
-
`
|
|
453
|
+
[`HfHubHTTPError`]:
|
|
497
454
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
498
455
|
|
|
499
456
|
Example:
|
|
@@ -505,12 +462,13 @@ class AsyncInferenceClient:
|
|
|
505
462
|
"hello world"
|
|
506
463
|
```
|
|
507
464
|
"""
|
|
508
|
-
|
|
465
|
+
model_id = model or self.model
|
|
466
|
+
provider_helper = get_provider_helper(self.provider, task="automatic-speech-recognition", model=model_id)
|
|
509
467
|
request_parameters = provider_helper.prepare_request(
|
|
510
468
|
inputs=audio,
|
|
511
469
|
parameters={**(extra_body or {})},
|
|
512
470
|
headers=self.headers,
|
|
513
|
-
model=
|
|
471
|
+
model=model_id,
|
|
514
472
|
api_key=self.token,
|
|
515
473
|
)
|
|
516
474
|
response = await self._inner_post(request_parameters)
|
|
@@ -519,121 +477,117 @@ class AsyncInferenceClient:
|
|
|
519
477
|
@overload
|
|
520
478
|
async def chat_completion( # type: ignore
|
|
521
479
|
self,
|
|
522
|
-
messages:
|
|
480
|
+
messages: list[Union[dict, ChatCompletionInputMessage]],
|
|
523
481
|
*,
|
|
524
482
|
model: Optional[str] = None,
|
|
525
483
|
stream: Literal[False] = False,
|
|
526
484
|
frequency_penalty: Optional[float] = None,
|
|
527
|
-
logit_bias: Optional[
|
|
485
|
+
logit_bias: Optional[list[float]] = None,
|
|
528
486
|
logprobs: Optional[bool] = None,
|
|
529
487
|
max_tokens: Optional[int] = None,
|
|
530
488
|
n: Optional[int] = None,
|
|
531
489
|
presence_penalty: Optional[float] = None,
|
|
532
490
|
response_format: Optional[ChatCompletionInputGrammarType] = None,
|
|
533
491
|
seed: Optional[int] = None,
|
|
534
|
-
stop: Optional[
|
|
492
|
+
stop: Optional[list[str]] = None,
|
|
535
493
|
stream_options: Optional[ChatCompletionInputStreamOptions] = None,
|
|
536
494
|
temperature: Optional[float] = None,
|
|
537
495
|
tool_choice: Optional[Union[ChatCompletionInputToolChoiceClass, "ChatCompletionInputToolChoiceEnum"]] = None,
|
|
538
496
|
tool_prompt: Optional[str] = None,
|
|
539
|
-
tools: Optional[
|
|
497
|
+
tools: Optional[list[ChatCompletionInputTool]] = None,
|
|
540
498
|
top_logprobs: Optional[int] = None,
|
|
541
499
|
top_p: Optional[float] = None,
|
|
542
|
-
extra_body: Optional[
|
|
500
|
+
extra_body: Optional[dict] = None,
|
|
543
501
|
) -> ChatCompletionOutput: ...
|
|
544
502
|
|
|
545
503
|
@overload
|
|
546
504
|
async def chat_completion( # type: ignore
|
|
547
505
|
self,
|
|
548
|
-
messages:
|
|
506
|
+
messages: list[Union[dict, ChatCompletionInputMessage]],
|
|
549
507
|
*,
|
|
550
508
|
model: Optional[str] = None,
|
|
551
509
|
stream: Literal[True] = True,
|
|
552
510
|
frequency_penalty: Optional[float] = None,
|
|
553
|
-
logit_bias: Optional[
|
|
511
|
+
logit_bias: Optional[list[float]] = None,
|
|
554
512
|
logprobs: Optional[bool] = None,
|
|
555
513
|
max_tokens: Optional[int] = None,
|
|
556
514
|
n: Optional[int] = None,
|
|
557
515
|
presence_penalty: Optional[float] = None,
|
|
558
516
|
response_format: Optional[ChatCompletionInputGrammarType] = None,
|
|
559
517
|
seed: Optional[int] = None,
|
|
560
|
-
stop: Optional[
|
|
518
|
+
stop: Optional[list[str]] = None,
|
|
561
519
|
stream_options: Optional[ChatCompletionInputStreamOptions] = None,
|
|
562
520
|
temperature: Optional[float] = None,
|
|
563
521
|
tool_choice: Optional[Union[ChatCompletionInputToolChoiceClass, "ChatCompletionInputToolChoiceEnum"]] = None,
|
|
564
522
|
tool_prompt: Optional[str] = None,
|
|
565
|
-
tools: Optional[
|
|
523
|
+
tools: Optional[list[ChatCompletionInputTool]] = None,
|
|
566
524
|
top_logprobs: Optional[int] = None,
|
|
567
525
|
top_p: Optional[float] = None,
|
|
568
|
-
extra_body: Optional[
|
|
526
|
+
extra_body: Optional[dict] = None,
|
|
569
527
|
) -> AsyncIterable[ChatCompletionStreamOutput]: ...
|
|
570
528
|
|
|
571
529
|
@overload
|
|
572
530
|
async def chat_completion(
|
|
573
531
|
self,
|
|
574
|
-
messages:
|
|
532
|
+
messages: list[Union[dict, ChatCompletionInputMessage]],
|
|
575
533
|
*,
|
|
576
534
|
model: Optional[str] = None,
|
|
577
535
|
stream: bool = False,
|
|
578
536
|
frequency_penalty: Optional[float] = None,
|
|
579
|
-
logit_bias: Optional[
|
|
537
|
+
logit_bias: Optional[list[float]] = None,
|
|
580
538
|
logprobs: Optional[bool] = None,
|
|
581
539
|
max_tokens: Optional[int] = None,
|
|
582
540
|
n: Optional[int] = None,
|
|
583
541
|
presence_penalty: Optional[float] = None,
|
|
584
542
|
response_format: Optional[ChatCompletionInputGrammarType] = None,
|
|
585
543
|
seed: Optional[int] = None,
|
|
586
|
-
stop: Optional[
|
|
544
|
+
stop: Optional[list[str]] = None,
|
|
587
545
|
stream_options: Optional[ChatCompletionInputStreamOptions] = None,
|
|
588
546
|
temperature: Optional[float] = None,
|
|
589
547
|
tool_choice: Optional[Union[ChatCompletionInputToolChoiceClass, "ChatCompletionInputToolChoiceEnum"]] = None,
|
|
590
548
|
tool_prompt: Optional[str] = None,
|
|
591
|
-
tools: Optional[
|
|
549
|
+
tools: Optional[list[ChatCompletionInputTool]] = None,
|
|
592
550
|
top_logprobs: Optional[int] = None,
|
|
593
551
|
top_p: Optional[float] = None,
|
|
594
|
-
extra_body: Optional[
|
|
552
|
+
extra_body: Optional[dict] = None,
|
|
595
553
|
) -> Union[ChatCompletionOutput, AsyncIterable[ChatCompletionStreamOutput]]: ...
|
|
596
554
|
|
|
597
555
|
async def chat_completion(
|
|
598
556
|
self,
|
|
599
|
-
messages:
|
|
557
|
+
messages: list[Union[dict, ChatCompletionInputMessage]],
|
|
600
558
|
*,
|
|
601
559
|
model: Optional[str] = None,
|
|
602
560
|
stream: bool = False,
|
|
603
561
|
# Parameters from ChatCompletionInput (handled manually)
|
|
604
562
|
frequency_penalty: Optional[float] = None,
|
|
605
|
-
logit_bias: Optional[
|
|
563
|
+
logit_bias: Optional[list[float]] = None,
|
|
606
564
|
logprobs: Optional[bool] = None,
|
|
607
565
|
max_tokens: Optional[int] = None,
|
|
608
566
|
n: Optional[int] = None,
|
|
609
567
|
presence_penalty: Optional[float] = None,
|
|
610
568
|
response_format: Optional[ChatCompletionInputGrammarType] = None,
|
|
611
569
|
seed: Optional[int] = None,
|
|
612
|
-
stop: Optional[
|
|
570
|
+
stop: Optional[list[str]] = None,
|
|
613
571
|
stream_options: Optional[ChatCompletionInputStreamOptions] = None,
|
|
614
572
|
temperature: Optional[float] = None,
|
|
615
573
|
tool_choice: Optional[Union[ChatCompletionInputToolChoiceClass, "ChatCompletionInputToolChoiceEnum"]] = None,
|
|
616
574
|
tool_prompt: Optional[str] = None,
|
|
617
|
-
tools: Optional[
|
|
575
|
+
tools: Optional[list[ChatCompletionInputTool]] = None,
|
|
618
576
|
top_logprobs: Optional[int] = None,
|
|
619
577
|
top_p: Optional[float] = None,
|
|
620
|
-
extra_body: Optional[
|
|
578
|
+
extra_body: Optional[dict] = None,
|
|
621
579
|
) -> Union[ChatCompletionOutput, AsyncIterable[ChatCompletionStreamOutput]]:
|
|
622
580
|
"""
|
|
623
581
|
A method for completing conversations using a specified language model.
|
|
624
582
|
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
for more details about OpenAI's compatibility.
|
|
583
|
+
> [!TIP]
|
|
584
|
+
> The `client.chat_completion` method is aliased as `client.chat.completions.create` for compatibility with OpenAI's client.
|
|
585
|
+
> Inputs and outputs are strictly the same and using either syntax will yield the same results.
|
|
586
|
+
> Check out the [Inference guide](https://huggingface.co/docs/huggingface_hub/guides/inference#openai-compatibility)
|
|
587
|
+
> for more details about OpenAI's compatibility.
|
|
631
588
|
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
<Tip>
|
|
635
|
-
You can pass provider-specific parameters to the model by using the `extra_body` argument.
|
|
636
|
-
</Tip>
|
|
589
|
+
> [!TIP]
|
|
590
|
+
> You can pass provider-specific parameters to the model by using the `extra_body` argument.
|
|
637
591
|
|
|
638
592
|
Args:
|
|
639
593
|
messages (List of [`ChatCompletionInputMessage`]):
|
|
@@ -647,7 +601,7 @@ class AsyncInferenceClient:
|
|
|
647
601
|
frequency_penalty (`float`, *optional*):
|
|
648
602
|
Penalizes new tokens based on their existing frequency
|
|
649
603
|
in the text so far. Range: [-2.0, 2.0]. Defaults to 0.0.
|
|
650
|
-
logit_bias (`
|
|
604
|
+
logit_bias (`list[float]`, *optional*):
|
|
651
605
|
Adjusts the likelihood of specific tokens appearing in the generated output.
|
|
652
606
|
logprobs (`bool`, *optional*):
|
|
653
607
|
Whether to return log probabilities of the output tokens or not. If true, returns the log
|
|
@@ -663,7 +617,7 @@ class AsyncInferenceClient:
|
|
|
663
617
|
Grammar constraints. Can be either a JSONSchema or a regex.
|
|
664
618
|
seed (Optional[`int`], *optional*):
|
|
665
619
|
Seed for reproducible control flow. Defaults to None.
|
|
666
|
-
stop (`
|
|
620
|
+
stop (`list[str]`, *optional*):
|
|
667
621
|
Up to four strings which trigger the end of the response.
|
|
668
622
|
Defaults to None.
|
|
669
623
|
stream (`bool`, *optional*):
|
|
@@ -687,7 +641,7 @@ class AsyncInferenceClient:
|
|
|
687
641
|
tools (List of [`ChatCompletionInputTool`], *optional*):
|
|
688
642
|
A list of tools the model may call. Currently, only functions are supported as a tool. Use this to
|
|
689
643
|
provide a list of functions the model may generate JSON inputs for.
|
|
690
|
-
extra_body (`
|
|
644
|
+
extra_body (`dict`, *optional*):
|
|
691
645
|
Additional provider-specific parameters to pass to the model. Refer to the provider's documentation
|
|
692
646
|
for supported parameters.
|
|
693
647
|
Returns:
|
|
@@ -699,7 +653,7 @@ class AsyncInferenceClient:
|
|
|
699
653
|
Raises:
|
|
700
654
|
[`InferenceTimeoutError`]:
|
|
701
655
|
If the model is unavailable or the request times out.
|
|
702
|
-
`
|
|
656
|
+
[`HfHubHTTPError`]:
|
|
703
657
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
704
658
|
|
|
705
659
|
Example:
|
|
@@ -931,7 +885,7 @@ class AsyncInferenceClient:
|
|
|
931
885
|
>>> messages = [
|
|
932
886
|
... {
|
|
933
887
|
... "role": "user",
|
|
934
|
-
... "content": "I saw a puppy a cat and a raccoon during my bike ride in the park. What did I
|
|
888
|
+
... "content": "I saw a puppy a cat and a raccoon during my bike ride in the park. What did I see and when?",
|
|
935
889
|
... },
|
|
936
890
|
... ]
|
|
937
891
|
>>> response_format = {
|
|
@@ -950,20 +904,26 @@ class AsyncInferenceClient:
|
|
|
950
904
|
... messages=messages,
|
|
951
905
|
... response_format=response_format,
|
|
952
906
|
... max_tokens=500,
|
|
953
|
-
)
|
|
907
|
+
... )
|
|
954
908
|
>>> response.choices[0].message.content
|
|
955
909
|
'{\n\n"activity": "bike ride",\n"animals": ["puppy", "cat", "raccoon"],\n"animals_seen": 3,\n"location": "park"}'
|
|
956
910
|
```
|
|
957
911
|
"""
|
|
958
|
-
# Get the provider helper
|
|
959
|
-
provider_helper = get_provider_helper(self.provider, task="conversational")
|
|
960
|
-
|
|
961
912
|
# Since `chat_completion(..., model=xxx)` is also a payload parameter for the server, we need to handle 'model' differently.
|
|
962
913
|
# `self.model` takes precedence over 'model' argument for building URL.
|
|
963
914
|
# `model` takes precedence for payload value.
|
|
964
915
|
model_id_or_url = self.model or model
|
|
965
916
|
payload_model = model or self.model
|
|
966
917
|
|
|
918
|
+
# Get the provider helper
|
|
919
|
+
provider_helper = get_provider_helper(
|
|
920
|
+
self.provider,
|
|
921
|
+
task="conversational",
|
|
922
|
+
model=model_id_or_url
|
|
923
|
+
if model_id_or_url is not None and model_id_or_url.startswith(("http://", "https://"))
|
|
924
|
+
else payload_model,
|
|
925
|
+
)
|
|
926
|
+
|
|
967
927
|
# Prepare the payload
|
|
968
928
|
parameters = {
|
|
969
929
|
"model": payload_model,
|
|
@@ -1013,8 +973,8 @@ class AsyncInferenceClient:
|
|
|
1013
973
|
max_question_len: Optional[int] = None,
|
|
1014
974
|
max_seq_len: Optional[int] = None,
|
|
1015
975
|
top_k: Optional[int] = None,
|
|
1016
|
-
word_boxes: Optional[
|
|
1017
|
-
) ->
|
|
976
|
+
word_boxes: Optional[list[Union[list[float], str]]] = None,
|
|
977
|
+
) -> list[DocumentQuestionAnsweringOutputElement]:
|
|
1018
978
|
"""
|
|
1019
979
|
Answer questions on document images.
|
|
1020
980
|
|
|
@@ -1044,16 +1004,16 @@ class AsyncInferenceClient:
|
|
|
1044
1004
|
top_k (`int`, *optional*):
|
|
1045
1005
|
The number of answers to return (will be chosen by order of likelihood). Can return less than top_k
|
|
1046
1006
|
answers if there are not enough options available within the context.
|
|
1047
|
-
word_boxes (`
|
|
1007
|
+
word_boxes (`list[Union[list[float], str`, *optional*):
|
|
1048
1008
|
A list of words and bounding boxes (normalized 0->1000). If provided, the inference will skip the OCR
|
|
1049
1009
|
step and use the provided bounding boxes instead.
|
|
1050
1010
|
Returns:
|
|
1051
|
-
`
|
|
1011
|
+
`list[DocumentQuestionAnsweringOutputElement]`: a list of [`DocumentQuestionAnsweringOutputElement`] items containing the predicted label, associated probability, word ids, and page number.
|
|
1052
1012
|
|
|
1053
1013
|
Raises:
|
|
1054
1014
|
[`InferenceTimeoutError`]:
|
|
1055
1015
|
If the model is unavailable or the request times out.
|
|
1056
|
-
`
|
|
1016
|
+
[`HfHubHTTPError`]:
|
|
1057
1017
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
1058
1018
|
|
|
1059
1019
|
|
|
@@ -1066,8 +1026,9 @@ class AsyncInferenceClient:
|
|
|
1066
1026
|
[DocumentQuestionAnsweringOutputElement(answer='us-001', end=16, score=0.9999666213989258, start=16)]
|
|
1067
1027
|
```
|
|
1068
1028
|
"""
|
|
1069
|
-
|
|
1070
|
-
provider_helper = get_provider_helper(self.provider, task="document-question-answering")
|
|
1029
|
+
model_id = model or self.model
|
|
1030
|
+
provider_helper = get_provider_helper(self.provider, task="document-question-answering", model=model_id)
|
|
1031
|
+
inputs: dict[str, Any] = {"question": question, "image": _b64_encode(image)}
|
|
1071
1032
|
request_parameters = provider_helper.prepare_request(
|
|
1072
1033
|
inputs=inputs,
|
|
1073
1034
|
parameters={
|
|
@@ -1081,7 +1042,7 @@ class AsyncInferenceClient:
|
|
|
1081
1042
|
"word_boxes": word_boxes,
|
|
1082
1043
|
},
|
|
1083
1044
|
headers=self.headers,
|
|
1084
|
-
model=
|
|
1045
|
+
model=model_id,
|
|
1085
1046
|
api_key=self.token,
|
|
1086
1047
|
)
|
|
1087
1048
|
response = await self._inner_post(request_parameters)
|
|
@@ -1104,8 +1065,8 @@ class AsyncInferenceClient:
|
|
|
1104
1065
|
text (`str`):
|
|
1105
1066
|
The text to embed.
|
|
1106
1067
|
model (`str`, *optional*):
|
|
1107
|
-
The model to use for the
|
|
1108
|
-
a deployed Inference Endpoint. If not provided, the default recommended
|
|
1068
|
+
The model to use for the feature extraction task. Can be a model ID hosted on the Hugging Face Hub or a URL to
|
|
1069
|
+
a deployed Inference Endpoint. If not provided, the default recommended feature extraction model will be used.
|
|
1109
1070
|
Defaults to None.
|
|
1110
1071
|
normalize (`bool`, *optional*):
|
|
1111
1072
|
Whether to normalize the embeddings or not.
|
|
@@ -1128,7 +1089,7 @@ class AsyncInferenceClient:
|
|
|
1128
1089
|
Raises:
|
|
1129
1090
|
[`InferenceTimeoutError`]:
|
|
1130
1091
|
If the model is unavailable or the request times out.
|
|
1131
|
-
`
|
|
1092
|
+
[`HfHubHTTPError`]:
|
|
1132
1093
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
1133
1094
|
|
|
1134
1095
|
Example:
|
|
@@ -1143,7 +1104,8 @@ class AsyncInferenceClient:
|
|
|
1143
1104
|
[ 0.28552425, -0.928395 , -1.2077185 , ..., 0.76810825, -2.1069427 , 0.6236161 ]], dtype=float32)
|
|
1144
1105
|
```
|
|
1145
1106
|
"""
|
|
1146
|
-
|
|
1107
|
+
model_id = model or self.model
|
|
1108
|
+
provider_helper = get_provider_helper(self.provider, task="feature-extraction", model=model_id)
|
|
1147
1109
|
request_parameters = provider_helper.prepare_request(
|
|
1148
1110
|
inputs=text,
|
|
1149
1111
|
parameters={
|
|
@@ -1153,21 +1115,21 @@ class AsyncInferenceClient:
|
|
|
1153
1115
|
"truncation_direction": truncation_direction,
|
|
1154
1116
|
},
|
|
1155
1117
|
headers=self.headers,
|
|
1156
|
-
model=
|
|
1118
|
+
model=model_id,
|
|
1157
1119
|
api_key=self.token,
|
|
1158
1120
|
)
|
|
1159
1121
|
response = await self._inner_post(request_parameters)
|
|
1160
1122
|
np = _import_numpy()
|
|
1161
|
-
return np.array(
|
|
1123
|
+
return np.array(provider_helper.get_response(response), dtype="float32")
|
|
1162
1124
|
|
|
1163
1125
|
async def fill_mask(
|
|
1164
1126
|
self,
|
|
1165
1127
|
text: str,
|
|
1166
1128
|
*,
|
|
1167
1129
|
model: Optional[str] = None,
|
|
1168
|
-
targets: Optional[
|
|
1130
|
+
targets: Optional[list[str]] = None,
|
|
1169
1131
|
top_k: Optional[int] = None,
|
|
1170
|
-
) ->
|
|
1132
|
+
) -> list[FillMaskOutputElement]:
|
|
1171
1133
|
"""
|
|
1172
1134
|
Fill in a hole with a missing word (token to be precise).
|
|
1173
1135
|
|
|
@@ -1177,20 +1139,20 @@ class AsyncInferenceClient:
|
|
|
1177
1139
|
model (`str`, *optional*):
|
|
1178
1140
|
The model to use for the fill mask task. Can be a model ID hosted on the Hugging Face Hub or a URL to
|
|
1179
1141
|
a deployed Inference Endpoint. If not provided, the default recommended fill mask model will be used.
|
|
1180
|
-
targets (`
|
|
1142
|
+
targets (`list[str`, *optional*):
|
|
1181
1143
|
When passed, the model will limit the scores to the passed targets instead of looking up in the whole
|
|
1182
1144
|
vocabulary. If the provided targets are not in the model vocab, they will be tokenized and the first
|
|
1183
1145
|
resulting token will be used (with a warning, and that might be slower).
|
|
1184
1146
|
top_k (`int`, *optional*):
|
|
1185
1147
|
When passed, overrides the number of predictions to return.
|
|
1186
1148
|
Returns:
|
|
1187
|
-
`
|
|
1149
|
+
`list[FillMaskOutputElement]`: a list of [`FillMaskOutputElement`] items containing the predicted label, associated
|
|
1188
1150
|
probability, token reference, and completed text.
|
|
1189
1151
|
|
|
1190
1152
|
Raises:
|
|
1191
1153
|
[`InferenceTimeoutError`]:
|
|
1192
1154
|
If the model is unavailable or the request times out.
|
|
1193
|
-
`
|
|
1155
|
+
[`HfHubHTTPError`]:
|
|
1194
1156
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
1195
1157
|
|
|
1196
1158
|
Example:
|
|
@@ -1205,12 +1167,13 @@ class AsyncInferenceClient:
|
|
|
1205
1167
|
]
|
|
1206
1168
|
```
|
|
1207
1169
|
"""
|
|
1208
|
-
|
|
1170
|
+
model_id = model or self.model
|
|
1171
|
+
provider_helper = get_provider_helper(self.provider, task="fill-mask", model=model_id)
|
|
1209
1172
|
request_parameters = provider_helper.prepare_request(
|
|
1210
1173
|
inputs=text,
|
|
1211
1174
|
parameters={"targets": targets, "top_k": top_k},
|
|
1212
1175
|
headers=self.headers,
|
|
1213
|
-
model=
|
|
1176
|
+
model=model_id,
|
|
1214
1177
|
api_key=self.token,
|
|
1215
1178
|
)
|
|
1216
1179
|
response = await self._inner_post(request_parameters)
|
|
@@ -1223,13 +1186,13 @@ class AsyncInferenceClient:
|
|
|
1223
1186
|
model: Optional[str] = None,
|
|
1224
1187
|
function_to_apply: Optional["ImageClassificationOutputTransform"] = None,
|
|
1225
1188
|
top_k: Optional[int] = None,
|
|
1226
|
-
) ->
|
|
1189
|
+
) -> list[ImageClassificationOutputElement]:
|
|
1227
1190
|
"""
|
|
1228
1191
|
Perform image classification on the given image using the specified model.
|
|
1229
1192
|
|
|
1230
1193
|
Args:
|
|
1231
|
-
image (`Union[str, Path, bytes, BinaryIO]`):
|
|
1232
|
-
The image to classify. It can be raw bytes, an image file,
|
|
1194
|
+
image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`):
|
|
1195
|
+
The image to classify. It can be raw bytes, an image file, a URL to an online image, or a PIL Image.
|
|
1233
1196
|
model (`str`, *optional*):
|
|
1234
1197
|
The model to use for image classification. Can be a model ID hosted on the Hugging Face Hub or a URL to a
|
|
1235
1198
|
deployed Inference Endpoint. If not provided, the default recommended model for image classification will be used.
|
|
@@ -1238,12 +1201,12 @@ class AsyncInferenceClient:
|
|
|
1238
1201
|
top_k (`int`, *optional*):
|
|
1239
1202
|
When specified, limits the output to the top K most probable classes.
|
|
1240
1203
|
Returns:
|
|
1241
|
-
`
|
|
1204
|
+
`list[ImageClassificationOutputElement]`: a list of [`ImageClassificationOutputElement`] items containing the predicted label and associated probability.
|
|
1242
1205
|
|
|
1243
1206
|
Raises:
|
|
1244
1207
|
[`InferenceTimeoutError`]:
|
|
1245
1208
|
If the model is unavailable or the request times out.
|
|
1246
|
-
`
|
|
1209
|
+
[`HfHubHTTPError`]:
|
|
1247
1210
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
1248
1211
|
|
|
1249
1212
|
Example:
|
|
@@ -1255,12 +1218,13 @@ class AsyncInferenceClient:
|
|
|
1255
1218
|
[ImageClassificationOutputElement(label='Blenheim spaniel', score=0.9779096841812134), ...]
|
|
1256
1219
|
```
|
|
1257
1220
|
"""
|
|
1258
|
-
|
|
1221
|
+
model_id = model or self.model
|
|
1222
|
+
provider_helper = get_provider_helper(self.provider, task="image-classification", model=model_id)
|
|
1259
1223
|
request_parameters = provider_helper.prepare_request(
|
|
1260
1224
|
inputs=image,
|
|
1261
1225
|
parameters={"function_to_apply": function_to_apply, "top_k": top_k},
|
|
1262
1226
|
headers=self.headers,
|
|
1263
|
-
model=
|
|
1227
|
+
model=model_id,
|
|
1264
1228
|
api_key=self.token,
|
|
1265
1229
|
)
|
|
1266
1230
|
response = await self._inner_post(request_parameters)
|
|
@@ -1275,19 +1239,16 @@ class AsyncInferenceClient:
|
|
|
1275
1239
|
overlap_mask_area_threshold: Optional[float] = None,
|
|
1276
1240
|
subtask: Optional["ImageSegmentationSubtask"] = None,
|
|
1277
1241
|
threshold: Optional[float] = None,
|
|
1278
|
-
) ->
|
|
1242
|
+
) -> list[ImageSegmentationOutputElement]:
|
|
1279
1243
|
"""
|
|
1280
1244
|
Perform image segmentation on the given image using the specified model.
|
|
1281
1245
|
|
|
1282
|
-
|
|
1283
|
-
|
|
1284
|
-
You must have `PIL` installed if you want to work with images (`pip install Pillow`).
|
|
1285
|
-
|
|
1286
|
-
</Tip>
|
|
1246
|
+
> [!WARNING]
|
|
1247
|
+
> You must have `PIL` installed if you want to work with images (`pip install Pillow`).
|
|
1287
1248
|
|
|
1288
1249
|
Args:
|
|
1289
|
-
image (`Union[str, Path, bytes, BinaryIO]`):
|
|
1290
|
-
The image to segment. It can be raw bytes, an image file,
|
|
1250
|
+
image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`):
|
|
1251
|
+
The image to segment. It can be raw bytes, an image file, a URL to an online image, or a PIL Image.
|
|
1291
1252
|
model (`str`, *optional*):
|
|
1292
1253
|
The model to use for image segmentation. Can be a model ID hosted on the Hugging Face Hub or a URL to a
|
|
1293
1254
|
deployed Inference Endpoint. If not provided, the default recommended model for image segmentation will be used.
|
|
@@ -1300,12 +1261,12 @@ class AsyncInferenceClient:
|
|
|
1300
1261
|
threshold (`float`, *optional*):
|
|
1301
1262
|
Probability threshold to filter out predicted masks.
|
|
1302
1263
|
Returns:
|
|
1303
|
-
`
|
|
1264
|
+
`list[ImageSegmentationOutputElement]`: A list of [`ImageSegmentationOutputElement`] items containing the segmented masks and associated attributes.
|
|
1304
1265
|
|
|
1305
1266
|
Raises:
|
|
1306
1267
|
[`InferenceTimeoutError`]:
|
|
1307
1268
|
If the model is unavailable or the request times out.
|
|
1308
|
-
`
|
|
1269
|
+
[`HfHubHTTPError`]:
|
|
1309
1270
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
1310
1271
|
|
|
1311
1272
|
Example:
|
|
@@ -1317,7 +1278,8 @@ class AsyncInferenceClient:
|
|
|
1317
1278
|
[ImageSegmentationOutputElement(score=0.989008, label='LABEL_184', mask=<PIL.PngImagePlugin.PngImageFile image mode=L size=400x300 at 0x7FDD2B129CC0>), ...]
|
|
1318
1279
|
```
|
|
1319
1280
|
"""
|
|
1320
|
-
|
|
1281
|
+
model_id = model or self.model
|
|
1282
|
+
provider_helper = get_provider_helper(self.provider, task="image-segmentation", model=model_id)
|
|
1321
1283
|
request_parameters = provider_helper.prepare_request(
|
|
1322
1284
|
inputs=image,
|
|
1323
1285
|
parameters={
|
|
@@ -1327,10 +1289,11 @@ class AsyncInferenceClient:
|
|
|
1327
1289
|
"threshold": threshold,
|
|
1328
1290
|
},
|
|
1329
1291
|
headers=self.headers,
|
|
1330
|
-
model=
|
|
1292
|
+
model=model_id,
|
|
1331
1293
|
api_key=self.token,
|
|
1332
1294
|
)
|
|
1333
1295
|
response = await self._inner_post(request_parameters)
|
|
1296
|
+
response = provider_helper.get_response(response, request_parameters)
|
|
1334
1297
|
output = ImageSegmentationOutputElement.parse_obj_as_list(response)
|
|
1335
1298
|
for item in output:
|
|
1336
1299
|
item.mask = _b64_to_image(item.mask) # type: ignore [assignment]
|
|
@@ -1351,15 +1314,12 @@ class AsyncInferenceClient:
|
|
|
1351
1314
|
"""
|
|
1352
1315
|
Perform image-to-image translation using a specified model.
|
|
1353
1316
|
|
|
1354
|
-
|
|
1355
|
-
|
|
1356
|
-
You must have `PIL` installed if you want to work with images (`pip install Pillow`).
|
|
1357
|
-
|
|
1358
|
-
</Tip>
|
|
1317
|
+
> [!WARNING]
|
|
1318
|
+
> You must have `PIL` installed if you want to work with images (`pip install Pillow`).
|
|
1359
1319
|
|
|
1360
1320
|
Args:
|
|
1361
|
-
image (`Union[str, Path, bytes, BinaryIO]`):
|
|
1362
|
-
The input image for translation. It can be raw bytes, an image file,
|
|
1321
|
+
image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`):
|
|
1322
|
+
The input image for translation. It can be raw bytes, an image file, a URL to an online image, or a PIL Image.
|
|
1363
1323
|
prompt (`str`, *optional*):
|
|
1364
1324
|
The text prompt to guide the image generation.
|
|
1365
1325
|
negative_prompt (`str`, *optional*):
|
|
@@ -1374,7 +1334,8 @@ class AsyncInferenceClient:
|
|
|
1374
1334
|
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
|
1375
1335
|
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
|
|
1376
1336
|
target_size (`ImageToImageTargetSize`, *optional*):
|
|
1377
|
-
The size in
|
|
1337
|
+
The size in pixels of the output image. This parameter is only supported by some providers and for
|
|
1338
|
+
specific models. It will be ignored when unsupported.
|
|
1378
1339
|
|
|
1379
1340
|
Returns:
|
|
1380
1341
|
`Image`: The translated image.
|
|
@@ -1382,7 +1343,7 @@ class AsyncInferenceClient:
|
|
|
1382
1343
|
Raises:
|
|
1383
1344
|
[`InferenceTimeoutError`]:
|
|
1384
1345
|
If the model is unavailable or the request times out.
|
|
1385
|
-
`
|
|
1346
|
+
[`HfHubHTTPError`]:
|
|
1386
1347
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
1387
1348
|
|
|
1388
1349
|
Example:
|
|
@@ -1393,8 +1354,10 @@ class AsyncInferenceClient:
|
|
|
1393
1354
|
>>> image = await client.image_to_image("cat.jpg", prompt="turn the cat into a tiger")
|
|
1394
1355
|
>>> image.save("tiger.jpg")
|
|
1395
1356
|
```
|
|
1357
|
+
|
|
1396
1358
|
"""
|
|
1397
|
-
|
|
1359
|
+
model_id = model or self.model
|
|
1360
|
+
provider_helper = get_provider_helper(self.provider, task="image-to-image", model=model_id)
|
|
1398
1361
|
request_parameters = provider_helper.prepare_request(
|
|
1399
1362
|
inputs=image,
|
|
1400
1363
|
parameters={
|
|
@@ -1406,22 +1369,103 @@ class AsyncInferenceClient:
|
|
|
1406
1369
|
**kwargs,
|
|
1407
1370
|
},
|
|
1408
1371
|
headers=self.headers,
|
|
1409
|
-
model=
|
|
1372
|
+
model=model_id,
|
|
1410
1373
|
api_key=self.token,
|
|
1411
1374
|
)
|
|
1412
1375
|
response = await self._inner_post(request_parameters)
|
|
1376
|
+
response = provider_helper.get_response(response, request_parameters)
|
|
1413
1377
|
return _bytes_to_image(response)
|
|
1414
1378
|
|
|
1379
|
+
async def image_to_video(
|
|
1380
|
+
self,
|
|
1381
|
+
image: ContentT,
|
|
1382
|
+
*,
|
|
1383
|
+
model: Optional[str] = None,
|
|
1384
|
+
prompt: Optional[str] = None,
|
|
1385
|
+
negative_prompt: Optional[str] = None,
|
|
1386
|
+
num_frames: Optional[float] = None,
|
|
1387
|
+
num_inference_steps: Optional[int] = None,
|
|
1388
|
+
guidance_scale: Optional[float] = None,
|
|
1389
|
+
seed: Optional[int] = None,
|
|
1390
|
+
target_size: Optional[ImageToVideoTargetSize] = None,
|
|
1391
|
+
**kwargs,
|
|
1392
|
+
) -> bytes:
|
|
1393
|
+
"""
|
|
1394
|
+
Generate a video from an input image.
|
|
1395
|
+
|
|
1396
|
+
Args:
|
|
1397
|
+
image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`):
|
|
1398
|
+
The input image to generate a video from. It can be raw bytes, an image file, a URL to an online image, or a PIL Image.
|
|
1399
|
+
model (`str`, *optional*):
|
|
1400
|
+
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
|
1401
|
+
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
|
|
1402
|
+
prompt (`str`, *optional*):
|
|
1403
|
+
The text prompt to guide the video generation.
|
|
1404
|
+
negative_prompt (`str`, *optional*):
|
|
1405
|
+
One prompt to guide what NOT to include in video generation.
|
|
1406
|
+
num_frames (`float`, *optional*):
|
|
1407
|
+
The num_frames parameter determines how many video frames are generated.
|
|
1408
|
+
num_inference_steps (`int`, *optional*):
|
|
1409
|
+
For diffusion models. The number of denoising steps. More denoising steps usually lead to a higher
|
|
1410
|
+
quality image at the expense of slower inference.
|
|
1411
|
+
guidance_scale (`float`, *optional*):
|
|
1412
|
+
For diffusion models. A higher guidance scale value encourages the model to generate videos closely
|
|
1413
|
+
linked to the text prompt at the expense of lower image quality.
|
|
1414
|
+
seed (`int`, *optional*):
|
|
1415
|
+
The seed to use for the video generation.
|
|
1416
|
+
target_size (`ImageToVideoTargetSize`, *optional*):
|
|
1417
|
+
The size in pixel of the output video frames.
|
|
1418
|
+
num_inference_steps (`int`, *optional*):
|
|
1419
|
+
The number of denoising steps. More denoising steps usually lead to a higher quality video at the
|
|
1420
|
+
expense of slower inference.
|
|
1421
|
+
seed (`int`, *optional*):
|
|
1422
|
+
Seed for the random number generator.
|
|
1423
|
+
|
|
1424
|
+
Returns:
|
|
1425
|
+
`bytes`: The generated video.
|
|
1426
|
+
|
|
1427
|
+
Examples:
|
|
1428
|
+
```py
|
|
1429
|
+
# Must be run in an async context
|
|
1430
|
+
>>> from huggingface_hub import AsyncInferenceClient
|
|
1431
|
+
>>> client = AsyncInferenceClient()
|
|
1432
|
+
>>> video = await client.image_to_video("cat.jpg", model="Wan-AI/Wan2.2-I2V-A14B", prompt="turn the cat into a tiger")
|
|
1433
|
+
>>> with open("tiger.mp4", "wb") as f:
|
|
1434
|
+
... f.write(video)
|
|
1435
|
+
```
|
|
1436
|
+
"""
|
|
1437
|
+
model_id = model or self.model
|
|
1438
|
+
provider_helper = get_provider_helper(self.provider, task="image-to-video", model=model_id)
|
|
1439
|
+
request_parameters = provider_helper.prepare_request(
|
|
1440
|
+
inputs=image,
|
|
1441
|
+
parameters={
|
|
1442
|
+
"prompt": prompt,
|
|
1443
|
+
"negative_prompt": negative_prompt,
|
|
1444
|
+
"num_frames": num_frames,
|
|
1445
|
+
"num_inference_steps": num_inference_steps,
|
|
1446
|
+
"guidance_scale": guidance_scale,
|
|
1447
|
+
"seed": seed,
|
|
1448
|
+
"target_size": target_size,
|
|
1449
|
+
**kwargs,
|
|
1450
|
+
},
|
|
1451
|
+
headers=self.headers,
|
|
1452
|
+
model=model_id,
|
|
1453
|
+
api_key=self.token,
|
|
1454
|
+
)
|
|
1455
|
+
response = await self._inner_post(request_parameters)
|
|
1456
|
+
response = provider_helper.get_response(response, request_parameters)
|
|
1457
|
+
return response
|
|
1458
|
+
|
|
1415
1459
|
async def image_to_text(self, image: ContentT, *, model: Optional[str] = None) -> ImageToTextOutput:
|
|
1416
1460
|
"""
|
|
1417
1461
|
Takes an input image and return text.
|
|
1418
1462
|
|
|
1419
1463
|
Models can have very different outputs depending on your use case (image captioning, optical character recognition
|
|
1420
|
-
(OCR), Pix2Struct, etc). Please have a look to the model card to learn more about a model's specificities.
|
|
1464
|
+
(OCR), Pix2Struct, etc.). Please have a look to the model card to learn more about a model's specificities.
|
|
1421
1465
|
|
|
1422
1466
|
Args:
|
|
1423
|
-
image (`Union[str, Path, bytes, BinaryIO]`):
|
|
1424
|
-
The input image to caption. It can be raw bytes, an image file,
|
|
1467
|
+
image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`):
|
|
1468
|
+
The input image to caption. It can be raw bytes, an image file, a URL to an online image, or a PIL Image.
|
|
1425
1469
|
model (`str`, *optional*):
|
|
1426
1470
|
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
|
1427
1471
|
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
|
|
@@ -1432,7 +1476,7 @@ class AsyncInferenceClient:
|
|
|
1432
1476
|
Raises:
|
|
1433
1477
|
[`InferenceTimeoutError`]:
|
|
1434
1478
|
If the model is unavailable or the request times out.
|
|
1435
|
-
`
|
|
1479
|
+
[`HfHubHTTPError`]:
|
|
1436
1480
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
1437
1481
|
|
|
1438
1482
|
Example:
|
|
@@ -1446,45 +1490,43 @@ class AsyncInferenceClient:
|
|
|
1446
1490
|
'a dog laying on the grass next to a flower pot '
|
|
1447
1491
|
```
|
|
1448
1492
|
"""
|
|
1449
|
-
|
|
1493
|
+
model_id = model or self.model
|
|
1494
|
+
provider_helper = get_provider_helper(self.provider, task="image-to-text", model=model_id)
|
|
1450
1495
|
request_parameters = provider_helper.prepare_request(
|
|
1451
1496
|
inputs=image,
|
|
1452
1497
|
parameters={},
|
|
1453
1498
|
headers=self.headers,
|
|
1454
|
-
model=
|
|
1499
|
+
model=model_id,
|
|
1455
1500
|
api_key=self.token,
|
|
1456
1501
|
)
|
|
1457
1502
|
response = await self._inner_post(request_parameters)
|
|
1458
|
-
|
|
1459
|
-
return
|
|
1503
|
+
output_list: list[ImageToTextOutput] = ImageToTextOutput.parse_obj_as_list(response)
|
|
1504
|
+
return output_list[0]
|
|
1460
1505
|
|
|
1461
1506
|
async def object_detection(
|
|
1462
1507
|
self, image: ContentT, *, model: Optional[str] = None, threshold: Optional[float] = None
|
|
1463
|
-
) ->
|
|
1508
|
+
) -> list[ObjectDetectionOutputElement]:
|
|
1464
1509
|
"""
|
|
1465
1510
|
Perform object detection on the given image using the specified model.
|
|
1466
1511
|
|
|
1467
|
-
|
|
1468
|
-
|
|
1469
|
-
You must have `PIL` installed if you want to work with images (`pip install Pillow`).
|
|
1470
|
-
|
|
1471
|
-
</Tip>
|
|
1512
|
+
> [!WARNING]
|
|
1513
|
+
> You must have `PIL` installed if you want to work with images (`pip install Pillow`).
|
|
1472
1514
|
|
|
1473
1515
|
Args:
|
|
1474
|
-
image (`Union[str, Path, bytes, BinaryIO]`):
|
|
1475
|
-
The image to detect objects on. It can be raw bytes, an image file,
|
|
1516
|
+
image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`):
|
|
1517
|
+
The image to detect objects on. It can be raw bytes, an image file, a URL to an online image, or a PIL Image.
|
|
1476
1518
|
model (`str`, *optional*):
|
|
1477
1519
|
The model to use for object detection. Can be a model ID hosted on the Hugging Face Hub or a URL to a
|
|
1478
1520
|
deployed Inference Endpoint. If not provided, the default recommended model for object detection (DETR) will be used.
|
|
1479
1521
|
threshold (`float`, *optional*):
|
|
1480
1522
|
The probability necessary to make a prediction.
|
|
1481
1523
|
Returns:
|
|
1482
|
-
`
|
|
1524
|
+
`list[ObjectDetectionOutputElement]`: A list of [`ObjectDetectionOutputElement`] items containing the bounding boxes and associated attributes.
|
|
1483
1525
|
|
|
1484
1526
|
Raises:
|
|
1485
1527
|
[`InferenceTimeoutError`]:
|
|
1486
1528
|
If the model is unavailable or the request times out.
|
|
1487
|
-
`
|
|
1529
|
+
[`HfHubHTTPError`]:
|
|
1488
1530
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
1489
1531
|
`ValueError`:
|
|
1490
1532
|
If the request output is not a List.
|
|
@@ -1498,12 +1540,13 @@ class AsyncInferenceClient:
|
|
|
1498
1540
|
[ObjectDetectionOutputElement(score=0.9486683011054993, label='person', box=ObjectDetectionBoundingBox(xmin=59, ymin=39, xmax=420, ymax=510)), ...]
|
|
1499
1541
|
```
|
|
1500
1542
|
"""
|
|
1501
|
-
|
|
1543
|
+
model_id = model or self.model
|
|
1544
|
+
provider_helper = get_provider_helper(self.provider, task="object-detection", model=model_id)
|
|
1502
1545
|
request_parameters = provider_helper.prepare_request(
|
|
1503
1546
|
inputs=image,
|
|
1504
1547
|
parameters={"threshold": threshold},
|
|
1505
1548
|
headers=self.headers,
|
|
1506
|
-
model=
|
|
1549
|
+
model=model_id,
|
|
1507
1550
|
api_key=self.token,
|
|
1508
1551
|
)
|
|
1509
1552
|
response = await self._inner_post(request_parameters)
|
|
@@ -1522,7 +1565,7 @@ class AsyncInferenceClient:
|
|
|
1522
1565
|
max_question_len: Optional[int] = None,
|
|
1523
1566
|
max_seq_len: Optional[int] = None,
|
|
1524
1567
|
top_k: Optional[int] = None,
|
|
1525
|
-
) -> Union[QuestionAnsweringOutputElement,
|
|
1568
|
+
) -> Union[QuestionAnsweringOutputElement, list[QuestionAnsweringOutputElement]]:
|
|
1526
1569
|
"""
|
|
1527
1570
|
Retrieve the answer to a question from a given text.
|
|
1528
1571
|
|
|
@@ -1554,13 +1597,13 @@ class AsyncInferenceClient:
|
|
|
1554
1597
|
topk answers if there are not enough options available within the context.
|
|
1555
1598
|
|
|
1556
1599
|
Returns:
|
|
1557
|
-
Union[`QuestionAnsweringOutputElement`,
|
|
1600
|
+
Union[`QuestionAnsweringOutputElement`, list[`QuestionAnsweringOutputElement`]]:
|
|
1558
1601
|
When top_k is 1 or not provided, it returns a single `QuestionAnsweringOutputElement`.
|
|
1559
1602
|
When top_k is greater than 1, it returns a list of `QuestionAnsweringOutputElement`.
|
|
1560
1603
|
Raises:
|
|
1561
1604
|
[`InferenceTimeoutError`]:
|
|
1562
1605
|
If the model is unavailable or the request times out.
|
|
1563
|
-
`
|
|
1606
|
+
[`HfHubHTTPError`]:
|
|
1564
1607
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
1565
1608
|
|
|
1566
1609
|
Example:
|
|
@@ -1572,9 +1615,10 @@ class AsyncInferenceClient:
|
|
|
1572
1615
|
QuestionAnsweringOutputElement(answer='Clara', end=16, score=0.9326565265655518, start=11)
|
|
1573
1616
|
```
|
|
1574
1617
|
"""
|
|
1575
|
-
|
|
1618
|
+
model_id = model or self.model
|
|
1619
|
+
provider_helper = get_provider_helper(self.provider, task="question-answering", model=model_id)
|
|
1576
1620
|
request_parameters = provider_helper.prepare_request(
|
|
1577
|
-
inputs=
|
|
1621
|
+
inputs={"question": question, "context": context},
|
|
1578
1622
|
parameters={
|
|
1579
1623
|
"align_to_words": align_to_words,
|
|
1580
1624
|
"doc_stride": doc_stride,
|
|
@@ -1584,9 +1628,8 @@ class AsyncInferenceClient:
|
|
|
1584
1628
|
"max_seq_len": max_seq_len,
|
|
1585
1629
|
"top_k": top_k,
|
|
1586
1630
|
},
|
|
1587
|
-
extra_payload={"question": question, "context": context},
|
|
1588
1631
|
headers=self.headers,
|
|
1589
|
-
model=
|
|
1632
|
+
model=model_id,
|
|
1590
1633
|
api_key=self.token,
|
|
1591
1634
|
)
|
|
1592
1635
|
response = await self._inner_post(request_parameters)
|
|
@@ -1595,28 +1638,28 @@ class AsyncInferenceClient:
|
|
|
1595
1638
|
return output
|
|
1596
1639
|
|
|
1597
1640
|
async def sentence_similarity(
|
|
1598
|
-
self, sentence: str, other_sentences:
|
|
1599
|
-
) ->
|
|
1641
|
+
self, sentence: str, other_sentences: list[str], *, model: Optional[str] = None
|
|
1642
|
+
) -> list[float]:
|
|
1600
1643
|
"""
|
|
1601
1644
|
Compute the semantic similarity between a sentence and a list of other sentences by comparing their embeddings.
|
|
1602
1645
|
|
|
1603
1646
|
Args:
|
|
1604
1647
|
sentence (`str`):
|
|
1605
1648
|
The main sentence to compare to others.
|
|
1606
|
-
other_sentences (`
|
|
1649
|
+
other_sentences (`list[str]`):
|
|
1607
1650
|
The list of sentences to compare to.
|
|
1608
1651
|
model (`str`, *optional*):
|
|
1609
|
-
The model to use for the
|
|
1610
|
-
a deployed Inference Endpoint. If not provided, the default recommended
|
|
1652
|
+
The model to use for the sentence similarity task. Can be a model ID hosted on the Hugging Face Hub or a URL to
|
|
1653
|
+
a deployed Inference Endpoint. If not provided, the default recommended sentence similarity model will be used.
|
|
1611
1654
|
Defaults to None.
|
|
1612
1655
|
|
|
1613
1656
|
Returns:
|
|
1614
|
-
`
|
|
1657
|
+
`list[float]`: The embedding representing the input text.
|
|
1615
1658
|
|
|
1616
1659
|
Raises:
|
|
1617
1660
|
[`InferenceTimeoutError`]:
|
|
1618
1661
|
If the model is unavailable or the request times out.
|
|
1619
|
-
`
|
|
1662
|
+
[`HfHubHTTPError`]:
|
|
1620
1663
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
1621
1664
|
|
|
1622
1665
|
Example:
|
|
@@ -1635,13 +1678,14 @@ class AsyncInferenceClient:
|
|
|
1635
1678
|
[0.7785726189613342, 0.45876261591911316, 0.2906220555305481]
|
|
1636
1679
|
```
|
|
1637
1680
|
"""
|
|
1638
|
-
|
|
1681
|
+
model_id = model or self.model
|
|
1682
|
+
provider_helper = get_provider_helper(self.provider, task="sentence-similarity", model=model_id)
|
|
1639
1683
|
request_parameters = provider_helper.prepare_request(
|
|
1640
|
-
inputs=
|
|
1684
|
+
inputs={"source_sentence": sentence, "sentences": other_sentences},
|
|
1641
1685
|
parameters={},
|
|
1642
|
-
extra_payload={
|
|
1686
|
+
extra_payload={},
|
|
1643
1687
|
headers=self.headers,
|
|
1644
|
-
model=
|
|
1688
|
+
model=model_id,
|
|
1645
1689
|
api_key=self.token,
|
|
1646
1690
|
)
|
|
1647
1691
|
response = await self._inner_post(request_parameters)
|
|
@@ -1653,7 +1697,7 @@ class AsyncInferenceClient:
|
|
|
1653
1697
|
*,
|
|
1654
1698
|
model: Optional[str] = None,
|
|
1655
1699
|
clean_up_tokenization_spaces: Optional[bool] = None,
|
|
1656
|
-
generate_parameters: Optional[
|
|
1700
|
+
generate_parameters: Optional[dict[str, Any]] = None,
|
|
1657
1701
|
truncation: Optional["SummarizationTruncationStrategy"] = None,
|
|
1658
1702
|
) -> SummarizationOutput:
|
|
1659
1703
|
"""
|
|
@@ -1667,7 +1711,7 @@ class AsyncInferenceClient:
|
|
|
1667
1711
|
Inference Endpoint. If not provided, the default recommended model for summarization will be used.
|
|
1668
1712
|
clean_up_tokenization_spaces (`bool`, *optional*):
|
|
1669
1713
|
Whether to clean up the potential extra spaces in the text output.
|
|
1670
|
-
generate_parameters (`
|
|
1714
|
+
generate_parameters (`dict[str, Any]`, *optional*):
|
|
1671
1715
|
Additional parametrization of the text generation algorithm.
|
|
1672
1716
|
truncation (`"SummarizationTruncationStrategy"`, *optional*):
|
|
1673
1717
|
The truncation strategy to use.
|
|
@@ -1677,7 +1721,7 @@ class AsyncInferenceClient:
|
|
|
1677
1721
|
Raises:
|
|
1678
1722
|
[`InferenceTimeoutError`]:
|
|
1679
1723
|
If the model is unavailable or the request times out.
|
|
1680
|
-
`
|
|
1724
|
+
[`HfHubHTTPError`]:
|
|
1681
1725
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
1682
1726
|
|
|
1683
1727
|
Example:
|
|
@@ -1694,12 +1738,13 @@ class AsyncInferenceClient:
|
|
|
1694
1738
|
"generate_parameters": generate_parameters,
|
|
1695
1739
|
"truncation": truncation,
|
|
1696
1740
|
}
|
|
1697
|
-
|
|
1741
|
+
model_id = model or self.model
|
|
1742
|
+
provider_helper = get_provider_helper(self.provider, task="summarization", model=model_id)
|
|
1698
1743
|
request_parameters = provider_helper.prepare_request(
|
|
1699
1744
|
inputs=text,
|
|
1700
1745
|
parameters=parameters,
|
|
1701
1746
|
headers=self.headers,
|
|
1702
|
-
model=
|
|
1747
|
+
model=model_id,
|
|
1703
1748
|
api_key=self.token,
|
|
1704
1749
|
)
|
|
1705
1750
|
response = await self._inner_post(request_parameters)
|
|
@@ -1707,7 +1752,7 @@ class AsyncInferenceClient:
|
|
|
1707
1752
|
|
|
1708
1753
|
async def table_question_answering(
|
|
1709
1754
|
self,
|
|
1710
|
-
table:
|
|
1755
|
+
table: dict[str, Any],
|
|
1711
1756
|
query: str,
|
|
1712
1757
|
*,
|
|
1713
1758
|
model: Optional[str] = None,
|
|
@@ -1742,7 +1787,7 @@ class AsyncInferenceClient:
|
|
|
1742
1787
|
Raises:
|
|
1743
1788
|
[`InferenceTimeoutError`]:
|
|
1744
1789
|
If the model is unavailable or the request times out.
|
|
1745
|
-
`
|
|
1790
|
+
[`HfHubHTTPError`]:
|
|
1746
1791
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
1747
1792
|
|
|
1748
1793
|
Example:
|
|
@@ -1756,24 +1801,24 @@ class AsyncInferenceClient:
|
|
|
1756
1801
|
TableQuestionAnsweringOutputElement(answer='36542', coordinates=[[0, 1]], cells=['36542'], aggregator='AVERAGE')
|
|
1757
1802
|
```
|
|
1758
1803
|
"""
|
|
1759
|
-
|
|
1804
|
+
model_id = model or self.model
|
|
1805
|
+
provider_helper = get_provider_helper(self.provider, task="table-question-answering", model=model_id)
|
|
1760
1806
|
request_parameters = provider_helper.prepare_request(
|
|
1761
|
-
inputs=
|
|
1807
|
+
inputs={"query": query, "table": table},
|
|
1762
1808
|
parameters={"model": model, "padding": padding, "sequential": sequential, "truncation": truncation},
|
|
1763
|
-
extra_payload={"query": query, "table": table},
|
|
1764
1809
|
headers=self.headers,
|
|
1765
|
-
model=
|
|
1810
|
+
model=model_id,
|
|
1766
1811
|
api_key=self.token,
|
|
1767
1812
|
)
|
|
1768
1813
|
response = await self._inner_post(request_parameters)
|
|
1769
1814
|
return TableQuestionAnsweringOutputElement.parse_obj_as_instance(response)
|
|
1770
1815
|
|
|
1771
|
-
async def tabular_classification(self, table:
|
|
1816
|
+
async def tabular_classification(self, table: dict[str, Any], *, model: Optional[str] = None) -> list[str]:
|
|
1772
1817
|
"""
|
|
1773
1818
|
Classifying a target category (a group) based on a set of attributes.
|
|
1774
1819
|
|
|
1775
1820
|
Args:
|
|
1776
|
-
table (`
|
|
1821
|
+
table (`dict[str, Any]`):
|
|
1777
1822
|
Set of attributes to classify.
|
|
1778
1823
|
model (`str`, *optional*):
|
|
1779
1824
|
The model to use for the tabular classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to
|
|
@@ -1786,7 +1831,7 @@ class AsyncInferenceClient:
|
|
|
1786
1831
|
Raises:
|
|
1787
1832
|
[`InferenceTimeoutError`]:
|
|
1788
1833
|
If the model is unavailable or the request times out.
|
|
1789
|
-
`
|
|
1834
|
+
[`HfHubHTTPError`]:
|
|
1790
1835
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
1791
1836
|
|
|
1792
1837
|
Example:
|
|
@@ -1811,24 +1856,25 @@ class AsyncInferenceClient:
|
|
|
1811
1856
|
["5", "5", "5"]
|
|
1812
1857
|
```
|
|
1813
1858
|
"""
|
|
1814
|
-
|
|
1859
|
+
model_id = model or self.model
|
|
1860
|
+
provider_helper = get_provider_helper(self.provider, task="tabular-classification", model=model_id)
|
|
1815
1861
|
request_parameters = provider_helper.prepare_request(
|
|
1816
1862
|
inputs=None,
|
|
1817
1863
|
extra_payload={"table": table},
|
|
1818
1864
|
parameters={},
|
|
1819
1865
|
headers=self.headers,
|
|
1820
|
-
model=
|
|
1866
|
+
model=model_id,
|
|
1821
1867
|
api_key=self.token,
|
|
1822
1868
|
)
|
|
1823
1869
|
response = await self._inner_post(request_parameters)
|
|
1824
1870
|
return _bytes_to_list(response)
|
|
1825
1871
|
|
|
1826
|
-
async def tabular_regression(self, table:
|
|
1872
|
+
async def tabular_regression(self, table: dict[str, Any], *, model: Optional[str] = None) -> list[float]:
|
|
1827
1873
|
"""
|
|
1828
1874
|
Predicting a numerical target value given a set of attributes/features in a table.
|
|
1829
1875
|
|
|
1830
1876
|
Args:
|
|
1831
|
-
table (`
|
|
1877
|
+
table (`dict[str, Any]`):
|
|
1832
1878
|
Set of attributes stored in a table. The attributes used to predict the target can be both numerical and categorical.
|
|
1833
1879
|
model (`str`, *optional*):
|
|
1834
1880
|
The model to use for the tabular regression task. Can be a model ID hosted on the Hugging Face Hub or a URL to
|
|
@@ -1841,7 +1887,7 @@ class AsyncInferenceClient:
|
|
|
1841
1887
|
Raises:
|
|
1842
1888
|
[`InferenceTimeoutError`]:
|
|
1843
1889
|
If the model is unavailable or the request times out.
|
|
1844
|
-
`
|
|
1890
|
+
[`HfHubHTTPError`]:
|
|
1845
1891
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
1846
1892
|
|
|
1847
1893
|
Example:
|
|
@@ -1861,13 +1907,14 @@ class AsyncInferenceClient:
|
|
|
1861
1907
|
[110, 120, 130]
|
|
1862
1908
|
```
|
|
1863
1909
|
"""
|
|
1864
|
-
|
|
1910
|
+
model_id = model or self.model
|
|
1911
|
+
provider_helper = get_provider_helper(self.provider, task="tabular-regression", model=model_id)
|
|
1865
1912
|
request_parameters = provider_helper.prepare_request(
|
|
1866
1913
|
inputs=None,
|
|
1867
1914
|
parameters={},
|
|
1868
1915
|
extra_payload={"table": table},
|
|
1869
1916
|
headers=self.headers,
|
|
1870
|
-
model=
|
|
1917
|
+
model=model_id,
|
|
1871
1918
|
api_key=self.token,
|
|
1872
1919
|
)
|
|
1873
1920
|
response = await self._inner_post(request_parameters)
|
|
@@ -1880,7 +1927,7 @@ class AsyncInferenceClient:
|
|
|
1880
1927
|
model: Optional[str] = None,
|
|
1881
1928
|
top_k: Optional[int] = None,
|
|
1882
1929
|
function_to_apply: Optional["TextClassificationOutputTransform"] = None,
|
|
1883
|
-
) ->
|
|
1930
|
+
) -> list[TextClassificationOutputElement]:
|
|
1884
1931
|
"""
|
|
1885
1932
|
Perform text classification (e.g. sentiment-analysis) on the given text.
|
|
1886
1933
|
|
|
@@ -1897,12 +1944,12 @@ class AsyncInferenceClient:
|
|
|
1897
1944
|
The function to apply to the model outputs in order to retrieve the scores.
|
|
1898
1945
|
|
|
1899
1946
|
Returns:
|
|
1900
|
-
`
|
|
1947
|
+
`list[TextClassificationOutputElement]`: a list of [`TextClassificationOutputElement`] items containing the predicted label and associated probability.
|
|
1901
1948
|
|
|
1902
1949
|
Raises:
|
|
1903
1950
|
[`InferenceTimeoutError`]:
|
|
1904
1951
|
If the model is unavailable or the request times out.
|
|
1905
|
-
`
|
|
1952
|
+
[`HfHubHTTPError`]:
|
|
1906
1953
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
1907
1954
|
|
|
1908
1955
|
Example:
|
|
@@ -1917,7 +1964,8 @@ class AsyncInferenceClient:
|
|
|
1917
1964
|
]
|
|
1918
1965
|
```
|
|
1919
1966
|
"""
|
|
1920
|
-
|
|
1967
|
+
model_id = model or self.model
|
|
1968
|
+
provider_helper = get_provider_helper(self.provider, task="text-classification", model=model_id)
|
|
1921
1969
|
request_parameters = provider_helper.prepare_request(
|
|
1922
1970
|
inputs=text,
|
|
1923
1971
|
parameters={
|
|
@@ -1925,33 +1973,33 @@ class AsyncInferenceClient:
|
|
|
1925
1973
|
"top_k": top_k,
|
|
1926
1974
|
},
|
|
1927
1975
|
headers=self.headers,
|
|
1928
|
-
model=
|
|
1976
|
+
model=model_id,
|
|
1929
1977
|
api_key=self.token,
|
|
1930
1978
|
)
|
|
1931
1979
|
response = await self._inner_post(request_parameters)
|
|
1932
1980
|
return TextClassificationOutputElement.parse_obj_as_list(response)[0] # type: ignore [return-value]
|
|
1933
1981
|
|
|
1934
1982
|
@overload
|
|
1935
|
-
async def text_generation(
|
|
1983
|
+
async def text_generation(
|
|
1936
1984
|
self,
|
|
1937
1985
|
prompt: str,
|
|
1938
1986
|
*,
|
|
1939
|
-
details: Literal[
|
|
1940
|
-
stream: Literal[
|
|
1987
|
+
details: Literal[True],
|
|
1988
|
+
stream: Literal[True],
|
|
1941
1989
|
model: Optional[str] = None,
|
|
1942
1990
|
# Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
|
|
1943
1991
|
adapter_id: Optional[str] = None,
|
|
1944
1992
|
best_of: Optional[int] = None,
|
|
1945
1993
|
decoder_input_details: Optional[bool] = None,
|
|
1946
|
-
do_sample: Optional[bool] =
|
|
1994
|
+
do_sample: Optional[bool] = None,
|
|
1947
1995
|
frequency_penalty: Optional[float] = None,
|
|
1948
1996
|
grammar: Optional[TextGenerationInputGrammarType] = None,
|
|
1949
1997
|
max_new_tokens: Optional[int] = None,
|
|
1950
1998
|
repetition_penalty: Optional[float] = None,
|
|
1951
|
-
return_full_text: Optional[bool] =
|
|
1999
|
+
return_full_text: Optional[bool] = None,
|
|
1952
2000
|
seed: Optional[int] = None,
|
|
1953
|
-
stop: Optional[
|
|
1954
|
-
stop_sequences: Optional[
|
|
2001
|
+
stop: Optional[list[str]] = None,
|
|
2002
|
+
stop_sequences: Optional[list[str]] = None, # Deprecated, use `stop` instead
|
|
1955
2003
|
temperature: Optional[float] = None,
|
|
1956
2004
|
top_k: Optional[int] = None,
|
|
1957
2005
|
top_n_tokens: Optional[int] = None,
|
|
@@ -1959,29 +2007,29 @@ class AsyncInferenceClient:
|
|
|
1959
2007
|
truncate: Optional[int] = None,
|
|
1960
2008
|
typical_p: Optional[float] = None,
|
|
1961
2009
|
watermark: Optional[bool] = None,
|
|
1962
|
-
) ->
|
|
2010
|
+
) -> AsyncIterable[TextGenerationStreamOutput]: ...
|
|
1963
2011
|
|
|
1964
2012
|
@overload
|
|
1965
|
-
async def text_generation(
|
|
2013
|
+
async def text_generation(
|
|
1966
2014
|
self,
|
|
1967
2015
|
prompt: str,
|
|
1968
2016
|
*,
|
|
1969
|
-
details: Literal[True]
|
|
1970
|
-
stream: Literal[False] =
|
|
2017
|
+
details: Literal[True],
|
|
2018
|
+
stream: Optional[Literal[False]] = None,
|
|
1971
2019
|
model: Optional[str] = None,
|
|
1972
2020
|
# Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
|
|
1973
2021
|
adapter_id: Optional[str] = None,
|
|
1974
2022
|
best_of: Optional[int] = None,
|
|
1975
2023
|
decoder_input_details: Optional[bool] = None,
|
|
1976
|
-
do_sample: Optional[bool] =
|
|
2024
|
+
do_sample: Optional[bool] = None,
|
|
1977
2025
|
frequency_penalty: Optional[float] = None,
|
|
1978
2026
|
grammar: Optional[TextGenerationInputGrammarType] = None,
|
|
1979
2027
|
max_new_tokens: Optional[int] = None,
|
|
1980
2028
|
repetition_penalty: Optional[float] = None,
|
|
1981
|
-
return_full_text: Optional[bool] =
|
|
2029
|
+
return_full_text: Optional[bool] = None,
|
|
1982
2030
|
seed: Optional[int] = None,
|
|
1983
|
-
stop: Optional[
|
|
1984
|
-
stop_sequences: Optional[
|
|
2031
|
+
stop: Optional[list[str]] = None,
|
|
2032
|
+
stop_sequences: Optional[list[str]] = None, # Deprecated, use `stop` instead
|
|
1985
2033
|
temperature: Optional[float] = None,
|
|
1986
2034
|
top_k: Optional[int] = None,
|
|
1987
2035
|
top_n_tokens: Optional[int] = None,
|
|
@@ -1992,26 +2040,26 @@ class AsyncInferenceClient:
|
|
|
1992
2040
|
) -> TextGenerationOutput: ...
|
|
1993
2041
|
|
|
1994
2042
|
@overload
|
|
1995
|
-
async def text_generation(
|
|
2043
|
+
async def text_generation(
|
|
1996
2044
|
self,
|
|
1997
2045
|
prompt: str,
|
|
1998
2046
|
*,
|
|
1999
|
-
details: Literal[False] =
|
|
2000
|
-
stream: Literal[True]
|
|
2047
|
+
details: Optional[Literal[False]] = None,
|
|
2048
|
+
stream: Literal[True],
|
|
2001
2049
|
model: Optional[str] = None,
|
|
2002
2050
|
# Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
|
|
2003
2051
|
adapter_id: Optional[str] = None,
|
|
2004
2052
|
best_of: Optional[int] = None,
|
|
2005
2053
|
decoder_input_details: Optional[bool] = None,
|
|
2006
|
-
do_sample: Optional[bool] =
|
|
2054
|
+
do_sample: Optional[bool] = None,
|
|
2007
2055
|
frequency_penalty: Optional[float] = None,
|
|
2008
2056
|
grammar: Optional[TextGenerationInputGrammarType] = None,
|
|
2009
2057
|
max_new_tokens: Optional[int] = None,
|
|
2010
2058
|
repetition_penalty: Optional[float] = None,
|
|
2011
|
-
return_full_text: Optional[bool] =
|
|
2059
|
+
return_full_text: Optional[bool] = None, # Manual default value
|
|
2012
2060
|
seed: Optional[int] = None,
|
|
2013
|
-
stop: Optional[
|
|
2014
|
-
stop_sequences: Optional[
|
|
2061
|
+
stop: Optional[list[str]] = None,
|
|
2062
|
+
stop_sequences: Optional[list[str]] = None, # Deprecated, use `stop` instead
|
|
2015
2063
|
temperature: Optional[float] = None,
|
|
2016
2064
|
top_k: Optional[int] = None,
|
|
2017
2065
|
top_n_tokens: Optional[int] = None,
|
|
@@ -2022,26 +2070,26 @@ class AsyncInferenceClient:
|
|
|
2022
2070
|
) -> AsyncIterable[str]: ...
|
|
2023
2071
|
|
|
2024
2072
|
@overload
|
|
2025
|
-
async def text_generation(
|
|
2073
|
+
async def text_generation(
|
|
2026
2074
|
self,
|
|
2027
2075
|
prompt: str,
|
|
2028
2076
|
*,
|
|
2029
|
-
details: Literal[
|
|
2030
|
-
stream: Literal[
|
|
2077
|
+
details: Optional[Literal[False]] = None,
|
|
2078
|
+
stream: Optional[Literal[False]] = None,
|
|
2031
2079
|
model: Optional[str] = None,
|
|
2032
2080
|
# Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
|
|
2033
2081
|
adapter_id: Optional[str] = None,
|
|
2034
2082
|
best_of: Optional[int] = None,
|
|
2035
2083
|
decoder_input_details: Optional[bool] = None,
|
|
2036
|
-
do_sample: Optional[bool] =
|
|
2084
|
+
do_sample: Optional[bool] = None,
|
|
2037
2085
|
frequency_penalty: Optional[float] = None,
|
|
2038
2086
|
grammar: Optional[TextGenerationInputGrammarType] = None,
|
|
2039
2087
|
max_new_tokens: Optional[int] = None,
|
|
2040
2088
|
repetition_penalty: Optional[float] = None,
|
|
2041
|
-
return_full_text: Optional[bool] =
|
|
2089
|
+
return_full_text: Optional[bool] = None,
|
|
2042
2090
|
seed: Optional[int] = None,
|
|
2043
|
-
stop: Optional[
|
|
2044
|
-
stop_sequences: Optional[
|
|
2091
|
+
stop: Optional[list[str]] = None,
|
|
2092
|
+
stop_sequences: Optional[list[str]] = None, # Deprecated, use `stop` instead
|
|
2045
2093
|
temperature: Optional[float] = None,
|
|
2046
2094
|
top_k: Optional[int] = None,
|
|
2047
2095
|
top_n_tokens: Optional[int] = None,
|
|
@@ -2049,29 +2097,29 @@ class AsyncInferenceClient:
|
|
|
2049
2097
|
truncate: Optional[int] = None,
|
|
2050
2098
|
typical_p: Optional[float] = None,
|
|
2051
2099
|
watermark: Optional[bool] = None,
|
|
2052
|
-
) ->
|
|
2100
|
+
) -> str: ...
|
|
2053
2101
|
|
|
2054
2102
|
@overload
|
|
2055
2103
|
async def text_generation(
|
|
2056
2104
|
self,
|
|
2057
2105
|
prompt: str,
|
|
2058
2106
|
*,
|
|
2059
|
-
details:
|
|
2060
|
-
stream: bool =
|
|
2107
|
+
details: Optional[bool] = None,
|
|
2108
|
+
stream: Optional[bool] = None,
|
|
2061
2109
|
model: Optional[str] = None,
|
|
2062
2110
|
# Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
|
|
2063
2111
|
adapter_id: Optional[str] = None,
|
|
2064
2112
|
best_of: Optional[int] = None,
|
|
2065
2113
|
decoder_input_details: Optional[bool] = None,
|
|
2066
|
-
do_sample: Optional[bool] =
|
|
2114
|
+
do_sample: Optional[bool] = None,
|
|
2067
2115
|
frequency_penalty: Optional[float] = None,
|
|
2068
2116
|
grammar: Optional[TextGenerationInputGrammarType] = None,
|
|
2069
2117
|
max_new_tokens: Optional[int] = None,
|
|
2070
2118
|
repetition_penalty: Optional[float] = None,
|
|
2071
|
-
return_full_text: Optional[bool] =
|
|
2119
|
+
return_full_text: Optional[bool] = None,
|
|
2072
2120
|
seed: Optional[int] = None,
|
|
2073
|
-
stop: Optional[
|
|
2074
|
-
stop_sequences: Optional[
|
|
2121
|
+
stop: Optional[list[str]] = None,
|
|
2122
|
+
stop_sequences: Optional[list[str]] = None, # Deprecated, use `stop` instead
|
|
2075
2123
|
temperature: Optional[float] = None,
|
|
2076
2124
|
top_k: Optional[int] = None,
|
|
2077
2125
|
top_n_tokens: Optional[int] = None,
|
|
@@ -2079,28 +2127,28 @@ class AsyncInferenceClient:
|
|
|
2079
2127
|
truncate: Optional[int] = None,
|
|
2080
2128
|
typical_p: Optional[float] = None,
|
|
2081
2129
|
watermark: Optional[bool] = None,
|
|
2082
|
-
) -> Union[TextGenerationOutput, AsyncIterable[TextGenerationStreamOutput]]: ...
|
|
2130
|
+
) -> Union[str, TextGenerationOutput, AsyncIterable[str], AsyncIterable[TextGenerationStreamOutput]]: ...
|
|
2083
2131
|
|
|
2084
2132
|
async def text_generation(
|
|
2085
2133
|
self,
|
|
2086
2134
|
prompt: str,
|
|
2087
2135
|
*,
|
|
2088
|
-
details: bool =
|
|
2089
|
-
stream: bool =
|
|
2136
|
+
details: Optional[bool] = None,
|
|
2137
|
+
stream: Optional[bool] = None,
|
|
2090
2138
|
model: Optional[str] = None,
|
|
2091
2139
|
# Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
|
|
2092
2140
|
adapter_id: Optional[str] = None,
|
|
2093
2141
|
best_of: Optional[int] = None,
|
|
2094
2142
|
decoder_input_details: Optional[bool] = None,
|
|
2095
|
-
do_sample: Optional[bool] =
|
|
2143
|
+
do_sample: Optional[bool] = None,
|
|
2096
2144
|
frequency_penalty: Optional[float] = None,
|
|
2097
2145
|
grammar: Optional[TextGenerationInputGrammarType] = None,
|
|
2098
2146
|
max_new_tokens: Optional[int] = None,
|
|
2099
2147
|
repetition_penalty: Optional[float] = None,
|
|
2100
|
-
return_full_text: Optional[bool] =
|
|
2148
|
+
return_full_text: Optional[bool] = None,
|
|
2101
2149
|
seed: Optional[int] = None,
|
|
2102
|
-
stop: Optional[
|
|
2103
|
-
stop_sequences: Optional[
|
|
2150
|
+
stop: Optional[list[str]] = None,
|
|
2151
|
+
stop_sequences: Optional[list[str]] = None, # Deprecated, use `stop` instead
|
|
2104
2152
|
temperature: Optional[float] = None,
|
|
2105
2153
|
top_k: Optional[int] = None,
|
|
2106
2154
|
top_n_tokens: Optional[int] = None,
|
|
@@ -2112,12 +2160,9 @@ class AsyncInferenceClient:
|
|
|
2112
2160
|
"""
|
|
2113
2161
|
Given a prompt, generate the following text.
|
|
2114
2162
|
|
|
2115
|
-
|
|
2116
|
-
|
|
2117
|
-
|
|
2118
|
-
It accepts a list of messages instead of a single text prompt and handles the chat templating for you.
|
|
2119
|
-
|
|
2120
|
-
</Tip>
|
|
2163
|
+
> [!TIP]
|
|
2164
|
+
> If you want to generate a response from chat messages, you should use the [`InferenceClient.chat_completion`] method.
|
|
2165
|
+
> It accepts a list of messages instead of a single text prompt and handles the chat templating for you.
|
|
2121
2166
|
|
|
2122
2167
|
Args:
|
|
2123
2168
|
prompt (`str`):
|
|
@@ -2156,9 +2201,9 @@ class AsyncInferenceClient:
|
|
|
2156
2201
|
Whether to prepend the prompt to the generated text
|
|
2157
2202
|
seed (`int`, *optional*):
|
|
2158
2203
|
Random sampling seed
|
|
2159
|
-
stop (`
|
|
2204
|
+
stop (`list[str]`, *optional*):
|
|
2160
2205
|
Stop generating tokens if a member of `stop` is generated.
|
|
2161
|
-
stop_sequences (`
|
|
2206
|
+
stop_sequences (`list[str]`, *optional*):
|
|
2162
2207
|
Deprecated argument. Use `stop` instead.
|
|
2163
2208
|
temperature (`float`, *optional*):
|
|
2164
2209
|
The value used to module the logits distribution.
|
|
@@ -2175,14 +2220,14 @@ class AsyncInferenceClient:
|
|
|
2175
2220
|
typical_p (`float`, *optional`):
|
|
2176
2221
|
Typical Decoding mass
|
|
2177
2222
|
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
|
|
2178
|
-
watermark (`bool`, *optional
|
|
2223
|
+
watermark (`bool`, *optional*):
|
|
2179
2224
|
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
|
2180
2225
|
|
|
2181
2226
|
Returns:
|
|
2182
|
-
`Union[str, TextGenerationOutput,
|
|
2227
|
+
`Union[str, TextGenerationOutput, AsyncIterable[str], AsyncIterable[TextGenerationStreamOutput]]`:
|
|
2183
2228
|
Generated text returned from the server:
|
|
2184
2229
|
- if `stream=False` and `details=False`, the generated text is returned as a `str` (default)
|
|
2185
|
-
- if `stream=True` and `details=False`, the generated text is returned token by token as a `
|
|
2230
|
+
- if `stream=True` and `details=False`, the generated text is returned token by token as a `AsyncIterable[str]`
|
|
2186
2231
|
- if `stream=False` and `details=True`, the generated text is returned with more details as a [`~huggingface_hub.TextGenerationOutput`]
|
|
2187
2232
|
- if `details=True` and `stream=True`, the generated text is returned token by token as a iterable of [`~huggingface_hub.TextGenerationStreamOutput`]
|
|
2188
2233
|
|
|
@@ -2191,7 +2236,7 @@ class AsyncInferenceClient:
|
|
|
2191
2236
|
If input values are not valid. No HTTP call is made to the server.
|
|
2192
2237
|
[`InferenceTimeoutError`]:
|
|
2193
2238
|
If the model is unavailable or the request times out.
|
|
2194
|
-
`
|
|
2239
|
+
[`HfHubHTTPError`]:
|
|
2195
2240
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
2196
2241
|
|
|
2197
2242
|
Example:
|
|
@@ -2326,7 +2371,7 @@ class AsyncInferenceClient:
|
|
|
2326
2371
|
"repetition_penalty": repetition_penalty,
|
|
2327
2372
|
"return_full_text": return_full_text,
|
|
2328
2373
|
"seed": seed,
|
|
2329
|
-
"stop": stop
|
|
2374
|
+
"stop": stop,
|
|
2330
2375
|
"temperature": temperature,
|
|
2331
2376
|
"top_k": top_k,
|
|
2332
2377
|
"top_n_tokens": top_n_tokens,
|
|
@@ -2367,29 +2412,30 @@ class AsyncInferenceClient:
|
|
|
2367
2412
|
" Please pass `stream=False` as input."
|
|
2368
2413
|
)
|
|
2369
2414
|
|
|
2370
|
-
|
|
2415
|
+
model_id = model or self.model
|
|
2416
|
+
provider_helper = get_provider_helper(self.provider, task="text-generation", model=model_id)
|
|
2371
2417
|
request_parameters = provider_helper.prepare_request(
|
|
2372
2418
|
inputs=prompt,
|
|
2373
2419
|
parameters=parameters,
|
|
2374
2420
|
extra_payload={"stream": stream},
|
|
2375
2421
|
headers=self.headers,
|
|
2376
|
-
model=
|
|
2422
|
+
model=model_id,
|
|
2377
2423
|
api_key=self.token,
|
|
2378
2424
|
)
|
|
2379
2425
|
|
|
2380
2426
|
# Handle errors separately for more precise error messages
|
|
2381
2427
|
try:
|
|
2382
|
-
bytes_output = await self._inner_post(request_parameters, stream=stream)
|
|
2383
|
-
except
|
|
2384
|
-
match = MODEL_KWARGS_NOT_USED_REGEX.search(e
|
|
2385
|
-
if e
|
|
2428
|
+
bytes_output = await self._inner_post(request_parameters, stream=stream or False)
|
|
2429
|
+
except HfHubHTTPError as e:
|
|
2430
|
+
match = MODEL_KWARGS_NOT_USED_REGEX.search(str(e))
|
|
2431
|
+
if isinstance(e, BadRequestError) and match:
|
|
2386
2432
|
unused_params = [kwarg.strip("' ") for kwarg in match.group(1).split(",")]
|
|
2387
2433
|
_set_unsupported_text_generation_kwargs(model, unused_params)
|
|
2388
2434
|
return await self.text_generation( # type: ignore
|
|
2389
2435
|
prompt=prompt,
|
|
2390
2436
|
details=details,
|
|
2391
2437
|
stream=stream,
|
|
2392
|
-
model=
|
|
2438
|
+
model=model_id,
|
|
2393
2439
|
adapter_id=adapter_id,
|
|
2394
2440
|
best_of=best_of,
|
|
2395
2441
|
decoder_input_details=decoder_input_details,
|
|
@@ -2420,8 +2466,8 @@ class AsyncInferenceClient:
|
|
|
2420
2466
|
# Data can be a single element (dict) or an iterable of dicts where we select the first element of.
|
|
2421
2467
|
if isinstance(data, list):
|
|
2422
2468
|
data = data[0]
|
|
2423
|
-
|
|
2424
|
-
return TextGenerationOutput.parse_obj_as_instance(
|
|
2469
|
+
response = provider_helper.get_response(data, request_parameters)
|
|
2470
|
+
return TextGenerationOutput.parse_obj_as_instance(response) if details else response["generated_text"]
|
|
2425
2471
|
|
|
2426
2472
|
async def text_to_image(
|
|
2427
2473
|
self,
|
|
@@ -2435,20 +2481,16 @@ class AsyncInferenceClient:
|
|
|
2435
2481
|
model: Optional[str] = None,
|
|
2436
2482
|
scheduler: Optional[str] = None,
|
|
2437
2483
|
seed: Optional[int] = None,
|
|
2438
|
-
extra_body: Optional[
|
|
2484
|
+
extra_body: Optional[dict[str, Any]] = None,
|
|
2439
2485
|
) -> "Image":
|
|
2440
2486
|
"""
|
|
2441
2487
|
Generate an image based on a given text using a specified model.
|
|
2442
2488
|
|
|
2443
|
-
|
|
2444
|
-
|
|
2445
|
-
You must have `PIL` installed if you want to work with images (`pip install Pillow`).
|
|
2489
|
+
> [!WARNING]
|
|
2490
|
+
> You must have `PIL` installed if you want to work with images (`pip install Pillow`).
|
|
2446
2491
|
|
|
2447
|
-
|
|
2448
|
-
|
|
2449
|
-
<Tip>
|
|
2450
|
-
You can pass provider-specific parameters to the model by using the `extra_body` argument.
|
|
2451
|
-
</Tip>
|
|
2492
|
+
> [!TIP]
|
|
2493
|
+
> You can pass provider-specific parameters to the model by using the `extra_body` argument.
|
|
2452
2494
|
|
|
2453
2495
|
Args:
|
|
2454
2496
|
prompt (`str`):
|
|
@@ -2473,7 +2515,7 @@ class AsyncInferenceClient:
|
|
|
2473
2515
|
Override the scheduler with a compatible one.
|
|
2474
2516
|
seed (`int`, *optional*):
|
|
2475
2517
|
Seed for the random number generator.
|
|
2476
|
-
extra_body (`
|
|
2518
|
+
extra_body (`dict[str, Any]`, *optional*):
|
|
2477
2519
|
Additional provider-specific parameters to pass to the model. Refer to the provider's documentation
|
|
2478
2520
|
for supported parameters.
|
|
2479
2521
|
|
|
@@ -2483,7 +2525,7 @@ class AsyncInferenceClient:
|
|
|
2483
2525
|
Raises:
|
|
2484
2526
|
[`InferenceTimeoutError`]:
|
|
2485
2527
|
If the model is unavailable or the request times out.
|
|
2486
|
-
`
|
|
2528
|
+
[`HfHubHTTPError`]:
|
|
2487
2529
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
2488
2530
|
|
|
2489
2531
|
Example:
|
|
@@ -2544,8 +2586,10 @@ class AsyncInferenceClient:
|
|
|
2544
2586
|
... )
|
|
2545
2587
|
>>> image.save("astronaut.png")
|
|
2546
2588
|
```
|
|
2589
|
+
|
|
2547
2590
|
"""
|
|
2548
|
-
|
|
2591
|
+
model_id = model or self.model
|
|
2592
|
+
provider_helper = get_provider_helper(self.provider, task="text-to-image", model=model_id)
|
|
2549
2593
|
request_parameters = provider_helper.prepare_request(
|
|
2550
2594
|
inputs=prompt,
|
|
2551
2595
|
parameters={
|
|
@@ -2559,11 +2603,11 @@ class AsyncInferenceClient:
|
|
|
2559
2603
|
**(extra_body or {}),
|
|
2560
2604
|
},
|
|
2561
2605
|
headers=self.headers,
|
|
2562
|
-
model=
|
|
2606
|
+
model=model_id,
|
|
2563
2607
|
api_key=self.token,
|
|
2564
2608
|
)
|
|
2565
2609
|
response = await self._inner_post(request_parameters)
|
|
2566
|
-
response = provider_helper.get_response(response)
|
|
2610
|
+
response = provider_helper.get_response(response, request_parameters)
|
|
2567
2611
|
return _bytes_to_image(response)
|
|
2568
2612
|
|
|
2569
2613
|
async def text_to_video(
|
|
@@ -2572,18 +2616,17 @@ class AsyncInferenceClient:
|
|
|
2572
2616
|
*,
|
|
2573
2617
|
model: Optional[str] = None,
|
|
2574
2618
|
guidance_scale: Optional[float] = None,
|
|
2575
|
-
negative_prompt: Optional[
|
|
2619
|
+
negative_prompt: Optional[list[str]] = None,
|
|
2576
2620
|
num_frames: Optional[float] = None,
|
|
2577
2621
|
num_inference_steps: Optional[int] = None,
|
|
2578
2622
|
seed: Optional[int] = None,
|
|
2579
|
-
extra_body: Optional[
|
|
2623
|
+
extra_body: Optional[dict[str, Any]] = None,
|
|
2580
2624
|
) -> bytes:
|
|
2581
2625
|
"""
|
|
2582
2626
|
Generate a video based on a given text.
|
|
2583
2627
|
|
|
2584
|
-
|
|
2585
|
-
You can pass provider-specific parameters to the model by using the `extra_body` argument.
|
|
2586
|
-
</Tip>
|
|
2628
|
+
> [!TIP]
|
|
2629
|
+
> You can pass provider-specific parameters to the model by using the `extra_body` argument.
|
|
2587
2630
|
|
|
2588
2631
|
Args:
|
|
2589
2632
|
prompt (`str`):
|
|
@@ -2595,7 +2638,7 @@ class AsyncInferenceClient:
|
|
|
2595
2638
|
guidance_scale (`float`, *optional*):
|
|
2596
2639
|
A higher guidance scale value encourages the model to generate videos closely linked to the text
|
|
2597
2640
|
prompt, but values too high may cause saturation and other artifacts.
|
|
2598
|
-
negative_prompt (`
|
|
2641
|
+
negative_prompt (`list[str]`, *optional*):
|
|
2599
2642
|
One or several prompt to guide what NOT to include in video generation.
|
|
2600
2643
|
num_frames (`float`, *optional*):
|
|
2601
2644
|
The num_frames parameter determines how many video frames are generated.
|
|
@@ -2604,7 +2647,7 @@ class AsyncInferenceClient:
|
|
|
2604
2647
|
expense of slower inference.
|
|
2605
2648
|
seed (`int`, *optional*):
|
|
2606
2649
|
Seed for the random number generator.
|
|
2607
|
-
extra_body (`
|
|
2650
|
+
extra_body (`dict[str, Any]`, *optional*):
|
|
2608
2651
|
Additional provider-specific parameters to pass to the model. Refer to the provider's documentation
|
|
2609
2652
|
for supported parameters.
|
|
2610
2653
|
|
|
@@ -2642,8 +2685,10 @@ class AsyncInferenceClient:
|
|
|
2642
2685
|
>>> with open("cat.mp4", "wb") as file:
|
|
2643
2686
|
... file.write(video)
|
|
2644
2687
|
```
|
|
2688
|
+
|
|
2645
2689
|
"""
|
|
2646
|
-
|
|
2690
|
+
model_id = model or self.model
|
|
2691
|
+
provider_helper = get_provider_helper(self.provider, task="text-to-video", model=model_id)
|
|
2647
2692
|
request_parameters = provider_helper.prepare_request(
|
|
2648
2693
|
inputs=prompt,
|
|
2649
2694
|
parameters={
|
|
@@ -2655,11 +2700,11 @@ class AsyncInferenceClient:
|
|
|
2655
2700
|
**(extra_body or {}),
|
|
2656
2701
|
},
|
|
2657
2702
|
headers=self.headers,
|
|
2658
|
-
model=
|
|
2703
|
+
model=model_id,
|
|
2659
2704
|
api_key=self.token,
|
|
2660
2705
|
)
|
|
2661
2706
|
response = await self._inner_post(request_parameters)
|
|
2662
|
-
response = provider_helper.get_response(response)
|
|
2707
|
+
response = provider_helper.get_response(response, request_parameters)
|
|
2663
2708
|
return response
|
|
2664
2709
|
|
|
2665
2710
|
async def text_to_speech(
|
|
@@ -2683,14 +2728,13 @@ class AsyncInferenceClient:
|
|
|
2683
2728
|
top_p: Optional[float] = None,
|
|
2684
2729
|
typical_p: Optional[float] = None,
|
|
2685
2730
|
use_cache: Optional[bool] = None,
|
|
2686
|
-
extra_body: Optional[
|
|
2731
|
+
extra_body: Optional[dict[str, Any]] = None,
|
|
2687
2732
|
) -> bytes:
|
|
2688
2733
|
"""
|
|
2689
2734
|
Synthesize an audio of a voice pronouncing a given text.
|
|
2690
2735
|
|
|
2691
|
-
|
|
2692
|
-
You can pass provider-specific parameters to the model by using the `extra_body` argument.
|
|
2693
|
-
</Tip>
|
|
2736
|
+
> [!TIP]
|
|
2737
|
+
> You can pass provider-specific parameters to the model by using the `extra_body` argument.
|
|
2694
2738
|
|
|
2695
2739
|
Args:
|
|
2696
2740
|
text (`str`):
|
|
@@ -2745,7 +2789,7 @@ class AsyncInferenceClient:
|
|
|
2745
2789
|
paper](https://hf.co/papers/2202.00666) for more details.
|
|
2746
2790
|
use_cache (`bool`, *optional*):
|
|
2747
2791
|
Whether the model should use the past last key/values attentions to speed up decoding
|
|
2748
|
-
extra_body (`
|
|
2792
|
+
extra_body (`dict[str, Any]`, *optional*):
|
|
2749
2793
|
Additional provider-specific parameters to pass to the model. Refer to the provider's documentation
|
|
2750
2794
|
for supported parameters.
|
|
2751
2795
|
Returns:
|
|
@@ -2754,7 +2798,7 @@ class AsyncInferenceClient:
|
|
|
2754
2798
|
Raises:
|
|
2755
2799
|
[`InferenceTimeoutError`]:
|
|
2756
2800
|
If the model is unavailable or the request times out.
|
|
2757
|
-
`
|
|
2801
|
+
[`HfHubHTTPError`]:
|
|
2758
2802
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
2759
2803
|
|
|
2760
2804
|
Example:
|
|
@@ -2841,7 +2885,8 @@ class AsyncInferenceClient:
|
|
|
2841
2885
|
... f.write(audio)
|
|
2842
2886
|
```
|
|
2843
2887
|
"""
|
|
2844
|
-
|
|
2888
|
+
model_id = model or self.model
|
|
2889
|
+
provider_helper = get_provider_helper(self.provider, task="text-to-speech", model=model_id)
|
|
2845
2890
|
request_parameters = provider_helper.prepare_request(
|
|
2846
2891
|
inputs=text,
|
|
2847
2892
|
parameters={
|
|
@@ -2864,7 +2909,7 @@ class AsyncInferenceClient:
|
|
|
2864
2909
|
**(extra_body or {}),
|
|
2865
2910
|
},
|
|
2866
2911
|
headers=self.headers,
|
|
2867
|
-
model=
|
|
2912
|
+
model=model_id,
|
|
2868
2913
|
api_key=self.token,
|
|
2869
2914
|
)
|
|
2870
2915
|
response = await self._inner_post(request_parameters)
|
|
@@ -2877,9 +2922,9 @@ class AsyncInferenceClient:
|
|
|
2877
2922
|
*,
|
|
2878
2923
|
model: Optional[str] = None,
|
|
2879
2924
|
aggregation_strategy: Optional["TokenClassificationAggregationStrategy"] = None,
|
|
2880
|
-
ignore_labels: Optional[
|
|
2925
|
+
ignore_labels: Optional[list[str]] = None,
|
|
2881
2926
|
stride: Optional[int] = None,
|
|
2882
|
-
) ->
|
|
2927
|
+
) -> list[TokenClassificationOutputElement]:
|
|
2883
2928
|
"""
|
|
2884
2929
|
Perform token classification on the given text.
|
|
2885
2930
|
Usually used for sentence parsing, either grammatical, or Named Entity Recognition (NER) to understand keywords contained within text.
|
|
@@ -2893,18 +2938,18 @@ class AsyncInferenceClient:
|
|
|
2893
2938
|
Defaults to None.
|
|
2894
2939
|
aggregation_strategy (`"TokenClassificationAggregationStrategy"`, *optional*):
|
|
2895
2940
|
The strategy used to fuse tokens based on model predictions
|
|
2896
|
-
ignore_labels (`
|
|
2941
|
+
ignore_labels (`list[str`, *optional*):
|
|
2897
2942
|
A list of labels to ignore
|
|
2898
2943
|
stride (`int`, *optional*):
|
|
2899
2944
|
The number of overlapping tokens between chunks when splitting the input text.
|
|
2900
2945
|
|
|
2901
2946
|
Returns:
|
|
2902
|
-
`
|
|
2947
|
+
`list[TokenClassificationOutputElement]`: List of [`TokenClassificationOutputElement`] items containing the entity group, confidence score, word, start and end index.
|
|
2903
2948
|
|
|
2904
2949
|
Raises:
|
|
2905
2950
|
[`InferenceTimeoutError`]:
|
|
2906
2951
|
If the model is unavailable or the request times out.
|
|
2907
|
-
`
|
|
2952
|
+
[`HfHubHTTPError`]:
|
|
2908
2953
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
2909
2954
|
|
|
2910
2955
|
Example:
|
|
@@ -2931,7 +2976,8 @@ class AsyncInferenceClient:
|
|
|
2931
2976
|
]
|
|
2932
2977
|
```
|
|
2933
2978
|
"""
|
|
2934
|
-
|
|
2979
|
+
model_id = model or self.model
|
|
2980
|
+
provider_helper = get_provider_helper(self.provider, task="token-classification", model=model_id)
|
|
2935
2981
|
request_parameters = provider_helper.prepare_request(
|
|
2936
2982
|
inputs=text,
|
|
2937
2983
|
parameters={
|
|
@@ -2940,7 +2986,7 @@ class AsyncInferenceClient:
|
|
|
2940
2986
|
"stride": stride,
|
|
2941
2987
|
},
|
|
2942
2988
|
headers=self.headers,
|
|
2943
|
-
model=
|
|
2989
|
+
model=model_id,
|
|
2944
2990
|
api_key=self.token,
|
|
2945
2991
|
)
|
|
2946
2992
|
response = await self._inner_post(request_parameters)
|
|
@@ -2955,7 +3001,7 @@ class AsyncInferenceClient:
|
|
|
2955
3001
|
tgt_lang: Optional[str] = None,
|
|
2956
3002
|
clean_up_tokenization_spaces: Optional[bool] = None,
|
|
2957
3003
|
truncation: Optional["TranslationTruncationStrategy"] = None,
|
|
2958
|
-
generate_parameters: Optional[
|
|
3004
|
+
generate_parameters: Optional[dict[str, Any]] = None,
|
|
2959
3005
|
) -> TranslationOutput:
|
|
2960
3006
|
"""
|
|
2961
3007
|
Convert text from one language to another.
|
|
@@ -2980,7 +3026,7 @@ class AsyncInferenceClient:
|
|
|
2980
3026
|
Whether to clean up the potential extra spaces in the text output.
|
|
2981
3027
|
truncation (`"TranslationTruncationStrategy"`, *optional*):
|
|
2982
3028
|
The truncation strategy to use.
|
|
2983
|
-
generate_parameters (`
|
|
3029
|
+
generate_parameters (`dict[str, Any]`, *optional*):
|
|
2984
3030
|
Additional parametrization of the text generation algorithm.
|
|
2985
3031
|
|
|
2986
3032
|
Returns:
|
|
@@ -2989,7 +3035,7 @@ class AsyncInferenceClient:
|
|
|
2989
3035
|
Raises:
|
|
2990
3036
|
[`InferenceTimeoutError`]:
|
|
2991
3037
|
If the model is unavailable or the request times out.
|
|
2992
|
-
`
|
|
3038
|
+
[`HfHubHTTPError`]:
|
|
2993
3039
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
2994
3040
|
`ValueError`:
|
|
2995
3041
|
If only one of the `src_lang` and `tgt_lang` arguments are provided.
|
|
@@ -3018,7 +3064,8 @@ class AsyncInferenceClient:
|
|
|
3018
3064
|
if src_lang is None and tgt_lang is not None:
|
|
3019
3065
|
raise ValueError("You cannot specify `tgt_lang` without specifying `src_lang`.")
|
|
3020
3066
|
|
|
3021
|
-
|
|
3067
|
+
model_id = model or self.model
|
|
3068
|
+
provider_helper = get_provider_helper(self.provider, task="translation", model=model_id)
|
|
3022
3069
|
request_parameters = provider_helper.prepare_request(
|
|
3023
3070
|
inputs=text,
|
|
3024
3071
|
parameters={
|
|
@@ -3029,7 +3076,7 @@ class AsyncInferenceClient:
|
|
|
3029
3076
|
"generate_parameters": generate_parameters,
|
|
3030
3077
|
},
|
|
3031
3078
|
headers=self.headers,
|
|
3032
|
-
model=
|
|
3079
|
+
model=model_id,
|
|
3033
3080
|
api_key=self.token,
|
|
3034
3081
|
)
|
|
3035
3082
|
response = await self._inner_post(request_parameters)
|
|
@@ -3042,13 +3089,13 @@ class AsyncInferenceClient:
|
|
|
3042
3089
|
*,
|
|
3043
3090
|
model: Optional[str] = None,
|
|
3044
3091
|
top_k: Optional[int] = None,
|
|
3045
|
-
) ->
|
|
3092
|
+
) -> list[VisualQuestionAnsweringOutputElement]:
|
|
3046
3093
|
"""
|
|
3047
3094
|
Answering open-ended questions based on an image.
|
|
3048
3095
|
|
|
3049
3096
|
Args:
|
|
3050
|
-
image (`Union[str, Path, bytes, BinaryIO]`):
|
|
3051
|
-
The input image for the context. It can be raw bytes, an image file,
|
|
3097
|
+
image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`):
|
|
3098
|
+
The input image for the context. It can be raw bytes, an image file, a URL to an online image, or a PIL Image.
|
|
3052
3099
|
question (`str`):
|
|
3053
3100
|
Question to be answered.
|
|
3054
3101
|
model (`str`, *optional*):
|
|
@@ -3059,12 +3106,12 @@ class AsyncInferenceClient:
|
|
|
3059
3106
|
The number of answers to return (will be chosen by order of likelihood). Note that we return less than
|
|
3060
3107
|
topk answers if there are not enough options available within the context.
|
|
3061
3108
|
Returns:
|
|
3062
|
-
`
|
|
3109
|
+
`list[VisualQuestionAnsweringOutputElement]`: a list of [`VisualQuestionAnsweringOutputElement`] items containing the predicted label and associated probability.
|
|
3063
3110
|
|
|
3064
3111
|
Raises:
|
|
3065
3112
|
`InferenceTimeoutError`:
|
|
3066
3113
|
If the model is unavailable or the request times out.
|
|
3067
|
-
`
|
|
3114
|
+
[`HfHubHTTPError`]:
|
|
3068
3115
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
3069
3116
|
|
|
3070
3117
|
Example:
|
|
@@ -3082,44 +3129,37 @@ class AsyncInferenceClient:
|
|
|
3082
3129
|
]
|
|
3083
3130
|
```
|
|
3084
3131
|
"""
|
|
3085
|
-
|
|
3132
|
+
model_id = model or self.model
|
|
3133
|
+
provider_helper = get_provider_helper(self.provider, task="visual-question-answering", model=model_id)
|
|
3086
3134
|
request_parameters = provider_helper.prepare_request(
|
|
3087
3135
|
inputs=image,
|
|
3088
3136
|
parameters={"top_k": top_k},
|
|
3089
3137
|
headers=self.headers,
|
|
3090
|
-
model=
|
|
3138
|
+
model=model_id,
|
|
3091
3139
|
api_key=self.token,
|
|
3092
3140
|
extra_payload={"question": question, "image": _b64_encode(image)},
|
|
3093
3141
|
)
|
|
3094
3142
|
response = await self._inner_post(request_parameters)
|
|
3095
3143
|
return VisualQuestionAnsweringOutputElement.parse_obj_as_list(response)
|
|
3096
3144
|
|
|
3097
|
-
@_deprecate_arguments(
|
|
3098
|
-
version="0.30.0",
|
|
3099
|
-
deprecated_args=["labels"],
|
|
3100
|
-
custom_message="`labels`has been renamed to `candidate_labels` and will be removed in huggingface_hub>=0.30.0.",
|
|
3101
|
-
)
|
|
3102
3145
|
async def zero_shot_classification(
|
|
3103
3146
|
self,
|
|
3104
3147
|
text: str,
|
|
3105
|
-
|
|
3106
|
-
candidate_labels: List[str] = None, # type: ignore
|
|
3148
|
+
candidate_labels: list[str],
|
|
3107
3149
|
*,
|
|
3108
3150
|
multi_label: Optional[bool] = False,
|
|
3109
3151
|
hypothesis_template: Optional[str] = None,
|
|
3110
3152
|
model: Optional[str] = None,
|
|
3111
|
-
|
|
3112
|
-
labels: List[str] = None, # type: ignore
|
|
3113
|
-
) -> List[ZeroShotClassificationOutputElement]:
|
|
3153
|
+
) -> list[ZeroShotClassificationOutputElement]:
|
|
3114
3154
|
"""
|
|
3115
3155
|
Provide as input a text and a set of candidate labels to classify the input text.
|
|
3116
3156
|
|
|
3117
3157
|
Args:
|
|
3118
3158
|
text (`str`):
|
|
3119
3159
|
The input text to classify.
|
|
3120
|
-
candidate_labels (`
|
|
3160
|
+
candidate_labels (`list[str]`):
|
|
3121
3161
|
The set of possible class labels to classify the text into.
|
|
3122
|
-
labels (`
|
|
3162
|
+
labels (`list[str]`, *optional*):
|
|
3123
3163
|
(deprecated) List of strings. Each string is the verbalization of a possible label for the input text.
|
|
3124
3164
|
multi_label (`bool`, *optional*):
|
|
3125
3165
|
Whether multiple candidate labels can be true. If false, the scores are normalized such that the sum of
|
|
@@ -3134,12 +3174,12 @@ class AsyncInferenceClient:
|
|
|
3134
3174
|
|
|
3135
3175
|
|
|
3136
3176
|
Returns:
|
|
3137
|
-
`
|
|
3177
|
+
`list[ZeroShotClassificationOutputElement]`: List of [`ZeroShotClassificationOutputElement`] items containing the predicted labels and their confidence.
|
|
3138
3178
|
|
|
3139
3179
|
Raises:
|
|
3140
3180
|
[`InferenceTimeoutError`]:
|
|
3141
3181
|
If the model is unavailable or the request times out.
|
|
3142
|
-
`
|
|
3182
|
+
[`HfHubHTTPError`]:
|
|
3143
3183
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
3144
3184
|
|
|
3145
3185
|
Example with `multi_label=False`:
|
|
@@ -3190,17 +3230,8 @@ class AsyncInferenceClient:
|
|
|
3190
3230
|
]
|
|
3191
3231
|
```
|
|
3192
3232
|
"""
|
|
3193
|
-
|
|
3194
|
-
|
|
3195
|
-
if candidate_labels is not None:
|
|
3196
|
-
raise ValueError(
|
|
3197
|
-
"Cannot specify both `labels` and `candidate_labels`. Use `candidate_labels` instead."
|
|
3198
|
-
)
|
|
3199
|
-
candidate_labels = labels
|
|
3200
|
-
elif candidate_labels is None:
|
|
3201
|
-
raise ValueError("Must specify `candidate_labels`")
|
|
3202
|
-
|
|
3203
|
-
provider_helper = get_provider_helper(self.provider, task="zero-shot-classification")
|
|
3233
|
+
model_id = model or self.model
|
|
3234
|
+
provider_helper = get_provider_helper(self.provider, task="zero-shot-classification", model=model_id)
|
|
3204
3235
|
request_parameters = provider_helper.prepare_request(
|
|
3205
3236
|
inputs=text,
|
|
3206
3237
|
parameters={
|
|
@@ -3209,7 +3240,7 @@ class AsyncInferenceClient:
|
|
|
3209
3240
|
"hypothesis_template": hypothesis_template,
|
|
3210
3241
|
},
|
|
3211
3242
|
headers=self.headers,
|
|
3212
|
-
model=
|
|
3243
|
+
model=model_id,
|
|
3213
3244
|
api_key=self.token,
|
|
3214
3245
|
)
|
|
3215
3246
|
response = await self._inner_post(request_parameters)
|
|
@@ -3219,31 +3250,25 @@ class AsyncInferenceClient:
|
|
|
3219
3250
|
for label, score in zip(output["labels"], output["scores"])
|
|
3220
3251
|
]
|
|
3221
3252
|
|
|
3222
|
-
@_deprecate_arguments(
|
|
3223
|
-
version="0.30.0",
|
|
3224
|
-
deprecated_args=["labels"],
|
|
3225
|
-
custom_message="`labels`has been renamed to `candidate_labels` and will be removed in huggingface_hub>=0.30.0.",
|
|
3226
|
-
)
|
|
3227
3253
|
async def zero_shot_image_classification(
|
|
3228
3254
|
self,
|
|
3229
3255
|
image: ContentT,
|
|
3230
|
-
|
|
3231
|
-
candidate_labels: List[str] = None, # type: ignore
|
|
3256
|
+
candidate_labels: list[str],
|
|
3232
3257
|
*,
|
|
3233
3258
|
model: Optional[str] = None,
|
|
3234
3259
|
hypothesis_template: Optional[str] = None,
|
|
3235
3260
|
# deprecated argument
|
|
3236
|
-
labels:
|
|
3237
|
-
) ->
|
|
3261
|
+
labels: list[str] = None, # type: ignore
|
|
3262
|
+
) -> list[ZeroShotImageClassificationOutputElement]:
|
|
3238
3263
|
"""
|
|
3239
3264
|
Provide input image and text labels to predict text labels for the image.
|
|
3240
3265
|
|
|
3241
3266
|
Args:
|
|
3242
|
-
image (`Union[str, Path, bytes, BinaryIO]`):
|
|
3243
|
-
The input image to caption. It can be raw bytes, an image file,
|
|
3244
|
-
candidate_labels (`
|
|
3267
|
+
image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`):
|
|
3268
|
+
The input image to caption. It can be raw bytes, an image file, a URL to an online image, or a PIL Image.
|
|
3269
|
+
candidate_labels (`list[str]`):
|
|
3245
3270
|
The candidate labels for this image
|
|
3246
|
-
labels (`
|
|
3271
|
+
labels (`list[str]`, *optional*):
|
|
3247
3272
|
(deprecated) List of string possible labels. There must be at least 2 labels.
|
|
3248
3273
|
model (`str`, *optional*):
|
|
3249
3274
|
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
|
@@ -3253,12 +3278,12 @@ class AsyncInferenceClient:
|
|
|
3253
3278
|
replacing the placeholder with the candidate labels.
|
|
3254
3279
|
|
|
3255
3280
|
Returns:
|
|
3256
|
-
`
|
|
3281
|
+
`list[ZeroShotImageClassificationOutputElement]`: List of [`ZeroShotImageClassificationOutputElement`] items containing the predicted labels and their confidence.
|
|
3257
3282
|
|
|
3258
3283
|
Raises:
|
|
3259
3284
|
[`InferenceTimeoutError`]:
|
|
3260
3285
|
If the model is unavailable or the request times out.
|
|
3261
|
-
`
|
|
3286
|
+
[`HfHubHTTPError`]:
|
|
3262
3287
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
3263
3288
|
|
|
3264
3289
|
Example:
|
|
@@ -3274,20 +3299,12 @@ class AsyncInferenceClient:
|
|
|
3274
3299
|
[ZeroShotImageClassificationOutputElement(label='dog', score=0.956),...]
|
|
3275
3300
|
```
|
|
3276
3301
|
"""
|
|
3277
|
-
# handle deprecation
|
|
3278
|
-
if labels is not None:
|
|
3279
|
-
if candidate_labels is not None:
|
|
3280
|
-
raise ValueError(
|
|
3281
|
-
"Cannot specify both `labels` and `candidate_labels`. Use `candidate_labels` instead."
|
|
3282
|
-
)
|
|
3283
|
-
candidate_labels = labels
|
|
3284
|
-
elif candidate_labels is None:
|
|
3285
|
-
raise ValueError("Must specify `candidate_labels`")
|
|
3286
3302
|
# Raise ValueError if input is less than 2 labels
|
|
3287
3303
|
if len(candidate_labels) < 2:
|
|
3288
3304
|
raise ValueError("You must specify at least 2 classes to compare.")
|
|
3289
3305
|
|
|
3290
|
-
|
|
3306
|
+
model_id = model or self.model
|
|
3307
|
+
provider_helper = get_provider_helper(self.provider, task="zero-shot-image-classification", model=model_id)
|
|
3291
3308
|
request_parameters = provider_helper.prepare_request(
|
|
3292
3309
|
inputs=image,
|
|
3293
3310
|
parameters={
|
|
@@ -3295,150 +3312,13 @@ class AsyncInferenceClient:
|
|
|
3295
3312
|
"hypothesis_template": hypothesis_template,
|
|
3296
3313
|
},
|
|
3297
3314
|
headers=self.headers,
|
|
3298
|
-
model=
|
|
3315
|
+
model=model_id,
|
|
3299
3316
|
api_key=self.token,
|
|
3300
3317
|
)
|
|
3301
3318
|
response = await self._inner_post(request_parameters)
|
|
3302
3319
|
return ZeroShotImageClassificationOutputElement.parse_obj_as_list(response)
|
|
3303
3320
|
|
|
3304
|
-
|
|
3305
|
-
version="0.33.0",
|
|
3306
|
-
message=(
|
|
3307
|
-
"HF Inference API is getting revamped and will only support warm models in the future (no cold start allowed)."
|
|
3308
|
-
" Use `HfApi.list_models(..., inference_provider='...')` to list warm models per provider."
|
|
3309
|
-
),
|
|
3310
|
-
)
|
|
3311
|
-
async def list_deployed_models(
|
|
3312
|
-
self, frameworks: Union[None, str, Literal["all"], List[str]] = None
|
|
3313
|
-
) -> Dict[str, List[str]]:
|
|
3314
|
-
"""
|
|
3315
|
-
List models deployed on the HF Serverless Inference API service.
|
|
3316
|
-
|
|
3317
|
-
This helper checks deployed models framework by framework. By default, it will check the 4 main frameworks that
|
|
3318
|
-
are supported and account for 95% of the hosted models. However, if you want a complete list of models you can
|
|
3319
|
-
specify `frameworks="all"` as input. Alternatively, if you know before-hand which framework you are interested
|
|
3320
|
-
in, you can also restrict to search to this one (e.g. `frameworks="text-generation-inference"`). The more
|
|
3321
|
-
frameworks are checked, the more time it will take.
|
|
3322
|
-
|
|
3323
|
-
<Tip warning={true}>
|
|
3324
|
-
|
|
3325
|
-
This endpoint method does not return a live list of all models available for the HF Inference API service.
|
|
3326
|
-
It searches over a cached list of models that were recently available and the list may not be up to date.
|
|
3327
|
-
If you want to know the live status of a specific model, use [`~InferenceClient.get_model_status`].
|
|
3328
|
-
|
|
3329
|
-
</Tip>
|
|
3330
|
-
|
|
3331
|
-
<Tip>
|
|
3332
|
-
|
|
3333
|
-
This endpoint method is mostly useful for discoverability. If you already know which model you want to use and want to
|
|
3334
|
-
check its availability, you can directly use [`~InferenceClient.get_model_status`].
|
|
3335
|
-
|
|
3336
|
-
</Tip>
|
|
3337
|
-
|
|
3338
|
-
Args:
|
|
3339
|
-
frameworks (`Literal["all"]` or `List[str]` or `str`, *optional*):
|
|
3340
|
-
The frameworks to filter on. By default only a subset of the available frameworks are tested. If set to
|
|
3341
|
-
"all", all available frameworks will be tested. It is also possible to provide a single framework or a
|
|
3342
|
-
custom set of frameworks to check.
|
|
3343
|
-
|
|
3344
|
-
Returns:
|
|
3345
|
-
`Dict[str, List[str]]`: A dictionary mapping task names to a sorted list of model IDs.
|
|
3346
|
-
|
|
3347
|
-
Example:
|
|
3348
|
-
```py
|
|
3349
|
-
# Must be run in an async contextthon
|
|
3350
|
-
>>> from huggingface_hub import AsyncInferenceClient
|
|
3351
|
-
>>> client = AsyncInferenceClient()
|
|
3352
|
-
|
|
3353
|
-
# Discover zero-shot-classification models currently deployed
|
|
3354
|
-
>>> models = await client.list_deployed_models()
|
|
3355
|
-
>>> models["zero-shot-classification"]
|
|
3356
|
-
['Narsil/deberta-large-mnli-zero-cls', 'facebook/bart-large-mnli', ...]
|
|
3357
|
-
|
|
3358
|
-
# List from only 1 framework
|
|
3359
|
-
>>> await client.list_deployed_models("text-generation-inference")
|
|
3360
|
-
{'text-generation': ['bigcode/starcoder', 'meta-llama/Llama-2-70b-chat-hf', ...], ...}
|
|
3361
|
-
```
|
|
3362
|
-
"""
|
|
3363
|
-
if self.provider != "hf-inference":
|
|
3364
|
-
raise ValueError(f"Listing deployed models is not supported on '{self.provider}'.")
|
|
3365
|
-
|
|
3366
|
-
# Resolve which frameworks to check
|
|
3367
|
-
if frameworks is None:
|
|
3368
|
-
frameworks = constants.MAIN_INFERENCE_API_FRAMEWORKS
|
|
3369
|
-
elif frameworks == "all":
|
|
3370
|
-
frameworks = constants.ALL_INFERENCE_API_FRAMEWORKS
|
|
3371
|
-
elif isinstance(frameworks, str):
|
|
3372
|
-
frameworks = [frameworks]
|
|
3373
|
-
frameworks = list(set(frameworks))
|
|
3374
|
-
|
|
3375
|
-
# Fetch them iteratively
|
|
3376
|
-
models_by_task: Dict[str, List[str]] = {}
|
|
3377
|
-
|
|
3378
|
-
def _unpack_response(framework: str, items: List[Dict]) -> None:
|
|
3379
|
-
for model in items:
|
|
3380
|
-
if framework == "sentence-transformers":
|
|
3381
|
-
# Model running with the `sentence-transformers` framework can work with both tasks even if not
|
|
3382
|
-
# branded as such in the API response
|
|
3383
|
-
models_by_task.setdefault("feature-extraction", []).append(model["model_id"])
|
|
3384
|
-
models_by_task.setdefault("sentence-similarity", []).append(model["model_id"])
|
|
3385
|
-
else:
|
|
3386
|
-
models_by_task.setdefault(model["task"], []).append(model["model_id"])
|
|
3387
|
-
|
|
3388
|
-
for framework in frameworks:
|
|
3389
|
-
response = get_session().get(
|
|
3390
|
-
f"{constants.INFERENCE_ENDPOINT}/framework/{framework}", headers=build_hf_headers(token=self.token)
|
|
3391
|
-
)
|
|
3392
|
-
hf_raise_for_status(response)
|
|
3393
|
-
_unpack_response(framework, response.json())
|
|
3394
|
-
|
|
3395
|
-
# Sort alphabetically for discoverability and return
|
|
3396
|
-
for task, models in models_by_task.items():
|
|
3397
|
-
models_by_task[task] = sorted(set(models), key=lambda x: x.lower())
|
|
3398
|
-
return models_by_task
|
|
3399
|
-
|
|
3400
|
-
def _get_client_session(self, headers: Optional[Dict] = None) -> "ClientSession":
|
|
3401
|
-
aiohttp = _import_aiohttp()
|
|
3402
|
-
client_headers = self.headers.copy()
|
|
3403
|
-
if headers is not None:
|
|
3404
|
-
client_headers.update(headers)
|
|
3405
|
-
|
|
3406
|
-
# Return a new aiohttp ClientSession with correct settings.
|
|
3407
|
-
session = aiohttp.ClientSession(
|
|
3408
|
-
headers=client_headers,
|
|
3409
|
-
cookies=self.cookies,
|
|
3410
|
-
timeout=aiohttp.ClientTimeout(self.timeout),
|
|
3411
|
-
trust_env=self.trust_env,
|
|
3412
|
-
)
|
|
3413
|
-
|
|
3414
|
-
# Keep track of sessions to close them later
|
|
3415
|
-
self._sessions[session] = set()
|
|
3416
|
-
|
|
3417
|
-
# Override the `._request` method to register responses to be closed
|
|
3418
|
-
session._wrapped_request = session._request
|
|
3419
|
-
|
|
3420
|
-
async def _request(method, url, **kwargs):
|
|
3421
|
-
response = await session._wrapped_request(method, url, **kwargs)
|
|
3422
|
-
self._sessions[session].add(response)
|
|
3423
|
-
return response
|
|
3424
|
-
|
|
3425
|
-
session._request = _request
|
|
3426
|
-
|
|
3427
|
-
# Override the 'close' method to
|
|
3428
|
-
# 1. close ongoing responses
|
|
3429
|
-
# 2. deregister the session when closed
|
|
3430
|
-
session._close = session.close
|
|
3431
|
-
|
|
3432
|
-
async def close_session():
|
|
3433
|
-
for response in self._sessions[session]:
|
|
3434
|
-
response.close()
|
|
3435
|
-
await session._close()
|
|
3436
|
-
self._sessions.pop(session, None)
|
|
3437
|
-
|
|
3438
|
-
session.close = close_session
|
|
3439
|
-
return session
|
|
3440
|
-
|
|
3441
|
-
async def get_endpoint_info(self, *, model: Optional[str] = None) -> Dict[str, Any]:
|
|
3321
|
+
async def get_endpoint_info(self, *, model: Optional[str] = None) -> dict[str, Any]:
|
|
3442
3322
|
"""
|
|
3443
3323
|
Get information about the deployed endpoint.
|
|
3444
3324
|
|
|
@@ -3451,7 +3331,7 @@ class AsyncInferenceClient:
|
|
|
3451
3331
|
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
|
|
3452
3332
|
|
|
3453
3333
|
Returns:
|
|
3454
|
-
`
|
|
3334
|
+
`dict[str, Any]`: Information about the endpoint.
|
|
3455
3335
|
|
|
3456
3336
|
Example:
|
|
3457
3337
|
```py
|
|
@@ -3493,17 +3373,16 @@ class AsyncInferenceClient:
|
|
|
3493
3373
|
else:
|
|
3494
3374
|
url = f"{constants.INFERENCE_ENDPOINT}/models/{model}/info"
|
|
3495
3375
|
|
|
3496
|
-
|
|
3497
|
-
|
|
3498
|
-
|
|
3499
|
-
|
|
3376
|
+
client = await self._get_async_client()
|
|
3377
|
+
response = await client.get(url, headers=build_hf_headers(token=self.token))
|
|
3378
|
+
hf_raise_for_status(response)
|
|
3379
|
+
return response.json()
|
|
3500
3380
|
|
|
3501
3381
|
async def health_check(self, model: Optional[str] = None) -> bool:
|
|
3502
3382
|
"""
|
|
3503
3383
|
Check the health of the deployed endpoint.
|
|
3504
3384
|
|
|
3505
3385
|
Health check is only available with Inference Endpoints powered by Text-Generation-Inference (TGI) or Text-Embedding-Inference (TEI).
|
|
3506
|
-
For Inference API, please use [`InferenceClient.get_model_status`] instead.
|
|
3507
3386
|
|
|
3508
3387
|
Args:
|
|
3509
3388
|
model (`str`, *optional*):
|
|
@@ -3528,77 +3407,12 @@ class AsyncInferenceClient:
|
|
|
3528
3407
|
if model is None:
|
|
3529
3408
|
raise ValueError("Model id not provided.")
|
|
3530
3409
|
if not model.startswith(("http://", "https://")):
|
|
3531
|
-
raise ValueError(
|
|
3532
|
-
"Model must be an Inference Endpoint URL. For serverless Inference API, please use `InferenceClient.get_model_status`."
|
|
3533
|
-
)
|
|
3410
|
+
raise ValueError("Model must be an Inference Endpoint URL.")
|
|
3534
3411
|
url = model.rstrip("/") + "/health"
|
|
3535
3412
|
|
|
3536
|
-
|
|
3537
|
-
|
|
3538
|
-
|
|
3539
|
-
|
|
3540
|
-
@_deprecate_method(
|
|
3541
|
-
version="0.33.0",
|
|
3542
|
-
message=(
|
|
3543
|
-
"HF Inference API is getting revamped and will only support warm models in the future (no cold start allowed)."
|
|
3544
|
-
" Use `HfApi.model_info` to get the model status both with HF Inference API and external providers."
|
|
3545
|
-
),
|
|
3546
|
-
)
|
|
3547
|
-
async def get_model_status(self, model: Optional[str] = None) -> ModelStatus:
|
|
3548
|
-
"""
|
|
3549
|
-
Get the status of a model hosted on the HF Inference API.
|
|
3550
|
-
|
|
3551
|
-
<Tip>
|
|
3552
|
-
|
|
3553
|
-
This endpoint is mostly useful when you already know which model you want to use and want to check its
|
|
3554
|
-
availability. If you want to discover already deployed models, you should rather use [`~InferenceClient.list_deployed_models`].
|
|
3555
|
-
|
|
3556
|
-
</Tip>
|
|
3557
|
-
|
|
3558
|
-
Args:
|
|
3559
|
-
model (`str`, *optional*):
|
|
3560
|
-
Identifier of the model for witch the status gonna be checked. If model is not provided,
|
|
3561
|
-
the model associated with this instance of [`InferenceClient`] will be used. Only HF Inference API service can be checked so the
|
|
3562
|
-
identifier cannot be a URL.
|
|
3563
|
-
|
|
3564
|
-
|
|
3565
|
-
Returns:
|
|
3566
|
-
[`ModelStatus`]: An instance of ModelStatus dataclass, containing information,
|
|
3567
|
-
about the state of the model: load, state, compute type and framework.
|
|
3568
|
-
|
|
3569
|
-
Example:
|
|
3570
|
-
```py
|
|
3571
|
-
# Must be run in an async context
|
|
3572
|
-
>>> from huggingface_hub import AsyncInferenceClient
|
|
3573
|
-
>>> client = AsyncInferenceClient()
|
|
3574
|
-
>>> await client.get_model_status("meta-llama/Meta-Llama-3-8B-Instruct")
|
|
3575
|
-
ModelStatus(loaded=True, state='Loaded', compute_type='gpu', framework='text-generation-inference')
|
|
3576
|
-
```
|
|
3577
|
-
"""
|
|
3578
|
-
if self.provider != "hf-inference":
|
|
3579
|
-
raise ValueError(f"Getting model status is not supported on '{self.provider}'.")
|
|
3580
|
-
|
|
3581
|
-
model = model or self.model
|
|
3582
|
-
if model is None:
|
|
3583
|
-
raise ValueError("Model id not provided.")
|
|
3584
|
-
if model.startswith("https://"):
|
|
3585
|
-
raise NotImplementedError("Model status is only available for Inference API endpoints.")
|
|
3586
|
-
url = f"{constants.INFERENCE_ENDPOINT}/status/{model}"
|
|
3587
|
-
|
|
3588
|
-
async with self._get_client_session(headers=build_hf_headers(token=self.token)) as client:
|
|
3589
|
-
response = await client.get(url, proxy=self.proxies)
|
|
3590
|
-
response.raise_for_status()
|
|
3591
|
-
response_data = await response.json()
|
|
3592
|
-
|
|
3593
|
-
if "error" in response_data:
|
|
3594
|
-
raise ValueError(response_data["error"])
|
|
3595
|
-
|
|
3596
|
-
return ModelStatus(
|
|
3597
|
-
loaded=response_data["loaded"],
|
|
3598
|
-
state=response_data["state"],
|
|
3599
|
-
compute_type=response_data["compute_type"],
|
|
3600
|
-
framework=response_data["framework"],
|
|
3601
|
-
)
|
|
3413
|
+
client = await self._get_async_client()
|
|
3414
|
+
response = await client.get(url, headers=build_hf_headers(token=self.token))
|
|
3415
|
+
return response.status_code == 200
|
|
3602
3416
|
|
|
3603
3417
|
@property
|
|
3604
3418
|
def chat(self) -> "ProxyClientChat":
|