huggingface-hub 0.27.1__py3-none-any.whl → 0.28.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of huggingface-hub might be problematic. Click here for more details.
- huggingface_hub/__init__.py +418 -12
- huggingface_hub/_commit_api.py +33 -4
- huggingface_hub/_inference_endpoints.py +8 -2
- huggingface_hub/_local_folder.py +14 -3
- huggingface_hub/commands/scan_cache.py +1 -1
- huggingface_hub/commands/upload_large_folder.py +1 -1
- huggingface_hub/constants.py +7 -2
- huggingface_hub/file_download.py +1 -2
- huggingface_hub/hf_api.py +65 -84
- huggingface_hub/inference/_client.py +706 -450
- huggingface_hub/inference/_common.py +32 -64
- huggingface_hub/inference/_generated/_async_client.py +722 -470
- huggingface_hub/inference/_generated/types/__init__.py +1 -0
- huggingface_hub/inference/_generated/types/image_to_image.py +3 -3
- huggingface_hub/inference/_generated/types/text_to_audio.py +1 -2
- huggingface_hub/inference/_generated/types/text_to_image.py +3 -3
- huggingface_hub/inference/_generated/types/text_to_speech.py +3 -6
- huggingface_hub/inference/_generated/types/text_to_video.py +47 -0
- huggingface_hub/inference/_generated/types/visual_question_answering.py +1 -1
- huggingface_hub/inference/_providers/__init__.py +89 -0
- huggingface_hub/inference/_providers/fal_ai.py +159 -0
- huggingface_hub/inference/_providers/hf_inference.py +202 -0
- huggingface_hub/inference/_providers/replicate.py +148 -0
- huggingface_hub/inference/_providers/sambanova.py +89 -0
- huggingface_hub/inference/_providers/together.py +153 -0
- huggingface_hub/py.typed +0 -0
- huggingface_hub/repocard.py +1 -1
- huggingface_hub/repocard_data.py +2 -1
- huggingface_hub/serialization/_base.py +1 -1
- huggingface_hub/serialization/_torch.py +1 -1
- huggingface_hub/utils/_fixes.py +25 -13
- huggingface_hub/utils/_http.py +3 -3
- huggingface_hub/utils/logging.py +1 -1
- {huggingface_hub-0.27.1.dist-info → huggingface_hub-0.28.0.dist-info}/METADATA +4 -4
- {huggingface_hub-0.27.1.dist-info → huggingface_hub-0.28.0.dist-info}/RECORD +39 -31
- {huggingface_hub-0.27.1.dist-info → huggingface_hub-0.28.0.dist-info}/LICENSE +0 -0
- {huggingface_hub-0.27.1.dist-info → huggingface_hub-0.28.0.dist-info}/WHEEL +0 -0
- {huggingface_hub-0.27.1.dist-info → huggingface_hub-0.28.0.dist-info}/entry_points.txt +0 -0
- {huggingface_hub-0.27.1.dist-info → huggingface_hub-0.28.0.dist-info}/top_level.txt +0 -0
|
@@ -40,7 +40,6 @@ import warnings
|
|
|
40
40
|
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Literal, Optional, Union, overload
|
|
41
41
|
|
|
42
42
|
from requests import HTTPError
|
|
43
|
-
from requests.structures import CaseInsensitiveDict
|
|
44
43
|
|
|
45
44
|
from huggingface_hub.constants import ALL_INFERENCE_API_FRAMEWORKS, INFERENCE_ENDPOINT, MAIN_INFERENCE_API_FRAMEWORKS
|
|
46
45
|
from huggingface_hub.errors import BadRequestError, InferenceTimeoutError
|
|
@@ -48,16 +47,15 @@ from huggingface_hub.inference._common import (
|
|
|
48
47
|
TASKS_EXPECTING_IMAGES,
|
|
49
48
|
ContentT,
|
|
50
49
|
ModelStatus,
|
|
50
|
+
RequestParameters,
|
|
51
51
|
_b64_encode,
|
|
52
52
|
_b64_to_image,
|
|
53
53
|
_bytes_to_dict,
|
|
54
54
|
_bytes_to_image,
|
|
55
55
|
_bytes_to_list,
|
|
56
|
-
_fetch_recommended_models,
|
|
57
56
|
_get_unsupported_text_generation_kwargs,
|
|
58
57
|
_import_numpy,
|
|
59
58
|
_open_as_binary,
|
|
60
|
-
_prepare_payload,
|
|
61
59
|
_set_unsupported_text_generation_kwargs,
|
|
62
60
|
_stream_chat_completion_response,
|
|
63
61
|
_stream_text_generation_response,
|
|
@@ -104,8 +102,9 @@ from huggingface_hub.inference._generated.types import (
|
|
|
104
102
|
ZeroShotClassificationOutputElement,
|
|
105
103
|
ZeroShotImageClassificationOutputElement,
|
|
106
104
|
)
|
|
105
|
+
from huggingface_hub.inference._providers import PROVIDER_T, HFInferenceTask, get_provider_helper
|
|
107
106
|
from huggingface_hub.utils import build_hf_headers, get_session, hf_raise_for_status
|
|
108
|
-
from huggingface_hub.utils._deprecation import _deprecate_arguments
|
|
107
|
+
from huggingface_hub.utils._deprecation import _deprecate_arguments, _deprecate_method
|
|
109
108
|
|
|
110
109
|
|
|
111
110
|
if TYPE_CHECKING:
|
|
@@ -123,7 +122,7 @@ class InferenceClient:
|
|
|
123
122
|
Initialize a new Inference Client.
|
|
124
123
|
|
|
125
124
|
[`InferenceClient`] aims to provide a unified experience to perform inference. The client can be used
|
|
126
|
-
seamlessly with either the (free) Inference API
|
|
125
|
+
seamlessly with either the (free) Inference API, self-hosted Inference Endpoints, or third-party Inference Providers.
|
|
127
126
|
|
|
128
127
|
Args:
|
|
129
128
|
model (`str`, `optional`):
|
|
@@ -134,6 +133,10 @@ class InferenceClient:
|
|
|
134
133
|
arguments are mutually exclusive. If using `base_url` for chat completion, the `/chat/completions` suffix
|
|
135
134
|
path will be appended to the base URL (see the [TGI Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api)
|
|
136
135
|
documentation for details). When passing a URL as `model`, the client will not append any suffix path to it.
|
|
136
|
+
provider (`str`, *optional*):
|
|
137
|
+
Name of the provider to use for inference. Can be `"replicate"`, `"together"`, `"fal-ai"`, `"sambanova"` or `"hf-inference"`.
|
|
138
|
+
defaults to hf-inference (Hugging Face Serverless Inference API).
|
|
139
|
+
If model is a URL or `base_url` is passed, then `provider` is not used.
|
|
137
140
|
token (`str` or `bool`, *optional*):
|
|
138
141
|
Hugging Face token. Will default to the locally saved token if not provided.
|
|
139
142
|
Pass `token=False` if you don't want to send your token to the server.
|
|
@@ -161,7 +164,8 @@ class InferenceClient:
|
|
|
161
164
|
self,
|
|
162
165
|
model: Optional[str] = None,
|
|
163
166
|
*,
|
|
164
|
-
|
|
167
|
+
provider: Optional[PROVIDER_T] = None,
|
|
168
|
+
token: Optional[str] = None,
|
|
165
169
|
timeout: Optional[float] = None,
|
|
166
170
|
headers: Optional[Dict[str, str]] = None,
|
|
167
171
|
cookies: Optional[Dict[str, str]] = None,
|
|
@@ -185,12 +189,12 @@ class InferenceClient:
|
|
|
185
189
|
)
|
|
186
190
|
|
|
187
191
|
self.model: Optional[str] = model
|
|
188
|
-
self.token:
|
|
189
|
-
self.headers
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
if
|
|
193
|
-
|
|
192
|
+
self.token: Optional[str] = token if token is not None else api_key
|
|
193
|
+
self.headers = headers if headers is not None else {}
|
|
194
|
+
|
|
195
|
+
# Configure provider
|
|
196
|
+
self.provider = provider if provider is not None else "hf-inference"
|
|
197
|
+
|
|
194
198
|
self.cookies = cookies
|
|
195
199
|
self.timeout = timeout
|
|
196
200
|
self.proxies = proxies
|
|
@@ -234,6 +238,14 @@ class InferenceClient:
|
|
|
234
238
|
stream: bool = False,
|
|
235
239
|
) -> Union[bytes, Iterable[bytes]]: ...
|
|
236
240
|
|
|
241
|
+
@_deprecate_method(
|
|
242
|
+
version="0.31.0",
|
|
243
|
+
message=(
|
|
244
|
+
"Making direct POST requests to the inference server is not supported anymore. "
|
|
245
|
+
"Please use task methods instead (e.g. `InferenceClient.chat_completion`). "
|
|
246
|
+
"If your use case is not supported, please open an issue in https://github.com/huggingface/huggingface_hub."
|
|
247
|
+
),
|
|
248
|
+
)
|
|
237
249
|
def post(
|
|
238
250
|
self,
|
|
239
251
|
*,
|
|
@@ -246,53 +258,62 @@ class InferenceClient:
|
|
|
246
258
|
"""
|
|
247
259
|
Make a POST request to the inference server.
|
|
248
260
|
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
261
|
+
This method is deprecated and will be removed in the future.
|
|
262
|
+
Please use task methods instead (e.g. `InferenceClient.chat_completion`).
|
|
263
|
+
"""
|
|
264
|
+
if self.provider != "hf-inference":
|
|
265
|
+
raise ValueError(
|
|
266
|
+
"Cannot use `post` with another provider than `hf-inference`. "
|
|
267
|
+
"`InferenceClient.post` is deprecated and should not be used directly anymore."
|
|
268
|
+
)
|
|
269
|
+
provider_helper = HFInferenceTask(task or "unknown")
|
|
270
|
+
url = provider_helper.build_url(provider_helper.map_model(model))
|
|
271
|
+
headers = provider_helper.prepare_headers(headers=self.headers, api_key=self.token)
|
|
272
|
+
return self._inner_post(
|
|
273
|
+
request_parameters=RequestParameters(
|
|
274
|
+
url=url,
|
|
275
|
+
task=task or "unknown",
|
|
276
|
+
model=model or "unknown",
|
|
277
|
+
json=json,
|
|
278
|
+
data=data,
|
|
279
|
+
headers=headers,
|
|
280
|
+
),
|
|
281
|
+
stream=stream,
|
|
282
|
+
)
|
|
266
283
|
|
|
267
|
-
|
|
268
|
-
|
|
284
|
+
@overload
|
|
285
|
+
def _inner_post( # type: ignore[misc]
|
|
286
|
+
self, request_parameters: RequestParameters, *, stream: Literal[False] = ...
|
|
287
|
+
) -> bytes: ...
|
|
269
288
|
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
If the request fails with an HTTP error status code other than HTTP 503.
|
|
275
|
-
"""
|
|
276
|
-
url = self._resolve_url(model, task)
|
|
289
|
+
@overload
|
|
290
|
+
def _inner_post( # type: ignore[misc]
|
|
291
|
+
self, request_parameters: RequestParameters, *, stream: Literal[True] = ...
|
|
292
|
+
) -> Iterable[bytes]: ...
|
|
277
293
|
|
|
278
|
-
|
|
279
|
-
|
|
294
|
+
@overload
|
|
295
|
+
def _inner_post(
|
|
296
|
+
self, request_parameters: RequestParameters, *, stream: bool = False
|
|
297
|
+
) -> Union[bytes, Iterable[bytes]]: ...
|
|
280
298
|
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
299
|
+
def _inner_post(
|
|
300
|
+
self, request_parameters: RequestParameters, *, stream: bool = False
|
|
301
|
+
) -> Union[bytes, Iterable[bytes]]:
|
|
302
|
+
"""Make a request to the inference server."""
|
|
303
|
+
# TODO: this should be handled in provider helpers directly
|
|
304
|
+
if request_parameters.task in TASKS_EXPECTING_IMAGES and "Accept" not in request_parameters.headers:
|
|
305
|
+
request_parameters.headers["Accept"] = "image/png"
|
|
285
306
|
|
|
286
307
|
t0 = time.time()
|
|
287
308
|
timeout = self.timeout
|
|
288
309
|
while True:
|
|
289
|
-
with _open_as_binary(data) as data_as_binary:
|
|
310
|
+
with _open_as_binary(request_parameters.data) as data_as_binary:
|
|
290
311
|
try:
|
|
291
312
|
response = get_session().post(
|
|
292
|
-
url,
|
|
293
|
-
json=json,
|
|
313
|
+
request_parameters.url,
|
|
314
|
+
json=request_parameters.json,
|
|
294
315
|
data=data_as_binary,
|
|
295
|
-
headers=headers,
|
|
316
|
+
headers=request_parameters.headers,
|
|
296
317
|
cookies=self.cookies,
|
|
297
318
|
timeout=self.timeout,
|
|
298
319
|
stream=stream,
|
|
@@ -300,21 +321,21 @@ class InferenceClient:
|
|
|
300
321
|
)
|
|
301
322
|
except TimeoutError as error:
|
|
302
323
|
# Convert any `TimeoutError` to a `InferenceTimeoutError`
|
|
303
|
-
raise InferenceTimeoutError(f"Inference call timed out: {url}") from error # type: ignore
|
|
324
|
+
raise InferenceTimeoutError(f"Inference call timed out: {request_parameters.url}") from error # type: ignore
|
|
304
325
|
|
|
305
326
|
try:
|
|
306
327
|
hf_raise_for_status(response)
|
|
307
328
|
return response.iter_lines() if stream else response.content
|
|
308
329
|
except HTTPError as error:
|
|
309
|
-
if error.response.status_code == 422 and task
|
|
330
|
+
if error.response.status_code == 422 and request_parameters.task != "unknown":
|
|
310
331
|
error.args = (
|
|
311
|
-
f"{error.args[0]}\nMake sure '{task}' task is supported by the model.",
|
|
332
|
+
f"{error.args[0]}\nMake sure '{request_parameters.task}' task is supported by the model.",
|
|
312
333
|
) + error.args[1:]
|
|
313
334
|
if error.response.status_code == 503:
|
|
314
335
|
# If Model is unavailable, either raise a TimeoutError...
|
|
315
336
|
if timeout is not None and time.time() - t0 > timeout:
|
|
316
337
|
raise InferenceTimeoutError(
|
|
317
|
-
f"Model not loaded on the server: {url}. Please retry with a higher timeout (current:"
|
|
338
|
+
f"Model not loaded on the server: {request_parameters.url}. Please retry with a higher timeout (current:"
|
|
318
339
|
f" {self.timeout}).",
|
|
319
340
|
request=error.request,
|
|
320
341
|
response=error.response,
|
|
@@ -322,8 +343,10 @@ class InferenceClient:
|
|
|
322
343
|
# ...or wait 1s and retry
|
|
323
344
|
logger.info(f"Waiting for model to be loaded on the server: {error}")
|
|
324
345
|
time.sleep(1)
|
|
325
|
-
if "X-wait-for-model" not in headers and url.startswith(
|
|
326
|
-
|
|
346
|
+
if "X-wait-for-model" not in request_parameters.headers and request_parameters.url.startswith(
|
|
347
|
+
INFERENCE_ENDPOINT
|
|
348
|
+
):
|
|
349
|
+
request_parameters.headers["X-wait-for-model"] = "1"
|
|
327
350
|
if timeout is not None:
|
|
328
351
|
timeout = max(self.timeout - (time.time() - t0), 1) # type: ignore
|
|
329
352
|
continue
|
|
@@ -374,9 +397,15 @@ class InferenceClient:
|
|
|
374
397
|
]
|
|
375
398
|
```
|
|
376
399
|
"""
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
400
|
+
provider_helper = get_provider_helper(self.provider, task="audio-classification")
|
|
401
|
+
request_parameters = provider_helper.prepare_request(
|
|
402
|
+
inputs=audio,
|
|
403
|
+
parameters={"function_to_apply": function_to_apply, "top_k": top_k},
|
|
404
|
+
headers=self.headers,
|
|
405
|
+
model=model or self.model,
|
|
406
|
+
api_key=self.token,
|
|
407
|
+
)
|
|
408
|
+
response = self._inner_post(request_parameters)
|
|
380
409
|
return AudioClassificationOutputElement.parse_obj_as_list(response)
|
|
381
410
|
|
|
382
411
|
def audio_to_audio(
|
|
@@ -416,7 +445,15 @@ class InferenceClient:
|
|
|
416
445
|
f.write(item.blob)
|
|
417
446
|
```
|
|
418
447
|
"""
|
|
419
|
-
|
|
448
|
+
provider_helper = get_provider_helper(self.provider, task="audio-to-audio")
|
|
449
|
+
request_parameters = provider_helper.prepare_request(
|
|
450
|
+
inputs=audio,
|
|
451
|
+
parameters={},
|
|
452
|
+
headers=self.headers,
|
|
453
|
+
model=model or self.model,
|
|
454
|
+
api_key=self.token,
|
|
455
|
+
)
|
|
456
|
+
response = self._inner_post(request_parameters)
|
|
420
457
|
audio_output = AudioToAudioOutputElement.parse_obj_as_list(response)
|
|
421
458
|
for item in audio_output:
|
|
422
459
|
item.blob = base64.b64decode(item.blob)
|
|
@@ -437,7 +474,8 @@ class InferenceClient:
|
|
|
437
474
|
model (`str`, *optional*):
|
|
438
475
|
The model to use for ASR. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
|
439
476
|
Inference Endpoint. If not provided, the default recommended model for ASR will be used.
|
|
440
|
-
|
|
477
|
+
parameters (Dict[str, Any], *optional*):
|
|
478
|
+
Additional parameters to pass to the model.
|
|
441
479
|
Returns:
|
|
442
480
|
[`AutomaticSpeechRecognitionOutput`]: An item containing the transcribed text and optionally the timestamp chunks.
|
|
443
481
|
|
|
@@ -455,7 +493,15 @@ class InferenceClient:
|
|
|
455
493
|
"hello world"
|
|
456
494
|
```
|
|
457
495
|
"""
|
|
458
|
-
|
|
496
|
+
provider_helper = get_provider_helper(self.provider, task="automatic-speech-recognition")
|
|
497
|
+
request_parameters = provider_helper.prepare_request(
|
|
498
|
+
inputs=audio,
|
|
499
|
+
parameters={},
|
|
500
|
+
headers=self.headers,
|
|
501
|
+
model=model or self.model,
|
|
502
|
+
api_key=self.token,
|
|
503
|
+
)
|
|
504
|
+
response = self._inner_post(request_parameters)
|
|
459
505
|
return AutomaticSpeechRecognitionOutput.parse_obj_as_instance(response)
|
|
460
506
|
|
|
461
507
|
@overload
|
|
@@ -569,6 +615,10 @@ class InferenceClient:
|
|
|
569
615
|
|
|
570
616
|
</Tip>
|
|
571
617
|
|
|
618
|
+
<Tip>
|
|
619
|
+
Some parameters might not be supported by some providers.
|
|
620
|
+
</Tip>
|
|
621
|
+
|
|
572
622
|
Args:
|
|
573
623
|
messages (List of [`ChatCompletionInputMessage`]):
|
|
574
624
|
Conversation history consisting of roles and content pairs.
|
|
@@ -576,25 +626,20 @@ class InferenceClient:
|
|
|
576
626
|
The model to use for chat-completion. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
|
577
627
|
Inference Endpoint. If not provided, the default recommended model for chat-based text-generation will be used.
|
|
578
628
|
See https://huggingface.co/tasks/text-generation for more details.
|
|
579
|
-
|
|
580
629
|
If `model` is a model ID, it is passed to the server as the `model` parameter. If you want to define a
|
|
581
630
|
custom URL while setting `model` in the request payload, you must set `base_url` when initializing [`InferenceClient`].
|
|
582
631
|
frequency_penalty (`float`, *optional*):
|
|
583
632
|
Penalizes new tokens based on their existing frequency
|
|
584
633
|
in the text so far. Range: [-2.0, 2.0]. Defaults to 0.0.
|
|
585
634
|
logit_bias (`List[float]`, *optional*):
|
|
586
|
-
|
|
587
|
-
(specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,
|
|
588
|
-
the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,
|
|
589
|
-
but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should
|
|
590
|
-
result in a ban or exclusive selection of the relevant token. Defaults to None.
|
|
635
|
+
Adjusts the likelihood of specific tokens appearing in the generated output.
|
|
591
636
|
logprobs (`bool`, *optional*):
|
|
592
637
|
Whether to return log probabilities of the output tokens or not. If true, returns the log
|
|
593
638
|
probabilities of each output token returned in the content of message.
|
|
594
639
|
max_tokens (`int`, *optional*):
|
|
595
640
|
Maximum number of tokens allowed in the response. Defaults to 100.
|
|
596
641
|
n (`int`, *optional*):
|
|
597
|
-
|
|
642
|
+
The number of completions to generate for each prompt.
|
|
598
643
|
presence_penalty (`float`, *optional*):
|
|
599
644
|
Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the
|
|
600
645
|
text so far, increasing the model's likelihood to talk about new topics.
|
|
@@ -602,7 +647,7 @@ class InferenceClient:
|
|
|
602
647
|
Grammar constraints. Can be either a JSONSchema or a regex.
|
|
603
648
|
seed (Optional[`int`], *optional*):
|
|
604
649
|
Seed for reproducible control flow. Defaults to None.
|
|
605
|
-
stop (
|
|
650
|
+
stop (`List[str]`, *optional*):
|
|
606
651
|
Up to four strings which trigger the end of the response.
|
|
607
652
|
Defaults to None.
|
|
608
653
|
stream (`bool`, *optional*):
|
|
@@ -711,6 +756,32 @@ class InferenceClient:
|
|
|
711
756
|
print(chunk.choices[0].delta.content)
|
|
712
757
|
```
|
|
713
758
|
|
|
759
|
+
Example using a third-party provider directly. Usage will be billed on your Together AI account.
|
|
760
|
+
```py
|
|
761
|
+
>>> from huggingface_hub import InferenceClient
|
|
762
|
+
>>> client = InferenceClient(
|
|
763
|
+
... provider="together", # Use Together AI provider
|
|
764
|
+
... api_key="<together_api_key>", # Pass your Together API key directly
|
|
765
|
+
... )
|
|
766
|
+
>>> client.chat_completion(
|
|
767
|
+
... model="meta-llama/Meta-Llama-3-8B-Instruct",
|
|
768
|
+
... messages=[{"role": "user", "content": "What is the capital of France?"}],
|
|
769
|
+
... )
|
|
770
|
+
```
|
|
771
|
+
|
|
772
|
+
Example using a third-party provider through Hugging Face Routing. Usage will be billed on your Hugging Face account.
|
|
773
|
+
```py
|
|
774
|
+
>>> from huggingface_hub import InferenceClient
|
|
775
|
+
>>> client = InferenceClient(
|
|
776
|
+
... provider="sambanova", # Use Sambanova provider
|
|
777
|
+
... api_key="hf_...", # Pass your HF token
|
|
778
|
+
... )
|
|
779
|
+
>>> client.chat_completion(
|
|
780
|
+
... model="meta-llama/Meta-Llama-3-8B-Instruct",
|
|
781
|
+
... messages=[{"role": "user", "content": "What is the capital of France?"}],
|
|
782
|
+
... )
|
|
783
|
+
```
|
|
784
|
+
|
|
714
785
|
Example using Image + Text as input:
|
|
715
786
|
```py
|
|
716
787
|
>>> from huggingface_hub import InferenceClient
|
|
@@ -859,68 +930,50 @@ class InferenceClient:
|
|
|
859
930
|
'{\n\n"activity": "bike ride",\n"animals": ["puppy", "cat", "raccoon"],\n"animals_seen": 3,\n"location": "park"}'
|
|
860
931
|
```
|
|
861
932
|
"""
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
#
|
|
866
|
-
|
|
867
|
-
|
|
868
|
-
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
|
|
875
|
-
|
|
876
|
-
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
|
|
880
|
-
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
|
|
933
|
+
# Get the provider helper
|
|
934
|
+
provider_helper = get_provider_helper(self.provider, task="conversational")
|
|
935
|
+
|
|
936
|
+
# Since `chat_completion(..., model=xxx)` is also a payload parameter for the server, we need to handle 'model' differently.
|
|
937
|
+
# `self.base_url` and `self.model` takes precedence over 'model' argument for building URL.
|
|
938
|
+
# `model` takes precedence for payload value.
|
|
939
|
+
model_id_or_url = self.base_url or self.model or model
|
|
940
|
+
payload_model = model or self.model
|
|
941
|
+
|
|
942
|
+
# Prepare the payload
|
|
943
|
+
parameters = {
|
|
944
|
+
"model": payload_model,
|
|
945
|
+
"frequency_penalty": frequency_penalty,
|
|
946
|
+
"logit_bias": logit_bias,
|
|
947
|
+
"logprobs": logprobs,
|
|
948
|
+
"max_tokens": max_tokens,
|
|
949
|
+
"n": n,
|
|
950
|
+
"presence_penalty": presence_penalty,
|
|
951
|
+
"response_format": response_format,
|
|
952
|
+
"seed": seed,
|
|
953
|
+
"stop": stop,
|
|
954
|
+
"temperature": temperature,
|
|
955
|
+
"tool_choice": tool_choice,
|
|
956
|
+
"tool_prompt": tool_prompt,
|
|
957
|
+
"tools": tools,
|
|
958
|
+
"top_logprobs": top_logprobs,
|
|
959
|
+
"top_p": top_p,
|
|
960
|
+
"stream": stream,
|
|
961
|
+
"stream_options": stream_options,
|
|
962
|
+
}
|
|
963
|
+
request_parameters = provider_helper.prepare_request(
|
|
964
|
+
inputs=messages,
|
|
965
|
+
parameters=parameters,
|
|
966
|
+
headers=self.headers,
|
|
967
|
+
model=model_id_or_url,
|
|
968
|
+
api_key=self.token,
|
|
890
969
|
)
|
|
891
|
-
|
|
892
|
-
data = self.post(model=model_url, json=payload, stream=stream)
|
|
970
|
+
data = self._inner_post(request_parameters, stream=stream)
|
|
893
971
|
|
|
894
972
|
if stream:
|
|
895
973
|
return _stream_chat_completion_response(data) # type: ignore[arg-type]
|
|
896
974
|
|
|
897
975
|
return ChatCompletionOutput.parse_obj_as_instance(data) # type: ignore[arg-type]
|
|
898
976
|
|
|
899
|
-
def _resolve_chat_completion_url(self, model: Optional[str] = None) -> str:
|
|
900
|
-
# Since `chat_completion(..., model=xxx)` is also a payload parameter for the server, we need to handle 'model' differently.
|
|
901
|
-
# `self.base_url` and `self.model` takes precedence over 'model' argument only in `chat_completion`.
|
|
902
|
-
model_id_or_url = self.base_url or self.model or model or self.get_recommended_model("text-generation")
|
|
903
|
-
|
|
904
|
-
# Resolve URL if it's a model ID
|
|
905
|
-
model_url = (
|
|
906
|
-
model_id_or_url
|
|
907
|
-
if model_id_or_url.startswith(("http://", "https://"))
|
|
908
|
-
else self._resolve_url(model_id_or_url, task="text-generation")
|
|
909
|
-
)
|
|
910
|
-
|
|
911
|
-
# Strip trailing /
|
|
912
|
-
model_url = model_url.rstrip("/")
|
|
913
|
-
|
|
914
|
-
# Append /chat/completions if not already present
|
|
915
|
-
if model_url.endswith("/v1"):
|
|
916
|
-
model_url += "/chat/completions"
|
|
917
|
-
|
|
918
|
-
# Append /v1/chat/completions if not already present
|
|
919
|
-
if not model_url.endswith("/chat/completions"):
|
|
920
|
-
model_url += "/v1/chat/completions"
|
|
921
|
-
|
|
922
|
-
return model_url
|
|
923
|
-
|
|
924
977
|
def document_question_answering(
|
|
925
978
|
self,
|
|
926
979
|
image: ContentT,
|
|
@@ -987,18 +1040,24 @@ class InferenceClient:
|
|
|
987
1040
|
```
|
|
988
1041
|
"""
|
|
989
1042
|
inputs: Dict[str, Any] = {"question": question, "image": _b64_encode(image)}
|
|
990
|
-
|
|
991
|
-
|
|
992
|
-
|
|
993
|
-
|
|
994
|
-
|
|
995
|
-
|
|
996
|
-
|
|
997
|
-
|
|
998
|
-
|
|
999
|
-
|
|
1000
|
-
|
|
1001
|
-
|
|
1043
|
+
provider_helper = get_provider_helper(self.provider, task="document-question-answering")
|
|
1044
|
+
request_parameters = provider_helper.prepare_request(
|
|
1045
|
+
inputs=inputs,
|
|
1046
|
+
parameters={
|
|
1047
|
+
"doc_stride": doc_stride,
|
|
1048
|
+
"handle_impossible_answer": handle_impossible_answer,
|
|
1049
|
+
"lang": lang,
|
|
1050
|
+
"max_answer_len": max_answer_len,
|
|
1051
|
+
"max_question_len": max_question_len,
|
|
1052
|
+
"max_seq_len": max_seq_len,
|
|
1053
|
+
"top_k": top_k,
|
|
1054
|
+
"word_boxes": word_boxes,
|
|
1055
|
+
},
|
|
1056
|
+
headers=self.headers,
|
|
1057
|
+
model=model or self.model,
|
|
1058
|
+
api_key=self.token,
|
|
1059
|
+
)
|
|
1060
|
+
response = self._inner_post(request_parameters)
|
|
1002
1061
|
return DocumentQuestionAnsweringOutputElement.parse_obj_as_list(response)
|
|
1003
1062
|
|
|
1004
1063
|
def feature_extraction(
|
|
@@ -1056,14 +1115,20 @@ class InferenceClient:
|
|
|
1056
1115
|
[ 0.28552425, -0.928395 , -1.2077185 , ..., 0.76810825, -2.1069427 , 0.6236161 ]], dtype=float32)
|
|
1057
1116
|
```
|
|
1058
1117
|
"""
|
|
1059
|
-
|
|
1060
|
-
|
|
1061
|
-
|
|
1062
|
-
|
|
1063
|
-
|
|
1064
|
-
|
|
1065
|
-
|
|
1066
|
-
|
|
1118
|
+
provider_helper = get_provider_helper(self.provider, task="feature-extraction")
|
|
1119
|
+
request_parameters = provider_helper.prepare_request(
|
|
1120
|
+
inputs=text,
|
|
1121
|
+
parameters={
|
|
1122
|
+
"normalize": normalize,
|
|
1123
|
+
"prompt_name": prompt_name,
|
|
1124
|
+
"truncate": truncate,
|
|
1125
|
+
"truncation_direction": truncation_direction,
|
|
1126
|
+
},
|
|
1127
|
+
headers=self.headers,
|
|
1128
|
+
model=model or self.model,
|
|
1129
|
+
api_key=self.token,
|
|
1130
|
+
)
|
|
1131
|
+
response = self._inner_post(request_parameters)
|
|
1067
1132
|
np = _import_numpy()
|
|
1068
1133
|
return np.array(_bytes_to_dict(response), dtype="float32")
|
|
1069
1134
|
|
|
@@ -1111,9 +1176,15 @@ class InferenceClient:
|
|
|
1111
1176
|
]
|
|
1112
1177
|
```
|
|
1113
1178
|
"""
|
|
1114
|
-
|
|
1115
|
-
|
|
1116
|
-
|
|
1179
|
+
provider_helper = get_provider_helper(self.provider, task="fill-mask")
|
|
1180
|
+
request_parameters = provider_helper.prepare_request(
|
|
1181
|
+
inputs=text,
|
|
1182
|
+
parameters={"targets": targets, "top_k": top_k},
|
|
1183
|
+
headers=self.headers,
|
|
1184
|
+
model=model or self.model,
|
|
1185
|
+
api_key=self.token,
|
|
1186
|
+
)
|
|
1187
|
+
response = self._inner_post(request_parameters)
|
|
1117
1188
|
return FillMaskOutputElement.parse_obj_as_list(response)
|
|
1118
1189
|
|
|
1119
1190
|
def image_classification(
|
|
@@ -1154,9 +1225,15 @@ class InferenceClient:
|
|
|
1154
1225
|
[ImageClassificationOutputElement(label='Blenheim spaniel', score=0.9779096841812134), ...]
|
|
1155
1226
|
```
|
|
1156
1227
|
"""
|
|
1157
|
-
|
|
1158
|
-
|
|
1159
|
-
|
|
1228
|
+
provider_helper = get_provider_helper(self.provider, task="image-classification")
|
|
1229
|
+
request_parameters = provider_helper.prepare_request(
|
|
1230
|
+
inputs=image,
|
|
1231
|
+
parameters={"function_to_apply": function_to_apply, "top_k": top_k},
|
|
1232
|
+
headers=self.headers,
|
|
1233
|
+
model=model or self.model,
|
|
1234
|
+
api_key=self.token,
|
|
1235
|
+
)
|
|
1236
|
+
response = self._inner_post(request_parameters)
|
|
1160
1237
|
return ImageClassificationOutputElement.parse_obj_as_list(response)
|
|
1161
1238
|
|
|
1162
1239
|
def image_segmentation(
|
|
@@ -1209,14 +1286,20 @@ class InferenceClient:
|
|
|
1209
1286
|
[ImageSegmentationOutputElement(score=0.989008, label='LABEL_184', mask=<PIL.PngImagePlugin.PngImageFile image mode=L size=400x300 at 0x7FDD2B129CC0>), ...]
|
|
1210
1287
|
```
|
|
1211
1288
|
"""
|
|
1212
|
-
|
|
1213
|
-
|
|
1214
|
-
|
|
1215
|
-
|
|
1216
|
-
|
|
1217
|
-
|
|
1218
|
-
|
|
1219
|
-
|
|
1289
|
+
provider_helper = get_provider_helper(self.provider, task="audio-classification")
|
|
1290
|
+
request_parameters = provider_helper.prepare_request(
|
|
1291
|
+
inputs=image,
|
|
1292
|
+
parameters={
|
|
1293
|
+
"mask_threshold": mask_threshold,
|
|
1294
|
+
"overlap_mask_area_threshold": overlap_mask_area_threshold,
|
|
1295
|
+
"subtask": subtask,
|
|
1296
|
+
"threshold": threshold,
|
|
1297
|
+
},
|
|
1298
|
+
headers=self.headers,
|
|
1299
|
+
model=model or self.model,
|
|
1300
|
+
api_key=self.token,
|
|
1301
|
+
)
|
|
1302
|
+
response = self._inner_post(request_parameters)
|
|
1220
1303
|
output = ImageSegmentationOutputElement.parse_obj_as_list(response)
|
|
1221
1304
|
for item in output:
|
|
1222
1305
|
item.mask = _b64_to_image(item.mask) # type: ignore [assignment]
|
|
@@ -1227,7 +1310,7 @@ class InferenceClient:
|
|
|
1227
1310
|
image: ContentT,
|
|
1228
1311
|
prompt: Optional[str] = None,
|
|
1229
1312
|
*,
|
|
1230
|
-
negative_prompt: Optional[
|
|
1313
|
+
negative_prompt: Optional[str] = None,
|
|
1231
1314
|
num_inference_steps: Optional[int] = None,
|
|
1232
1315
|
guidance_scale: Optional[float] = None,
|
|
1233
1316
|
model: Optional[str] = None,
|
|
@@ -1248,8 +1331,8 @@ class InferenceClient:
|
|
|
1248
1331
|
The input image for translation. It can be raw bytes, an image file, or a URL to an online image.
|
|
1249
1332
|
prompt (`str`, *optional*):
|
|
1250
1333
|
The text prompt to guide the image generation.
|
|
1251
|
-
negative_prompt (`
|
|
1252
|
-
One
|
|
1334
|
+
negative_prompt (`str`, *optional*):
|
|
1335
|
+
One prompt to guide what NOT to include in image generation.
|
|
1253
1336
|
num_inference_steps (`int`, *optional*):
|
|
1254
1337
|
For diffusion models. The number of denoising steps. More denoising steps usually lead to a higher
|
|
1255
1338
|
quality image at the expense of slower inference.
|
|
@@ -1279,16 +1362,22 @@ class InferenceClient:
|
|
|
1279
1362
|
>>> image.save("tiger.jpg")
|
|
1280
1363
|
```
|
|
1281
1364
|
"""
|
|
1282
|
-
|
|
1283
|
-
|
|
1284
|
-
|
|
1285
|
-
|
|
1286
|
-
|
|
1287
|
-
|
|
1288
|
-
|
|
1289
|
-
|
|
1290
|
-
|
|
1291
|
-
|
|
1365
|
+
provider_helper = get_provider_helper(self.provider, task="image-to-image")
|
|
1366
|
+
request_parameters = provider_helper.prepare_request(
|
|
1367
|
+
inputs=image,
|
|
1368
|
+
parameters={
|
|
1369
|
+
"prompt": prompt,
|
|
1370
|
+
"negative_prompt": negative_prompt,
|
|
1371
|
+
"target_size": target_size,
|
|
1372
|
+
"num_inference_steps": num_inference_steps,
|
|
1373
|
+
"guidance_scale": guidance_scale,
|
|
1374
|
+
**kwargs,
|
|
1375
|
+
},
|
|
1376
|
+
headers=self.headers,
|
|
1377
|
+
model=model or self.model,
|
|
1378
|
+
api_key=self.token,
|
|
1379
|
+
)
|
|
1380
|
+
response = self._inner_post(request_parameters)
|
|
1292
1381
|
return _bytes_to_image(response)
|
|
1293
1382
|
|
|
1294
1383
|
def image_to_text(self, image: ContentT, *, model: Optional[str] = None) -> ImageToTextOutput:
|
|
@@ -1324,93 +1413,18 @@ class InferenceClient:
|
|
|
1324
1413
|
'a dog laying on the grass next to a flower pot '
|
|
1325
1414
|
```
|
|
1326
1415
|
"""
|
|
1327
|
-
|
|
1416
|
+
provider_helper = get_provider_helper(self.provider, task="image-to-text")
|
|
1417
|
+
request_parameters = provider_helper.prepare_request(
|
|
1418
|
+
inputs=image,
|
|
1419
|
+
parameters={},
|
|
1420
|
+
headers=self.headers,
|
|
1421
|
+
model=model or self.model,
|
|
1422
|
+
api_key=self.token,
|
|
1423
|
+
)
|
|
1424
|
+
response = self._inner_post(request_parameters)
|
|
1328
1425
|
output = ImageToTextOutput.parse_obj(response)
|
|
1329
1426
|
return output[0] if isinstance(output, list) else output
|
|
1330
1427
|
|
|
1331
|
-
def list_deployed_models(
|
|
1332
|
-
self, frameworks: Union[None, str, Literal["all"], List[str]] = None
|
|
1333
|
-
) -> Dict[str, List[str]]:
|
|
1334
|
-
"""
|
|
1335
|
-
List models deployed on the Serverless Inference API service.
|
|
1336
|
-
|
|
1337
|
-
This helper checks deployed models framework by framework. By default, it will check the 4 main frameworks that
|
|
1338
|
-
are supported and account for 95% of the hosted models. However, if you want a complete list of models you can
|
|
1339
|
-
specify `frameworks="all"` as input. Alternatively, if you know before-hand which framework you are interested
|
|
1340
|
-
in, you can also restrict to search to this one (e.g. `frameworks="text-generation-inference"`). The more
|
|
1341
|
-
frameworks are checked, the more time it will take.
|
|
1342
|
-
|
|
1343
|
-
<Tip warning={true}>
|
|
1344
|
-
|
|
1345
|
-
This endpoint method does not return a live list of all models available for the Serverless Inference API service.
|
|
1346
|
-
It searches over a cached list of models that were recently available and the list may not be up to date.
|
|
1347
|
-
If you want to know the live status of a specific model, use [`~InferenceClient.get_model_status`].
|
|
1348
|
-
|
|
1349
|
-
</Tip>
|
|
1350
|
-
|
|
1351
|
-
<Tip>
|
|
1352
|
-
|
|
1353
|
-
This endpoint method is mostly useful for discoverability. If you already know which model you want to use and want to
|
|
1354
|
-
check its availability, you can directly use [`~InferenceClient.get_model_status`].
|
|
1355
|
-
|
|
1356
|
-
</Tip>
|
|
1357
|
-
|
|
1358
|
-
Args:
|
|
1359
|
-
frameworks (`Literal["all"]` or `List[str]` or `str`, *optional*):
|
|
1360
|
-
The frameworks to filter on. By default only a subset of the available frameworks are tested. If set to
|
|
1361
|
-
"all", all available frameworks will be tested. It is also possible to provide a single framework or a
|
|
1362
|
-
custom set of frameworks to check.
|
|
1363
|
-
|
|
1364
|
-
Returns:
|
|
1365
|
-
`Dict[str, List[str]]`: A dictionary mapping task names to a sorted list of model IDs.
|
|
1366
|
-
|
|
1367
|
-
Example:
|
|
1368
|
-
```python
|
|
1369
|
-
>>> from huggingface_hub import InferenceClient
|
|
1370
|
-
>>> client = InferenceClient()
|
|
1371
|
-
|
|
1372
|
-
# Discover zero-shot-classification models currently deployed
|
|
1373
|
-
>>> models = client.list_deployed_models()
|
|
1374
|
-
>>> models["zero-shot-classification"]
|
|
1375
|
-
['Narsil/deberta-large-mnli-zero-cls', 'facebook/bart-large-mnli', ...]
|
|
1376
|
-
|
|
1377
|
-
# List from only 1 framework
|
|
1378
|
-
>>> client.list_deployed_models("text-generation-inference")
|
|
1379
|
-
{'text-generation': ['bigcode/starcoder', 'meta-llama/Llama-2-70b-chat-hf', ...], ...}
|
|
1380
|
-
```
|
|
1381
|
-
"""
|
|
1382
|
-
# Resolve which frameworks to check
|
|
1383
|
-
if frameworks is None:
|
|
1384
|
-
frameworks = MAIN_INFERENCE_API_FRAMEWORKS
|
|
1385
|
-
elif frameworks == "all":
|
|
1386
|
-
frameworks = ALL_INFERENCE_API_FRAMEWORKS
|
|
1387
|
-
elif isinstance(frameworks, str):
|
|
1388
|
-
frameworks = [frameworks]
|
|
1389
|
-
frameworks = list(set(frameworks))
|
|
1390
|
-
|
|
1391
|
-
# Fetch them iteratively
|
|
1392
|
-
models_by_task: Dict[str, List[str]] = {}
|
|
1393
|
-
|
|
1394
|
-
def _unpack_response(framework: str, items: List[Dict]) -> None:
|
|
1395
|
-
for model in items:
|
|
1396
|
-
if framework == "sentence-transformers":
|
|
1397
|
-
# Model running with the `sentence-transformers` framework can work with both tasks even if not
|
|
1398
|
-
# branded as such in the API response
|
|
1399
|
-
models_by_task.setdefault("feature-extraction", []).append(model["model_id"])
|
|
1400
|
-
models_by_task.setdefault("sentence-similarity", []).append(model["model_id"])
|
|
1401
|
-
else:
|
|
1402
|
-
models_by_task.setdefault(model["task"], []).append(model["model_id"])
|
|
1403
|
-
|
|
1404
|
-
for framework in frameworks:
|
|
1405
|
-
response = get_session().get(f"{INFERENCE_ENDPOINT}/framework/{framework}", headers=self.headers)
|
|
1406
|
-
hf_raise_for_status(response)
|
|
1407
|
-
_unpack_response(framework, response.json())
|
|
1408
|
-
|
|
1409
|
-
# Sort alphabetically for discoverability and return
|
|
1410
|
-
for task, models in models_by_task.items():
|
|
1411
|
-
models_by_task[task] = sorted(set(models), key=lambda x: x.lower())
|
|
1412
|
-
return models_by_task
|
|
1413
|
-
|
|
1414
1428
|
def object_detection(
|
|
1415
1429
|
self, image: ContentT, *, model: Optional[str] = None, threshold: Optional[float] = None
|
|
1416
1430
|
) -> List[ObjectDetectionOutputElement]:
|
|
@@ -1450,11 +1464,15 @@ class InferenceClient:
|
|
|
1450
1464
|
[ObjectDetectionOutputElement(score=0.9486683011054993, label='person', box=ObjectDetectionBoundingBox(xmin=59, ymin=39, xmax=420, ymax=510)), ...]
|
|
1451
1465
|
```
|
|
1452
1466
|
"""
|
|
1453
|
-
|
|
1454
|
-
|
|
1455
|
-
|
|
1456
|
-
|
|
1457
|
-
|
|
1467
|
+
provider_helper = get_provider_helper(self.provider, task="object-detection")
|
|
1468
|
+
request_parameters = provider_helper.prepare_request(
|
|
1469
|
+
inputs=image,
|
|
1470
|
+
parameters={"threshold": threshold},
|
|
1471
|
+
headers=self.headers,
|
|
1472
|
+
model=model or self.model,
|
|
1473
|
+
api_key=self.token,
|
|
1474
|
+
)
|
|
1475
|
+
response = self._inner_post(request_parameters)
|
|
1458
1476
|
return ObjectDetectionOutputElement.parse_obj_as_list(response)
|
|
1459
1477
|
|
|
1460
1478
|
def question_answering(
|
|
@@ -1519,22 +1537,24 @@ class InferenceClient:
|
|
|
1519
1537
|
QuestionAnsweringOutputElement(answer='Clara', end=16, score=0.9326565265655518, start=11)
|
|
1520
1538
|
```
|
|
1521
1539
|
"""
|
|
1522
|
-
|
|
1523
|
-
|
|
1524
|
-
|
|
1525
|
-
|
|
1526
|
-
|
|
1527
|
-
|
|
1528
|
-
|
|
1529
|
-
|
|
1530
|
-
|
|
1531
|
-
|
|
1532
|
-
|
|
1533
|
-
|
|
1534
|
-
|
|
1535
|
-
|
|
1536
|
-
|
|
1540
|
+
provider_helper = get_provider_helper(self.provider, task="question-answering")
|
|
1541
|
+
request_parameters = provider_helper.prepare_request(
|
|
1542
|
+
inputs=None,
|
|
1543
|
+
parameters={
|
|
1544
|
+
"align_to_words": align_to_words,
|
|
1545
|
+
"doc_stride": doc_stride,
|
|
1546
|
+
"handle_impossible_answer": handle_impossible_answer,
|
|
1547
|
+
"max_answer_len": max_answer_len,
|
|
1548
|
+
"max_question_len": max_question_len,
|
|
1549
|
+
"max_seq_len": max_seq_len,
|
|
1550
|
+
"top_k": top_k,
|
|
1551
|
+
},
|
|
1552
|
+
extra_payload={"question": question, "context": context},
|
|
1553
|
+
headers=self.headers,
|
|
1554
|
+
model=model or self.model,
|
|
1555
|
+
api_key=self.token,
|
|
1537
1556
|
)
|
|
1557
|
+
response = self._inner_post(request_parameters)
|
|
1538
1558
|
# Parse the response as a single `QuestionAnsweringOutputElement` when top_k is 1 or not provided, or a list of `QuestionAnsweringOutputElement` to ensure backward compatibility.
|
|
1539
1559
|
output = QuestionAnsweringOutputElement.parse_obj(response)
|
|
1540
1560
|
return output
|
|
@@ -1579,11 +1599,16 @@ class InferenceClient:
|
|
|
1579
1599
|
[0.7785726189613342, 0.45876261591911316, 0.2906220555305481]
|
|
1580
1600
|
```
|
|
1581
1601
|
"""
|
|
1582
|
-
|
|
1583
|
-
|
|
1584
|
-
|
|
1585
|
-
|
|
1602
|
+
provider_helper = get_provider_helper(self.provider, task="sentence-similarity")
|
|
1603
|
+
request_parameters = provider_helper.prepare_request(
|
|
1604
|
+
inputs=None,
|
|
1605
|
+
parameters={},
|
|
1606
|
+
extra_payload={"source_sentence": sentence, "sentences": other_sentences},
|
|
1607
|
+
headers=self.headers,
|
|
1608
|
+
model=model or self.model,
|
|
1609
|
+
api_key=self.token,
|
|
1586
1610
|
)
|
|
1611
|
+
response = self._inner_post(request_parameters)
|
|
1587
1612
|
return _bytes_to_list(response)
|
|
1588
1613
|
|
|
1589
1614
|
@_deprecate_arguments(
|
|
@@ -1645,8 +1670,15 @@ class InferenceClient:
|
|
|
1645
1670
|
"generate_parameters": generate_parameters,
|
|
1646
1671
|
"truncation": truncation,
|
|
1647
1672
|
}
|
|
1648
|
-
|
|
1649
|
-
|
|
1673
|
+
provider_helper = get_provider_helper(self.provider, task="summarization")
|
|
1674
|
+
request_parameters = provider_helper.prepare_request(
|
|
1675
|
+
inputs=text,
|
|
1676
|
+
parameters=parameters,
|
|
1677
|
+
headers=self.headers,
|
|
1678
|
+
model=model or self.model,
|
|
1679
|
+
api_key=self.token,
|
|
1680
|
+
)
|
|
1681
|
+
response = self._inner_post(request_parameters)
|
|
1650
1682
|
return SummarizationOutput.parse_obj_as_list(response)[0]
|
|
1651
1683
|
|
|
1652
1684
|
def table_question_answering(
|
|
@@ -1699,21 +1731,16 @@ class InferenceClient:
|
|
|
1699
1731
|
TableQuestionAnsweringOutputElement(answer='36542', coordinates=[[0, 1]], cells=['36542'], aggregator='AVERAGE')
|
|
1700
1732
|
```
|
|
1701
1733
|
"""
|
|
1702
|
-
|
|
1703
|
-
|
|
1704
|
-
|
|
1705
|
-
"truncation": truncation,
|
|
1706
|
-
|
|
1707
|
-
|
|
1708
|
-
|
|
1709
|
-
|
|
1710
|
-
}
|
|
1711
|
-
payload = _prepare_payload(inputs, parameters=parameters)
|
|
1712
|
-
response = self.post(
|
|
1713
|
-
**payload,
|
|
1714
|
-
model=model,
|
|
1715
|
-
task="table-question-answering",
|
|
1734
|
+
provider_helper = get_provider_helper(self.provider, task="table-question-answering")
|
|
1735
|
+
request_parameters = provider_helper.prepare_request(
|
|
1736
|
+
inputs=None,
|
|
1737
|
+
parameters={"model": model, "padding": padding, "sequential": sequential, "truncation": truncation},
|
|
1738
|
+
extra_payload={"query": query, "table": table},
|
|
1739
|
+
headers=self.headers,
|
|
1740
|
+
model=model or self.model,
|
|
1741
|
+
api_key=self.token,
|
|
1716
1742
|
)
|
|
1743
|
+
response = self._inner_post(request_parameters)
|
|
1717
1744
|
return TableQuestionAnsweringOutputElement.parse_obj_as_instance(response)
|
|
1718
1745
|
|
|
1719
1746
|
def tabular_classification(self, table: Dict[str, Any], *, model: Optional[str] = None) -> List[str]:
|
|
@@ -1758,11 +1785,16 @@ class InferenceClient:
|
|
|
1758
1785
|
["5", "5", "5"]
|
|
1759
1786
|
```
|
|
1760
1787
|
"""
|
|
1761
|
-
|
|
1762
|
-
|
|
1763
|
-
|
|
1764
|
-
|
|
1788
|
+
provider_helper = get_provider_helper(self.provider, task="tabular-classification")
|
|
1789
|
+
request_parameters = provider_helper.prepare_request(
|
|
1790
|
+
inputs=None,
|
|
1791
|
+
extra_payload={"table": table},
|
|
1792
|
+
parameters={},
|
|
1793
|
+
headers=self.headers,
|
|
1794
|
+
model=model or self.model,
|
|
1795
|
+
api_key=self.token,
|
|
1765
1796
|
)
|
|
1797
|
+
response = self._inner_post(request_parameters)
|
|
1766
1798
|
return _bytes_to_list(response)
|
|
1767
1799
|
|
|
1768
1800
|
def tabular_regression(self, table: Dict[str, Any], *, model: Optional[str] = None) -> List[float]:
|
|
@@ -1802,7 +1834,16 @@ class InferenceClient:
|
|
|
1802
1834
|
[110, 120, 130]
|
|
1803
1835
|
```
|
|
1804
1836
|
"""
|
|
1805
|
-
|
|
1837
|
+
provider_helper = get_provider_helper(self.provider, task="tabular-regression")
|
|
1838
|
+
request_parameters = provider_helper.prepare_request(
|
|
1839
|
+
inputs=None,
|
|
1840
|
+
parameters={},
|
|
1841
|
+
extra_payload={"table": table},
|
|
1842
|
+
headers=self.headers,
|
|
1843
|
+
model=model or self.model,
|
|
1844
|
+
api_key=self.token,
|
|
1845
|
+
)
|
|
1846
|
+
response = self._inner_post(request_parameters)
|
|
1806
1847
|
return _bytes_to_list(response)
|
|
1807
1848
|
|
|
1808
1849
|
def text_classification(
|
|
@@ -1848,16 +1889,18 @@ class InferenceClient:
|
|
|
1848
1889
|
]
|
|
1849
1890
|
```
|
|
1850
1891
|
"""
|
|
1851
|
-
|
|
1852
|
-
|
|
1853
|
-
|
|
1854
|
-
|
|
1855
|
-
|
|
1856
|
-
|
|
1857
|
-
|
|
1858
|
-
|
|
1859
|
-
|
|
1892
|
+
provider_helper = get_provider_helper(self.provider, task="text-classification")
|
|
1893
|
+
request_parameters = provider_helper.prepare_request(
|
|
1894
|
+
inputs=text,
|
|
1895
|
+
parameters={
|
|
1896
|
+
"function_to_apply": function_to_apply,
|
|
1897
|
+
"top_k": top_k,
|
|
1898
|
+
},
|
|
1899
|
+
headers=self.headers,
|
|
1900
|
+
model=model or self.model,
|
|
1901
|
+
api_key=self.token,
|
|
1860
1902
|
)
|
|
1903
|
+
response = self._inner_post(request_parameters)
|
|
1861
1904
|
return TextClassificationOutputElement.parse_obj_as_list(response)[0] # type: ignore [return-value]
|
|
1862
1905
|
|
|
1863
1906
|
@overload
|
|
@@ -2041,15 +2084,6 @@ class InferenceClient:
|
|
|
2041
2084
|
"""
|
|
2042
2085
|
Given a prompt, generate the following text.
|
|
2043
2086
|
|
|
2044
|
-
API endpoint is supposed to run with the `text-generation-inference` backend (TGI). This backend is the
|
|
2045
|
-
go-to solution to run large language models at scale. However, for some smaller models (e.g. "gpt2") the
|
|
2046
|
-
default `transformers` + `api-inference` solution is still in use. Both approaches have very similar APIs, but
|
|
2047
|
-
not exactly the same. This method is compatible with both approaches but some parameters are only available for
|
|
2048
|
-
`text-generation-inference`. If some parameters are ignored, a warning message is triggered but the process
|
|
2049
|
-
continues correctly.
|
|
2050
|
-
|
|
2051
|
-
To learn more about the TGI project, please refer to https://github.com/huggingface/text-generation-inference.
|
|
2052
|
-
|
|
2053
2087
|
<Tip>
|
|
2054
2088
|
|
|
2055
2089
|
If you want to generate a response from chat messages, you should use the [`InferenceClient.chat_completion`] method.
|
|
@@ -2272,12 +2306,6 @@ class InferenceClient:
|
|
|
2272
2306
|
"typical_p": typical_p,
|
|
2273
2307
|
"watermark": watermark,
|
|
2274
2308
|
}
|
|
2275
|
-
parameters = {k: v for k, v in parameters.items() if v is not None}
|
|
2276
|
-
payload = {
|
|
2277
|
-
"inputs": prompt,
|
|
2278
|
-
"parameters": parameters,
|
|
2279
|
-
"stream": stream,
|
|
2280
|
-
}
|
|
2281
2309
|
|
|
2282
2310
|
# Remove some parameters if not a TGI server
|
|
2283
2311
|
unsupported_kwargs = _get_unsupported_text_generation_kwargs(model)
|
|
@@ -2310,9 +2338,19 @@ class InferenceClient:
|
|
|
2310
2338
|
" Please pass `stream=False` as input."
|
|
2311
2339
|
)
|
|
2312
2340
|
|
|
2341
|
+
provider_helper = get_provider_helper(self.provider, task="text-generation")
|
|
2342
|
+
request_parameters = provider_helper.prepare_request(
|
|
2343
|
+
inputs=prompt,
|
|
2344
|
+
parameters=parameters,
|
|
2345
|
+
extra_payload={"stream": stream},
|
|
2346
|
+
headers=self.headers,
|
|
2347
|
+
model=model or self.model,
|
|
2348
|
+
api_key=self.token,
|
|
2349
|
+
)
|
|
2350
|
+
|
|
2313
2351
|
# Handle errors separately for more precise error messages
|
|
2314
2352
|
try:
|
|
2315
|
-
bytes_output = self.
|
|
2353
|
+
bytes_output = self._inner_post(request_parameters, stream=stream)
|
|
2316
2354
|
except HTTPError as e:
|
|
2317
2355
|
match = MODEL_KWARGS_NOT_USED_REGEX.search(str(e))
|
|
2318
2356
|
if isinstance(e, BadRequestError) and match:
|
|
@@ -2322,7 +2360,7 @@ class InferenceClient:
|
|
|
2322
2360
|
prompt=prompt,
|
|
2323
2361
|
details=details,
|
|
2324
2362
|
stream=stream,
|
|
2325
|
-
model=model,
|
|
2363
|
+
model=model or self.model,
|
|
2326
2364
|
adapter_id=adapter_id,
|
|
2327
2365
|
best_of=best_of,
|
|
2328
2366
|
decoder_input_details=decoder_input_details,
|
|
@@ -2360,7 +2398,7 @@ class InferenceClient:
|
|
|
2360
2398
|
self,
|
|
2361
2399
|
prompt: str,
|
|
2362
2400
|
*,
|
|
2363
|
-
negative_prompt: Optional[
|
|
2401
|
+
negative_prompt: Optional[str] = None,
|
|
2364
2402
|
height: Optional[float] = None,
|
|
2365
2403
|
width: Optional[float] = None,
|
|
2366
2404
|
num_inference_steps: Optional[int] = None,
|
|
@@ -2383,8 +2421,8 @@ class InferenceClient:
|
|
|
2383
2421
|
Args:
|
|
2384
2422
|
prompt (`str`):
|
|
2385
2423
|
The prompt to generate an image from.
|
|
2386
|
-
negative_prompt (`
|
|
2387
|
-
One
|
|
2424
|
+
negative_prompt (`str`, *optional*):
|
|
2425
|
+
One prompt to guide what NOT to include in image generation.
|
|
2388
2426
|
height (`float`, *optional*):
|
|
2389
2427
|
The height in pixels of the image to generate.
|
|
2390
2428
|
width (`float`, *optional*):
|
|
@@ -2430,23 +2468,143 @@ class InferenceClient:
|
|
|
2430
2468
|
... )
|
|
2431
2469
|
>>> image.save("better_astronaut.png")
|
|
2432
2470
|
```
|
|
2433
|
-
|
|
2471
|
+
Example using a third-party provider directly. Usage will be billed on your fal.ai account.
|
|
2472
|
+
```py
|
|
2473
|
+
>>> from huggingface_hub import InferenceClient
|
|
2474
|
+
>>> client = InferenceClient(
|
|
2475
|
+
... provider="fal-ai", # Use fal.ai provider
|
|
2476
|
+
... api_key="fal-ai-api-key", # Pass your fal.ai API key
|
|
2477
|
+
... )
|
|
2478
|
+
>>> image = client.text_to_image(
|
|
2479
|
+
... "A majestic lion in a fantasy forest",
|
|
2480
|
+
... model="black-forest-labs/FLUX.1-schnell",
|
|
2481
|
+
... )
|
|
2482
|
+
>>> image.save("lion.png")
|
|
2483
|
+
```
|
|
2434
2484
|
|
|
2435
|
-
|
|
2436
|
-
|
|
2437
|
-
|
|
2438
|
-
|
|
2439
|
-
|
|
2440
|
-
|
|
2441
|
-
|
|
2442
|
-
|
|
2443
|
-
|
|
2444
|
-
|
|
2445
|
-
|
|
2446
|
-
|
|
2447
|
-
|
|
2485
|
+
Example using a third-party provider through Hugging Face Routing. Usage will be billed on your Hugging Face account.
|
|
2486
|
+
```py
|
|
2487
|
+
>>> from huggingface_hub import InferenceClient
|
|
2488
|
+
>>> client = InferenceClient(
|
|
2489
|
+
... provider="replicate", # Use replicate provider
|
|
2490
|
+
... api_key="hf_...", # Pass your HF token
|
|
2491
|
+
... )
|
|
2492
|
+
>>> image = client.text_to_image(
|
|
2493
|
+
... "An astronaut riding a horse on the moon.",
|
|
2494
|
+
... model="black-forest-labs/FLUX.1-dev",
|
|
2495
|
+
... )
|
|
2496
|
+
>>> image.save("astronaut.png")
|
|
2497
|
+
```
|
|
2498
|
+
"""
|
|
2499
|
+
provider_helper = get_provider_helper(self.provider, task="text-to-image")
|
|
2500
|
+
request_parameters = provider_helper.prepare_request(
|
|
2501
|
+
inputs=prompt,
|
|
2502
|
+
parameters={
|
|
2503
|
+
"negative_prompt": negative_prompt,
|
|
2504
|
+
"height": height,
|
|
2505
|
+
"width": width,
|
|
2506
|
+
"num_inference_steps": num_inference_steps,
|
|
2507
|
+
"guidance_scale": guidance_scale,
|
|
2508
|
+
"scheduler": scheduler,
|
|
2509
|
+
"target_size": target_size,
|
|
2510
|
+
"seed": seed,
|
|
2511
|
+
**kwargs,
|
|
2512
|
+
},
|
|
2513
|
+
headers=self.headers,
|
|
2514
|
+
model=model or self.model,
|
|
2515
|
+
api_key=self.token,
|
|
2516
|
+
)
|
|
2517
|
+
response = self._inner_post(request_parameters)
|
|
2518
|
+
response = provider_helper.get_response(response)
|
|
2448
2519
|
return _bytes_to_image(response)
|
|
2449
2520
|
|
|
2521
|
+
def text_to_video(
|
|
2522
|
+
self,
|
|
2523
|
+
prompt: str,
|
|
2524
|
+
*,
|
|
2525
|
+
model: Optional[str] = None,
|
|
2526
|
+
guidance_scale: Optional[float] = None,
|
|
2527
|
+
negative_prompt: Optional[List[str]] = None,
|
|
2528
|
+
num_frames: Optional[float] = None,
|
|
2529
|
+
num_inference_steps: Optional[int] = None,
|
|
2530
|
+
seed: Optional[int] = None,
|
|
2531
|
+
) -> bytes:
|
|
2532
|
+
"""
|
|
2533
|
+
Generate a video based on a given text.
|
|
2534
|
+
|
|
2535
|
+
Args:
|
|
2536
|
+
prompt (`str`):
|
|
2537
|
+
The prompt to generate a video from.
|
|
2538
|
+
model (`str`, *optional*):
|
|
2539
|
+
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
|
2540
|
+
Inference Endpoint. If not provided, the default recommended text-to-video model will be used.
|
|
2541
|
+
Defaults to None.
|
|
2542
|
+
guidance_scale (`float`, *optional*):
|
|
2543
|
+
A higher guidance scale value encourages the model to generate videos closely linked to the text
|
|
2544
|
+
prompt, but values too high may cause saturation and other artifacts.
|
|
2545
|
+
negative_prompt (`List[str]`, *optional*):
|
|
2546
|
+
One or several prompt to guide what NOT to include in video generation.
|
|
2547
|
+
num_frames (`float`, *optional*):
|
|
2548
|
+
The num_frames parameter determines how many video frames are generated.
|
|
2549
|
+
num_inference_steps (`int`, *optional*):
|
|
2550
|
+
The number of denoising steps. More denoising steps usually lead to a higher quality video at the
|
|
2551
|
+
expense of slower inference.
|
|
2552
|
+
seed (`int`, *optional*):
|
|
2553
|
+
Seed for the random number generator.
|
|
2554
|
+
|
|
2555
|
+
Returns:
|
|
2556
|
+
`bytes`: The generated video.
|
|
2557
|
+
|
|
2558
|
+
Example:
|
|
2559
|
+
|
|
2560
|
+
Example using a third-party provider directly. Usage will be billed on your fal.ai account.
|
|
2561
|
+
```py
|
|
2562
|
+
>>> from huggingface_hub import InferenceClient
|
|
2563
|
+
>>> client = InferenceClient(
|
|
2564
|
+
... provider="fal-ai", # Using fal.ai provider
|
|
2565
|
+
... api_key="fal-ai-api-key", # Pass your fal.ai API key
|
|
2566
|
+
... )
|
|
2567
|
+
>>> video = client.text_to_video(
|
|
2568
|
+
... "A majestic lion running in a fantasy forest",
|
|
2569
|
+
... model="tencent/HunyuanVideo",
|
|
2570
|
+
... )
|
|
2571
|
+
>>> with open("lion.mp4", "wb") as file:
|
|
2572
|
+
... file.write(video)
|
|
2573
|
+
```
|
|
2574
|
+
|
|
2575
|
+
Example using a third-party provider through Hugging Face Routing. Usage will be billed on your Hugging Face account.
|
|
2576
|
+
```py
|
|
2577
|
+
>>> from huggingface_hub import InferenceClient
|
|
2578
|
+
>>> client = InferenceClient(
|
|
2579
|
+
... provider="replicate", # Using replicate provider
|
|
2580
|
+
... api_key="hf_...", # Pass your HF token
|
|
2581
|
+
... )
|
|
2582
|
+
>>> video = client.text_to_video(
|
|
2583
|
+
... "A cat running in a park",
|
|
2584
|
+
... model="genmo/mochi-1-preview",
|
|
2585
|
+
... )
|
|
2586
|
+
>>> with open("cat.mp4", "wb") as file:
|
|
2587
|
+
... file.write(video)
|
|
2588
|
+
```
|
|
2589
|
+
"""
|
|
2590
|
+
provider_helper = get_provider_helper(self.provider, task="text-to-video")
|
|
2591
|
+
request_parameters = provider_helper.prepare_request(
|
|
2592
|
+
inputs=prompt,
|
|
2593
|
+
parameters={
|
|
2594
|
+
"guidance_scale": guidance_scale,
|
|
2595
|
+
"negative_prompt": negative_prompt,
|
|
2596
|
+
"num_frames": num_frames,
|
|
2597
|
+
"num_inference_steps": num_inference_steps,
|
|
2598
|
+
"seed": seed,
|
|
2599
|
+
},
|
|
2600
|
+
headers=self.headers,
|
|
2601
|
+
model=model or self.model,
|
|
2602
|
+
api_key=self.token,
|
|
2603
|
+
)
|
|
2604
|
+
response = self._inner_post(request_parameters)
|
|
2605
|
+
response = provider_helper.get_response(response)
|
|
2606
|
+
return response
|
|
2607
|
+
|
|
2450
2608
|
def text_to_speech(
|
|
2451
2609
|
self,
|
|
2452
2610
|
text: str,
|
|
@@ -2544,27 +2702,62 @@ class InferenceClient:
|
|
|
2544
2702
|
>>> audio = client.text_to_speech("Hello world")
|
|
2545
2703
|
>>> Path("hello_world.flac").write_bytes(audio)
|
|
2546
2704
|
```
|
|
2705
|
+
|
|
2706
|
+
Example using a third-party provider directly. Usage will be billed on your Replicate account.
|
|
2707
|
+
```py
|
|
2708
|
+
>>> from huggingface_hub import InferenceClient
|
|
2709
|
+
>>> client = InferenceClient(
|
|
2710
|
+
... provider="replicate",
|
|
2711
|
+
... api_key="your-replicate-api-key", # Pass your Replicate API key directly
|
|
2712
|
+
... )
|
|
2713
|
+
>>> audio = client.text_to_speech(
|
|
2714
|
+
... text="Hello world",
|
|
2715
|
+
... model="OuteAI/OuteTTS-0.3-500M",
|
|
2716
|
+
... )
|
|
2717
|
+
>>> Path("hello_world.flac").write_bytes(audio)
|
|
2718
|
+
```
|
|
2719
|
+
|
|
2720
|
+
Example using a third-party provider through Hugging Face Routing. Usage will be billed on your Hugging Face account.
|
|
2721
|
+
```py
|
|
2722
|
+
>>> from huggingface_hub import InferenceClient
|
|
2723
|
+
>>> client = InferenceClient(
|
|
2724
|
+
... provider="replicate",
|
|
2725
|
+
... api_key="hf_...", # Pass your HF token
|
|
2726
|
+
... )
|
|
2727
|
+
>>> audio =client.text_to_speech(
|
|
2728
|
+
... text="Hello world",
|
|
2729
|
+
... model="OuteAI/OuteTTS-0.3-500M",
|
|
2730
|
+
... )
|
|
2731
|
+
>>> Path("hello_world.flac").write_bytes(audio)
|
|
2732
|
+
```
|
|
2547
2733
|
"""
|
|
2548
|
-
|
|
2549
|
-
|
|
2550
|
-
|
|
2551
|
-
|
|
2552
|
-
|
|
2553
|
-
|
|
2554
|
-
|
|
2555
|
-
|
|
2556
|
-
|
|
2557
|
-
|
|
2558
|
-
|
|
2559
|
-
|
|
2560
|
-
|
|
2561
|
-
|
|
2562
|
-
|
|
2563
|
-
|
|
2564
|
-
|
|
2565
|
-
|
|
2566
|
-
|
|
2567
|
-
|
|
2734
|
+
provider_helper = get_provider_helper(self.provider, task="text-to-speech")
|
|
2735
|
+
request_parameters = provider_helper.prepare_request(
|
|
2736
|
+
inputs=text,
|
|
2737
|
+
parameters={
|
|
2738
|
+
"do_sample": do_sample,
|
|
2739
|
+
"early_stopping": early_stopping,
|
|
2740
|
+
"epsilon_cutoff": epsilon_cutoff,
|
|
2741
|
+
"eta_cutoff": eta_cutoff,
|
|
2742
|
+
"max_length": max_length,
|
|
2743
|
+
"max_new_tokens": max_new_tokens,
|
|
2744
|
+
"min_length": min_length,
|
|
2745
|
+
"min_new_tokens": min_new_tokens,
|
|
2746
|
+
"num_beam_groups": num_beam_groups,
|
|
2747
|
+
"num_beams": num_beams,
|
|
2748
|
+
"penalty_alpha": penalty_alpha,
|
|
2749
|
+
"temperature": temperature,
|
|
2750
|
+
"top_k": top_k,
|
|
2751
|
+
"top_p": top_p,
|
|
2752
|
+
"typical_p": typical_p,
|
|
2753
|
+
"use_cache": use_cache,
|
|
2754
|
+
},
|
|
2755
|
+
headers=self.headers,
|
|
2756
|
+
model=model or self.model,
|
|
2757
|
+
api_key=self.token,
|
|
2758
|
+
)
|
|
2759
|
+
response = self._inner_post(request_parameters)
|
|
2760
|
+
response = provider_helper.get_response(response)
|
|
2568
2761
|
return response
|
|
2569
2762
|
|
|
2570
2763
|
def token_classification(
|
|
@@ -2626,18 +2819,19 @@ class InferenceClient:
|
|
|
2626
2819
|
]
|
|
2627
2820
|
```
|
|
2628
2821
|
"""
|
|
2629
|
-
|
|
2630
|
-
|
|
2631
|
-
|
|
2632
|
-
|
|
2633
|
-
|
|
2634
|
-
|
|
2635
|
-
|
|
2636
|
-
|
|
2637
|
-
|
|
2638
|
-
model=model,
|
|
2639
|
-
|
|
2822
|
+
provider_helper = get_provider_helper(self.provider, task="token-classification")
|
|
2823
|
+
request_parameters = provider_helper.prepare_request(
|
|
2824
|
+
inputs=text,
|
|
2825
|
+
parameters={
|
|
2826
|
+
"aggregation_strategy": aggregation_strategy,
|
|
2827
|
+
"ignore_labels": ignore_labels,
|
|
2828
|
+
"stride": stride,
|
|
2829
|
+
},
|
|
2830
|
+
headers=self.headers,
|
|
2831
|
+
model=model or self.model,
|
|
2832
|
+
api_key=self.token,
|
|
2640
2833
|
)
|
|
2834
|
+
response = self._inner_post(request_parameters)
|
|
2641
2835
|
return TokenClassificationOutputElement.parse_obj_as_list(response)
|
|
2642
2836
|
|
|
2643
2837
|
def translation(
|
|
@@ -2710,15 +2904,22 @@ class InferenceClient:
|
|
|
2710
2904
|
|
|
2711
2905
|
if src_lang is None and tgt_lang is not None:
|
|
2712
2906
|
raise ValueError("You cannot specify `tgt_lang` without specifying `src_lang`.")
|
|
2713
|
-
|
|
2714
|
-
|
|
2715
|
-
|
|
2716
|
-
|
|
2717
|
-
|
|
2718
|
-
|
|
2719
|
-
|
|
2720
|
-
|
|
2721
|
-
|
|
2907
|
+
|
|
2908
|
+
provider_helper = get_provider_helper(self.provider, task="translation")
|
|
2909
|
+
request_parameters = provider_helper.prepare_request(
|
|
2910
|
+
inputs=text,
|
|
2911
|
+
parameters={
|
|
2912
|
+
"src_lang": src_lang,
|
|
2913
|
+
"tgt_lang": tgt_lang,
|
|
2914
|
+
"clean_up_tokenization_spaces": clean_up_tokenization_spaces,
|
|
2915
|
+
"truncation": truncation,
|
|
2916
|
+
"generate_parameters": generate_parameters,
|
|
2917
|
+
},
|
|
2918
|
+
headers=self.headers,
|
|
2919
|
+
model=model or self.model,
|
|
2920
|
+
api_key=self.token,
|
|
2921
|
+
)
|
|
2922
|
+
response = self._inner_post(request_parameters)
|
|
2722
2923
|
return TranslationOutput.parse_obj_as_list(response)[0]
|
|
2723
2924
|
|
|
2724
2925
|
def visual_question_answering(
|
|
@@ -2767,10 +2968,16 @@ class InferenceClient:
|
|
|
2767
2968
|
]
|
|
2768
2969
|
```
|
|
2769
2970
|
"""
|
|
2770
|
-
|
|
2771
|
-
|
|
2772
|
-
|
|
2773
|
-
|
|
2971
|
+
provider_helper = get_provider_helper(self.provider, task="visual-question-answering")
|
|
2972
|
+
request_parameters = provider_helper.prepare_request(
|
|
2973
|
+
inputs=image,
|
|
2974
|
+
parameters={"top_k": top_k},
|
|
2975
|
+
headers=self.headers,
|
|
2976
|
+
model=model or self.model,
|
|
2977
|
+
api_key=self.token,
|
|
2978
|
+
extra_payload={"question": question, "image": _b64_encode(image)},
|
|
2979
|
+
)
|
|
2980
|
+
response = self._inner_post(request_parameters)
|
|
2774
2981
|
return VisualQuestionAnsweringOutputElement.parse_obj_as_list(response)
|
|
2775
2982
|
|
|
2776
2983
|
@_deprecate_arguments(
|
|
@@ -2876,17 +3083,20 @@ class InferenceClient:
|
|
|
2876
3083
|
candidate_labels = labels
|
|
2877
3084
|
elif candidate_labels is None:
|
|
2878
3085
|
raise ValueError("Must specify `candidate_labels`")
|
|
2879
|
-
|
|
2880
|
-
|
|
2881
|
-
|
|
2882
|
-
|
|
2883
|
-
|
|
2884
|
-
|
|
2885
|
-
|
|
2886
|
-
|
|
2887
|
-
|
|
2888
|
-
|
|
3086
|
+
|
|
3087
|
+
provider_helper = get_provider_helper(self.provider, task="zero-shot-classification")
|
|
3088
|
+
request_parameters = provider_helper.prepare_request(
|
|
3089
|
+
inputs=text,
|
|
3090
|
+
parameters={
|
|
3091
|
+
"candidate_labels": candidate_labels,
|
|
3092
|
+
"multi_label": multi_label,
|
|
3093
|
+
"hypothesis_template": hypothesis_template,
|
|
3094
|
+
},
|
|
3095
|
+
headers=self.headers,
|
|
3096
|
+
model=model or self.model,
|
|
3097
|
+
api_key=self.token,
|
|
2889
3098
|
)
|
|
3099
|
+
response = self._inner_post(request_parameters)
|
|
2890
3100
|
output = _bytes_to_dict(response)
|
|
2891
3101
|
return [
|
|
2892
3102
|
ZeroShotClassificationOutputElement.parse_obj_as_instance({"label": label, "score": score})
|
|
@@ -2959,71 +3169,108 @@ class InferenceClient:
|
|
|
2959
3169
|
# Raise ValueError if input is less than 2 labels
|
|
2960
3170
|
if len(candidate_labels) < 2:
|
|
2961
3171
|
raise ValueError("You must specify at least 2 classes to compare.")
|
|
2962
|
-
|
|
2963
|
-
|
|
2964
|
-
|
|
2965
|
-
|
|
2966
|
-
|
|
2967
|
-
|
|
2968
|
-
|
|
2969
|
-
|
|
2970
|
-
|
|
3172
|
+
|
|
3173
|
+
provider_helper = get_provider_helper(self.provider, task="zero-shot-image-classification")
|
|
3174
|
+
request_parameters = provider_helper.prepare_request(
|
|
3175
|
+
inputs=image,
|
|
3176
|
+
parameters={
|
|
3177
|
+
"candidate_labels": candidate_labels,
|
|
3178
|
+
"hypothesis_template": hypothesis_template,
|
|
3179
|
+
},
|
|
3180
|
+
headers=self.headers,
|
|
3181
|
+
model=model or self.model,
|
|
3182
|
+
api_key=self.token,
|
|
2971
3183
|
)
|
|
3184
|
+
response = self._inner_post(request_parameters)
|
|
2972
3185
|
return ZeroShotImageClassificationOutputElement.parse_obj_as_list(response)
|
|
2973
3186
|
|
|
2974
|
-
def
|
|
2975
|
-
|
|
3187
|
+
def list_deployed_models(
|
|
3188
|
+
self, frameworks: Union[None, str, Literal["all"], List[str]] = None
|
|
3189
|
+
) -> Dict[str, List[str]]:
|
|
3190
|
+
"""
|
|
3191
|
+
List models deployed on the Serverless Inference API service.
|
|
2976
3192
|
|
|
2977
|
-
|
|
2978
|
-
|
|
2979
|
-
|
|
3193
|
+
This helper checks deployed models framework by framework. By default, it will check the 4 main frameworks that
|
|
3194
|
+
are supported and account for 95% of the hosted models. However, if you want a complete list of models you can
|
|
3195
|
+
specify `frameworks="all"` as input. Alternatively, if you know before-hand which framework you are interested
|
|
3196
|
+
in, you can also restrict to search to this one (e.g. `frameworks="text-generation-inference"`). The more
|
|
3197
|
+
frameworks are checked, the more time it will take.
|
|
2980
3198
|
|
|
2981
|
-
|
|
2982
|
-
if model is None:
|
|
2983
|
-
if task is None:
|
|
2984
|
-
raise ValueError(
|
|
2985
|
-
"You must specify at least a model (repo_id or URL) or a task, either when instantiating"
|
|
2986
|
-
" `InferenceClient` or when making a request."
|
|
2987
|
-
)
|
|
2988
|
-
model = self.get_recommended_model(task)
|
|
2989
|
-
logger.info(
|
|
2990
|
-
f"Using recommended model {model} for task {task}. Note that it is"
|
|
2991
|
-
f" encouraged to explicitly set `model='{model}'` as the recommended"
|
|
2992
|
-
" models list might get updated without prior notice."
|
|
2993
|
-
)
|
|
3199
|
+
<Tip warning={true}>
|
|
2994
3200
|
|
|
2995
|
-
|
|
2996
|
-
|
|
2997
|
-
|
|
2998
|
-
f"{INFERENCE_ENDPOINT}/pipeline/{task}/{model}"
|
|
2999
|
-
if task in ("feature-extraction", "sentence-similarity")
|
|
3000
|
-
# Otherwise, we use the default endpoint
|
|
3001
|
-
else f"{INFERENCE_ENDPOINT}/models/{model}"
|
|
3002
|
-
)
|
|
3201
|
+
This endpoint method does not return a live list of all models available for the Serverless Inference API service.
|
|
3202
|
+
It searches over a cached list of models that were recently available and the list may not be up to date.
|
|
3203
|
+
If you want to know the live status of a specific model, use [`~InferenceClient.get_model_status`].
|
|
3003
3204
|
|
|
3004
|
-
|
|
3005
|
-
|
|
3006
|
-
|
|
3007
|
-
|
|
3205
|
+
</Tip>
|
|
3206
|
+
|
|
3207
|
+
<Tip>
|
|
3208
|
+
|
|
3209
|
+
This endpoint method is mostly useful for discoverability. If you already know which model you want to use and want to
|
|
3210
|
+
check its availability, you can directly use [`~InferenceClient.get_model_status`].
|
|
3211
|
+
|
|
3212
|
+
</Tip>
|
|
3008
3213
|
|
|
3009
3214
|
Args:
|
|
3010
|
-
|
|
3011
|
-
The
|
|
3012
|
-
|
|
3215
|
+
frameworks (`Literal["all"]` or `List[str]` or `str`, *optional*):
|
|
3216
|
+
The frameworks to filter on. By default only a subset of the available frameworks are tested. If set to
|
|
3217
|
+
"all", all available frameworks will be tested. It is also possible to provide a single framework or a
|
|
3218
|
+
custom set of frameworks to check.
|
|
3013
3219
|
|
|
3014
3220
|
Returns:
|
|
3015
|
-
`str`:
|
|
3221
|
+
`Dict[str, List[str]]`: A dictionary mapping task names to a sorted list of model IDs.
|
|
3016
3222
|
|
|
3017
|
-
|
|
3018
|
-
|
|
3223
|
+
Example:
|
|
3224
|
+
```python
|
|
3225
|
+
>>> from huggingface_hub import InferenceClient
|
|
3226
|
+
>>> client = InferenceClient()
|
|
3227
|
+
|
|
3228
|
+
# Discover zero-shot-classification models currently deployed
|
|
3229
|
+
>>> models = client.list_deployed_models()
|
|
3230
|
+
>>> models["zero-shot-classification"]
|
|
3231
|
+
['Narsil/deberta-large-mnli-zero-cls', 'facebook/bart-large-mnli', ...]
|
|
3232
|
+
|
|
3233
|
+
# List from only 1 framework
|
|
3234
|
+
>>> client.list_deployed_models("text-generation-inference")
|
|
3235
|
+
{'text-generation': ['bigcode/starcoder', 'meta-llama/Llama-2-70b-chat-hf', ...], ...}
|
|
3236
|
+
```
|
|
3019
3237
|
"""
|
|
3020
|
-
|
|
3021
|
-
|
|
3022
|
-
|
|
3023
|
-
|
|
3024
|
-
|
|
3238
|
+
if self.provider != "hf-inference":
|
|
3239
|
+
raise ValueError(f"Listing deployed models is not supported on '{self.provider}'.")
|
|
3240
|
+
|
|
3241
|
+
# Resolve which frameworks to check
|
|
3242
|
+
if frameworks is None:
|
|
3243
|
+
frameworks = MAIN_INFERENCE_API_FRAMEWORKS
|
|
3244
|
+
elif frameworks == "all":
|
|
3245
|
+
frameworks = ALL_INFERENCE_API_FRAMEWORKS
|
|
3246
|
+
elif isinstance(frameworks, str):
|
|
3247
|
+
frameworks = [frameworks]
|
|
3248
|
+
frameworks = list(set(frameworks))
|
|
3249
|
+
|
|
3250
|
+
# Fetch them iteratively
|
|
3251
|
+
models_by_task: Dict[str, List[str]] = {}
|
|
3252
|
+
|
|
3253
|
+
def _unpack_response(framework: str, items: List[Dict]) -> None:
|
|
3254
|
+
for model in items:
|
|
3255
|
+
if framework == "sentence-transformers":
|
|
3256
|
+
# Model running with the `sentence-transformers` framework can work with both tasks even if not
|
|
3257
|
+
# branded as such in the API response
|
|
3258
|
+
models_by_task.setdefault("feature-extraction", []).append(model["model_id"])
|
|
3259
|
+
models_by_task.setdefault("sentence-similarity", []).append(model["model_id"])
|
|
3260
|
+
else:
|
|
3261
|
+
models_by_task.setdefault(model["task"], []).append(model["model_id"])
|
|
3262
|
+
|
|
3263
|
+
for framework in frameworks:
|
|
3264
|
+
response = get_session().get(
|
|
3265
|
+
f"{INFERENCE_ENDPOINT}/framework/{framework}", headers=build_hf_headers(token=self.token)
|
|
3025
3266
|
)
|
|
3026
|
-
|
|
3267
|
+
hf_raise_for_status(response)
|
|
3268
|
+
_unpack_response(framework, response.json())
|
|
3269
|
+
|
|
3270
|
+
# Sort alphabetically for discoverability and return
|
|
3271
|
+
for task, models in models_by_task.items():
|
|
3272
|
+
models_by_task[task] = sorted(set(models), key=lambda x: x.lower())
|
|
3273
|
+
return models_by_task
|
|
3027
3274
|
|
|
3028
3275
|
def get_endpoint_info(self, *, model: Optional[str] = None) -> Dict[str, Any]:
|
|
3029
3276
|
"""
|
|
@@ -3068,6 +3315,9 @@ class InferenceClient:
|
|
|
3068
3315
|
}
|
|
3069
3316
|
```
|
|
3070
3317
|
"""
|
|
3318
|
+
if self.provider != "hf-inference":
|
|
3319
|
+
raise ValueError(f"Getting endpoint info is not supported on '{self.provider}'.")
|
|
3320
|
+
|
|
3071
3321
|
model = model or self.model
|
|
3072
3322
|
if model is None:
|
|
3073
3323
|
raise ValueError("Model id not provided.")
|
|
@@ -3076,7 +3326,7 @@ class InferenceClient:
|
|
|
3076
3326
|
else:
|
|
3077
3327
|
url = f"{INFERENCE_ENDPOINT}/models/{model}/info"
|
|
3078
3328
|
|
|
3079
|
-
response = get_session().get(url, headers=self.
|
|
3329
|
+
response = get_session().get(url, headers=build_hf_headers(token=self.token))
|
|
3080
3330
|
hf_raise_for_status(response)
|
|
3081
3331
|
return response.json()
|
|
3082
3332
|
|
|
@@ -3102,6 +3352,9 @@ class InferenceClient:
|
|
|
3102
3352
|
True
|
|
3103
3353
|
```
|
|
3104
3354
|
"""
|
|
3355
|
+
if self.provider != "hf-inference":
|
|
3356
|
+
raise ValueError(f"Health check is not supported on '{self.provider}'.")
|
|
3357
|
+
|
|
3105
3358
|
model = model or self.model
|
|
3106
3359
|
if model is None:
|
|
3107
3360
|
raise ValueError("Model id not provided.")
|
|
@@ -3111,7 +3364,7 @@ class InferenceClient:
|
|
|
3111
3364
|
)
|
|
3112
3365
|
url = model.rstrip("/") + "/health"
|
|
3113
3366
|
|
|
3114
|
-
response = get_session().get(url, headers=self.
|
|
3367
|
+
response = get_session().get(url, headers=build_hf_headers(token=self.token))
|
|
3115
3368
|
return response.status_code == 200
|
|
3116
3369
|
|
|
3117
3370
|
def get_model_status(self, model: Optional[str] = None) -> ModelStatus:
|
|
@@ -3144,6 +3397,9 @@ class InferenceClient:
|
|
|
3144
3397
|
ModelStatus(loaded=True, state='Loaded', compute_type='gpu', framework='text-generation-inference')
|
|
3145
3398
|
```
|
|
3146
3399
|
"""
|
|
3400
|
+
if self.provider != "hf-inference":
|
|
3401
|
+
raise ValueError(f"Getting model status is not supported on '{self.provider}'.")
|
|
3402
|
+
|
|
3147
3403
|
model = model or self.model
|
|
3148
3404
|
if model is None:
|
|
3149
3405
|
raise ValueError("Model id not provided.")
|
|
@@ -3151,7 +3407,7 @@ class InferenceClient:
|
|
|
3151
3407
|
raise NotImplementedError("Model status is only available for Inference API endpoints.")
|
|
3152
3408
|
url = f"{INFERENCE_ENDPOINT}/status/{model}"
|
|
3153
3409
|
|
|
3154
|
-
response = get_session().get(url, headers=self.
|
|
3410
|
+
response = get_session().get(url, headers=build_hf_headers(token=self.token))
|
|
3155
3411
|
hf_raise_for_status(response)
|
|
3156
3412
|
response_data = response.json()
|
|
3157
3413
|
|