huggingface-hub 0.25.2__py3-none-any.whl → 0.26.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of huggingface-hub might be problematic. Click here for more details.
- huggingface_hub/__init__.py +45 -11
- huggingface_hub/_login.py +172 -33
- huggingface_hub/commands/user.py +125 -9
- huggingface_hub/constants.py +1 -1
- huggingface_hub/errors.py +6 -9
- huggingface_hub/file_download.py +2 -372
- huggingface_hub/hf_api.py +170 -13
- huggingface_hub/hf_file_system.py +3 -3
- huggingface_hub/hub_mixin.py +2 -1
- huggingface_hub/inference/_client.py +500 -145
- huggingface_hub/inference/_common.py +42 -4
- huggingface_hub/inference/_generated/_async_client.py +499 -144
- huggingface_hub/inference/_generated/types/__init__.py +37 -7
- huggingface_hub/inference/_generated/types/audio_classification.py +8 -5
- huggingface_hub/inference/_generated/types/automatic_speech_recognition.py +9 -7
- huggingface_hub/inference/_generated/types/chat_completion.py +23 -4
- huggingface_hub/inference/_generated/types/image_classification.py +8 -5
- huggingface_hub/inference/_generated/types/image_segmentation.py +9 -7
- huggingface_hub/inference/_generated/types/image_to_image.py +7 -5
- huggingface_hub/inference/_generated/types/image_to_text.py +4 -4
- huggingface_hub/inference/_generated/types/object_detection.py +11 -5
- huggingface_hub/inference/_generated/types/summarization.py +11 -13
- huggingface_hub/inference/_generated/types/text_classification.py +10 -5
- huggingface_hub/inference/_generated/types/text_generation.py +1 -0
- huggingface_hub/inference/_generated/types/text_to_audio.py +2 -2
- huggingface_hub/inference/_generated/types/text_to_image.py +9 -7
- huggingface_hub/inference/_generated/types/text_to_speech.py +107 -0
- huggingface_hub/inference/_generated/types/translation.py +17 -11
- huggingface_hub/inference/_generated/types/video_classification.py +2 -2
- huggingface_hub/repocard.py +2 -1
- huggingface_hub/repocard_data.py +10 -2
- huggingface_hub/serialization/_torch.py +7 -4
- huggingface_hub/utils/__init__.py +4 -20
- huggingface_hub/utils/{_token.py → _auth.py} +86 -3
- huggingface_hub/utils/_headers.py +1 -1
- huggingface_hub/utils/_hf_folder.py +1 -1
- huggingface_hub/utils/_http.py +10 -4
- huggingface_hub/utils/_runtime.py +1 -10
- {huggingface_hub-0.25.2.dist-info → huggingface_hub-0.26.0rc0.dist-info}/METADATA +12 -12
- {huggingface_hub-0.25.2.dist-info → huggingface_hub-0.26.0rc0.dist-info}/RECORD +44 -44
- huggingface_hub/inference/_templating.py +0 -102
- {huggingface_hub-0.25.2.dist-info → huggingface_hub-0.26.0rc0.dist-info}/LICENSE +0 -0
- {huggingface_hub-0.25.2.dist-info → huggingface_hub-0.26.0rc0.dist-info}/WHEEL +0 -0
- {huggingface_hub-0.25.2.dist-info → huggingface_hub-0.26.0rc0.dist-info}/entry_points.txt +0 -0
- {huggingface_hub-0.25.2.dist-info → huggingface_hub-0.26.0rc0.dist-info}/top_level.txt +0 -0
|
@@ -24,18 +24,7 @@ import logging
|
|
|
24
24
|
import re
|
|
25
25
|
import time
|
|
26
26
|
import warnings
|
|
27
|
-
from typing import
|
|
28
|
-
TYPE_CHECKING,
|
|
29
|
-
Any,
|
|
30
|
-
AsyncIterable,
|
|
31
|
-
Dict,
|
|
32
|
-
List,
|
|
33
|
-
Literal,
|
|
34
|
-
Optional,
|
|
35
|
-
Set,
|
|
36
|
-
Union,
|
|
37
|
-
overload,
|
|
38
|
-
)
|
|
27
|
+
from typing import TYPE_CHECKING, Any, AsyncIterable, Dict, List, Literal, Optional, Set, Union, overload
|
|
39
28
|
|
|
40
29
|
from requests.structures import CaseInsensitiveDict
|
|
41
30
|
|
|
@@ -56,16 +45,18 @@ from huggingface_hub.inference._common import (
|
|
|
56
45
|
_get_unsupported_text_generation_kwargs,
|
|
57
46
|
_import_numpy,
|
|
58
47
|
_open_as_binary,
|
|
48
|
+
_prepare_payload,
|
|
59
49
|
_set_unsupported_text_generation_kwargs,
|
|
60
50
|
raise_text_generation_error,
|
|
61
51
|
)
|
|
62
52
|
from huggingface_hub.inference._generated.types import (
|
|
63
53
|
AudioClassificationOutputElement,
|
|
54
|
+
AudioClassificationOutputTransform,
|
|
64
55
|
AudioToAudioOutputElement,
|
|
65
56
|
AutomaticSpeechRecognitionOutput,
|
|
66
57
|
ChatCompletionInputGrammarType,
|
|
67
|
-
|
|
68
|
-
|
|
58
|
+
ChatCompletionInputStreamOptions,
|
|
59
|
+
ChatCompletionInputToolType,
|
|
69
60
|
ChatCompletionOutput,
|
|
70
61
|
ChatCompletionStreamOutput,
|
|
71
62
|
DocumentQuestionAnsweringOutputElement,
|
|
@@ -78,19 +69,21 @@ from huggingface_hub.inference._generated.types import (
|
|
|
78
69
|
SummarizationOutput,
|
|
79
70
|
TableQuestionAnsweringOutputElement,
|
|
80
71
|
TextClassificationOutputElement,
|
|
72
|
+
TextClassificationOutputTransform,
|
|
81
73
|
TextGenerationInputGrammarType,
|
|
82
74
|
TextGenerationOutput,
|
|
83
75
|
TextGenerationStreamOutput,
|
|
76
|
+
TextToImageTargetSize,
|
|
77
|
+
TextToSpeechEarlyStoppingEnum,
|
|
84
78
|
TokenClassificationOutputElement,
|
|
79
|
+
ToolElement,
|
|
85
80
|
TranslationOutput,
|
|
86
81
|
VisualQuestionAnsweringOutputElement,
|
|
87
82
|
ZeroShotClassificationOutputElement,
|
|
88
83
|
ZeroShotImageClassificationOutputElement,
|
|
89
84
|
)
|
|
90
|
-
from huggingface_hub.utils import
|
|
91
|
-
|
|
92
|
-
)
|
|
93
|
-
from huggingface_hub.utils._deprecation import _deprecate_positional_args
|
|
85
|
+
from huggingface_hub.utils import build_hf_headers
|
|
86
|
+
from huggingface_hub.utils._deprecation import _deprecate_arguments
|
|
94
87
|
|
|
95
88
|
from .._common import _async_yield_from, _import_aiohttp
|
|
96
89
|
|
|
@@ -147,7 +140,6 @@ class AsyncInferenceClient:
|
|
|
147
140
|
follow the same pattern as `openai.OpenAI` client. Cannot be used if `token` is set. Defaults to None.
|
|
148
141
|
"""
|
|
149
142
|
|
|
150
|
-
@_deprecate_positional_args(version="0.26")
|
|
151
143
|
def __init__(
|
|
152
144
|
self,
|
|
153
145
|
model: Optional[str] = None,
|
|
@@ -365,6 +357,8 @@ class AsyncInferenceClient:
|
|
|
365
357
|
audio: ContentT,
|
|
366
358
|
*,
|
|
367
359
|
model: Optional[str] = None,
|
|
360
|
+
top_k: Optional[int] = None,
|
|
361
|
+
function_to_apply: Optional["AudioClassificationOutputTransform"] = None,
|
|
368
362
|
) -> List[AudioClassificationOutputElement]:
|
|
369
363
|
"""
|
|
370
364
|
Perform audio classification on the provided audio content.
|
|
@@ -377,6 +371,10 @@ class AsyncInferenceClient:
|
|
|
377
371
|
The model to use for audio classification. Can be a model ID hosted on the Hugging Face Hub
|
|
378
372
|
or a URL to a deployed Inference Endpoint. If not provided, the default recommended model for
|
|
379
373
|
audio classification will be used.
|
|
374
|
+
top_k (`int`, *optional*):
|
|
375
|
+
When specified, limits the output to the top K most probable classes.
|
|
376
|
+
function_to_apply (`"AudioClassificationOutputTransform"`, *optional*):
|
|
377
|
+
The function to apply to the output.
|
|
380
378
|
|
|
381
379
|
Returns:
|
|
382
380
|
`List[AudioClassificationOutputElement]`: List of [`AudioClassificationOutputElement`] items containing the predicted labels and their confidence.
|
|
@@ -400,7 +398,9 @@ class AsyncInferenceClient:
|
|
|
400
398
|
]
|
|
401
399
|
```
|
|
402
400
|
"""
|
|
403
|
-
|
|
401
|
+
parameters = {"function_to_apply": function_to_apply, "top_k": top_k}
|
|
402
|
+
payload = _prepare_payload(audio, parameters=parameters, expect_binary=True)
|
|
403
|
+
response = await self.post(**payload, model=model, task="audio-classification")
|
|
404
404
|
return AudioClassificationOutputElement.parse_obj_as_list(response)
|
|
405
405
|
|
|
406
406
|
async def audio_to_audio(
|
|
@@ -487,7 +487,7 @@ class AsyncInferenceClient:
|
|
|
487
487
|
@overload
|
|
488
488
|
async def chat_completion( # type: ignore
|
|
489
489
|
self,
|
|
490
|
-
messages: List[Dict
|
|
490
|
+
messages: List[Dict],
|
|
491
491
|
*,
|
|
492
492
|
model: Optional[str] = None,
|
|
493
493
|
stream: Literal[False] = False,
|
|
@@ -500,10 +500,11 @@ class AsyncInferenceClient:
|
|
|
500
500
|
response_format: Optional[ChatCompletionInputGrammarType] = None,
|
|
501
501
|
seed: Optional[int] = None,
|
|
502
502
|
stop: Optional[List[str]] = None,
|
|
503
|
+
stream_options: Optional[ChatCompletionInputStreamOptions] = None,
|
|
503
504
|
temperature: Optional[float] = None,
|
|
504
|
-
tool_choice: Optional[Union[
|
|
505
|
+
tool_choice: Optional[Union[ChatCompletionInputToolType, str]] = None,
|
|
505
506
|
tool_prompt: Optional[str] = None,
|
|
506
|
-
tools: Optional[List[
|
|
507
|
+
tools: Optional[List[ToolElement]] = None,
|
|
507
508
|
top_logprobs: Optional[int] = None,
|
|
508
509
|
top_p: Optional[float] = None,
|
|
509
510
|
) -> ChatCompletionOutput: ...
|
|
@@ -511,7 +512,7 @@ class AsyncInferenceClient:
|
|
|
511
512
|
@overload
|
|
512
513
|
async def chat_completion( # type: ignore
|
|
513
514
|
self,
|
|
514
|
-
messages: List[Dict
|
|
515
|
+
messages: List[Dict],
|
|
515
516
|
*,
|
|
516
517
|
model: Optional[str] = None,
|
|
517
518
|
stream: Literal[True] = True,
|
|
@@ -524,10 +525,11 @@ class AsyncInferenceClient:
|
|
|
524
525
|
response_format: Optional[ChatCompletionInputGrammarType] = None,
|
|
525
526
|
seed: Optional[int] = None,
|
|
526
527
|
stop: Optional[List[str]] = None,
|
|
528
|
+
stream_options: Optional[ChatCompletionInputStreamOptions] = None,
|
|
527
529
|
temperature: Optional[float] = None,
|
|
528
|
-
tool_choice: Optional[Union[
|
|
530
|
+
tool_choice: Optional[Union[ChatCompletionInputToolType, str]] = None,
|
|
529
531
|
tool_prompt: Optional[str] = None,
|
|
530
|
-
tools: Optional[List[
|
|
532
|
+
tools: Optional[List[ToolElement]] = None,
|
|
531
533
|
top_logprobs: Optional[int] = None,
|
|
532
534
|
top_p: Optional[float] = None,
|
|
533
535
|
) -> AsyncIterable[ChatCompletionStreamOutput]: ...
|
|
@@ -535,7 +537,7 @@ class AsyncInferenceClient:
|
|
|
535
537
|
@overload
|
|
536
538
|
async def chat_completion(
|
|
537
539
|
self,
|
|
538
|
-
messages: List[Dict
|
|
540
|
+
messages: List[Dict],
|
|
539
541
|
*,
|
|
540
542
|
model: Optional[str] = None,
|
|
541
543
|
stream: bool = False,
|
|
@@ -548,17 +550,18 @@ class AsyncInferenceClient:
|
|
|
548
550
|
response_format: Optional[ChatCompletionInputGrammarType] = None,
|
|
549
551
|
seed: Optional[int] = None,
|
|
550
552
|
stop: Optional[List[str]] = None,
|
|
553
|
+
stream_options: Optional[ChatCompletionInputStreamOptions] = None,
|
|
551
554
|
temperature: Optional[float] = None,
|
|
552
|
-
tool_choice: Optional[Union[
|
|
555
|
+
tool_choice: Optional[Union[ChatCompletionInputToolType, str]] = None,
|
|
553
556
|
tool_prompt: Optional[str] = None,
|
|
554
|
-
tools: Optional[List[
|
|
557
|
+
tools: Optional[List[ToolElement]] = None,
|
|
555
558
|
top_logprobs: Optional[int] = None,
|
|
556
559
|
top_p: Optional[float] = None,
|
|
557
560
|
) -> Union[ChatCompletionOutput, AsyncIterable[ChatCompletionStreamOutput]]: ...
|
|
558
561
|
|
|
559
562
|
async def chat_completion(
|
|
560
563
|
self,
|
|
561
|
-
messages: List[Dict
|
|
564
|
+
messages: List[Dict],
|
|
562
565
|
*,
|
|
563
566
|
model: Optional[str] = None,
|
|
564
567
|
stream: bool = False,
|
|
@@ -572,10 +575,11 @@ class AsyncInferenceClient:
|
|
|
572
575
|
response_format: Optional[ChatCompletionInputGrammarType] = None,
|
|
573
576
|
seed: Optional[int] = None,
|
|
574
577
|
stop: Optional[List[str]] = None,
|
|
578
|
+
stream_options: Optional[ChatCompletionInputStreamOptions] = None,
|
|
575
579
|
temperature: Optional[float] = None,
|
|
576
|
-
tool_choice: Optional[Union[
|
|
580
|
+
tool_choice: Optional[Union[ChatCompletionInputToolType, str]] = None,
|
|
577
581
|
tool_prompt: Optional[str] = None,
|
|
578
|
-
tools: Optional[List[
|
|
582
|
+
tools: Optional[List[ToolElement]] = None,
|
|
579
583
|
top_logprobs: Optional[int] = None,
|
|
580
584
|
top_p: Optional[float] = None,
|
|
581
585
|
) -> Union[ChatCompletionOutput, AsyncIterable[ChatCompletionStreamOutput]]:
|
|
@@ -592,7 +596,7 @@ class AsyncInferenceClient:
|
|
|
592
596
|
</Tip>
|
|
593
597
|
|
|
594
598
|
Args:
|
|
595
|
-
messages (List
|
|
599
|
+
messages (List of [`ChatCompletionInputMessage`]):
|
|
596
600
|
Conversation history consisting of roles and content pairs.
|
|
597
601
|
model (`str`, *optional*):
|
|
598
602
|
The model to use for chat-completion. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
|
@@ -629,6 +633,8 @@ class AsyncInferenceClient:
|
|
|
629
633
|
Defaults to None.
|
|
630
634
|
stream (`bool`, *optional*):
|
|
631
635
|
Enable realtime streaming of responses. Defaults to False.
|
|
636
|
+
stream_options ([`ChatCompletionInputStreamOptions`], *optional*):
|
|
637
|
+
Options for streaming completions.
|
|
632
638
|
temperature (`float`, *optional*):
|
|
633
639
|
Controls randomness of the generations. Lower values ensure
|
|
634
640
|
less random completions. Range: [0, 2]. Defaults to 1.0.
|
|
@@ -639,11 +645,11 @@ class AsyncInferenceClient:
|
|
|
639
645
|
top_p (`float`, *optional*):
|
|
640
646
|
Fraction of the most likely next words to sample from.
|
|
641
647
|
Must be between 0 and 1. Defaults to 1.0.
|
|
642
|
-
tool_choice ([`
|
|
648
|
+
tool_choice ([`ChatCompletionInputToolType`] or `str`, *optional*):
|
|
643
649
|
The tool to use for the completion. Defaults to "auto".
|
|
644
650
|
tool_prompt (`str`, *optional*):
|
|
645
651
|
A prompt to be appended before the tools.
|
|
646
|
-
tools (List of [`
|
|
652
|
+
tools (List of [`ToolElement`], *optional*):
|
|
647
653
|
A list of tools the model may call. Currently, only functions are supported as a tool. Use this to
|
|
648
654
|
provide a list of functions the model may generate JSON inputs for.
|
|
649
655
|
|
|
@@ -694,7 +700,7 @@ class AsyncInferenceClient:
|
|
|
694
700
|
)
|
|
695
701
|
```
|
|
696
702
|
|
|
697
|
-
Example
|
|
703
|
+
Example using streaming:
|
|
698
704
|
```py
|
|
699
705
|
# Must be run in an async context
|
|
700
706
|
>>> from huggingface_hub import AsyncInferenceClient
|
|
@@ -734,6 +740,41 @@ class AsyncInferenceClient:
|
|
|
734
740
|
print(chunk.choices[0].delta.content)
|
|
735
741
|
```
|
|
736
742
|
|
|
743
|
+
Example using Image + Text as input:
|
|
744
|
+
```py
|
|
745
|
+
# Must be run in an async context
|
|
746
|
+
>>> from huggingface_hub import AsyncInferenceClient
|
|
747
|
+
|
|
748
|
+
# provide a remote URL
|
|
749
|
+
>>> image_url ="https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"
|
|
750
|
+
# or a base64-encoded image
|
|
751
|
+
>>> image_path = "/path/to/image.jpeg"
|
|
752
|
+
>>> with open(image_path, "rb") as f:
|
|
753
|
+
... base64_image = base64.b64encode(f.read()).decode("utf-8")
|
|
754
|
+
>>> image_url = f"data:image/jpeg;base64,{base64_image}"
|
|
755
|
+
|
|
756
|
+
>>> client = AsyncInferenceClient("meta-llama/Llama-3.2-11B-Vision-Instruct")
|
|
757
|
+
>>> output = await client.chat.completions.create(
|
|
758
|
+
... messages=[
|
|
759
|
+
... {
|
|
760
|
+
... "role": "user",
|
|
761
|
+
... "content": [
|
|
762
|
+
... {
|
|
763
|
+
... "type": "image_url",
|
|
764
|
+
... "image_url": {"url": image_url},
|
|
765
|
+
... },
|
|
766
|
+
... {
|
|
767
|
+
... "type": "text",
|
|
768
|
+
... "text": "Describe this image in one sentence.",
|
|
769
|
+
... },
|
|
770
|
+
... ],
|
|
771
|
+
... },
|
|
772
|
+
... ],
|
|
773
|
+
... )
|
|
774
|
+
>>> output
|
|
775
|
+
The image depicts the iconic Statue of Liberty situated in New York Harbor, New York, on a clear day.
|
|
776
|
+
```
|
|
777
|
+
|
|
737
778
|
Example using tools:
|
|
738
779
|
```py
|
|
739
780
|
# Must be run in an async context
|
|
@@ -877,6 +918,7 @@ class AsyncInferenceClient:
|
|
|
877
918
|
top_logprobs=top_logprobs,
|
|
878
919
|
top_p=top_p,
|
|
879
920
|
stream=stream,
|
|
921
|
+
stream_options=stream_options,
|
|
880
922
|
)
|
|
881
923
|
payload = {key: value for key, value in payload.items() if value is not None}
|
|
882
924
|
data = await self.post(model=model_url, json=payload, stream=stream)
|
|
@@ -917,6 +959,14 @@ class AsyncInferenceClient:
|
|
|
917
959
|
question: str,
|
|
918
960
|
*,
|
|
919
961
|
model: Optional[str] = None,
|
|
962
|
+
doc_stride: Optional[int] = None,
|
|
963
|
+
handle_impossible_answer: Optional[bool] = None,
|
|
964
|
+
lang: Optional[str] = None,
|
|
965
|
+
max_answer_len: Optional[int] = None,
|
|
966
|
+
max_question_len: Optional[int] = None,
|
|
967
|
+
max_seq_len: Optional[int] = None,
|
|
968
|
+
top_k: Optional[int] = None,
|
|
969
|
+
word_boxes: Optional[List[Union[List[float], str]]] = None,
|
|
920
970
|
) -> List[DocumentQuestionAnsweringOutputElement]:
|
|
921
971
|
"""
|
|
922
972
|
Answer questions on document images.
|
|
@@ -930,7 +980,29 @@ class AsyncInferenceClient:
|
|
|
930
980
|
The model to use for the document question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to
|
|
931
981
|
a deployed Inference Endpoint. If not provided, the default recommended document question answering model will be used.
|
|
932
982
|
Defaults to None.
|
|
933
|
-
|
|
983
|
+
doc_stride (`int`, *optional*):
|
|
984
|
+
If the words in the document are too long to fit with the question for the model, it will
|
|
985
|
+
be split in several chunks with some overlap. This argument controls the size of that
|
|
986
|
+
overlap.
|
|
987
|
+
handle_impossible_answer (`bool`, *optional*):
|
|
988
|
+
Whether to accept impossible as an answer.
|
|
989
|
+
lang (`str`, *optional*):
|
|
990
|
+
Language to use while running OCR.
|
|
991
|
+
max_answer_len (`int`, *optional*):
|
|
992
|
+
The maximum length of predicted answers (e.g., only answers with a shorter length are
|
|
993
|
+
considered).
|
|
994
|
+
max_question_len (`int`, *optional*):
|
|
995
|
+
The maximum length of the question after tokenization. It will be truncated if needed.
|
|
996
|
+
max_seq_len (`int`, *optional*):
|
|
997
|
+
The maximum length of the total sentence (context + question) in tokens of each chunk
|
|
998
|
+
passed to the model. The context will be split in several chunks (using doc_stride as
|
|
999
|
+
overlap) if needed.
|
|
1000
|
+
top_k (`int`, *optional*):
|
|
1001
|
+
The number of answers to return (will be chosen by order of likelihood). Can return less
|
|
1002
|
+
than top_k answers if there are not enough options available within the context.
|
|
1003
|
+
word_boxes (`List[Union[List[float], str]]`, *optional*):
|
|
1004
|
+
A list of words and bounding boxes (normalized 0->1000). If provided, the inference will
|
|
1005
|
+
skip the OCR step and use the provided bounding boxes instead.
|
|
934
1006
|
Returns:
|
|
935
1007
|
`List[DocumentQuestionAnsweringOutputElement]`: a list of [`DocumentQuestionAnsweringOutputElement`] items containing the predicted label, associated probability, word ids, and page number.
|
|
936
1008
|
|
|
@@ -940,17 +1012,29 @@ class AsyncInferenceClient:
|
|
|
940
1012
|
`aiohttp.ClientResponseError`:
|
|
941
1013
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
942
1014
|
|
|
1015
|
+
|
|
943
1016
|
Example:
|
|
944
1017
|
```py
|
|
945
1018
|
# Must be run in an async context
|
|
946
1019
|
>>> from huggingface_hub import AsyncInferenceClient
|
|
947
1020
|
>>> client = AsyncInferenceClient()
|
|
948
1021
|
>>> await client.document_question_answering(image="https://huggingface.co/spaces/impira/docquery/resolve/2359223c1837a7587402bda0f2643382a6eefeab/invoice.png", question="What is the invoice number?")
|
|
949
|
-
[DocumentQuestionAnsweringOutputElement(
|
|
1022
|
+
[DocumentQuestionAnsweringOutputElement(answer='us-001', end=16, score=0.9999666213989258, start=16, words=None)]
|
|
950
1023
|
```
|
|
951
1024
|
"""
|
|
952
|
-
|
|
953
|
-
|
|
1025
|
+
inputs: Dict[str, Any] = {"question": question, "image": _b64_encode(image)}
|
|
1026
|
+
parameters = {
|
|
1027
|
+
"doc_stride": doc_stride,
|
|
1028
|
+
"handle_impossible_answer": handle_impossible_answer,
|
|
1029
|
+
"lang": lang,
|
|
1030
|
+
"max_answer_len": max_answer_len,
|
|
1031
|
+
"max_question_len": max_question_len,
|
|
1032
|
+
"max_seq_len": max_seq_len,
|
|
1033
|
+
"top_k": top_k,
|
|
1034
|
+
"word_boxes": word_boxes,
|
|
1035
|
+
}
|
|
1036
|
+
payload = _prepare_payload(inputs, parameters=parameters)
|
|
1037
|
+
response = await self.post(**payload, model=model, task="document-question-answering")
|
|
954
1038
|
return DocumentQuestionAnsweringOutputElement.parse_obj_as_list(response)
|
|
955
1039
|
|
|
956
1040
|
async def feature_extraction(
|
|
@@ -974,7 +1058,7 @@ class AsyncInferenceClient:
|
|
|
974
1058
|
a deployed Inference Endpoint. If not provided, the default recommended conversational model will be used.
|
|
975
1059
|
Defaults to None.
|
|
976
1060
|
normalize (`bool`, *optional*):
|
|
977
|
-
Whether to normalize the embeddings or not.
|
|
1061
|
+
Whether to normalize the embeddings or not.
|
|
978
1062
|
Only available on server powered by Text-Embedding-Inference.
|
|
979
1063
|
prompt_name (`str`, *optional*):
|
|
980
1064
|
The name of the prompt that should be used by for encoding. If not set, no prompt will be applied.
|
|
@@ -983,7 +1067,7 @@ class AsyncInferenceClient:
|
|
|
983
1067
|
then the sentence "What is the capital of France?" will be encoded as "query: What is the capital of France?"
|
|
984
1068
|
because the prompt text will be prepended before any text to encode.
|
|
985
1069
|
truncate (`bool`, *optional*):
|
|
986
|
-
Whether to truncate the embeddings or not.
|
|
1070
|
+
Whether to truncate the embeddings or not.
|
|
987
1071
|
Only available on server powered by Text-Embedding-Inference.
|
|
988
1072
|
truncation_direction (`Literal["Left", "Right"]`, *optional*):
|
|
989
1073
|
Which side of the input should be truncated when `truncate=True` is passed.
|
|
@@ -1009,20 +1093,25 @@ class AsyncInferenceClient:
|
|
|
1009
1093
|
[ 0.28552425, -0.928395 , -1.2077185 , ..., 0.76810825, -2.1069427 , 0.6236161 ]], dtype=float32)
|
|
1010
1094
|
```
|
|
1011
1095
|
"""
|
|
1012
|
-
|
|
1013
|
-
|
|
1014
|
-
|
|
1015
|
-
|
|
1016
|
-
|
|
1017
|
-
|
|
1018
|
-
|
|
1019
|
-
|
|
1020
|
-
payload["truncation_direction"] = truncation_direction
|
|
1021
|
-
response = await self.post(json=payload, model=model, task="feature-extraction")
|
|
1096
|
+
parameters = {
|
|
1097
|
+
"normalize": normalize,
|
|
1098
|
+
"prompt_name": prompt_name,
|
|
1099
|
+
"truncate": truncate,
|
|
1100
|
+
"truncation_direction": truncation_direction,
|
|
1101
|
+
}
|
|
1102
|
+
payload = _prepare_payload(text, parameters=parameters)
|
|
1103
|
+
response = await self.post(**payload, model=model, task="feature-extraction")
|
|
1022
1104
|
np = _import_numpy()
|
|
1023
1105
|
return np.array(_bytes_to_dict(response), dtype="float32")
|
|
1024
1106
|
|
|
1025
|
-
async def fill_mask(
|
|
1107
|
+
async def fill_mask(
|
|
1108
|
+
self,
|
|
1109
|
+
text: str,
|
|
1110
|
+
*,
|
|
1111
|
+
model: Optional[str] = None,
|
|
1112
|
+
targets: Optional[List[str]] = None,
|
|
1113
|
+
top_k: Optional[int] = None,
|
|
1114
|
+
) -> List[FillMaskOutputElement]:
|
|
1026
1115
|
"""
|
|
1027
1116
|
Fill in a hole with a missing word (token to be precise).
|
|
1028
1117
|
|
|
@@ -1032,8 +1121,13 @@ class AsyncInferenceClient:
|
|
|
1032
1121
|
model (`str`, *optional*):
|
|
1033
1122
|
The model to use for the fill mask task. Can be a model ID hosted on the Hugging Face Hub or a URL to
|
|
1034
1123
|
a deployed Inference Endpoint. If not provided, the default recommended fill mask model will be used.
|
|
1035
|
-
|
|
1036
|
-
|
|
1124
|
+
targets (`List[str]`, *optional*):
|
|
1125
|
+
When passed, the model will limit the scores to the passed targets instead of looking up
|
|
1126
|
+
in the whole vocabulary. If the provided targets are not in the model vocab, they will be
|
|
1127
|
+
tokenized and the first resulting token will be used (with a warning, and that might be
|
|
1128
|
+
slower).
|
|
1129
|
+
top_k (`int`, *optional*):
|
|
1130
|
+
When passed, overrides the number of predictions to return.
|
|
1037
1131
|
Returns:
|
|
1038
1132
|
`List[FillMaskOutputElement]`: a list of [`FillMaskOutputElement`] items containing the predicted label, associated
|
|
1039
1133
|
probability, token reference, and completed text.
|
|
@@ -1056,7 +1150,9 @@ class AsyncInferenceClient:
|
|
|
1056
1150
|
]
|
|
1057
1151
|
```
|
|
1058
1152
|
"""
|
|
1059
|
-
|
|
1153
|
+
parameters = {"targets": targets, "top_k": top_k}
|
|
1154
|
+
payload = _prepare_payload(text, parameters=parameters)
|
|
1155
|
+
response = await self.post(**payload, model=model, task="fill-mask")
|
|
1060
1156
|
return FillMaskOutputElement.parse_obj_as_list(response)
|
|
1061
1157
|
|
|
1062
1158
|
async def image_classification(
|
|
@@ -1064,6 +1160,8 @@ class AsyncInferenceClient:
|
|
|
1064
1160
|
image: ContentT,
|
|
1065
1161
|
*,
|
|
1066
1162
|
model: Optional[str] = None,
|
|
1163
|
+
function_to_apply: Optional[Literal["sigmoid", "softmax", "none"]] = None,
|
|
1164
|
+
top_k: Optional[int] = None,
|
|
1067
1165
|
) -> List[ImageClassificationOutputElement]:
|
|
1068
1166
|
"""
|
|
1069
1167
|
Perform image classification on the given image using the specified model.
|
|
@@ -1074,7 +1172,10 @@ class AsyncInferenceClient:
|
|
|
1074
1172
|
model (`str`, *optional*):
|
|
1075
1173
|
The model to use for image classification. Can be a model ID hosted on the Hugging Face Hub or a URL to a
|
|
1076
1174
|
deployed Inference Endpoint. If not provided, the default recommended model for image classification will be used.
|
|
1077
|
-
|
|
1175
|
+
function_to_apply (`Literal["sigmoid", "softmax", "none"]`, *optional*):
|
|
1176
|
+
The function to apply to the output scores.
|
|
1177
|
+
top_k (`int`, *optional*):
|
|
1178
|
+
When specified, limits the output to the top K most probable classes.
|
|
1078
1179
|
Returns:
|
|
1079
1180
|
`List[ImageClassificationOutputElement]`: a list of [`ImageClassificationOutputElement`] items containing the predicted label and associated probability.
|
|
1080
1181
|
|
|
@@ -1090,10 +1191,12 @@ class AsyncInferenceClient:
|
|
|
1090
1191
|
>>> from huggingface_hub import AsyncInferenceClient
|
|
1091
1192
|
>>> client = AsyncInferenceClient()
|
|
1092
1193
|
>>> await client.image_classification("https://upload.wikimedia.org/wikipedia/commons/thumb/4/43/Cute_dog.jpg/320px-Cute_dog.jpg")
|
|
1093
|
-
[ImageClassificationOutputElement(
|
|
1194
|
+
[ImageClassificationOutputElement(label='Blenheim spaniel', score=0.9779096841812134), ...]
|
|
1094
1195
|
```
|
|
1095
1196
|
"""
|
|
1096
|
-
|
|
1197
|
+
parameters = {"function_to_apply": function_to_apply, "top_k": top_k}
|
|
1198
|
+
payload = _prepare_payload(image, parameters=parameters, expect_binary=True)
|
|
1199
|
+
response = await self.post(**payload, model=model, task="image-classification")
|
|
1097
1200
|
return ImageClassificationOutputElement.parse_obj_as_list(response)
|
|
1098
1201
|
|
|
1099
1202
|
async def image_segmentation(
|
|
@@ -1101,6 +1204,10 @@ class AsyncInferenceClient:
|
|
|
1101
1204
|
image: ContentT,
|
|
1102
1205
|
*,
|
|
1103
1206
|
model: Optional[str] = None,
|
|
1207
|
+
mask_threshold: Optional[float] = None,
|
|
1208
|
+
overlap_mask_area_threshold: Optional[float] = None,
|
|
1209
|
+
subtask: Optional[Literal["instance", "panoptic", "semantic"]] = None,
|
|
1210
|
+
threshold: Optional[float] = None,
|
|
1104
1211
|
) -> List[ImageSegmentationOutputElement]:
|
|
1105
1212
|
"""
|
|
1106
1213
|
Perform image segmentation on the given image using the specified model.
|
|
@@ -1117,7 +1224,14 @@ class AsyncInferenceClient:
|
|
|
1117
1224
|
model (`str`, *optional*):
|
|
1118
1225
|
The model to use for image segmentation. Can be a model ID hosted on the Hugging Face Hub or a URL to a
|
|
1119
1226
|
deployed Inference Endpoint. If not provided, the default recommended model for image segmentation will be used.
|
|
1120
|
-
|
|
1227
|
+
mask_threshold (`float`, *optional*):
|
|
1228
|
+
Threshold to use when turning the predicted masks into binary values.
|
|
1229
|
+
overlap_mask_area_threshold (`float`, *optional*):
|
|
1230
|
+
Mask overlap threshold to eliminate small, disconnected segments.
|
|
1231
|
+
subtask (`Literal["instance", "panoptic", "semantic"]`, *optional*):
|
|
1232
|
+
Segmentation task to be performed, depending on model capabilities.
|
|
1233
|
+
threshold (`float`, *optional*):
|
|
1234
|
+
Probability threshold to filter out predicted masks.
|
|
1121
1235
|
Returns:
|
|
1122
1236
|
`List[ImageSegmentationOutputElement]`: A list of [`ImageSegmentationOutputElement`] items containing the segmented masks and associated attributes.
|
|
1123
1237
|
|
|
@@ -1132,14 +1246,21 @@ class AsyncInferenceClient:
|
|
|
1132
1246
|
# Must be run in an async context
|
|
1133
1247
|
>>> from huggingface_hub import AsyncInferenceClient
|
|
1134
1248
|
>>> client = AsyncInferenceClient()
|
|
1135
|
-
>>> await client.image_segmentation("cat.jpg")
|
|
1249
|
+
>>> await client.image_segmentation("cat.jpg")
|
|
1136
1250
|
[ImageSegmentationOutputElement(score=0.989008, label='LABEL_184', mask=<PIL.PngImagePlugin.PngImageFile image mode=L size=400x300 at 0x7FDD2B129CC0>), ...]
|
|
1137
1251
|
```
|
|
1138
1252
|
"""
|
|
1139
|
-
|
|
1253
|
+
parameters = {
|
|
1254
|
+
"mask_threshold": mask_threshold,
|
|
1255
|
+
"overlap_mask_area_threshold": overlap_mask_area_threshold,
|
|
1256
|
+
"subtask": subtask,
|
|
1257
|
+
"threshold": threshold,
|
|
1258
|
+
}
|
|
1259
|
+
payload = _prepare_payload(image, parameters=parameters, expect_binary=True)
|
|
1260
|
+
response = await self.post(**payload, model=model, task="image-segmentation")
|
|
1140
1261
|
output = ImageSegmentationOutputElement.parse_obj_as_list(response)
|
|
1141
1262
|
for item in output:
|
|
1142
|
-
item.mask = _b64_to_image(item.mask)
|
|
1263
|
+
item.mask = _b64_to_image(item.mask) # type: ignore [assignment]
|
|
1143
1264
|
return output
|
|
1144
1265
|
|
|
1145
1266
|
async def image_to_image(
|
|
@@ -1212,19 +1333,8 @@ class AsyncInferenceClient:
|
|
|
1212
1333
|
"guidance_scale": guidance_scale,
|
|
1213
1334
|
**kwargs,
|
|
1214
1335
|
}
|
|
1215
|
-
|
|
1216
|
-
|
|
1217
|
-
data = image
|
|
1218
|
-
payload: Optional[Dict[str, Any]] = None
|
|
1219
|
-
else:
|
|
1220
|
-
# Or an image + some parameters => use base64 encoding
|
|
1221
|
-
data = None
|
|
1222
|
-
payload = {"inputs": _b64_encode(image)}
|
|
1223
|
-
for key, value in parameters.items():
|
|
1224
|
-
if value is not None:
|
|
1225
|
-
payload.setdefault("parameters", {})[key] = value
|
|
1226
|
-
|
|
1227
|
-
response = await self.post(json=payload, data=data, model=model, task="image-to-image")
|
|
1336
|
+
payload = _prepare_payload(image, parameters=parameters, expect_binary=True)
|
|
1337
|
+
response = await self.post(**payload, model=model, task="image-to-image")
|
|
1228
1338
|
return _bytes_to_image(response)
|
|
1229
1339
|
|
|
1230
1340
|
async def image_to_text(self, image: ContentT, *, model: Optional[str] = None) -> ImageToTextOutput:
|
|
@@ -1355,10 +1465,7 @@ class AsyncInferenceClient:
|
|
|
1355
1465
|
return models_by_task
|
|
1356
1466
|
|
|
1357
1467
|
async def object_detection(
|
|
1358
|
-
self,
|
|
1359
|
-
image: ContentT,
|
|
1360
|
-
*,
|
|
1361
|
-
model: Optional[str] = None,
|
|
1468
|
+
self, image: ContentT, *, model: Optional[str] = None, threshold: Optional[float] = None
|
|
1362
1469
|
) -> List[ObjectDetectionOutputElement]:
|
|
1363
1470
|
"""
|
|
1364
1471
|
Perform object detection on the given image using the specified model.
|
|
@@ -1375,7 +1482,8 @@ class AsyncInferenceClient:
|
|
|
1375
1482
|
model (`str`, *optional*):
|
|
1376
1483
|
The model to use for object detection. Can be a model ID hosted on the Hugging Face Hub or a URL to a
|
|
1377
1484
|
deployed Inference Endpoint. If not provided, the default recommended model for object detection (DETR) will be used.
|
|
1378
|
-
|
|
1485
|
+
threshold (`float`, *optional*):
|
|
1486
|
+
The probability necessary to make a prediction.
|
|
1379
1487
|
Returns:
|
|
1380
1488
|
`List[ObjectDetectionOutputElement]`: A list of [`ObjectDetectionOutputElement`] items containing the bounding boxes and associated attributes.
|
|
1381
1489
|
|
|
@@ -1392,17 +1500,31 @@ class AsyncInferenceClient:
|
|
|
1392
1500
|
# Must be run in an async context
|
|
1393
1501
|
>>> from huggingface_hub import AsyncInferenceClient
|
|
1394
1502
|
>>> client = AsyncInferenceClient()
|
|
1395
|
-
>>> await client.object_detection("people.jpg")
|
|
1503
|
+
>>> await client.object_detection("people.jpg")
|
|
1396
1504
|
[ObjectDetectionOutputElement(score=0.9486683011054993, label='person', box=ObjectDetectionBoundingBox(xmin=59, ymin=39, xmax=420, ymax=510)), ...]
|
|
1397
1505
|
```
|
|
1398
1506
|
"""
|
|
1399
|
-
|
|
1400
|
-
|
|
1507
|
+
parameters = {
|
|
1508
|
+
"threshold": threshold,
|
|
1509
|
+
}
|
|
1510
|
+
payload = _prepare_payload(image, parameters=parameters, expect_binary=True)
|
|
1511
|
+
response = await self.post(**payload, model=model, task="object-detection")
|
|
1401
1512
|
return ObjectDetectionOutputElement.parse_obj_as_list(response)
|
|
1402
1513
|
|
|
1403
1514
|
async def question_answering(
|
|
1404
|
-
self,
|
|
1405
|
-
|
|
1515
|
+
self,
|
|
1516
|
+
question: str,
|
|
1517
|
+
context: str,
|
|
1518
|
+
*,
|
|
1519
|
+
model: Optional[str] = None,
|
|
1520
|
+
align_to_words: Optional[bool] = None,
|
|
1521
|
+
doc_stride: Optional[int] = None,
|
|
1522
|
+
handle_impossible_answer: Optional[bool] = None,
|
|
1523
|
+
max_answer_len: Optional[int] = None,
|
|
1524
|
+
max_question_len: Optional[int] = None,
|
|
1525
|
+
max_seq_len: Optional[int] = None,
|
|
1526
|
+
top_k: Optional[int] = None,
|
|
1527
|
+
) -> Union[QuestionAnsweringOutputElement, List[QuestionAnsweringOutputElement]]:
|
|
1406
1528
|
"""
|
|
1407
1529
|
Retrieve the answer to a question from a given text.
|
|
1408
1530
|
|
|
@@ -1414,10 +1536,31 @@ class AsyncInferenceClient:
|
|
|
1414
1536
|
model (`str`):
|
|
1415
1537
|
The model to use for the question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to
|
|
1416
1538
|
a deployed Inference Endpoint.
|
|
1417
|
-
|
|
1539
|
+
align_to_words (`bool`, *optional*):
|
|
1540
|
+
Attempts to align the answer to real words. Improves quality on space separated
|
|
1541
|
+
languages. Might hurt on non-space-separated languages (like Japanese or Chinese).
|
|
1542
|
+
doc_stride (`int`, *optional*):
|
|
1543
|
+
If the context is too long to fit with the question for the model, it will be split in
|
|
1544
|
+
several chunks with some overlap. This argument controls the size of that overlap.
|
|
1545
|
+
handle_impossible_answer (`bool`, *optional*):
|
|
1546
|
+
Whether to accept impossible as an answer.
|
|
1547
|
+
max_answer_len (`int`, *optional*):
|
|
1548
|
+
The maximum length of predicted answers (e.g., only answers with a shorter length are
|
|
1549
|
+
considered).
|
|
1550
|
+
max_question_len (`int`, *optional*):
|
|
1551
|
+
The maximum length of the question after tokenization. It will be truncated if needed.
|
|
1552
|
+
max_seq_len (`int`, *optional*):
|
|
1553
|
+
The maximum length of the total sentence (context + question) in tokens of each chunk
|
|
1554
|
+
passed to the model. The context will be split in several chunks (using docStride as
|
|
1555
|
+
overlap) if needed.
|
|
1556
|
+
top_k (`int`, *optional*):
|
|
1557
|
+
The number of answers to return (will be chosen by order of likelihood). Note that we
|
|
1558
|
+
return less than topk answers if there are not enough options available within the
|
|
1559
|
+
context.
|
|
1418
1560
|
Returns:
|
|
1419
|
-
[`QuestionAnsweringOutputElement`]:
|
|
1420
|
-
|
|
1561
|
+
Union[`QuestionAnsweringOutputElement`, List[`QuestionAnsweringOutputElement`]]:
|
|
1562
|
+
When top_k is 1 or not provided, it returns a single `QuestionAnsweringOutputElement`.
|
|
1563
|
+
When top_k is greater than 1, it returns a list of `QuestionAnsweringOutputElement`.
|
|
1421
1564
|
Raises:
|
|
1422
1565
|
[`InferenceTimeoutError`]:
|
|
1423
1566
|
If the model is unavailable or the request times out.
|
|
@@ -1430,17 +1573,28 @@ class AsyncInferenceClient:
|
|
|
1430
1573
|
>>> from huggingface_hub import AsyncInferenceClient
|
|
1431
1574
|
>>> client = AsyncInferenceClient()
|
|
1432
1575
|
>>> await client.question_answering(question="What's my name?", context="My name is Clara and I live in Berkeley.")
|
|
1433
|
-
QuestionAnsweringOutputElement(
|
|
1576
|
+
QuestionAnsweringOutputElement(answer='Clara', end=16, score=0.9326565265655518, start=11)
|
|
1434
1577
|
```
|
|
1435
1578
|
"""
|
|
1436
|
-
|
|
1437
|
-
|
|
1579
|
+
parameters = {
|
|
1580
|
+
"align_to_words": align_to_words,
|
|
1581
|
+
"doc_stride": doc_stride,
|
|
1582
|
+
"handle_impossible_answer": handle_impossible_answer,
|
|
1583
|
+
"max_answer_len": max_answer_len,
|
|
1584
|
+
"max_question_len": max_question_len,
|
|
1585
|
+
"max_seq_len": max_seq_len,
|
|
1586
|
+
"top_k": top_k,
|
|
1587
|
+
}
|
|
1588
|
+
inputs: Dict[str, Any] = {"question": question, "context": context}
|
|
1589
|
+
payload = _prepare_payload(inputs, parameters=parameters)
|
|
1438
1590
|
response = await self.post(
|
|
1439
|
-
|
|
1591
|
+
**payload,
|
|
1440
1592
|
model=model,
|
|
1441
1593
|
task="question-answering",
|
|
1442
1594
|
)
|
|
1443
|
-
|
|
1595
|
+
# Parse the response as a single `QuestionAnsweringOutputElement` when top_k is 1 or not provided, or a list of `QuestionAnsweringOutputElement` to ensure backward compatibility.
|
|
1596
|
+
output = QuestionAnsweringOutputElement.parse_obj(response)
|
|
1597
|
+
return output
|
|
1444
1598
|
|
|
1445
1599
|
async def sentence_similarity(
|
|
1446
1600
|
self, sentence: str, other_sentences: List[str], *, model: Optional[str] = None
|
|
@@ -1490,12 +1644,23 @@ class AsyncInferenceClient:
|
|
|
1490
1644
|
)
|
|
1491
1645
|
return _bytes_to_list(response)
|
|
1492
1646
|
|
|
1647
|
+
@_deprecate_arguments(
|
|
1648
|
+
version="0.29",
|
|
1649
|
+
deprecated_args=["parameters"],
|
|
1650
|
+
custom_message=(
|
|
1651
|
+
"The `parameters` argument is deprecated and will be removed in a future version. "
|
|
1652
|
+
"Provide individual parameters instead: `clean_up_tokenization_spaces`, `generate_parameters`, and `truncation`."
|
|
1653
|
+
),
|
|
1654
|
+
)
|
|
1493
1655
|
async def summarization(
|
|
1494
1656
|
self,
|
|
1495
1657
|
text: str,
|
|
1496
1658
|
*,
|
|
1497
1659
|
parameters: Optional[Dict[str, Any]] = None,
|
|
1498
1660
|
model: Optional[str] = None,
|
|
1661
|
+
clean_up_tokenization_spaces: Optional[bool] = None,
|
|
1662
|
+
generate_parameters: Optional[Dict[str, Any]] = None,
|
|
1663
|
+
truncation: Optional[Literal["do_not_truncate", "longest_first", "only_first", "only_second"]] = None,
|
|
1499
1664
|
) -> SummarizationOutput:
|
|
1500
1665
|
"""
|
|
1501
1666
|
Generate a summary of a given text using a specified model.
|
|
@@ -1508,8 +1673,13 @@ class AsyncInferenceClient:
|
|
|
1508
1673
|
for more details.
|
|
1509
1674
|
model (`str`, *optional*):
|
|
1510
1675
|
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
|
1511
|
-
Inference Endpoint.
|
|
1512
|
-
|
|
1676
|
+
Inference Endpoint. If not provided, the default recommended model for summarization will be used.
|
|
1677
|
+
clean_up_tokenization_spaces (`bool`, *optional*):
|
|
1678
|
+
Whether to clean up the potential extra spaces in the text output.
|
|
1679
|
+
generate_parameters (`Dict[str, Any]`, *optional*):
|
|
1680
|
+
Additional parametrization of the text generation algorithm.
|
|
1681
|
+
truncation (`Literal["do_not_truncate", "longest_first", "only_first", "only_second"]`, *optional*):
|
|
1682
|
+
The truncation strategy to use.
|
|
1513
1683
|
Returns:
|
|
1514
1684
|
[`SummarizationOutput`]: The generated summary text.
|
|
1515
1685
|
|
|
@@ -1528,14 +1698,23 @@ class AsyncInferenceClient:
|
|
|
1528
1698
|
SummarizationOutput(generated_text="The Eiffel tower is one of the most famous landmarks in the world....")
|
|
1529
1699
|
```
|
|
1530
1700
|
"""
|
|
1531
|
-
|
|
1532
|
-
|
|
1533
|
-
|
|
1534
|
-
|
|
1701
|
+
if parameters is None:
|
|
1702
|
+
parameters = {
|
|
1703
|
+
"clean_up_tokenization_spaces": clean_up_tokenization_spaces,
|
|
1704
|
+
"generate_parameters": generate_parameters,
|
|
1705
|
+
"truncation": truncation,
|
|
1706
|
+
}
|
|
1707
|
+
payload = _prepare_payload(text, parameters=parameters)
|
|
1708
|
+
response = await self.post(**payload, model=model, task="summarization")
|
|
1535
1709
|
return SummarizationOutput.parse_obj_as_list(response)[0]
|
|
1536
1710
|
|
|
1537
1711
|
async def table_question_answering(
|
|
1538
|
-
self,
|
|
1712
|
+
self,
|
|
1713
|
+
table: Dict[str, Any],
|
|
1714
|
+
query: str,
|
|
1715
|
+
*,
|
|
1716
|
+
model: Optional[str] = None,
|
|
1717
|
+
parameters: Optional[Dict[str, Any]] = None,
|
|
1539
1718
|
) -> TableQuestionAnsweringOutputElement:
|
|
1540
1719
|
"""
|
|
1541
1720
|
Retrieve the answer to a question from information given in a table.
|
|
@@ -1549,6 +1728,8 @@ class AsyncInferenceClient:
|
|
|
1549
1728
|
model (`str`):
|
|
1550
1729
|
The model to use for the table-question-answering task. Can be a model ID hosted on the Hugging Face
|
|
1551
1730
|
Hub or a URL to a deployed Inference Endpoint.
|
|
1731
|
+
parameters (`Dict[str, Any]`, *optional*):
|
|
1732
|
+
Additional inference parameters. Defaults to None.
|
|
1552
1733
|
|
|
1553
1734
|
Returns:
|
|
1554
1735
|
[`TableQuestionAnsweringOutputElement`]: a table question answering output containing the answer, coordinates, cells and the aggregator used.
|
|
@@ -1570,11 +1751,13 @@ class AsyncInferenceClient:
|
|
|
1570
1751
|
TableQuestionAnsweringOutputElement(answer='36542', coordinates=[[0, 1]], cells=['36542'], aggregator='AVERAGE')
|
|
1571
1752
|
```
|
|
1572
1753
|
"""
|
|
1754
|
+
inputs = {
|
|
1755
|
+
"query": query,
|
|
1756
|
+
"table": table,
|
|
1757
|
+
}
|
|
1758
|
+
payload = _prepare_payload(inputs, parameters=parameters)
|
|
1573
1759
|
response = await self.post(
|
|
1574
|
-
|
|
1575
|
-
"query": query,
|
|
1576
|
-
"table": table,
|
|
1577
|
-
},
|
|
1760
|
+
**payload,
|
|
1578
1761
|
model=model,
|
|
1579
1762
|
task="table-question-answering",
|
|
1580
1763
|
)
|
|
@@ -1623,7 +1806,11 @@ class AsyncInferenceClient:
|
|
|
1623
1806
|
["5", "5", "5"]
|
|
1624
1807
|
```
|
|
1625
1808
|
"""
|
|
1626
|
-
response = await self.post(
|
|
1809
|
+
response = await self.post(
|
|
1810
|
+
json={"table": table},
|
|
1811
|
+
model=model,
|
|
1812
|
+
task="tabular-classification",
|
|
1813
|
+
)
|
|
1627
1814
|
return _bytes_to_list(response)
|
|
1628
1815
|
|
|
1629
1816
|
async def tabular_regression(self, table: Dict[str, Any], *, model: Optional[str] = None) -> List[float]:
|
|
@@ -1668,7 +1855,12 @@ class AsyncInferenceClient:
|
|
|
1668
1855
|
return _bytes_to_list(response)
|
|
1669
1856
|
|
|
1670
1857
|
async def text_classification(
|
|
1671
|
-
self,
|
|
1858
|
+
self,
|
|
1859
|
+
text: str,
|
|
1860
|
+
*,
|
|
1861
|
+
model: Optional[str] = None,
|
|
1862
|
+
top_k: Optional[int] = None,
|
|
1863
|
+
function_to_apply: Optional["TextClassificationOutputTransform"] = None,
|
|
1672
1864
|
) -> List[TextClassificationOutputElement]:
|
|
1673
1865
|
"""
|
|
1674
1866
|
Perform text classification (e.g. sentiment-analysis) on the given text.
|
|
@@ -1680,6 +1872,10 @@ class AsyncInferenceClient:
|
|
|
1680
1872
|
The model to use for the text classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to
|
|
1681
1873
|
a deployed Inference Endpoint. If not provided, the default recommended text classification model will be used.
|
|
1682
1874
|
Defaults to None.
|
|
1875
|
+
top_k (`int`, *optional*):
|
|
1876
|
+
When specified, limits the output to the top K most probable classes.
|
|
1877
|
+
function_to_apply (`"TextClassificationOutputTransform"`, *optional*):
|
|
1878
|
+
The function to apply to the output.
|
|
1683
1879
|
|
|
1684
1880
|
Returns:
|
|
1685
1881
|
`List[TextClassificationOutputElement]`: a list of [`TextClassificationOutputElement`] items containing the predicted label and associated probability.
|
|
@@ -1702,7 +1898,16 @@ class AsyncInferenceClient:
|
|
|
1702
1898
|
]
|
|
1703
1899
|
```
|
|
1704
1900
|
"""
|
|
1705
|
-
|
|
1901
|
+
parameters = {
|
|
1902
|
+
"function_to_apply": function_to_apply,
|
|
1903
|
+
"top_k": top_k,
|
|
1904
|
+
}
|
|
1905
|
+
payload = _prepare_payload(text, parameters=parameters)
|
|
1906
|
+
response = await self.post(
|
|
1907
|
+
**payload,
|
|
1908
|
+
model=model,
|
|
1909
|
+
task="text-classification",
|
|
1910
|
+
)
|
|
1706
1911
|
return TextClassificationOutputElement.parse_obj_as_list(response)[0] # type: ignore [return-value]
|
|
1707
1912
|
|
|
1708
1913
|
@overload
|
|
@@ -2212,6 +2417,9 @@ class AsyncInferenceClient:
|
|
|
2212
2417
|
num_inference_steps: Optional[float] = None,
|
|
2213
2418
|
guidance_scale: Optional[float] = None,
|
|
2214
2419
|
model: Optional[str] = None,
|
|
2420
|
+
scheduler: Optional[str] = None,
|
|
2421
|
+
target_size: Optional[TextToImageTargetSize] = None,
|
|
2422
|
+
seed: Optional[int] = None,
|
|
2215
2423
|
**kwargs,
|
|
2216
2424
|
) -> "Image":
|
|
2217
2425
|
"""
|
|
@@ -2240,7 +2448,14 @@ class AsyncInferenceClient:
|
|
|
2240
2448
|
usually at the expense of lower image quality.
|
|
2241
2449
|
model (`str`, *optional*):
|
|
2242
2450
|
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
|
2243
|
-
Inference Endpoint.
|
|
2451
|
+
Inference Endpoint. If not provided, the default recommended text-to-image model will be used.
|
|
2452
|
+
Defaults to None.
|
|
2453
|
+
scheduler (`str`, *optional*):
|
|
2454
|
+
Override the scheduler with a compatible one.
|
|
2455
|
+
target_size (`TextToImageTargetSize`, *optional*):
|
|
2456
|
+
The size in pixel of the output image
|
|
2457
|
+
seed (`int`, *optional*):
|
|
2458
|
+
Seed for the random number generator.
|
|
2244
2459
|
|
|
2245
2460
|
Returns:
|
|
2246
2461
|
`Image`: The generated image.
|
|
@@ -2268,22 +2483,44 @@ class AsyncInferenceClient:
|
|
|
2268
2483
|
>>> image.save("better_astronaut.png")
|
|
2269
2484
|
```
|
|
2270
2485
|
"""
|
|
2271
|
-
|
|
2486
|
+
|
|
2272
2487
|
parameters = {
|
|
2273
2488
|
"negative_prompt": negative_prompt,
|
|
2274
2489
|
"height": height,
|
|
2275
2490
|
"width": width,
|
|
2276
2491
|
"num_inference_steps": num_inference_steps,
|
|
2277
2492
|
"guidance_scale": guidance_scale,
|
|
2493
|
+
"scheduler": scheduler,
|
|
2494
|
+
"target_size": target_size,
|
|
2495
|
+
"seed": seed,
|
|
2278
2496
|
**kwargs,
|
|
2279
2497
|
}
|
|
2280
|
-
|
|
2281
|
-
|
|
2282
|
-
payload.setdefault("parameters", {})[key] = value # type: ignore
|
|
2283
|
-
response = await self.post(json=payload, model=model, task="text-to-image")
|
|
2498
|
+
payload = _prepare_payload(prompt, parameters=parameters)
|
|
2499
|
+
response = await self.post(**payload, model=model, task="text-to-image")
|
|
2284
2500
|
return _bytes_to_image(response)
|
|
2285
2501
|
|
|
2286
|
-
async def text_to_speech(
|
|
2502
|
+
async def text_to_speech(
|
|
2503
|
+
self,
|
|
2504
|
+
text: str,
|
|
2505
|
+
*,
|
|
2506
|
+
model: Optional[str] = None,
|
|
2507
|
+
do_sample: Optional[bool] = None,
|
|
2508
|
+
early_stopping: Optional[Union[bool, "TextToSpeechEarlyStoppingEnum"]] = None,
|
|
2509
|
+
epsilon_cutoff: Optional[float] = None,
|
|
2510
|
+
eta_cutoff: Optional[float] = None,
|
|
2511
|
+
max_length: Optional[int] = None,
|
|
2512
|
+
max_new_tokens: Optional[int] = None,
|
|
2513
|
+
min_length: Optional[int] = None,
|
|
2514
|
+
min_new_tokens: Optional[int] = None,
|
|
2515
|
+
num_beam_groups: Optional[int] = None,
|
|
2516
|
+
num_beams: Optional[int] = None,
|
|
2517
|
+
penalty_alpha: Optional[float] = None,
|
|
2518
|
+
temperature: Optional[float] = None,
|
|
2519
|
+
top_k: Optional[int] = None,
|
|
2520
|
+
top_p: Optional[float] = None,
|
|
2521
|
+
typical_p: Optional[float] = None,
|
|
2522
|
+
use_cache: Optional[bool] = None,
|
|
2523
|
+
) -> bytes:
|
|
2287
2524
|
"""
|
|
2288
2525
|
Synthesize an audio of a voice pronouncing a given text.
|
|
2289
2526
|
|
|
@@ -2292,7 +2529,56 @@ class AsyncInferenceClient:
|
|
|
2292
2529
|
The text to synthesize.
|
|
2293
2530
|
model (`str`, *optional*):
|
|
2294
2531
|
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
|
2295
|
-
Inference Endpoint.
|
|
2532
|
+
Inference Endpoint. If not provided, the default recommended text-to-speech model will be used.
|
|
2533
|
+
Defaults to None.
|
|
2534
|
+
do_sample (`bool`, *optional*):
|
|
2535
|
+
Whether to use sampling instead of greedy decoding when generating new tokens.
|
|
2536
|
+
early_stopping (`Union[bool, "TextToSpeechEarlyStoppingEnum"`, *optional*):
|
|
2537
|
+
Controls the stopping condition for beam-based methods.
|
|
2538
|
+
epsilon_cutoff (`float`, *optional*):
|
|
2539
|
+
If set to float strictly between 0 and 1, only tokens with a conditional probability
|
|
2540
|
+
greater than epsilon_cutoff will be sampled. In the paper, suggested values range from
|
|
2541
|
+
3e-4 to 9e-4, depending on the size of the model. See [Truncation Sampling as Language
|
|
2542
|
+
Model Desmoothing](https://hf.co/papers/2210.15191) for more details.
|
|
2543
|
+
eta_cutoff (`float`, *optional*):
|
|
2544
|
+
Eta sampling is a hybrid of locally typical sampling and epsilon sampling. If set to
|
|
2545
|
+
float strictly between 0 and 1, a token is only considered if it is greater than either
|
|
2546
|
+
eta_cutoff or sqrt(eta_cutoff) * exp(-entropy(softmax(next_token_logits))). The latter
|
|
2547
|
+
term is intuitively the expected next token probability, scaled by sqrt(eta_cutoff). In
|
|
2548
|
+
the paper, suggested values range from 3e-4 to 2e-3, depending on the size of the model.
|
|
2549
|
+
See [Truncation Sampling as Language Model Desmoothing](https://hf.co/papers/2210.15191)
|
|
2550
|
+
for more details.
|
|
2551
|
+
max_length (`int`, *optional*):
|
|
2552
|
+
The maximum length (in tokens) of the generated text, including the input.
|
|
2553
|
+
max_new_tokens (`int`, *optional*):
|
|
2554
|
+
The maximum number of tokens to generate. Takes precedence over maxLength.
|
|
2555
|
+
min_length (`int`, *optional*):
|
|
2556
|
+
The minimum length (in tokens) of the generated text, including the input.
|
|
2557
|
+
min_new_tokens (`int`, *optional*):
|
|
2558
|
+
The minimum number of tokens to generate. Takes precedence over maxLength.
|
|
2559
|
+
num_beam_groups (`int`, *optional*):
|
|
2560
|
+
Number of groups to divide num_beams into in order to ensure diversity among different
|
|
2561
|
+
groups of beams. See [this paper](https://hf.co/papers/1610.02424) for more details.
|
|
2562
|
+
num_beams (`int`, *optional*):
|
|
2563
|
+
Number of beams to use for beam search.
|
|
2564
|
+
penalty_alpha (`float`, *optional*):
|
|
2565
|
+
The value balances the model confidence and the degeneration penalty in contrastive
|
|
2566
|
+
search decoding.
|
|
2567
|
+
temperature (`float`, *optional*):
|
|
2568
|
+
The value used to modulate the next token probabilities.
|
|
2569
|
+
top_k (`int`, *optional*):
|
|
2570
|
+
The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
|
2571
|
+
top_p (`float`, *optional*):
|
|
2572
|
+
If set to float < 1, only the smallest set of most probable tokens with probabilities
|
|
2573
|
+
that add up to top_p or higher are kept for generation.
|
|
2574
|
+
typical_p (`float`, *optional*):
|
|
2575
|
+
Local typicality measures how similar the conditional probability of predicting a target token next is
|
|
2576
|
+
to the expected conditional probability of predicting a random token next, given the partial text
|
|
2577
|
+
already generated. If set to float < 1, the smallest set of the most locally typical tokens with
|
|
2578
|
+
probabilities that add up to typical_p or higher are kept for generation. See [this
|
|
2579
|
+
paper](https://hf.co/papers/2202.00666) for more details.
|
|
2580
|
+
use_cache (`bool`, *optional*):
|
|
2581
|
+
Whether the model should use the past last key/values attentions to speed up decoding
|
|
2296
2582
|
|
|
2297
2583
|
Returns:
|
|
2298
2584
|
`bytes`: The generated audio.
|
|
@@ -2314,10 +2600,36 @@ class AsyncInferenceClient:
|
|
|
2314
2600
|
>>> Path("hello_world.flac").write_bytes(audio)
|
|
2315
2601
|
```
|
|
2316
2602
|
"""
|
|
2317
|
-
|
|
2603
|
+
parameters = {
|
|
2604
|
+
"do_sample": do_sample,
|
|
2605
|
+
"early_stopping": early_stopping,
|
|
2606
|
+
"epsilon_cutoff": epsilon_cutoff,
|
|
2607
|
+
"eta_cutoff": eta_cutoff,
|
|
2608
|
+
"max_length": max_length,
|
|
2609
|
+
"max_new_tokens": max_new_tokens,
|
|
2610
|
+
"min_length": min_length,
|
|
2611
|
+
"min_new_tokens": min_new_tokens,
|
|
2612
|
+
"num_beam_groups": num_beam_groups,
|
|
2613
|
+
"num_beams": num_beams,
|
|
2614
|
+
"penalty_alpha": penalty_alpha,
|
|
2615
|
+
"temperature": temperature,
|
|
2616
|
+
"top_k": top_k,
|
|
2617
|
+
"top_p": top_p,
|
|
2618
|
+
"typical_p": typical_p,
|
|
2619
|
+
"use_cache": use_cache,
|
|
2620
|
+
}
|
|
2621
|
+
payload = _prepare_payload(text, parameters=parameters)
|
|
2622
|
+
response = await self.post(**payload, model=model, task="text-to-speech")
|
|
2623
|
+
return response
|
|
2318
2624
|
|
|
2319
2625
|
async def token_classification(
|
|
2320
|
-
self,
|
|
2626
|
+
self,
|
|
2627
|
+
text: str,
|
|
2628
|
+
*,
|
|
2629
|
+
model: Optional[str] = None,
|
|
2630
|
+
aggregation_strategy: Optional[Literal["none", "simple", "first", "average", "max"]] = None,
|
|
2631
|
+
ignore_labels: Optional[List[str]] = None,
|
|
2632
|
+
stride: Optional[int] = None,
|
|
2321
2633
|
) -> List[TokenClassificationOutputElement]:
|
|
2322
2634
|
"""
|
|
2323
2635
|
Perform token classification on the given text.
|
|
@@ -2330,6 +2642,12 @@ class AsyncInferenceClient:
|
|
|
2330
2642
|
The model to use for the token classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to
|
|
2331
2643
|
a deployed Inference Endpoint. If not provided, the default recommended token classification model will be used.
|
|
2332
2644
|
Defaults to None.
|
|
2645
|
+
aggregation_strategy (`Literal["none", "simple", "first", "average", "max"]`, *optional*):
|
|
2646
|
+
The strategy used to fuse tokens based on model predictions.
|
|
2647
|
+
ignore_labels (`List[str]`, *optional*):
|
|
2648
|
+
A list of labels to ignore.
|
|
2649
|
+
stride (`int`, *optional*):
|
|
2650
|
+
The number of overlapping tokens between chunks when splitting the input text.
|
|
2333
2651
|
|
|
2334
2652
|
Returns:
|
|
2335
2653
|
`List[TokenClassificationOutputElement]`: List of [`TokenClassificationOutputElement`] items containing the entity group, confidence score, word, start and end index.
|
|
@@ -2364,16 +2682,30 @@ class AsyncInferenceClient:
|
|
|
2364
2682
|
]
|
|
2365
2683
|
```
|
|
2366
2684
|
"""
|
|
2367
|
-
|
|
2685
|
+
|
|
2686
|
+
parameters = {
|
|
2687
|
+
"aggregation_strategy": aggregation_strategy,
|
|
2688
|
+
"ignore_labels": ignore_labels,
|
|
2689
|
+
"stride": stride,
|
|
2690
|
+
}
|
|
2691
|
+
payload = _prepare_payload(text, parameters=parameters)
|
|
2368
2692
|
response = await self.post(
|
|
2369
|
-
|
|
2693
|
+
**payload,
|
|
2370
2694
|
model=model,
|
|
2371
2695
|
task="token-classification",
|
|
2372
2696
|
)
|
|
2373
2697
|
return TokenClassificationOutputElement.parse_obj_as_list(response)
|
|
2374
2698
|
|
|
2375
2699
|
async def translation(
|
|
2376
|
-
self,
|
|
2700
|
+
self,
|
|
2701
|
+
text: str,
|
|
2702
|
+
*,
|
|
2703
|
+
model: Optional[str] = None,
|
|
2704
|
+
src_lang: Optional[str] = None,
|
|
2705
|
+
tgt_lang: Optional[str] = None,
|
|
2706
|
+
clean_up_tokenization_spaces: Optional[bool] = None,
|
|
2707
|
+
truncation: Optional[Literal["do_not_truncate", "longest_first", "only_first", "only_second"]] = None,
|
|
2708
|
+
generate_parameters: Optional[Dict[str, Any]] = None,
|
|
2377
2709
|
) -> TranslationOutput:
|
|
2378
2710
|
"""
|
|
2379
2711
|
Convert text from one language to another.
|
|
@@ -2382,7 +2714,6 @@ class AsyncInferenceClient:
|
|
|
2382
2714
|
your specific use case. Source and target languages usually depend on the model.
|
|
2383
2715
|
However, it is possible to specify source and target languages for certain models. If you are working with one of these models,
|
|
2384
2716
|
you can use `src_lang` and `tgt_lang` arguments to pass the relevant information.
|
|
2385
|
-
You can find this information in the model card.
|
|
2386
2717
|
|
|
2387
2718
|
Args:
|
|
2388
2719
|
text (`str`):
|
|
@@ -2392,9 +2723,15 @@ class AsyncInferenceClient:
|
|
|
2392
2723
|
a deployed Inference Endpoint. If not provided, the default recommended translation model will be used.
|
|
2393
2724
|
Defaults to None.
|
|
2394
2725
|
src_lang (`str`, *optional*):
|
|
2395
|
-
|
|
2726
|
+
The source language of the text. Required for models that can translate from multiple languages.
|
|
2396
2727
|
tgt_lang (`str`, *optional*):
|
|
2397
|
-
Target language
|
|
2728
|
+
Target language to translate to. Required for models that can translate to multiple languages.
|
|
2729
|
+
clean_up_tokenization_spaces (`bool`, *optional*):
|
|
2730
|
+
Whether to clean up the potential extra spaces in the text output.
|
|
2731
|
+
truncation (`Literal["do_not_truncate", "longest_first", "only_first", "only_second"]`, *optional*):
|
|
2732
|
+
The truncation strategy to use.
|
|
2733
|
+
generate_parameters (`Dict[str, Any]`, *optional*):
|
|
2734
|
+
Additional parametrization of the text generation algorithm.
|
|
2398
2735
|
|
|
2399
2736
|
Returns:
|
|
2400
2737
|
[`TranslationOutput`]: The generated translated text.
|
|
@@ -2430,12 +2767,15 @@ class AsyncInferenceClient:
|
|
|
2430
2767
|
|
|
2431
2768
|
if src_lang is None and tgt_lang is not None:
|
|
2432
2769
|
raise ValueError("You cannot specify `tgt_lang` without specifying `src_lang`.")
|
|
2433
|
-
|
|
2434
|
-
|
|
2435
|
-
|
|
2436
|
-
|
|
2437
|
-
|
|
2438
|
-
|
|
2770
|
+
parameters = {
|
|
2771
|
+
"src_lang": src_lang,
|
|
2772
|
+
"tgt_lang": tgt_lang,
|
|
2773
|
+
"clean_up_tokenization_spaces": clean_up_tokenization_spaces,
|
|
2774
|
+
"truncation": truncation,
|
|
2775
|
+
"generate_parameters": generate_parameters,
|
|
2776
|
+
}
|
|
2777
|
+
payload = _prepare_payload(text, parameters=parameters)
|
|
2778
|
+
response = await self.post(**payload, model=model, task="translation")
|
|
2439
2779
|
return TranslationOutput.parse_obj_as_list(response)[0]
|
|
2440
2780
|
|
|
2441
2781
|
async def visual_question_answering(
|
|
@@ -2444,6 +2784,7 @@ class AsyncInferenceClient:
|
|
|
2444
2784
|
question: str,
|
|
2445
2785
|
*,
|
|
2446
2786
|
model: Optional[str] = None,
|
|
2787
|
+
top_k: Optional[int] = None,
|
|
2447
2788
|
) -> List[VisualQuestionAnsweringOutputElement]:
|
|
2448
2789
|
"""
|
|
2449
2790
|
Answering open-ended questions based on an image.
|
|
@@ -2457,7 +2798,10 @@ class AsyncInferenceClient:
|
|
|
2457
2798
|
The model to use for the visual question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to
|
|
2458
2799
|
a deployed Inference Endpoint. If not provided, the default recommended visual question answering model will be used.
|
|
2459
2800
|
Defaults to None.
|
|
2460
|
-
|
|
2801
|
+
top_k (`int`, *optional*):
|
|
2802
|
+
The number of answers to return (will be chosen by order of likelihood). Note that we
|
|
2803
|
+
return less than topk answers if there are not enough options available within the
|
|
2804
|
+
context.
|
|
2461
2805
|
Returns:
|
|
2462
2806
|
`List[VisualQuestionAnsweringOutputElement]`: a list of [`VisualQuestionAnsweringOutputElement`] items containing the predicted label and associated probability.
|
|
2463
2807
|
|
|
@@ -2483,6 +2827,8 @@ class AsyncInferenceClient:
|
|
|
2483
2827
|
```
|
|
2484
2828
|
"""
|
|
2485
2829
|
payload: Dict[str, Any] = {"question": question, "image": _b64_encode(image)}
|
|
2830
|
+
if top_k is not None:
|
|
2831
|
+
payload.setdefault("parameters", {})["top_k"] = top_k
|
|
2486
2832
|
response = await self.post(json=payload, model=model, task="visual-question-answering")
|
|
2487
2833
|
return VisualQuestionAnsweringOutputElement.parse_obj_as_list(response)
|
|
2488
2834
|
|
|
@@ -2513,7 +2859,7 @@ class AsyncInferenceClient:
|
|
|
2513
2859
|
The model then evaluates for both hypotheses if they are entailed in the provided `text` or not.
|
|
2514
2860
|
model (`str`, *optional*):
|
|
2515
2861
|
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
|
2516
|
-
Inference Endpoint. This parameter overrides the model defined at the instance level.
|
|
2862
|
+
Inference Endpoint. This parameter overrides the model defined at the instance level. If not provided, the default recommended zero-shot classification model will be used.
|
|
2517
2863
|
|
|
2518
2864
|
Returns:
|
|
2519
2865
|
`List[ZeroShotClassificationOutputElement]`: List of [`ZeroShotClassificationOutputElement`] items containing the predicted labels and their confidence.
|
|
@@ -2573,15 +2919,14 @@ class AsyncInferenceClient:
|
|
|
2573
2919
|
```
|
|
2574
2920
|
"""
|
|
2575
2921
|
|
|
2576
|
-
parameters = {
|
|
2577
|
-
|
|
2578
|
-
|
|
2579
|
-
|
|
2922
|
+
parameters = {
|
|
2923
|
+
"candidate_labels": labels,
|
|
2924
|
+
"multi_label": multi_label,
|
|
2925
|
+
"hypothesis_template": hypothesis_template,
|
|
2926
|
+
}
|
|
2927
|
+
payload = _prepare_payload(text, parameters=parameters)
|
|
2580
2928
|
response = await self.post(
|
|
2581
|
-
|
|
2582
|
-
"inputs": text,
|
|
2583
|
-
"parameters": parameters,
|
|
2584
|
-
},
|
|
2929
|
+
**payload,
|
|
2585
2930
|
task="zero-shot-classification",
|
|
2586
2931
|
model=model,
|
|
2587
2932
|
)
|
|
@@ -2592,7 +2937,12 @@ class AsyncInferenceClient:
|
|
|
2592
2937
|
]
|
|
2593
2938
|
|
|
2594
2939
|
async def zero_shot_image_classification(
|
|
2595
|
-
self,
|
|
2940
|
+
self,
|
|
2941
|
+
image: ContentT,
|
|
2942
|
+
labels: List[str],
|
|
2943
|
+
*,
|
|
2944
|
+
model: Optional[str] = None,
|
|
2945
|
+
hypothesis_template: Optional[str] = None,
|
|
2596
2946
|
) -> List[ZeroShotImageClassificationOutputElement]:
|
|
2597
2947
|
"""
|
|
2598
2948
|
Provide input image and text labels to predict text labels for the image.
|
|
@@ -2604,8 +2954,10 @@ class AsyncInferenceClient:
|
|
|
2604
2954
|
List of string possible labels. There must be at least 2 labels.
|
|
2605
2955
|
model (`str`, *optional*):
|
|
2606
2956
|
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
|
2607
|
-
Inference Endpoint. This parameter overrides the model defined at the instance level.
|
|
2608
|
-
|
|
2957
|
+
Inference Endpoint. This parameter overrides the model defined at the instance level. If not provided, the default recommended zero-shot image classification model will be used.
|
|
2958
|
+
hypothesis_template (`str`, *optional*):
|
|
2959
|
+
The sentence used in conjunction with `labels` to attempt the text classification by replacing the
|
|
2960
|
+
placeholder with the candidate labels.
|
|
2609
2961
|
Returns:
|
|
2610
2962
|
`List[ZeroShotImageClassificationOutputElement]`: List of [`ZeroShotImageClassificationOutputElement`] items containing the predicted labels and their confidence.
|
|
2611
2963
|
|
|
@@ -2632,8 +2984,11 @@ class AsyncInferenceClient:
|
|
|
2632
2984
|
if len(labels) < 2:
|
|
2633
2985
|
raise ValueError("You must specify at least 2 classes to compare.")
|
|
2634
2986
|
|
|
2987
|
+
inputs = {"image": _b64_encode(image), "candidateLabels": ",".join(labels)}
|
|
2988
|
+
parameters = {"hypothesis_template": hypothesis_template}
|
|
2989
|
+
payload = _prepare_payload(inputs, parameters=parameters)
|
|
2635
2990
|
response = await self.post(
|
|
2636
|
-
|
|
2991
|
+
**payload,
|
|
2637
2992
|
model=model,
|
|
2638
2993
|
task="zero-shot-image-classification",
|
|
2639
2994
|
)
|