huggingface-hub 0.29.0rc2__py3-none-any.whl → 1.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- huggingface_hub/__init__.py +160 -46
- huggingface_hub/_commit_api.py +277 -71
- huggingface_hub/_commit_scheduler.py +15 -15
- huggingface_hub/_inference_endpoints.py +33 -22
- huggingface_hub/_jobs_api.py +301 -0
- huggingface_hub/_local_folder.py +18 -3
- huggingface_hub/_login.py +31 -63
- huggingface_hub/_oauth.py +460 -0
- huggingface_hub/_snapshot_download.py +241 -81
- huggingface_hub/_space_api.py +18 -10
- huggingface_hub/_tensorboard_logger.py +15 -19
- huggingface_hub/_upload_large_folder.py +196 -76
- huggingface_hub/_webhooks_payload.py +3 -3
- huggingface_hub/_webhooks_server.py +15 -25
- huggingface_hub/{commands → cli}/__init__.py +1 -15
- huggingface_hub/cli/_cli_utils.py +173 -0
- huggingface_hub/cli/auth.py +147 -0
- huggingface_hub/cli/cache.py +841 -0
- huggingface_hub/cli/download.py +189 -0
- huggingface_hub/cli/hf.py +60 -0
- huggingface_hub/cli/inference_endpoints.py +377 -0
- huggingface_hub/cli/jobs.py +772 -0
- huggingface_hub/cli/lfs.py +175 -0
- huggingface_hub/cli/repo.py +315 -0
- huggingface_hub/cli/repo_files.py +94 -0
- huggingface_hub/{commands/env.py → cli/system.py} +10 -13
- huggingface_hub/cli/upload.py +294 -0
- huggingface_hub/cli/upload_large_folder.py +117 -0
- huggingface_hub/community.py +20 -12
- huggingface_hub/constants.py +83 -59
- huggingface_hub/dataclasses.py +609 -0
- huggingface_hub/errors.py +99 -30
- huggingface_hub/fastai_utils.py +30 -41
- huggingface_hub/file_download.py +606 -346
- huggingface_hub/hf_api.py +2445 -1132
- huggingface_hub/hf_file_system.py +269 -152
- huggingface_hub/hub_mixin.py +61 -66
- huggingface_hub/inference/_client.py +501 -630
- huggingface_hub/inference/_common.py +133 -121
- huggingface_hub/inference/_generated/_async_client.py +536 -722
- huggingface_hub/inference/_generated/types/__init__.py +6 -1
- huggingface_hub/inference/_generated/types/automatic_speech_recognition.py +5 -6
- huggingface_hub/inference/_generated/types/base.py +10 -7
- huggingface_hub/inference/_generated/types/chat_completion.py +77 -31
- huggingface_hub/inference/_generated/types/depth_estimation.py +2 -2
- huggingface_hub/inference/_generated/types/document_question_answering.py +2 -2
- huggingface_hub/inference/_generated/types/feature_extraction.py +2 -2
- huggingface_hub/inference/_generated/types/fill_mask.py +2 -2
- huggingface_hub/inference/_generated/types/image_to_image.py +8 -2
- huggingface_hub/inference/_generated/types/image_to_text.py +2 -3
- huggingface_hub/inference/_generated/types/image_to_video.py +60 -0
- huggingface_hub/inference/_generated/types/sentence_similarity.py +3 -3
- huggingface_hub/inference/_generated/types/summarization.py +2 -2
- huggingface_hub/inference/_generated/types/table_question_answering.py +5 -5
- huggingface_hub/inference/_generated/types/text2text_generation.py +2 -2
- huggingface_hub/inference/_generated/types/text_generation.py +11 -11
- huggingface_hub/inference/_generated/types/text_to_audio.py +1 -2
- huggingface_hub/inference/_generated/types/text_to_speech.py +1 -2
- huggingface_hub/inference/_generated/types/text_to_video.py +2 -2
- huggingface_hub/inference/_generated/types/token_classification.py +2 -2
- huggingface_hub/inference/_generated/types/translation.py +2 -2
- huggingface_hub/inference/_generated/types/zero_shot_classification.py +2 -2
- huggingface_hub/inference/_generated/types/zero_shot_image_classification.py +2 -2
- huggingface_hub/inference/_generated/types/zero_shot_object_detection.py +1 -3
- huggingface_hub/inference/_mcp/__init__.py +0 -0
- huggingface_hub/inference/_mcp/_cli_hacks.py +88 -0
- huggingface_hub/inference/_mcp/agent.py +100 -0
- huggingface_hub/inference/_mcp/cli.py +247 -0
- huggingface_hub/inference/_mcp/constants.py +81 -0
- huggingface_hub/inference/_mcp/mcp_client.py +395 -0
- huggingface_hub/inference/_mcp/types.py +45 -0
- huggingface_hub/inference/_mcp/utils.py +128 -0
- huggingface_hub/inference/_providers/__init__.py +149 -20
- huggingface_hub/inference/_providers/_common.py +160 -37
- huggingface_hub/inference/_providers/black_forest_labs.py +12 -9
- huggingface_hub/inference/_providers/cerebras.py +6 -0
- huggingface_hub/inference/_providers/clarifai.py +13 -0
- huggingface_hub/inference/_providers/cohere.py +32 -0
- huggingface_hub/inference/_providers/fal_ai.py +231 -22
- huggingface_hub/inference/_providers/featherless_ai.py +38 -0
- huggingface_hub/inference/_providers/fireworks_ai.py +22 -1
- huggingface_hub/inference/_providers/groq.py +9 -0
- huggingface_hub/inference/_providers/hf_inference.py +143 -33
- huggingface_hub/inference/_providers/hyperbolic.py +9 -5
- huggingface_hub/inference/_providers/nebius.py +47 -5
- huggingface_hub/inference/_providers/novita.py +48 -5
- huggingface_hub/inference/_providers/nscale.py +44 -0
- huggingface_hub/inference/_providers/openai.py +25 -0
- huggingface_hub/inference/_providers/publicai.py +6 -0
- huggingface_hub/inference/_providers/replicate.py +46 -9
- huggingface_hub/inference/_providers/sambanova.py +37 -1
- huggingface_hub/inference/_providers/scaleway.py +28 -0
- huggingface_hub/inference/_providers/together.py +34 -5
- huggingface_hub/inference/_providers/wavespeed.py +138 -0
- huggingface_hub/inference/_providers/zai_org.py +17 -0
- huggingface_hub/lfs.py +33 -100
- huggingface_hub/repocard.py +34 -38
- huggingface_hub/repocard_data.py +79 -59
- huggingface_hub/serialization/__init__.py +0 -1
- huggingface_hub/serialization/_base.py +12 -15
- huggingface_hub/serialization/_dduf.py +8 -8
- huggingface_hub/serialization/_torch.py +69 -69
- huggingface_hub/utils/__init__.py +27 -8
- huggingface_hub/utils/_auth.py +7 -7
- huggingface_hub/utils/_cache_manager.py +92 -147
- huggingface_hub/utils/_chunk_utils.py +2 -3
- huggingface_hub/utils/_deprecation.py +1 -1
- huggingface_hub/utils/_dotenv.py +55 -0
- huggingface_hub/utils/_experimental.py +7 -5
- huggingface_hub/utils/_fixes.py +0 -10
- huggingface_hub/utils/_git_credential.py +5 -5
- huggingface_hub/utils/_headers.py +8 -30
- huggingface_hub/utils/_http.py +399 -237
- huggingface_hub/utils/_pagination.py +6 -6
- huggingface_hub/utils/_parsing.py +98 -0
- huggingface_hub/utils/_paths.py +5 -5
- huggingface_hub/utils/_runtime.py +74 -22
- huggingface_hub/utils/_safetensors.py +21 -21
- huggingface_hub/utils/_subprocess.py +13 -11
- huggingface_hub/utils/_telemetry.py +4 -4
- huggingface_hub/{commands/_cli_utils.py → utils/_terminal.py} +4 -4
- huggingface_hub/utils/_typing.py +25 -5
- huggingface_hub/utils/_validators.py +55 -74
- huggingface_hub/utils/_verification.py +167 -0
- huggingface_hub/utils/_xet.py +235 -0
- huggingface_hub/utils/_xet_progress_reporting.py +162 -0
- huggingface_hub/utils/insecure_hashlib.py +3 -5
- huggingface_hub/utils/logging.py +8 -11
- huggingface_hub/utils/tqdm.py +33 -4
- {huggingface_hub-0.29.0rc2.dist-info → huggingface_hub-1.1.3.dist-info}/METADATA +94 -82
- huggingface_hub-1.1.3.dist-info/RECORD +155 -0
- {huggingface_hub-0.29.0rc2.dist-info → huggingface_hub-1.1.3.dist-info}/WHEEL +1 -1
- huggingface_hub-1.1.3.dist-info/entry_points.txt +6 -0
- huggingface_hub/commands/delete_cache.py +0 -428
- huggingface_hub/commands/download.py +0 -200
- huggingface_hub/commands/huggingface_cli.py +0 -61
- huggingface_hub/commands/lfs.py +0 -200
- huggingface_hub/commands/repo_files.py +0 -128
- huggingface_hub/commands/scan_cache.py +0 -181
- huggingface_hub/commands/tag.py +0 -159
- huggingface_hub/commands/upload.py +0 -299
- huggingface_hub/commands/upload_large_folder.py +0 -129
- huggingface_hub/commands/user.py +0 -304
- huggingface_hub/commands/version.py +0 -37
- huggingface_hub/inference_api.py +0 -217
- huggingface_hub/keras_mixin.py +0 -500
- huggingface_hub/repository.py +0 -1477
- huggingface_hub/serialization/_tensorflow.py +0 -95
- huggingface_hub/utils/_hf_folder.py +0 -68
- huggingface_hub-0.29.0rc2.dist-info/RECORD +0 -131
- huggingface_hub-0.29.0rc2.dist-info/entry_points.txt +0 -6
- {huggingface_hub-0.29.0rc2.dist-info → huggingface_hub-1.1.3.dist-info/licenses}/LICENSE +0 -0
- {huggingface_hub-0.29.0rc2.dist-info → huggingface_hub-1.1.3.dist-info}/top_level.txt +0 -0
|
@@ -34,18 +34,17 @@
|
|
|
34
34
|
# - Only the main parameters are publicly exposed. Power users can always read the docs for more options.
|
|
35
35
|
import base64
|
|
36
36
|
import logging
|
|
37
|
+
import os
|
|
37
38
|
import re
|
|
38
39
|
import warnings
|
|
39
|
-
from
|
|
40
|
-
|
|
41
|
-
from requests import HTTPError
|
|
40
|
+
from contextlib import ExitStack
|
|
41
|
+
from typing import TYPE_CHECKING, Any, Iterable, Literal, Optional, Union, overload
|
|
42
42
|
|
|
43
43
|
from huggingface_hub import constants
|
|
44
|
-
from huggingface_hub.errors import BadRequestError, InferenceTimeoutError
|
|
44
|
+
from huggingface_hub.errors import BadRequestError, HfHubHTTPError, InferenceTimeoutError
|
|
45
45
|
from huggingface_hub.inference._common import (
|
|
46
46
|
TASKS_EXPECTING_IMAGES,
|
|
47
47
|
ContentT,
|
|
48
|
-
ModelStatus,
|
|
49
48
|
RequestParameters,
|
|
50
49
|
_b64_encode,
|
|
51
50
|
_b64_to_image,
|
|
@@ -54,7 +53,6 @@ from huggingface_hub.inference._common import (
|
|
|
54
53
|
_bytes_to_list,
|
|
55
54
|
_get_unsupported_text_generation_kwargs,
|
|
56
55
|
_import_numpy,
|
|
57
|
-
_open_as_binary,
|
|
58
56
|
_set_unsupported_text_generation_kwargs,
|
|
59
57
|
_stream_chat_completion_response,
|
|
60
58
|
_stream_text_generation_response,
|
|
@@ -66,6 +64,7 @@ from huggingface_hub.inference._generated.types import (
|
|
|
66
64
|
AudioToAudioOutputElement,
|
|
67
65
|
AutomaticSpeechRecognitionOutput,
|
|
68
66
|
ChatCompletionInputGrammarType,
|
|
67
|
+
ChatCompletionInputMessage,
|
|
69
68
|
ChatCompletionInputStreamOptions,
|
|
70
69
|
ChatCompletionInputTool,
|
|
71
70
|
ChatCompletionInputToolChoiceClass,
|
|
@@ -80,6 +79,7 @@ from huggingface_hub.inference._generated.types import (
|
|
|
80
79
|
ImageSegmentationSubtask,
|
|
81
80
|
ImageToImageTargetSize,
|
|
82
81
|
ImageToTextOutput,
|
|
82
|
+
ImageToVideoTargetSize,
|
|
83
83
|
ObjectDetectionOutputElement,
|
|
84
84
|
Padding,
|
|
85
85
|
QuestionAnsweringOutputElement,
|
|
@@ -100,9 +100,14 @@ from huggingface_hub.inference._generated.types import (
|
|
|
100
100
|
ZeroShotClassificationOutputElement,
|
|
101
101
|
ZeroShotImageClassificationOutputElement,
|
|
102
102
|
)
|
|
103
|
-
from huggingface_hub.inference._providers import
|
|
104
|
-
from huggingface_hub.utils import
|
|
105
|
-
|
|
103
|
+
from huggingface_hub.inference._providers import PROVIDER_OR_POLICY_T, get_provider_helper
|
|
104
|
+
from huggingface_hub.utils import (
|
|
105
|
+
build_hf_headers,
|
|
106
|
+
get_session,
|
|
107
|
+
hf_raise_for_status,
|
|
108
|
+
validate_hf_hub_args,
|
|
109
|
+
)
|
|
110
|
+
from huggingface_hub.utils._auth import get_token
|
|
106
111
|
|
|
107
112
|
|
|
108
113
|
if TYPE_CHECKING:
|
|
@@ -128,28 +133,25 @@ class InferenceClient:
|
|
|
128
133
|
or a URL to a deployed Inference Endpoint. Defaults to None, in which case a recommended model is
|
|
129
134
|
automatically selected for the task.
|
|
130
135
|
Note: for better compatibility with OpenAI's client, `model` has been aliased as `base_url`. Those 2
|
|
131
|
-
arguments are mutually exclusive. If
|
|
132
|
-
path will be appended to the base URL (see the [TGI Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api)
|
|
133
|
-
documentation for details). When passing a URL as `model`, the client will not append any suffix path to it.
|
|
136
|
+
arguments are mutually exclusive. If a URL is passed as `model` or `base_url` for chat completion, the `(/v1)/chat/completions` suffix path will be appended to the URL.
|
|
134
137
|
provider (`str`, *optional*):
|
|
135
|
-
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"replicate"`, "sambanova"` or `"
|
|
136
|
-
|
|
138
|
+
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"clarifai"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `"publicai"`, `"replicate"`, `"sambanova"`, `"scaleway"`, `"together"`, `"wavespeed"` or `"zai-org"`.
|
|
139
|
+
Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers.
|
|
137
140
|
If model is a URL or `base_url` is passed, then `provider` is not used.
|
|
138
|
-
token (`str
|
|
141
|
+
token (`str`, *optional*):
|
|
139
142
|
Hugging Face token. Will default to the locally saved token if not provided.
|
|
140
|
-
Pass `token=False` if you don't want to send your token to the server.
|
|
141
143
|
Note: for better compatibility with OpenAI's client, `token` has been aliased as `api_key`. Those 2
|
|
142
144
|
arguments are mutually exclusive and have the exact same behavior.
|
|
143
145
|
timeout (`float`, `optional`):
|
|
144
|
-
The maximum number of seconds to wait for a response from the server.
|
|
145
|
-
|
|
146
|
-
headers (`Dict[str, str]`, `optional`):
|
|
146
|
+
The maximum number of seconds to wait for a response from the server. Defaults to None, meaning it will loop until the server is available.
|
|
147
|
+
headers (`dict[str, str]`, `optional`):
|
|
147
148
|
Additional headers to send to the server. By default only the authorization and user-agent headers are sent.
|
|
148
149
|
Values in this dictionary will override the default values.
|
|
149
|
-
|
|
150
|
+
bill_to (`str`, `optional`):
|
|
151
|
+
The billing account to use for the requests. By default the requests are billed on the user's account.
|
|
152
|
+
Requests can only be billed to an organization the user is a member of, and which has subscribed to Enterprise Hub.
|
|
153
|
+
cookies (`dict[str, str]`, `optional`):
|
|
150
154
|
Additional cookies to send to the server.
|
|
151
|
-
proxies (`Any`, `optional`):
|
|
152
|
-
Proxies to use for the request.
|
|
153
155
|
base_url (`str`, `optional`):
|
|
154
156
|
Base URL to run inference. This is a duplicated argument from `model` to make [`InferenceClient`]
|
|
155
157
|
follow the same pattern as `openai.OpenAI` client. Cannot be used if `model` is set. Defaults to None.
|
|
@@ -158,16 +160,17 @@ class InferenceClient:
|
|
|
158
160
|
follow the same pattern as `openai.OpenAI` client. Cannot be used if `token` is set. Defaults to None.
|
|
159
161
|
"""
|
|
160
162
|
|
|
163
|
+
@validate_hf_hub_args
|
|
161
164
|
def __init__(
|
|
162
165
|
self,
|
|
163
166
|
model: Optional[str] = None,
|
|
164
167
|
*,
|
|
165
|
-
provider: Optional[
|
|
168
|
+
provider: Optional[PROVIDER_OR_POLICY_T] = None,
|
|
166
169
|
token: Optional[str] = None,
|
|
167
170
|
timeout: Optional[float] = None,
|
|
168
|
-
headers: Optional[
|
|
169
|
-
cookies: Optional[
|
|
170
|
-
|
|
171
|
+
headers: Optional[dict[str, str]] = None,
|
|
172
|
+
cookies: Optional[dict[str, str]] = None,
|
|
173
|
+
bill_to: Optional[str] = None,
|
|
171
174
|
# OpenAI compatibility
|
|
172
175
|
base_url: Optional[str] = None,
|
|
173
176
|
api_key: Optional[str] = None,
|
|
@@ -185,97 +188,63 @@ class InferenceClient:
|
|
|
185
188
|
" `api_key` is an alias for `token` to make the API compatible with OpenAI's client."
|
|
186
189
|
" It has the exact same behavior as `token`."
|
|
187
190
|
)
|
|
191
|
+
token = token if token is not None else api_key
|
|
192
|
+
if isinstance(token, bool):
|
|
193
|
+
# Legacy behavior: previously it was possible to pass `token=False` to disable authentication. This is not
|
|
194
|
+
# supported anymore as authentication is required. Better to explicitly raise here rather than risking
|
|
195
|
+
# sending the locally saved token without the user knowing about it.
|
|
196
|
+
if token is False:
|
|
197
|
+
raise ValueError(
|
|
198
|
+
"Cannot use `token=False` to disable authentication as authentication is required to run Inference."
|
|
199
|
+
)
|
|
200
|
+
warnings.warn(
|
|
201
|
+
"Using `token=True` to automatically use the locally saved token is deprecated and will be removed in a future release. "
|
|
202
|
+
"Please use `token=None` instead (default).",
|
|
203
|
+
DeprecationWarning,
|
|
204
|
+
)
|
|
205
|
+
token = get_token()
|
|
188
206
|
|
|
189
207
|
self.model: Optional[str] = base_url or model
|
|
190
|
-
self.token: Optional[str] = token
|
|
191
|
-
|
|
208
|
+
self.token: Optional[str] = token
|
|
209
|
+
|
|
210
|
+
self.headers = {**headers} if headers is not None else {}
|
|
211
|
+
if bill_to is not None:
|
|
212
|
+
if (
|
|
213
|
+
constants.HUGGINGFACE_HEADER_X_BILL_TO in self.headers
|
|
214
|
+
and self.headers[constants.HUGGINGFACE_HEADER_X_BILL_TO] != bill_to
|
|
215
|
+
):
|
|
216
|
+
warnings.warn(
|
|
217
|
+
f"Overriding existing '{self.headers[constants.HUGGINGFACE_HEADER_X_BILL_TO]}' value in headers with '{bill_to}'.",
|
|
218
|
+
UserWarning,
|
|
219
|
+
)
|
|
220
|
+
self.headers[constants.HUGGINGFACE_HEADER_X_BILL_TO] = bill_to
|
|
221
|
+
|
|
222
|
+
if token is not None and not token.startswith("hf_"):
|
|
223
|
+
warnings.warn(
|
|
224
|
+
"You've provided an external provider's API key, so requests will be billed directly by the provider. "
|
|
225
|
+
"The `bill_to` parameter is only applicable for Hugging Face billing and will be ignored.",
|
|
226
|
+
UserWarning,
|
|
227
|
+
)
|
|
192
228
|
|
|
193
229
|
# Configure provider
|
|
194
|
-
self.provider = provider
|
|
230
|
+
self.provider = provider
|
|
195
231
|
|
|
196
232
|
self.cookies = cookies
|
|
197
233
|
self.timeout = timeout
|
|
198
|
-
|
|
234
|
+
|
|
235
|
+
self.exit_stack = ExitStack()
|
|
199
236
|
|
|
200
237
|
def __repr__(self):
|
|
201
238
|
return f"<InferenceClient(model='{self.model if self.model else ''}', timeout={self.timeout})>"
|
|
202
239
|
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
self,
|
|
206
|
-
*,
|
|
207
|
-
json: Optional[Union[str, Dict, List]] = None,
|
|
208
|
-
data: Optional[ContentT] = None,
|
|
209
|
-
model: Optional[str] = None,
|
|
210
|
-
task: Optional[str] = None,
|
|
211
|
-
stream: Literal[False] = ...,
|
|
212
|
-
) -> bytes: ...
|
|
240
|
+
def __enter__(self):
|
|
241
|
+
return self
|
|
213
242
|
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
self,
|
|
217
|
-
*,
|
|
218
|
-
json: Optional[Union[str, Dict, List]] = None,
|
|
219
|
-
data: Optional[ContentT] = None,
|
|
220
|
-
model: Optional[str] = None,
|
|
221
|
-
task: Optional[str] = None,
|
|
222
|
-
stream: Literal[True] = ...,
|
|
223
|
-
) -> Iterable[bytes]: ...
|
|
243
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
|
244
|
+
self.exit_stack.close()
|
|
224
245
|
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
self,
|
|
228
|
-
*,
|
|
229
|
-
json: Optional[Union[str, Dict, List]] = None,
|
|
230
|
-
data: Optional[ContentT] = None,
|
|
231
|
-
model: Optional[str] = None,
|
|
232
|
-
task: Optional[str] = None,
|
|
233
|
-
stream: bool = False,
|
|
234
|
-
) -> Union[bytes, Iterable[bytes]]: ...
|
|
235
|
-
|
|
236
|
-
@_deprecate_method(
|
|
237
|
-
version="0.31.0",
|
|
238
|
-
message=(
|
|
239
|
-
"Making direct POST requests to the inference server is not supported anymore. "
|
|
240
|
-
"Please use task methods instead (e.g. `InferenceClient.chat_completion`). "
|
|
241
|
-
"If your use case is not supported, please open an issue in https://github.com/huggingface/huggingface_hub."
|
|
242
|
-
),
|
|
243
|
-
)
|
|
244
|
-
def post(
|
|
245
|
-
self,
|
|
246
|
-
*,
|
|
247
|
-
json: Optional[Union[str, Dict, List]] = None,
|
|
248
|
-
data: Optional[ContentT] = None,
|
|
249
|
-
model: Optional[str] = None,
|
|
250
|
-
task: Optional[str] = None,
|
|
251
|
-
stream: bool = False,
|
|
252
|
-
) -> Union[bytes, Iterable[bytes]]:
|
|
253
|
-
"""
|
|
254
|
-
Make a POST request to the inference server.
|
|
255
|
-
|
|
256
|
-
This method is deprecated and will be removed in the future.
|
|
257
|
-
Please use task methods instead (e.g. `InferenceClient.chat_completion`).
|
|
258
|
-
"""
|
|
259
|
-
if self.provider != "hf-inference":
|
|
260
|
-
raise ValueError(
|
|
261
|
-
"Cannot use `post` with another provider than `hf-inference`. "
|
|
262
|
-
"`InferenceClient.post` is deprecated and should not be used directly anymore."
|
|
263
|
-
)
|
|
264
|
-
provider_helper = HFInferenceTask(task or "unknown")
|
|
265
|
-
mapped_model = provider_helper._prepare_mapped_model(model or self.model)
|
|
266
|
-
url = provider_helper._prepare_url(self.token, mapped_model) # type: ignore[arg-type]
|
|
267
|
-
headers = provider_helper._prepare_headers(self.headers, self.token) # type: ignore[arg-type]
|
|
268
|
-
return self._inner_post(
|
|
269
|
-
request_parameters=RequestParameters(
|
|
270
|
-
url=url,
|
|
271
|
-
task=task or "unknown",
|
|
272
|
-
model=model or "unknown",
|
|
273
|
-
json=json,
|
|
274
|
-
data=data,
|
|
275
|
-
headers=headers,
|
|
276
|
-
),
|
|
277
|
-
stream=stream,
|
|
278
|
-
)
|
|
246
|
+
def close(self):
|
|
247
|
+
self.exit_stack.close()
|
|
279
248
|
|
|
280
249
|
@overload
|
|
281
250
|
def _inner_post( # type: ignore[misc]
|
|
@@ -285,48 +254,48 @@ class InferenceClient:
|
|
|
285
254
|
@overload
|
|
286
255
|
def _inner_post( # type: ignore[misc]
|
|
287
256
|
self, request_parameters: RequestParameters, *, stream: Literal[True] = ...
|
|
288
|
-
) -> Iterable[
|
|
257
|
+
) -> Iterable[str]: ...
|
|
289
258
|
|
|
290
259
|
@overload
|
|
291
260
|
def _inner_post(
|
|
292
261
|
self, request_parameters: RequestParameters, *, stream: bool = False
|
|
293
|
-
) -> Union[bytes, Iterable[
|
|
262
|
+
) -> Union[bytes, Iterable[str]]: ...
|
|
294
263
|
|
|
295
264
|
def _inner_post(
|
|
296
265
|
self, request_parameters: RequestParameters, *, stream: bool = False
|
|
297
|
-
) -> Union[bytes, Iterable[
|
|
266
|
+
) -> Union[bytes, Iterable[str]]:
|
|
298
267
|
"""Make a request to the inference server."""
|
|
299
268
|
# TODO: this should be handled in provider helpers directly
|
|
300
269
|
if request_parameters.task in TASKS_EXPECTING_IMAGES and "Accept" not in request_parameters.headers:
|
|
301
270
|
request_parameters.headers["Accept"] = "image/png"
|
|
302
271
|
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
272
|
+
try:
|
|
273
|
+
response = self.exit_stack.enter_context(
|
|
274
|
+
get_session().stream(
|
|
275
|
+
"POST",
|
|
276
|
+
request_parameters.url,
|
|
277
|
+
json=request_parameters.json,
|
|
278
|
+
content=request_parameters.data,
|
|
279
|
+
headers=request_parameters.headers,
|
|
280
|
+
cookies=self.cookies,
|
|
281
|
+
timeout=self.timeout,
|
|
282
|
+
)
|
|
283
|
+
)
|
|
284
|
+
hf_raise_for_status(response)
|
|
285
|
+
if stream:
|
|
286
|
+
return response.iter_lines()
|
|
287
|
+
else:
|
|
288
|
+
return response.read()
|
|
289
|
+
except TimeoutError as error:
|
|
290
|
+
# Convert any `TimeoutError` to a `InferenceTimeoutError`
|
|
291
|
+
raise InferenceTimeoutError(f"Inference call timed out: {request_parameters.url}") from error # type: ignore
|
|
292
|
+
except HfHubHTTPError as error:
|
|
293
|
+
if error.response.status_code == 422 and request_parameters.task != "unknown":
|
|
294
|
+
msg = str(error.args[0])
|
|
295
|
+
if len(error.response.text) > 0:
|
|
296
|
+
msg += f"{os.linesep}{error.response.text}{os.linesep}"
|
|
297
|
+
error.args = (msg,) + error.args[1:]
|
|
298
|
+
raise
|
|
330
299
|
|
|
331
300
|
def audio_classification(
|
|
332
301
|
self,
|
|
@@ -335,7 +304,7 @@ class InferenceClient:
|
|
|
335
304
|
model: Optional[str] = None,
|
|
336
305
|
top_k: Optional[int] = None,
|
|
337
306
|
function_to_apply: Optional["AudioClassificationOutputTransform"] = None,
|
|
338
|
-
) ->
|
|
307
|
+
) -> list[AudioClassificationOutputElement]:
|
|
339
308
|
"""
|
|
340
309
|
Perform audio classification on the provided audio content.
|
|
341
310
|
|
|
@@ -353,12 +322,12 @@ class InferenceClient:
|
|
|
353
322
|
The function to apply to the model outputs in order to retrieve the scores.
|
|
354
323
|
|
|
355
324
|
Returns:
|
|
356
|
-
`
|
|
325
|
+
`list[AudioClassificationOutputElement]`: List of [`AudioClassificationOutputElement`] items containing the predicted labels and their confidence.
|
|
357
326
|
|
|
358
327
|
Raises:
|
|
359
328
|
[`InferenceTimeoutError`]:
|
|
360
329
|
If the model is unavailable or the request times out.
|
|
361
|
-
`
|
|
330
|
+
[`HfHubHTTPError`]:
|
|
362
331
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
363
332
|
|
|
364
333
|
Example:
|
|
@@ -373,12 +342,13 @@ class InferenceClient:
|
|
|
373
342
|
]
|
|
374
343
|
```
|
|
375
344
|
"""
|
|
376
|
-
|
|
345
|
+
model_id = model or self.model
|
|
346
|
+
provider_helper = get_provider_helper(self.provider, task="audio-classification", model=model_id)
|
|
377
347
|
request_parameters = provider_helper.prepare_request(
|
|
378
348
|
inputs=audio,
|
|
379
349
|
parameters={"function_to_apply": function_to_apply, "top_k": top_k},
|
|
380
350
|
headers=self.headers,
|
|
381
|
-
model=
|
|
351
|
+
model=model_id,
|
|
382
352
|
api_key=self.token,
|
|
383
353
|
)
|
|
384
354
|
response = self._inner_post(request_parameters)
|
|
@@ -389,7 +359,7 @@ class InferenceClient:
|
|
|
389
359
|
audio: ContentT,
|
|
390
360
|
*,
|
|
391
361
|
model: Optional[str] = None,
|
|
392
|
-
) ->
|
|
362
|
+
) -> list[AudioToAudioOutputElement]:
|
|
393
363
|
"""
|
|
394
364
|
Performs multiple tasks related to audio-to-audio depending on the model (eg: speech enhancement, source separation).
|
|
395
365
|
|
|
@@ -403,12 +373,12 @@ class InferenceClient:
|
|
|
403
373
|
audio_to_audio will be used.
|
|
404
374
|
|
|
405
375
|
Returns:
|
|
406
|
-
`
|
|
376
|
+
`list[AudioToAudioOutputElement]`: A list of [`AudioToAudioOutputElement`] items containing audios label, content-type, and audio content in blob.
|
|
407
377
|
|
|
408
378
|
Raises:
|
|
409
379
|
`InferenceTimeoutError`:
|
|
410
380
|
If the model is unavailable or the request times out.
|
|
411
|
-
`
|
|
381
|
+
[`HfHubHTTPError`]:
|
|
412
382
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
413
383
|
|
|
414
384
|
Example:
|
|
@@ -421,12 +391,13 @@ class InferenceClient:
|
|
|
421
391
|
f.write(item.blob)
|
|
422
392
|
```
|
|
423
393
|
"""
|
|
424
|
-
|
|
394
|
+
model_id = model or self.model
|
|
395
|
+
provider_helper = get_provider_helper(self.provider, task="audio-to-audio", model=model_id)
|
|
425
396
|
request_parameters = provider_helper.prepare_request(
|
|
426
397
|
inputs=audio,
|
|
427
398
|
parameters={},
|
|
428
399
|
headers=self.headers,
|
|
429
|
-
model=
|
|
400
|
+
model=model_id,
|
|
430
401
|
api_key=self.token,
|
|
431
402
|
)
|
|
432
403
|
response = self._inner_post(request_parameters)
|
|
@@ -440,7 +411,7 @@ class InferenceClient:
|
|
|
440
411
|
audio: ContentT,
|
|
441
412
|
*,
|
|
442
413
|
model: Optional[str] = None,
|
|
443
|
-
extra_body: Optional[
|
|
414
|
+
extra_body: Optional[dict] = None,
|
|
444
415
|
) -> AutomaticSpeechRecognitionOutput:
|
|
445
416
|
"""
|
|
446
417
|
Perform automatic speech recognition (ASR or audio-to-text) on the given audio content.
|
|
@@ -451,7 +422,7 @@ class InferenceClient:
|
|
|
451
422
|
model (`str`, *optional*):
|
|
452
423
|
The model to use for ASR. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
|
453
424
|
Inference Endpoint. If not provided, the default recommended model for ASR will be used.
|
|
454
|
-
extra_body (`
|
|
425
|
+
extra_body (`dict`, *optional*):
|
|
455
426
|
Additional provider-specific parameters to pass to the model. Refer to the provider's documentation
|
|
456
427
|
for supported parameters.
|
|
457
428
|
Returns:
|
|
@@ -460,7 +431,7 @@ class InferenceClient:
|
|
|
460
431
|
Raises:
|
|
461
432
|
[`InferenceTimeoutError`]:
|
|
462
433
|
If the model is unavailable or the request times out.
|
|
463
|
-
`
|
|
434
|
+
[`HfHubHTTPError`]:
|
|
464
435
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
465
436
|
|
|
466
437
|
Example:
|
|
@@ -471,12 +442,13 @@ class InferenceClient:
|
|
|
471
442
|
"hello world"
|
|
472
443
|
```
|
|
473
444
|
"""
|
|
474
|
-
|
|
445
|
+
model_id = model or self.model
|
|
446
|
+
provider_helper = get_provider_helper(self.provider, task="automatic-speech-recognition", model=model_id)
|
|
475
447
|
request_parameters = provider_helper.prepare_request(
|
|
476
448
|
inputs=audio,
|
|
477
449
|
parameters={**(extra_body or {})},
|
|
478
450
|
headers=self.headers,
|
|
479
|
-
model=
|
|
451
|
+
model=model_id,
|
|
480
452
|
api_key=self.token,
|
|
481
453
|
)
|
|
482
454
|
response = self._inner_post(request_parameters)
|
|
@@ -485,121 +457,117 @@ class InferenceClient:
|
|
|
485
457
|
@overload
|
|
486
458
|
def chat_completion( # type: ignore
|
|
487
459
|
self,
|
|
488
|
-
messages:
|
|
460
|
+
messages: list[Union[dict, ChatCompletionInputMessage]],
|
|
489
461
|
*,
|
|
490
462
|
model: Optional[str] = None,
|
|
491
463
|
stream: Literal[False] = False,
|
|
492
464
|
frequency_penalty: Optional[float] = None,
|
|
493
|
-
logit_bias: Optional[
|
|
465
|
+
logit_bias: Optional[list[float]] = None,
|
|
494
466
|
logprobs: Optional[bool] = None,
|
|
495
467
|
max_tokens: Optional[int] = None,
|
|
496
468
|
n: Optional[int] = None,
|
|
497
469
|
presence_penalty: Optional[float] = None,
|
|
498
470
|
response_format: Optional[ChatCompletionInputGrammarType] = None,
|
|
499
471
|
seed: Optional[int] = None,
|
|
500
|
-
stop: Optional[
|
|
472
|
+
stop: Optional[list[str]] = None,
|
|
501
473
|
stream_options: Optional[ChatCompletionInputStreamOptions] = None,
|
|
502
474
|
temperature: Optional[float] = None,
|
|
503
475
|
tool_choice: Optional[Union[ChatCompletionInputToolChoiceClass, "ChatCompletionInputToolChoiceEnum"]] = None,
|
|
504
476
|
tool_prompt: Optional[str] = None,
|
|
505
|
-
tools: Optional[
|
|
477
|
+
tools: Optional[list[ChatCompletionInputTool]] = None,
|
|
506
478
|
top_logprobs: Optional[int] = None,
|
|
507
479
|
top_p: Optional[float] = None,
|
|
508
|
-
extra_body: Optional[
|
|
480
|
+
extra_body: Optional[dict] = None,
|
|
509
481
|
) -> ChatCompletionOutput: ...
|
|
510
482
|
|
|
511
483
|
@overload
|
|
512
484
|
def chat_completion( # type: ignore
|
|
513
485
|
self,
|
|
514
|
-
messages:
|
|
486
|
+
messages: list[Union[dict, ChatCompletionInputMessage]],
|
|
515
487
|
*,
|
|
516
488
|
model: Optional[str] = None,
|
|
517
489
|
stream: Literal[True] = True,
|
|
518
490
|
frequency_penalty: Optional[float] = None,
|
|
519
|
-
logit_bias: Optional[
|
|
491
|
+
logit_bias: Optional[list[float]] = None,
|
|
520
492
|
logprobs: Optional[bool] = None,
|
|
521
493
|
max_tokens: Optional[int] = None,
|
|
522
494
|
n: Optional[int] = None,
|
|
523
495
|
presence_penalty: Optional[float] = None,
|
|
524
496
|
response_format: Optional[ChatCompletionInputGrammarType] = None,
|
|
525
497
|
seed: Optional[int] = None,
|
|
526
|
-
stop: Optional[
|
|
498
|
+
stop: Optional[list[str]] = None,
|
|
527
499
|
stream_options: Optional[ChatCompletionInputStreamOptions] = None,
|
|
528
500
|
temperature: Optional[float] = None,
|
|
529
501
|
tool_choice: Optional[Union[ChatCompletionInputToolChoiceClass, "ChatCompletionInputToolChoiceEnum"]] = None,
|
|
530
502
|
tool_prompt: Optional[str] = None,
|
|
531
|
-
tools: Optional[
|
|
503
|
+
tools: Optional[list[ChatCompletionInputTool]] = None,
|
|
532
504
|
top_logprobs: Optional[int] = None,
|
|
533
505
|
top_p: Optional[float] = None,
|
|
534
|
-
extra_body: Optional[
|
|
506
|
+
extra_body: Optional[dict] = None,
|
|
535
507
|
) -> Iterable[ChatCompletionStreamOutput]: ...
|
|
536
508
|
|
|
537
509
|
@overload
|
|
538
510
|
def chat_completion(
|
|
539
511
|
self,
|
|
540
|
-
messages:
|
|
512
|
+
messages: list[Union[dict, ChatCompletionInputMessage]],
|
|
541
513
|
*,
|
|
542
514
|
model: Optional[str] = None,
|
|
543
515
|
stream: bool = False,
|
|
544
516
|
frequency_penalty: Optional[float] = None,
|
|
545
|
-
logit_bias: Optional[
|
|
517
|
+
logit_bias: Optional[list[float]] = None,
|
|
546
518
|
logprobs: Optional[bool] = None,
|
|
547
519
|
max_tokens: Optional[int] = None,
|
|
548
520
|
n: Optional[int] = None,
|
|
549
521
|
presence_penalty: Optional[float] = None,
|
|
550
522
|
response_format: Optional[ChatCompletionInputGrammarType] = None,
|
|
551
523
|
seed: Optional[int] = None,
|
|
552
|
-
stop: Optional[
|
|
524
|
+
stop: Optional[list[str]] = None,
|
|
553
525
|
stream_options: Optional[ChatCompletionInputStreamOptions] = None,
|
|
554
526
|
temperature: Optional[float] = None,
|
|
555
527
|
tool_choice: Optional[Union[ChatCompletionInputToolChoiceClass, "ChatCompletionInputToolChoiceEnum"]] = None,
|
|
556
528
|
tool_prompt: Optional[str] = None,
|
|
557
|
-
tools: Optional[
|
|
529
|
+
tools: Optional[list[ChatCompletionInputTool]] = None,
|
|
558
530
|
top_logprobs: Optional[int] = None,
|
|
559
531
|
top_p: Optional[float] = None,
|
|
560
|
-
extra_body: Optional[
|
|
532
|
+
extra_body: Optional[dict] = None,
|
|
561
533
|
) -> Union[ChatCompletionOutput, Iterable[ChatCompletionStreamOutput]]: ...
|
|
562
534
|
|
|
563
535
|
def chat_completion(
|
|
564
536
|
self,
|
|
565
|
-
messages:
|
|
537
|
+
messages: list[Union[dict, ChatCompletionInputMessage]],
|
|
566
538
|
*,
|
|
567
539
|
model: Optional[str] = None,
|
|
568
540
|
stream: bool = False,
|
|
569
541
|
# Parameters from ChatCompletionInput (handled manually)
|
|
570
542
|
frequency_penalty: Optional[float] = None,
|
|
571
|
-
logit_bias: Optional[
|
|
543
|
+
logit_bias: Optional[list[float]] = None,
|
|
572
544
|
logprobs: Optional[bool] = None,
|
|
573
545
|
max_tokens: Optional[int] = None,
|
|
574
546
|
n: Optional[int] = None,
|
|
575
547
|
presence_penalty: Optional[float] = None,
|
|
576
548
|
response_format: Optional[ChatCompletionInputGrammarType] = None,
|
|
577
549
|
seed: Optional[int] = None,
|
|
578
|
-
stop: Optional[
|
|
550
|
+
stop: Optional[list[str]] = None,
|
|
579
551
|
stream_options: Optional[ChatCompletionInputStreamOptions] = None,
|
|
580
552
|
temperature: Optional[float] = None,
|
|
581
553
|
tool_choice: Optional[Union[ChatCompletionInputToolChoiceClass, "ChatCompletionInputToolChoiceEnum"]] = None,
|
|
582
554
|
tool_prompt: Optional[str] = None,
|
|
583
|
-
tools: Optional[
|
|
555
|
+
tools: Optional[list[ChatCompletionInputTool]] = None,
|
|
584
556
|
top_logprobs: Optional[int] = None,
|
|
585
557
|
top_p: Optional[float] = None,
|
|
586
|
-
extra_body: Optional[
|
|
558
|
+
extra_body: Optional[dict] = None,
|
|
587
559
|
) -> Union[ChatCompletionOutput, Iterable[ChatCompletionStreamOutput]]:
|
|
588
560
|
"""
|
|
589
561
|
A method for completing conversations using a specified language model.
|
|
590
562
|
|
|
591
|
-
|
|
563
|
+
> [!TIP]
|
|
564
|
+
> The `client.chat_completion` method is aliased as `client.chat.completions.create` for compatibility with OpenAI's client.
|
|
565
|
+
> Inputs and outputs are strictly the same and using either syntax will yield the same results.
|
|
566
|
+
> Check out the [Inference guide](https://huggingface.co/docs/huggingface_hub/guides/inference#openai-compatibility)
|
|
567
|
+
> for more details about OpenAI's compatibility.
|
|
592
568
|
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
Check out the [Inference guide](https://huggingface.co/docs/huggingface_hub/guides/inference#openai-compatibility)
|
|
596
|
-
for more details about OpenAI's compatibility.
|
|
597
|
-
|
|
598
|
-
</Tip>
|
|
599
|
-
|
|
600
|
-
<Tip>
|
|
601
|
-
You can pass provider-specific parameters to the model by using the `extra_body` argument.
|
|
602
|
-
</Tip>
|
|
569
|
+
> [!TIP]
|
|
570
|
+
> You can pass provider-specific parameters to the model by using the `extra_body` argument.
|
|
603
571
|
|
|
604
572
|
Args:
|
|
605
573
|
messages (List of [`ChatCompletionInputMessage`]):
|
|
@@ -613,7 +581,7 @@ class InferenceClient:
|
|
|
613
581
|
frequency_penalty (`float`, *optional*):
|
|
614
582
|
Penalizes new tokens based on their existing frequency
|
|
615
583
|
in the text so far. Range: [-2.0, 2.0]. Defaults to 0.0.
|
|
616
|
-
logit_bias (`
|
|
584
|
+
logit_bias (`list[float]`, *optional*):
|
|
617
585
|
Adjusts the likelihood of specific tokens appearing in the generated output.
|
|
618
586
|
logprobs (`bool`, *optional*):
|
|
619
587
|
Whether to return log probabilities of the output tokens or not. If true, returns the log
|
|
@@ -629,7 +597,7 @@ class InferenceClient:
|
|
|
629
597
|
Grammar constraints. Can be either a JSONSchema or a regex.
|
|
630
598
|
seed (Optional[`int`], *optional*):
|
|
631
599
|
Seed for reproducible control flow. Defaults to None.
|
|
632
|
-
stop (`
|
|
600
|
+
stop (`list[str]`, *optional*):
|
|
633
601
|
Up to four strings which trigger the end of the response.
|
|
634
602
|
Defaults to None.
|
|
635
603
|
stream (`bool`, *optional*):
|
|
@@ -653,7 +621,7 @@ class InferenceClient:
|
|
|
653
621
|
tools (List of [`ChatCompletionInputTool`], *optional*):
|
|
654
622
|
A list of tools the model may call. Currently, only functions are supported as a tool. Use this to
|
|
655
623
|
provide a list of functions the model may generate JSON inputs for.
|
|
656
|
-
extra_body (`
|
|
624
|
+
extra_body (`dict`, *optional*):
|
|
657
625
|
Additional provider-specific parameters to pass to the model. Refer to the provider's documentation
|
|
658
626
|
for supported parameters.
|
|
659
627
|
Returns:
|
|
@@ -665,7 +633,7 @@ class InferenceClient:
|
|
|
665
633
|
Raises:
|
|
666
634
|
[`InferenceTimeoutError`]:
|
|
667
635
|
If the model is unavailable or the request times out.
|
|
668
|
-
`
|
|
636
|
+
[`HfHubHTTPError`]:
|
|
669
637
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
670
638
|
|
|
671
639
|
Example:
|
|
@@ -891,7 +859,7 @@ class InferenceClient:
|
|
|
891
859
|
>>> messages = [
|
|
892
860
|
... {
|
|
893
861
|
... "role": "user",
|
|
894
|
-
... "content": "I saw a puppy a cat and a raccoon during my bike ride in the park. What did I
|
|
862
|
+
... "content": "I saw a puppy a cat and a raccoon during my bike ride in the park. What did I see and when?",
|
|
895
863
|
... },
|
|
896
864
|
... ]
|
|
897
865
|
>>> response_format = {
|
|
@@ -910,20 +878,26 @@ class InferenceClient:
|
|
|
910
878
|
... messages=messages,
|
|
911
879
|
... response_format=response_format,
|
|
912
880
|
... max_tokens=500,
|
|
913
|
-
)
|
|
881
|
+
... )
|
|
914
882
|
>>> response.choices[0].message.content
|
|
915
883
|
'{\n\n"activity": "bike ride",\n"animals": ["puppy", "cat", "raccoon"],\n"animals_seen": 3,\n"location": "park"}'
|
|
916
884
|
```
|
|
917
885
|
"""
|
|
918
|
-
# Get the provider helper
|
|
919
|
-
provider_helper = get_provider_helper(self.provider, task="conversational")
|
|
920
|
-
|
|
921
886
|
# Since `chat_completion(..., model=xxx)` is also a payload parameter for the server, we need to handle 'model' differently.
|
|
922
887
|
# `self.model` takes precedence over 'model' argument for building URL.
|
|
923
888
|
# `model` takes precedence for payload value.
|
|
924
889
|
model_id_or_url = self.model or model
|
|
925
890
|
payload_model = model or self.model
|
|
926
891
|
|
|
892
|
+
# Get the provider helper
|
|
893
|
+
provider_helper = get_provider_helper(
|
|
894
|
+
self.provider,
|
|
895
|
+
task="conversational",
|
|
896
|
+
model=model_id_or_url
|
|
897
|
+
if model_id_or_url is not None and model_id_or_url.startswith(("http://", "https://"))
|
|
898
|
+
else payload_model,
|
|
899
|
+
)
|
|
900
|
+
|
|
927
901
|
# Prepare the payload
|
|
928
902
|
parameters = {
|
|
929
903
|
"model": payload_model,
|
|
@@ -973,8 +947,8 @@ class InferenceClient:
|
|
|
973
947
|
max_question_len: Optional[int] = None,
|
|
974
948
|
max_seq_len: Optional[int] = None,
|
|
975
949
|
top_k: Optional[int] = None,
|
|
976
|
-
word_boxes: Optional[
|
|
977
|
-
) ->
|
|
950
|
+
word_boxes: Optional[list[Union[list[float], str]]] = None,
|
|
951
|
+
) -> list[DocumentQuestionAnsweringOutputElement]:
|
|
978
952
|
"""
|
|
979
953
|
Answer questions on document images.
|
|
980
954
|
|
|
@@ -1004,16 +978,16 @@ class InferenceClient:
|
|
|
1004
978
|
top_k (`int`, *optional*):
|
|
1005
979
|
The number of answers to return (will be chosen by order of likelihood). Can return less than top_k
|
|
1006
980
|
answers if there are not enough options available within the context.
|
|
1007
|
-
word_boxes (`
|
|
981
|
+
word_boxes (`list[Union[list[float], str`, *optional*):
|
|
1008
982
|
A list of words and bounding boxes (normalized 0->1000). If provided, the inference will skip the OCR
|
|
1009
983
|
step and use the provided bounding boxes instead.
|
|
1010
984
|
Returns:
|
|
1011
|
-
`
|
|
985
|
+
`list[DocumentQuestionAnsweringOutputElement]`: a list of [`DocumentQuestionAnsweringOutputElement`] items containing the predicted label, associated probability, word ids, and page number.
|
|
1012
986
|
|
|
1013
987
|
Raises:
|
|
1014
988
|
[`InferenceTimeoutError`]:
|
|
1015
989
|
If the model is unavailable or the request times out.
|
|
1016
|
-
`
|
|
990
|
+
[`HfHubHTTPError`]:
|
|
1017
991
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
1018
992
|
|
|
1019
993
|
|
|
@@ -1025,8 +999,9 @@ class InferenceClient:
|
|
|
1025
999
|
[DocumentQuestionAnsweringOutputElement(answer='us-001', end=16, score=0.9999666213989258, start=16)]
|
|
1026
1000
|
```
|
|
1027
1001
|
"""
|
|
1028
|
-
|
|
1029
|
-
provider_helper = get_provider_helper(self.provider, task="document-question-answering")
|
|
1002
|
+
model_id = model or self.model
|
|
1003
|
+
provider_helper = get_provider_helper(self.provider, task="document-question-answering", model=model_id)
|
|
1004
|
+
inputs: dict[str, Any] = {"question": question, "image": _b64_encode(image)}
|
|
1030
1005
|
request_parameters = provider_helper.prepare_request(
|
|
1031
1006
|
inputs=inputs,
|
|
1032
1007
|
parameters={
|
|
@@ -1040,7 +1015,7 @@ class InferenceClient:
|
|
|
1040
1015
|
"word_boxes": word_boxes,
|
|
1041
1016
|
},
|
|
1042
1017
|
headers=self.headers,
|
|
1043
|
-
model=
|
|
1018
|
+
model=model_id,
|
|
1044
1019
|
api_key=self.token,
|
|
1045
1020
|
)
|
|
1046
1021
|
response = self._inner_post(request_parameters)
|
|
@@ -1063,8 +1038,8 @@ class InferenceClient:
|
|
|
1063
1038
|
text (`str`):
|
|
1064
1039
|
The text to embed.
|
|
1065
1040
|
model (`str`, *optional*):
|
|
1066
|
-
The model to use for the
|
|
1067
|
-
a deployed Inference Endpoint. If not provided, the default recommended
|
|
1041
|
+
The model to use for the feature extraction task. Can be a model ID hosted on the Hugging Face Hub or a URL to
|
|
1042
|
+
a deployed Inference Endpoint. If not provided, the default recommended feature extraction model will be used.
|
|
1068
1043
|
Defaults to None.
|
|
1069
1044
|
normalize (`bool`, *optional*):
|
|
1070
1045
|
Whether to normalize the embeddings or not.
|
|
@@ -1087,7 +1062,7 @@ class InferenceClient:
|
|
|
1087
1062
|
Raises:
|
|
1088
1063
|
[`InferenceTimeoutError`]:
|
|
1089
1064
|
If the model is unavailable or the request times out.
|
|
1090
|
-
`
|
|
1065
|
+
[`HfHubHTTPError`]:
|
|
1091
1066
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
1092
1067
|
|
|
1093
1068
|
Example:
|
|
@@ -1101,7 +1076,8 @@ class InferenceClient:
|
|
|
1101
1076
|
[ 0.28552425, -0.928395 , -1.2077185 , ..., 0.76810825, -2.1069427 , 0.6236161 ]], dtype=float32)
|
|
1102
1077
|
```
|
|
1103
1078
|
"""
|
|
1104
|
-
|
|
1079
|
+
model_id = model or self.model
|
|
1080
|
+
provider_helper = get_provider_helper(self.provider, task="feature-extraction", model=model_id)
|
|
1105
1081
|
request_parameters = provider_helper.prepare_request(
|
|
1106
1082
|
inputs=text,
|
|
1107
1083
|
parameters={
|
|
@@ -1111,21 +1087,21 @@ class InferenceClient:
|
|
|
1111
1087
|
"truncation_direction": truncation_direction,
|
|
1112
1088
|
},
|
|
1113
1089
|
headers=self.headers,
|
|
1114
|
-
model=
|
|
1090
|
+
model=model_id,
|
|
1115
1091
|
api_key=self.token,
|
|
1116
1092
|
)
|
|
1117
1093
|
response = self._inner_post(request_parameters)
|
|
1118
1094
|
np = _import_numpy()
|
|
1119
|
-
return np.array(
|
|
1095
|
+
return np.array(provider_helper.get_response(response), dtype="float32")
|
|
1120
1096
|
|
|
1121
1097
|
def fill_mask(
|
|
1122
1098
|
self,
|
|
1123
1099
|
text: str,
|
|
1124
1100
|
*,
|
|
1125
1101
|
model: Optional[str] = None,
|
|
1126
|
-
targets: Optional[
|
|
1102
|
+
targets: Optional[list[str]] = None,
|
|
1127
1103
|
top_k: Optional[int] = None,
|
|
1128
|
-
) ->
|
|
1104
|
+
) -> list[FillMaskOutputElement]:
|
|
1129
1105
|
"""
|
|
1130
1106
|
Fill in a hole with a missing word (token to be precise).
|
|
1131
1107
|
|
|
@@ -1135,20 +1111,20 @@ class InferenceClient:
|
|
|
1135
1111
|
model (`str`, *optional*):
|
|
1136
1112
|
The model to use for the fill mask task. Can be a model ID hosted on the Hugging Face Hub or a URL to
|
|
1137
1113
|
a deployed Inference Endpoint. If not provided, the default recommended fill mask model will be used.
|
|
1138
|
-
targets (`
|
|
1114
|
+
targets (`list[str`, *optional*):
|
|
1139
1115
|
When passed, the model will limit the scores to the passed targets instead of looking up in the whole
|
|
1140
1116
|
vocabulary. If the provided targets are not in the model vocab, they will be tokenized and the first
|
|
1141
1117
|
resulting token will be used (with a warning, and that might be slower).
|
|
1142
1118
|
top_k (`int`, *optional*):
|
|
1143
1119
|
When passed, overrides the number of predictions to return.
|
|
1144
1120
|
Returns:
|
|
1145
|
-
`
|
|
1121
|
+
`list[FillMaskOutputElement]`: a list of [`FillMaskOutputElement`] items containing the predicted label, associated
|
|
1146
1122
|
probability, token reference, and completed text.
|
|
1147
1123
|
|
|
1148
1124
|
Raises:
|
|
1149
1125
|
[`InferenceTimeoutError`]:
|
|
1150
1126
|
If the model is unavailable or the request times out.
|
|
1151
|
-
`
|
|
1127
|
+
[`HfHubHTTPError`]:
|
|
1152
1128
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
1153
1129
|
|
|
1154
1130
|
Example:
|
|
@@ -1162,12 +1138,13 @@ class InferenceClient:
|
|
|
1162
1138
|
]
|
|
1163
1139
|
```
|
|
1164
1140
|
"""
|
|
1165
|
-
|
|
1141
|
+
model_id = model or self.model
|
|
1142
|
+
provider_helper = get_provider_helper(self.provider, task="fill-mask", model=model_id)
|
|
1166
1143
|
request_parameters = provider_helper.prepare_request(
|
|
1167
1144
|
inputs=text,
|
|
1168
1145
|
parameters={"targets": targets, "top_k": top_k},
|
|
1169
1146
|
headers=self.headers,
|
|
1170
|
-
model=
|
|
1147
|
+
model=model_id,
|
|
1171
1148
|
api_key=self.token,
|
|
1172
1149
|
)
|
|
1173
1150
|
response = self._inner_post(request_parameters)
|
|
@@ -1180,13 +1157,13 @@ class InferenceClient:
|
|
|
1180
1157
|
model: Optional[str] = None,
|
|
1181
1158
|
function_to_apply: Optional["ImageClassificationOutputTransform"] = None,
|
|
1182
1159
|
top_k: Optional[int] = None,
|
|
1183
|
-
) ->
|
|
1160
|
+
) -> list[ImageClassificationOutputElement]:
|
|
1184
1161
|
"""
|
|
1185
1162
|
Perform image classification on the given image using the specified model.
|
|
1186
1163
|
|
|
1187
1164
|
Args:
|
|
1188
|
-
image (`Union[str, Path, bytes, BinaryIO]`):
|
|
1189
|
-
The image to classify. It can be raw bytes, an image file,
|
|
1165
|
+
image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`):
|
|
1166
|
+
The image to classify. It can be raw bytes, an image file, a URL to an online image, or a PIL Image.
|
|
1190
1167
|
model (`str`, *optional*):
|
|
1191
1168
|
The model to use for image classification. Can be a model ID hosted on the Hugging Face Hub or a URL to a
|
|
1192
1169
|
deployed Inference Endpoint. If not provided, the default recommended model for image classification will be used.
|
|
@@ -1195,12 +1172,12 @@ class InferenceClient:
|
|
|
1195
1172
|
top_k (`int`, *optional*):
|
|
1196
1173
|
When specified, limits the output to the top K most probable classes.
|
|
1197
1174
|
Returns:
|
|
1198
|
-
`
|
|
1175
|
+
`list[ImageClassificationOutputElement]`: a list of [`ImageClassificationOutputElement`] items containing the predicted label and associated probability.
|
|
1199
1176
|
|
|
1200
1177
|
Raises:
|
|
1201
1178
|
[`InferenceTimeoutError`]:
|
|
1202
1179
|
If the model is unavailable or the request times out.
|
|
1203
|
-
`
|
|
1180
|
+
[`HfHubHTTPError`]:
|
|
1204
1181
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
1205
1182
|
|
|
1206
1183
|
Example:
|
|
@@ -1211,12 +1188,13 @@ class InferenceClient:
|
|
|
1211
1188
|
[ImageClassificationOutputElement(label='Blenheim spaniel', score=0.9779096841812134), ...]
|
|
1212
1189
|
```
|
|
1213
1190
|
"""
|
|
1214
|
-
|
|
1191
|
+
model_id = model or self.model
|
|
1192
|
+
provider_helper = get_provider_helper(self.provider, task="image-classification", model=model_id)
|
|
1215
1193
|
request_parameters = provider_helper.prepare_request(
|
|
1216
1194
|
inputs=image,
|
|
1217
1195
|
parameters={"function_to_apply": function_to_apply, "top_k": top_k},
|
|
1218
1196
|
headers=self.headers,
|
|
1219
|
-
model=
|
|
1197
|
+
model=model_id,
|
|
1220
1198
|
api_key=self.token,
|
|
1221
1199
|
)
|
|
1222
1200
|
response = self._inner_post(request_parameters)
|
|
@@ -1231,19 +1209,16 @@ class InferenceClient:
|
|
|
1231
1209
|
overlap_mask_area_threshold: Optional[float] = None,
|
|
1232
1210
|
subtask: Optional["ImageSegmentationSubtask"] = None,
|
|
1233
1211
|
threshold: Optional[float] = None,
|
|
1234
|
-
) ->
|
|
1212
|
+
) -> list[ImageSegmentationOutputElement]:
|
|
1235
1213
|
"""
|
|
1236
1214
|
Perform image segmentation on the given image using the specified model.
|
|
1237
1215
|
|
|
1238
|
-
|
|
1239
|
-
|
|
1240
|
-
You must have `PIL` installed if you want to work with images (`pip install Pillow`).
|
|
1241
|
-
|
|
1242
|
-
</Tip>
|
|
1216
|
+
> [!WARNING]
|
|
1217
|
+
> You must have `PIL` installed if you want to work with images (`pip install Pillow`).
|
|
1243
1218
|
|
|
1244
1219
|
Args:
|
|
1245
|
-
image (`Union[str, Path, bytes, BinaryIO]`):
|
|
1246
|
-
The image to segment. It can be raw bytes, an image file,
|
|
1220
|
+
image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`):
|
|
1221
|
+
The image to segment. It can be raw bytes, an image file, a URL to an online image, or a PIL Image.
|
|
1247
1222
|
model (`str`, *optional*):
|
|
1248
1223
|
The model to use for image segmentation. Can be a model ID hosted on the Hugging Face Hub or a URL to a
|
|
1249
1224
|
deployed Inference Endpoint. If not provided, the default recommended model for image segmentation will be used.
|
|
@@ -1256,12 +1231,12 @@ class InferenceClient:
|
|
|
1256
1231
|
threshold (`float`, *optional*):
|
|
1257
1232
|
Probability threshold to filter out predicted masks.
|
|
1258
1233
|
Returns:
|
|
1259
|
-
`
|
|
1234
|
+
`list[ImageSegmentationOutputElement]`: A list of [`ImageSegmentationOutputElement`] items containing the segmented masks and associated attributes.
|
|
1260
1235
|
|
|
1261
1236
|
Raises:
|
|
1262
1237
|
[`InferenceTimeoutError`]:
|
|
1263
1238
|
If the model is unavailable or the request times out.
|
|
1264
|
-
`
|
|
1239
|
+
[`HfHubHTTPError`]:
|
|
1265
1240
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
1266
1241
|
|
|
1267
1242
|
Example:
|
|
@@ -1272,7 +1247,8 @@ class InferenceClient:
|
|
|
1272
1247
|
[ImageSegmentationOutputElement(score=0.989008, label='LABEL_184', mask=<PIL.PngImagePlugin.PngImageFile image mode=L size=400x300 at 0x7FDD2B129CC0>), ...]
|
|
1273
1248
|
```
|
|
1274
1249
|
"""
|
|
1275
|
-
|
|
1250
|
+
model_id = model or self.model
|
|
1251
|
+
provider_helper = get_provider_helper(self.provider, task="image-segmentation", model=model_id)
|
|
1276
1252
|
request_parameters = provider_helper.prepare_request(
|
|
1277
1253
|
inputs=image,
|
|
1278
1254
|
parameters={
|
|
@@ -1282,10 +1258,11 @@ class InferenceClient:
|
|
|
1282
1258
|
"threshold": threshold,
|
|
1283
1259
|
},
|
|
1284
1260
|
headers=self.headers,
|
|
1285
|
-
model=
|
|
1261
|
+
model=model_id,
|
|
1286
1262
|
api_key=self.token,
|
|
1287
1263
|
)
|
|
1288
1264
|
response = self._inner_post(request_parameters)
|
|
1265
|
+
response = provider_helper.get_response(response, request_parameters)
|
|
1289
1266
|
output = ImageSegmentationOutputElement.parse_obj_as_list(response)
|
|
1290
1267
|
for item in output:
|
|
1291
1268
|
item.mask = _b64_to_image(item.mask) # type: ignore [assignment]
|
|
@@ -1306,15 +1283,12 @@ class InferenceClient:
|
|
|
1306
1283
|
"""
|
|
1307
1284
|
Perform image-to-image translation using a specified model.
|
|
1308
1285
|
|
|
1309
|
-
|
|
1310
|
-
|
|
1311
|
-
You must have `PIL` installed if you want to work with images (`pip install Pillow`).
|
|
1312
|
-
|
|
1313
|
-
</Tip>
|
|
1286
|
+
> [!WARNING]
|
|
1287
|
+
> You must have `PIL` installed if you want to work with images (`pip install Pillow`).
|
|
1314
1288
|
|
|
1315
1289
|
Args:
|
|
1316
|
-
image (`Union[str, Path, bytes, BinaryIO]`):
|
|
1317
|
-
The input image for translation. It can be raw bytes, an image file,
|
|
1290
|
+
image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`):
|
|
1291
|
+
The input image for translation. It can be raw bytes, an image file, a URL to an online image, or a PIL Image.
|
|
1318
1292
|
prompt (`str`, *optional*):
|
|
1319
1293
|
The text prompt to guide the image generation.
|
|
1320
1294
|
negative_prompt (`str`, *optional*):
|
|
@@ -1329,7 +1303,8 @@ class InferenceClient:
|
|
|
1329
1303
|
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
|
1330
1304
|
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
|
|
1331
1305
|
target_size (`ImageToImageTargetSize`, *optional*):
|
|
1332
|
-
The size in
|
|
1306
|
+
The size in pixels of the output image. This parameter is only supported by some providers and for
|
|
1307
|
+
specific models. It will be ignored when unsupported.
|
|
1333
1308
|
|
|
1334
1309
|
Returns:
|
|
1335
1310
|
`Image`: The translated image.
|
|
@@ -1337,7 +1312,7 @@ class InferenceClient:
|
|
|
1337
1312
|
Raises:
|
|
1338
1313
|
[`InferenceTimeoutError`]:
|
|
1339
1314
|
If the model is unavailable or the request times out.
|
|
1340
|
-
`
|
|
1315
|
+
[`HfHubHTTPError`]:
|
|
1341
1316
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
1342
1317
|
|
|
1343
1318
|
Example:
|
|
@@ -1347,8 +1322,10 @@ class InferenceClient:
|
|
|
1347
1322
|
>>> image = client.image_to_image("cat.jpg", prompt="turn the cat into a tiger")
|
|
1348
1323
|
>>> image.save("tiger.jpg")
|
|
1349
1324
|
```
|
|
1325
|
+
|
|
1350
1326
|
"""
|
|
1351
|
-
|
|
1327
|
+
model_id = model or self.model
|
|
1328
|
+
provider_helper = get_provider_helper(self.provider, task="image-to-image", model=model_id)
|
|
1352
1329
|
request_parameters = provider_helper.prepare_request(
|
|
1353
1330
|
inputs=image,
|
|
1354
1331
|
parameters={
|
|
@@ -1360,22 +1337,102 @@ class InferenceClient:
|
|
|
1360
1337
|
**kwargs,
|
|
1361
1338
|
},
|
|
1362
1339
|
headers=self.headers,
|
|
1363
|
-
model=
|
|
1340
|
+
model=model_id,
|
|
1364
1341
|
api_key=self.token,
|
|
1365
1342
|
)
|
|
1366
1343
|
response = self._inner_post(request_parameters)
|
|
1344
|
+
response = provider_helper.get_response(response, request_parameters)
|
|
1367
1345
|
return _bytes_to_image(response)
|
|
1368
1346
|
|
|
1347
|
+
def image_to_video(
|
|
1348
|
+
self,
|
|
1349
|
+
image: ContentT,
|
|
1350
|
+
*,
|
|
1351
|
+
model: Optional[str] = None,
|
|
1352
|
+
prompt: Optional[str] = None,
|
|
1353
|
+
negative_prompt: Optional[str] = None,
|
|
1354
|
+
num_frames: Optional[float] = None,
|
|
1355
|
+
num_inference_steps: Optional[int] = None,
|
|
1356
|
+
guidance_scale: Optional[float] = None,
|
|
1357
|
+
seed: Optional[int] = None,
|
|
1358
|
+
target_size: Optional[ImageToVideoTargetSize] = None,
|
|
1359
|
+
**kwargs,
|
|
1360
|
+
) -> bytes:
|
|
1361
|
+
"""
|
|
1362
|
+
Generate a video from an input image.
|
|
1363
|
+
|
|
1364
|
+
Args:
|
|
1365
|
+
image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`):
|
|
1366
|
+
The input image to generate a video from. It can be raw bytes, an image file, a URL to an online image, or a PIL Image.
|
|
1367
|
+
model (`str`, *optional*):
|
|
1368
|
+
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
|
1369
|
+
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
|
|
1370
|
+
prompt (`str`, *optional*):
|
|
1371
|
+
The text prompt to guide the video generation.
|
|
1372
|
+
negative_prompt (`str`, *optional*):
|
|
1373
|
+
One prompt to guide what NOT to include in video generation.
|
|
1374
|
+
num_frames (`float`, *optional*):
|
|
1375
|
+
The num_frames parameter determines how many video frames are generated.
|
|
1376
|
+
num_inference_steps (`int`, *optional*):
|
|
1377
|
+
For diffusion models. The number of denoising steps. More denoising steps usually lead to a higher
|
|
1378
|
+
quality image at the expense of slower inference.
|
|
1379
|
+
guidance_scale (`float`, *optional*):
|
|
1380
|
+
For diffusion models. A higher guidance scale value encourages the model to generate videos closely
|
|
1381
|
+
linked to the text prompt at the expense of lower image quality.
|
|
1382
|
+
seed (`int`, *optional*):
|
|
1383
|
+
The seed to use for the video generation.
|
|
1384
|
+
target_size (`ImageToVideoTargetSize`, *optional*):
|
|
1385
|
+
The size in pixel of the output video frames.
|
|
1386
|
+
num_inference_steps (`int`, *optional*):
|
|
1387
|
+
The number of denoising steps. More denoising steps usually lead to a higher quality video at the
|
|
1388
|
+
expense of slower inference.
|
|
1389
|
+
seed (`int`, *optional*):
|
|
1390
|
+
Seed for the random number generator.
|
|
1391
|
+
|
|
1392
|
+
Returns:
|
|
1393
|
+
`bytes`: The generated video.
|
|
1394
|
+
|
|
1395
|
+
Examples:
|
|
1396
|
+
```py
|
|
1397
|
+
>>> from huggingface_hub import InferenceClient
|
|
1398
|
+
>>> client = InferenceClient()
|
|
1399
|
+
>>> video = client.image_to_video("cat.jpg", model="Wan-AI/Wan2.2-I2V-A14B", prompt="turn the cat into a tiger")
|
|
1400
|
+
>>> with open("tiger.mp4", "wb") as f:
|
|
1401
|
+
... f.write(video)
|
|
1402
|
+
```
|
|
1403
|
+
"""
|
|
1404
|
+
model_id = model or self.model
|
|
1405
|
+
provider_helper = get_provider_helper(self.provider, task="image-to-video", model=model_id)
|
|
1406
|
+
request_parameters = provider_helper.prepare_request(
|
|
1407
|
+
inputs=image,
|
|
1408
|
+
parameters={
|
|
1409
|
+
"prompt": prompt,
|
|
1410
|
+
"negative_prompt": negative_prompt,
|
|
1411
|
+
"num_frames": num_frames,
|
|
1412
|
+
"num_inference_steps": num_inference_steps,
|
|
1413
|
+
"guidance_scale": guidance_scale,
|
|
1414
|
+
"seed": seed,
|
|
1415
|
+
"target_size": target_size,
|
|
1416
|
+
**kwargs,
|
|
1417
|
+
},
|
|
1418
|
+
headers=self.headers,
|
|
1419
|
+
model=model_id,
|
|
1420
|
+
api_key=self.token,
|
|
1421
|
+
)
|
|
1422
|
+
response = self._inner_post(request_parameters)
|
|
1423
|
+
response = provider_helper.get_response(response, request_parameters)
|
|
1424
|
+
return response
|
|
1425
|
+
|
|
1369
1426
|
def image_to_text(self, image: ContentT, *, model: Optional[str] = None) -> ImageToTextOutput:
|
|
1370
1427
|
"""
|
|
1371
1428
|
Takes an input image and return text.
|
|
1372
1429
|
|
|
1373
1430
|
Models can have very different outputs depending on your use case (image captioning, optical character recognition
|
|
1374
|
-
(OCR), Pix2Struct, etc). Please have a look to the model card to learn more about a model's specificities.
|
|
1431
|
+
(OCR), Pix2Struct, etc.). Please have a look to the model card to learn more about a model's specificities.
|
|
1375
1432
|
|
|
1376
1433
|
Args:
|
|
1377
|
-
image (`Union[str, Path, bytes, BinaryIO]`):
|
|
1378
|
-
The input image to caption. It can be raw bytes, an image file,
|
|
1434
|
+
image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`):
|
|
1435
|
+
The input image to caption. It can be raw bytes, an image file, a URL to an online image, or a PIL Image.
|
|
1379
1436
|
model (`str`, *optional*):
|
|
1380
1437
|
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
|
1381
1438
|
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
|
|
@@ -1386,7 +1443,7 @@ class InferenceClient:
|
|
|
1386
1443
|
Raises:
|
|
1387
1444
|
[`InferenceTimeoutError`]:
|
|
1388
1445
|
If the model is unavailable or the request times out.
|
|
1389
|
-
`
|
|
1446
|
+
[`HfHubHTTPError`]:
|
|
1390
1447
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
1391
1448
|
|
|
1392
1449
|
Example:
|
|
@@ -1399,45 +1456,43 @@ class InferenceClient:
|
|
|
1399
1456
|
'a dog laying on the grass next to a flower pot '
|
|
1400
1457
|
```
|
|
1401
1458
|
"""
|
|
1402
|
-
|
|
1459
|
+
model_id = model or self.model
|
|
1460
|
+
provider_helper = get_provider_helper(self.provider, task="image-to-text", model=model_id)
|
|
1403
1461
|
request_parameters = provider_helper.prepare_request(
|
|
1404
1462
|
inputs=image,
|
|
1405
1463
|
parameters={},
|
|
1406
1464
|
headers=self.headers,
|
|
1407
|
-
model=
|
|
1465
|
+
model=model_id,
|
|
1408
1466
|
api_key=self.token,
|
|
1409
1467
|
)
|
|
1410
1468
|
response = self._inner_post(request_parameters)
|
|
1411
|
-
|
|
1412
|
-
return
|
|
1469
|
+
output_list: list[ImageToTextOutput] = ImageToTextOutput.parse_obj_as_list(response)
|
|
1470
|
+
return output_list[0]
|
|
1413
1471
|
|
|
1414
1472
|
def object_detection(
|
|
1415
1473
|
self, image: ContentT, *, model: Optional[str] = None, threshold: Optional[float] = None
|
|
1416
|
-
) ->
|
|
1474
|
+
) -> list[ObjectDetectionOutputElement]:
|
|
1417
1475
|
"""
|
|
1418
1476
|
Perform object detection on the given image using the specified model.
|
|
1419
1477
|
|
|
1420
|
-
|
|
1421
|
-
|
|
1422
|
-
You must have `PIL` installed if you want to work with images (`pip install Pillow`).
|
|
1423
|
-
|
|
1424
|
-
</Tip>
|
|
1478
|
+
> [!WARNING]
|
|
1479
|
+
> You must have `PIL` installed if you want to work with images (`pip install Pillow`).
|
|
1425
1480
|
|
|
1426
1481
|
Args:
|
|
1427
|
-
image (`Union[str, Path, bytes, BinaryIO]`):
|
|
1428
|
-
The image to detect objects on. It can be raw bytes, an image file,
|
|
1482
|
+
image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`):
|
|
1483
|
+
The image to detect objects on. It can be raw bytes, an image file, a URL to an online image, or a PIL Image.
|
|
1429
1484
|
model (`str`, *optional*):
|
|
1430
1485
|
The model to use for object detection. Can be a model ID hosted on the Hugging Face Hub or a URL to a
|
|
1431
1486
|
deployed Inference Endpoint. If not provided, the default recommended model for object detection (DETR) will be used.
|
|
1432
1487
|
threshold (`float`, *optional*):
|
|
1433
1488
|
The probability necessary to make a prediction.
|
|
1434
1489
|
Returns:
|
|
1435
|
-
`
|
|
1490
|
+
`list[ObjectDetectionOutputElement]`: A list of [`ObjectDetectionOutputElement`] items containing the bounding boxes and associated attributes.
|
|
1436
1491
|
|
|
1437
1492
|
Raises:
|
|
1438
1493
|
[`InferenceTimeoutError`]:
|
|
1439
1494
|
If the model is unavailable or the request times out.
|
|
1440
|
-
`
|
|
1495
|
+
[`HfHubHTTPError`]:
|
|
1441
1496
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
1442
1497
|
`ValueError`:
|
|
1443
1498
|
If the request output is not a List.
|
|
@@ -1450,12 +1505,13 @@ class InferenceClient:
|
|
|
1450
1505
|
[ObjectDetectionOutputElement(score=0.9486683011054993, label='person', box=ObjectDetectionBoundingBox(xmin=59, ymin=39, xmax=420, ymax=510)), ...]
|
|
1451
1506
|
```
|
|
1452
1507
|
"""
|
|
1453
|
-
|
|
1508
|
+
model_id = model or self.model
|
|
1509
|
+
provider_helper = get_provider_helper(self.provider, task="object-detection", model=model_id)
|
|
1454
1510
|
request_parameters = provider_helper.prepare_request(
|
|
1455
1511
|
inputs=image,
|
|
1456
1512
|
parameters={"threshold": threshold},
|
|
1457
1513
|
headers=self.headers,
|
|
1458
|
-
model=
|
|
1514
|
+
model=model_id,
|
|
1459
1515
|
api_key=self.token,
|
|
1460
1516
|
)
|
|
1461
1517
|
response = self._inner_post(request_parameters)
|
|
@@ -1474,7 +1530,7 @@ class InferenceClient:
|
|
|
1474
1530
|
max_question_len: Optional[int] = None,
|
|
1475
1531
|
max_seq_len: Optional[int] = None,
|
|
1476
1532
|
top_k: Optional[int] = None,
|
|
1477
|
-
) -> Union[QuestionAnsweringOutputElement,
|
|
1533
|
+
) -> Union[QuestionAnsweringOutputElement, list[QuestionAnsweringOutputElement]]:
|
|
1478
1534
|
"""
|
|
1479
1535
|
Retrieve the answer to a question from a given text.
|
|
1480
1536
|
|
|
@@ -1506,13 +1562,13 @@ class InferenceClient:
|
|
|
1506
1562
|
topk answers if there are not enough options available within the context.
|
|
1507
1563
|
|
|
1508
1564
|
Returns:
|
|
1509
|
-
Union[`QuestionAnsweringOutputElement`,
|
|
1565
|
+
Union[`QuestionAnsweringOutputElement`, list[`QuestionAnsweringOutputElement`]]:
|
|
1510
1566
|
When top_k is 1 or not provided, it returns a single `QuestionAnsweringOutputElement`.
|
|
1511
1567
|
When top_k is greater than 1, it returns a list of `QuestionAnsweringOutputElement`.
|
|
1512
1568
|
Raises:
|
|
1513
1569
|
[`InferenceTimeoutError`]:
|
|
1514
1570
|
If the model is unavailable or the request times out.
|
|
1515
|
-
`
|
|
1571
|
+
[`HfHubHTTPError`]:
|
|
1516
1572
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
1517
1573
|
|
|
1518
1574
|
Example:
|
|
@@ -1523,9 +1579,10 @@ class InferenceClient:
|
|
|
1523
1579
|
QuestionAnsweringOutputElement(answer='Clara', end=16, score=0.9326565265655518, start=11)
|
|
1524
1580
|
```
|
|
1525
1581
|
"""
|
|
1526
|
-
|
|
1582
|
+
model_id = model or self.model
|
|
1583
|
+
provider_helper = get_provider_helper(self.provider, task="question-answering", model=model_id)
|
|
1527
1584
|
request_parameters = provider_helper.prepare_request(
|
|
1528
|
-
inputs=
|
|
1585
|
+
inputs={"question": question, "context": context},
|
|
1529
1586
|
parameters={
|
|
1530
1587
|
"align_to_words": align_to_words,
|
|
1531
1588
|
"doc_stride": doc_stride,
|
|
@@ -1535,9 +1592,8 @@ class InferenceClient:
|
|
|
1535
1592
|
"max_seq_len": max_seq_len,
|
|
1536
1593
|
"top_k": top_k,
|
|
1537
1594
|
},
|
|
1538
|
-
extra_payload={"question": question, "context": context},
|
|
1539
1595
|
headers=self.headers,
|
|
1540
|
-
model=
|
|
1596
|
+
model=model_id,
|
|
1541
1597
|
api_key=self.token,
|
|
1542
1598
|
)
|
|
1543
1599
|
response = self._inner_post(request_parameters)
|
|
@@ -1546,28 +1602,28 @@ class InferenceClient:
|
|
|
1546
1602
|
return output
|
|
1547
1603
|
|
|
1548
1604
|
def sentence_similarity(
|
|
1549
|
-
self, sentence: str, other_sentences:
|
|
1550
|
-
) ->
|
|
1605
|
+
self, sentence: str, other_sentences: list[str], *, model: Optional[str] = None
|
|
1606
|
+
) -> list[float]:
|
|
1551
1607
|
"""
|
|
1552
1608
|
Compute the semantic similarity between a sentence and a list of other sentences by comparing their embeddings.
|
|
1553
1609
|
|
|
1554
1610
|
Args:
|
|
1555
1611
|
sentence (`str`):
|
|
1556
1612
|
The main sentence to compare to others.
|
|
1557
|
-
other_sentences (`
|
|
1613
|
+
other_sentences (`list[str]`):
|
|
1558
1614
|
The list of sentences to compare to.
|
|
1559
1615
|
model (`str`, *optional*):
|
|
1560
|
-
The model to use for the
|
|
1561
|
-
a deployed Inference Endpoint. If not provided, the default recommended
|
|
1616
|
+
The model to use for the sentence similarity task. Can be a model ID hosted on the Hugging Face Hub or a URL to
|
|
1617
|
+
a deployed Inference Endpoint. If not provided, the default recommended sentence similarity model will be used.
|
|
1562
1618
|
Defaults to None.
|
|
1563
1619
|
|
|
1564
1620
|
Returns:
|
|
1565
|
-
`
|
|
1621
|
+
`list[float]`: The embedding representing the input text.
|
|
1566
1622
|
|
|
1567
1623
|
Raises:
|
|
1568
1624
|
[`InferenceTimeoutError`]:
|
|
1569
1625
|
If the model is unavailable or the request times out.
|
|
1570
|
-
`
|
|
1626
|
+
[`HfHubHTTPError`]:
|
|
1571
1627
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
1572
1628
|
|
|
1573
1629
|
Example:
|
|
@@ -1585,13 +1641,14 @@ class InferenceClient:
|
|
|
1585
1641
|
[0.7785726189613342, 0.45876261591911316, 0.2906220555305481]
|
|
1586
1642
|
```
|
|
1587
1643
|
"""
|
|
1588
|
-
|
|
1644
|
+
model_id = model or self.model
|
|
1645
|
+
provider_helper = get_provider_helper(self.provider, task="sentence-similarity", model=model_id)
|
|
1589
1646
|
request_parameters = provider_helper.prepare_request(
|
|
1590
|
-
inputs=
|
|
1647
|
+
inputs={"source_sentence": sentence, "sentences": other_sentences},
|
|
1591
1648
|
parameters={},
|
|
1592
|
-
extra_payload={
|
|
1649
|
+
extra_payload={},
|
|
1593
1650
|
headers=self.headers,
|
|
1594
|
-
model=
|
|
1651
|
+
model=model_id,
|
|
1595
1652
|
api_key=self.token,
|
|
1596
1653
|
)
|
|
1597
1654
|
response = self._inner_post(request_parameters)
|
|
@@ -1603,7 +1660,7 @@ class InferenceClient:
|
|
|
1603
1660
|
*,
|
|
1604
1661
|
model: Optional[str] = None,
|
|
1605
1662
|
clean_up_tokenization_spaces: Optional[bool] = None,
|
|
1606
|
-
generate_parameters: Optional[
|
|
1663
|
+
generate_parameters: Optional[dict[str, Any]] = None,
|
|
1607
1664
|
truncation: Optional["SummarizationTruncationStrategy"] = None,
|
|
1608
1665
|
) -> SummarizationOutput:
|
|
1609
1666
|
"""
|
|
@@ -1617,7 +1674,7 @@ class InferenceClient:
|
|
|
1617
1674
|
Inference Endpoint. If not provided, the default recommended model for summarization will be used.
|
|
1618
1675
|
clean_up_tokenization_spaces (`bool`, *optional*):
|
|
1619
1676
|
Whether to clean up the potential extra spaces in the text output.
|
|
1620
|
-
generate_parameters (`
|
|
1677
|
+
generate_parameters (`dict[str, Any]`, *optional*):
|
|
1621
1678
|
Additional parametrization of the text generation algorithm.
|
|
1622
1679
|
truncation (`"SummarizationTruncationStrategy"`, *optional*):
|
|
1623
1680
|
The truncation strategy to use.
|
|
@@ -1627,7 +1684,7 @@ class InferenceClient:
|
|
|
1627
1684
|
Raises:
|
|
1628
1685
|
[`InferenceTimeoutError`]:
|
|
1629
1686
|
If the model is unavailable or the request times out.
|
|
1630
|
-
`
|
|
1687
|
+
[`HfHubHTTPError`]:
|
|
1631
1688
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
1632
1689
|
|
|
1633
1690
|
Example:
|
|
@@ -1643,12 +1700,13 @@ class InferenceClient:
|
|
|
1643
1700
|
"generate_parameters": generate_parameters,
|
|
1644
1701
|
"truncation": truncation,
|
|
1645
1702
|
}
|
|
1646
|
-
|
|
1703
|
+
model_id = model or self.model
|
|
1704
|
+
provider_helper = get_provider_helper(self.provider, task="summarization", model=model_id)
|
|
1647
1705
|
request_parameters = provider_helper.prepare_request(
|
|
1648
1706
|
inputs=text,
|
|
1649
1707
|
parameters=parameters,
|
|
1650
1708
|
headers=self.headers,
|
|
1651
|
-
model=
|
|
1709
|
+
model=model_id,
|
|
1652
1710
|
api_key=self.token,
|
|
1653
1711
|
)
|
|
1654
1712
|
response = self._inner_post(request_parameters)
|
|
@@ -1656,7 +1714,7 @@ class InferenceClient:
|
|
|
1656
1714
|
|
|
1657
1715
|
def table_question_answering(
|
|
1658
1716
|
self,
|
|
1659
|
-
table:
|
|
1717
|
+
table: dict[str, Any],
|
|
1660
1718
|
query: str,
|
|
1661
1719
|
*,
|
|
1662
1720
|
model: Optional[str] = None,
|
|
@@ -1691,7 +1749,7 @@ class InferenceClient:
|
|
|
1691
1749
|
Raises:
|
|
1692
1750
|
[`InferenceTimeoutError`]:
|
|
1693
1751
|
If the model is unavailable or the request times out.
|
|
1694
|
-
`
|
|
1752
|
+
[`HfHubHTTPError`]:
|
|
1695
1753
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
1696
1754
|
|
|
1697
1755
|
Example:
|
|
@@ -1704,24 +1762,24 @@ class InferenceClient:
|
|
|
1704
1762
|
TableQuestionAnsweringOutputElement(answer='36542', coordinates=[[0, 1]], cells=['36542'], aggregator='AVERAGE')
|
|
1705
1763
|
```
|
|
1706
1764
|
"""
|
|
1707
|
-
|
|
1765
|
+
model_id = model or self.model
|
|
1766
|
+
provider_helper = get_provider_helper(self.provider, task="table-question-answering", model=model_id)
|
|
1708
1767
|
request_parameters = provider_helper.prepare_request(
|
|
1709
|
-
inputs=
|
|
1768
|
+
inputs={"query": query, "table": table},
|
|
1710
1769
|
parameters={"model": model, "padding": padding, "sequential": sequential, "truncation": truncation},
|
|
1711
|
-
extra_payload={"query": query, "table": table},
|
|
1712
1770
|
headers=self.headers,
|
|
1713
|
-
model=
|
|
1771
|
+
model=model_id,
|
|
1714
1772
|
api_key=self.token,
|
|
1715
1773
|
)
|
|
1716
1774
|
response = self._inner_post(request_parameters)
|
|
1717
1775
|
return TableQuestionAnsweringOutputElement.parse_obj_as_instance(response)
|
|
1718
1776
|
|
|
1719
|
-
def tabular_classification(self, table:
|
|
1777
|
+
def tabular_classification(self, table: dict[str, Any], *, model: Optional[str] = None) -> list[str]:
|
|
1720
1778
|
"""
|
|
1721
1779
|
Classifying a target category (a group) based on a set of attributes.
|
|
1722
1780
|
|
|
1723
1781
|
Args:
|
|
1724
|
-
table (`
|
|
1782
|
+
table (`dict[str, Any]`):
|
|
1725
1783
|
Set of attributes to classify.
|
|
1726
1784
|
model (`str`, *optional*):
|
|
1727
1785
|
The model to use for the tabular classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to
|
|
@@ -1734,7 +1792,7 @@ class InferenceClient:
|
|
|
1734
1792
|
Raises:
|
|
1735
1793
|
[`InferenceTimeoutError`]:
|
|
1736
1794
|
If the model is unavailable or the request times out.
|
|
1737
|
-
`
|
|
1795
|
+
[`HfHubHTTPError`]:
|
|
1738
1796
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
1739
1797
|
|
|
1740
1798
|
Example:
|
|
@@ -1758,24 +1816,25 @@ class InferenceClient:
|
|
|
1758
1816
|
["5", "5", "5"]
|
|
1759
1817
|
```
|
|
1760
1818
|
"""
|
|
1761
|
-
|
|
1819
|
+
model_id = model or self.model
|
|
1820
|
+
provider_helper = get_provider_helper(self.provider, task="tabular-classification", model=model_id)
|
|
1762
1821
|
request_parameters = provider_helper.prepare_request(
|
|
1763
1822
|
inputs=None,
|
|
1764
1823
|
extra_payload={"table": table},
|
|
1765
1824
|
parameters={},
|
|
1766
1825
|
headers=self.headers,
|
|
1767
|
-
model=
|
|
1826
|
+
model=model_id,
|
|
1768
1827
|
api_key=self.token,
|
|
1769
1828
|
)
|
|
1770
1829
|
response = self._inner_post(request_parameters)
|
|
1771
1830
|
return _bytes_to_list(response)
|
|
1772
1831
|
|
|
1773
|
-
def tabular_regression(self, table:
|
|
1832
|
+
def tabular_regression(self, table: dict[str, Any], *, model: Optional[str] = None) -> list[float]:
|
|
1774
1833
|
"""
|
|
1775
1834
|
Predicting a numerical target value given a set of attributes/features in a table.
|
|
1776
1835
|
|
|
1777
1836
|
Args:
|
|
1778
|
-
table (`
|
|
1837
|
+
table (`dict[str, Any]`):
|
|
1779
1838
|
Set of attributes stored in a table. The attributes used to predict the target can be both numerical and categorical.
|
|
1780
1839
|
model (`str`, *optional*):
|
|
1781
1840
|
The model to use for the tabular regression task. Can be a model ID hosted on the Hugging Face Hub or a URL to
|
|
@@ -1788,7 +1847,7 @@ class InferenceClient:
|
|
|
1788
1847
|
Raises:
|
|
1789
1848
|
[`InferenceTimeoutError`]:
|
|
1790
1849
|
If the model is unavailable or the request times out.
|
|
1791
|
-
`
|
|
1850
|
+
[`HfHubHTTPError`]:
|
|
1792
1851
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
1793
1852
|
|
|
1794
1853
|
Example:
|
|
@@ -1807,13 +1866,14 @@ class InferenceClient:
|
|
|
1807
1866
|
[110, 120, 130]
|
|
1808
1867
|
```
|
|
1809
1868
|
"""
|
|
1810
|
-
|
|
1869
|
+
model_id = model or self.model
|
|
1870
|
+
provider_helper = get_provider_helper(self.provider, task="tabular-regression", model=model_id)
|
|
1811
1871
|
request_parameters = provider_helper.prepare_request(
|
|
1812
1872
|
inputs=None,
|
|
1813
1873
|
parameters={},
|
|
1814
1874
|
extra_payload={"table": table},
|
|
1815
1875
|
headers=self.headers,
|
|
1816
|
-
model=
|
|
1876
|
+
model=model_id,
|
|
1817
1877
|
api_key=self.token,
|
|
1818
1878
|
)
|
|
1819
1879
|
response = self._inner_post(request_parameters)
|
|
@@ -1826,7 +1886,7 @@ class InferenceClient:
|
|
|
1826
1886
|
model: Optional[str] = None,
|
|
1827
1887
|
top_k: Optional[int] = None,
|
|
1828
1888
|
function_to_apply: Optional["TextClassificationOutputTransform"] = None,
|
|
1829
|
-
) ->
|
|
1889
|
+
) -> list[TextClassificationOutputElement]:
|
|
1830
1890
|
"""
|
|
1831
1891
|
Perform text classification (e.g. sentiment-analysis) on the given text.
|
|
1832
1892
|
|
|
@@ -1843,12 +1903,12 @@ class InferenceClient:
|
|
|
1843
1903
|
The function to apply to the model outputs in order to retrieve the scores.
|
|
1844
1904
|
|
|
1845
1905
|
Returns:
|
|
1846
|
-
`
|
|
1906
|
+
`list[TextClassificationOutputElement]`: a list of [`TextClassificationOutputElement`] items containing the predicted label and associated probability.
|
|
1847
1907
|
|
|
1848
1908
|
Raises:
|
|
1849
1909
|
[`InferenceTimeoutError`]:
|
|
1850
1910
|
If the model is unavailable or the request times out.
|
|
1851
|
-
`
|
|
1911
|
+
[`HfHubHTTPError`]:
|
|
1852
1912
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
1853
1913
|
|
|
1854
1914
|
Example:
|
|
@@ -1862,7 +1922,8 @@ class InferenceClient:
|
|
|
1862
1922
|
]
|
|
1863
1923
|
```
|
|
1864
1924
|
"""
|
|
1865
|
-
|
|
1925
|
+
model_id = model or self.model
|
|
1926
|
+
provider_helper = get_provider_helper(self.provider, task="text-classification", model=model_id)
|
|
1866
1927
|
request_parameters = provider_helper.prepare_request(
|
|
1867
1928
|
inputs=text,
|
|
1868
1929
|
parameters={
|
|
@@ -1870,33 +1931,33 @@ class InferenceClient:
|
|
|
1870
1931
|
"top_k": top_k,
|
|
1871
1932
|
},
|
|
1872
1933
|
headers=self.headers,
|
|
1873
|
-
model=
|
|
1934
|
+
model=model_id,
|
|
1874
1935
|
api_key=self.token,
|
|
1875
1936
|
)
|
|
1876
1937
|
response = self._inner_post(request_parameters)
|
|
1877
1938
|
return TextClassificationOutputElement.parse_obj_as_list(response)[0] # type: ignore [return-value]
|
|
1878
1939
|
|
|
1879
1940
|
@overload
|
|
1880
|
-
def text_generation(
|
|
1941
|
+
def text_generation(
|
|
1881
1942
|
self,
|
|
1882
1943
|
prompt: str,
|
|
1883
1944
|
*,
|
|
1884
|
-
details: Literal[
|
|
1885
|
-
stream: Literal[
|
|
1945
|
+
details: Literal[True],
|
|
1946
|
+
stream: Literal[True],
|
|
1886
1947
|
model: Optional[str] = None,
|
|
1887
1948
|
# Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
|
|
1888
1949
|
adapter_id: Optional[str] = None,
|
|
1889
1950
|
best_of: Optional[int] = None,
|
|
1890
1951
|
decoder_input_details: Optional[bool] = None,
|
|
1891
|
-
do_sample: Optional[bool] =
|
|
1952
|
+
do_sample: Optional[bool] = None,
|
|
1892
1953
|
frequency_penalty: Optional[float] = None,
|
|
1893
1954
|
grammar: Optional[TextGenerationInputGrammarType] = None,
|
|
1894
1955
|
max_new_tokens: Optional[int] = None,
|
|
1895
1956
|
repetition_penalty: Optional[float] = None,
|
|
1896
|
-
return_full_text: Optional[bool] =
|
|
1957
|
+
return_full_text: Optional[bool] = None,
|
|
1897
1958
|
seed: Optional[int] = None,
|
|
1898
|
-
stop: Optional[
|
|
1899
|
-
stop_sequences: Optional[
|
|
1959
|
+
stop: Optional[list[str]] = None,
|
|
1960
|
+
stop_sequences: Optional[list[str]] = None, # Deprecated, use `stop` instead
|
|
1900
1961
|
temperature: Optional[float] = None,
|
|
1901
1962
|
top_k: Optional[int] = None,
|
|
1902
1963
|
top_n_tokens: Optional[int] = None,
|
|
@@ -1904,29 +1965,29 @@ class InferenceClient:
|
|
|
1904
1965
|
truncate: Optional[int] = None,
|
|
1905
1966
|
typical_p: Optional[float] = None,
|
|
1906
1967
|
watermark: Optional[bool] = None,
|
|
1907
|
-
) ->
|
|
1968
|
+
) -> Iterable[TextGenerationStreamOutput]: ...
|
|
1908
1969
|
|
|
1909
1970
|
@overload
|
|
1910
|
-
def text_generation(
|
|
1971
|
+
def text_generation(
|
|
1911
1972
|
self,
|
|
1912
1973
|
prompt: str,
|
|
1913
1974
|
*,
|
|
1914
|
-
details: Literal[True]
|
|
1915
|
-
stream: Literal[False] =
|
|
1975
|
+
details: Literal[True],
|
|
1976
|
+
stream: Optional[Literal[False]] = None,
|
|
1916
1977
|
model: Optional[str] = None,
|
|
1917
1978
|
# Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
|
|
1918
1979
|
adapter_id: Optional[str] = None,
|
|
1919
1980
|
best_of: Optional[int] = None,
|
|
1920
1981
|
decoder_input_details: Optional[bool] = None,
|
|
1921
|
-
do_sample: Optional[bool] =
|
|
1982
|
+
do_sample: Optional[bool] = None,
|
|
1922
1983
|
frequency_penalty: Optional[float] = None,
|
|
1923
1984
|
grammar: Optional[TextGenerationInputGrammarType] = None,
|
|
1924
1985
|
max_new_tokens: Optional[int] = None,
|
|
1925
1986
|
repetition_penalty: Optional[float] = None,
|
|
1926
|
-
return_full_text: Optional[bool] =
|
|
1987
|
+
return_full_text: Optional[bool] = None,
|
|
1927
1988
|
seed: Optional[int] = None,
|
|
1928
|
-
stop: Optional[
|
|
1929
|
-
stop_sequences: Optional[
|
|
1989
|
+
stop: Optional[list[str]] = None,
|
|
1990
|
+
stop_sequences: Optional[list[str]] = None, # Deprecated, use `stop` instead
|
|
1930
1991
|
temperature: Optional[float] = None,
|
|
1931
1992
|
top_k: Optional[int] = None,
|
|
1932
1993
|
top_n_tokens: Optional[int] = None,
|
|
@@ -1937,26 +1998,26 @@ class InferenceClient:
|
|
|
1937
1998
|
) -> TextGenerationOutput: ...
|
|
1938
1999
|
|
|
1939
2000
|
@overload
|
|
1940
|
-
def text_generation(
|
|
2001
|
+
def text_generation(
|
|
1941
2002
|
self,
|
|
1942
2003
|
prompt: str,
|
|
1943
2004
|
*,
|
|
1944
|
-
details: Literal[False] =
|
|
1945
|
-
stream: Literal[True]
|
|
2005
|
+
details: Optional[Literal[False]] = None,
|
|
2006
|
+
stream: Literal[True],
|
|
1946
2007
|
model: Optional[str] = None,
|
|
1947
2008
|
# Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
|
|
1948
2009
|
adapter_id: Optional[str] = None,
|
|
1949
2010
|
best_of: Optional[int] = None,
|
|
1950
2011
|
decoder_input_details: Optional[bool] = None,
|
|
1951
|
-
do_sample: Optional[bool] =
|
|
2012
|
+
do_sample: Optional[bool] = None,
|
|
1952
2013
|
frequency_penalty: Optional[float] = None,
|
|
1953
2014
|
grammar: Optional[TextGenerationInputGrammarType] = None,
|
|
1954
2015
|
max_new_tokens: Optional[int] = None,
|
|
1955
2016
|
repetition_penalty: Optional[float] = None,
|
|
1956
|
-
return_full_text: Optional[bool] =
|
|
2017
|
+
return_full_text: Optional[bool] = None, # Manual default value
|
|
1957
2018
|
seed: Optional[int] = None,
|
|
1958
|
-
stop: Optional[
|
|
1959
|
-
stop_sequences: Optional[
|
|
2019
|
+
stop: Optional[list[str]] = None,
|
|
2020
|
+
stop_sequences: Optional[list[str]] = None, # Deprecated, use `stop` instead
|
|
1960
2021
|
temperature: Optional[float] = None,
|
|
1961
2022
|
top_k: Optional[int] = None,
|
|
1962
2023
|
top_n_tokens: Optional[int] = None,
|
|
@@ -1967,26 +2028,26 @@ class InferenceClient:
|
|
|
1967
2028
|
) -> Iterable[str]: ...
|
|
1968
2029
|
|
|
1969
2030
|
@overload
|
|
1970
|
-
def text_generation(
|
|
2031
|
+
def text_generation(
|
|
1971
2032
|
self,
|
|
1972
2033
|
prompt: str,
|
|
1973
2034
|
*,
|
|
1974
|
-
details: Literal[
|
|
1975
|
-
stream: Literal[
|
|
2035
|
+
details: Optional[Literal[False]] = None,
|
|
2036
|
+
stream: Optional[Literal[False]] = None,
|
|
1976
2037
|
model: Optional[str] = None,
|
|
1977
2038
|
# Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
|
|
1978
2039
|
adapter_id: Optional[str] = None,
|
|
1979
2040
|
best_of: Optional[int] = None,
|
|
1980
2041
|
decoder_input_details: Optional[bool] = None,
|
|
1981
|
-
do_sample: Optional[bool] =
|
|
2042
|
+
do_sample: Optional[bool] = None,
|
|
1982
2043
|
frequency_penalty: Optional[float] = None,
|
|
1983
2044
|
grammar: Optional[TextGenerationInputGrammarType] = None,
|
|
1984
2045
|
max_new_tokens: Optional[int] = None,
|
|
1985
2046
|
repetition_penalty: Optional[float] = None,
|
|
1986
|
-
return_full_text: Optional[bool] =
|
|
2047
|
+
return_full_text: Optional[bool] = None,
|
|
1987
2048
|
seed: Optional[int] = None,
|
|
1988
|
-
stop: Optional[
|
|
1989
|
-
stop_sequences: Optional[
|
|
2049
|
+
stop: Optional[list[str]] = None,
|
|
2050
|
+
stop_sequences: Optional[list[str]] = None, # Deprecated, use `stop` instead
|
|
1990
2051
|
temperature: Optional[float] = None,
|
|
1991
2052
|
top_k: Optional[int] = None,
|
|
1992
2053
|
top_n_tokens: Optional[int] = None,
|
|
@@ -1994,29 +2055,29 @@ class InferenceClient:
|
|
|
1994
2055
|
truncate: Optional[int] = None,
|
|
1995
2056
|
typical_p: Optional[float] = None,
|
|
1996
2057
|
watermark: Optional[bool] = None,
|
|
1997
|
-
) ->
|
|
2058
|
+
) -> str: ...
|
|
1998
2059
|
|
|
1999
2060
|
@overload
|
|
2000
2061
|
def text_generation(
|
|
2001
2062
|
self,
|
|
2002
2063
|
prompt: str,
|
|
2003
2064
|
*,
|
|
2004
|
-
details:
|
|
2005
|
-
stream: bool =
|
|
2065
|
+
details: Optional[bool] = None,
|
|
2066
|
+
stream: Optional[bool] = None,
|
|
2006
2067
|
model: Optional[str] = None,
|
|
2007
2068
|
# Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
|
|
2008
2069
|
adapter_id: Optional[str] = None,
|
|
2009
2070
|
best_of: Optional[int] = None,
|
|
2010
2071
|
decoder_input_details: Optional[bool] = None,
|
|
2011
|
-
do_sample: Optional[bool] =
|
|
2072
|
+
do_sample: Optional[bool] = None,
|
|
2012
2073
|
frequency_penalty: Optional[float] = None,
|
|
2013
2074
|
grammar: Optional[TextGenerationInputGrammarType] = None,
|
|
2014
2075
|
max_new_tokens: Optional[int] = None,
|
|
2015
2076
|
repetition_penalty: Optional[float] = None,
|
|
2016
|
-
return_full_text: Optional[bool] =
|
|
2077
|
+
return_full_text: Optional[bool] = None,
|
|
2017
2078
|
seed: Optional[int] = None,
|
|
2018
|
-
stop: Optional[
|
|
2019
|
-
stop_sequences: Optional[
|
|
2079
|
+
stop: Optional[list[str]] = None,
|
|
2080
|
+
stop_sequences: Optional[list[str]] = None, # Deprecated, use `stop` instead
|
|
2020
2081
|
temperature: Optional[float] = None,
|
|
2021
2082
|
top_k: Optional[int] = None,
|
|
2022
2083
|
top_n_tokens: Optional[int] = None,
|
|
@@ -2024,28 +2085,28 @@ class InferenceClient:
|
|
|
2024
2085
|
truncate: Optional[int] = None,
|
|
2025
2086
|
typical_p: Optional[float] = None,
|
|
2026
2087
|
watermark: Optional[bool] = None,
|
|
2027
|
-
) -> Union[TextGenerationOutput, Iterable[TextGenerationStreamOutput]]: ...
|
|
2088
|
+
) -> Union[str, TextGenerationOutput, Iterable[str], Iterable[TextGenerationStreamOutput]]: ...
|
|
2028
2089
|
|
|
2029
2090
|
def text_generation(
|
|
2030
2091
|
self,
|
|
2031
2092
|
prompt: str,
|
|
2032
2093
|
*,
|
|
2033
|
-
details: bool =
|
|
2034
|
-
stream: bool =
|
|
2094
|
+
details: Optional[bool] = None,
|
|
2095
|
+
stream: Optional[bool] = None,
|
|
2035
2096
|
model: Optional[str] = None,
|
|
2036
2097
|
# Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
|
|
2037
2098
|
adapter_id: Optional[str] = None,
|
|
2038
2099
|
best_of: Optional[int] = None,
|
|
2039
2100
|
decoder_input_details: Optional[bool] = None,
|
|
2040
|
-
do_sample: Optional[bool] =
|
|
2101
|
+
do_sample: Optional[bool] = None,
|
|
2041
2102
|
frequency_penalty: Optional[float] = None,
|
|
2042
2103
|
grammar: Optional[TextGenerationInputGrammarType] = None,
|
|
2043
2104
|
max_new_tokens: Optional[int] = None,
|
|
2044
2105
|
repetition_penalty: Optional[float] = None,
|
|
2045
|
-
return_full_text: Optional[bool] =
|
|
2106
|
+
return_full_text: Optional[bool] = None,
|
|
2046
2107
|
seed: Optional[int] = None,
|
|
2047
|
-
stop: Optional[
|
|
2048
|
-
stop_sequences: Optional[
|
|
2108
|
+
stop: Optional[list[str]] = None,
|
|
2109
|
+
stop_sequences: Optional[list[str]] = None, # Deprecated, use `stop` instead
|
|
2049
2110
|
temperature: Optional[float] = None,
|
|
2050
2111
|
top_k: Optional[int] = None,
|
|
2051
2112
|
top_n_tokens: Optional[int] = None,
|
|
@@ -2057,12 +2118,9 @@ class InferenceClient:
|
|
|
2057
2118
|
"""
|
|
2058
2119
|
Given a prompt, generate the following text.
|
|
2059
2120
|
|
|
2060
|
-
|
|
2061
|
-
|
|
2062
|
-
|
|
2063
|
-
It accepts a list of messages instead of a single text prompt and handles the chat templating for you.
|
|
2064
|
-
|
|
2065
|
-
</Tip>
|
|
2121
|
+
> [!TIP]
|
|
2122
|
+
> If you want to generate a response from chat messages, you should use the [`InferenceClient.chat_completion`] method.
|
|
2123
|
+
> It accepts a list of messages instead of a single text prompt and handles the chat templating for you.
|
|
2066
2124
|
|
|
2067
2125
|
Args:
|
|
2068
2126
|
prompt (`str`):
|
|
@@ -2101,9 +2159,9 @@ class InferenceClient:
|
|
|
2101
2159
|
Whether to prepend the prompt to the generated text
|
|
2102
2160
|
seed (`int`, *optional*):
|
|
2103
2161
|
Random sampling seed
|
|
2104
|
-
stop (`
|
|
2162
|
+
stop (`list[str]`, *optional*):
|
|
2105
2163
|
Stop generating tokens if a member of `stop` is generated.
|
|
2106
|
-
stop_sequences (`
|
|
2164
|
+
stop_sequences (`list[str]`, *optional*):
|
|
2107
2165
|
Deprecated argument. Use `stop` instead.
|
|
2108
2166
|
temperature (`float`, *optional*):
|
|
2109
2167
|
The value used to module the logits distribution.
|
|
@@ -2120,7 +2178,7 @@ class InferenceClient:
|
|
|
2120
2178
|
typical_p (`float`, *optional`):
|
|
2121
2179
|
Typical Decoding mass
|
|
2122
2180
|
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
|
|
2123
|
-
watermark (`bool`, *optional
|
|
2181
|
+
watermark (`bool`, *optional*):
|
|
2124
2182
|
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
|
2125
2183
|
|
|
2126
2184
|
Returns:
|
|
@@ -2136,7 +2194,7 @@ class InferenceClient:
|
|
|
2136
2194
|
If input values are not valid. No HTTP call is made to the server.
|
|
2137
2195
|
[`InferenceTimeoutError`]:
|
|
2138
2196
|
If the model is unavailable or the request times out.
|
|
2139
|
-
`
|
|
2197
|
+
[`HfHubHTTPError`]:
|
|
2140
2198
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
2141
2199
|
|
|
2142
2200
|
Example:
|
|
@@ -2270,7 +2328,7 @@ class InferenceClient:
|
|
|
2270
2328
|
"repetition_penalty": repetition_penalty,
|
|
2271
2329
|
"return_full_text": return_full_text,
|
|
2272
2330
|
"seed": seed,
|
|
2273
|
-
"stop": stop
|
|
2331
|
+
"stop": stop,
|
|
2274
2332
|
"temperature": temperature,
|
|
2275
2333
|
"top_k": top_k,
|
|
2276
2334
|
"top_n_tokens": top_n_tokens,
|
|
@@ -2311,20 +2369,21 @@ class InferenceClient:
|
|
|
2311
2369
|
" Please pass `stream=False` as input."
|
|
2312
2370
|
)
|
|
2313
2371
|
|
|
2314
|
-
|
|
2372
|
+
model_id = model or self.model
|
|
2373
|
+
provider_helper = get_provider_helper(self.provider, task="text-generation", model=model_id)
|
|
2315
2374
|
request_parameters = provider_helper.prepare_request(
|
|
2316
2375
|
inputs=prompt,
|
|
2317
2376
|
parameters=parameters,
|
|
2318
2377
|
extra_payload={"stream": stream},
|
|
2319
2378
|
headers=self.headers,
|
|
2320
|
-
model=
|
|
2379
|
+
model=model_id,
|
|
2321
2380
|
api_key=self.token,
|
|
2322
2381
|
)
|
|
2323
2382
|
|
|
2324
2383
|
# Handle errors separately for more precise error messages
|
|
2325
2384
|
try:
|
|
2326
|
-
bytes_output = self._inner_post(request_parameters, stream=stream)
|
|
2327
|
-
except
|
|
2385
|
+
bytes_output = self._inner_post(request_parameters, stream=stream or False)
|
|
2386
|
+
except HfHubHTTPError as e:
|
|
2328
2387
|
match = MODEL_KWARGS_NOT_USED_REGEX.search(str(e))
|
|
2329
2388
|
if isinstance(e, BadRequestError) and match:
|
|
2330
2389
|
unused_params = [kwarg.strip("' ") for kwarg in match.group(1).split(",")]
|
|
@@ -2333,7 +2392,7 @@ class InferenceClient:
|
|
|
2333
2392
|
prompt=prompt,
|
|
2334
2393
|
details=details,
|
|
2335
2394
|
stream=stream,
|
|
2336
|
-
model=
|
|
2395
|
+
model=model_id,
|
|
2337
2396
|
adapter_id=adapter_id,
|
|
2338
2397
|
best_of=best_of,
|
|
2339
2398
|
decoder_input_details=decoder_input_details,
|
|
@@ -2364,8 +2423,8 @@ class InferenceClient:
|
|
|
2364
2423
|
# Data can be a single element (dict) or an iterable of dicts where we select the first element of.
|
|
2365
2424
|
if isinstance(data, list):
|
|
2366
2425
|
data = data[0]
|
|
2367
|
-
|
|
2368
|
-
return TextGenerationOutput.parse_obj_as_instance(
|
|
2426
|
+
response = provider_helper.get_response(data, request_parameters)
|
|
2427
|
+
return TextGenerationOutput.parse_obj_as_instance(response) if details else response["generated_text"]
|
|
2369
2428
|
|
|
2370
2429
|
def text_to_image(
|
|
2371
2430
|
self,
|
|
@@ -2379,20 +2438,16 @@ class InferenceClient:
|
|
|
2379
2438
|
model: Optional[str] = None,
|
|
2380
2439
|
scheduler: Optional[str] = None,
|
|
2381
2440
|
seed: Optional[int] = None,
|
|
2382
|
-
extra_body: Optional[
|
|
2441
|
+
extra_body: Optional[dict[str, Any]] = None,
|
|
2383
2442
|
) -> "Image":
|
|
2384
2443
|
"""
|
|
2385
2444
|
Generate an image based on a given text using a specified model.
|
|
2386
2445
|
|
|
2387
|
-
|
|
2446
|
+
> [!WARNING]
|
|
2447
|
+
> You must have `PIL` installed if you want to work with images (`pip install Pillow`).
|
|
2388
2448
|
|
|
2389
|
-
|
|
2390
|
-
|
|
2391
|
-
</Tip>
|
|
2392
|
-
|
|
2393
|
-
<Tip>
|
|
2394
|
-
You can pass provider-specific parameters to the model by using the `extra_body` argument.
|
|
2395
|
-
</Tip>
|
|
2449
|
+
> [!TIP]
|
|
2450
|
+
> You can pass provider-specific parameters to the model by using the `extra_body` argument.
|
|
2396
2451
|
|
|
2397
2452
|
Args:
|
|
2398
2453
|
prompt (`str`):
|
|
@@ -2417,7 +2472,7 @@ class InferenceClient:
|
|
|
2417
2472
|
Override the scheduler with a compatible one.
|
|
2418
2473
|
seed (`int`, *optional*):
|
|
2419
2474
|
Seed for the random number generator.
|
|
2420
|
-
extra_body (`
|
|
2475
|
+
extra_body (`dict[str, Any]`, *optional*):
|
|
2421
2476
|
Additional provider-specific parameters to pass to the model. Refer to the provider's documentation
|
|
2422
2477
|
for supported parameters.
|
|
2423
2478
|
|
|
@@ -2427,7 +2482,7 @@ class InferenceClient:
|
|
|
2427
2482
|
Raises:
|
|
2428
2483
|
[`InferenceTimeoutError`]:
|
|
2429
2484
|
If the model is unavailable or the request times out.
|
|
2430
|
-
`
|
|
2485
|
+
[`HfHubHTTPError`]:
|
|
2431
2486
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
2432
2487
|
|
|
2433
2488
|
Example:
|
|
@@ -2487,8 +2542,10 @@ class InferenceClient:
|
|
|
2487
2542
|
... )
|
|
2488
2543
|
>>> image.save("astronaut.png")
|
|
2489
2544
|
```
|
|
2545
|
+
|
|
2490
2546
|
"""
|
|
2491
|
-
|
|
2547
|
+
model_id = model or self.model
|
|
2548
|
+
provider_helper = get_provider_helper(self.provider, task="text-to-image", model=model_id)
|
|
2492
2549
|
request_parameters = provider_helper.prepare_request(
|
|
2493
2550
|
inputs=prompt,
|
|
2494
2551
|
parameters={
|
|
@@ -2502,11 +2559,11 @@ class InferenceClient:
|
|
|
2502
2559
|
**(extra_body or {}),
|
|
2503
2560
|
},
|
|
2504
2561
|
headers=self.headers,
|
|
2505
|
-
model=
|
|
2562
|
+
model=model_id,
|
|
2506
2563
|
api_key=self.token,
|
|
2507
2564
|
)
|
|
2508
2565
|
response = self._inner_post(request_parameters)
|
|
2509
|
-
response = provider_helper.get_response(response)
|
|
2566
|
+
response = provider_helper.get_response(response, request_parameters)
|
|
2510
2567
|
return _bytes_to_image(response)
|
|
2511
2568
|
|
|
2512
2569
|
def text_to_video(
|
|
@@ -2515,18 +2572,17 @@ class InferenceClient:
|
|
|
2515
2572
|
*,
|
|
2516
2573
|
model: Optional[str] = None,
|
|
2517
2574
|
guidance_scale: Optional[float] = None,
|
|
2518
|
-
negative_prompt: Optional[
|
|
2575
|
+
negative_prompt: Optional[list[str]] = None,
|
|
2519
2576
|
num_frames: Optional[float] = None,
|
|
2520
2577
|
num_inference_steps: Optional[int] = None,
|
|
2521
2578
|
seed: Optional[int] = None,
|
|
2522
|
-
extra_body: Optional[
|
|
2579
|
+
extra_body: Optional[dict[str, Any]] = None,
|
|
2523
2580
|
) -> bytes:
|
|
2524
2581
|
"""
|
|
2525
2582
|
Generate a video based on a given text.
|
|
2526
2583
|
|
|
2527
|
-
|
|
2528
|
-
You can pass provider-specific parameters to the model by using the `extra_body` argument.
|
|
2529
|
-
</Tip>
|
|
2584
|
+
> [!TIP]
|
|
2585
|
+
> You can pass provider-specific parameters to the model by using the `extra_body` argument.
|
|
2530
2586
|
|
|
2531
2587
|
Args:
|
|
2532
2588
|
prompt (`str`):
|
|
@@ -2538,7 +2594,7 @@ class InferenceClient:
|
|
|
2538
2594
|
guidance_scale (`float`, *optional*):
|
|
2539
2595
|
A higher guidance scale value encourages the model to generate videos closely linked to the text
|
|
2540
2596
|
prompt, but values too high may cause saturation and other artifacts.
|
|
2541
|
-
negative_prompt (`
|
|
2597
|
+
negative_prompt (`list[str]`, *optional*):
|
|
2542
2598
|
One or several prompt to guide what NOT to include in video generation.
|
|
2543
2599
|
num_frames (`float`, *optional*):
|
|
2544
2600
|
The num_frames parameter determines how many video frames are generated.
|
|
@@ -2547,7 +2603,7 @@ class InferenceClient:
|
|
|
2547
2603
|
expense of slower inference.
|
|
2548
2604
|
seed (`int`, *optional*):
|
|
2549
2605
|
Seed for the random number generator.
|
|
2550
|
-
extra_body (`
|
|
2606
|
+
extra_body (`dict[str, Any]`, *optional*):
|
|
2551
2607
|
Additional provider-specific parameters to pass to the model. Refer to the provider's documentation
|
|
2552
2608
|
for supported parameters.
|
|
2553
2609
|
|
|
@@ -2585,8 +2641,10 @@ class InferenceClient:
|
|
|
2585
2641
|
>>> with open("cat.mp4", "wb") as file:
|
|
2586
2642
|
... file.write(video)
|
|
2587
2643
|
```
|
|
2644
|
+
|
|
2588
2645
|
"""
|
|
2589
|
-
|
|
2646
|
+
model_id = model or self.model
|
|
2647
|
+
provider_helper = get_provider_helper(self.provider, task="text-to-video", model=model_id)
|
|
2590
2648
|
request_parameters = provider_helper.prepare_request(
|
|
2591
2649
|
inputs=prompt,
|
|
2592
2650
|
parameters={
|
|
@@ -2598,11 +2656,11 @@ class InferenceClient:
|
|
|
2598
2656
|
**(extra_body or {}),
|
|
2599
2657
|
},
|
|
2600
2658
|
headers=self.headers,
|
|
2601
|
-
model=
|
|
2659
|
+
model=model_id,
|
|
2602
2660
|
api_key=self.token,
|
|
2603
2661
|
)
|
|
2604
2662
|
response = self._inner_post(request_parameters)
|
|
2605
|
-
response = provider_helper.get_response(response)
|
|
2663
|
+
response = provider_helper.get_response(response, request_parameters)
|
|
2606
2664
|
return response
|
|
2607
2665
|
|
|
2608
2666
|
def text_to_speech(
|
|
@@ -2626,14 +2684,13 @@ class InferenceClient:
|
|
|
2626
2684
|
top_p: Optional[float] = None,
|
|
2627
2685
|
typical_p: Optional[float] = None,
|
|
2628
2686
|
use_cache: Optional[bool] = None,
|
|
2629
|
-
extra_body: Optional[
|
|
2687
|
+
extra_body: Optional[dict[str, Any]] = None,
|
|
2630
2688
|
) -> bytes:
|
|
2631
2689
|
"""
|
|
2632
2690
|
Synthesize an audio of a voice pronouncing a given text.
|
|
2633
2691
|
|
|
2634
|
-
|
|
2635
|
-
You can pass provider-specific parameters to the model by using the `extra_body` argument.
|
|
2636
|
-
</Tip>
|
|
2692
|
+
> [!TIP]
|
|
2693
|
+
> You can pass provider-specific parameters to the model by using the `extra_body` argument.
|
|
2637
2694
|
|
|
2638
2695
|
Args:
|
|
2639
2696
|
text (`str`):
|
|
@@ -2688,7 +2745,7 @@ class InferenceClient:
|
|
|
2688
2745
|
paper](https://hf.co/papers/2202.00666) for more details.
|
|
2689
2746
|
use_cache (`bool`, *optional*):
|
|
2690
2747
|
Whether the model should use the past last key/values attentions to speed up decoding
|
|
2691
|
-
extra_body (`
|
|
2748
|
+
extra_body (`dict[str, Any]`, *optional*):
|
|
2692
2749
|
Additional provider-specific parameters to pass to the model. Refer to the provider's documentation
|
|
2693
2750
|
for supported parameters.
|
|
2694
2751
|
Returns:
|
|
@@ -2697,7 +2754,7 @@ class InferenceClient:
|
|
|
2697
2754
|
Raises:
|
|
2698
2755
|
[`InferenceTimeoutError`]:
|
|
2699
2756
|
If the model is unavailable or the request times out.
|
|
2700
|
-
`
|
|
2757
|
+
[`HfHubHTTPError`]:
|
|
2701
2758
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
2702
2759
|
|
|
2703
2760
|
Example:
|
|
@@ -2783,7 +2840,8 @@ class InferenceClient:
|
|
|
2783
2840
|
... f.write(audio)
|
|
2784
2841
|
```
|
|
2785
2842
|
"""
|
|
2786
|
-
|
|
2843
|
+
model_id = model or self.model
|
|
2844
|
+
provider_helper = get_provider_helper(self.provider, task="text-to-speech", model=model_id)
|
|
2787
2845
|
request_parameters = provider_helper.prepare_request(
|
|
2788
2846
|
inputs=text,
|
|
2789
2847
|
parameters={
|
|
@@ -2806,7 +2864,7 @@ class InferenceClient:
|
|
|
2806
2864
|
**(extra_body or {}),
|
|
2807
2865
|
},
|
|
2808
2866
|
headers=self.headers,
|
|
2809
|
-
model=
|
|
2867
|
+
model=model_id,
|
|
2810
2868
|
api_key=self.token,
|
|
2811
2869
|
)
|
|
2812
2870
|
response = self._inner_post(request_parameters)
|
|
@@ -2819,9 +2877,9 @@ class InferenceClient:
|
|
|
2819
2877
|
*,
|
|
2820
2878
|
model: Optional[str] = None,
|
|
2821
2879
|
aggregation_strategy: Optional["TokenClassificationAggregationStrategy"] = None,
|
|
2822
|
-
ignore_labels: Optional[
|
|
2880
|
+
ignore_labels: Optional[list[str]] = None,
|
|
2823
2881
|
stride: Optional[int] = None,
|
|
2824
|
-
) ->
|
|
2882
|
+
) -> list[TokenClassificationOutputElement]:
|
|
2825
2883
|
"""
|
|
2826
2884
|
Perform token classification on the given text.
|
|
2827
2885
|
Usually used for sentence parsing, either grammatical, or Named Entity Recognition (NER) to understand keywords contained within text.
|
|
@@ -2835,18 +2893,18 @@ class InferenceClient:
|
|
|
2835
2893
|
Defaults to None.
|
|
2836
2894
|
aggregation_strategy (`"TokenClassificationAggregationStrategy"`, *optional*):
|
|
2837
2895
|
The strategy used to fuse tokens based on model predictions
|
|
2838
|
-
ignore_labels (`
|
|
2896
|
+
ignore_labels (`list[str`, *optional*):
|
|
2839
2897
|
A list of labels to ignore
|
|
2840
2898
|
stride (`int`, *optional*):
|
|
2841
2899
|
The number of overlapping tokens between chunks when splitting the input text.
|
|
2842
2900
|
|
|
2843
2901
|
Returns:
|
|
2844
|
-
`
|
|
2902
|
+
`list[TokenClassificationOutputElement]`: List of [`TokenClassificationOutputElement`] items containing the entity group, confidence score, word, start and end index.
|
|
2845
2903
|
|
|
2846
2904
|
Raises:
|
|
2847
2905
|
[`InferenceTimeoutError`]:
|
|
2848
2906
|
If the model is unavailable or the request times out.
|
|
2849
|
-
`
|
|
2907
|
+
[`HfHubHTTPError`]:
|
|
2850
2908
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
2851
2909
|
|
|
2852
2910
|
Example:
|
|
@@ -2872,7 +2930,8 @@ class InferenceClient:
|
|
|
2872
2930
|
]
|
|
2873
2931
|
```
|
|
2874
2932
|
"""
|
|
2875
|
-
|
|
2933
|
+
model_id = model or self.model
|
|
2934
|
+
provider_helper = get_provider_helper(self.provider, task="token-classification", model=model_id)
|
|
2876
2935
|
request_parameters = provider_helper.prepare_request(
|
|
2877
2936
|
inputs=text,
|
|
2878
2937
|
parameters={
|
|
@@ -2881,7 +2940,7 @@ class InferenceClient:
|
|
|
2881
2940
|
"stride": stride,
|
|
2882
2941
|
},
|
|
2883
2942
|
headers=self.headers,
|
|
2884
|
-
model=
|
|
2943
|
+
model=model_id,
|
|
2885
2944
|
api_key=self.token,
|
|
2886
2945
|
)
|
|
2887
2946
|
response = self._inner_post(request_parameters)
|
|
@@ -2896,7 +2955,7 @@ class InferenceClient:
|
|
|
2896
2955
|
tgt_lang: Optional[str] = None,
|
|
2897
2956
|
clean_up_tokenization_spaces: Optional[bool] = None,
|
|
2898
2957
|
truncation: Optional["TranslationTruncationStrategy"] = None,
|
|
2899
|
-
generate_parameters: Optional[
|
|
2958
|
+
generate_parameters: Optional[dict[str, Any]] = None,
|
|
2900
2959
|
) -> TranslationOutput:
|
|
2901
2960
|
"""
|
|
2902
2961
|
Convert text from one language to another.
|
|
@@ -2921,7 +2980,7 @@ class InferenceClient:
|
|
|
2921
2980
|
Whether to clean up the potential extra spaces in the text output.
|
|
2922
2981
|
truncation (`"TranslationTruncationStrategy"`, *optional*):
|
|
2923
2982
|
The truncation strategy to use.
|
|
2924
|
-
generate_parameters (`
|
|
2983
|
+
generate_parameters (`dict[str, Any]`, *optional*):
|
|
2925
2984
|
Additional parametrization of the text generation algorithm.
|
|
2926
2985
|
|
|
2927
2986
|
Returns:
|
|
@@ -2930,7 +2989,7 @@ class InferenceClient:
|
|
|
2930
2989
|
Raises:
|
|
2931
2990
|
[`InferenceTimeoutError`]:
|
|
2932
2991
|
If the model is unavailable or the request times out.
|
|
2933
|
-
`
|
|
2992
|
+
[`HfHubHTTPError`]:
|
|
2934
2993
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
2935
2994
|
`ValueError`:
|
|
2936
2995
|
If only one of the `src_lang` and `tgt_lang` arguments are provided.
|
|
@@ -2958,7 +3017,8 @@ class InferenceClient:
|
|
|
2958
3017
|
if src_lang is None and tgt_lang is not None:
|
|
2959
3018
|
raise ValueError("You cannot specify `tgt_lang` without specifying `src_lang`.")
|
|
2960
3019
|
|
|
2961
|
-
|
|
3020
|
+
model_id = model or self.model
|
|
3021
|
+
provider_helper = get_provider_helper(self.provider, task="translation", model=model_id)
|
|
2962
3022
|
request_parameters = provider_helper.prepare_request(
|
|
2963
3023
|
inputs=text,
|
|
2964
3024
|
parameters={
|
|
@@ -2969,7 +3029,7 @@ class InferenceClient:
|
|
|
2969
3029
|
"generate_parameters": generate_parameters,
|
|
2970
3030
|
},
|
|
2971
3031
|
headers=self.headers,
|
|
2972
|
-
model=
|
|
3032
|
+
model=model_id,
|
|
2973
3033
|
api_key=self.token,
|
|
2974
3034
|
)
|
|
2975
3035
|
response = self._inner_post(request_parameters)
|
|
@@ -2982,13 +3042,13 @@ class InferenceClient:
|
|
|
2982
3042
|
*,
|
|
2983
3043
|
model: Optional[str] = None,
|
|
2984
3044
|
top_k: Optional[int] = None,
|
|
2985
|
-
) ->
|
|
3045
|
+
) -> list[VisualQuestionAnsweringOutputElement]:
|
|
2986
3046
|
"""
|
|
2987
3047
|
Answering open-ended questions based on an image.
|
|
2988
3048
|
|
|
2989
3049
|
Args:
|
|
2990
|
-
image (`Union[str, Path, bytes, BinaryIO]`):
|
|
2991
|
-
The input image for the context. It can be raw bytes, an image file,
|
|
3050
|
+
image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`):
|
|
3051
|
+
The input image for the context. It can be raw bytes, an image file, a URL to an online image, or a PIL Image.
|
|
2992
3052
|
question (`str`):
|
|
2993
3053
|
Question to be answered.
|
|
2994
3054
|
model (`str`, *optional*):
|
|
@@ -2999,12 +3059,12 @@ class InferenceClient:
|
|
|
2999
3059
|
The number of answers to return (will be chosen by order of likelihood). Note that we return less than
|
|
3000
3060
|
topk answers if there are not enough options available within the context.
|
|
3001
3061
|
Returns:
|
|
3002
|
-
`
|
|
3062
|
+
`list[VisualQuestionAnsweringOutputElement]`: a list of [`VisualQuestionAnsweringOutputElement`] items containing the predicted label and associated probability.
|
|
3003
3063
|
|
|
3004
3064
|
Raises:
|
|
3005
3065
|
`InferenceTimeoutError`:
|
|
3006
3066
|
If the model is unavailable or the request times out.
|
|
3007
|
-
`
|
|
3067
|
+
[`HfHubHTTPError`]:
|
|
3008
3068
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
3009
3069
|
|
|
3010
3070
|
Example:
|
|
@@ -3021,44 +3081,37 @@ class InferenceClient:
|
|
|
3021
3081
|
]
|
|
3022
3082
|
```
|
|
3023
3083
|
"""
|
|
3024
|
-
|
|
3084
|
+
model_id = model or self.model
|
|
3085
|
+
provider_helper = get_provider_helper(self.provider, task="visual-question-answering", model=model_id)
|
|
3025
3086
|
request_parameters = provider_helper.prepare_request(
|
|
3026
3087
|
inputs=image,
|
|
3027
3088
|
parameters={"top_k": top_k},
|
|
3028
3089
|
headers=self.headers,
|
|
3029
|
-
model=
|
|
3090
|
+
model=model_id,
|
|
3030
3091
|
api_key=self.token,
|
|
3031
3092
|
extra_payload={"question": question, "image": _b64_encode(image)},
|
|
3032
3093
|
)
|
|
3033
3094
|
response = self._inner_post(request_parameters)
|
|
3034
3095
|
return VisualQuestionAnsweringOutputElement.parse_obj_as_list(response)
|
|
3035
3096
|
|
|
3036
|
-
@_deprecate_arguments(
|
|
3037
|
-
version="0.30.0",
|
|
3038
|
-
deprecated_args=["labels"],
|
|
3039
|
-
custom_message="`labels`has been renamed to `candidate_labels` and will be removed in huggingface_hub>=0.30.0.",
|
|
3040
|
-
)
|
|
3041
3097
|
def zero_shot_classification(
|
|
3042
3098
|
self,
|
|
3043
3099
|
text: str,
|
|
3044
|
-
|
|
3045
|
-
candidate_labels: List[str] = None, # type: ignore
|
|
3100
|
+
candidate_labels: list[str],
|
|
3046
3101
|
*,
|
|
3047
3102
|
multi_label: Optional[bool] = False,
|
|
3048
3103
|
hypothesis_template: Optional[str] = None,
|
|
3049
3104
|
model: Optional[str] = None,
|
|
3050
|
-
|
|
3051
|
-
labels: List[str] = None, # type: ignore
|
|
3052
|
-
) -> List[ZeroShotClassificationOutputElement]:
|
|
3105
|
+
) -> list[ZeroShotClassificationOutputElement]:
|
|
3053
3106
|
"""
|
|
3054
3107
|
Provide as input a text and a set of candidate labels to classify the input text.
|
|
3055
3108
|
|
|
3056
3109
|
Args:
|
|
3057
3110
|
text (`str`):
|
|
3058
3111
|
The input text to classify.
|
|
3059
|
-
candidate_labels (`
|
|
3112
|
+
candidate_labels (`list[str]`):
|
|
3060
3113
|
The set of possible class labels to classify the text into.
|
|
3061
|
-
labels (`
|
|
3114
|
+
labels (`list[str]`, *optional*):
|
|
3062
3115
|
(deprecated) List of strings. Each string is the verbalization of a possible label for the input text.
|
|
3063
3116
|
multi_label (`bool`, *optional*):
|
|
3064
3117
|
Whether multiple candidate labels can be true. If false, the scores are normalized such that the sum of
|
|
@@ -3073,12 +3126,12 @@ class InferenceClient:
|
|
|
3073
3126
|
|
|
3074
3127
|
|
|
3075
3128
|
Returns:
|
|
3076
|
-
`
|
|
3129
|
+
`list[ZeroShotClassificationOutputElement]`: List of [`ZeroShotClassificationOutputElement`] items containing the predicted labels and their confidence.
|
|
3077
3130
|
|
|
3078
3131
|
Raises:
|
|
3079
3132
|
[`InferenceTimeoutError`]:
|
|
3080
3133
|
If the model is unavailable or the request times out.
|
|
3081
|
-
`
|
|
3134
|
+
[`HfHubHTTPError`]:
|
|
3082
3135
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
3083
3136
|
|
|
3084
3137
|
Example with `multi_label=False`:
|
|
@@ -3127,17 +3180,8 @@ class InferenceClient:
|
|
|
3127
3180
|
]
|
|
3128
3181
|
```
|
|
3129
3182
|
"""
|
|
3130
|
-
|
|
3131
|
-
|
|
3132
|
-
if candidate_labels is not None:
|
|
3133
|
-
raise ValueError(
|
|
3134
|
-
"Cannot specify both `labels` and `candidate_labels`. Use `candidate_labels` instead."
|
|
3135
|
-
)
|
|
3136
|
-
candidate_labels = labels
|
|
3137
|
-
elif candidate_labels is None:
|
|
3138
|
-
raise ValueError("Must specify `candidate_labels`")
|
|
3139
|
-
|
|
3140
|
-
provider_helper = get_provider_helper(self.provider, task="zero-shot-classification")
|
|
3183
|
+
model_id = model or self.model
|
|
3184
|
+
provider_helper = get_provider_helper(self.provider, task="zero-shot-classification", model=model_id)
|
|
3141
3185
|
request_parameters = provider_helper.prepare_request(
|
|
3142
3186
|
inputs=text,
|
|
3143
3187
|
parameters={
|
|
@@ -3146,7 +3190,7 @@ class InferenceClient:
|
|
|
3146
3190
|
"hypothesis_template": hypothesis_template,
|
|
3147
3191
|
},
|
|
3148
3192
|
headers=self.headers,
|
|
3149
|
-
model=
|
|
3193
|
+
model=model_id,
|
|
3150
3194
|
api_key=self.token,
|
|
3151
3195
|
)
|
|
3152
3196
|
response = self._inner_post(request_parameters)
|
|
@@ -3156,31 +3200,25 @@ class InferenceClient:
|
|
|
3156
3200
|
for label, score in zip(output["labels"], output["scores"])
|
|
3157
3201
|
]
|
|
3158
3202
|
|
|
3159
|
-
@_deprecate_arguments(
|
|
3160
|
-
version="0.30.0",
|
|
3161
|
-
deprecated_args=["labels"],
|
|
3162
|
-
custom_message="`labels`has been renamed to `candidate_labels` and will be removed in huggingface_hub>=0.30.0.",
|
|
3163
|
-
)
|
|
3164
3203
|
def zero_shot_image_classification(
|
|
3165
3204
|
self,
|
|
3166
3205
|
image: ContentT,
|
|
3167
|
-
|
|
3168
|
-
candidate_labels: List[str] = None, # type: ignore
|
|
3206
|
+
candidate_labels: list[str],
|
|
3169
3207
|
*,
|
|
3170
3208
|
model: Optional[str] = None,
|
|
3171
3209
|
hypothesis_template: Optional[str] = None,
|
|
3172
3210
|
# deprecated argument
|
|
3173
|
-
labels:
|
|
3174
|
-
) ->
|
|
3211
|
+
labels: list[str] = None, # type: ignore
|
|
3212
|
+
) -> list[ZeroShotImageClassificationOutputElement]:
|
|
3175
3213
|
"""
|
|
3176
3214
|
Provide input image and text labels to predict text labels for the image.
|
|
3177
3215
|
|
|
3178
3216
|
Args:
|
|
3179
|
-
image (`Union[str, Path, bytes, BinaryIO]`):
|
|
3180
|
-
The input image to caption. It can be raw bytes, an image file,
|
|
3181
|
-
candidate_labels (`
|
|
3217
|
+
image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`):
|
|
3218
|
+
The input image to caption. It can be raw bytes, an image file, a URL to an online image, or a PIL Image.
|
|
3219
|
+
candidate_labels (`list[str]`):
|
|
3182
3220
|
The candidate labels for this image
|
|
3183
|
-
labels (`
|
|
3221
|
+
labels (`list[str]`, *optional*):
|
|
3184
3222
|
(deprecated) List of string possible labels. There must be at least 2 labels.
|
|
3185
3223
|
model (`str`, *optional*):
|
|
3186
3224
|
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
|
@@ -3190,12 +3228,12 @@ class InferenceClient:
|
|
|
3190
3228
|
replacing the placeholder with the candidate labels.
|
|
3191
3229
|
|
|
3192
3230
|
Returns:
|
|
3193
|
-
`
|
|
3231
|
+
`list[ZeroShotImageClassificationOutputElement]`: List of [`ZeroShotImageClassificationOutputElement`] items containing the predicted labels and their confidence.
|
|
3194
3232
|
|
|
3195
3233
|
Raises:
|
|
3196
3234
|
[`InferenceTimeoutError`]:
|
|
3197
3235
|
If the model is unavailable or the request times out.
|
|
3198
|
-
`
|
|
3236
|
+
[`HfHubHTTPError`]:
|
|
3199
3237
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
3200
3238
|
|
|
3201
3239
|
Example:
|
|
@@ -3210,20 +3248,12 @@ class InferenceClient:
|
|
|
3210
3248
|
[ZeroShotImageClassificationOutputElement(label='dog', score=0.956),...]
|
|
3211
3249
|
```
|
|
3212
3250
|
"""
|
|
3213
|
-
# handle deprecation
|
|
3214
|
-
if labels is not None:
|
|
3215
|
-
if candidate_labels is not None:
|
|
3216
|
-
raise ValueError(
|
|
3217
|
-
"Cannot specify both `labels` and `candidate_labels`. Use `candidate_labels` instead."
|
|
3218
|
-
)
|
|
3219
|
-
candidate_labels = labels
|
|
3220
|
-
elif candidate_labels is None:
|
|
3221
|
-
raise ValueError("Must specify `candidate_labels`")
|
|
3222
3251
|
# Raise ValueError if input is less than 2 labels
|
|
3223
3252
|
if len(candidate_labels) < 2:
|
|
3224
3253
|
raise ValueError("You must specify at least 2 classes to compare.")
|
|
3225
3254
|
|
|
3226
|
-
|
|
3255
|
+
model_id = model or self.model
|
|
3256
|
+
provider_helper = get_provider_helper(self.provider, task="zero-shot-image-classification", model=model_id)
|
|
3227
3257
|
request_parameters = provider_helper.prepare_request(
|
|
3228
3258
|
inputs=image,
|
|
3229
3259
|
parameters={
|
|
@@ -3231,108 +3261,13 @@ class InferenceClient:
|
|
|
3231
3261
|
"hypothesis_template": hypothesis_template,
|
|
3232
3262
|
},
|
|
3233
3263
|
headers=self.headers,
|
|
3234
|
-
model=
|
|
3264
|
+
model=model_id,
|
|
3235
3265
|
api_key=self.token,
|
|
3236
3266
|
)
|
|
3237
3267
|
response = self._inner_post(request_parameters)
|
|
3238
3268
|
return ZeroShotImageClassificationOutputElement.parse_obj_as_list(response)
|
|
3239
3269
|
|
|
3240
|
-
|
|
3241
|
-
version="0.33.0",
|
|
3242
|
-
message=(
|
|
3243
|
-
"HF Inference API is getting revamped and will only support warm models in the future (no cold start allowed)."
|
|
3244
|
-
" Use `HfApi.list_models(..., inference_provider='...')` to list warm models per provider."
|
|
3245
|
-
),
|
|
3246
|
-
)
|
|
3247
|
-
def list_deployed_models(
|
|
3248
|
-
self, frameworks: Union[None, str, Literal["all"], List[str]] = None
|
|
3249
|
-
) -> Dict[str, List[str]]:
|
|
3250
|
-
"""
|
|
3251
|
-
List models deployed on the HF Serverless Inference API service.
|
|
3252
|
-
|
|
3253
|
-
This helper checks deployed models framework by framework. By default, it will check the 4 main frameworks that
|
|
3254
|
-
are supported and account for 95% of the hosted models. However, if you want a complete list of models you can
|
|
3255
|
-
specify `frameworks="all"` as input. Alternatively, if you know before-hand which framework you are interested
|
|
3256
|
-
in, you can also restrict to search to this one (e.g. `frameworks="text-generation-inference"`). The more
|
|
3257
|
-
frameworks are checked, the more time it will take.
|
|
3258
|
-
|
|
3259
|
-
<Tip warning={true}>
|
|
3260
|
-
|
|
3261
|
-
This endpoint method does not return a live list of all models available for the HF Inference API service.
|
|
3262
|
-
It searches over a cached list of models that were recently available and the list may not be up to date.
|
|
3263
|
-
If you want to know the live status of a specific model, use [`~InferenceClient.get_model_status`].
|
|
3264
|
-
|
|
3265
|
-
</Tip>
|
|
3266
|
-
|
|
3267
|
-
<Tip>
|
|
3268
|
-
|
|
3269
|
-
This endpoint method is mostly useful for discoverability. If you already know which model you want to use and want to
|
|
3270
|
-
check its availability, you can directly use [`~InferenceClient.get_model_status`].
|
|
3271
|
-
|
|
3272
|
-
</Tip>
|
|
3273
|
-
|
|
3274
|
-
Args:
|
|
3275
|
-
frameworks (`Literal["all"]` or `List[str]` or `str`, *optional*):
|
|
3276
|
-
The frameworks to filter on. By default only a subset of the available frameworks are tested. If set to
|
|
3277
|
-
"all", all available frameworks will be tested. It is also possible to provide a single framework or a
|
|
3278
|
-
custom set of frameworks to check.
|
|
3279
|
-
|
|
3280
|
-
Returns:
|
|
3281
|
-
`Dict[str, List[str]]`: A dictionary mapping task names to a sorted list of model IDs.
|
|
3282
|
-
|
|
3283
|
-
Example:
|
|
3284
|
-
```python
|
|
3285
|
-
>>> from huggingface_hub import InferenceClient
|
|
3286
|
-
>>> client = InferenceClient()
|
|
3287
|
-
|
|
3288
|
-
# Discover zero-shot-classification models currently deployed
|
|
3289
|
-
>>> models = client.list_deployed_models()
|
|
3290
|
-
>>> models["zero-shot-classification"]
|
|
3291
|
-
['Narsil/deberta-large-mnli-zero-cls', 'facebook/bart-large-mnli', ...]
|
|
3292
|
-
|
|
3293
|
-
# List from only 1 framework
|
|
3294
|
-
>>> client.list_deployed_models("text-generation-inference")
|
|
3295
|
-
{'text-generation': ['bigcode/starcoder', 'meta-llama/Llama-2-70b-chat-hf', ...], ...}
|
|
3296
|
-
```
|
|
3297
|
-
"""
|
|
3298
|
-
if self.provider != "hf-inference":
|
|
3299
|
-
raise ValueError(f"Listing deployed models is not supported on '{self.provider}'.")
|
|
3300
|
-
|
|
3301
|
-
# Resolve which frameworks to check
|
|
3302
|
-
if frameworks is None:
|
|
3303
|
-
frameworks = constants.MAIN_INFERENCE_API_FRAMEWORKS
|
|
3304
|
-
elif frameworks == "all":
|
|
3305
|
-
frameworks = constants.ALL_INFERENCE_API_FRAMEWORKS
|
|
3306
|
-
elif isinstance(frameworks, str):
|
|
3307
|
-
frameworks = [frameworks]
|
|
3308
|
-
frameworks = list(set(frameworks))
|
|
3309
|
-
|
|
3310
|
-
# Fetch them iteratively
|
|
3311
|
-
models_by_task: Dict[str, List[str]] = {}
|
|
3312
|
-
|
|
3313
|
-
def _unpack_response(framework: str, items: List[Dict]) -> None:
|
|
3314
|
-
for model in items:
|
|
3315
|
-
if framework == "sentence-transformers":
|
|
3316
|
-
# Model running with the `sentence-transformers` framework can work with both tasks even if not
|
|
3317
|
-
# branded as such in the API response
|
|
3318
|
-
models_by_task.setdefault("feature-extraction", []).append(model["model_id"])
|
|
3319
|
-
models_by_task.setdefault("sentence-similarity", []).append(model["model_id"])
|
|
3320
|
-
else:
|
|
3321
|
-
models_by_task.setdefault(model["task"], []).append(model["model_id"])
|
|
3322
|
-
|
|
3323
|
-
for framework in frameworks:
|
|
3324
|
-
response = get_session().get(
|
|
3325
|
-
f"{constants.INFERENCE_ENDPOINT}/framework/{framework}", headers=build_hf_headers(token=self.token)
|
|
3326
|
-
)
|
|
3327
|
-
hf_raise_for_status(response)
|
|
3328
|
-
_unpack_response(framework, response.json())
|
|
3329
|
-
|
|
3330
|
-
# Sort alphabetically for discoverability and return
|
|
3331
|
-
for task, models in models_by_task.items():
|
|
3332
|
-
models_by_task[task] = sorted(set(models), key=lambda x: x.lower())
|
|
3333
|
-
return models_by_task
|
|
3334
|
-
|
|
3335
|
-
def get_endpoint_info(self, *, model: Optional[str] = None) -> Dict[str, Any]:
|
|
3270
|
+
def get_endpoint_info(self, *, model: Optional[str] = None) -> dict[str, Any]:
|
|
3336
3271
|
"""
|
|
3337
3272
|
Get information about the deployed endpoint.
|
|
3338
3273
|
|
|
@@ -3345,7 +3280,7 @@ class InferenceClient:
|
|
|
3345
3280
|
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
|
|
3346
3281
|
|
|
3347
3282
|
Returns:
|
|
3348
|
-
`
|
|
3283
|
+
`dict[str, Any]`: Information about the endpoint.
|
|
3349
3284
|
|
|
3350
3285
|
Example:
|
|
3351
3286
|
```py
|
|
@@ -3395,7 +3330,6 @@ class InferenceClient:
|
|
|
3395
3330
|
Check the health of the deployed endpoint.
|
|
3396
3331
|
|
|
3397
3332
|
Health check is only available with Inference Endpoints powered by Text-Generation-Inference (TGI) or Text-Embedding-Inference (TEI).
|
|
3398
|
-
For Inference API, please use [`InferenceClient.get_model_status`] instead.
|
|
3399
3333
|
|
|
3400
3334
|
Args:
|
|
3401
3335
|
model (`str`, *optional*):
|
|
@@ -3419,75 +3353,12 @@ class InferenceClient:
|
|
|
3419
3353
|
if model is None:
|
|
3420
3354
|
raise ValueError("Model id not provided.")
|
|
3421
3355
|
if not model.startswith(("http://", "https://")):
|
|
3422
|
-
raise ValueError(
|
|
3423
|
-
"Model must be an Inference Endpoint URL. For serverless Inference API, please use `InferenceClient.get_model_status`."
|
|
3424
|
-
)
|
|
3356
|
+
raise ValueError("Model must be an Inference Endpoint URL.")
|
|
3425
3357
|
url = model.rstrip("/") + "/health"
|
|
3426
3358
|
|
|
3427
3359
|
response = get_session().get(url, headers=build_hf_headers(token=self.token))
|
|
3428
3360
|
return response.status_code == 200
|
|
3429
3361
|
|
|
3430
|
-
@_deprecate_method(
|
|
3431
|
-
version="0.33.0",
|
|
3432
|
-
message=(
|
|
3433
|
-
"HF Inference API is getting revamped and will only support warm models in the future (no cold start allowed)."
|
|
3434
|
-
" Use `HfApi.model_info` to get the model status both with HF Inference API and external providers."
|
|
3435
|
-
),
|
|
3436
|
-
)
|
|
3437
|
-
def get_model_status(self, model: Optional[str] = None) -> ModelStatus:
|
|
3438
|
-
"""
|
|
3439
|
-
Get the status of a model hosted on the HF Inference API.
|
|
3440
|
-
|
|
3441
|
-
<Tip>
|
|
3442
|
-
|
|
3443
|
-
This endpoint is mostly useful when you already know which model you want to use and want to check its
|
|
3444
|
-
availability. If you want to discover already deployed models, you should rather use [`~InferenceClient.list_deployed_models`].
|
|
3445
|
-
|
|
3446
|
-
</Tip>
|
|
3447
|
-
|
|
3448
|
-
Args:
|
|
3449
|
-
model (`str`, *optional*):
|
|
3450
|
-
Identifier of the model for witch the status gonna be checked. If model is not provided,
|
|
3451
|
-
the model associated with this instance of [`InferenceClient`] will be used. Only HF Inference API service can be checked so the
|
|
3452
|
-
identifier cannot be a URL.
|
|
3453
|
-
|
|
3454
|
-
|
|
3455
|
-
Returns:
|
|
3456
|
-
[`ModelStatus`]: An instance of ModelStatus dataclass, containing information,
|
|
3457
|
-
about the state of the model: load, state, compute type and framework.
|
|
3458
|
-
|
|
3459
|
-
Example:
|
|
3460
|
-
```py
|
|
3461
|
-
>>> from huggingface_hub import InferenceClient
|
|
3462
|
-
>>> client = InferenceClient()
|
|
3463
|
-
>>> client.get_model_status("meta-llama/Meta-Llama-3-8B-Instruct")
|
|
3464
|
-
ModelStatus(loaded=True, state='Loaded', compute_type='gpu', framework='text-generation-inference')
|
|
3465
|
-
```
|
|
3466
|
-
"""
|
|
3467
|
-
if self.provider != "hf-inference":
|
|
3468
|
-
raise ValueError(f"Getting model status is not supported on '{self.provider}'.")
|
|
3469
|
-
|
|
3470
|
-
model = model or self.model
|
|
3471
|
-
if model is None:
|
|
3472
|
-
raise ValueError("Model id not provided.")
|
|
3473
|
-
if model.startswith("https://"):
|
|
3474
|
-
raise NotImplementedError("Model status is only available for Inference API endpoints.")
|
|
3475
|
-
url = f"{constants.INFERENCE_ENDPOINT}/status/{model}"
|
|
3476
|
-
|
|
3477
|
-
response = get_session().get(url, headers=build_hf_headers(token=self.token))
|
|
3478
|
-
hf_raise_for_status(response)
|
|
3479
|
-
response_data = response.json()
|
|
3480
|
-
|
|
3481
|
-
if "error" in response_data:
|
|
3482
|
-
raise ValueError(response_data["error"])
|
|
3483
|
-
|
|
3484
|
-
return ModelStatus(
|
|
3485
|
-
loaded=response_data["loaded"],
|
|
3486
|
-
state=response_data["state"],
|
|
3487
|
-
compute_type=response_data["compute_type"],
|
|
3488
|
-
framework=response_data["framework"],
|
|
3489
|
-
)
|
|
3490
|
-
|
|
3491
3362
|
@property
|
|
3492
3363
|
def chat(self) -> "ProxyClientChat":
|
|
3493
3364
|
return ProxyClientChat(self)
|