huggingface-hub 0.21.2__py3-none-any.whl → 0.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of huggingface-hub might be problematic. Click here for more details.
- huggingface_hub/__init__.py +217 -1
- huggingface_hub/_commit_api.py +14 -15
- huggingface_hub/_inference_endpoints.py +12 -11
- huggingface_hub/_login.py +1 -0
- huggingface_hub/_multi_commits.py +1 -0
- huggingface_hub/_snapshot_download.py +9 -1
- huggingface_hub/_tensorboard_logger.py +1 -0
- huggingface_hub/_webhooks_payload.py +1 -0
- huggingface_hub/_webhooks_server.py +1 -0
- huggingface_hub/commands/_cli_utils.py +1 -0
- huggingface_hub/commands/delete_cache.py +1 -0
- huggingface_hub/commands/download.py +1 -0
- huggingface_hub/commands/env.py +1 -0
- huggingface_hub/commands/scan_cache.py +1 -0
- huggingface_hub/commands/upload.py +1 -0
- huggingface_hub/community.py +1 -0
- huggingface_hub/constants.py +3 -1
- huggingface_hub/errors.py +38 -0
- huggingface_hub/file_download.py +102 -95
- huggingface_hub/hf_api.py +47 -35
- huggingface_hub/hf_file_system.py +77 -3
- huggingface_hub/hub_mixin.py +230 -61
- huggingface_hub/inference/_client.py +554 -239
- huggingface_hub/inference/_common.py +195 -41
- huggingface_hub/inference/_generated/_async_client.py +558 -239
- huggingface_hub/inference/_generated/types/__init__.py +115 -0
- huggingface_hub/inference/_generated/types/audio_classification.py +43 -0
- huggingface_hub/inference/_generated/types/audio_to_audio.py +31 -0
- huggingface_hub/inference/_generated/types/automatic_speech_recognition.py +116 -0
- huggingface_hub/inference/_generated/types/base.py +149 -0
- huggingface_hub/inference/_generated/types/chat_completion.py +106 -0
- huggingface_hub/inference/_generated/types/depth_estimation.py +29 -0
- huggingface_hub/inference/_generated/types/document_question_answering.py +85 -0
- huggingface_hub/inference/_generated/types/feature_extraction.py +19 -0
- huggingface_hub/inference/_generated/types/fill_mask.py +50 -0
- huggingface_hub/inference/_generated/types/image_classification.py +43 -0
- huggingface_hub/inference/_generated/types/image_segmentation.py +52 -0
- huggingface_hub/inference/_generated/types/image_to_image.py +55 -0
- huggingface_hub/inference/_generated/types/image_to_text.py +105 -0
- huggingface_hub/inference/_generated/types/object_detection.py +55 -0
- huggingface_hub/inference/_generated/types/question_answering.py +77 -0
- huggingface_hub/inference/_generated/types/sentence_similarity.py +28 -0
- huggingface_hub/inference/_generated/types/summarization.py +46 -0
- huggingface_hub/inference/_generated/types/table_question_answering.py +45 -0
- huggingface_hub/inference/_generated/types/text2text_generation.py +45 -0
- huggingface_hub/inference/_generated/types/text_classification.py +43 -0
- huggingface_hub/inference/_generated/types/text_generation.py +161 -0
- huggingface_hub/inference/_generated/types/text_to_audio.py +105 -0
- huggingface_hub/inference/_generated/types/text_to_image.py +57 -0
- huggingface_hub/inference/_generated/types/token_classification.py +53 -0
- huggingface_hub/inference/_generated/types/translation.py +46 -0
- huggingface_hub/inference/_generated/types/video_classification.py +47 -0
- huggingface_hub/inference/_generated/types/visual_question_answering.py +53 -0
- huggingface_hub/inference/_generated/types/zero_shot_classification.py +56 -0
- huggingface_hub/inference/_generated/types/zero_shot_image_classification.py +51 -0
- huggingface_hub/inference/_generated/types/zero_shot_object_detection.py +55 -0
- huggingface_hub/inference/_templating.py +105 -0
- huggingface_hub/inference/_types.py +4 -152
- huggingface_hub/keras_mixin.py +39 -17
- huggingface_hub/lfs.py +20 -8
- huggingface_hub/repocard.py +11 -3
- huggingface_hub/repocard_data.py +12 -2
- huggingface_hub/serialization/__init__.py +1 -0
- huggingface_hub/serialization/_base.py +1 -0
- huggingface_hub/serialization/_numpy.py +1 -0
- huggingface_hub/serialization/_tensorflow.py +1 -0
- huggingface_hub/serialization/_torch.py +1 -0
- huggingface_hub/utils/__init__.py +4 -1
- huggingface_hub/utils/_cache_manager.py +7 -0
- huggingface_hub/utils/_chunk_utils.py +1 -0
- huggingface_hub/utils/_datetime.py +1 -0
- huggingface_hub/utils/_errors.py +10 -1
- huggingface_hub/utils/_experimental.py +1 -0
- huggingface_hub/utils/_fixes.py +19 -3
- huggingface_hub/utils/_git_credential.py +1 -0
- huggingface_hub/utils/_headers.py +10 -3
- huggingface_hub/utils/_hf_folder.py +1 -0
- huggingface_hub/utils/_http.py +1 -0
- huggingface_hub/utils/_pagination.py +1 -0
- huggingface_hub/utils/_paths.py +1 -0
- huggingface_hub/utils/_runtime.py +22 -0
- huggingface_hub/utils/_subprocess.py +1 -0
- huggingface_hub/utils/_token.py +1 -0
- huggingface_hub/utils/_typing.py +29 -1
- huggingface_hub/utils/_validators.py +1 -0
- huggingface_hub/utils/endpoint_helpers.py +1 -0
- huggingface_hub/utils/logging.py +1 -1
- huggingface_hub/utils/sha.py +1 -0
- huggingface_hub/utils/tqdm.py +1 -0
- {huggingface_hub-0.21.2.dist-info → huggingface_hub-0.22.0.dist-info}/METADATA +14 -15
- huggingface_hub-0.22.0.dist-info/RECORD +113 -0
- {huggingface_hub-0.21.2.dist-info → huggingface_hub-0.22.0.dist-info}/WHEEL +1 -1
- huggingface_hub/inference/_text_generation.py +0 -551
- huggingface_hub-0.21.2.dist-info/RECORD +0 -81
- {huggingface_hub-0.21.2.dist-info → huggingface_hub-0.22.0.dist-info}/LICENSE +0 -0
- {huggingface_hub-0.21.2.dist-info → huggingface_hub-0.22.0.dist-info}/entry_points.txt +0 -0
- {huggingface_hub-0.21.2.dist-info → huggingface_hub-0.22.0.dist-info}/top_level.txt +0 -0
|
@@ -23,7 +23,6 @@ import base64
|
|
|
23
23
|
import logging
|
|
24
24
|
import time
|
|
25
25
|
import warnings
|
|
26
|
-
from dataclasses import asdict
|
|
27
26
|
from typing import (
|
|
28
27
|
TYPE_CHECKING,
|
|
29
28
|
Any,
|
|
@@ -39,11 +38,13 @@ from typing import (
|
|
|
39
38
|
from requests.structures import CaseInsensitiveDict
|
|
40
39
|
|
|
41
40
|
from huggingface_hub.constants import ALL_INFERENCE_API_FRAMEWORKS, INFERENCE_ENDPOINT, MAIN_INFERENCE_API_FRAMEWORKS
|
|
41
|
+
from huggingface_hub.errors import InferenceTimeoutError
|
|
42
42
|
from huggingface_hub.inference._common import (
|
|
43
43
|
TASKS_EXPECTING_IMAGES,
|
|
44
44
|
ContentT,
|
|
45
|
-
InferenceTimeoutError,
|
|
46
45
|
ModelStatus,
|
|
46
|
+
_async_stream_chat_completion_response_from_bytes,
|
|
47
|
+
_async_stream_chat_completion_response_from_text_generation,
|
|
47
48
|
_async_stream_text_generation_response,
|
|
48
49
|
_b64_encode,
|
|
49
50
|
_b64_to_image,
|
|
@@ -52,27 +53,42 @@ from huggingface_hub.inference._common import (
|
|
|
52
53
|
_bytes_to_list,
|
|
53
54
|
_fetch_recommended_models,
|
|
54
55
|
_import_numpy,
|
|
56
|
+
_is_chat_completion_server,
|
|
55
57
|
_is_tgi_server,
|
|
56
58
|
_open_as_binary,
|
|
59
|
+
_set_as_non_chat_completion_server,
|
|
57
60
|
_set_as_non_tgi,
|
|
58
|
-
)
|
|
59
|
-
from huggingface_hub.inference._text_generation import (
|
|
60
|
-
TextGenerationParameters,
|
|
61
|
-
TextGenerationRequest,
|
|
62
|
-
TextGenerationResponse,
|
|
63
|
-
TextGenerationStreamResponse,
|
|
64
61
|
raise_text_generation_error,
|
|
65
62
|
)
|
|
63
|
+
from huggingface_hub.inference._generated.types import (
|
|
64
|
+
AudioClassificationOutputElement,
|
|
65
|
+
AudioToAudioOutputElement,
|
|
66
|
+
AutomaticSpeechRecognitionOutput,
|
|
67
|
+
ChatCompletionOutput,
|
|
68
|
+
ChatCompletionOutputChoice,
|
|
69
|
+
ChatCompletionOutputChoiceMessage,
|
|
70
|
+
ChatCompletionStreamOutput,
|
|
71
|
+
DocumentQuestionAnsweringOutputElement,
|
|
72
|
+
FillMaskOutputElement,
|
|
73
|
+
ImageClassificationOutputElement,
|
|
74
|
+
ImageSegmentationOutputElement,
|
|
75
|
+
ImageToTextOutput,
|
|
76
|
+
ObjectDetectionOutputElement,
|
|
77
|
+
QuestionAnsweringOutputElement,
|
|
78
|
+
SummarizationOutput,
|
|
79
|
+
TableQuestionAnsweringOutputElement,
|
|
80
|
+
TextClassificationOutputElement,
|
|
81
|
+
TextGenerationOutput,
|
|
82
|
+
TextGenerationStreamOutput,
|
|
83
|
+
TokenClassificationOutputElement,
|
|
84
|
+
TranslationOutput,
|
|
85
|
+
VisualQuestionAnsweringOutputElement,
|
|
86
|
+
ZeroShotClassificationOutputElement,
|
|
87
|
+
ZeroShotImageClassificationOutputElement,
|
|
88
|
+
)
|
|
89
|
+
from huggingface_hub.inference._templating import render_chat_prompt
|
|
66
90
|
from huggingface_hub.inference._types import (
|
|
67
|
-
|
|
68
|
-
ClassificationOutput,
|
|
69
|
-
ConversationalOutput,
|
|
70
|
-
FillMaskOutput,
|
|
71
|
-
ImageSegmentationOutput,
|
|
72
|
-
ObjectDetectionOutput,
|
|
73
|
-
QuestionAnsweringOutput,
|
|
74
|
-
TableQuestionAnsweringOutput,
|
|
75
|
-
TokenClassificationOutput,
|
|
91
|
+
ConversationalOutput, # soon to be removed
|
|
76
92
|
)
|
|
77
93
|
from huggingface_hub.utils import (
|
|
78
94
|
build_hf_headers,
|
|
@@ -100,9 +116,9 @@ class AsyncInferenceClient:
|
|
|
100
116
|
The model to run inference with. Can be a model id hosted on the Hugging Face Hub, e.g. `bigcode/starcoder`
|
|
101
117
|
or a URL to a deployed Inference Endpoint. Defaults to None, in which case a recommended model is
|
|
102
118
|
automatically selected for the task.
|
|
103
|
-
token (`str`, *optional*):
|
|
104
|
-
Hugging Face token. Will default to the locally saved token
|
|
105
|
-
your token to the server.
|
|
119
|
+
token (`str` or `bool`, *optional*):
|
|
120
|
+
Hugging Face token. Will default to the locally saved token if not provided.
|
|
121
|
+
Pass `token=False` if you don't want to send your token to the server.
|
|
106
122
|
timeout (`float`, `optional`):
|
|
107
123
|
The maximum number of seconds to wait for a response from the server. Loading a new model in Inference
|
|
108
124
|
API can take up to several minutes. Defaults to None, meaning it will loop until the server is available.
|
|
@@ -122,6 +138,7 @@ class AsyncInferenceClient:
|
|
|
122
138
|
cookies: Optional[Dict[str, str]] = None,
|
|
123
139
|
) -> None:
|
|
124
140
|
self.model: Optional[str] = model
|
|
141
|
+
self.token: Union[str, bool, None] = token
|
|
125
142
|
self.headers = CaseInsensitiveDict(build_hf_headers(token=token)) # contains 'authorization' + 'user-agent'
|
|
126
143
|
if headers is not None:
|
|
127
144
|
self.headers.update(headers)
|
|
@@ -140,11 +157,10 @@ class AsyncInferenceClient:
|
|
|
140
157
|
model: Optional[str] = None,
|
|
141
158
|
task: Optional[str] = None,
|
|
142
159
|
stream: Literal[False] = ...,
|
|
143
|
-
) -> bytes:
|
|
144
|
-
pass
|
|
160
|
+
) -> bytes: ...
|
|
145
161
|
|
|
146
162
|
@overload
|
|
147
|
-
async def post(
|
|
163
|
+
async def post( # type: ignore[misc]
|
|
148
164
|
self,
|
|
149
165
|
*,
|
|
150
166
|
json: Optional[Union[str, Dict, List]] = None,
|
|
@@ -152,8 +168,18 @@ class AsyncInferenceClient:
|
|
|
152
168
|
model: Optional[str] = None,
|
|
153
169
|
task: Optional[str] = None,
|
|
154
170
|
stream: Literal[True] = ...,
|
|
155
|
-
) -> AsyncIterable[bytes]:
|
|
156
|
-
|
|
171
|
+
) -> AsyncIterable[bytes]: ...
|
|
172
|
+
|
|
173
|
+
@overload
|
|
174
|
+
async def post(
|
|
175
|
+
self,
|
|
176
|
+
*,
|
|
177
|
+
json: Optional[Union[str, Dict, List]] = None,
|
|
178
|
+
data: Optional[ContentT] = None,
|
|
179
|
+
model: Optional[str] = None,
|
|
180
|
+
task: Optional[str] = None,
|
|
181
|
+
stream: bool = False,
|
|
182
|
+
) -> Union[bytes, AsyncIterable[bytes]]: ...
|
|
157
183
|
|
|
158
184
|
async def post(
|
|
159
185
|
self,
|
|
@@ -263,7 +289,7 @@ class AsyncInferenceClient:
|
|
|
263
289
|
audio: ContentT,
|
|
264
290
|
*,
|
|
265
291
|
model: Optional[str] = None,
|
|
266
|
-
) -> List[
|
|
292
|
+
) -> List[AudioClassificationOutputElement]:
|
|
267
293
|
"""
|
|
268
294
|
Perform audio classification on the provided audio content.
|
|
269
295
|
|
|
@@ -277,7 +303,7 @@ class AsyncInferenceClient:
|
|
|
277
303
|
audio classification will be used.
|
|
278
304
|
|
|
279
305
|
Returns:
|
|
280
|
-
`List[
|
|
306
|
+
`List[AudioClassificationOutputElement]`: List of [`AudioClassificationOutputElement`] items containing the predicted labels and their confidence.
|
|
281
307
|
|
|
282
308
|
Raises:
|
|
283
309
|
[`InferenceTimeoutError`]:
|
|
@@ -291,18 +317,22 @@ class AsyncInferenceClient:
|
|
|
291
317
|
>>> from huggingface_hub import AsyncInferenceClient
|
|
292
318
|
>>> client = AsyncInferenceClient()
|
|
293
319
|
>>> await client.audio_classification("audio.flac")
|
|
294
|
-
[
|
|
320
|
+
[
|
|
321
|
+
AudioClassificationOutputElement(score=0.4976358711719513, label='hap'),
|
|
322
|
+
AudioClassificationOutputElement(score=0.3677836060523987, label='neu'),
|
|
323
|
+
...
|
|
324
|
+
]
|
|
295
325
|
```
|
|
296
326
|
"""
|
|
297
327
|
response = await self.post(data=audio, model=model, task="audio-classification")
|
|
298
|
-
return
|
|
328
|
+
return AudioClassificationOutputElement.parse_obj_as_list(response)
|
|
299
329
|
|
|
300
330
|
async def audio_to_audio(
|
|
301
331
|
self,
|
|
302
332
|
audio: ContentT,
|
|
303
333
|
*,
|
|
304
334
|
model: Optional[str] = None,
|
|
305
|
-
) -> List[
|
|
335
|
+
) -> List[AudioToAudioOutputElement]:
|
|
306
336
|
"""
|
|
307
337
|
Performs multiple tasks related to audio-to-audio depending on the model (eg: speech enhancement, source separation).
|
|
308
338
|
|
|
@@ -316,7 +346,7 @@ class AsyncInferenceClient:
|
|
|
316
346
|
audio_to_audio will be used.
|
|
317
347
|
|
|
318
348
|
Returns:
|
|
319
|
-
`List[
|
|
349
|
+
`List[AudioToAudioOutputElement]`: A list of [`AudioToAudioOutputElement`] items containing audios label, content-type, and audio content in blob.
|
|
320
350
|
|
|
321
351
|
Raises:
|
|
322
352
|
`InferenceTimeoutError`:
|
|
@@ -332,13 +362,13 @@ class AsyncInferenceClient:
|
|
|
332
362
|
>>> audio_output = await client.audio_to_audio("audio.flac")
|
|
333
363
|
>>> async for i, item in enumerate(audio_output):
|
|
334
364
|
>>> with open(f"output_{i}.flac", "wb") as f:
|
|
335
|
-
f.write(item
|
|
365
|
+
f.write(item.blob)
|
|
336
366
|
```
|
|
337
367
|
"""
|
|
338
368
|
response = await self.post(data=audio, model=model, task="audio-to-audio")
|
|
339
|
-
audio_output =
|
|
369
|
+
audio_output = AudioToAudioOutputElement.parse_obj_as_list(response)
|
|
340
370
|
for item in audio_output:
|
|
341
|
-
item
|
|
371
|
+
item.blob = base64.b64decode(item.blob)
|
|
342
372
|
return audio_output
|
|
343
373
|
|
|
344
374
|
async def automatic_speech_recognition(
|
|
@@ -346,7 +376,7 @@ class AsyncInferenceClient:
|
|
|
346
376
|
audio: ContentT,
|
|
347
377
|
*,
|
|
348
378
|
model: Optional[str] = None,
|
|
349
|
-
) ->
|
|
379
|
+
) -> AutomaticSpeechRecognitionOutput:
|
|
350
380
|
"""
|
|
351
381
|
Perform automatic speech recognition (ASR or audio-to-text) on the given audio content.
|
|
352
382
|
|
|
@@ -358,7 +388,7 @@ class AsyncInferenceClient:
|
|
|
358
388
|
Inference Endpoint. If not provided, the default recommended model for ASR will be used.
|
|
359
389
|
|
|
360
390
|
Returns:
|
|
361
|
-
|
|
391
|
+
[`AutomaticSpeechRecognitionOutput`]: An item containing the transcribed text and optionally the timestamp chunks.
|
|
362
392
|
|
|
363
393
|
Raises:
|
|
364
394
|
[`InferenceTimeoutError`]:
|
|
@@ -371,12 +401,266 @@ class AsyncInferenceClient:
|
|
|
371
401
|
# Must be run in an async context
|
|
372
402
|
>>> from huggingface_hub import AsyncInferenceClient
|
|
373
403
|
>>> client = AsyncInferenceClient()
|
|
374
|
-
>>> await client.automatic_speech_recognition("hello_world.flac")
|
|
404
|
+
>>> await client.automatic_speech_recognition("hello_world.flac").text
|
|
375
405
|
"hello world"
|
|
376
406
|
```
|
|
377
407
|
"""
|
|
378
408
|
response = await self.post(data=audio, model=model, task="automatic-speech-recognition")
|
|
379
|
-
return
|
|
409
|
+
return AutomaticSpeechRecognitionOutput.parse_obj_as_instance(response)
|
|
410
|
+
|
|
411
|
+
@overload
|
|
412
|
+
async def chat_completion( # type: ignore
|
|
413
|
+
self,
|
|
414
|
+
messages: List[Dict[str, str]],
|
|
415
|
+
*,
|
|
416
|
+
model: Optional[str] = None,
|
|
417
|
+
stream: Literal[False] = False,
|
|
418
|
+
max_tokens: int = 20,
|
|
419
|
+
seed: Optional[int] = None,
|
|
420
|
+
stop: Optional[Union[List[str], str]] = None,
|
|
421
|
+
temperature: float = 1.0,
|
|
422
|
+
top_p: Optional[float] = None,
|
|
423
|
+
) -> ChatCompletionOutput: ...
|
|
424
|
+
|
|
425
|
+
@overload
|
|
426
|
+
async def chat_completion( # type: ignore
|
|
427
|
+
self,
|
|
428
|
+
messages: List[Dict[str, str]],
|
|
429
|
+
*,
|
|
430
|
+
model: Optional[str] = None,
|
|
431
|
+
stream: Literal[True] = True,
|
|
432
|
+
max_tokens: int = 20,
|
|
433
|
+
seed: Optional[int] = None,
|
|
434
|
+
stop: Optional[Union[List[str], str]] = None,
|
|
435
|
+
temperature: float = 1.0,
|
|
436
|
+
top_p: Optional[float] = None,
|
|
437
|
+
) -> AsyncIterable[ChatCompletionStreamOutput]: ...
|
|
438
|
+
|
|
439
|
+
@overload
|
|
440
|
+
async def chat_completion(
|
|
441
|
+
self,
|
|
442
|
+
messages: List[Dict[str, str]],
|
|
443
|
+
*,
|
|
444
|
+
model: Optional[str] = None,
|
|
445
|
+
stream: bool = False,
|
|
446
|
+
max_tokens: int = 20,
|
|
447
|
+
seed: Optional[int] = None,
|
|
448
|
+
stop: Optional[Union[List[str], str]] = None,
|
|
449
|
+
temperature: float = 1.0,
|
|
450
|
+
top_p: Optional[float] = None,
|
|
451
|
+
) -> Union[ChatCompletionOutput, AsyncIterable[ChatCompletionStreamOutput]]: ...
|
|
452
|
+
|
|
453
|
+
async def chat_completion(
|
|
454
|
+
self,
|
|
455
|
+
messages: List[Dict[str, str]],
|
|
456
|
+
*,
|
|
457
|
+
model: Optional[str] = None,
|
|
458
|
+
stream: bool = False,
|
|
459
|
+
max_tokens: int = 20,
|
|
460
|
+
seed: Optional[int] = None,
|
|
461
|
+
stop: Optional[Union[List[str], str]] = None,
|
|
462
|
+
temperature: float = 1.0,
|
|
463
|
+
top_p: Optional[float] = None,
|
|
464
|
+
) -> Union[ChatCompletionOutput, AsyncIterable[ChatCompletionStreamOutput]]:
|
|
465
|
+
"""
|
|
466
|
+
A method for completing conversations using a specified language model.
|
|
467
|
+
|
|
468
|
+
<Tip>
|
|
469
|
+
|
|
470
|
+
If the model is served by a server supporting chat-completion, the method will directly call the server's
|
|
471
|
+
`/v1/chat/completions` endpoint. If the server does not support chat-completion, the method will render the
|
|
472
|
+
chat template client-side based on the information fetched from the Hub API. In this case, you will need to
|
|
473
|
+
have `minijinja` template engine installed. Run `pip install "huggingface_hub[inference]"` or `pip install minijinja`
|
|
474
|
+
to install it.
|
|
475
|
+
|
|
476
|
+
</Tip>
|
|
477
|
+
|
|
478
|
+
Args:
|
|
479
|
+
messages (List[Union[`SystemMessage`, `UserMessage`, `AssistantMessage`]]):
|
|
480
|
+
Conversation history consisting of roles and content pairs.
|
|
481
|
+
model (`str`, *optional*):
|
|
482
|
+
The model to use for chat-completion. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
|
483
|
+
Inference Endpoint. If not provided, the default recommended model for chat-based text-generation will be used.
|
|
484
|
+
See https://huggingface.co/tasks/text-generation for more details.
|
|
485
|
+
frequency_penalty (`float`, optional):
|
|
486
|
+
Penalizes new tokens based on their existing frequency
|
|
487
|
+
in the text so far. Range: [-2.0, 2.0]. Defaults to 0.0.
|
|
488
|
+
max_tokens (`int`, optional):
|
|
489
|
+
Maximum number of tokens allowed in the response. Defaults to 20.
|
|
490
|
+
seed (Optional[`int`], optional):
|
|
491
|
+
Seed for reproducible control flow. Defaults to None.
|
|
492
|
+
stop (Optional[`str`], optional):
|
|
493
|
+
Up to four strings which trigger the end of the response.
|
|
494
|
+
Defaults to None.
|
|
495
|
+
stream (`bool`, optional):
|
|
496
|
+
Enable realtime streaming of responses. Defaults to False.
|
|
497
|
+
temperature (`float`, optional):
|
|
498
|
+
Controls randomness of the generations. Lower values ensure
|
|
499
|
+
less random completions. Range: [0, 2]. Defaults to 1.0.
|
|
500
|
+
top_p (`float`, optional):
|
|
501
|
+
Fraction of the most likely next words to sample from.
|
|
502
|
+
Must be between 0 and 1. Defaults to 1.0.
|
|
503
|
+
|
|
504
|
+
Returns:
|
|
505
|
+
`Union[ChatCompletionOutput, Iterable[ChatCompletionStreamOutput]]`:
|
|
506
|
+
Generated text returned from the server:
|
|
507
|
+
- if `stream=False`, the generated text is returned as a [`ChatCompletionOutput`] (default).
|
|
508
|
+
- if `stream=True`, the generated text is returned token by token as a sequence of [`ChatCompletionStreamOutput`].
|
|
509
|
+
|
|
510
|
+
Raises:
|
|
511
|
+
[`InferenceTimeoutError`]:
|
|
512
|
+
If the model is unavailable or the request times out.
|
|
513
|
+
`aiohttp.ClientResponseError`:
|
|
514
|
+
If the request fails with an HTTP error status code other than HTTP 503.
|
|
515
|
+
|
|
516
|
+
Example:
|
|
517
|
+
```py
|
|
518
|
+
# Must be run in an async context
|
|
519
|
+
>>> from huggingface_hub import AsyncInferenceClient
|
|
520
|
+
>>> messages = [{"role": "user", "content": "What is the capital of France?"}]
|
|
521
|
+
>>> client = AsyncInferenceClient("HuggingFaceH4/zephyr-7b-beta")
|
|
522
|
+
>>> await client.chat_completion(messages, max_tokens=100)
|
|
523
|
+
ChatCompletionOutput(
|
|
524
|
+
choices=[
|
|
525
|
+
ChatCompletionOutputChoice(
|
|
526
|
+
finish_reason='eos_token',
|
|
527
|
+
index=0,
|
|
528
|
+
message=ChatCompletionOutputChoiceMessage(
|
|
529
|
+
content='The capital of France is Paris. The official name of the city is "Ville de Paris" (City of Paris) and the name of the country\'s governing body, which is located in Paris, is "La République française" (The French Republic). \nI hope that helps! Let me know if you need any further information.'
|
|
530
|
+
)
|
|
531
|
+
)
|
|
532
|
+
],
|
|
533
|
+
created=1710498360
|
|
534
|
+
)
|
|
535
|
+
|
|
536
|
+
>>> async for token in await client.chat_completion(messages, max_tokens=10, stream=True):
|
|
537
|
+
... print(token)
|
|
538
|
+
ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content='The', role='assistant'), index=0, finish_reason=None)], created=1710498504)
|
|
539
|
+
ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content=' capital', role='assistant'), index=0, finish_reason=None)], created=1710498504)
|
|
540
|
+
(...)
|
|
541
|
+
ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content=' may', role='assistant'), index=0, finish_reason=None)], created=1710498504)
|
|
542
|
+
ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content=None, role=None), index=0, finish_reason='length')], created=1710498504)
|
|
543
|
+
```
|
|
544
|
+
"""
|
|
545
|
+
# determine model
|
|
546
|
+
model = model or self.model or self.get_recommended_model("text-generation")
|
|
547
|
+
|
|
548
|
+
if _is_chat_completion_server(model):
|
|
549
|
+
# First, let's consider the server has a `/v1/chat/completions` endpoint.
|
|
550
|
+
# If that's the case, we don't have to render the chat template client-side.
|
|
551
|
+
model_url = self._resolve_url(model) + "/v1/chat/completions"
|
|
552
|
+
|
|
553
|
+
try:
|
|
554
|
+
data = await self.post(
|
|
555
|
+
model=model_url,
|
|
556
|
+
json=dict(
|
|
557
|
+
model="tgi", # random string
|
|
558
|
+
messages=messages,
|
|
559
|
+
max_tokens=max_tokens,
|
|
560
|
+
seed=seed,
|
|
561
|
+
stop=stop,
|
|
562
|
+
temperature=temperature,
|
|
563
|
+
top_p=top_p,
|
|
564
|
+
stream=stream,
|
|
565
|
+
),
|
|
566
|
+
stream=stream,
|
|
567
|
+
)
|
|
568
|
+
except _import_aiohttp().ClientResponseError:
|
|
569
|
+
# Let's consider the server is not a chat completion server.
|
|
570
|
+
# Then we call again `chat_completion` which will render the chat template client side.
|
|
571
|
+
# (can be HTTP 500, HTTP 400, HTTP 404 depending on the server)
|
|
572
|
+
_set_as_non_chat_completion_server(model)
|
|
573
|
+
return await self.chat_completion(
|
|
574
|
+
messages=messages,
|
|
575
|
+
model=model,
|
|
576
|
+
stream=stream,
|
|
577
|
+
max_tokens=max_tokens,
|
|
578
|
+
seed=seed,
|
|
579
|
+
stop=stop,
|
|
580
|
+
temperature=temperature,
|
|
581
|
+
top_p=top_p,
|
|
582
|
+
)
|
|
583
|
+
|
|
584
|
+
if stream:
|
|
585
|
+
return _async_stream_chat_completion_response_from_bytes(data) # type: ignore[arg-type]
|
|
586
|
+
|
|
587
|
+
return ChatCompletionOutput.parse_obj_as_instance(data) # type: ignore[arg-type]
|
|
588
|
+
|
|
589
|
+
# At this point, we know the server is not a chat completion server.
|
|
590
|
+
# We need to render the chat template client side based on the information we can fetch from
|
|
591
|
+
# the Hub API.
|
|
592
|
+
|
|
593
|
+
model_id = None
|
|
594
|
+
if model.startswith(("http://", "https://")):
|
|
595
|
+
# If URL, we need to know which model is served. This is not always possible.
|
|
596
|
+
# A workaround is to list the user Inference Endpoints and check if one of them correspond to the model URL.
|
|
597
|
+
# If not, we raise an error.
|
|
598
|
+
# TODO: fix when we have a proper API for this (at least for Inference Endpoints)
|
|
599
|
+
# TODO: what if Sagemaker URL?
|
|
600
|
+
# TODO: what if Azure URL?
|
|
601
|
+
from ..hf_api import HfApi
|
|
602
|
+
|
|
603
|
+
for endpoint in HfApi(token=self.token).list_inference_endpoints():
|
|
604
|
+
if endpoint.url == model:
|
|
605
|
+
model_id = endpoint.repository
|
|
606
|
+
break
|
|
607
|
+
else:
|
|
608
|
+
model_id = model
|
|
609
|
+
|
|
610
|
+
if model_id is None:
|
|
611
|
+
# If we don't have the model ID, we can't fetch the chat template.
|
|
612
|
+
# We raise an error.
|
|
613
|
+
raise ValueError(
|
|
614
|
+
"Request can't be processed as the model ID can't be inferred from model URL. "
|
|
615
|
+
"This is needed to fetch the chat template from the Hub since the model is not "
|
|
616
|
+
"served with a Chat-completion API."
|
|
617
|
+
)
|
|
618
|
+
|
|
619
|
+
# fetch chat template + tokens
|
|
620
|
+
prompt = render_chat_prompt(model_id=model_id, token=self.token, messages=messages)
|
|
621
|
+
|
|
622
|
+
# generate response
|
|
623
|
+
stop_sequences = [stop] if isinstance(stop, str) else stop
|
|
624
|
+
text_generation_output = await self.text_generation(
|
|
625
|
+
prompt=prompt,
|
|
626
|
+
details=True,
|
|
627
|
+
stream=stream,
|
|
628
|
+
model=model,
|
|
629
|
+
max_new_tokens=max_tokens,
|
|
630
|
+
seed=seed,
|
|
631
|
+
stop_sequences=stop_sequences,
|
|
632
|
+
temperature=temperature,
|
|
633
|
+
top_p=top_p,
|
|
634
|
+
)
|
|
635
|
+
|
|
636
|
+
created = int(time.time())
|
|
637
|
+
|
|
638
|
+
if stream:
|
|
639
|
+
return _async_stream_chat_completion_response_from_text_generation(text_generation_output) # type: ignore [arg-type]
|
|
640
|
+
|
|
641
|
+
if isinstance(text_generation_output, TextGenerationOutput):
|
|
642
|
+
# General use case => format ChatCompletionOutput from text generation details
|
|
643
|
+
content: str = text_generation_output.generated_text
|
|
644
|
+
finish_reason: str = text_generation_output.details.finish_reason # type: ignore[union-attr]
|
|
645
|
+
else:
|
|
646
|
+
# Corner case: if server doesn't support details (e.g. if not a TGI server), we only receive an output string.
|
|
647
|
+
# In such a case, `finish_reason` is set to `"unk"`.
|
|
648
|
+
content = text_generation_output # type: ignore[assignment]
|
|
649
|
+
finish_reason = "unk"
|
|
650
|
+
|
|
651
|
+
return ChatCompletionOutput(
|
|
652
|
+
created=created,
|
|
653
|
+
choices=[
|
|
654
|
+
ChatCompletionOutputChoice(
|
|
655
|
+
finish_reason=finish_reason, # type: ignore
|
|
656
|
+
index=0,
|
|
657
|
+
message=ChatCompletionOutputChoiceMessage(
|
|
658
|
+
content=content,
|
|
659
|
+
role="assistant",
|
|
660
|
+
),
|
|
661
|
+
)
|
|
662
|
+
],
|
|
663
|
+
)
|
|
380
664
|
|
|
381
665
|
async def conversational(
|
|
382
666
|
self,
|
|
@@ -390,6 +674,13 @@ class AsyncInferenceClient:
|
|
|
390
674
|
"""
|
|
391
675
|
Generate conversational responses based on the given input text (i.e. chat with the API).
|
|
392
676
|
|
|
677
|
+
<Tip warning={true}>
|
|
678
|
+
|
|
679
|
+
[`InferenceClient.conversational`] API is deprecated and will be removed in a future release. Please use
|
|
680
|
+
[`InferenceClient.chat_completion`] instead.
|
|
681
|
+
|
|
682
|
+
</Tip>
|
|
683
|
+
|
|
393
684
|
Args:
|
|
394
685
|
text (`str`):
|
|
395
686
|
The last input from the user in the conversation.
|
|
@@ -430,6 +721,11 @@ class AsyncInferenceClient:
|
|
|
430
721
|
... )
|
|
431
722
|
```
|
|
432
723
|
"""
|
|
724
|
+
warnings.warn(
|
|
725
|
+
"'InferenceClient.conversational' is deprecated and will be removed starting from huggingface_hub>=0.25. "
|
|
726
|
+
"Please use the more appropriate 'InferenceClient.chat_completion' API instead.",
|
|
727
|
+
FutureWarning,
|
|
728
|
+
)
|
|
433
729
|
payload: Dict[str, Any] = {"inputs": {"text": text}}
|
|
434
730
|
if generated_responses is not None:
|
|
435
731
|
payload["inputs"]["generated_responses"] = generated_responses
|
|
@@ -440,58 +736,13 @@ class AsyncInferenceClient:
|
|
|
440
736
|
response = await self.post(json=payload, model=model, task="conversational")
|
|
441
737
|
return _bytes_to_dict(response) # type: ignore
|
|
442
738
|
|
|
443
|
-
async def visual_question_answering(
|
|
444
|
-
self,
|
|
445
|
-
image: ContentT,
|
|
446
|
-
question: str,
|
|
447
|
-
*,
|
|
448
|
-
model: Optional[str] = None,
|
|
449
|
-
) -> List[str]:
|
|
450
|
-
"""
|
|
451
|
-
Answering open-ended questions based on an image.
|
|
452
|
-
|
|
453
|
-
Args:
|
|
454
|
-
image (`Union[str, Path, bytes, BinaryIO]`):
|
|
455
|
-
The input image for the context. It can be raw bytes, an image file, or a URL to an online image.
|
|
456
|
-
question (`str`):
|
|
457
|
-
Question to be answered.
|
|
458
|
-
model (`str`, *optional*):
|
|
459
|
-
The model to use for the visual question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to
|
|
460
|
-
a deployed Inference Endpoint. If not provided, the default recommended visual question answering model will be used.
|
|
461
|
-
Defaults to None.
|
|
462
|
-
|
|
463
|
-
Returns:
|
|
464
|
-
`List[Dict]`: a list of dictionaries containing the predicted label and associated probability.
|
|
465
|
-
|
|
466
|
-
Raises:
|
|
467
|
-
`InferenceTimeoutError`:
|
|
468
|
-
If the model is unavailable or the request times out.
|
|
469
|
-
`aiohttp.ClientResponseError`:
|
|
470
|
-
If the request fails with an HTTP error status code other than HTTP 503.
|
|
471
|
-
|
|
472
|
-
Example:
|
|
473
|
-
```py
|
|
474
|
-
# Must be run in an async context
|
|
475
|
-
>>> from huggingface_hub import AsyncInferenceClient
|
|
476
|
-
>>> client = AsyncInferenceClient()
|
|
477
|
-
>>> await client.visual_question_answering(
|
|
478
|
-
... image="https://huggingface.co/datasets/mishig/sample_images/resolve/main/tiger.jpg",
|
|
479
|
-
... question="What is the animal doing?"
|
|
480
|
-
... )
|
|
481
|
-
[{'score': 0.778609573841095, 'answer': 'laying down'},{'score': 0.6957435607910156, 'answer': 'sitting'}, ...]
|
|
482
|
-
```
|
|
483
|
-
"""
|
|
484
|
-
payload: Dict[str, Any] = {"question": question, "image": _b64_encode(image)}
|
|
485
|
-
response = await self.post(json=payload, model=model, task="visual-question-answering")
|
|
486
|
-
return _bytes_to_list(response)
|
|
487
|
-
|
|
488
739
|
async def document_question_answering(
|
|
489
740
|
self,
|
|
490
741
|
image: ContentT,
|
|
491
742
|
question: str,
|
|
492
743
|
*,
|
|
493
744
|
model: Optional[str] = None,
|
|
494
|
-
) -> List[
|
|
745
|
+
) -> List[DocumentQuestionAnsweringOutputElement]:
|
|
495
746
|
"""
|
|
496
747
|
Answer questions on document images.
|
|
497
748
|
|
|
@@ -506,7 +757,7 @@ class AsyncInferenceClient:
|
|
|
506
757
|
Defaults to None.
|
|
507
758
|
|
|
508
759
|
Returns:
|
|
509
|
-
`List[
|
|
760
|
+
`List[DocumentQuestionAnsweringOutputElement]`: a list of [`DocumentQuestionAnsweringOutputElement`] items containing the predicted label, associated probability, word ids, and page number.
|
|
510
761
|
|
|
511
762
|
Raises:
|
|
512
763
|
[`InferenceTimeoutError`]:
|
|
@@ -520,12 +771,12 @@ class AsyncInferenceClient:
|
|
|
520
771
|
>>> from huggingface_hub import AsyncInferenceClient
|
|
521
772
|
>>> client = AsyncInferenceClient()
|
|
522
773
|
>>> await client.document_question_answering(image="https://huggingface.co/spaces/impira/docquery/resolve/2359223c1837a7587402bda0f2643382a6eefeab/invoice.png", question="What is the invoice number?")
|
|
523
|
-
[
|
|
774
|
+
[DocumentQuestionAnsweringOutputElement(score=0.42515629529953003, answer='us-001', start=16, end=16)]
|
|
524
775
|
```
|
|
525
776
|
"""
|
|
526
777
|
payload: Dict[str, Any] = {"question": question, "image": _b64_encode(image)}
|
|
527
778
|
response = await self.post(json=payload, model=model, task="document-question-answering")
|
|
528
|
-
return
|
|
779
|
+
return DocumentQuestionAnsweringOutputElement.parse_obj_as_list(response)
|
|
529
780
|
|
|
530
781
|
async def feature_extraction(self, text: str, *, model: Optional[str] = None) -> "np.ndarray":
|
|
531
782
|
"""
|
|
@@ -564,7 +815,7 @@ class AsyncInferenceClient:
|
|
|
564
815
|
np = _import_numpy()
|
|
565
816
|
return np.array(_bytes_to_dict(response), dtype="float32")
|
|
566
817
|
|
|
567
|
-
async def fill_mask(self, text: str, *, model: Optional[str] = None) -> List[
|
|
818
|
+
async def fill_mask(self, text: str, *, model: Optional[str] = None) -> List[FillMaskOutputElement]:
|
|
568
819
|
"""
|
|
569
820
|
Fill in a hole with a missing word (token to be precise).
|
|
570
821
|
|
|
@@ -577,7 +828,7 @@ class AsyncInferenceClient:
|
|
|
577
828
|
Defaults to None.
|
|
578
829
|
|
|
579
830
|
Returns:
|
|
580
|
-
`List[
|
|
831
|
+
`List[FillMaskOutputElement]`: a list of [`FillMaskOutputElement`] items containing the predicted label, associated
|
|
581
832
|
probability, token reference, and completed text.
|
|
582
833
|
|
|
583
834
|
Raises:
|
|
@@ -592,25 +843,21 @@ class AsyncInferenceClient:
|
|
|
592
843
|
>>> from huggingface_hub import AsyncInferenceClient
|
|
593
844
|
>>> client = AsyncInferenceClient()
|
|
594
845
|
>>> await client.fill_mask("The goal of life is <mask>.")
|
|
595
|
-
[
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
{'score': 0.06554922461509705,
|
|
600
|
-
'token': 45075,
|
|
601
|
-
'token_str': ' immortality',
|
|
602
|
-
'sequence': 'The goal of life is immortality.'}]
|
|
846
|
+
[
|
|
847
|
+
FillMaskOutputElement(score=0.06897063553333282, token=11098, token_str=' happiness', sequence='The goal of life is happiness.'),
|
|
848
|
+
FillMaskOutputElement(score=0.06554922461509705, token=45075, token_str=' immortality', sequence='The goal of life is immortality.')
|
|
849
|
+
]
|
|
603
850
|
```
|
|
604
851
|
"""
|
|
605
852
|
response = await self.post(json={"inputs": text}, model=model, task="fill-mask")
|
|
606
|
-
return
|
|
853
|
+
return FillMaskOutputElement.parse_obj_as_list(response)
|
|
607
854
|
|
|
608
855
|
async def image_classification(
|
|
609
856
|
self,
|
|
610
857
|
image: ContentT,
|
|
611
858
|
*,
|
|
612
859
|
model: Optional[str] = None,
|
|
613
|
-
) -> List[
|
|
860
|
+
) -> List[ImageClassificationOutputElement]:
|
|
614
861
|
"""
|
|
615
862
|
Perform image classification on the given image using the specified model.
|
|
616
863
|
|
|
@@ -622,7 +869,7 @@ class AsyncInferenceClient:
|
|
|
622
869
|
deployed Inference Endpoint. If not provided, the default recommended model for image classification will be used.
|
|
623
870
|
|
|
624
871
|
Returns:
|
|
625
|
-
`List[
|
|
872
|
+
`List[ImageClassificationOutputElement]`: a list of [`ImageClassificationOutputElement`] items containing the predicted label and associated probability.
|
|
626
873
|
|
|
627
874
|
Raises:
|
|
628
875
|
[`InferenceTimeoutError`]:
|
|
@@ -636,18 +883,18 @@ class AsyncInferenceClient:
|
|
|
636
883
|
>>> from huggingface_hub import AsyncInferenceClient
|
|
637
884
|
>>> client = AsyncInferenceClient()
|
|
638
885
|
>>> await client.image_classification("https://upload.wikimedia.org/wikipedia/commons/thumb/4/43/Cute_dog.jpg/320px-Cute_dog.jpg")
|
|
639
|
-
[
|
|
886
|
+
[ImageClassificationOutputElement(score=0.9779096841812134, label='Blenheim spaniel'), ...]
|
|
640
887
|
```
|
|
641
888
|
"""
|
|
642
889
|
response = await self.post(data=image, model=model, task="image-classification")
|
|
643
|
-
return
|
|
890
|
+
return ImageClassificationOutputElement.parse_obj_as_list(response)
|
|
644
891
|
|
|
645
892
|
async def image_segmentation(
|
|
646
893
|
self,
|
|
647
894
|
image: ContentT,
|
|
648
895
|
*,
|
|
649
896
|
model: Optional[str] = None,
|
|
650
|
-
) -> List[
|
|
897
|
+
) -> List[ImageSegmentationOutputElement]:
|
|
651
898
|
"""
|
|
652
899
|
Perform image segmentation on the given image using the specified model.
|
|
653
900
|
|
|
@@ -665,7 +912,7 @@ class AsyncInferenceClient:
|
|
|
665
912
|
deployed Inference Endpoint. If not provided, the default recommended model for image segmentation will be used.
|
|
666
913
|
|
|
667
914
|
Returns:
|
|
668
|
-
`List[
|
|
915
|
+
`List[ImageSegmentationOutputElement]`: A list of [`ImageSegmentationOutputElement`] items containing the segmented masks and associated attributes.
|
|
669
916
|
|
|
670
917
|
Raises:
|
|
671
918
|
[`InferenceTimeoutError`]:
|
|
@@ -679,19 +926,13 @@ class AsyncInferenceClient:
|
|
|
679
926
|
>>> from huggingface_hub import AsyncInferenceClient
|
|
680
927
|
>>> client = AsyncInferenceClient()
|
|
681
928
|
>>> await client.image_segmentation("cat.jpg"):
|
|
682
|
-
[
|
|
929
|
+
[ImageSegmentationOutputElement(score=0.989008, label='LABEL_184', mask=<PIL.PngImagePlugin.PngImageFile image mode=L size=400x300 at 0x7FDD2B129CC0>), ...]
|
|
683
930
|
```
|
|
684
931
|
"""
|
|
685
|
-
|
|
686
|
-
# Segment
|
|
687
932
|
response = await self.post(data=image, model=model, task="image-segmentation")
|
|
688
|
-
output =
|
|
689
|
-
|
|
690
|
-
# Parse masks as PIL Image
|
|
691
|
-
if not isinstance(output, list):
|
|
692
|
-
raise ValueError(f"Server output must be a list. Got {type(output)}: {str(output)[:200]}...")
|
|
933
|
+
output = ImageSegmentationOutputElement.parse_obj_as_list(response)
|
|
693
934
|
for item in output:
|
|
694
|
-
item
|
|
935
|
+
item.mask = _b64_to_image(item.mask)
|
|
695
936
|
return output
|
|
696
937
|
|
|
697
938
|
async def image_to_image(
|
|
@@ -779,7 +1020,7 @@ class AsyncInferenceClient:
|
|
|
779
1020
|
response = await self.post(json=payload, data=data, model=model, task="image-to-image")
|
|
780
1021
|
return _bytes_to_image(response)
|
|
781
1022
|
|
|
782
|
-
async def image_to_text(self, image: ContentT, *, model: Optional[str] = None) ->
|
|
1023
|
+
async def image_to_text(self, image: ContentT, *, model: Optional[str] = None) -> ImageToTextOutput:
|
|
783
1024
|
"""
|
|
784
1025
|
Takes an input image and return text.
|
|
785
1026
|
|
|
@@ -794,7 +1035,7 @@ class AsyncInferenceClient:
|
|
|
794
1035
|
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
|
|
795
1036
|
|
|
796
1037
|
Returns:
|
|
797
|
-
`
|
|
1038
|
+
[`ImageToTextOutput`]: The generated text.
|
|
798
1039
|
|
|
799
1040
|
Raises:
|
|
800
1041
|
[`InferenceTimeoutError`]:
|
|
@@ -814,7 +1055,7 @@ class AsyncInferenceClient:
|
|
|
814
1055
|
```
|
|
815
1056
|
"""
|
|
816
1057
|
response = await self.post(data=image, model=model, task="image-to-text")
|
|
817
|
-
return
|
|
1058
|
+
return ImageToTextOutput.parse_obj_as_instance(response)
|
|
818
1059
|
|
|
819
1060
|
async def list_deployed_models(
|
|
820
1061
|
self, frameworks: Union[None, str, Literal["all"], List[str]] = None
|
|
@@ -902,7 +1143,7 @@ class AsyncInferenceClient:
|
|
|
902
1143
|
image: ContentT,
|
|
903
1144
|
*,
|
|
904
1145
|
model: Optional[str] = None,
|
|
905
|
-
) -> List[
|
|
1146
|
+
) -> List[ObjectDetectionOutputElement]:
|
|
906
1147
|
"""
|
|
907
1148
|
Perform object detection on the given image using the specified model.
|
|
908
1149
|
|
|
@@ -920,7 +1161,7 @@ class AsyncInferenceClient:
|
|
|
920
1161
|
deployed Inference Endpoint. If not provided, the default recommended model for object detection (DETR) will be used.
|
|
921
1162
|
|
|
922
1163
|
Returns:
|
|
923
|
-
`List[
|
|
1164
|
+
`List[ObjectDetectionOutputElement]`: A list of [`ObjectDetectionOutputElement`] items containing the bounding boxes and associated attributes.
|
|
924
1165
|
|
|
925
1166
|
Raises:
|
|
926
1167
|
[`InferenceTimeoutError`]:
|
|
@@ -936,19 +1177,16 @@ class AsyncInferenceClient:
|
|
|
936
1177
|
>>> from huggingface_hub import AsyncInferenceClient
|
|
937
1178
|
>>> client = AsyncInferenceClient()
|
|
938
1179
|
>>> await client.object_detection("people.jpg"):
|
|
939
|
-
[
|
|
1180
|
+
[ObjectDetectionOutputElement(score=0.9486683011054993, label='person', box=ObjectDetectionBoundingBox(xmin=59, ymin=39, xmax=420, ymax=510)), ...]
|
|
940
1181
|
```
|
|
941
1182
|
"""
|
|
942
1183
|
# detect objects
|
|
943
1184
|
response = await self.post(data=image, model=model, task="object-detection")
|
|
944
|
-
|
|
945
|
-
if not isinstance(output, list):
|
|
946
|
-
raise ValueError(f"Server output must be a list. Got {type(output)}: {str(output)[:200]}...")
|
|
947
|
-
return output
|
|
1185
|
+
return ObjectDetectionOutputElement.parse_obj_as_list(response)
|
|
948
1186
|
|
|
949
1187
|
async def question_answering(
|
|
950
1188
|
self, question: str, context: str, *, model: Optional[str] = None
|
|
951
|
-
) ->
|
|
1189
|
+
) -> QuestionAnsweringOutputElement:
|
|
952
1190
|
"""
|
|
953
1191
|
Retrieve the answer to a question from a given text.
|
|
954
1192
|
|
|
@@ -962,7 +1200,7 @@ class AsyncInferenceClient:
|
|
|
962
1200
|
a deployed Inference Endpoint.
|
|
963
1201
|
|
|
964
1202
|
Returns:
|
|
965
|
-
`
|
|
1203
|
+
[`QuestionAnsweringOutputElement`]: an question answering output containing the score, start index, end index, and answer.
|
|
966
1204
|
|
|
967
1205
|
Raises:
|
|
968
1206
|
[`InferenceTimeoutError`]:
|
|
@@ -976,7 +1214,7 @@ class AsyncInferenceClient:
|
|
|
976
1214
|
>>> from huggingface_hub import AsyncInferenceClient
|
|
977
1215
|
>>> client = AsyncInferenceClient()
|
|
978
1216
|
>>> await client.question_answering(question="What's my name?", context="My name is Clara and I live in Berkeley.")
|
|
979
|
-
|
|
1217
|
+
QuestionAnsweringOutputElement(score=0.9326562285423279, start=11, end=16, answer='Clara')
|
|
980
1218
|
```
|
|
981
1219
|
"""
|
|
982
1220
|
|
|
@@ -986,7 +1224,7 @@ class AsyncInferenceClient:
|
|
|
986
1224
|
model=model,
|
|
987
1225
|
task="question-answering",
|
|
988
1226
|
)
|
|
989
|
-
return
|
|
1227
|
+
return QuestionAnsweringOutputElement.parse_obj_as_instance(response)
|
|
990
1228
|
|
|
991
1229
|
async def sentence_similarity(
|
|
992
1230
|
self, sentence: str, other_sentences: List[str], *, model: Optional[str] = None
|
|
@@ -1042,7 +1280,7 @@ class AsyncInferenceClient:
|
|
|
1042
1280
|
*,
|
|
1043
1281
|
parameters: Optional[Dict[str, Any]] = None,
|
|
1044
1282
|
model: Optional[str] = None,
|
|
1045
|
-
) ->
|
|
1283
|
+
) -> SummarizationOutput:
|
|
1046
1284
|
"""
|
|
1047
1285
|
Generate a summary of a given text using a specified model.
|
|
1048
1286
|
|
|
@@ -1057,7 +1295,7 @@ class AsyncInferenceClient:
|
|
|
1057
1295
|
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
|
|
1058
1296
|
|
|
1059
1297
|
Returns:
|
|
1060
|
-
`
|
|
1298
|
+
[`SummarizationOutput`]: The generated summary text.
|
|
1061
1299
|
|
|
1062
1300
|
Raises:
|
|
1063
1301
|
[`InferenceTimeoutError`]:
|
|
@@ -1071,18 +1309,18 @@ class AsyncInferenceClient:
|
|
|
1071
1309
|
>>> from huggingface_hub import AsyncInferenceClient
|
|
1072
1310
|
>>> client = AsyncInferenceClient()
|
|
1073
1311
|
>>> await client.summarization("The Eiffel tower...")
|
|
1074
|
-
|
|
1312
|
+
SummarizationOutput(generated_text="The Eiffel tower is one of the most famous landmarks in the world....")
|
|
1075
1313
|
```
|
|
1076
1314
|
"""
|
|
1077
1315
|
payload: Dict[str, Any] = {"inputs": text}
|
|
1078
1316
|
if parameters is not None:
|
|
1079
1317
|
payload["parameters"] = parameters
|
|
1080
1318
|
response = await self.post(json=payload, model=model, task="summarization")
|
|
1081
|
-
return
|
|
1319
|
+
return SummarizationOutput.parse_obj_as_list(response)[0]
|
|
1082
1320
|
|
|
1083
1321
|
async def table_question_answering(
|
|
1084
1322
|
self, table: Dict[str, Any], query: str, *, model: Optional[str] = None
|
|
1085
|
-
) ->
|
|
1323
|
+
) -> TableQuestionAnsweringOutputElement:
|
|
1086
1324
|
"""
|
|
1087
1325
|
Retrieve the answer to a question from information given in a table.
|
|
1088
1326
|
|
|
@@ -1097,7 +1335,7 @@ class AsyncInferenceClient:
|
|
|
1097
1335
|
Hub or a URL to a deployed Inference Endpoint.
|
|
1098
1336
|
|
|
1099
1337
|
Returns:
|
|
1100
|
-
`
|
|
1338
|
+
[`TableQuestionAnsweringOutputElement`]: a table question answering output containing the answer, coordinates, cells and the aggregator used.
|
|
1101
1339
|
|
|
1102
1340
|
Raises:
|
|
1103
1341
|
[`InferenceTimeoutError`]:
|
|
@@ -1113,7 +1351,7 @@ class AsyncInferenceClient:
|
|
|
1113
1351
|
>>> query = "How many stars does the transformers repository have?"
|
|
1114
1352
|
>>> table = {"Repository": ["Transformers", "Datasets", "Tokenizers"], "Stars": ["36542", "4512", "3934"]}
|
|
1115
1353
|
>>> await client.table_question_answering(table, query, model="google/tapas-base-finetuned-wtq")
|
|
1116
|
-
|
|
1354
|
+
TableQuestionAnsweringOutputElement(answer='36542', coordinates=[[0, 1]], cells=['36542'], aggregator='AVERAGE')
|
|
1117
1355
|
```
|
|
1118
1356
|
"""
|
|
1119
1357
|
response = await self.post(
|
|
@@ -1124,7 +1362,7 @@ class AsyncInferenceClient:
|
|
|
1124
1362
|
model=model,
|
|
1125
1363
|
task="table-question-answering",
|
|
1126
1364
|
)
|
|
1127
|
-
return
|
|
1365
|
+
return TableQuestionAnsweringOutputElement.parse_obj_as_instance(response)
|
|
1128
1366
|
|
|
1129
1367
|
async def tabular_classification(self, table: Dict[str, Any], *, model: Optional[str] = None) -> List[str]:
|
|
1130
1368
|
"""
|
|
@@ -1213,7 +1451,9 @@ class AsyncInferenceClient:
|
|
|
1213
1451
|
response = await self.post(json={"table": table}, model=model, task="tabular-regression")
|
|
1214
1452
|
return _bytes_to_list(response)
|
|
1215
1453
|
|
|
1216
|
-
async def text_classification(
|
|
1454
|
+
async def text_classification(
|
|
1455
|
+
self, text: str, *, model: Optional[str] = None
|
|
1456
|
+
) -> List[TextClassificationOutputElement]:
|
|
1217
1457
|
"""
|
|
1218
1458
|
Perform text classification (e.g. sentiment-analysis) on the given text.
|
|
1219
1459
|
|
|
@@ -1226,7 +1466,7 @@ class AsyncInferenceClient:
|
|
|
1226
1466
|
Defaults to None.
|
|
1227
1467
|
|
|
1228
1468
|
Returns:
|
|
1229
|
-
`List[
|
|
1469
|
+
`List[TextClassificationOutputElement]`: a list of [`TextClassificationOutputElement`] items containing the predicted label and associated probability.
|
|
1230
1470
|
|
|
1231
1471
|
Raises:
|
|
1232
1472
|
[`InferenceTimeoutError`]:
|
|
@@ -1240,11 +1480,14 @@ class AsyncInferenceClient:
|
|
|
1240
1480
|
>>> from huggingface_hub import AsyncInferenceClient
|
|
1241
1481
|
>>> client = AsyncInferenceClient()
|
|
1242
1482
|
>>> await client.text_classification("I like you")
|
|
1243
|
-
[
|
|
1483
|
+
[
|
|
1484
|
+
TextClassificationOutputElement(label='POSITIVE', score=0.9998695850372314),
|
|
1485
|
+
TextClassificationOutputElement(label='NEGATIVE', score=0.0001304351753788069),
|
|
1486
|
+
]
|
|
1244
1487
|
```
|
|
1245
1488
|
"""
|
|
1246
1489
|
response = await self.post(json={"inputs": text}, model=model, task="text-classification")
|
|
1247
|
-
return
|
|
1490
|
+
return TextClassificationOutputElement.parse_obj_as_list(response)[0] # type: ignore [return-value]
|
|
1248
1491
|
|
|
1249
1492
|
@overload
|
|
1250
1493
|
async def text_generation( # type: ignore
|
|
@@ -1267,8 +1510,7 @@ class AsyncInferenceClient:
|
|
|
1267
1510
|
truncate: Optional[int] = None,
|
|
1268
1511
|
typical_p: Optional[float] = None,
|
|
1269
1512
|
watermark: bool = False,
|
|
1270
|
-
) -> str:
|
|
1271
|
-
...
|
|
1513
|
+
) -> str: ...
|
|
1272
1514
|
|
|
1273
1515
|
@overload
|
|
1274
1516
|
async def text_generation( # type: ignore
|
|
@@ -1291,8 +1533,7 @@ class AsyncInferenceClient:
|
|
|
1291
1533
|
truncate: Optional[int] = None,
|
|
1292
1534
|
typical_p: Optional[float] = None,
|
|
1293
1535
|
watermark: bool = False,
|
|
1294
|
-
) ->
|
|
1295
|
-
...
|
|
1536
|
+
) -> TextGenerationOutput: ...
|
|
1296
1537
|
|
|
1297
1538
|
@overload
|
|
1298
1539
|
async def text_generation( # type: ignore
|
|
@@ -1315,11 +1556,10 @@ class AsyncInferenceClient:
|
|
|
1315
1556
|
truncate: Optional[int] = None,
|
|
1316
1557
|
typical_p: Optional[float] = None,
|
|
1317
1558
|
watermark: bool = False,
|
|
1318
|
-
) -> AsyncIterable[str]:
|
|
1319
|
-
...
|
|
1559
|
+
) -> AsyncIterable[str]: ...
|
|
1320
1560
|
|
|
1321
1561
|
@overload
|
|
1322
|
-
async def text_generation(
|
|
1562
|
+
async def text_generation( # type: ignore
|
|
1323
1563
|
self,
|
|
1324
1564
|
prompt: str,
|
|
1325
1565
|
*,
|
|
@@ -1339,8 +1579,30 @@ class AsyncInferenceClient:
|
|
|
1339
1579
|
truncate: Optional[int] = None,
|
|
1340
1580
|
typical_p: Optional[float] = None,
|
|
1341
1581
|
watermark: bool = False,
|
|
1342
|
-
) -> AsyncIterable[
|
|
1343
|
-
|
|
1582
|
+
) -> AsyncIterable[TextGenerationStreamOutput]: ...
|
|
1583
|
+
|
|
1584
|
+
@overload
|
|
1585
|
+
async def text_generation(
|
|
1586
|
+
self,
|
|
1587
|
+
prompt: str,
|
|
1588
|
+
*,
|
|
1589
|
+
details: Literal[True] = ...,
|
|
1590
|
+
stream: bool = ...,
|
|
1591
|
+
model: Optional[str] = None,
|
|
1592
|
+
do_sample: bool = False,
|
|
1593
|
+
max_new_tokens: int = 20,
|
|
1594
|
+
best_of: Optional[int] = None,
|
|
1595
|
+
repetition_penalty: Optional[float] = None,
|
|
1596
|
+
return_full_text: bool = False,
|
|
1597
|
+
seed: Optional[int] = None,
|
|
1598
|
+
stop_sequences: Optional[List[str]] = None,
|
|
1599
|
+
temperature: Optional[float] = None,
|
|
1600
|
+
top_k: Optional[int] = None,
|
|
1601
|
+
top_p: Optional[float] = None,
|
|
1602
|
+
truncate: Optional[int] = None,
|
|
1603
|
+
typical_p: Optional[float] = None,
|
|
1604
|
+
watermark: bool = False,
|
|
1605
|
+
) -> Union[TextGenerationOutput, AsyncIterable[TextGenerationStreamOutput]]: ...
|
|
1344
1606
|
|
|
1345
1607
|
async def text_generation(
|
|
1346
1608
|
self,
|
|
@@ -1363,13 +1625,10 @@ class AsyncInferenceClient:
|
|
|
1363
1625
|
typical_p: Optional[float] = None,
|
|
1364
1626
|
watermark: bool = False,
|
|
1365
1627
|
decoder_input_details: bool = False,
|
|
1366
|
-
) -> Union[str,
|
|
1628
|
+
) -> Union[str, TextGenerationOutput, AsyncIterable[str], AsyncIterable[TextGenerationStreamOutput]]:
|
|
1367
1629
|
"""
|
|
1368
1630
|
Given a prompt, generate the following text.
|
|
1369
1631
|
|
|
1370
|
-
It is recommended to have Pydantic installed in order to get inputs validated. This is preferable as it allow
|
|
1371
|
-
early failures.
|
|
1372
|
-
|
|
1373
1632
|
API endpoint is supposed to run with the `text-generation-inference` backend (TGI). This backend is the
|
|
1374
1633
|
go-to solution to run large language models at scale. However, for some smaller models (e.g. "gpt2") the
|
|
1375
1634
|
default `transformers` + `api-inference` solution is still in use. Both approaches have very similar APIs, but
|
|
@@ -1427,12 +1686,12 @@ class AsyncInferenceClient:
|
|
|
1427
1686
|
into account. Defaults to `False`.
|
|
1428
1687
|
|
|
1429
1688
|
Returns:
|
|
1430
|
-
`Union[str,
|
|
1689
|
+
`Union[str, TextGenerationOutput, Iterable[str], Iterable[TextGenerationStreamOutput]]`:
|
|
1431
1690
|
Generated text returned from the server:
|
|
1432
1691
|
- if `stream=False` and `details=False`, the generated text is returned as a `str` (default)
|
|
1433
1692
|
- if `stream=True` and `details=False`, the generated text is returned token by token as a `Iterable[str]`
|
|
1434
|
-
- if `stream=False` and `details=True`, the generated text is returned with more details as a [`~huggingface_hub.
|
|
1435
|
-
- if `details=True` and `stream=True`, the generated text is returned token by token as a iterable of [`~huggingface_hub.
|
|
1693
|
+
- if `stream=False` and `details=True`, the generated text is returned with more details as a [`~huggingface_hub.TextGenerationOutput`]
|
|
1694
|
+
- if `details=True` and `stream=True`, the generated text is returned token by token as a iterable of [`~huggingface_hub.TextGenerationStreamOutput`]
|
|
1436
1695
|
|
|
1437
1696
|
Raises:
|
|
1438
1697
|
`ValidationError`:
|
|
@@ -1470,23 +1729,23 @@ class AsyncInferenceClient:
|
|
|
1470
1729
|
|
|
1471
1730
|
# Case 3: get more details about the generation process.
|
|
1472
1731
|
>>> await client.text_generation("The huggingface_hub library is ", max_new_tokens=12, details=True)
|
|
1473
|
-
|
|
1732
|
+
TextGenerationOutput(
|
|
1474
1733
|
generated_text='100% open source and built to be easy to use.',
|
|
1475
|
-
details=
|
|
1476
|
-
finish_reason
|
|
1734
|
+
details=TextGenerationDetails(
|
|
1735
|
+
finish_reason='length',
|
|
1477
1736
|
generated_tokens=12,
|
|
1478
1737
|
seed=None,
|
|
1479
1738
|
prefill=[
|
|
1480
|
-
|
|
1481
|
-
|
|
1739
|
+
TextGenerationPrefillToken(id=487, text='The', logprob=None),
|
|
1740
|
+
TextGenerationPrefillToken(id=53789, text=' hugging', logprob=-13.171875),
|
|
1482
1741
|
(...)
|
|
1483
|
-
|
|
1742
|
+
TextGenerationPrefillToken(id=204, text=' ', logprob=-7.0390625)
|
|
1484
1743
|
],
|
|
1485
1744
|
tokens=[
|
|
1486
|
-
|
|
1487
|
-
|
|
1745
|
+
TokenElement(id=1425, text='100', logprob=-1.0175781, special=False),
|
|
1746
|
+
TokenElement(id=16, text='%', logprob=-0.0463562, special=False),
|
|
1488
1747
|
(...)
|
|
1489
|
-
|
|
1748
|
+
TokenElement(id=25, text='.', logprob=-0.5703125, special=False)
|
|
1490
1749
|
],
|
|
1491
1750
|
best_of_sequences=None
|
|
1492
1751
|
)
|
|
@@ -1497,30 +1756,27 @@ class AsyncInferenceClient:
|
|
|
1497
1756
|
>>> async for details in await client.text_generation("The huggingface_hub library is ", max_new_tokens=12, details=True, stream=True):
|
|
1498
1757
|
... print(details)
|
|
1499
1758
|
...
|
|
1500
|
-
|
|
1501
|
-
|
|
1502
|
-
|
|
1503
|
-
|
|
1504
|
-
|
|
1505
|
-
|
|
1506
|
-
|
|
1507
|
-
|
|
1508
|
-
|
|
1509
|
-
|
|
1510
|
-
|
|
1511
|
-
|
|
1759
|
+
TextGenerationStreamOutput(token=TokenElement(id=1425, text='100', logprob=-1.0175781, special=False), generated_text=None, details=None)
|
|
1760
|
+
TextGenerationStreamOutput(token=TokenElement(id=16, text='%', logprob=-0.0463562, special=False), generated_text=None, details=None)
|
|
1761
|
+
TextGenerationStreamOutput(token=TokenElement(id=1314, text=' open', logprob=-1.3359375, special=False), generated_text=None, details=None)
|
|
1762
|
+
TextGenerationStreamOutput(token=TokenElement(id=3178, text=' source', logprob=-0.28100586, special=False), generated_text=None, details=None)
|
|
1763
|
+
TextGenerationStreamOutput(token=TokenElement(id=273, text=' and', logprob=-0.5961914, special=False), generated_text=None, details=None)
|
|
1764
|
+
TextGenerationStreamOutput(token=TokenElement(id=3426, text=' built', logprob=-1.9423828, special=False), generated_text=None, details=None)
|
|
1765
|
+
TextGenerationStreamOutput(token=TokenElement(id=271, text=' to', logprob=-1.4121094, special=False), generated_text=None, details=None)
|
|
1766
|
+
TextGenerationStreamOutput(token=TokenElement(id=314, text=' be', logprob=-1.5224609, special=False), generated_text=None, details=None)
|
|
1767
|
+
TextGenerationStreamOutput(token=TokenElement(id=1833, text=' easy', logprob=-2.1132812, special=False), generated_text=None, details=None)
|
|
1768
|
+
TextGenerationStreamOutput(token=TokenElement(id=271, text=' to', logprob=-0.08520508, special=False), generated_text=None, details=None)
|
|
1769
|
+
TextGenerationStreamOutput(token=TokenElement(id=745, text=' use', logprob=-0.39453125, special=False), generated_text=None, details=None)
|
|
1770
|
+
TextGenerationStreamOutput(token=TokenElement(
|
|
1512
1771
|
id=25,
|
|
1513
1772
|
text='.',
|
|
1514
1773
|
logprob=-0.5703125,
|
|
1515
1774
|
special=False),
|
|
1516
1775
|
generated_text='100% open source and built to be easy to use.',
|
|
1517
|
-
details=
|
|
1776
|
+
details=TextGenerationStreamDetails(finish_reason='length', generated_tokens=12, seed=None)
|
|
1518
1777
|
)
|
|
1519
1778
|
```
|
|
1520
1779
|
"""
|
|
1521
|
-
# NOTE: Text-generation integration is taken from the text-generation-inference project. It has more features
|
|
1522
|
-
# like input/output validation (if Pydantic is installed). See `_text_generation.py` header for more details.
|
|
1523
|
-
|
|
1524
1780
|
if decoder_input_details and not details:
|
|
1525
1781
|
warnings.warn(
|
|
1526
1782
|
"`decoder_input_details=True` has been passed to the server but `details=False` is set meaning that"
|
|
@@ -1528,34 +1784,38 @@ class AsyncInferenceClient:
|
|
|
1528
1784
|
)
|
|
1529
1785
|
decoder_input_details = False
|
|
1530
1786
|
|
|
1531
|
-
#
|
|
1532
|
-
|
|
1533
|
-
|
|
1534
|
-
|
|
1535
|
-
|
|
1536
|
-
|
|
1537
|
-
|
|
1538
|
-
|
|
1539
|
-
|
|
1540
|
-
|
|
1541
|
-
|
|
1542
|
-
|
|
1543
|
-
|
|
1544
|
-
|
|
1545
|
-
|
|
1546
|
-
|
|
1547
|
-
|
|
1548
|
-
|
|
1549
|
-
|
|
1550
|
-
|
|
1787
|
+
# Build payload
|
|
1788
|
+
payload = {
|
|
1789
|
+
"inputs": prompt,
|
|
1790
|
+
"parameters": {
|
|
1791
|
+
"best_of": best_of,
|
|
1792
|
+
"decoder_input_details": decoder_input_details,
|
|
1793
|
+
"details": details,
|
|
1794
|
+
"do_sample": do_sample,
|
|
1795
|
+
"max_new_tokens": max_new_tokens,
|
|
1796
|
+
"repetition_penalty": repetition_penalty,
|
|
1797
|
+
"return_full_text": return_full_text,
|
|
1798
|
+
"seed": seed,
|
|
1799
|
+
"stop": stop_sequences if stop_sequences is not None else [],
|
|
1800
|
+
"temperature": temperature,
|
|
1801
|
+
"top_k": top_k,
|
|
1802
|
+
"top_p": top_p,
|
|
1803
|
+
"truncate": truncate,
|
|
1804
|
+
"typical_p": typical_p,
|
|
1805
|
+
"watermark": watermark,
|
|
1806
|
+
},
|
|
1807
|
+
"stream": stream,
|
|
1808
|
+
}
|
|
1551
1809
|
|
|
1552
1810
|
# Remove some parameters if not a TGI server
|
|
1553
1811
|
if not _is_tgi_server(model):
|
|
1812
|
+
parameters: Dict = payload["parameters"] # type: ignore [assignment]
|
|
1813
|
+
|
|
1554
1814
|
ignored_parameters = []
|
|
1555
|
-
for key in "watermark", "
|
|
1556
|
-
if
|
|
1815
|
+
for key in "watermark", "details", "decoder_input_details", "best_of", "stop", "return_full_text":
|
|
1816
|
+
if parameters[key] is not None:
|
|
1557
1817
|
ignored_parameters.append(key)
|
|
1558
|
-
del
|
|
1818
|
+
del parameters[key]
|
|
1559
1819
|
if len(ignored_parameters) > 0:
|
|
1560
1820
|
warnings.warn(
|
|
1561
1821
|
"API endpoint/model for text-generation is not served via TGI. Ignoring parameters"
|
|
@@ -1608,8 +1868,8 @@ class AsyncInferenceClient:
|
|
|
1608
1868
|
if stream:
|
|
1609
1869
|
return _async_stream_text_generation_response(bytes_output, details) # type: ignore
|
|
1610
1870
|
|
|
1611
|
-
data = _bytes_to_dict(bytes_output)[0]
|
|
1612
|
-
return
|
|
1871
|
+
data = _bytes_to_dict(bytes_output)[0] # type: ignore[arg-type]
|
|
1872
|
+
return TextGenerationOutput.parse_obj_as_instance(data) if details else data["generated_text"]
|
|
1613
1873
|
|
|
1614
1874
|
async def text_to_image(
|
|
1615
1875
|
self,
|
|
@@ -1725,7 +1985,9 @@ class AsyncInferenceClient:
|
|
|
1725
1985
|
"""
|
|
1726
1986
|
return await self.post(json={"inputs": text}, model=model, task="text-to-speech")
|
|
1727
1987
|
|
|
1728
|
-
async def token_classification(
|
|
1988
|
+
async def token_classification(
|
|
1989
|
+
self, text: str, *, model: Optional[str] = None
|
|
1990
|
+
) -> List[TokenClassificationOutputElement]:
|
|
1729
1991
|
"""
|
|
1730
1992
|
Perform token classification on the given text.
|
|
1731
1993
|
Usually used for sentence parsing, either grammatical, or Named Entity Recognition (NER) to understand keywords contained within text.
|
|
@@ -1739,7 +2001,7 @@ class AsyncInferenceClient:
|
|
|
1739
2001
|
Defaults to None.
|
|
1740
2002
|
|
|
1741
2003
|
Returns:
|
|
1742
|
-
`List[
|
|
2004
|
+
`List[TokenClassificationOutputElement]`: List of [`TokenClassificationOutputElement`] items containing the entity group, confidence score, word, start and end index.
|
|
1743
2005
|
|
|
1744
2006
|
Raises:
|
|
1745
2007
|
[`InferenceTimeoutError`]:
|
|
@@ -1753,16 +2015,22 @@ class AsyncInferenceClient:
|
|
|
1753
2015
|
>>> from huggingface_hub import AsyncInferenceClient
|
|
1754
2016
|
>>> client = AsyncInferenceClient()
|
|
1755
2017
|
>>> await client.token_classification("My name is Sarah Jessica Parker but you can call me Jessica")
|
|
1756
|
-
[
|
|
1757
|
-
|
|
1758
|
-
|
|
1759
|
-
|
|
1760
|
-
|
|
1761
|
-
|
|
1762
|
-
|
|
1763
|
-
|
|
1764
|
-
|
|
1765
|
-
|
|
2018
|
+
[
|
|
2019
|
+
TokenClassificationOutputElement(
|
|
2020
|
+
entity_group='PER',
|
|
2021
|
+
score=0.9971321225166321,
|
|
2022
|
+
word='Sarah Jessica Parker',
|
|
2023
|
+
start=11,
|
|
2024
|
+
end=31,
|
|
2025
|
+
),
|
|
2026
|
+
TokenClassificationOutputElement(
|
|
2027
|
+
entity_group='PER',
|
|
2028
|
+
score=0.9773476123809814,
|
|
2029
|
+
word='Jessica',
|
|
2030
|
+
start=52,
|
|
2031
|
+
end=59,
|
|
2032
|
+
)
|
|
2033
|
+
]
|
|
1766
2034
|
```
|
|
1767
2035
|
"""
|
|
1768
2036
|
payload: Dict[str, Any] = {"inputs": text}
|
|
@@ -1771,11 +2039,11 @@ class AsyncInferenceClient:
|
|
|
1771
2039
|
model=model,
|
|
1772
2040
|
task="token-classification",
|
|
1773
2041
|
)
|
|
1774
|
-
return
|
|
2042
|
+
return TokenClassificationOutputElement.parse_obj_as_list(response)
|
|
1775
2043
|
|
|
1776
2044
|
async def translation(
|
|
1777
2045
|
self, text: str, *, model: Optional[str] = None, src_lang: Optional[str] = None, tgt_lang: Optional[str] = None
|
|
1778
|
-
) ->
|
|
2046
|
+
) -> TranslationOutput:
|
|
1779
2047
|
"""
|
|
1780
2048
|
Convert text from one language to another.
|
|
1781
2049
|
|
|
@@ -1798,7 +2066,7 @@ class AsyncInferenceClient:
|
|
|
1798
2066
|
Target language of the translation task, i.e. output language. Cannot be passed without `src_lang`.
|
|
1799
2067
|
|
|
1800
2068
|
Returns:
|
|
1801
|
-
`
|
|
2069
|
+
[`TranslationOutput`]: The generated translated text.
|
|
1802
2070
|
|
|
1803
2071
|
Raises:
|
|
1804
2072
|
[`InferenceTimeoutError`]:
|
|
@@ -1816,7 +2084,7 @@ class AsyncInferenceClient:
|
|
|
1816
2084
|
>>> await client.translation("My name is Wolfgang and I live in Berlin")
|
|
1817
2085
|
'Mein Name ist Wolfgang und ich lebe in Berlin.'
|
|
1818
2086
|
>>> await client.translation("My name is Wolfgang and I live in Berlin", model="Helsinki-NLP/opus-mt-en-fr")
|
|
1819
|
-
|
|
2087
|
+
TranslationOutput(translation_text='Je m\'appelle Wolfgang et je vis à Berlin.')
|
|
1820
2088
|
```
|
|
1821
2089
|
|
|
1822
2090
|
Specifying languages:
|
|
@@ -1837,11 +2105,59 @@ class AsyncInferenceClient:
|
|
|
1837
2105
|
if src_lang and tgt_lang:
|
|
1838
2106
|
payload["parameters"] = {"src_lang": src_lang, "tgt_lang": tgt_lang}
|
|
1839
2107
|
response = await self.post(json=payload, model=model, task="translation")
|
|
1840
|
-
return
|
|
2108
|
+
return TranslationOutput.parse_obj_as_list(response)[0]
|
|
2109
|
+
|
|
2110
|
+
async def visual_question_answering(
|
|
2111
|
+
self,
|
|
2112
|
+
image: ContentT,
|
|
2113
|
+
question: str,
|
|
2114
|
+
*,
|
|
2115
|
+
model: Optional[str] = None,
|
|
2116
|
+
) -> List[VisualQuestionAnsweringOutputElement]:
|
|
2117
|
+
"""
|
|
2118
|
+
Answering open-ended questions based on an image.
|
|
2119
|
+
|
|
2120
|
+
Args:
|
|
2121
|
+
image (`Union[str, Path, bytes, BinaryIO]`):
|
|
2122
|
+
The input image for the context. It can be raw bytes, an image file, or a URL to an online image.
|
|
2123
|
+
question (`str`):
|
|
2124
|
+
Question to be answered.
|
|
2125
|
+
model (`str`, *optional*):
|
|
2126
|
+
The model to use for the visual question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to
|
|
2127
|
+
a deployed Inference Endpoint. If not provided, the default recommended visual question answering model will be used.
|
|
2128
|
+
Defaults to None.
|
|
2129
|
+
|
|
2130
|
+
Returns:
|
|
2131
|
+
`List[VisualQuestionAnsweringOutputElement]`: a list of [`VisualQuestionAnsweringOutputElement`] items containing the predicted label and associated probability.
|
|
2132
|
+
|
|
2133
|
+
Raises:
|
|
2134
|
+
`InferenceTimeoutError`:
|
|
2135
|
+
If the model is unavailable or the request times out.
|
|
2136
|
+
`aiohttp.ClientResponseError`:
|
|
2137
|
+
If the request fails with an HTTP error status code other than HTTP 503.
|
|
2138
|
+
|
|
2139
|
+
Example:
|
|
2140
|
+
```py
|
|
2141
|
+
# Must be run in an async context
|
|
2142
|
+
>>> from huggingface_hub import AsyncInferenceClient
|
|
2143
|
+
>>> client = AsyncInferenceClient()
|
|
2144
|
+
>>> await client.visual_question_answering(
|
|
2145
|
+
... image="https://huggingface.co/datasets/mishig/sample_images/resolve/main/tiger.jpg",
|
|
2146
|
+
... question="What is the animal doing?"
|
|
2147
|
+
... )
|
|
2148
|
+
[
|
|
2149
|
+
VisualQuestionAnsweringOutputElement(score=0.778609573841095, answer='laying down'),
|
|
2150
|
+
VisualQuestionAnsweringOutputElement(score=0.6957435607910156, answer='sitting'),
|
|
2151
|
+
]
|
|
2152
|
+
```
|
|
2153
|
+
"""
|
|
2154
|
+
payload: Dict[str, Any] = {"question": question, "image": _b64_encode(image)}
|
|
2155
|
+
response = await self.post(json=payload, model=model, task="visual-question-answering")
|
|
2156
|
+
return VisualQuestionAnsweringOutputElement.parse_obj_as_list(response)
|
|
1841
2157
|
|
|
1842
2158
|
async def zero_shot_classification(
|
|
1843
2159
|
self, text: str, labels: List[str], *, multi_label: bool = False, model: Optional[str] = None
|
|
1844
|
-
) -> List[
|
|
2160
|
+
) -> List[ZeroShotClassificationOutputElement]:
|
|
1845
2161
|
"""
|
|
1846
2162
|
Provide as input a text and a set of candidate labels to classify the input text.
|
|
1847
2163
|
|
|
@@ -1857,7 +2173,7 @@ class AsyncInferenceClient:
|
|
|
1857
2173
|
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
|
|
1858
2174
|
|
|
1859
2175
|
Returns:
|
|
1860
|
-
`List[
|
|
2176
|
+
`List[ZeroShotClassificationOutputElement]`: List of [`ZeroShotClassificationOutputElement`] items containing the predicted labels and their confidence.
|
|
1861
2177
|
|
|
1862
2178
|
Raises:
|
|
1863
2179
|
[`InferenceTimeoutError`]:
|
|
@@ -1878,19 +2194,19 @@ class AsyncInferenceClient:
|
|
|
1878
2194
|
>>> labels = ["space & cosmos", "scientific discovery", "microbiology", "robots", "archeology"]
|
|
1879
2195
|
>>> await client.zero_shot_classification(text, labels)
|
|
1880
2196
|
[
|
|
1881
|
-
|
|
1882
|
-
|
|
1883
|
-
|
|
1884
|
-
|
|
1885
|
-
|
|
2197
|
+
ZeroShotClassificationOutputElement(label='scientific discovery', score=0.7961668968200684),
|
|
2198
|
+
ZeroShotClassificationOutputElement(label='space & cosmos', score=0.18570658564567566),
|
|
2199
|
+
ZeroShotClassificationOutputElement(label='microbiology', score=0.00730885099619627),
|
|
2200
|
+
ZeroShotClassificationOutputElement(label='archeology', score=0.006258360575884581),
|
|
2201
|
+
ZeroShotClassificationOutputElement(label='robots', score=0.004559356719255447),
|
|
1886
2202
|
]
|
|
1887
2203
|
>>> await client.zero_shot_classification(text, labels, multi_label=True)
|
|
1888
2204
|
[
|
|
1889
|
-
|
|
1890
|
-
|
|
1891
|
-
|
|
1892
|
-
|
|
1893
|
-
|
|
2205
|
+
ZeroShotClassificationOutputElement(label='scientific discovery', score=0.9829297661781311),
|
|
2206
|
+
ZeroShotClassificationOutputElement(label='space & cosmos', score=0.755190908908844),
|
|
2207
|
+
ZeroShotClassificationOutputElement(label='microbiology', score=0.0005462635890580714),
|
|
2208
|
+
ZeroShotClassificationOutputElement(label='archeology', score=0.00047131875180639327),
|
|
2209
|
+
ZeroShotClassificationOutputElement(label='robots', score=0.00030448526376858354),
|
|
1894
2210
|
]
|
|
1895
2211
|
```
|
|
1896
2212
|
"""
|
|
@@ -1910,11 +2226,14 @@ class AsyncInferenceClient:
|
|
|
1910
2226
|
task="zero-shot-classification",
|
|
1911
2227
|
)
|
|
1912
2228
|
output = _bytes_to_dict(response)
|
|
1913
|
-
return [
|
|
2229
|
+
return [
|
|
2230
|
+
ZeroShotClassificationOutputElement.parse_obj_as_instance({"label": label, "score": score})
|
|
2231
|
+
for label, score in zip(output["labels"], output["scores"])
|
|
2232
|
+
]
|
|
1914
2233
|
|
|
1915
2234
|
async def zero_shot_image_classification(
|
|
1916
2235
|
self, image: ContentT, labels: List[str], *, model: Optional[str] = None
|
|
1917
|
-
) -> List[
|
|
2236
|
+
) -> List[ZeroShotImageClassificationOutputElement]:
|
|
1918
2237
|
"""
|
|
1919
2238
|
Provide input image and text labels to predict text labels for the image.
|
|
1920
2239
|
|
|
@@ -1928,7 +2247,7 @@ class AsyncInferenceClient:
|
|
|
1928
2247
|
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
|
|
1929
2248
|
|
|
1930
2249
|
Returns:
|
|
1931
|
-
`List[
|
|
2250
|
+
`List[ZeroShotImageClassificationOutputElement]`: List of [`ZeroShotImageClassificationOutputElement`] items containing the predicted labels and their confidence.
|
|
1932
2251
|
|
|
1933
2252
|
Raises:
|
|
1934
2253
|
[`InferenceTimeoutError`]:
|
|
@@ -1946,7 +2265,7 @@ class AsyncInferenceClient:
|
|
|
1946
2265
|
... "https://upload.wikimedia.org/wikipedia/commons/thumb/4/43/Cute_dog.jpg/320px-Cute_dog.jpg",
|
|
1947
2266
|
... labels=["dog", "cat", "horse"],
|
|
1948
2267
|
... )
|
|
1949
|
-
[
|
|
2268
|
+
[ZeroShotImageClassificationOutputElement(label='dog', score=0.956),...]
|
|
1950
2269
|
```
|
|
1951
2270
|
"""
|
|
1952
2271
|
# Raise ValueError if input is less than 2 labels
|
|
@@ -1958,7 +2277,7 @@ class AsyncInferenceClient:
|
|
|
1958
2277
|
model=model,
|
|
1959
2278
|
task="zero-shot-image-classification",
|
|
1960
2279
|
)
|
|
1961
|
-
return
|
|
2280
|
+
return ZeroShotImageClassificationOutputElement.parse_obj_as_list(response)
|
|
1962
2281
|
|
|
1963
2282
|
def _resolve_url(self, model: Optional[str] = None, task: Optional[str] = None) -> str:
|
|
1964
2283
|
model = model or self.model
|