huggingface-hub 0.27.1__py3-none-any.whl → 0.28.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of huggingface-hub might be problematic. Click here for more details.
- huggingface_hub/__init__.py +418 -12
- huggingface_hub/_commit_api.py +33 -4
- huggingface_hub/_inference_endpoints.py +8 -2
- huggingface_hub/_local_folder.py +14 -3
- huggingface_hub/commands/scan_cache.py +1 -1
- huggingface_hub/commands/upload_large_folder.py +1 -1
- huggingface_hub/constants.py +7 -2
- huggingface_hub/file_download.py +1 -2
- huggingface_hub/hf_api.py +64 -83
- huggingface_hub/inference/_client.py +706 -450
- huggingface_hub/inference/_common.py +32 -64
- huggingface_hub/inference/_generated/_async_client.py +722 -470
- huggingface_hub/inference/_generated/types/__init__.py +1 -0
- huggingface_hub/inference/_generated/types/image_to_image.py +3 -3
- huggingface_hub/inference/_generated/types/text_to_audio.py +1 -2
- huggingface_hub/inference/_generated/types/text_to_image.py +3 -3
- huggingface_hub/inference/_generated/types/text_to_speech.py +3 -6
- huggingface_hub/inference/_generated/types/text_to_video.py +47 -0
- huggingface_hub/inference/_generated/types/visual_question_answering.py +1 -1
- huggingface_hub/inference/_providers/__init__.py +89 -0
- huggingface_hub/inference/_providers/fal_ai.py +155 -0
- huggingface_hub/inference/_providers/hf_inference.py +202 -0
- huggingface_hub/inference/_providers/replicate.py +144 -0
- huggingface_hub/inference/_providers/sambanova.py +85 -0
- huggingface_hub/inference/_providers/together.py +148 -0
- huggingface_hub/py.typed +0 -0
- huggingface_hub/repocard.py +1 -1
- huggingface_hub/repocard_data.py +2 -1
- huggingface_hub/serialization/_base.py +1 -1
- huggingface_hub/serialization/_torch.py +1 -1
- huggingface_hub/utils/_fixes.py +25 -13
- huggingface_hub/utils/_http.py +2 -2
- huggingface_hub/utils/logging.py +1 -1
- {huggingface_hub-0.27.1.dist-info → huggingface_hub-0.28.0rc0.dist-info}/METADATA +4 -4
- {huggingface_hub-0.27.1.dist-info → huggingface_hub-0.28.0rc0.dist-info}/RECORD +39 -31
- {huggingface_hub-0.27.1.dist-info → huggingface_hub-0.28.0rc0.dist-info}/LICENSE +0 -0
- {huggingface_hub-0.27.1.dist-info → huggingface_hub-0.28.0rc0.dist-info}/WHEEL +0 -0
- {huggingface_hub-0.27.1.dist-info → huggingface_hub-0.28.0rc0.dist-info}/entry_points.txt +0 -0
- {huggingface_hub-0.27.1.dist-info → huggingface_hub-0.28.0rc0.dist-info}/top_level.txt +0 -0
|
@@ -26,14 +26,13 @@ import time
|
|
|
26
26
|
import warnings
|
|
27
27
|
from typing import TYPE_CHECKING, Any, AsyncIterable, Dict, List, Literal, Optional, Set, Union, overload
|
|
28
28
|
|
|
29
|
-
from requests.structures import CaseInsensitiveDict
|
|
30
|
-
|
|
31
29
|
from huggingface_hub.constants import ALL_INFERENCE_API_FRAMEWORKS, INFERENCE_ENDPOINT, MAIN_INFERENCE_API_FRAMEWORKS
|
|
32
30
|
from huggingface_hub.errors import InferenceTimeoutError
|
|
33
31
|
from huggingface_hub.inference._common import (
|
|
34
32
|
TASKS_EXPECTING_IMAGES,
|
|
35
33
|
ContentT,
|
|
36
34
|
ModelStatus,
|
|
35
|
+
RequestParameters,
|
|
37
36
|
_async_stream_chat_completion_response,
|
|
38
37
|
_async_stream_text_generation_response,
|
|
39
38
|
_b64_encode,
|
|
@@ -41,11 +40,9 @@ from huggingface_hub.inference._common import (
|
|
|
41
40
|
_bytes_to_dict,
|
|
42
41
|
_bytes_to_image,
|
|
43
42
|
_bytes_to_list,
|
|
44
|
-
_fetch_recommended_models,
|
|
45
43
|
_get_unsupported_text_generation_kwargs,
|
|
46
44
|
_import_numpy,
|
|
47
45
|
_open_as_binary,
|
|
48
|
-
_prepare_payload,
|
|
49
46
|
_set_unsupported_text_generation_kwargs,
|
|
50
47
|
raise_text_generation_error,
|
|
51
48
|
)
|
|
@@ -90,8 +87,9 @@ from huggingface_hub.inference._generated.types import (
|
|
|
90
87
|
ZeroShotClassificationOutputElement,
|
|
91
88
|
ZeroShotImageClassificationOutputElement,
|
|
92
89
|
)
|
|
93
|
-
from huggingface_hub.
|
|
94
|
-
from huggingface_hub.utils
|
|
90
|
+
from huggingface_hub.inference._providers import PROVIDER_T, HFInferenceTask, get_provider_helper
|
|
91
|
+
from huggingface_hub.utils import build_hf_headers, get_session, hf_raise_for_status
|
|
92
|
+
from huggingface_hub.utils._deprecation import _deprecate_arguments, _deprecate_method
|
|
95
93
|
|
|
96
94
|
from .._common import _async_yield_from, _import_aiohttp
|
|
97
95
|
|
|
@@ -112,7 +110,7 @@ class AsyncInferenceClient:
|
|
|
112
110
|
Initialize a new Inference Client.
|
|
113
111
|
|
|
114
112
|
[`InferenceClient`] aims to provide a unified experience to perform inference. The client can be used
|
|
115
|
-
seamlessly with either the (free) Inference API
|
|
113
|
+
seamlessly with either the (free) Inference API, self-hosted Inference Endpoints, or third-party Inference Providers.
|
|
116
114
|
|
|
117
115
|
Args:
|
|
118
116
|
model (`str`, `optional`):
|
|
@@ -123,6 +121,10 @@ class AsyncInferenceClient:
|
|
|
123
121
|
arguments are mutually exclusive. If using `base_url` for chat completion, the `/chat/completions` suffix
|
|
124
122
|
path will be appended to the base URL (see the [TGI Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api)
|
|
125
123
|
documentation for details). When passing a URL as `model`, the client will not append any suffix path to it.
|
|
124
|
+
provider (`str`, *optional*):
|
|
125
|
+
Name of the provider to use for inference. Can be `"replicate"`, `"together"`, `"fal-ai"`, `"sambanova"` or `"hf-inference"`.
|
|
126
|
+
defaults to hf-inference (Hugging Face Serverless Inference API).
|
|
127
|
+
If model is a URL or `base_url` is passed, then `provider` is not used.
|
|
126
128
|
token (`str` or `bool`, *optional*):
|
|
127
129
|
Hugging Face token. Will default to the locally saved token if not provided.
|
|
128
130
|
Pass `token=False` if you don't want to send your token to the server.
|
|
@@ -152,7 +154,8 @@ class AsyncInferenceClient:
|
|
|
152
154
|
self,
|
|
153
155
|
model: Optional[str] = None,
|
|
154
156
|
*,
|
|
155
|
-
|
|
157
|
+
provider: Optional[PROVIDER_T] = None,
|
|
158
|
+
token: Optional[str] = None,
|
|
156
159
|
timeout: Optional[float] = None,
|
|
157
160
|
headers: Optional[Dict[str, str]] = None,
|
|
158
161
|
cookies: Optional[Dict[str, str]] = None,
|
|
@@ -177,12 +180,12 @@ class AsyncInferenceClient:
|
|
|
177
180
|
)
|
|
178
181
|
|
|
179
182
|
self.model: Optional[str] = model
|
|
180
|
-
self.token:
|
|
181
|
-
self.headers
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
if
|
|
185
|
-
|
|
183
|
+
self.token: Optional[str] = token if token is not None else api_key
|
|
184
|
+
self.headers = headers if headers is not None else {}
|
|
185
|
+
|
|
186
|
+
# Configure provider
|
|
187
|
+
self.provider = provider if provider is not None else "hf-inference"
|
|
188
|
+
|
|
186
189
|
self.cookies = cookies
|
|
187
190
|
self.timeout = timeout
|
|
188
191
|
self.trust_env = trust_env
|
|
@@ -230,6 +233,14 @@ class AsyncInferenceClient:
|
|
|
230
233
|
stream: bool = False,
|
|
231
234
|
) -> Union[bytes, AsyncIterable[bytes]]: ...
|
|
232
235
|
|
|
236
|
+
@_deprecate_method(
|
|
237
|
+
version="0.31.0",
|
|
238
|
+
message=(
|
|
239
|
+
"Making direct POST requests to the inference server is not supported anymore. "
|
|
240
|
+
"Please use task methods instead (e.g. `InferenceClient.chat_completion`). "
|
|
241
|
+
"If your use case is not supported, please open an issue in https://github.com/huggingface/huggingface_hub."
|
|
242
|
+
),
|
|
243
|
+
)
|
|
233
244
|
async def post(
|
|
234
245
|
self,
|
|
235
246
|
*,
|
|
@@ -242,56 +253,67 @@ class AsyncInferenceClient:
|
|
|
242
253
|
"""
|
|
243
254
|
Make a POST request to the inference server.
|
|
244
255
|
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
256
|
+
This method is deprecated and will be removed in the future.
|
|
257
|
+
Please use task methods instead (e.g. `InferenceClient.chat_completion`).
|
|
258
|
+
"""
|
|
259
|
+
if self.provider != "hf-inference":
|
|
260
|
+
raise ValueError(
|
|
261
|
+
"Cannot use `post` with another provider than `hf-inference`. "
|
|
262
|
+
"`InferenceClient.post` is deprecated and should not be used directly anymore."
|
|
263
|
+
)
|
|
264
|
+
provider_helper = HFInferenceTask(task or "unknown")
|
|
265
|
+
url = provider_helper.build_url(provider_helper.map_model(model))
|
|
266
|
+
headers = provider_helper.prepare_headers(headers=self.headers, api_key=self.token)
|
|
267
|
+
return await self._inner_post(
|
|
268
|
+
request_parameters=RequestParameters(
|
|
269
|
+
url=url,
|
|
270
|
+
task=task or "unknown",
|
|
271
|
+
model=model or "unknown",
|
|
272
|
+
json=json,
|
|
273
|
+
data=data,
|
|
274
|
+
headers=headers,
|
|
275
|
+
),
|
|
276
|
+
stream=stream,
|
|
277
|
+
)
|
|
262
278
|
|
|
263
|
-
|
|
264
|
-
|
|
279
|
+
@overload
|
|
280
|
+
async def _inner_post( # type: ignore[misc]
|
|
281
|
+
self, request_parameters: RequestParameters, *, stream: Literal[False] = ...
|
|
282
|
+
) -> bytes: ...
|
|
265
283
|
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
If the request fails with an HTTP error status code other than HTTP 503.
|
|
271
|
-
"""
|
|
284
|
+
@overload
|
|
285
|
+
async def _inner_post( # type: ignore[misc]
|
|
286
|
+
self, request_parameters: RequestParameters, *, stream: Literal[True] = ...
|
|
287
|
+
) -> AsyncIterable[bytes]: ...
|
|
272
288
|
|
|
273
|
-
|
|
289
|
+
@overload
|
|
290
|
+
async def _inner_post(
|
|
291
|
+
self, request_parameters: RequestParameters, *, stream: bool = False
|
|
292
|
+
) -> Union[bytes, AsyncIterable[bytes]]: ...
|
|
274
293
|
|
|
275
|
-
|
|
294
|
+
async def _inner_post(
|
|
295
|
+
self, request_parameters: RequestParameters, *, stream: bool = False
|
|
296
|
+
) -> Union[bytes, AsyncIterable[bytes]]:
|
|
297
|
+
"""Make a request to the inference server."""
|
|
276
298
|
|
|
277
|
-
|
|
278
|
-
warnings.warn("Ignoring `json` as `data` is passed as binary.")
|
|
299
|
+
aiohttp = _import_aiohttp()
|
|
279
300
|
|
|
280
|
-
#
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
headers["Accept"] = "image/png"
|
|
301
|
+
# TODO: this should be handled in provider helpers directly
|
|
302
|
+
if request_parameters.task in TASKS_EXPECTING_IMAGES and "Accept" not in request_parameters.headers:
|
|
303
|
+
request_parameters.headers["Accept"] = "image/png"
|
|
284
304
|
|
|
285
305
|
t0 = time.time()
|
|
286
306
|
timeout = self.timeout
|
|
287
307
|
while True:
|
|
288
|
-
with _open_as_binary(data) as data_as_binary:
|
|
308
|
+
with _open_as_binary(request_parameters.data) as data_as_binary:
|
|
289
309
|
# Do not use context manager as we don't want to close the connection immediately when returning
|
|
290
310
|
# a stream
|
|
291
|
-
session = self._get_client_session(headers=headers)
|
|
311
|
+
session = self._get_client_session(headers=request_parameters.headers)
|
|
292
312
|
|
|
293
313
|
try:
|
|
294
|
-
response = await session.post(
|
|
314
|
+
response = await session.post(
|
|
315
|
+
request_parameters.url, json=request_parameters.json, data=data_as_binary, proxy=self.proxies
|
|
316
|
+
)
|
|
295
317
|
response_error_payload = None
|
|
296
318
|
if response.status != 200:
|
|
297
319
|
try:
|
|
@@ -308,25 +330,27 @@ class AsyncInferenceClient:
|
|
|
308
330
|
except asyncio.TimeoutError as error:
|
|
309
331
|
await session.close()
|
|
310
332
|
# Convert any `TimeoutError` to a `InferenceTimeoutError`
|
|
311
|
-
raise InferenceTimeoutError(f"Inference call timed out: {url}") from error # type: ignore
|
|
333
|
+
raise InferenceTimeoutError(f"Inference call timed out: {request_parameters.url}") from error # type: ignore
|
|
312
334
|
except aiohttp.ClientResponseError as error:
|
|
313
335
|
error.response_error_payload = response_error_payload
|
|
314
336
|
await session.close()
|
|
315
|
-
if response.status == 422 and task
|
|
316
|
-
error.message += f". Make sure '{task}' task is supported by the model."
|
|
337
|
+
if response.status == 422 and request_parameters.task != "unknown":
|
|
338
|
+
error.message += f". Make sure '{request_parameters.task}' task is supported by the model."
|
|
317
339
|
if response.status == 503:
|
|
318
340
|
# If Model is unavailable, either raise a TimeoutError...
|
|
319
341
|
if timeout is not None and time.time() - t0 > timeout:
|
|
320
342
|
raise InferenceTimeoutError(
|
|
321
|
-
f"Model not loaded on the server: {url}. Please retry with a higher timeout"
|
|
343
|
+
f"Model not loaded on the server: {request_parameters.url}. Please retry with a higher timeout"
|
|
322
344
|
f" (current: {self.timeout}).",
|
|
323
345
|
request=error.request,
|
|
324
346
|
response=error.response,
|
|
325
347
|
) from error
|
|
326
348
|
# ...or wait 1s and retry
|
|
327
349
|
logger.info(f"Waiting for model to be loaded on the server: {error}")
|
|
328
|
-
if "X-wait-for-model" not in headers and url.startswith(
|
|
329
|
-
|
|
350
|
+
if "X-wait-for-model" not in request_parameters.headers and request_parameters.url.startswith(
|
|
351
|
+
INFERENCE_ENDPOINT
|
|
352
|
+
):
|
|
353
|
+
request_parameters.headers["X-wait-for-model"] = "1"
|
|
330
354
|
await asyncio.sleep(1)
|
|
331
355
|
if timeout is not None:
|
|
332
356
|
timeout = max(self.timeout - (time.time() - t0), 1) # type: ignore
|
|
@@ -408,9 +432,15 @@ class AsyncInferenceClient:
|
|
|
408
432
|
]
|
|
409
433
|
```
|
|
410
434
|
"""
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
435
|
+
provider_helper = get_provider_helper(self.provider, task="audio-classification")
|
|
436
|
+
request_parameters = provider_helper.prepare_request(
|
|
437
|
+
inputs=audio,
|
|
438
|
+
parameters={"function_to_apply": function_to_apply, "top_k": top_k},
|
|
439
|
+
headers=self.headers,
|
|
440
|
+
model=model or self.model,
|
|
441
|
+
api_key=self.token,
|
|
442
|
+
)
|
|
443
|
+
response = await self._inner_post(request_parameters)
|
|
414
444
|
return AudioClassificationOutputElement.parse_obj_as_list(response)
|
|
415
445
|
|
|
416
446
|
async def audio_to_audio(
|
|
@@ -451,7 +481,15 @@ class AsyncInferenceClient:
|
|
|
451
481
|
f.write(item.blob)
|
|
452
482
|
```
|
|
453
483
|
"""
|
|
454
|
-
|
|
484
|
+
provider_helper = get_provider_helper(self.provider, task="audio-to-audio")
|
|
485
|
+
request_parameters = provider_helper.prepare_request(
|
|
486
|
+
inputs=audio,
|
|
487
|
+
parameters={},
|
|
488
|
+
headers=self.headers,
|
|
489
|
+
model=model or self.model,
|
|
490
|
+
api_key=self.token,
|
|
491
|
+
)
|
|
492
|
+
response = await self._inner_post(request_parameters)
|
|
455
493
|
audio_output = AudioToAudioOutputElement.parse_obj_as_list(response)
|
|
456
494
|
for item in audio_output:
|
|
457
495
|
item.blob = base64.b64decode(item.blob)
|
|
@@ -472,7 +510,8 @@ class AsyncInferenceClient:
|
|
|
472
510
|
model (`str`, *optional*):
|
|
473
511
|
The model to use for ASR. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
|
474
512
|
Inference Endpoint. If not provided, the default recommended model for ASR will be used.
|
|
475
|
-
|
|
513
|
+
parameters (Dict[str, Any], *optional*):
|
|
514
|
+
Additional parameters to pass to the model.
|
|
476
515
|
Returns:
|
|
477
516
|
[`AutomaticSpeechRecognitionOutput`]: An item containing the transcribed text and optionally the timestamp chunks.
|
|
478
517
|
|
|
@@ -491,7 +530,15 @@ class AsyncInferenceClient:
|
|
|
491
530
|
"hello world"
|
|
492
531
|
```
|
|
493
532
|
"""
|
|
494
|
-
|
|
533
|
+
provider_helper = get_provider_helper(self.provider, task="automatic-speech-recognition")
|
|
534
|
+
request_parameters = provider_helper.prepare_request(
|
|
535
|
+
inputs=audio,
|
|
536
|
+
parameters={},
|
|
537
|
+
headers=self.headers,
|
|
538
|
+
model=model or self.model,
|
|
539
|
+
api_key=self.token,
|
|
540
|
+
)
|
|
541
|
+
response = await self._inner_post(request_parameters)
|
|
495
542
|
return AutomaticSpeechRecognitionOutput.parse_obj_as_instance(response)
|
|
496
543
|
|
|
497
544
|
@overload
|
|
@@ -605,6 +652,10 @@ class AsyncInferenceClient:
|
|
|
605
652
|
|
|
606
653
|
</Tip>
|
|
607
654
|
|
|
655
|
+
<Tip>
|
|
656
|
+
Some parameters might not be supported by some providers.
|
|
657
|
+
</Tip>
|
|
658
|
+
|
|
608
659
|
Args:
|
|
609
660
|
messages (List of [`ChatCompletionInputMessage`]):
|
|
610
661
|
Conversation history consisting of roles and content pairs.
|
|
@@ -612,25 +663,20 @@ class AsyncInferenceClient:
|
|
|
612
663
|
The model to use for chat-completion. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
|
613
664
|
Inference Endpoint. If not provided, the default recommended model for chat-based text-generation will be used.
|
|
614
665
|
See https://huggingface.co/tasks/text-generation for more details.
|
|
615
|
-
|
|
616
666
|
If `model` is a model ID, it is passed to the server as the `model` parameter. If you want to define a
|
|
617
667
|
custom URL while setting `model` in the request payload, you must set `base_url` when initializing [`InferenceClient`].
|
|
618
668
|
frequency_penalty (`float`, *optional*):
|
|
619
669
|
Penalizes new tokens based on their existing frequency
|
|
620
670
|
in the text so far. Range: [-2.0, 2.0]. Defaults to 0.0.
|
|
621
671
|
logit_bias (`List[float]`, *optional*):
|
|
622
|
-
|
|
623
|
-
(specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,
|
|
624
|
-
the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,
|
|
625
|
-
but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should
|
|
626
|
-
result in a ban or exclusive selection of the relevant token. Defaults to None.
|
|
672
|
+
Adjusts the likelihood of specific tokens appearing in the generated output.
|
|
627
673
|
logprobs (`bool`, *optional*):
|
|
628
674
|
Whether to return log probabilities of the output tokens or not. If true, returns the log
|
|
629
675
|
probabilities of each output token returned in the content of message.
|
|
630
676
|
max_tokens (`int`, *optional*):
|
|
631
677
|
Maximum number of tokens allowed in the response. Defaults to 100.
|
|
632
678
|
n (`int`, *optional*):
|
|
633
|
-
|
|
679
|
+
The number of completions to generate for each prompt.
|
|
634
680
|
presence_penalty (`float`, *optional*):
|
|
635
681
|
Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the
|
|
636
682
|
text so far, increasing the model's likelihood to talk about new topics.
|
|
@@ -638,7 +684,7 @@ class AsyncInferenceClient:
|
|
|
638
684
|
Grammar constraints. Can be either a JSONSchema or a regex.
|
|
639
685
|
seed (Optional[`int`], *optional*):
|
|
640
686
|
Seed for reproducible control flow. Defaults to None.
|
|
641
|
-
stop (
|
|
687
|
+
stop (`List[str]`, *optional*):
|
|
642
688
|
Up to four strings which trigger the end of the response.
|
|
643
689
|
Defaults to None.
|
|
644
690
|
stream (`bool`, *optional*):
|
|
@@ -750,6 +796,32 @@ class AsyncInferenceClient:
|
|
|
750
796
|
print(chunk.choices[0].delta.content)
|
|
751
797
|
```
|
|
752
798
|
|
|
799
|
+
Example using a third-party provider directly. Usage will be billed on your Together AI account.
|
|
800
|
+
```py
|
|
801
|
+
>>> from huggingface_hub import InferenceClient
|
|
802
|
+
>>> client = InferenceClient(
|
|
803
|
+
... provider="together", # Use Together AI provider
|
|
804
|
+
... api_key="<together_api_key>", # Pass your Together API key directly
|
|
805
|
+
... )
|
|
806
|
+
>>> client.chat_completion(
|
|
807
|
+
... model="meta-llama/Meta-Llama-3-8B-Instruct",
|
|
808
|
+
... messages=[{"role": "user", "content": "What is the capital of France?"}],
|
|
809
|
+
... )
|
|
810
|
+
```
|
|
811
|
+
|
|
812
|
+
Example using a third-party provider through Hugging Face Routing. Usage will be billed on your Hugging Face account.
|
|
813
|
+
```py
|
|
814
|
+
>>> from huggingface_hub import InferenceClient
|
|
815
|
+
>>> client = InferenceClient(
|
|
816
|
+
... provider="sambanova", # Use Sambanova provider
|
|
817
|
+
... api_key="hf_...", # Pass your HF token
|
|
818
|
+
... )
|
|
819
|
+
>>> client.chat_completion(
|
|
820
|
+
... model="meta-llama/Meta-Llama-3-8B-Instruct",
|
|
821
|
+
... messages=[{"role": "user", "content": "What is the capital of France?"}],
|
|
822
|
+
... )
|
|
823
|
+
```
|
|
824
|
+
|
|
753
825
|
Example using Image + Text as input:
|
|
754
826
|
```py
|
|
755
827
|
# Must be run in an async context
|
|
@@ -901,68 +973,50 @@ class AsyncInferenceClient:
|
|
|
901
973
|
'{\n\n"activity": "bike ride",\n"animals": ["puppy", "cat", "raccoon"],\n"animals_seen": 3,\n"location": "park"}'
|
|
902
974
|
```
|
|
903
975
|
"""
|
|
904
|
-
|
|
905
|
-
|
|
906
|
-
|
|
907
|
-
#
|
|
908
|
-
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
|
|
926
|
-
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
|
|
976
|
+
# Get the provider helper
|
|
977
|
+
provider_helper = get_provider_helper(self.provider, task="conversational")
|
|
978
|
+
|
|
979
|
+
# Since `chat_completion(..., model=xxx)` is also a payload parameter for the server, we need to handle 'model' differently.
|
|
980
|
+
# `self.base_url` and `self.model` takes precedence over 'model' argument for building URL.
|
|
981
|
+
# `model` takes precedence for payload value.
|
|
982
|
+
model_id_or_url = self.base_url or self.model or model
|
|
983
|
+
payload_model = model or self.model
|
|
984
|
+
|
|
985
|
+
# Prepare the payload
|
|
986
|
+
parameters = {
|
|
987
|
+
"model": payload_model,
|
|
988
|
+
"frequency_penalty": frequency_penalty,
|
|
989
|
+
"logit_bias": logit_bias,
|
|
990
|
+
"logprobs": logprobs,
|
|
991
|
+
"max_tokens": max_tokens,
|
|
992
|
+
"n": n,
|
|
993
|
+
"presence_penalty": presence_penalty,
|
|
994
|
+
"response_format": response_format,
|
|
995
|
+
"seed": seed,
|
|
996
|
+
"stop": stop,
|
|
997
|
+
"temperature": temperature,
|
|
998
|
+
"tool_choice": tool_choice,
|
|
999
|
+
"tool_prompt": tool_prompt,
|
|
1000
|
+
"tools": tools,
|
|
1001
|
+
"top_logprobs": top_logprobs,
|
|
1002
|
+
"top_p": top_p,
|
|
1003
|
+
"stream": stream,
|
|
1004
|
+
"stream_options": stream_options,
|
|
1005
|
+
}
|
|
1006
|
+
request_parameters = provider_helper.prepare_request(
|
|
1007
|
+
inputs=messages,
|
|
1008
|
+
parameters=parameters,
|
|
1009
|
+
headers=self.headers,
|
|
1010
|
+
model=model_id_or_url,
|
|
1011
|
+
api_key=self.token,
|
|
932
1012
|
)
|
|
933
|
-
|
|
934
|
-
data = await self.post(model=model_url, json=payload, stream=stream)
|
|
1013
|
+
data = await self._inner_post(request_parameters, stream=stream)
|
|
935
1014
|
|
|
936
1015
|
if stream:
|
|
937
1016
|
return _async_stream_chat_completion_response(data) # type: ignore[arg-type]
|
|
938
1017
|
|
|
939
1018
|
return ChatCompletionOutput.parse_obj_as_instance(data) # type: ignore[arg-type]
|
|
940
1019
|
|
|
941
|
-
def _resolve_chat_completion_url(self, model: Optional[str] = None) -> str:
|
|
942
|
-
# Since `chat_completion(..., model=xxx)` is also a payload parameter for the server, we need to handle 'model' differently.
|
|
943
|
-
# `self.base_url` and `self.model` takes precedence over 'model' argument only in `chat_completion`.
|
|
944
|
-
model_id_or_url = self.base_url or self.model or model or self.get_recommended_model("text-generation")
|
|
945
|
-
|
|
946
|
-
# Resolve URL if it's a model ID
|
|
947
|
-
model_url = (
|
|
948
|
-
model_id_or_url
|
|
949
|
-
if model_id_or_url.startswith(("http://", "https://"))
|
|
950
|
-
else self._resolve_url(model_id_or_url, task="text-generation")
|
|
951
|
-
)
|
|
952
|
-
|
|
953
|
-
# Strip trailing /
|
|
954
|
-
model_url = model_url.rstrip("/")
|
|
955
|
-
|
|
956
|
-
# Append /chat/completions if not already present
|
|
957
|
-
if model_url.endswith("/v1"):
|
|
958
|
-
model_url += "/chat/completions"
|
|
959
|
-
|
|
960
|
-
# Append /v1/chat/completions if not already present
|
|
961
|
-
if not model_url.endswith("/chat/completions"):
|
|
962
|
-
model_url += "/v1/chat/completions"
|
|
963
|
-
|
|
964
|
-
return model_url
|
|
965
|
-
|
|
966
1020
|
async def document_question_answering(
|
|
967
1021
|
self,
|
|
968
1022
|
image: ContentT,
|
|
@@ -1030,18 +1084,24 @@ class AsyncInferenceClient:
|
|
|
1030
1084
|
```
|
|
1031
1085
|
"""
|
|
1032
1086
|
inputs: Dict[str, Any] = {"question": question, "image": _b64_encode(image)}
|
|
1033
|
-
|
|
1034
|
-
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
|
|
1040
|
-
|
|
1041
|
-
|
|
1042
|
-
|
|
1043
|
-
|
|
1044
|
-
|
|
1087
|
+
provider_helper = get_provider_helper(self.provider, task="document-question-answering")
|
|
1088
|
+
request_parameters = provider_helper.prepare_request(
|
|
1089
|
+
inputs=inputs,
|
|
1090
|
+
parameters={
|
|
1091
|
+
"doc_stride": doc_stride,
|
|
1092
|
+
"handle_impossible_answer": handle_impossible_answer,
|
|
1093
|
+
"lang": lang,
|
|
1094
|
+
"max_answer_len": max_answer_len,
|
|
1095
|
+
"max_question_len": max_question_len,
|
|
1096
|
+
"max_seq_len": max_seq_len,
|
|
1097
|
+
"top_k": top_k,
|
|
1098
|
+
"word_boxes": word_boxes,
|
|
1099
|
+
},
|
|
1100
|
+
headers=self.headers,
|
|
1101
|
+
model=model or self.model,
|
|
1102
|
+
api_key=self.token,
|
|
1103
|
+
)
|
|
1104
|
+
response = await self._inner_post(request_parameters)
|
|
1045
1105
|
return DocumentQuestionAnsweringOutputElement.parse_obj_as_list(response)
|
|
1046
1106
|
|
|
1047
1107
|
async def feature_extraction(
|
|
@@ -1100,14 +1160,20 @@ class AsyncInferenceClient:
|
|
|
1100
1160
|
[ 0.28552425, -0.928395 , -1.2077185 , ..., 0.76810825, -2.1069427 , 0.6236161 ]], dtype=float32)
|
|
1101
1161
|
```
|
|
1102
1162
|
"""
|
|
1103
|
-
|
|
1104
|
-
|
|
1105
|
-
|
|
1106
|
-
|
|
1107
|
-
|
|
1108
|
-
|
|
1109
|
-
|
|
1110
|
-
|
|
1163
|
+
provider_helper = get_provider_helper(self.provider, task="feature-extraction")
|
|
1164
|
+
request_parameters = provider_helper.prepare_request(
|
|
1165
|
+
inputs=text,
|
|
1166
|
+
parameters={
|
|
1167
|
+
"normalize": normalize,
|
|
1168
|
+
"prompt_name": prompt_name,
|
|
1169
|
+
"truncate": truncate,
|
|
1170
|
+
"truncation_direction": truncation_direction,
|
|
1171
|
+
},
|
|
1172
|
+
headers=self.headers,
|
|
1173
|
+
model=model or self.model,
|
|
1174
|
+
api_key=self.token,
|
|
1175
|
+
)
|
|
1176
|
+
response = await self._inner_post(request_parameters)
|
|
1111
1177
|
np = _import_numpy()
|
|
1112
1178
|
return np.array(_bytes_to_dict(response), dtype="float32")
|
|
1113
1179
|
|
|
@@ -1156,9 +1222,15 @@ class AsyncInferenceClient:
|
|
|
1156
1222
|
]
|
|
1157
1223
|
```
|
|
1158
1224
|
"""
|
|
1159
|
-
|
|
1160
|
-
|
|
1161
|
-
|
|
1225
|
+
provider_helper = get_provider_helper(self.provider, task="fill-mask")
|
|
1226
|
+
request_parameters = provider_helper.prepare_request(
|
|
1227
|
+
inputs=text,
|
|
1228
|
+
parameters={"targets": targets, "top_k": top_k},
|
|
1229
|
+
headers=self.headers,
|
|
1230
|
+
model=model or self.model,
|
|
1231
|
+
api_key=self.token,
|
|
1232
|
+
)
|
|
1233
|
+
response = await self._inner_post(request_parameters)
|
|
1162
1234
|
return FillMaskOutputElement.parse_obj_as_list(response)
|
|
1163
1235
|
|
|
1164
1236
|
async def image_classification(
|
|
@@ -1200,9 +1272,15 @@ class AsyncInferenceClient:
|
|
|
1200
1272
|
[ImageClassificationOutputElement(label='Blenheim spaniel', score=0.9779096841812134), ...]
|
|
1201
1273
|
```
|
|
1202
1274
|
"""
|
|
1203
|
-
|
|
1204
|
-
|
|
1205
|
-
|
|
1275
|
+
provider_helper = get_provider_helper(self.provider, task="image-classification")
|
|
1276
|
+
request_parameters = provider_helper.prepare_request(
|
|
1277
|
+
inputs=image,
|
|
1278
|
+
parameters={"function_to_apply": function_to_apply, "top_k": top_k},
|
|
1279
|
+
headers=self.headers,
|
|
1280
|
+
model=model or self.model,
|
|
1281
|
+
api_key=self.token,
|
|
1282
|
+
)
|
|
1283
|
+
response = await self._inner_post(request_parameters)
|
|
1206
1284
|
return ImageClassificationOutputElement.parse_obj_as_list(response)
|
|
1207
1285
|
|
|
1208
1286
|
async def image_segmentation(
|
|
@@ -1256,14 +1334,20 @@ class AsyncInferenceClient:
|
|
|
1256
1334
|
[ImageSegmentationOutputElement(score=0.989008, label='LABEL_184', mask=<PIL.PngImagePlugin.PngImageFile image mode=L size=400x300 at 0x7FDD2B129CC0>), ...]
|
|
1257
1335
|
```
|
|
1258
1336
|
"""
|
|
1259
|
-
|
|
1260
|
-
|
|
1261
|
-
|
|
1262
|
-
|
|
1263
|
-
|
|
1264
|
-
|
|
1265
|
-
|
|
1266
|
-
|
|
1337
|
+
provider_helper = get_provider_helper(self.provider, task="audio-classification")
|
|
1338
|
+
request_parameters = provider_helper.prepare_request(
|
|
1339
|
+
inputs=image,
|
|
1340
|
+
parameters={
|
|
1341
|
+
"mask_threshold": mask_threshold,
|
|
1342
|
+
"overlap_mask_area_threshold": overlap_mask_area_threshold,
|
|
1343
|
+
"subtask": subtask,
|
|
1344
|
+
"threshold": threshold,
|
|
1345
|
+
},
|
|
1346
|
+
headers=self.headers,
|
|
1347
|
+
model=model or self.model,
|
|
1348
|
+
api_key=self.token,
|
|
1349
|
+
)
|
|
1350
|
+
response = await self._inner_post(request_parameters)
|
|
1267
1351
|
output = ImageSegmentationOutputElement.parse_obj_as_list(response)
|
|
1268
1352
|
for item in output:
|
|
1269
1353
|
item.mask = _b64_to_image(item.mask) # type: ignore [assignment]
|
|
@@ -1274,7 +1358,7 @@ class AsyncInferenceClient:
|
|
|
1274
1358
|
image: ContentT,
|
|
1275
1359
|
prompt: Optional[str] = None,
|
|
1276
1360
|
*,
|
|
1277
|
-
negative_prompt: Optional[
|
|
1361
|
+
negative_prompt: Optional[str] = None,
|
|
1278
1362
|
num_inference_steps: Optional[int] = None,
|
|
1279
1363
|
guidance_scale: Optional[float] = None,
|
|
1280
1364
|
model: Optional[str] = None,
|
|
@@ -1295,8 +1379,8 @@ class AsyncInferenceClient:
|
|
|
1295
1379
|
The input image for translation. It can be raw bytes, an image file, or a URL to an online image.
|
|
1296
1380
|
prompt (`str`, *optional*):
|
|
1297
1381
|
The text prompt to guide the image generation.
|
|
1298
|
-
negative_prompt (`
|
|
1299
|
-
One
|
|
1382
|
+
negative_prompt (`str`, *optional*):
|
|
1383
|
+
One prompt to guide what NOT to include in image generation.
|
|
1300
1384
|
num_inference_steps (`int`, *optional*):
|
|
1301
1385
|
For diffusion models. The number of denoising steps. More denoising steps usually lead to a higher
|
|
1302
1386
|
quality image at the expense of slower inference.
|
|
@@ -1327,16 +1411,22 @@ class AsyncInferenceClient:
|
|
|
1327
1411
|
>>> image.save("tiger.jpg")
|
|
1328
1412
|
```
|
|
1329
1413
|
"""
|
|
1330
|
-
|
|
1331
|
-
|
|
1332
|
-
|
|
1333
|
-
|
|
1334
|
-
|
|
1335
|
-
|
|
1336
|
-
|
|
1337
|
-
|
|
1338
|
-
|
|
1339
|
-
|
|
1414
|
+
provider_helper = get_provider_helper(self.provider, task="image-to-image")
|
|
1415
|
+
request_parameters = provider_helper.prepare_request(
|
|
1416
|
+
inputs=image,
|
|
1417
|
+
parameters={
|
|
1418
|
+
"prompt": prompt,
|
|
1419
|
+
"negative_prompt": negative_prompt,
|
|
1420
|
+
"target_size": target_size,
|
|
1421
|
+
"num_inference_steps": num_inference_steps,
|
|
1422
|
+
"guidance_scale": guidance_scale,
|
|
1423
|
+
**kwargs,
|
|
1424
|
+
},
|
|
1425
|
+
headers=self.headers,
|
|
1426
|
+
model=model or self.model,
|
|
1427
|
+
api_key=self.token,
|
|
1428
|
+
)
|
|
1429
|
+
response = await self._inner_post(request_parameters)
|
|
1340
1430
|
return _bytes_to_image(response)
|
|
1341
1431
|
|
|
1342
1432
|
async def image_to_text(self, image: ContentT, *, model: Optional[str] = None) -> ImageToTextOutput:
|
|
@@ -1373,99 +1463,18 @@ class AsyncInferenceClient:
|
|
|
1373
1463
|
'a dog laying on the grass next to a flower pot '
|
|
1374
1464
|
```
|
|
1375
1465
|
"""
|
|
1376
|
-
|
|
1466
|
+
provider_helper = get_provider_helper(self.provider, task="image-to-text")
|
|
1467
|
+
request_parameters = provider_helper.prepare_request(
|
|
1468
|
+
inputs=image,
|
|
1469
|
+
parameters={},
|
|
1470
|
+
headers=self.headers,
|
|
1471
|
+
model=model or self.model,
|
|
1472
|
+
api_key=self.token,
|
|
1473
|
+
)
|
|
1474
|
+
response = await self._inner_post(request_parameters)
|
|
1377
1475
|
output = ImageToTextOutput.parse_obj(response)
|
|
1378
1476
|
return output[0] if isinstance(output, list) else output
|
|
1379
1477
|
|
|
1380
|
-
async def list_deployed_models(
|
|
1381
|
-
self, frameworks: Union[None, str, Literal["all"], List[str]] = None
|
|
1382
|
-
) -> Dict[str, List[str]]:
|
|
1383
|
-
"""
|
|
1384
|
-
List models deployed on the Serverless Inference API service.
|
|
1385
|
-
|
|
1386
|
-
This helper checks deployed models framework by framework. By default, it will check the 4 main frameworks that
|
|
1387
|
-
are supported and account for 95% of the hosted models. However, if you want a complete list of models you can
|
|
1388
|
-
specify `frameworks="all"` as input. Alternatively, if you know before-hand which framework you are interested
|
|
1389
|
-
in, you can also restrict to search to this one (e.g. `frameworks="text-generation-inference"`). The more
|
|
1390
|
-
frameworks are checked, the more time it will take.
|
|
1391
|
-
|
|
1392
|
-
<Tip warning={true}>
|
|
1393
|
-
|
|
1394
|
-
This endpoint method does not return a live list of all models available for the Serverless Inference API service.
|
|
1395
|
-
It searches over a cached list of models that were recently available and the list may not be up to date.
|
|
1396
|
-
If you want to know the live status of a specific model, use [`~InferenceClient.get_model_status`].
|
|
1397
|
-
|
|
1398
|
-
</Tip>
|
|
1399
|
-
|
|
1400
|
-
<Tip>
|
|
1401
|
-
|
|
1402
|
-
This endpoint method is mostly useful for discoverability. If you already know which model you want to use and want to
|
|
1403
|
-
check its availability, you can directly use [`~InferenceClient.get_model_status`].
|
|
1404
|
-
|
|
1405
|
-
</Tip>
|
|
1406
|
-
|
|
1407
|
-
Args:
|
|
1408
|
-
frameworks (`Literal["all"]` or `List[str]` or `str`, *optional*):
|
|
1409
|
-
The frameworks to filter on. By default only a subset of the available frameworks are tested. If set to
|
|
1410
|
-
"all", all available frameworks will be tested. It is also possible to provide a single framework or a
|
|
1411
|
-
custom set of frameworks to check.
|
|
1412
|
-
|
|
1413
|
-
Returns:
|
|
1414
|
-
`Dict[str, List[str]]`: A dictionary mapping task names to a sorted list of model IDs.
|
|
1415
|
-
|
|
1416
|
-
Example:
|
|
1417
|
-
```py
|
|
1418
|
-
# Must be run in an async contextthon
|
|
1419
|
-
>>> from huggingface_hub import AsyncInferenceClient
|
|
1420
|
-
>>> client = AsyncInferenceClient()
|
|
1421
|
-
|
|
1422
|
-
# Discover zero-shot-classification models currently deployed
|
|
1423
|
-
>>> models = await client.list_deployed_models()
|
|
1424
|
-
>>> models["zero-shot-classification"]
|
|
1425
|
-
['Narsil/deberta-large-mnli-zero-cls', 'facebook/bart-large-mnli', ...]
|
|
1426
|
-
|
|
1427
|
-
# List from only 1 framework
|
|
1428
|
-
>>> await client.list_deployed_models("text-generation-inference")
|
|
1429
|
-
{'text-generation': ['bigcode/starcoder', 'meta-llama/Llama-2-70b-chat-hf', ...], ...}
|
|
1430
|
-
```
|
|
1431
|
-
"""
|
|
1432
|
-
# Resolve which frameworks to check
|
|
1433
|
-
if frameworks is None:
|
|
1434
|
-
frameworks = MAIN_INFERENCE_API_FRAMEWORKS
|
|
1435
|
-
elif frameworks == "all":
|
|
1436
|
-
frameworks = ALL_INFERENCE_API_FRAMEWORKS
|
|
1437
|
-
elif isinstance(frameworks, str):
|
|
1438
|
-
frameworks = [frameworks]
|
|
1439
|
-
frameworks = list(set(frameworks))
|
|
1440
|
-
|
|
1441
|
-
# Fetch them iteratively
|
|
1442
|
-
models_by_task: Dict[str, List[str]] = {}
|
|
1443
|
-
|
|
1444
|
-
def _unpack_response(framework: str, items: List[Dict]) -> None:
|
|
1445
|
-
for model in items:
|
|
1446
|
-
if framework == "sentence-transformers":
|
|
1447
|
-
# Model running with the `sentence-transformers` framework can work with both tasks even if not
|
|
1448
|
-
# branded as such in the API response
|
|
1449
|
-
models_by_task.setdefault("feature-extraction", []).append(model["model_id"])
|
|
1450
|
-
models_by_task.setdefault("sentence-similarity", []).append(model["model_id"])
|
|
1451
|
-
else:
|
|
1452
|
-
models_by_task.setdefault(model["task"], []).append(model["model_id"])
|
|
1453
|
-
|
|
1454
|
-
async def _fetch_framework(framework: str) -> None:
|
|
1455
|
-
async with self._get_client_session() as client:
|
|
1456
|
-
response = await client.get(f"{INFERENCE_ENDPOINT}/framework/{framework}", proxy=self.proxies)
|
|
1457
|
-
response.raise_for_status()
|
|
1458
|
-
_unpack_response(framework, await response.json())
|
|
1459
|
-
|
|
1460
|
-
import asyncio
|
|
1461
|
-
|
|
1462
|
-
await asyncio.gather(*[_fetch_framework(framework) for framework in frameworks])
|
|
1463
|
-
|
|
1464
|
-
# Sort alphabetically for discoverability and return
|
|
1465
|
-
for task, models in models_by_task.items():
|
|
1466
|
-
models_by_task[task] = sorted(set(models), key=lambda x: x.lower())
|
|
1467
|
-
return models_by_task
|
|
1468
|
-
|
|
1469
1478
|
async def object_detection(
|
|
1470
1479
|
self, image: ContentT, *, model: Optional[str] = None, threshold: Optional[float] = None
|
|
1471
1480
|
) -> List[ObjectDetectionOutputElement]:
|
|
@@ -1506,11 +1515,15 @@ class AsyncInferenceClient:
|
|
|
1506
1515
|
[ObjectDetectionOutputElement(score=0.9486683011054993, label='person', box=ObjectDetectionBoundingBox(xmin=59, ymin=39, xmax=420, ymax=510)), ...]
|
|
1507
1516
|
```
|
|
1508
1517
|
"""
|
|
1509
|
-
|
|
1510
|
-
|
|
1511
|
-
|
|
1512
|
-
|
|
1513
|
-
|
|
1518
|
+
provider_helper = get_provider_helper(self.provider, task="object-detection")
|
|
1519
|
+
request_parameters = provider_helper.prepare_request(
|
|
1520
|
+
inputs=image,
|
|
1521
|
+
parameters={"threshold": threshold},
|
|
1522
|
+
headers=self.headers,
|
|
1523
|
+
model=model or self.model,
|
|
1524
|
+
api_key=self.token,
|
|
1525
|
+
)
|
|
1526
|
+
response = await self._inner_post(request_parameters)
|
|
1514
1527
|
return ObjectDetectionOutputElement.parse_obj_as_list(response)
|
|
1515
1528
|
|
|
1516
1529
|
async def question_answering(
|
|
@@ -1576,22 +1589,24 @@ class AsyncInferenceClient:
|
|
|
1576
1589
|
QuestionAnsweringOutputElement(answer='Clara', end=16, score=0.9326565265655518, start=11)
|
|
1577
1590
|
```
|
|
1578
1591
|
"""
|
|
1579
|
-
|
|
1580
|
-
|
|
1581
|
-
|
|
1582
|
-
|
|
1583
|
-
|
|
1584
|
-
|
|
1585
|
-
|
|
1586
|
-
|
|
1587
|
-
|
|
1588
|
-
|
|
1589
|
-
|
|
1590
|
-
|
|
1591
|
-
|
|
1592
|
-
|
|
1593
|
-
|
|
1592
|
+
provider_helper = get_provider_helper(self.provider, task="question-answering")
|
|
1593
|
+
request_parameters = provider_helper.prepare_request(
|
|
1594
|
+
inputs=None,
|
|
1595
|
+
parameters={
|
|
1596
|
+
"align_to_words": align_to_words,
|
|
1597
|
+
"doc_stride": doc_stride,
|
|
1598
|
+
"handle_impossible_answer": handle_impossible_answer,
|
|
1599
|
+
"max_answer_len": max_answer_len,
|
|
1600
|
+
"max_question_len": max_question_len,
|
|
1601
|
+
"max_seq_len": max_seq_len,
|
|
1602
|
+
"top_k": top_k,
|
|
1603
|
+
},
|
|
1604
|
+
extra_payload={"question": question, "context": context},
|
|
1605
|
+
headers=self.headers,
|
|
1606
|
+
model=model or self.model,
|
|
1607
|
+
api_key=self.token,
|
|
1594
1608
|
)
|
|
1609
|
+
response = await self._inner_post(request_parameters)
|
|
1595
1610
|
# Parse the response as a single `QuestionAnsweringOutputElement` when top_k is 1 or not provided, or a list of `QuestionAnsweringOutputElement` to ensure backward compatibility.
|
|
1596
1611
|
output = QuestionAnsweringOutputElement.parse_obj(response)
|
|
1597
1612
|
return output
|
|
@@ -1637,11 +1652,16 @@ class AsyncInferenceClient:
|
|
|
1637
1652
|
[0.7785726189613342, 0.45876261591911316, 0.2906220555305481]
|
|
1638
1653
|
```
|
|
1639
1654
|
"""
|
|
1640
|
-
|
|
1641
|
-
|
|
1642
|
-
|
|
1643
|
-
|
|
1655
|
+
provider_helper = get_provider_helper(self.provider, task="sentence-similarity")
|
|
1656
|
+
request_parameters = provider_helper.prepare_request(
|
|
1657
|
+
inputs=None,
|
|
1658
|
+
parameters={},
|
|
1659
|
+
extra_payload={"source_sentence": sentence, "sentences": other_sentences},
|
|
1660
|
+
headers=self.headers,
|
|
1661
|
+
model=model or self.model,
|
|
1662
|
+
api_key=self.token,
|
|
1644
1663
|
)
|
|
1664
|
+
response = await self._inner_post(request_parameters)
|
|
1645
1665
|
return _bytes_to_list(response)
|
|
1646
1666
|
|
|
1647
1667
|
@_deprecate_arguments(
|
|
@@ -1704,8 +1724,15 @@ class AsyncInferenceClient:
|
|
|
1704
1724
|
"generate_parameters": generate_parameters,
|
|
1705
1725
|
"truncation": truncation,
|
|
1706
1726
|
}
|
|
1707
|
-
|
|
1708
|
-
|
|
1727
|
+
provider_helper = get_provider_helper(self.provider, task="summarization")
|
|
1728
|
+
request_parameters = provider_helper.prepare_request(
|
|
1729
|
+
inputs=text,
|
|
1730
|
+
parameters=parameters,
|
|
1731
|
+
headers=self.headers,
|
|
1732
|
+
model=model or self.model,
|
|
1733
|
+
api_key=self.token,
|
|
1734
|
+
)
|
|
1735
|
+
response = await self._inner_post(request_parameters)
|
|
1709
1736
|
return SummarizationOutput.parse_obj_as_list(response)[0]
|
|
1710
1737
|
|
|
1711
1738
|
async def table_question_answering(
|
|
@@ -1759,21 +1786,16 @@ class AsyncInferenceClient:
|
|
|
1759
1786
|
TableQuestionAnsweringOutputElement(answer='36542', coordinates=[[0, 1]], cells=['36542'], aggregator='AVERAGE')
|
|
1760
1787
|
```
|
|
1761
1788
|
"""
|
|
1762
|
-
|
|
1763
|
-
|
|
1764
|
-
|
|
1765
|
-
"truncation": truncation,
|
|
1766
|
-
|
|
1767
|
-
|
|
1768
|
-
|
|
1769
|
-
|
|
1770
|
-
}
|
|
1771
|
-
payload = _prepare_payload(inputs, parameters=parameters)
|
|
1772
|
-
response = await self.post(
|
|
1773
|
-
**payload,
|
|
1774
|
-
model=model,
|
|
1775
|
-
task="table-question-answering",
|
|
1789
|
+
provider_helper = get_provider_helper(self.provider, task="table-question-answering")
|
|
1790
|
+
request_parameters = provider_helper.prepare_request(
|
|
1791
|
+
inputs=None,
|
|
1792
|
+
parameters={"model": model, "padding": padding, "sequential": sequential, "truncation": truncation},
|
|
1793
|
+
extra_payload={"query": query, "table": table},
|
|
1794
|
+
headers=self.headers,
|
|
1795
|
+
model=model or self.model,
|
|
1796
|
+
api_key=self.token,
|
|
1776
1797
|
)
|
|
1798
|
+
response = await self._inner_post(request_parameters)
|
|
1777
1799
|
return TableQuestionAnsweringOutputElement.parse_obj_as_instance(response)
|
|
1778
1800
|
|
|
1779
1801
|
async def tabular_classification(self, table: Dict[str, Any], *, model: Optional[str] = None) -> List[str]:
|
|
@@ -1819,11 +1841,16 @@ class AsyncInferenceClient:
|
|
|
1819
1841
|
["5", "5", "5"]
|
|
1820
1842
|
```
|
|
1821
1843
|
"""
|
|
1822
|
-
|
|
1823
|
-
|
|
1824
|
-
|
|
1825
|
-
|
|
1844
|
+
provider_helper = get_provider_helper(self.provider, task="tabular-classification")
|
|
1845
|
+
request_parameters = provider_helper.prepare_request(
|
|
1846
|
+
inputs=None,
|
|
1847
|
+
extra_payload={"table": table},
|
|
1848
|
+
parameters={},
|
|
1849
|
+
headers=self.headers,
|
|
1850
|
+
model=model or self.model,
|
|
1851
|
+
api_key=self.token,
|
|
1826
1852
|
)
|
|
1853
|
+
response = await self._inner_post(request_parameters)
|
|
1827
1854
|
return _bytes_to_list(response)
|
|
1828
1855
|
|
|
1829
1856
|
async def tabular_regression(self, table: Dict[str, Any], *, model: Optional[str] = None) -> List[float]:
|
|
@@ -1864,7 +1891,16 @@ class AsyncInferenceClient:
|
|
|
1864
1891
|
[110, 120, 130]
|
|
1865
1892
|
```
|
|
1866
1893
|
"""
|
|
1867
|
-
|
|
1894
|
+
provider_helper = get_provider_helper(self.provider, task="tabular-regression")
|
|
1895
|
+
request_parameters = provider_helper.prepare_request(
|
|
1896
|
+
inputs=None,
|
|
1897
|
+
parameters={},
|
|
1898
|
+
extra_payload={"table": table},
|
|
1899
|
+
headers=self.headers,
|
|
1900
|
+
model=model or self.model,
|
|
1901
|
+
api_key=self.token,
|
|
1902
|
+
)
|
|
1903
|
+
response = await self._inner_post(request_parameters)
|
|
1868
1904
|
return _bytes_to_list(response)
|
|
1869
1905
|
|
|
1870
1906
|
async def text_classification(
|
|
@@ -1911,16 +1947,18 @@ class AsyncInferenceClient:
|
|
|
1911
1947
|
]
|
|
1912
1948
|
```
|
|
1913
1949
|
"""
|
|
1914
|
-
|
|
1915
|
-
|
|
1916
|
-
|
|
1917
|
-
|
|
1918
|
-
|
|
1919
|
-
|
|
1920
|
-
|
|
1921
|
-
|
|
1922
|
-
|
|
1950
|
+
provider_helper = get_provider_helper(self.provider, task="text-classification")
|
|
1951
|
+
request_parameters = provider_helper.prepare_request(
|
|
1952
|
+
inputs=text,
|
|
1953
|
+
parameters={
|
|
1954
|
+
"function_to_apply": function_to_apply,
|
|
1955
|
+
"top_k": top_k,
|
|
1956
|
+
},
|
|
1957
|
+
headers=self.headers,
|
|
1958
|
+
model=model or self.model,
|
|
1959
|
+
api_key=self.token,
|
|
1923
1960
|
)
|
|
1961
|
+
response = await self._inner_post(request_parameters)
|
|
1924
1962
|
return TextClassificationOutputElement.parse_obj_as_list(response)[0] # type: ignore [return-value]
|
|
1925
1963
|
|
|
1926
1964
|
@overload
|
|
@@ -2104,15 +2142,6 @@ class AsyncInferenceClient:
|
|
|
2104
2142
|
"""
|
|
2105
2143
|
Given a prompt, generate the following text.
|
|
2106
2144
|
|
|
2107
|
-
API endpoint is supposed to run with the `text-generation-inference` backend (TGI). This backend is the
|
|
2108
|
-
go-to solution to run large language models at scale. However, for some smaller models (e.g. "gpt2") the
|
|
2109
|
-
default `transformers` + `api-inference` solution is still in use. Both approaches have very similar APIs, but
|
|
2110
|
-
not exactly the same. This method is compatible with both approaches but some parameters are only available for
|
|
2111
|
-
`text-generation-inference`. If some parameters are ignored, a warning message is triggered but the process
|
|
2112
|
-
continues correctly.
|
|
2113
|
-
|
|
2114
|
-
To learn more about the TGI project, please refer to https://github.com/huggingface/text-generation-inference.
|
|
2115
|
-
|
|
2116
2145
|
<Tip>
|
|
2117
2146
|
|
|
2118
2147
|
If you want to generate a response from chat messages, you should use the [`InferenceClient.chat_completion`] method.
|
|
@@ -2336,12 +2365,6 @@ class AsyncInferenceClient:
|
|
|
2336
2365
|
"typical_p": typical_p,
|
|
2337
2366
|
"watermark": watermark,
|
|
2338
2367
|
}
|
|
2339
|
-
parameters = {k: v for k, v in parameters.items() if v is not None}
|
|
2340
|
-
payload = {
|
|
2341
|
-
"inputs": prompt,
|
|
2342
|
-
"parameters": parameters,
|
|
2343
|
-
"stream": stream,
|
|
2344
|
-
}
|
|
2345
2368
|
|
|
2346
2369
|
# Remove some parameters if not a TGI server
|
|
2347
2370
|
unsupported_kwargs = _get_unsupported_text_generation_kwargs(model)
|
|
@@ -2374,9 +2397,19 @@ class AsyncInferenceClient:
|
|
|
2374
2397
|
" Please pass `stream=False` as input."
|
|
2375
2398
|
)
|
|
2376
2399
|
|
|
2400
|
+
provider_helper = get_provider_helper(self.provider, task="text-generation")
|
|
2401
|
+
request_parameters = provider_helper.prepare_request(
|
|
2402
|
+
inputs=prompt,
|
|
2403
|
+
parameters=parameters,
|
|
2404
|
+
extra_payload={"stream": stream},
|
|
2405
|
+
headers=self.headers,
|
|
2406
|
+
model=model or self.model,
|
|
2407
|
+
api_key=self.token,
|
|
2408
|
+
)
|
|
2409
|
+
|
|
2377
2410
|
# Handle errors separately for more precise error messages
|
|
2378
2411
|
try:
|
|
2379
|
-
bytes_output = await self.
|
|
2412
|
+
bytes_output = await self._inner_post(request_parameters, stream=stream)
|
|
2380
2413
|
except _import_aiohttp().ClientResponseError as e:
|
|
2381
2414
|
match = MODEL_KWARGS_NOT_USED_REGEX.search(e.response_error_payload["error"])
|
|
2382
2415
|
if e.status == 400 and match:
|
|
@@ -2386,7 +2419,7 @@ class AsyncInferenceClient:
|
|
|
2386
2419
|
prompt=prompt,
|
|
2387
2420
|
details=details,
|
|
2388
2421
|
stream=stream,
|
|
2389
|
-
model=model,
|
|
2422
|
+
model=model or self.model,
|
|
2390
2423
|
adapter_id=adapter_id,
|
|
2391
2424
|
best_of=best_of,
|
|
2392
2425
|
decoder_input_details=decoder_input_details,
|
|
@@ -2424,7 +2457,7 @@ class AsyncInferenceClient:
|
|
|
2424
2457
|
self,
|
|
2425
2458
|
prompt: str,
|
|
2426
2459
|
*,
|
|
2427
|
-
negative_prompt: Optional[
|
|
2460
|
+
negative_prompt: Optional[str] = None,
|
|
2428
2461
|
height: Optional[float] = None,
|
|
2429
2462
|
width: Optional[float] = None,
|
|
2430
2463
|
num_inference_steps: Optional[int] = None,
|
|
@@ -2447,8 +2480,8 @@ class AsyncInferenceClient:
|
|
|
2447
2480
|
Args:
|
|
2448
2481
|
prompt (`str`):
|
|
2449
2482
|
The prompt to generate an image from.
|
|
2450
|
-
negative_prompt (`
|
|
2451
|
-
One
|
|
2483
|
+
negative_prompt (`str`, *optional*):
|
|
2484
|
+
One prompt to guide what NOT to include in image generation.
|
|
2452
2485
|
height (`float`, *optional*):
|
|
2453
2486
|
The height in pixels of the image to generate.
|
|
2454
2487
|
width (`float`, *optional*):
|
|
@@ -2495,23 +2528,143 @@ class AsyncInferenceClient:
|
|
|
2495
2528
|
... )
|
|
2496
2529
|
>>> image.save("better_astronaut.png")
|
|
2497
2530
|
```
|
|
2498
|
-
|
|
2531
|
+
Example using a third-party provider directly. Usage will be billed on your fal.ai account.
|
|
2532
|
+
```py
|
|
2533
|
+
>>> from huggingface_hub import InferenceClient
|
|
2534
|
+
>>> client = InferenceClient(
|
|
2535
|
+
... provider="fal-ai", # Use fal.ai provider
|
|
2536
|
+
... api_key="fal-ai-api-key", # Pass your fal.ai API key
|
|
2537
|
+
... )
|
|
2538
|
+
>>> image = client.text_to_image(
|
|
2539
|
+
... "A majestic lion in a fantasy forest",
|
|
2540
|
+
... model="black-forest-labs/FLUX.1-schnell",
|
|
2541
|
+
... )
|
|
2542
|
+
>>> image.save("lion.png")
|
|
2543
|
+
```
|
|
2499
2544
|
|
|
2500
|
-
|
|
2501
|
-
|
|
2502
|
-
|
|
2503
|
-
|
|
2504
|
-
|
|
2505
|
-
|
|
2506
|
-
|
|
2507
|
-
|
|
2508
|
-
|
|
2509
|
-
|
|
2510
|
-
|
|
2511
|
-
|
|
2512
|
-
|
|
2545
|
+
Example using a third-party provider through Hugging Face Routing. Usage will be billed on your Hugging Face account.
|
|
2546
|
+
```py
|
|
2547
|
+
>>> from huggingface_hub import InferenceClient
|
|
2548
|
+
>>> client = InferenceClient(
|
|
2549
|
+
... provider="replicate", # Use replicate provider
|
|
2550
|
+
... api_key="hf_...", # Pass your HF token
|
|
2551
|
+
... )
|
|
2552
|
+
>>> image = client.text_to_image(
|
|
2553
|
+
... "An astronaut riding a horse on the moon.",
|
|
2554
|
+
... model="black-forest-labs/FLUX.1-dev",
|
|
2555
|
+
... )
|
|
2556
|
+
>>> image.save("astronaut.png")
|
|
2557
|
+
```
|
|
2558
|
+
"""
|
|
2559
|
+
provider_helper = get_provider_helper(self.provider, task="text-to-image")
|
|
2560
|
+
request_parameters = provider_helper.prepare_request(
|
|
2561
|
+
inputs=prompt,
|
|
2562
|
+
parameters={
|
|
2563
|
+
"negative_prompt": negative_prompt,
|
|
2564
|
+
"height": height,
|
|
2565
|
+
"width": width,
|
|
2566
|
+
"num_inference_steps": num_inference_steps,
|
|
2567
|
+
"guidance_scale": guidance_scale,
|
|
2568
|
+
"scheduler": scheduler,
|
|
2569
|
+
"target_size": target_size,
|
|
2570
|
+
"seed": seed,
|
|
2571
|
+
**kwargs,
|
|
2572
|
+
},
|
|
2573
|
+
headers=self.headers,
|
|
2574
|
+
model=model or self.model,
|
|
2575
|
+
api_key=self.token,
|
|
2576
|
+
)
|
|
2577
|
+
response = await self._inner_post(request_parameters)
|
|
2578
|
+
response = provider_helper.get_response(response)
|
|
2513
2579
|
return _bytes_to_image(response)
|
|
2514
2580
|
|
|
2581
|
+
async def text_to_video(
|
|
2582
|
+
self,
|
|
2583
|
+
prompt: str,
|
|
2584
|
+
*,
|
|
2585
|
+
model: Optional[str] = None,
|
|
2586
|
+
guidance_scale: Optional[float] = None,
|
|
2587
|
+
negative_prompt: Optional[List[str]] = None,
|
|
2588
|
+
num_frames: Optional[float] = None,
|
|
2589
|
+
num_inference_steps: Optional[int] = None,
|
|
2590
|
+
seed: Optional[int] = None,
|
|
2591
|
+
) -> bytes:
|
|
2592
|
+
"""
|
|
2593
|
+
Generate a video based on a given text.
|
|
2594
|
+
|
|
2595
|
+
Args:
|
|
2596
|
+
prompt (`str`):
|
|
2597
|
+
The prompt to generate a video from.
|
|
2598
|
+
model (`str`, *optional*):
|
|
2599
|
+
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
|
2600
|
+
Inference Endpoint. If not provided, the default recommended text-to-video model will be used.
|
|
2601
|
+
Defaults to None.
|
|
2602
|
+
guidance_scale (`float`, *optional*):
|
|
2603
|
+
A higher guidance scale value encourages the model to generate videos closely linked to the text
|
|
2604
|
+
prompt, but values too high may cause saturation and other artifacts.
|
|
2605
|
+
negative_prompt (`List[str]`, *optional*):
|
|
2606
|
+
One or several prompt to guide what NOT to include in video generation.
|
|
2607
|
+
num_frames (`float`, *optional*):
|
|
2608
|
+
The num_frames parameter determines how many video frames are generated.
|
|
2609
|
+
num_inference_steps (`int`, *optional*):
|
|
2610
|
+
The number of denoising steps. More denoising steps usually lead to a higher quality video at the
|
|
2611
|
+
expense of slower inference.
|
|
2612
|
+
seed (`int`, *optional*):
|
|
2613
|
+
Seed for the random number generator.
|
|
2614
|
+
|
|
2615
|
+
Returns:
|
|
2616
|
+
`bytes`: The generated video.
|
|
2617
|
+
|
|
2618
|
+
Example:
|
|
2619
|
+
|
|
2620
|
+
Example using a third-party provider directly. Usage will be billed on your fal.ai account.
|
|
2621
|
+
```py
|
|
2622
|
+
>>> from huggingface_hub import InferenceClient
|
|
2623
|
+
>>> client = InferenceClient(
|
|
2624
|
+
... provider="fal-ai", # Using fal.ai provider
|
|
2625
|
+
... api_key="fal-ai-api-key", # Pass your fal.ai API key
|
|
2626
|
+
... )
|
|
2627
|
+
>>> video = client.text_to_video(
|
|
2628
|
+
... "A majestic lion running in a fantasy forest",
|
|
2629
|
+
... model="tencent/HunyuanVideo",
|
|
2630
|
+
... )
|
|
2631
|
+
>>> with open("lion.mp4", "wb") as file:
|
|
2632
|
+
... file.write(video)
|
|
2633
|
+
```
|
|
2634
|
+
|
|
2635
|
+
Example using a third-party provider through Hugging Face Routing. Usage will be billed on your Hugging Face account.
|
|
2636
|
+
```py
|
|
2637
|
+
>>> from huggingface_hub import InferenceClient
|
|
2638
|
+
>>> client = InferenceClient(
|
|
2639
|
+
... provider="replicate", # Using replicate provider
|
|
2640
|
+
... api_key="hf_...", # Pass your HF token
|
|
2641
|
+
... )
|
|
2642
|
+
>>> video = client.text_to_video(
|
|
2643
|
+
... "A cat running in a park",
|
|
2644
|
+
... model="genmo/mochi-1-preview",
|
|
2645
|
+
... )
|
|
2646
|
+
>>> with open("cat.mp4", "wb") as file:
|
|
2647
|
+
... file.write(video)
|
|
2648
|
+
```
|
|
2649
|
+
"""
|
|
2650
|
+
provider_helper = get_provider_helper(self.provider, task="text-to-video")
|
|
2651
|
+
request_parameters = provider_helper.prepare_request(
|
|
2652
|
+
inputs=prompt,
|
|
2653
|
+
parameters={
|
|
2654
|
+
"guidance_scale": guidance_scale,
|
|
2655
|
+
"negative_prompt": negative_prompt,
|
|
2656
|
+
"num_frames": num_frames,
|
|
2657
|
+
"num_inference_steps": num_inference_steps,
|
|
2658
|
+
"seed": seed,
|
|
2659
|
+
},
|
|
2660
|
+
headers=self.headers,
|
|
2661
|
+
model=model or self.model,
|
|
2662
|
+
api_key=self.token,
|
|
2663
|
+
)
|
|
2664
|
+
response = await self._inner_post(request_parameters)
|
|
2665
|
+
response = provider_helper.get_response(response)
|
|
2666
|
+
return response
|
|
2667
|
+
|
|
2515
2668
|
async def text_to_speech(
|
|
2516
2669
|
self,
|
|
2517
2670
|
text: str,
|
|
@@ -2610,27 +2763,62 @@ class AsyncInferenceClient:
|
|
|
2610
2763
|
>>> audio = await client.text_to_speech("Hello world")
|
|
2611
2764
|
>>> Path("hello_world.flac").write_bytes(audio)
|
|
2612
2765
|
```
|
|
2766
|
+
|
|
2767
|
+
Example using a third-party provider directly. Usage will be billed on your Replicate account.
|
|
2768
|
+
```py
|
|
2769
|
+
>>> from huggingface_hub import InferenceClient
|
|
2770
|
+
>>> client = InferenceClient(
|
|
2771
|
+
... provider="replicate",
|
|
2772
|
+
... api_key="your-replicate-api-key", # Pass your Replicate API key directly
|
|
2773
|
+
... )
|
|
2774
|
+
>>> audio = client.text_to_speech(
|
|
2775
|
+
... text="Hello world",
|
|
2776
|
+
... model="OuteAI/OuteTTS-0.3-500M",
|
|
2777
|
+
... )
|
|
2778
|
+
>>> Path("hello_world.flac").write_bytes(audio)
|
|
2779
|
+
```
|
|
2780
|
+
|
|
2781
|
+
Example using a third-party provider through Hugging Face Routing. Usage will be billed on your Hugging Face account.
|
|
2782
|
+
```py
|
|
2783
|
+
>>> from huggingface_hub import InferenceClient
|
|
2784
|
+
>>> client = InferenceClient(
|
|
2785
|
+
... provider="replicate",
|
|
2786
|
+
... api_key="hf_...", # Pass your HF token
|
|
2787
|
+
... )
|
|
2788
|
+
>>> audio =client.text_to_speech(
|
|
2789
|
+
... text="Hello world",
|
|
2790
|
+
... model="OuteAI/OuteTTS-0.3-500M",
|
|
2791
|
+
... )
|
|
2792
|
+
>>> Path("hello_world.flac").write_bytes(audio)
|
|
2793
|
+
```
|
|
2613
2794
|
"""
|
|
2614
|
-
|
|
2615
|
-
|
|
2616
|
-
|
|
2617
|
-
|
|
2618
|
-
|
|
2619
|
-
|
|
2620
|
-
|
|
2621
|
-
|
|
2622
|
-
|
|
2623
|
-
|
|
2624
|
-
|
|
2625
|
-
|
|
2626
|
-
|
|
2627
|
-
|
|
2628
|
-
|
|
2629
|
-
|
|
2630
|
-
|
|
2631
|
-
|
|
2632
|
-
|
|
2633
|
-
|
|
2795
|
+
provider_helper = get_provider_helper(self.provider, task="text-to-speech")
|
|
2796
|
+
request_parameters = provider_helper.prepare_request(
|
|
2797
|
+
inputs=text,
|
|
2798
|
+
parameters={
|
|
2799
|
+
"do_sample": do_sample,
|
|
2800
|
+
"early_stopping": early_stopping,
|
|
2801
|
+
"epsilon_cutoff": epsilon_cutoff,
|
|
2802
|
+
"eta_cutoff": eta_cutoff,
|
|
2803
|
+
"max_length": max_length,
|
|
2804
|
+
"max_new_tokens": max_new_tokens,
|
|
2805
|
+
"min_length": min_length,
|
|
2806
|
+
"min_new_tokens": min_new_tokens,
|
|
2807
|
+
"num_beam_groups": num_beam_groups,
|
|
2808
|
+
"num_beams": num_beams,
|
|
2809
|
+
"penalty_alpha": penalty_alpha,
|
|
2810
|
+
"temperature": temperature,
|
|
2811
|
+
"top_k": top_k,
|
|
2812
|
+
"top_p": top_p,
|
|
2813
|
+
"typical_p": typical_p,
|
|
2814
|
+
"use_cache": use_cache,
|
|
2815
|
+
},
|
|
2816
|
+
headers=self.headers,
|
|
2817
|
+
model=model or self.model,
|
|
2818
|
+
api_key=self.token,
|
|
2819
|
+
)
|
|
2820
|
+
response = await self._inner_post(request_parameters)
|
|
2821
|
+
response = provider_helper.get_response(response)
|
|
2634
2822
|
return response
|
|
2635
2823
|
|
|
2636
2824
|
async def token_classification(
|
|
@@ -2693,18 +2881,19 @@ class AsyncInferenceClient:
|
|
|
2693
2881
|
]
|
|
2694
2882
|
```
|
|
2695
2883
|
"""
|
|
2696
|
-
|
|
2697
|
-
|
|
2698
|
-
|
|
2699
|
-
|
|
2700
|
-
|
|
2701
|
-
|
|
2702
|
-
|
|
2703
|
-
|
|
2704
|
-
|
|
2705
|
-
model=model,
|
|
2706
|
-
|
|
2884
|
+
provider_helper = get_provider_helper(self.provider, task="token-classification")
|
|
2885
|
+
request_parameters = provider_helper.prepare_request(
|
|
2886
|
+
inputs=text,
|
|
2887
|
+
parameters={
|
|
2888
|
+
"aggregation_strategy": aggregation_strategy,
|
|
2889
|
+
"ignore_labels": ignore_labels,
|
|
2890
|
+
"stride": stride,
|
|
2891
|
+
},
|
|
2892
|
+
headers=self.headers,
|
|
2893
|
+
model=model or self.model,
|
|
2894
|
+
api_key=self.token,
|
|
2707
2895
|
)
|
|
2896
|
+
response = await self._inner_post(request_parameters)
|
|
2708
2897
|
return TokenClassificationOutputElement.parse_obj_as_list(response)
|
|
2709
2898
|
|
|
2710
2899
|
async def translation(
|
|
@@ -2778,15 +2967,22 @@ class AsyncInferenceClient:
|
|
|
2778
2967
|
|
|
2779
2968
|
if src_lang is None and tgt_lang is not None:
|
|
2780
2969
|
raise ValueError("You cannot specify `tgt_lang` without specifying `src_lang`.")
|
|
2781
|
-
|
|
2782
|
-
|
|
2783
|
-
|
|
2784
|
-
|
|
2785
|
-
|
|
2786
|
-
|
|
2787
|
-
|
|
2788
|
-
|
|
2789
|
-
|
|
2970
|
+
|
|
2971
|
+
provider_helper = get_provider_helper(self.provider, task="translation")
|
|
2972
|
+
request_parameters = provider_helper.prepare_request(
|
|
2973
|
+
inputs=text,
|
|
2974
|
+
parameters={
|
|
2975
|
+
"src_lang": src_lang,
|
|
2976
|
+
"tgt_lang": tgt_lang,
|
|
2977
|
+
"clean_up_tokenization_spaces": clean_up_tokenization_spaces,
|
|
2978
|
+
"truncation": truncation,
|
|
2979
|
+
"generate_parameters": generate_parameters,
|
|
2980
|
+
},
|
|
2981
|
+
headers=self.headers,
|
|
2982
|
+
model=model or self.model,
|
|
2983
|
+
api_key=self.token,
|
|
2984
|
+
)
|
|
2985
|
+
response = await self._inner_post(request_parameters)
|
|
2790
2986
|
return TranslationOutput.parse_obj_as_list(response)[0]
|
|
2791
2987
|
|
|
2792
2988
|
async def visual_question_answering(
|
|
@@ -2836,10 +3032,16 @@ class AsyncInferenceClient:
|
|
|
2836
3032
|
]
|
|
2837
3033
|
```
|
|
2838
3034
|
"""
|
|
2839
|
-
|
|
2840
|
-
|
|
2841
|
-
|
|
2842
|
-
|
|
3035
|
+
provider_helper = get_provider_helper(self.provider, task="visual-question-answering")
|
|
3036
|
+
request_parameters = provider_helper.prepare_request(
|
|
3037
|
+
inputs=image,
|
|
3038
|
+
parameters={"top_k": top_k},
|
|
3039
|
+
headers=self.headers,
|
|
3040
|
+
model=model or self.model,
|
|
3041
|
+
api_key=self.token,
|
|
3042
|
+
extra_payload={"question": question, "image": _b64_encode(image)},
|
|
3043
|
+
)
|
|
3044
|
+
response = await self._inner_post(request_parameters)
|
|
2843
3045
|
return VisualQuestionAnsweringOutputElement.parse_obj_as_list(response)
|
|
2844
3046
|
|
|
2845
3047
|
@_deprecate_arguments(
|
|
@@ -2947,17 +3149,20 @@ class AsyncInferenceClient:
|
|
|
2947
3149
|
candidate_labels = labels
|
|
2948
3150
|
elif candidate_labels is None:
|
|
2949
3151
|
raise ValueError("Must specify `candidate_labels`")
|
|
2950
|
-
|
|
2951
|
-
|
|
2952
|
-
|
|
2953
|
-
|
|
2954
|
-
|
|
2955
|
-
|
|
2956
|
-
|
|
2957
|
-
|
|
2958
|
-
|
|
2959
|
-
|
|
3152
|
+
|
|
3153
|
+
provider_helper = get_provider_helper(self.provider, task="zero-shot-classification")
|
|
3154
|
+
request_parameters = provider_helper.prepare_request(
|
|
3155
|
+
inputs=text,
|
|
3156
|
+
parameters={
|
|
3157
|
+
"candidate_labels": candidate_labels,
|
|
3158
|
+
"multi_label": multi_label,
|
|
3159
|
+
"hypothesis_template": hypothesis_template,
|
|
3160
|
+
},
|
|
3161
|
+
headers=self.headers,
|
|
3162
|
+
model=model or self.model,
|
|
3163
|
+
api_key=self.token,
|
|
2960
3164
|
)
|
|
3165
|
+
response = await self._inner_post(request_parameters)
|
|
2961
3166
|
output = _bytes_to_dict(response)
|
|
2962
3167
|
return [
|
|
2963
3168
|
ZeroShotClassificationOutputElement.parse_obj_as_instance({"label": label, "score": score})
|
|
@@ -3031,18 +3236,110 @@ class AsyncInferenceClient:
|
|
|
3031
3236
|
# Raise ValueError if input is less than 2 labels
|
|
3032
3237
|
if len(candidate_labels) < 2:
|
|
3033
3238
|
raise ValueError("You must specify at least 2 classes to compare.")
|
|
3034
|
-
|
|
3035
|
-
|
|
3036
|
-
|
|
3037
|
-
|
|
3038
|
-
|
|
3039
|
-
|
|
3040
|
-
|
|
3041
|
-
|
|
3042
|
-
|
|
3239
|
+
|
|
3240
|
+
provider_helper = get_provider_helper(self.provider, task="zero-shot-image-classification")
|
|
3241
|
+
request_parameters = provider_helper.prepare_request(
|
|
3242
|
+
inputs=image,
|
|
3243
|
+
parameters={
|
|
3244
|
+
"candidate_labels": candidate_labels,
|
|
3245
|
+
"hypothesis_template": hypothesis_template,
|
|
3246
|
+
},
|
|
3247
|
+
headers=self.headers,
|
|
3248
|
+
model=model or self.model,
|
|
3249
|
+
api_key=self.token,
|
|
3043
3250
|
)
|
|
3251
|
+
response = await self._inner_post(request_parameters)
|
|
3044
3252
|
return ZeroShotImageClassificationOutputElement.parse_obj_as_list(response)
|
|
3045
3253
|
|
|
3254
|
+
async def list_deployed_models(
|
|
3255
|
+
self, frameworks: Union[None, str, Literal["all"], List[str]] = None
|
|
3256
|
+
) -> Dict[str, List[str]]:
|
|
3257
|
+
"""
|
|
3258
|
+
List models deployed on the Serverless Inference API service.
|
|
3259
|
+
|
|
3260
|
+
This helper checks deployed models framework by framework. By default, it will check the 4 main frameworks that
|
|
3261
|
+
are supported and account for 95% of the hosted models. However, if you want a complete list of models you can
|
|
3262
|
+
specify `frameworks="all"` as input. Alternatively, if you know before-hand which framework you are interested
|
|
3263
|
+
in, you can also restrict to search to this one (e.g. `frameworks="text-generation-inference"`). The more
|
|
3264
|
+
frameworks are checked, the more time it will take.
|
|
3265
|
+
|
|
3266
|
+
<Tip warning={true}>
|
|
3267
|
+
|
|
3268
|
+
This endpoint method does not return a live list of all models available for the Serverless Inference API service.
|
|
3269
|
+
It searches over a cached list of models that were recently available and the list may not be up to date.
|
|
3270
|
+
If you want to know the live status of a specific model, use [`~InferenceClient.get_model_status`].
|
|
3271
|
+
|
|
3272
|
+
</Tip>
|
|
3273
|
+
|
|
3274
|
+
<Tip>
|
|
3275
|
+
|
|
3276
|
+
This endpoint method is mostly useful for discoverability. If you already know which model you want to use and want to
|
|
3277
|
+
check its availability, you can directly use [`~InferenceClient.get_model_status`].
|
|
3278
|
+
|
|
3279
|
+
</Tip>
|
|
3280
|
+
|
|
3281
|
+
Args:
|
|
3282
|
+
frameworks (`Literal["all"]` or `List[str]` or `str`, *optional*):
|
|
3283
|
+
The frameworks to filter on. By default only a subset of the available frameworks are tested. If set to
|
|
3284
|
+
"all", all available frameworks will be tested. It is also possible to provide a single framework or a
|
|
3285
|
+
custom set of frameworks to check.
|
|
3286
|
+
|
|
3287
|
+
Returns:
|
|
3288
|
+
`Dict[str, List[str]]`: A dictionary mapping task names to a sorted list of model IDs.
|
|
3289
|
+
|
|
3290
|
+
Example:
|
|
3291
|
+
```py
|
|
3292
|
+
# Must be run in an async contextthon
|
|
3293
|
+
>>> from huggingface_hub import AsyncInferenceClient
|
|
3294
|
+
>>> client = AsyncInferenceClient()
|
|
3295
|
+
|
|
3296
|
+
# Discover zero-shot-classification models currently deployed
|
|
3297
|
+
>>> models = await client.list_deployed_models()
|
|
3298
|
+
>>> models["zero-shot-classification"]
|
|
3299
|
+
['Narsil/deberta-large-mnli-zero-cls', 'facebook/bart-large-mnli', ...]
|
|
3300
|
+
|
|
3301
|
+
# List from only 1 framework
|
|
3302
|
+
>>> await client.list_deployed_models("text-generation-inference")
|
|
3303
|
+
{'text-generation': ['bigcode/starcoder', 'meta-llama/Llama-2-70b-chat-hf', ...], ...}
|
|
3304
|
+
```
|
|
3305
|
+
"""
|
|
3306
|
+
if self.provider != "hf-inference":
|
|
3307
|
+
raise ValueError(f"Listing deployed models is not supported on '{self.provider}'.")
|
|
3308
|
+
|
|
3309
|
+
# Resolve which frameworks to check
|
|
3310
|
+
if frameworks is None:
|
|
3311
|
+
frameworks = MAIN_INFERENCE_API_FRAMEWORKS
|
|
3312
|
+
elif frameworks == "all":
|
|
3313
|
+
frameworks = ALL_INFERENCE_API_FRAMEWORKS
|
|
3314
|
+
elif isinstance(frameworks, str):
|
|
3315
|
+
frameworks = [frameworks]
|
|
3316
|
+
frameworks = list(set(frameworks))
|
|
3317
|
+
|
|
3318
|
+
# Fetch them iteratively
|
|
3319
|
+
models_by_task: Dict[str, List[str]] = {}
|
|
3320
|
+
|
|
3321
|
+
def _unpack_response(framework: str, items: List[Dict]) -> None:
|
|
3322
|
+
for model in items:
|
|
3323
|
+
if framework == "sentence-transformers":
|
|
3324
|
+
# Model running with the `sentence-transformers` framework can work with both tasks even if not
|
|
3325
|
+
# branded as such in the API response
|
|
3326
|
+
models_by_task.setdefault("feature-extraction", []).append(model["model_id"])
|
|
3327
|
+
models_by_task.setdefault("sentence-similarity", []).append(model["model_id"])
|
|
3328
|
+
else:
|
|
3329
|
+
models_by_task.setdefault(model["task"], []).append(model["model_id"])
|
|
3330
|
+
|
|
3331
|
+
for framework in frameworks:
|
|
3332
|
+
response = get_session().get(
|
|
3333
|
+
f"{INFERENCE_ENDPOINT}/framework/{framework}", headers=build_hf_headers(token=self.token)
|
|
3334
|
+
)
|
|
3335
|
+
hf_raise_for_status(response)
|
|
3336
|
+
_unpack_response(framework, response.json())
|
|
3337
|
+
|
|
3338
|
+
# Sort alphabetically for discoverability and return
|
|
3339
|
+
for task, models in models_by_task.items():
|
|
3340
|
+
models_by_task[task] = sorted(set(models), key=lambda x: x.lower())
|
|
3341
|
+
return models_by_task
|
|
3342
|
+
|
|
3046
3343
|
def _get_client_session(self, headers: Optional[Dict] = None) -> "ClientSession":
|
|
3047
3344
|
aiohttp = _import_aiohttp()
|
|
3048
3345
|
client_headers = self.headers.copy()
|
|
@@ -3084,60 +3381,6 @@ class AsyncInferenceClient:
|
|
|
3084
3381
|
session.close = close_session
|
|
3085
3382
|
return session
|
|
3086
3383
|
|
|
3087
|
-
def _resolve_url(self, model: Optional[str] = None, task: Optional[str] = None) -> str:
|
|
3088
|
-
model = model or self.model or self.base_url
|
|
3089
|
-
|
|
3090
|
-
# If model is already a URL, ignore `task` and return directly
|
|
3091
|
-
if model is not None and (model.startswith("http://") or model.startswith("https://")):
|
|
3092
|
-
return model
|
|
3093
|
-
|
|
3094
|
-
# # If no model but task is set => fetch the recommended one for this task
|
|
3095
|
-
if model is None:
|
|
3096
|
-
if task is None:
|
|
3097
|
-
raise ValueError(
|
|
3098
|
-
"You must specify at least a model (repo_id or URL) or a task, either when instantiating"
|
|
3099
|
-
" `InferenceClient` or when making a request."
|
|
3100
|
-
)
|
|
3101
|
-
model = self.get_recommended_model(task)
|
|
3102
|
-
logger.info(
|
|
3103
|
-
f"Using recommended model {model} for task {task}. Note that it is"
|
|
3104
|
-
f" encouraged to explicitly set `model='{model}'` as the recommended"
|
|
3105
|
-
" models list might get updated without prior notice."
|
|
3106
|
-
)
|
|
3107
|
-
|
|
3108
|
-
# Compute InferenceAPI url
|
|
3109
|
-
return (
|
|
3110
|
-
# Feature-extraction and sentence-similarity are the only cases where we handle models with several tasks.
|
|
3111
|
-
f"{INFERENCE_ENDPOINT}/pipeline/{task}/{model}"
|
|
3112
|
-
if task in ("feature-extraction", "sentence-similarity")
|
|
3113
|
-
# Otherwise, we use the default endpoint
|
|
3114
|
-
else f"{INFERENCE_ENDPOINT}/models/{model}"
|
|
3115
|
-
)
|
|
3116
|
-
|
|
3117
|
-
@staticmethod
|
|
3118
|
-
def get_recommended_model(task: str) -> str:
|
|
3119
|
-
"""
|
|
3120
|
-
Get the model Hugging Face recommends for the input task.
|
|
3121
|
-
|
|
3122
|
-
Args:
|
|
3123
|
-
task (`str`):
|
|
3124
|
-
The Hugging Face task to get which model Hugging Face recommends.
|
|
3125
|
-
All available tasks can be found [here](https://huggingface.co/tasks).
|
|
3126
|
-
|
|
3127
|
-
Returns:
|
|
3128
|
-
`str`: Name of the model recommended for the input task.
|
|
3129
|
-
|
|
3130
|
-
Raises:
|
|
3131
|
-
`ValueError`: If Hugging Face has no recommendation for the input task.
|
|
3132
|
-
"""
|
|
3133
|
-
model = _fetch_recommended_models().get(task)
|
|
3134
|
-
if model is None:
|
|
3135
|
-
raise ValueError(
|
|
3136
|
-
f"Task {task} has no recommended model. Please specify a model"
|
|
3137
|
-
" explicitly. Visit https://huggingface.co/tasks for more info."
|
|
3138
|
-
)
|
|
3139
|
-
return model
|
|
3140
|
-
|
|
3141
3384
|
async def get_endpoint_info(self, *, model: Optional[str] = None) -> Dict[str, Any]:
|
|
3142
3385
|
"""
|
|
3143
3386
|
Get information about the deployed endpoint.
|
|
@@ -3182,6 +3425,9 @@ class AsyncInferenceClient:
|
|
|
3182
3425
|
}
|
|
3183
3426
|
```
|
|
3184
3427
|
"""
|
|
3428
|
+
if self.provider != "hf-inference":
|
|
3429
|
+
raise ValueError(f"Getting endpoint info is not supported on '{self.provider}'.")
|
|
3430
|
+
|
|
3185
3431
|
model = model or self.model
|
|
3186
3432
|
if model is None:
|
|
3187
3433
|
raise ValueError("Model id not provided.")
|
|
@@ -3190,7 +3436,7 @@ class AsyncInferenceClient:
|
|
|
3190
3436
|
else:
|
|
3191
3437
|
url = f"{INFERENCE_ENDPOINT}/models/{model}/info"
|
|
3192
3438
|
|
|
3193
|
-
async with self._get_client_session() as client:
|
|
3439
|
+
async with self._get_client_session(headers=build_hf_headers(token=self.token)) as client:
|
|
3194
3440
|
response = await client.get(url, proxy=self.proxies)
|
|
3195
3441
|
response.raise_for_status()
|
|
3196
3442
|
return await response.json()
|
|
@@ -3218,6 +3464,9 @@ class AsyncInferenceClient:
|
|
|
3218
3464
|
True
|
|
3219
3465
|
```
|
|
3220
3466
|
"""
|
|
3467
|
+
if self.provider != "hf-inference":
|
|
3468
|
+
raise ValueError(f"Health check is not supported on '{self.provider}'.")
|
|
3469
|
+
|
|
3221
3470
|
model = model or self.model
|
|
3222
3471
|
if model is None:
|
|
3223
3472
|
raise ValueError("Model id not provided.")
|
|
@@ -3227,7 +3476,7 @@ class AsyncInferenceClient:
|
|
|
3227
3476
|
)
|
|
3228
3477
|
url = model.rstrip("/") + "/health"
|
|
3229
3478
|
|
|
3230
|
-
async with self._get_client_session() as client:
|
|
3479
|
+
async with self._get_client_session(headers=build_hf_headers(token=self.token)) as client:
|
|
3231
3480
|
response = await client.get(url, proxy=self.proxies)
|
|
3232
3481
|
return response.status == 200
|
|
3233
3482
|
|
|
@@ -3262,6 +3511,9 @@ class AsyncInferenceClient:
|
|
|
3262
3511
|
ModelStatus(loaded=True, state='Loaded', compute_type='gpu', framework='text-generation-inference')
|
|
3263
3512
|
```
|
|
3264
3513
|
"""
|
|
3514
|
+
if self.provider != "hf-inference":
|
|
3515
|
+
raise ValueError(f"Getting model status is not supported on '{self.provider}'.")
|
|
3516
|
+
|
|
3265
3517
|
model = model or self.model
|
|
3266
3518
|
if model is None:
|
|
3267
3519
|
raise ValueError("Model id not provided.")
|
|
@@ -3269,7 +3521,7 @@ class AsyncInferenceClient:
|
|
|
3269
3521
|
raise NotImplementedError("Model status is only available for Inference API endpoints.")
|
|
3270
3522
|
url = f"{INFERENCE_ENDPOINT}/status/{model}"
|
|
3271
3523
|
|
|
3272
|
-
async with self._get_client_session() as client:
|
|
3524
|
+
async with self._get_client_session(headers=build_hf_headers(token=self.token)) as client:
|
|
3273
3525
|
response = await client.get(url, proxy=self.proxies)
|
|
3274
3526
|
response.raise_for_status()
|
|
3275
3527
|
response_data = await response.json()
|