huggingface-hub 0.21.4__py3-none-any.whl → 0.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of huggingface-hub might be problematic. Click here for more details.
- huggingface_hub/__init__.py +217 -1
- huggingface_hub/_commit_api.py +14 -15
- huggingface_hub/_inference_endpoints.py +12 -11
- huggingface_hub/_login.py +1 -0
- huggingface_hub/_multi_commits.py +1 -0
- huggingface_hub/_snapshot_download.py +9 -1
- huggingface_hub/_tensorboard_logger.py +1 -0
- huggingface_hub/_webhooks_payload.py +1 -0
- huggingface_hub/_webhooks_server.py +1 -0
- huggingface_hub/commands/_cli_utils.py +1 -0
- huggingface_hub/commands/delete_cache.py +1 -0
- huggingface_hub/commands/download.py +1 -0
- huggingface_hub/commands/env.py +1 -0
- huggingface_hub/commands/scan_cache.py +1 -0
- huggingface_hub/commands/upload.py +1 -0
- huggingface_hub/community.py +1 -0
- huggingface_hub/constants.py +3 -1
- huggingface_hub/errors.py +38 -0
- huggingface_hub/file_download.py +102 -95
- huggingface_hub/hf_api.py +47 -35
- huggingface_hub/hf_file_system.py +77 -3
- huggingface_hub/hub_mixin.py +215 -54
- huggingface_hub/inference/_client.py +554 -239
- huggingface_hub/inference/_common.py +195 -41
- huggingface_hub/inference/_generated/_async_client.py +558 -239
- huggingface_hub/inference/_generated/types/__init__.py +115 -0
- huggingface_hub/inference/_generated/types/audio_classification.py +43 -0
- huggingface_hub/inference/_generated/types/audio_to_audio.py +31 -0
- huggingface_hub/inference/_generated/types/automatic_speech_recognition.py +116 -0
- huggingface_hub/inference/_generated/types/base.py +149 -0
- huggingface_hub/inference/_generated/types/chat_completion.py +106 -0
- huggingface_hub/inference/_generated/types/depth_estimation.py +29 -0
- huggingface_hub/inference/_generated/types/document_question_answering.py +85 -0
- huggingface_hub/inference/_generated/types/feature_extraction.py +19 -0
- huggingface_hub/inference/_generated/types/fill_mask.py +50 -0
- huggingface_hub/inference/_generated/types/image_classification.py +43 -0
- huggingface_hub/inference/_generated/types/image_segmentation.py +52 -0
- huggingface_hub/inference/_generated/types/image_to_image.py +55 -0
- huggingface_hub/inference/_generated/types/image_to_text.py +105 -0
- huggingface_hub/inference/_generated/types/object_detection.py +55 -0
- huggingface_hub/inference/_generated/types/question_answering.py +77 -0
- huggingface_hub/inference/_generated/types/sentence_similarity.py +28 -0
- huggingface_hub/inference/_generated/types/summarization.py +46 -0
- huggingface_hub/inference/_generated/types/table_question_answering.py +45 -0
- huggingface_hub/inference/_generated/types/text2text_generation.py +45 -0
- huggingface_hub/inference/_generated/types/text_classification.py +43 -0
- huggingface_hub/inference/_generated/types/text_generation.py +161 -0
- huggingface_hub/inference/_generated/types/text_to_audio.py +105 -0
- huggingface_hub/inference/_generated/types/text_to_image.py +57 -0
- huggingface_hub/inference/_generated/types/token_classification.py +53 -0
- huggingface_hub/inference/_generated/types/translation.py +46 -0
- huggingface_hub/inference/_generated/types/video_classification.py +47 -0
- huggingface_hub/inference/_generated/types/visual_question_answering.py +53 -0
- huggingface_hub/inference/_generated/types/zero_shot_classification.py +56 -0
- huggingface_hub/inference/_generated/types/zero_shot_image_classification.py +51 -0
- huggingface_hub/inference/_generated/types/zero_shot_object_detection.py +55 -0
- huggingface_hub/inference/_templating.py +105 -0
- huggingface_hub/inference/_types.py +4 -152
- huggingface_hub/keras_mixin.py +39 -17
- huggingface_hub/lfs.py +20 -8
- huggingface_hub/repocard.py +11 -3
- huggingface_hub/repocard_data.py +12 -2
- huggingface_hub/serialization/__init__.py +1 -0
- huggingface_hub/serialization/_base.py +1 -0
- huggingface_hub/serialization/_numpy.py +1 -0
- huggingface_hub/serialization/_tensorflow.py +1 -0
- huggingface_hub/serialization/_torch.py +1 -0
- huggingface_hub/utils/__init__.py +4 -1
- huggingface_hub/utils/_cache_manager.py +7 -0
- huggingface_hub/utils/_chunk_utils.py +1 -0
- huggingface_hub/utils/_datetime.py +1 -0
- huggingface_hub/utils/_errors.py +10 -1
- huggingface_hub/utils/_experimental.py +1 -0
- huggingface_hub/utils/_fixes.py +19 -3
- huggingface_hub/utils/_git_credential.py +1 -0
- huggingface_hub/utils/_headers.py +10 -3
- huggingface_hub/utils/_hf_folder.py +1 -0
- huggingface_hub/utils/_http.py +1 -0
- huggingface_hub/utils/_pagination.py +1 -0
- huggingface_hub/utils/_paths.py +1 -0
- huggingface_hub/utils/_runtime.py +22 -0
- huggingface_hub/utils/_subprocess.py +1 -0
- huggingface_hub/utils/_token.py +1 -0
- huggingface_hub/utils/_typing.py +29 -1
- huggingface_hub/utils/_validators.py +1 -0
- huggingface_hub/utils/endpoint_helpers.py +1 -0
- huggingface_hub/utils/logging.py +1 -1
- huggingface_hub/utils/sha.py +1 -0
- huggingface_hub/utils/tqdm.py +1 -0
- {huggingface_hub-0.21.4.dist-info → huggingface_hub-0.22.0.dist-info}/METADATA +14 -15
- huggingface_hub-0.22.0.dist-info/RECORD +113 -0
- {huggingface_hub-0.21.4.dist-info → huggingface_hub-0.22.0.dist-info}/WHEEL +1 -1
- huggingface_hub/inference/_text_generation.py +0 -551
- huggingface_hub-0.21.4.dist-info/RECORD +0 -81
- {huggingface_hub-0.21.4.dist-info → huggingface_hub-0.22.0.dist-info}/LICENSE +0 -0
- {huggingface_hub-0.21.4.dist-info → huggingface_hub-0.22.0.dist-info}/entry_points.txt +0 -0
- {huggingface_hub-0.21.4.dist-info → huggingface_hub-0.22.0.dist-info}/top_level.txt +0 -0
|
@@ -23,7 +23,6 @@
|
|
|
23
23
|
# https://github.com/huggingface/unity-api#tasks
|
|
24
24
|
#
|
|
25
25
|
# Some TODO:
|
|
26
|
-
# - validate inputs/options/parameters? with Pydantic for instance? or only optionally?
|
|
27
26
|
# - add all tasks
|
|
28
27
|
#
|
|
29
28
|
# NOTE: the philosophy of this client is "let's make it as easy as possible to use it, even if less optimized". Some
|
|
@@ -37,7 +36,6 @@ import base64
|
|
|
37
36
|
import logging
|
|
38
37
|
import time
|
|
39
38
|
import warnings
|
|
40
|
-
from dataclasses import asdict
|
|
41
39
|
from typing import (
|
|
42
40
|
TYPE_CHECKING,
|
|
43
41
|
Any,
|
|
@@ -54,10 +52,10 @@ from requests import HTTPError
|
|
|
54
52
|
from requests.structures import CaseInsensitiveDict
|
|
55
53
|
|
|
56
54
|
from huggingface_hub.constants import ALL_INFERENCE_API_FRAMEWORKS, INFERENCE_ENDPOINT, MAIN_INFERENCE_API_FRAMEWORKS
|
|
55
|
+
from huggingface_hub.errors import InferenceTimeoutError
|
|
57
56
|
from huggingface_hub.inference._common import (
|
|
58
57
|
TASKS_EXPECTING_IMAGES,
|
|
59
58
|
ContentT,
|
|
60
|
-
InferenceTimeoutError,
|
|
61
59
|
ModelStatus,
|
|
62
60
|
_b64_encode,
|
|
63
61
|
_b64_to_image,
|
|
@@ -66,28 +64,45 @@ from huggingface_hub.inference._common import (
|
|
|
66
64
|
_bytes_to_list,
|
|
67
65
|
_fetch_recommended_models,
|
|
68
66
|
_import_numpy,
|
|
67
|
+
_is_chat_completion_server,
|
|
69
68
|
_is_tgi_server,
|
|
70
69
|
_open_as_binary,
|
|
70
|
+
_set_as_non_chat_completion_server,
|
|
71
71
|
_set_as_non_tgi,
|
|
72
|
+
_stream_chat_completion_response_from_bytes,
|
|
73
|
+
_stream_chat_completion_response_from_text_generation,
|
|
72
74
|
_stream_text_generation_response,
|
|
73
|
-
)
|
|
74
|
-
from huggingface_hub.inference._text_generation import (
|
|
75
|
-
TextGenerationParameters,
|
|
76
|
-
TextGenerationRequest,
|
|
77
|
-
TextGenerationResponse,
|
|
78
|
-
TextGenerationStreamResponse,
|
|
79
75
|
raise_text_generation_error,
|
|
80
76
|
)
|
|
77
|
+
from huggingface_hub.inference._generated.types import (
|
|
78
|
+
AudioClassificationOutputElement,
|
|
79
|
+
AudioToAudioOutputElement,
|
|
80
|
+
AutomaticSpeechRecognitionOutput,
|
|
81
|
+
ChatCompletionOutput,
|
|
82
|
+
ChatCompletionOutputChoice,
|
|
83
|
+
ChatCompletionOutputChoiceMessage,
|
|
84
|
+
ChatCompletionStreamOutput,
|
|
85
|
+
DocumentQuestionAnsweringOutputElement,
|
|
86
|
+
FillMaskOutputElement,
|
|
87
|
+
ImageClassificationOutputElement,
|
|
88
|
+
ImageSegmentationOutputElement,
|
|
89
|
+
ImageToTextOutput,
|
|
90
|
+
ObjectDetectionOutputElement,
|
|
91
|
+
QuestionAnsweringOutputElement,
|
|
92
|
+
SummarizationOutput,
|
|
93
|
+
TableQuestionAnsweringOutputElement,
|
|
94
|
+
TextClassificationOutputElement,
|
|
95
|
+
TextGenerationOutput,
|
|
96
|
+
TextGenerationStreamOutput,
|
|
97
|
+
TokenClassificationOutputElement,
|
|
98
|
+
TranslationOutput,
|
|
99
|
+
VisualQuestionAnsweringOutputElement,
|
|
100
|
+
ZeroShotClassificationOutputElement,
|
|
101
|
+
ZeroShotImageClassificationOutputElement,
|
|
102
|
+
)
|
|
103
|
+
from huggingface_hub.inference._templating import render_chat_prompt
|
|
81
104
|
from huggingface_hub.inference._types import (
|
|
82
|
-
|
|
83
|
-
ClassificationOutput,
|
|
84
|
-
ConversationalOutput,
|
|
85
|
-
FillMaskOutput,
|
|
86
|
-
ImageSegmentationOutput,
|
|
87
|
-
ObjectDetectionOutput,
|
|
88
|
-
QuestionAnsweringOutput,
|
|
89
|
-
TableQuestionAnsweringOutput,
|
|
90
|
-
TokenClassificationOutput,
|
|
105
|
+
ConversationalOutput, # soon to be removed
|
|
91
106
|
)
|
|
92
107
|
from huggingface_hub.utils import (
|
|
93
108
|
BadRequestError,
|
|
@@ -116,9 +131,9 @@ class InferenceClient:
|
|
|
116
131
|
The model to run inference with. Can be a model id hosted on the Hugging Face Hub, e.g. `bigcode/starcoder`
|
|
117
132
|
or a URL to a deployed Inference Endpoint. Defaults to None, in which case a recommended model is
|
|
118
133
|
automatically selected for the task.
|
|
119
|
-
token (`str`, *optional*):
|
|
120
|
-
Hugging Face token. Will default to the locally saved token
|
|
121
|
-
your token to the server.
|
|
134
|
+
token (`str` or `bool`, *optional*):
|
|
135
|
+
Hugging Face token. Will default to the locally saved token if not provided.
|
|
136
|
+
Pass `token=False` if you don't want to send your token to the server.
|
|
122
137
|
timeout (`float`, `optional`):
|
|
123
138
|
The maximum number of seconds to wait for a response from the server. Loading a new model in Inference
|
|
124
139
|
API can take up to several minutes. Defaults to None, meaning it will loop until the server is available.
|
|
@@ -138,6 +153,7 @@ class InferenceClient:
|
|
|
138
153
|
cookies: Optional[Dict[str, str]] = None,
|
|
139
154
|
) -> None:
|
|
140
155
|
self.model: Optional[str] = model
|
|
156
|
+
self.token: Union[str, bool, None] = token
|
|
141
157
|
self.headers = CaseInsensitiveDict(build_hf_headers(token=token)) # contains 'authorization' + 'user-agent'
|
|
142
158
|
if headers is not None:
|
|
143
159
|
self.headers.update(headers)
|
|
@@ -156,11 +172,10 @@ class InferenceClient:
|
|
|
156
172
|
model: Optional[str] = None,
|
|
157
173
|
task: Optional[str] = None,
|
|
158
174
|
stream: Literal[False] = ...,
|
|
159
|
-
) -> bytes:
|
|
160
|
-
pass
|
|
175
|
+
) -> bytes: ...
|
|
161
176
|
|
|
162
177
|
@overload
|
|
163
|
-
def post(
|
|
178
|
+
def post( # type: ignore[misc]
|
|
164
179
|
self,
|
|
165
180
|
*,
|
|
166
181
|
json: Optional[Union[str, Dict, List]] = None,
|
|
@@ -168,8 +183,18 @@ class InferenceClient:
|
|
|
168
183
|
model: Optional[str] = None,
|
|
169
184
|
task: Optional[str] = None,
|
|
170
185
|
stream: Literal[True] = ...,
|
|
171
|
-
) -> Iterable[bytes]:
|
|
172
|
-
|
|
186
|
+
) -> Iterable[bytes]: ...
|
|
187
|
+
|
|
188
|
+
@overload
|
|
189
|
+
def post(
|
|
190
|
+
self,
|
|
191
|
+
*,
|
|
192
|
+
json: Optional[Union[str, Dict, List]] = None,
|
|
193
|
+
data: Optional[ContentT] = None,
|
|
194
|
+
model: Optional[str] = None,
|
|
195
|
+
task: Optional[str] = None,
|
|
196
|
+
stream: bool = False,
|
|
197
|
+
) -> Union[bytes, Iterable[bytes]]: ...
|
|
173
198
|
|
|
174
199
|
def post(
|
|
175
200
|
self,
|
|
@@ -268,7 +293,7 @@ class InferenceClient:
|
|
|
268
293
|
audio: ContentT,
|
|
269
294
|
*,
|
|
270
295
|
model: Optional[str] = None,
|
|
271
|
-
) -> List[
|
|
296
|
+
) -> List[AudioClassificationOutputElement]:
|
|
272
297
|
"""
|
|
273
298
|
Perform audio classification on the provided audio content.
|
|
274
299
|
|
|
@@ -282,7 +307,7 @@ class InferenceClient:
|
|
|
282
307
|
audio classification will be used.
|
|
283
308
|
|
|
284
309
|
Returns:
|
|
285
|
-
`List[
|
|
310
|
+
`List[AudioClassificationOutputElement]`: List of [`AudioClassificationOutputElement`] items containing the predicted labels and their confidence.
|
|
286
311
|
|
|
287
312
|
Raises:
|
|
288
313
|
[`InferenceTimeoutError`]:
|
|
@@ -295,18 +320,22 @@ class InferenceClient:
|
|
|
295
320
|
>>> from huggingface_hub import InferenceClient
|
|
296
321
|
>>> client = InferenceClient()
|
|
297
322
|
>>> client.audio_classification("audio.flac")
|
|
298
|
-
[
|
|
323
|
+
[
|
|
324
|
+
AudioClassificationOutputElement(score=0.4976358711719513, label='hap'),
|
|
325
|
+
AudioClassificationOutputElement(score=0.3677836060523987, label='neu'),
|
|
326
|
+
...
|
|
327
|
+
]
|
|
299
328
|
```
|
|
300
329
|
"""
|
|
301
330
|
response = self.post(data=audio, model=model, task="audio-classification")
|
|
302
|
-
return
|
|
331
|
+
return AudioClassificationOutputElement.parse_obj_as_list(response)
|
|
303
332
|
|
|
304
333
|
def audio_to_audio(
|
|
305
334
|
self,
|
|
306
335
|
audio: ContentT,
|
|
307
336
|
*,
|
|
308
337
|
model: Optional[str] = None,
|
|
309
|
-
) -> List[
|
|
338
|
+
) -> List[AudioToAudioOutputElement]:
|
|
310
339
|
"""
|
|
311
340
|
Performs multiple tasks related to audio-to-audio depending on the model (eg: speech enhancement, source separation).
|
|
312
341
|
|
|
@@ -320,7 +349,7 @@ class InferenceClient:
|
|
|
320
349
|
audio_to_audio will be used.
|
|
321
350
|
|
|
322
351
|
Returns:
|
|
323
|
-
`List[
|
|
352
|
+
`List[AudioToAudioOutputElement]`: A list of [`AudioToAudioOutputElement`] items containing audios label, content-type, and audio content in blob.
|
|
324
353
|
|
|
325
354
|
Raises:
|
|
326
355
|
`InferenceTimeoutError`:
|
|
@@ -335,13 +364,13 @@ class InferenceClient:
|
|
|
335
364
|
>>> audio_output = client.audio_to_audio("audio.flac")
|
|
336
365
|
>>> for i, item in enumerate(audio_output):
|
|
337
366
|
>>> with open(f"output_{i}.flac", "wb") as f:
|
|
338
|
-
f.write(item
|
|
367
|
+
f.write(item.blob)
|
|
339
368
|
```
|
|
340
369
|
"""
|
|
341
370
|
response = self.post(data=audio, model=model, task="audio-to-audio")
|
|
342
|
-
audio_output =
|
|
371
|
+
audio_output = AudioToAudioOutputElement.parse_obj_as_list(response)
|
|
343
372
|
for item in audio_output:
|
|
344
|
-
item
|
|
373
|
+
item.blob = base64.b64decode(item.blob)
|
|
345
374
|
return audio_output
|
|
346
375
|
|
|
347
376
|
def automatic_speech_recognition(
|
|
@@ -349,7 +378,7 @@ class InferenceClient:
|
|
|
349
378
|
audio: ContentT,
|
|
350
379
|
*,
|
|
351
380
|
model: Optional[str] = None,
|
|
352
|
-
) ->
|
|
381
|
+
) -> AutomaticSpeechRecognitionOutput:
|
|
353
382
|
"""
|
|
354
383
|
Perform automatic speech recognition (ASR or audio-to-text) on the given audio content.
|
|
355
384
|
|
|
@@ -361,7 +390,7 @@ class InferenceClient:
|
|
|
361
390
|
Inference Endpoint. If not provided, the default recommended model for ASR will be used.
|
|
362
391
|
|
|
363
392
|
Returns:
|
|
364
|
-
|
|
393
|
+
[`AutomaticSpeechRecognitionOutput`]: An item containing the transcribed text and optionally the timestamp chunks.
|
|
365
394
|
|
|
366
395
|
Raises:
|
|
367
396
|
[`InferenceTimeoutError`]:
|
|
@@ -373,12 +402,265 @@ class InferenceClient:
|
|
|
373
402
|
```py
|
|
374
403
|
>>> from huggingface_hub import InferenceClient
|
|
375
404
|
>>> client = InferenceClient()
|
|
376
|
-
>>> client.automatic_speech_recognition("hello_world.flac")
|
|
405
|
+
>>> client.automatic_speech_recognition("hello_world.flac").text
|
|
377
406
|
"hello world"
|
|
378
407
|
```
|
|
379
408
|
"""
|
|
380
409
|
response = self.post(data=audio, model=model, task="automatic-speech-recognition")
|
|
381
|
-
return
|
|
410
|
+
return AutomaticSpeechRecognitionOutput.parse_obj_as_instance(response)
|
|
411
|
+
|
|
412
|
+
@overload
|
|
413
|
+
def chat_completion( # type: ignore
|
|
414
|
+
self,
|
|
415
|
+
messages: List[Dict[str, str]],
|
|
416
|
+
*,
|
|
417
|
+
model: Optional[str] = None,
|
|
418
|
+
stream: Literal[False] = False,
|
|
419
|
+
max_tokens: int = 20,
|
|
420
|
+
seed: Optional[int] = None,
|
|
421
|
+
stop: Optional[Union[List[str], str]] = None,
|
|
422
|
+
temperature: float = 1.0,
|
|
423
|
+
top_p: Optional[float] = None,
|
|
424
|
+
) -> ChatCompletionOutput: ...
|
|
425
|
+
|
|
426
|
+
@overload
|
|
427
|
+
def chat_completion( # type: ignore
|
|
428
|
+
self,
|
|
429
|
+
messages: List[Dict[str, str]],
|
|
430
|
+
*,
|
|
431
|
+
model: Optional[str] = None,
|
|
432
|
+
stream: Literal[True] = True,
|
|
433
|
+
max_tokens: int = 20,
|
|
434
|
+
seed: Optional[int] = None,
|
|
435
|
+
stop: Optional[Union[List[str], str]] = None,
|
|
436
|
+
temperature: float = 1.0,
|
|
437
|
+
top_p: Optional[float] = None,
|
|
438
|
+
) -> Iterable[ChatCompletionStreamOutput]: ...
|
|
439
|
+
|
|
440
|
+
@overload
|
|
441
|
+
def chat_completion(
|
|
442
|
+
self,
|
|
443
|
+
messages: List[Dict[str, str]],
|
|
444
|
+
*,
|
|
445
|
+
model: Optional[str] = None,
|
|
446
|
+
stream: bool = False,
|
|
447
|
+
max_tokens: int = 20,
|
|
448
|
+
seed: Optional[int] = None,
|
|
449
|
+
stop: Optional[Union[List[str], str]] = None,
|
|
450
|
+
temperature: float = 1.0,
|
|
451
|
+
top_p: Optional[float] = None,
|
|
452
|
+
) -> Union[ChatCompletionOutput, Iterable[ChatCompletionStreamOutput]]: ...
|
|
453
|
+
|
|
454
|
+
def chat_completion(
|
|
455
|
+
self,
|
|
456
|
+
messages: List[Dict[str, str]],
|
|
457
|
+
*,
|
|
458
|
+
model: Optional[str] = None,
|
|
459
|
+
stream: bool = False,
|
|
460
|
+
max_tokens: int = 20,
|
|
461
|
+
seed: Optional[int] = None,
|
|
462
|
+
stop: Optional[Union[List[str], str]] = None,
|
|
463
|
+
temperature: float = 1.0,
|
|
464
|
+
top_p: Optional[float] = None,
|
|
465
|
+
) -> Union[ChatCompletionOutput, Iterable[ChatCompletionStreamOutput]]:
|
|
466
|
+
"""
|
|
467
|
+
A method for completing conversations using a specified language model.
|
|
468
|
+
|
|
469
|
+
<Tip>
|
|
470
|
+
|
|
471
|
+
If the model is served by a server supporting chat-completion, the method will directly call the server's
|
|
472
|
+
`/v1/chat/completions` endpoint. If the server does not support chat-completion, the method will render the
|
|
473
|
+
chat template client-side based on the information fetched from the Hub API. In this case, you will need to
|
|
474
|
+
have `minijinja` template engine installed. Run `pip install "huggingface_hub[inference]"` or `pip install minijinja`
|
|
475
|
+
to install it.
|
|
476
|
+
|
|
477
|
+
</Tip>
|
|
478
|
+
|
|
479
|
+
Args:
|
|
480
|
+
messages (List[Union[`SystemMessage`, `UserMessage`, `AssistantMessage`]]):
|
|
481
|
+
Conversation history consisting of roles and content pairs.
|
|
482
|
+
model (`str`, *optional*):
|
|
483
|
+
The model to use for chat-completion. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
|
484
|
+
Inference Endpoint. If not provided, the default recommended model for chat-based text-generation will be used.
|
|
485
|
+
See https://huggingface.co/tasks/text-generation for more details.
|
|
486
|
+
frequency_penalty (`float`, optional):
|
|
487
|
+
Penalizes new tokens based on their existing frequency
|
|
488
|
+
in the text so far. Range: [-2.0, 2.0]. Defaults to 0.0.
|
|
489
|
+
max_tokens (`int`, optional):
|
|
490
|
+
Maximum number of tokens allowed in the response. Defaults to 20.
|
|
491
|
+
seed (Optional[`int`], optional):
|
|
492
|
+
Seed for reproducible control flow. Defaults to None.
|
|
493
|
+
stop (Optional[`str`], optional):
|
|
494
|
+
Up to four strings which trigger the end of the response.
|
|
495
|
+
Defaults to None.
|
|
496
|
+
stream (`bool`, optional):
|
|
497
|
+
Enable realtime streaming of responses. Defaults to False.
|
|
498
|
+
temperature (`float`, optional):
|
|
499
|
+
Controls randomness of the generations. Lower values ensure
|
|
500
|
+
less random completions. Range: [0, 2]. Defaults to 1.0.
|
|
501
|
+
top_p (`float`, optional):
|
|
502
|
+
Fraction of the most likely next words to sample from.
|
|
503
|
+
Must be between 0 and 1. Defaults to 1.0.
|
|
504
|
+
|
|
505
|
+
Returns:
|
|
506
|
+
`Union[ChatCompletionOutput, Iterable[ChatCompletionStreamOutput]]`:
|
|
507
|
+
Generated text returned from the server:
|
|
508
|
+
- if `stream=False`, the generated text is returned as a [`ChatCompletionOutput`] (default).
|
|
509
|
+
- if `stream=True`, the generated text is returned token by token as a sequence of [`ChatCompletionStreamOutput`].
|
|
510
|
+
|
|
511
|
+
Raises:
|
|
512
|
+
[`InferenceTimeoutError`]:
|
|
513
|
+
If the model is unavailable or the request times out.
|
|
514
|
+
`HTTPError`:
|
|
515
|
+
If the request fails with an HTTP error status code other than HTTP 503.
|
|
516
|
+
|
|
517
|
+
Example:
|
|
518
|
+
```py
|
|
519
|
+
>>> from huggingface_hub import InferenceClient
|
|
520
|
+
>>> messages = [{"role": "user", "content": "What is the capital of France?"}]
|
|
521
|
+
>>> client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
|
|
522
|
+
>>> client.chat_completion(messages, max_tokens=100)
|
|
523
|
+
ChatCompletionOutput(
|
|
524
|
+
choices=[
|
|
525
|
+
ChatCompletionOutputChoice(
|
|
526
|
+
finish_reason='eos_token',
|
|
527
|
+
index=0,
|
|
528
|
+
message=ChatCompletionOutputChoiceMessage(
|
|
529
|
+
content='The capital of France is Paris. The official name of the city is "Ville de Paris" (City of Paris) and the name of the country\'s governing body, which is located in Paris, is "La République française" (The French Republic). \nI hope that helps! Let me know if you need any further information.'
|
|
530
|
+
)
|
|
531
|
+
)
|
|
532
|
+
],
|
|
533
|
+
created=1710498360
|
|
534
|
+
)
|
|
535
|
+
|
|
536
|
+
>>> for token in client.chat_completion(messages, max_tokens=10, stream=True):
|
|
537
|
+
... print(token)
|
|
538
|
+
ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content='The', role='assistant'), index=0, finish_reason=None)], created=1710498504)
|
|
539
|
+
ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content=' capital', role='assistant'), index=0, finish_reason=None)], created=1710498504)
|
|
540
|
+
(...)
|
|
541
|
+
ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content=' may', role='assistant'), index=0, finish_reason=None)], created=1710498504)
|
|
542
|
+
ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content=None, role=None), index=0, finish_reason='length')], created=1710498504)
|
|
543
|
+
```
|
|
544
|
+
"""
|
|
545
|
+
# determine model
|
|
546
|
+
model = model or self.model or self.get_recommended_model("text-generation")
|
|
547
|
+
|
|
548
|
+
if _is_chat_completion_server(model):
|
|
549
|
+
# First, let's consider the server has a `/v1/chat/completions` endpoint.
|
|
550
|
+
# If that's the case, we don't have to render the chat template client-side.
|
|
551
|
+
model_url = self._resolve_url(model) + "/v1/chat/completions"
|
|
552
|
+
|
|
553
|
+
try:
|
|
554
|
+
data = self.post(
|
|
555
|
+
model=model_url,
|
|
556
|
+
json=dict(
|
|
557
|
+
model="tgi", # random string
|
|
558
|
+
messages=messages,
|
|
559
|
+
max_tokens=max_tokens,
|
|
560
|
+
seed=seed,
|
|
561
|
+
stop=stop,
|
|
562
|
+
temperature=temperature,
|
|
563
|
+
top_p=top_p,
|
|
564
|
+
stream=stream,
|
|
565
|
+
),
|
|
566
|
+
stream=stream,
|
|
567
|
+
)
|
|
568
|
+
except HTTPError:
|
|
569
|
+
# Let's consider the server is not a chat completion server.
|
|
570
|
+
# Then we call again `chat_completion` which will render the chat template client side.
|
|
571
|
+
# (can be HTTP 500, HTTP 400, HTTP 404 depending on the server)
|
|
572
|
+
_set_as_non_chat_completion_server(model)
|
|
573
|
+
return self.chat_completion(
|
|
574
|
+
messages=messages,
|
|
575
|
+
model=model,
|
|
576
|
+
stream=stream,
|
|
577
|
+
max_tokens=max_tokens,
|
|
578
|
+
seed=seed,
|
|
579
|
+
stop=stop,
|
|
580
|
+
temperature=temperature,
|
|
581
|
+
top_p=top_p,
|
|
582
|
+
)
|
|
583
|
+
|
|
584
|
+
if stream:
|
|
585
|
+
return _stream_chat_completion_response_from_bytes(data) # type: ignore[arg-type]
|
|
586
|
+
|
|
587
|
+
return ChatCompletionOutput.parse_obj_as_instance(data) # type: ignore[arg-type]
|
|
588
|
+
|
|
589
|
+
# At this point, we know the server is not a chat completion server.
|
|
590
|
+
# We need to render the chat template client side based on the information we can fetch from
|
|
591
|
+
# the Hub API.
|
|
592
|
+
|
|
593
|
+
model_id = None
|
|
594
|
+
if model.startswith(("http://", "https://")):
|
|
595
|
+
# If URL, we need to know which model is served. This is not always possible.
|
|
596
|
+
# A workaround is to list the user Inference Endpoints and check if one of them correspond to the model URL.
|
|
597
|
+
# If not, we raise an error.
|
|
598
|
+
# TODO: fix when we have a proper API for this (at least for Inference Endpoints)
|
|
599
|
+
# TODO: what if Sagemaker URL?
|
|
600
|
+
# TODO: what if Azure URL?
|
|
601
|
+
from ..hf_api import HfApi
|
|
602
|
+
|
|
603
|
+
for endpoint in HfApi(token=self.token).list_inference_endpoints():
|
|
604
|
+
if endpoint.url == model:
|
|
605
|
+
model_id = endpoint.repository
|
|
606
|
+
break
|
|
607
|
+
else:
|
|
608
|
+
model_id = model
|
|
609
|
+
|
|
610
|
+
if model_id is None:
|
|
611
|
+
# If we don't have the model ID, we can't fetch the chat template.
|
|
612
|
+
# We raise an error.
|
|
613
|
+
raise ValueError(
|
|
614
|
+
"Request can't be processed as the model ID can't be inferred from model URL. "
|
|
615
|
+
"This is needed to fetch the chat template from the Hub since the model is not "
|
|
616
|
+
"served with a Chat-completion API."
|
|
617
|
+
)
|
|
618
|
+
|
|
619
|
+
# fetch chat template + tokens
|
|
620
|
+
prompt = render_chat_prompt(model_id=model_id, token=self.token, messages=messages)
|
|
621
|
+
|
|
622
|
+
# generate response
|
|
623
|
+
stop_sequences = [stop] if isinstance(stop, str) else stop
|
|
624
|
+
text_generation_output = self.text_generation(
|
|
625
|
+
prompt=prompt,
|
|
626
|
+
details=True,
|
|
627
|
+
stream=stream,
|
|
628
|
+
model=model,
|
|
629
|
+
max_new_tokens=max_tokens,
|
|
630
|
+
seed=seed,
|
|
631
|
+
stop_sequences=stop_sequences,
|
|
632
|
+
temperature=temperature,
|
|
633
|
+
top_p=top_p,
|
|
634
|
+
)
|
|
635
|
+
|
|
636
|
+
created = int(time.time())
|
|
637
|
+
|
|
638
|
+
if stream:
|
|
639
|
+
return _stream_chat_completion_response_from_text_generation(text_generation_output) # type: ignore [arg-type]
|
|
640
|
+
|
|
641
|
+
if isinstance(text_generation_output, TextGenerationOutput):
|
|
642
|
+
# General use case => format ChatCompletionOutput from text generation details
|
|
643
|
+
content: str = text_generation_output.generated_text
|
|
644
|
+
finish_reason: str = text_generation_output.details.finish_reason # type: ignore[union-attr]
|
|
645
|
+
else:
|
|
646
|
+
# Corner case: if server doesn't support details (e.g. if not a TGI server), we only receive an output string.
|
|
647
|
+
# In such a case, `finish_reason` is set to `"unk"`.
|
|
648
|
+
content = text_generation_output # type: ignore[assignment]
|
|
649
|
+
finish_reason = "unk"
|
|
650
|
+
|
|
651
|
+
return ChatCompletionOutput(
|
|
652
|
+
created=created,
|
|
653
|
+
choices=[
|
|
654
|
+
ChatCompletionOutputChoice(
|
|
655
|
+
finish_reason=finish_reason, # type: ignore
|
|
656
|
+
index=0,
|
|
657
|
+
message=ChatCompletionOutputChoiceMessage(
|
|
658
|
+
content=content,
|
|
659
|
+
role="assistant",
|
|
660
|
+
),
|
|
661
|
+
)
|
|
662
|
+
],
|
|
663
|
+
)
|
|
382
664
|
|
|
383
665
|
def conversational(
|
|
384
666
|
self,
|
|
@@ -392,6 +674,13 @@ class InferenceClient:
|
|
|
392
674
|
"""
|
|
393
675
|
Generate conversational responses based on the given input text (i.e. chat with the API).
|
|
394
676
|
|
|
677
|
+
<Tip warning={true}>
|
|
678
|
+
|
|
679
|
+
[`InferenceClient.conversational`] API is deprecated and will be removed in a future release. Please use
|
|
680
|
+
[`InferenceClient.chat_completion`] instead.
|
|
681
|
+
|
|
682
|
+
</Tip>
|
|
683
|
+
|
|
395
684
|
Args:
|
|
396
685
|
text (`str`):
|
|
397
686
|
The last input from the user in the conversation.
|
|
@@ -431,6 +720,11 @@ class InferenceClient:
|
|
|
431
720
|
... )
|
|
432
721
|
```
|
|
433
722
|
"""
|
|
723
|
+
warnings.warn(
|
|
724
|
+
"'InferenceClient.conversational' is deprecated and will be removed starting from huggingface_hub>=0.25. "
|
|
725
|
+
"Please use the more appropriate 'InferenceClient.chat_completion' API instead.",
|
|
726
|
+
FutureWarning,
|
|
727
|
+
)
|
|
434
728
|
payload: Dict[str, Any] = {"inputs": {"text": text}}
|
|
435
729
|
if generated_responses is not None:
|
|
436
730
|
payload["inputs"]["generated_responses"] = generated_responses
|
|
@@ -441,57 +735,13 @@ class InferenceClient:
|
|
|
441
735
|
response = self.post(json=payload, model=model, task="conversational")
|
|
442
736
|
return _bytes_to_dict(response) # type: ignore
|
|
443
737
|
|
|
444
|
-
def visual_question_answering(
|
|
445
|
-
self,
|
|
446
|
-
image: ContentT,
|
|
447
|
-
question: str,
|
|
448
|
-
*,
|
|
449
|
-
model: Optional[str] = None,
|
|
450
|
-
) -> List[str]:
|
|
451
|
-
"""
|
|
452
|
-
Answering open-ended questions based on an image.
|
|
453
|
-
|
|
454
|
-
Args:
|
|
455
|
-
image (`Union[str, Path, bytes, BinaryIO]`):
|
|
456
|
-
The input image for the context. It can be raw bytes, an image file, or a URL to an online image.
|
|
457
|
-
question (`str`):
|
|
458
|
-
Question to be answered.
|
|
459
|
-
model (`str`, *optional*):
|
|
460
|
-
The model to use for the visual question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to
|
|
461
|
-
a deployed Inference Endpoint. If not provided, the default recommended visual question answering model will be used.
|
|
462
|
-
Defaults to None.
|
|
463
|
-
|
|
464
|
-
Returns:
|
|
465
|
-
`List[Dict]`: a list of dictionaries containing the predicted label and associated probability.
|
|
466
|
-
|
|
467
|
-
Raises:
|
|
468
|
-
`InferenceTimeoutError`:
|
|
469
|
-
If the model is unavailable or the request times out.
|
|
470
|
-
`HTTPError`:
|
|
471
|
-
If the request fails with an HTTP error status code other than HTTP 503.
|
|
472
|
-
|
|
473
|
-
Example:
|
|
474
|
-
```py
|
|
475
|
-
>>> from huggingface_hub import InferenceClient
|
|
476
|
-
>>> client = InferenceClient()
|
|
477
|
-
>>> client.visual_question_answering(
|
|
478
|
-
... image="https://huggingface.co/datasets/mishig/sample_images/resolve/main/tiger.jpg",
|
|
479
|
-
... question="What is the animal doing?"
|
|
480
|
-
... )
|
|
481
|
-
[{'score': 0.778609573841095, 'answer': 'laying down'},{'score': 0.6957435607910156, 'answer': 'sitting'}, ...]
|
|
482
|
-
```
|
|
483
|
-
"""
|
|
484
|
-
payload: Dict[str, Any] = {"question": question, "image": _b64_encode(image)}
|
|
485
|
-
response = self.post(json=payload, model=model, task="visual-question-answering")
|
|
486
|
-
return _bytes_to_list(response)
|
|
487
|
-
|
|
488
738
|
def document_question_answering(
|
|
489
739
|
self,
|
|
490
740
|
image: ContentT,
|
|
491
741
|
question: str,
|
|
492
742
|
*,
|
|
493
743
|
model: Optional[str] = None,
|
|
494
|
-
) -> List[
|
|
744
|
+
) -> List[DocumentQuestionAnsweringOutputElement]:
|
|
495
745
|
"""
|
|
496
746
|
Answer questions on document images.
|
|
497
747
|
|
|
@@ -506,7 +756,7 @@ class InferenceClient:
|
|
|
506
756
|
Defaults to None.
|
|
507
757
|
|
|
508
758
|
Returns:
|
|
509
|
-
`List[
|
|
759
|
+
`List[DocumentQuestionAnsweringOutputElement]`: a list of [`DocumentQuestionAnsweringOutputElement`] items containing the predicted label, associated probability, word ids, and page number.
|
|
510
760
|
|
|
511
761
|
Raises:
|
|
512
762
|
[`InferenceTimeoutError`]:
|
|
@@ -519,12 +769,12 @@ class InferenceClient:
|
|
|
519
769
|
>>> from huggingface_hub import InferenceClient
|
|
520
770
|
>>> client = InferenceClient()
|
|
521
771
|
>>> client.document_question_answering(image="https://huggingface.co/spaces/impira/docquery/resolve/2359223c1837a7587402bda0f2643382a6eefeab/invoice.png", question="What is the invoice number?")
|
|
522
|
-
[
|
|
772
|
+
[DocumentQuestionAnsweringOutputElement(score=0.42515629529953003, answer='us-001', start=16, end=16)]
|
|
523
773
|
```
|
|
524
774
|
"""
|
|
525
775
|
payload: Dict[str, Any] = {"question": question, "image": _b64_encode(image)}
|
|
526
776
|
response = self.post(json=payload, model=model, task="document-question-answering")
|
|
527
|
-
return
|
|
777
|
+
return DocumentQuestionAnsweringOutputElement.parse_obj_as_list(response)
|
|
528
778
|
|
|
529
779
|
def feature_extraction(self, text: str, *, model: Optional[str] = None) -> "np.ndarray":
|
|
530
780
|
"""
|
|
@@ -562,7 +812,7 @@ class InferenceClient:
|
|
|
562
812
|
np = _import_numpy()
|
|
563
813
|
return np.array(_bytes_to_dict(response), dtype="float32")
|
|
564
814
|
|
|
565
|
-
def fill_mask(self, text: str, *, model: Optional[str] = None) -> List[
|
|
815
|
+
def fill_mask(self, text: str, *, model: Optional[str] = None) -> List[FillMaskOutputElement]:
|
|
566
816
|
"""
|
|
567
817
|
Fill in a hole with a missing word (token to be precise).
|
|
568
818
|
|
|
@@ -575,7 +825,7 @@ class InferenceClient:
|
|
|
575
825
|
Defaults to None.
|
|
576
826
|
|
|
577
827
|
Returns:
|
|
578
|
-
`List[
|
|
828
|
+
`List[FillMaskOutputElement]`: a list of [`FillMaskOutputElement`] items containing the predicted label, associated
|
|
579
829
|
probability, token reference, and completed text.
|
|
580
830
|
|
|
581
831
|
Raises:
|
|
@@ -589,25 +839,21 @@ class InferenceClient:
|
|
|
589
839
|
>>> from huggingface_hub import InferenceClient
|
|
590
840
|
>>> client = InferenceClient()
|
|
591
841
|
>>> client.fill_mask("The goal of life is <mask>.")
|
|
592
|
-
[
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
{'score': 0.06554922461509705,
|
|
597
|
-
'token': 45075,
|
|
598
|
-
'token_str': ' immortality',
|
|
599
|
-
'sequence': 'The goal of life is immortality.'}]
|
|
842
|
+
[
|
|
843
|
+
FillMaskOutputElement(score=0.06897063553333282, token=11098, token_str=' happiness', sequence='The goal of life is happiness.'),
|
|
844
|
+
FillMaskOutputElement(score=0.06554922461509705, token=45075, token_str=' immortality', sequence='The goal of life is immortality.')
|
|
845
|
+
]
|
|
600
846
|
```
|
|
601
847
|
"""
|
|
602
848
|
response = self.post(json={"inputs": text}, model=model, task="fill-mask")
|
|
603
|
-
return
|
|
849
|
+
return FillMaskOutputElement.parse_obj_as_list(response)
|
|
604
850
|
|
|
605
851
|
def image_classification(
|
|
606
852
|
self,
|
|
607
853
|
image: ContentT,
|
|
608
854
|
*,
|
|
609
855
|
model: Optional[str] = None,
|
|
610
|
-
) -> List[
|
|
856
|
+
) -> List[ImageClassificationOutputElement]:
|
|
611
857
|
"""
|
|
612
858
|
Perform image classification on the given image using the specified model.
|
|
613
859
|
|
|
@@ -619,7 +865,7 @@ class InferenceClient:
|
|
|
619
865
|
deployed Inference Endpoint. If not provided, the default recommended model for image classification will be used.
|
|
620
866
|
|
|
621
867
|
Returns:
|
|
622
|
-
`List[
|
|
868
|
+
`List[ImageClassificationOutputElement]`: a list of [`ImageClassificationOutputElement`] items containing the predicted label and associated probability.
|
|
623
869
|
|
|
624
870
|
Raises:
|
|
625
871
|
[`InferenceTimeoutError`]:
|
|
@@ -632,18 +878,18 @@ class InferenceClient:
|
|
|
632
878
|
>>> from huggingface_hub import InferenceClient
|
|
633
879
|
>>> client = InferenceClient()
|
|
634
880
|
>>> client.image_classification("https://upload.wikimedia.org/wikipedia/commons/thumb/4/43/Cute_dog.jpg/320px-Cute_dog.jpg")
|
|
635
|
-
[
|
|
881
|
+
[ImageClassificationOutputElement(score=0.9779096841812134, label='Blenheim spaniel'), ...]
|
|
636
882
|
```
|
|
637
883
|
"""
|
|
638
884
|
response = self.post(data=image, model=model, task="image-classification")
|
|
639
|
-
return
|
|
885
|
+
return ImageClassificationOutputElement.parse_obj_as_list(response)
|
|
640
886
|
|
|
641
887
|
def image_segmentation(
|
|
642
888
|
self,
|
|
643
889
|
image: ContentT,
|
|
644
890
|
*,
|
|
645
891
|
model: Optional[str] = None,
|
|
646
|
-
) -> List[
|
|
892
|
+
) -> List[ImageSegmentationOutputElement]:
|
|
647
893
|
"""
|
|
648
894
|
Perform image segmentation on the given image using the specified model.
|
|
649
895
|
|
|
@@ -661,7 +907,7 @@ class InferenceClient:
|
|
|
661
907
|
deployed Inference Endpoint. If not provided, the default recommended model for image segmentation will be used.
|
|
662
908
|
|
|
663
909
|
Returns:
|
|
664
|
-
`List[
|
|
910
|
+
`List[ImageSegmentationOutputElement]`: A list of [`ImageSegmentationOutputElement`] items containing the segmented masks and associated attributes.
|
|
665
911
|
|
|
666
912
|
Raises:
|
|
667
913
|
[`InferenceTimeoutError`]:
|
|
@@ -674,19 +920,13 @@ class InferenceClient:
|
|
|
674
920
|
>>> from huggingface_hub import InferenceClient
|
|
675
921
|
>>> client = InferenceClient()
|
|
676
922
|
>>> client.image_segmentation("cat.jpg"):
|
|
677
|
-
[
|
|
923
|
+
[ImageSegmentationOutputElement(score=0.989008, label='LABEL_184', mask=<PIL.PngImagePlugin.PngImageFile image mode=L size=400x300 at 0x7FDD2B129CC0>), ...]
|
|
678
924
|
```
|
|
679
925
|
"""
|
|
680
|
-
|
|
681
|
-
# Segment
|
|
682
926
|
response = self.post(data=image, model=model, task="image-segmentation")
|
|
683
|
-
output =
|
|
684
|
-
|
|
685
|
-
# Parse masks as PIL Image
|
|
686
|
-
if not isinstance(output, list):
|
|
687
|
-
raise ValueError(f"Server output must be a list. Got {type(output)}: {str(output)[:200]}...")
|
|
927
|
+
output = ImageSegmentationOutputElement.parse_obj_as_list(response)
|
|
688
928
|
for item in output:
|
|
689
|
-
item
|
|
929
|
+
item.mask = _b64_to_image(item.mask)
|
|
690
930
|
return output
|
|
691
931
|
|
|
692
932
|
def image_to_image(
|
|
@@ -773,7 +1013,7 @@ class InferenceClient:
|
|
|
773
1013
|
response = self.post(json=payload, data=data, model=model, task="image-to-image")
|
|
774
1014
|
return _bytes_to_image(response)
|
|
775
1015
|
|
|
776
|
-
def image_to_text(self, image: ContentT, *, model: Optional[str] = None) ->
|
|
1016
|
+
def image_to_text(self, image: ContentT, *, model: Optional[str] = None) -> ImageToTextOutput:
|
|
777
1017
|
"""
|
|
778
1018
|
Takes an input image and return text.
|
|
779
1019
|
|
|
@@ -788,7 +1028,7 @@ class InferenceClient:
|
|
|
788
1028
|
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
|
|
789
1029
|
|
|
790
1030
|
Returns:
|
|
791
|
-
`
|
|
1031
|
+
[`ImageToTextOutput`]: The generated text.
|
|
792
1032
|
|
|
793
1033
|
Raises:
|
|
794
1034
|
[`InferenceTimeoutError`]:
|
|
@@ -807,7 +1047,7 @@ class InferenceClient:
|
|
|
807
1047
|
```
|
|
808
1048
|
"""
|
|
809
1049
|
response = self.post(data=image, model=model, task="image-to-text")
|
|
810
|
-
return
|
|
1050
|
+
return ImageToTextOutput.parse_obj_as_instance(response)
|
|
811
1051
|
|
|
812
1052
|
def list_deployed_models(
|
|
813
1053
|
self, frameworks: Union[None, str, Literal["all"], List[str]] = None
|
|
@@ -889,7 +1129,7 @@ class InferenceClient:
|
|
|
889
1129
|
image: ContentT,
|
|
890
1130
|
*,
|
|
891
1131
|
model: Optional[str] = None,
|
|
892
|
-
) -> List[
|
|
1132
|
+
) -> List[ObjectDetectionOutputElement]:
|
|
893
1133
|
"""
|
|
894
1134
|
Perform object detection on the given image using the specified model.
|
|
895
1135
|
|
|
@@ -907,7 +1147,7 @@ class InferenceClient:
|
|
|
907
1147
|
deployed Inference Endpoint. If not provided, the default recommended model for object detection (DETR) will be used.
|
|
908
1148
|
|
|
909
1149
|
Returns:
|
|
910
|
-
`List[
|
|
1150
|
+
`List[ObjectDetectionOutputElement]`: A list of [`ObjectDetectionOutputElement`] items containing the bounding boxes and associated attributes.
|
|
911
1151
|
|
|
912
1152
|
Raises:
|
|
913
1153
|
[`InferenceTimeoutError`]:
|
|
@@ -922,19 +1162,16 @@ class InferenceClient:
|
|
|
922
1162
|
>>> from huggingface_hub import InferenceClient
|
|
923
1163
|
>>> client = InferenceClient()
|
|
924
1164
|
>>> client.object_detection("people.jpg"):
|
|
925
|
-
[
|
|
1165
|
+
[ObjectDetectionOutputElement(score=0.9486683011054993, label='person', box=ObjectDetectionBoundingBox(xmin=59, ymin=39, xmax=420, ymax=510)), ...]
|
|
926
1166
|
```
|
|
927
1167
|
"""
|
|
928
1168
|
# detect objects
|
|
929
1169
|
response = self.post(data=image, model=model, task="object-detection")
|
|
930
|
-
|
|
931
|
-
if not isinstance(output, list):
|
|
932
|
-
raise ValueError(f"Server output must be a list. Got {type(output)}: {str(output)[:200]}...")
|
|
933
|
-
return output
|
|
1170
|
+
return ObjectDetectionOutputElement.parse_obj_as_list(response)
|
|
934
1171
|
|
|
935
1172
|
def question_answering(
|
|
936
1173
|
self, question: str, context: str, *, model: Optional[str] = None
|
|
937
|
-
) ->
|
|
1174
|
+
) -> QuestionAnsweringOutputElement:
|
|
938
1175
|
"""
|
|
939
1176
|
Retrieve the answer to a question from a given text.
|
|
940
1177
|
|
|
@@ -948,7 +1185,7 @@ class InferenceClient:
|
|
|
948
1185
|
a deployed Inference Endpoint.
|
|
949
1186
|
|
|
950
1187
|
Returns:
|
|
951
|
-
`
|
|
1188
|
+
[`QuestionAnsweringOutputElement`]: an question answering output containing the score, start index, end index, and answer.
|
|
952
1189
|
|
|
953
1190
|
Raises:
|
|
954
1191
|
[`InferenceTimeoutError`]:
|
|
@@ -961,7 +1198,7 @@ class InferenceClient:
|
|
|
961
1198
|
>>> from huggingface_hub import InferenceClient
|
|
962
1199
|
>>> client = InferenceClient()
|
|
963
1200
|
>>> client.question_answering(question="What's my name?", context="My name is Clara and I live in Berkeley.")
|
|
964
|
-
|
|
1201
|
+
QuestionAnsweringOutputElement(score=0.9326562285423279, start=11, end=16, answer='Clara')
|
|
965
1202
|
```
|
|
966
1203
|
"""
|
|
967
1204
|
|
|
@@ -971,7 +1208,7 @@ class InferenceClient:
|
|
|
971
1208
|
model=model,
|
|
972
1209
|
task="question-answering",
|
|
973
1210
|
)
|
|
974
|
-
return
|
|
1211
|
+
return QuestionAnsweringOutputElement.parse_obj_as_instance(response)
|
|
975
1212
|
|
|
976
1213
|
def sentence_similarity(
|
|
977
1214
|
self, sentence: str, other_sentences: List[str], *, model: Optional[str] = None
|
|
@@ -1026,7 +1263,7 @@ class InferenceClient:
|
|
|
1026
1263
|
*,
|
|
1027
1264
|
parameters: Optional[Dict[str, Any]] = None,
|
|
1028
1265
|
model: Optional[str] = None,
|
|
1029
|
-
) ->
|
|
1266
|
+
) -> SummarizationOutput:
|
|
1030
1267
|
"""
|
|
1031
1268
|
Generate a summary of a given text using a specified model.
|
|
1032
1269
|
|
|
@@ -1041,7 +1278,7 @@ class InferenceClient:
|
|
|
1041
1278
|
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
|
|
1042
1279
|
|
|
1043
1280
|
Returns:
|
|
1044
|
-
`
|
|
1281
|
+
[`SummarizationOutput`]: The generated summary text.
|
|
1045
1282
|
|
|
1046
1283
|
Raises:
|
|
1047
1284
|
[`InferenceTimeoutError`]:
|
|
@@ -1054,18 +1291,18 @@ class InferenceClient:
|
|
|
1054
1291
|
>>> from huggingface_hub import InferenceClient
|
|
1055
1292
|
>>> client = InferenceClient()
|
|
1056
1293
|
>>> client.summarization("The Eiffel tower...")
|
|
1057
|
-
|
|
1294
|
+
SummarizationOutput(generated_text="The Eiffel tower is one of the most famous landmarks in the world....")
|
|
1058
1295
|
```
|
|
1059
1296
|
"""
|
|
1060
1297
|
payload: Dict[str, Any] = {"inputs": text}
|
|
1061
1298
|
if parameters is not None:
|
|
1062
1299
|
payload["parameters"] = parameters
|
|
1063
1300
|
response = self.post(json=payload, model=model, task="summarization")
|
|
1064
|
-
return
|
|
1301
|
+
return SummarizationOutput.parse_obj_as_list(response)[0]
|
|
1065
1302
|
|
|
1066
1303
|
def table_question_answering(
|
|
1067
1304
|
self, table: Dict[str, Any], query: str, *, model: Optional[str] = None
|
|
1068
|
-
) ->
|
|
1305
|
+
) -> TableQuestionAnsweringOutputElement:
|
|
1069
1306
|
"""
|
|
1070
1307
|
Retrieve the answer to a question from information given in a table.
|
|
1071
1308
|
|
|
@@ -1080,7 +1317,7 @@ class InferenceClient:
|
|
|
1080
1317
|
Hub or a URL to a deployed Inference Endpoint.
|
|
1081
1318
|
|
|
1082
1319
|
Returns:
|
|
1083
|
-
`
|
|
1320
|
+
[`TableQuestionAnsweringOutputElement`]: a table question answering output containing the answer, coordinates, cells and the aggregator used.
|
|
1084
1321
|
|
|
1085
1322
|
Raises:
|
|
1086
1323
|
[`InferenceTimeoutError`]:
|
|
@@ -1095,7 +1332,7 @@ class InferenceClient:
|
|
|
1095
1332
|
>>> query = "How many stars does the transformers repository have?"
|
|
1096
1333
|
>>> table = {"Repository": ["Transformers", "Datasets", "Tokenizers"], "Stars": ["36542", "4512", "3934"]}
|
|
1097
1334
|
>>> client.table_question_answering(table, query, model="google/tapas-base-finetuned-wtq")
|
|
1098
|
-
|
|
1335
|
+
TableQuestionAnsweringOutputElement(answer='36542', coordinates=[[0, 1]], cells=['36542'], aggregator='AVERAGE')
|
|
1099
1336
|
```
|
|
1100
1337
|
"""
|
|
1101
1338
|
response = self.post(
|
|
@@ -1106,7 +1343,7 @@ class InferenceClient:
|
|
|
1106
1343
|
model=model,
|
|
1107
1344
|
task="table-question-answering",
|
|
1108
1345
|
)
|
|
1109
|
-
return
|
|
1346
|
+
return TableQuestionAnsweringOutputElement.parse_obj_as_instance(response)
|
|
1110
1347
|
|
|
1111
1348
|
def tabular_classification(self, table: Dict[str, Any], *, model: Optional[str] = None) -> List[str]:
|
|
1112
1349
|
"""
|
|
@@ -1193,7 +1430,7 @@ class InferenceClient:
|
|
|
1193
1430
|
response = self.post(json={"table": table}, model=model, task="tabular-regression")
|
|
1194
1431
|
return _bytes_to_list(response)
|
|
1195
1432
|
|
|
1196
|
-
def text_classification(self, text: str, *, model: Optional[str] = None) -> List[
|
|
1433
|
+
def text_classification(self, text: str, *, model: Optional[str] = None) -> List[TextClassificationOutputElement]:
|
|
1197
1434
|
"""
|
|
1198
1435
|
Perform text classification (e.g. sentiment-analysis) on the given text.
|
|
1199
1436
|
|
|
@@ -1206,7 +1443,7 @@ class InferenceClient:
|
|
|
1206
1443
|
Defaults to None.
|
|
1207
1444
|
|
|
1208
1445
|
Returns:
|
|
1209
|
-
`List[
|
|
1446
|
+
`List[TextClassificationOutputElement]`: a list of [`TextClassificationOutputElement`] items containing the predicted label and associated probability.
|
|
1210
1447
|
|
|
1211
1448
|
Raises:
|
|
1212
1449
|
[`InferenceTimeoutError`]:
|
|
@@ -1219,11 +1456,14 @@ class InferenceClient:
|
|
|
1219
1456
|
>>> from huggingface_hub import InferenceClient
|
|
1220
1457
|
>>> client = InferenceClient()
|
|
1221
1458
|
>>> client.text_classification("I like you")
|
|
1222
|
-
[
|
|
1459
|
+
[
|
|
1460
|
+
TextClassificationOutputElement(label='POSITIVE', score=0.9998695850372314),
|
|
1461
|
+
TextClassificationOutputElement(label='NEGATIVE', score=0.0001304351753788069),
|
|
1462
|
+
]
|
|
1223
1463
|
```
|
|
1224
1464
|
"""
|
|
1225
1465
|
response = self.post(json={"inputs": text}, model=model, task="text-classification")
|
|
1226
|
-
return
|
|
1466
|
+
return TextClassificationOutputElement.parse_obj_as_list(response)[0] # type: ignore [return-value]
|
|
1227
1467
|
|
|
1228
1468
|
@overload
|
|
1229
1469
|
def text_generation( # type: ignore
|
|
@@ -1246,8 +1486,7 @@ class InferenceClient:
|
|
|
1246
1486
|
truncate: Optional[int] = None,
|
|
1247
1487
|
typical_p: Optional[float] = None,
|
|
1248
1488
|
watermark: bool = False,
|
|
1249
|
-
) -> str:
|
|
1250
|
-
...
|
|
1489
|
+
) -> str: ...
|
|
1251
1490
|
|
|
1252
1491
|
@overload
|
|
1253
1492
|
def text_generation( # type: ignore
|
|
@@ -1270,8 +1509,7 @@ class InferenceClient:
|
|
|
1270
1509
|
truncate: Optional[int] = None,
|
|
1271
1510
|
typical_p: Optional[float] = None,
|
|
1272
1511
|
watermark: bool = False,
|
|
1273
|
-
) ->
|
|
1274
|
-
...
|
|
1512
|
+
) -> TextGenerationOutput: ...
|
|
1275
1513
|
|
|
1276
1514
|
@overload
|
|
1277
1515
|
def text_generation( # type: ignore
|
|
@@ -1294,11 +1532,10 @@ class InferenceClient:
|
|
|
1294
1532
|
truncate: Optional[int] = None,
|
|
1295
1533
|
typical_p: Optional[float] = None,
|
|
1296
1534
|
watermark: bool = False,
|
|
1297
|
-
) -> Iterable[str]:
|
|
1298
|
-
...
|
|
1535
|
+
) -> Iterable[str]: ...
|
|
1299
1536
|
|
|
1300
1537
|
@overload
|
|
1301
|
-
def text_generation(
|
|
1538
|
+
def text_generation( # type: ignore
|
|
1302
1539
|
self,
|
|
1303
1540
|
prompt: str,
|
|
1304
1541
|
*,
|
|
@@ -1318,8 +1555,30 @@ class InferenceClient:
|
|
|
1318
1555
|
truncate: Optional[int] = None,
|
|
1319
1556
|
typical_p: Optional[float] = None,
|
|
1320
1557
|
watermark: bool = False,
|
|
1321
|
-
) -> Iterable[
|
|
1322
|
-
|
|
1558
|
+
) -> Iterable[TextGenerationStreamOutput]: ...
|
|
1559
|
+
|
|
1560
|
+
@overload
|
|
1561
|
+
def text_generation(
|
|
1562
|
+
self,
|
|
1563
|
+
prompt: str,
|
|
1564
|
+
*,
|
|
1565
|
+
details: Literal[True] = ...,
|
|
1566
|
+
stream: bool = ...,
|
|
1567
|
+
model: Optional[str] = None,
|
|
1568
|
+
do_sample: bool = False,
|
|
1569
|
+
max_new_tokens: int = 20,
|
|
1570
|
+
best_of: Optional[int] = None,
|
|
1571
|
+
repetition_penalty: Optional[float] = None,
|
|
1572
|
+
return_full_text: bool = False,
|
|
1573
|
+
seed: Optional[int] = None,
|
|
1574
|
+
stop_sequences: Optional[List[str]] = None,
|
|
1575
|
+
temperature: Optional[float] = None,
|
|
1576
|
+
top_k: Optional[int] = None,
|
|
1577
|
+
top_p: Optional[float] = None,
|
|
1578
|
+
truncate: Optional[int] = None,
|
|
1579
|
+
typical_p: Optional[float] = None,
|
|
1580
|
+
watermark: bool = False,
|
|
1581
|
+
) -> Union[TextGenerationOutput, Iterable[TextGenerationStreamOutput]]: ...
|
|
1323
1582
|
|
|
1324
1583
|
def text_generation(
|
|
1325
1584
|
self,
|
|
@@ -1342,13 +1601,10 @@ class InferenceClient:
|
|
|
1342
1601
|
typical_p: Optional[float] = None,
|
|
1343
1602
|
watermark: bool = False,
|
|
1344
1603
|
decoder_input_details: bool = False,
|
|
1345
|
-
) -> Union[str,
|
|
1604
|
+
) -> Union[str, TextGenerationOutput, Iterable[str], Iterable[TextGenerationStreamOutput]]:
|
|
1346
1605
|
"""
|
|
1347
1606
|
Given a prompt, generate the following text.
|
|
1348
1607
|
|
|
1349
|
-
It is recommended to have Pydantic installed in order to get inputs validated. This is preferable as it allow
|
|
1350
|
-
early failures.
|
|
1351
|
-
|
|
1352
1608
|
API endpoint is supposed to run with the `text-generation-inference` backend (TGI). This backend is the
|
|
1353
1609
|
go-to solution to run large language models at scale. However, for some smaller models (e.g. "gpt2") the
|
|
1354
1610
|
default `transformers` + `api-inference` solution is still in use. Both approaches have very similar APIs, but
|
|
@@ -1406,12 +1662,12 @@ class InferenceClient:
|
|
|
1406
1662
|
into account. Defaults to `False`.
|
|
1407
1663
|
|
|
1408
1664
|
Returns:
|
|
1409
|
-
`Union[str,
|
|
1665
|
+
`Union[str, TextGenerationOutput, Iterable[str], Iterable[TextGenerationStreamOutput]]`:
|
|
1410
1666
|
Generated text returned from the server:
|
|
1411
1667
|
- if `stream=False` and `details=False`, the generated text is returned as a `str` (default)
|
|
1412
1668
|
- if `stream=True` and `details=False`, the generated text is returned token by token as a `Iterable[str]`
|
|
1413
|
-
- if `stream=False` and `details=True`, the generated text is returned with more details as a [`~huggingface_hub.
|
|
1414
|
-
- if `details=True` and `stream=True`, the generated text is returned token by token as a iterable of [`~huggingface_hub.
|
|
1669
|
+
- if `stream=False` and `details=True`, the generated text is returned with more details as a [`~huggingface_hub.TextGenerationOutput`]
|
|
1670
|
+
- if `details=True` and `stream=True`, the generated text is returned token by token as a iterable of [`~huggingface_hub.TextGenerationStreamOutput`]
|
|
1415
1671
|
|
|
1416
1672
|
Raises:
|
|
1417
1673
|
`ValidationError`:
|
|
@@ -1448,23 +1704,23 @@ class InferenceClient:
|
|
|
1448
1704
|
|
|
1449
1705
|
# Case 3: get more details about the generation process.
|
|
1450
1706
|
>>> client.text_generation("The huggingface_hub library is ", max_new_tokens=12, details=True)
|
|
1451
|
-
|
|
1707
|
+
TextGenerationOutput(
|
|
1452
1708
|
generated_text='100% open source and built to be easy to use.',
|
|
1453
|
-
details=
|
|
1454
|
-
finish_reason
|
|
1709
|
+
details=TextGenerationDetails(
|
|
1710
|
+
finish_reason='length',
|
|
1455
1711
|
generated_tokens=12,
|
|
1456
1712
|
seed=None,
|
|
1457
1713
|
prefill=[
|
|
1458
|
-
|
|
1459
|
-
|
|
1714
|
+
TextGenerationPrefillToken(id=487, text='The', logprob=None),
|
|
1715
|
+
TextGenerationPrefillToken(id=53789, text=' hugging', logprob=-13.171875),
|
|
1460
1716
|
(...)
|
|
1461
|
-
|
|
1717
|
+
TextGenerationPrefillToken(id=204, text=' ', logprob=-7.0390625)
|
|
1462
1718
|
],
|
|
1463
1719
|
tokens=[
|
|
1464
|
-
|
|
1465
|
-
|
|
1720
|
+
TokenElement(id=1425, text='100', logprob=-1.0175781, special=False),
|
|
1721
|
+
TokenElement(id=16, text='%', logprob=-0.0463562, special=False),
|
|
1466
1722
|
(...)
|
|
1467
|
-
|
|
1723
|
+
TokenElement(id=25, text='.', logprob=-0.5703125, special=False)
|
|
1468
1724
|
],
|
|
1469
1725
|
best_of_sequences=None
|
|
1470
1726
|
)
|
|
@@ -1475,30 +1731,27 @@ class InferenceClient:
|
|
|
1475
1731
|
>>> for details in client.text_generation("The huggingface_hub library is ", max_new_tokens=12, details=True, stream=True):
|
|
1476
1732
|
... print(details)
|
|
1477
1733
|
...
|
|
1478
|
-
|
|
1479
|
-
|
|
1480
|
-
|
|
1481
|
-
|
|
1482
|
-
|
|
1483
|
-
|
|
1484
|
-
|
|
1485
|
-
|
|
1486
|
-
|
|
1487
|
-
|
|
1488
|
-
|
|
1489
|
-
|
|
1734
|
+
TextGenerationStreamOutput(token=TokenElement(id=1425, text='100', logprob=-1.0175781, special=False), generated_text=None, details=None)
|
|
1735
|
+
TextGenerationStreamOutput(token=TokenElement(id=16, text='%', logprob=-0.0463562, special=False), generated_text=None, details=None)
|
|
1736
|
+
TextGenerationStreamOutput(token=TokenElement(id=1314, text=' open', logprob=-1.3359375, special=False), generated_text=None, details=None)
|
|
1737
|
+
TextGenerationStreamOutput(token=TokenElement(id=3178, text=' source', logprob=-0.28100586, special=False), generated_text=None, details=None)
|
|
1738
|
+
TextGenerationStreamOutput(token=TokenElement(id=273, text=' and', logprob=-0.5961914, special=False), generated_text=None, details=None)
|
|
1739
|
+
TextGenerationStreamOutput(token=TokenElement(id=3426, text=' built', logprob=-1.9423828, special=False), generated_text=None, details=None)
|
|
1740
|
+
TextGenerationStreamOutput(token=TokenElement(id=271, text=' to', logprob=-1.4121094, special=False), generated_text=None, details=None)
|
|
1741
|
+
TextGenerationStreamOutput(token=TokenElement(id=314, text=' be', logprob=-1.5224609, special=False), generated_text=None, details=None)
|
|
1742
|
+
TextGenerationStreamOutput(token=TokenElement(id=1833, text=' easy', logprob=-2.1132812, special=False), generated_text=None, details=None)
|
|
1743
|
+
TextGenerationStreamOutput(token=TokenElement(id=271, text=' to', logprob=-0.08520508, special=False), generated_text=None, details=None)
|
|
1744
|
+
TextGenerationStreamOutput(token=TokenElement(id=745, text=' use', logprob=-0.39453125, special=False), generated_text=None, details=None)
|
|
1745
|
+
TextGenerationStreamOutput(token=TokenElement(
|
|
1490
1746
|
id=25,
|
|
1491
1747
|
text='.',
|
|
1492
1748
|
logprob=-0.5703125,
|
|
1493
1749
|
special=False),
|
|
1494
1750
|
generated_text='100% open source and built to be easy to use.',
|
|
1495
|
-
details=
|
|
1751
|
+
details=TextGenerationStreamDetails(finish_reason='length', generated_tokens=12, seed=None)
|
|
1496
1752
|
)
|
|
1497
1753
|
```
|
|
1498
1754
|
"""
|
|
1499
|
-
# NOTE: Text-generation integration is taken from the text-generation-inference project. It has more features
|
|
1500
|
-
# like input/output validation (if Pydantic is installed). See `_text_generation.py` header for more details.
|
|
1501
|
-
|
|
1502
1755
|
if decoder_input_details and not details:
|
|
1503
1756
|
warnings.warn(
|
|
1504
1757
|
"`decoder_input_details=True` has been passed to the server but `details=False` is set meaning that"
|
|
@@ -1506,34 +1759,38 @@ class InferenceClient:
|
|
|
1506
1759
|
)
|
|
1507
1760
|
decoder_input_details = False
|
|
1508
1761
|
|
|
1509
|
-
#
|
|
1510
|
-
|
|
1511
|
-
|
|
1512
|
-
|
|
1513
|
-
|
|
1514
|
-
|
|
1515
|
-
|
|
1516
|
-
|
|
1517
|
-
|
|
1518
|
-
|
|
1519
|
-
|
|
1520
|
-
|
|
1521
|
-
|
|
1522
|
-
|
|
1523
|
-
|
|
1524
|
-
|
|
1525
|
-
|
|
1526
|
-
|
|
1527
|
-
|
|
1528
|
-
|
|
1762
|
+
# Build payload
|
|
1763
|
+
payload = {
|
|
1764
|
+
"inputs": prompt,
|
|
1765
|
+
"parameters": {
|
|
1766
|
+
"best_of": best_of,
|
|
1767
|
+
"decoder_input_details": decoder_input_details,
|
|
1768
|
+
"details": details,
|
|
1769
|
+
"do_sample": do_sample,
|
|
1770
|
+
"max_new_tokens": max_new_tokens,
|
|
1771
|
+
"repetition_penalty": repetition_penalty,
|
|
1772
|
+
"return_full_text": return_full_text,
|
|
1773
|
+
"seed": seed,
|
|
1774
|
+
"stop": stop_sequences if stop_sequences is not None else [],
|
|
1775
|
+
"temperature": temperature,
|
|
1776
|
+
"top_k": top_k,
|
|
1777
|
+
"top_p": top_p,
|
|
1778
|
+
"truncate": truncate,
|
|
1779
|
+
"typical_p": typical_p,
|
|
1780
|
+
"watermark": watermark,
|
|
1781
|
+
},
|
|
1782
|
+
"stream": stream,
|
|
1783
|
+
}
|
|
1529
1784
|
|
|
1530
1785
|
# Remove some parameters if not a TGI server
|
|
1531
1786
|
if not _is_tgi_server(model):
|
|
1787
|
+
parameters: Dict = payload["parameters"] # type: ignore [assignment]
|
|
1788
|
+
|
|
1532
1789
|
ignored_parameters = []
|
|
1533
|
-
for key in "watermark", "
|
|
1534
|
-
if
|
|
1790
|
+
for key in "watermark", "details", "decoder_input_details", "best_of", "stop", "return_full_text":
|
|
1791
|
+
if parameters[key] is not None:
|
|
1535
1792
|
ignored_parameters.append(key)
|
|
1536
|
-
del
|
|
1793
|
+
del parameters[key]
|
|
1537
1794
|
if len(ignored_parameters) > 0:
|
|
1538
1795
|
warnings.warn(
|
|
1539
1796
|
"API endpoint/model for text-generation is not served via TGI. Ignoring parameters"
|
|
@@ -1585,8 +1842,8 @@ class InferenceClient:
|
|
|
1585
1842
|
if stream:
|
|
1586
1843
|
return _stream_text_generation_response(bytes_output, details) # type: ignore
|
|
1587
1844
|
|
|
1588
|
-
data = _bytes_to_dict(bytes_output)[0]
|
|
1589
|
-
return
|
|
1845
|
+
data = _bytes_to_dict(bytes_output)[0] # type: ignore[arg-type]
|
|
1846
|
+
return TextGenerationOutput.parse_obj_as_instance(data) if details else data["generated_text"]
|
|
1590
1847
|
|
|
1591
1848
|
def text_to_image(
|
|
1592
1849
|
self,
|
|
@@ -1700,7 +1957,9 @@ class InferenceClient:
|
|
|
1700
1957
|
"""
|
|
1701
1958
|
return self.post(json={"inputs": text}, model=model, task="text-to-speech")
|
|
1702
1959
|
|
|
1703
|
-
def token_classification(
|
|
1960
|
+
def token_classification(
|
|
1961
|
+
self, text: str, *, model: Optional[str] = None
|
|
1962
|
+
) -> List[TokenClassificationOutputElement]:
|
|
1704
1963
|
"""
|
|
1705
1964
|
Perform token classification on the given text.
|
|
1706
1965
|
Usually used for sentence parsing, either grammatical, or Named Entity Recognition (NER) to understand keywords contained within text.
|
|
@@ -1714,7 +1973,7 @@ class InferenceClient:
|
|
|
1714
1973
|
Defaults to None.
|
|
1715
1974
|
|
|
1716
1975
|
Returns:
|
|
1717
|
-
`List[
|
|
1976
|
+
`List[TokenClassificationOutputElement]`: List of [`TokenClassificationOutputElement`] items containing the entity group, confidence score, word, start and end index.
|
|
1718
1977
|
|
|
1719
1978
|
Raises:
|
|
1720
1979
|
[`InferenceTimeoutError`]:
|
|
@@ -1727,16 +1986,22 @@ class InferenceClient:
|
|
|
1727
1986
|
>>> from huggingface_hub import InferenceClient
|
|
1728
1987
|
>>> client = InferenceClient()
|
|
1729
1988
|
>>> client.token_classification("My name is Sarah Jessica Parker but you can call me Jessica")
|
|
1730
|
-
[
|
|
1731
|
-
|
|
1732
|
-
|
|
1733
|
-
|
|
1734
|
-
|
|
1735
|
-
|
|
1736
|
-
|
|
1737
|
-
|
|
1738
|
-
|
|
1739
|
-
|
|
1989
|
+
[
|
|
1990
|
+
TokenClassificationOutputElement(
|
|
1991
|
+
entity_group='PER',
|
|
1992
|
+
score=0.9971321225166321,
|
|
1993
|
+
word='Sarah Jessica Parker',
|
|
1994
|
+
start=11,
|
|
1995
|
+
end=31,
|
|
1996
|
+
),
|
|
1997
|
+
TokenClassificationOutputElement(
|
|
1998
|
+
entity_group='PER',
|
|
1999
|
+
score=0.9773476123809814,
|
|
2000
|
+
word='Jessica',
|
|
2001
|
+
start=52,
|
|
2002
|
+
end=59,
|
|
2003
|
+
)
|
|
2004
|
+
]
|
|
1740
2005
|
```
|
|
1741
2006
|
"""
|
|
1742
2007
|
payload: Dict[str, Any] = {"inputs": text}
|
|
@@ -1745,11 +2010,11 @@ class InferenceClient:
|
|
|
1745
2010
|
model=model,
|
|
1746
2011
|
task="token-classification",
|
|
1747
2012
|
)
|
|
1748
|
-
return
|
|
2013
|
+
return TokenClassificationOutputElement.parse_obj_as_list(response)
|
|
1749
2014
|
|
|
1750
2015
|
def translation(
|
|
1751
2016
|
self, text: str, *, model: Optional[str] = None, src_lang: Optional[str] = None, tgt_lang: Optional[str] = None
|
|
1752
|
-
) ->
|
|
2017
|
+
) -> TranslationOutput:
|
|
1753
2018
|
"""
|
|
1754
2019
|
Convert text from one language to another.
|
|
1755
2020
|
|
|
@@ -1772,7 +2037,7 @@ class InferenceClient:
|
|
|
1772
2037
|
Target language of the translation task, i.e. output language. Cannot be passed without `src_lang`.
|
|
1773
2038
|
|
|
1774
2039
|
Returns:
|
|
1775
|
-
`
|
|
2040
|
+
[`TranslationOutput`]: The generated translated text.
|
|
1776
2041
|
|
|
1777
2042
|
Raises:
|
|
1778
2043
|
[`InferenceTimeoutError`]:
|
|
@@ -1789,7 +2054,7 @@ class InferenceClient:
|
|
|
1789
2054
|
>>> client.translation("My name is Wolfgang and I live in Berlin")
|
|
1790
2055
|
'Mein Name ist Wolfgang und ich lebe in Berlin.'
|
|
1791
2056
|
>>> client.translation("My name is Wolfgang and I live in Berlin", model="Helsinki-NLP/opus-mt-en-fr")
|
|
1792
|
-
|
|
2057
|
+
TranslationOutput(translation_text='Je m\'appelle Wolfgang et je vis à Berlin.')
|
|
1793
2058
|
```
|
|
1794
2059
|
|
|
1795
2060
|
Specifying languages:
|
|
@@ -1810,11 +2075,58 @@ class InferenceClient:
|
|
|
1810
2075
|
if src_lang and tgt_lang:
|
|
1811
2076
|
payload["parameters"] = {"src_lang": src_lang, "tgt_lang": tgt_lang}
|
|
1812
2077
|
response = self.post(json=payload, model=model, task="translation")
|
|
1813
|
-
return
|
|
2078
|
+
return TranslationOutput.parse_obj_as_list(response)[0]
|
|
2079
|
+
|
|
2080
|
+
def visual_question_answering(
|
|
2081
|
+
self,
|
|
2082
|
+
image: ContentT,
|
|
2083
|
+
question: str,
|
|
2084
|
+
*,
|
|
2085
|
+
model: Optional[str] = None,
|
|
2086
|
+
) -> List[VisualQuestionAnsweringOutputElement]:
|
|
2087
|
+
"""
|
|
2088
|
+
Answering open-ended questions based on an image.
|
|
2089
|
+
|
|
2090
|
+
Args:
|
|
2091
|
+
image (`Union[str, Path, bytes, BinaryIO]`):
|
|
2092
|
+
The input image for the context. It can be raw bytes, an image file, or a URL to an online image.
|
|
2093
|
+
question (`str`):
|
|
2094
|
+
Question to be answered.
|
|
2095
|
+
model (`str`, *optional*):
|
|
2096
|
+
The model to use for the visual question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to
|
|
2097
|
+
a deployed Inference Endpoint. If not provided, the default recommended visual question answering model will be used.
|
|
2098
|
+
Defaults to None.
|
|
2099
|
+
|
|
2100
|
+
Returns:
|
|
2101
|
+
`List[VisualQuestionAnsweringOutputElement]`: a list of [`VisualQuestionAnsweringOutputElement`] items containing the predicted label and associated probability.
|
|
2102
|
+
|
|
2103
|
+
Raises:
|
|
2104
|
+
`InferenceTimeoutError`:
|
|
2105
|
+
If the model is unavailable or the request times out.
|
|
2106
|
+
`HTTPError`:
|
|
2107
|
+
If the request fails with an HTTP error status code other than HTTP 503.
|
|
2108
|
+
|
|
2109
|
+
Example:
|
|
2110
|
+
```py
|
|
2111
|
+
>>> from huggingface_hub import InferenceClient
|
|
2112
|
+
>>> client = InferenceClient()
|
|
2113
|
+
>>> client.visual_question_answering(
|
|
2114
|
+
... image="https://huggingface.co/datasets/mishig/sample_images/resolve/main/tiger.jpg",
|
|
2115
|
+
... question="What is the animal doing?"
|
|
2116
|
+
... )
|
|
2117
|
+
[
|
|
2118
|
+
VisualQuestionAnsweringOutputElement(score=0.778609573841095, answer='laying down'),
|
|
2119
|
+
VisualQuestionAnsweringOutputElement(score=0.6957435607910156, answer='sitting'),
|
|
2120
|
+
]
|
|
2121
|
+
```
|
|
2122
|
+
"""
|
|
2123
|
+
payload: Dict[str, Any] = {"question": question, "image": _b64_encode(image)}
|
|
2124
|
+
response = self.post(json=payload, model=model, task="visual-question-answering")
|
|
2125
|
+
return VisualQuestionAnsweringOutputElement.parse_obj_as_list(response)
|
|
1814
2126
|
|
|
1815
2127
|
def zero_shot_classification(
|
|
1816
2128
|
self, text: str, labels: List[str], *, multi_label: bool = False, model: Optional[str] = None
|
|
1817
|
-
) -> List[
|
|
2129
|
+
) -> List[ZeroShotClassificationOutputElement]:
|
|
1818
2130
|
"""
|
|
1819
2131
|
Provide as input a text and a set of candidate labels to classify the input text.
|
|
1820
2132
|
|
|
@@ -1830,7 +2142,7 @@ class InferenceClient:
|
|
|
1830
2142
|
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
|
|
1831
2143
|
|
|
1832
2144
|
Returns:
|
|
1833
|
-
`List[
|
|
2145
|
+
`List[ZeroShotClassificationOutputElement]`: List of [`ZeroShotClassificationOutputElement`] items containing the predicted labels and their confidence.
|
|
1834
2146
|
|
|
1835
2147
|
Raises:
|
|
1836
2148
|
[`InferenceTimeoutError`]:
|
|
@@ -1850,19 +2162,19 @@ class InferenceClient:
|
|
|
1850
2162
|
>>> labels = ["space & cosmos", "scientific discovery", "microbiology", "robots", "archeology"]
|
|
1851
2163
|
>>> client.zero_shot_classification(text, labels)
|
|
1852
2164
|
[
|
|
1853
|
-
|
|
1854
|
-
|
|
1855
|
-
|
|
1856
|
-
|
|
1857
|
-
|
|
2165
|
+
ZeroShotClassificationOutputElement(label='scientific discovery', score=0.7961668968200684),
|
|
2166
|
+
ZeroShotClassificationOutputElement(label='space & cosmos', score=0.18570658564567566),
|
|
2167
|
+
ZeroShotClassificationOutputElement(label='microbiology', score=0.00730885099619627),
|
|
2168
|
+
ZeroShotClassificationOutputElement(label='archeology', score=0.006258360575884581),
|
|
2169
|
+
ZeroShotClassificationOutputElement(label='robots', score=0.004559356719255447),
|
|
1858
2170
|
]
|
|
1859
2171
|
>>> client.zero_shot_classification(text, labels, multi_label=True)
|
|
1860
2172
|
[
|
|
1861
|
-
|
|
1862
|
-
|
|
1863
|
-
|
|
1864
|
-
|
|
1865
|
-
|
|
2173
|
+
ZeroShotClassificationOutputElement(label='scientific discovery', score=0.9829297661781311),
|
|
2174
|
+
ZeroShotClassificationOutputElement(label='space & cosmos', score=0.755190908908844),
|
|
2175
|
+
ZeroShotClassificationOutputElement(label='microbiology', score=0.0005462635890580714),
|
|
2176
|
+
ZeroShotClassificationOutputElement(label='archeology', score=0.00047131875180639327),
|
|
2177
|
+
ZeroShotClassificationOutputElement(label='robots', score=0.00030448526376858354),
|
|
1866
2178
|
]
|
|
1867
2179
|
```
|
|
1868
2180
|
"""
|
|
@@ -1882,11 +2194,14 @@ class InferenceClient:
|
|
|
1882
2194
|
task="zero-shot-classification",
|
|
1883
2195
|
)
|
|
1884
2196
|
output = _bytes_to_dict(response)
|
|
1885
|
-
return [
|
|
2197
|
+
return [
|
|
2198
|
+
ZeroShotClassificationOutputElement.parse_obj_as_instance({"label": label, "score": score})
|
|
2199
|
+
for label, score in zip(output["labels"], output["scores"])
|
|
2200
|
+
]
|
|
1886
2201
|
|
|
1887
2202
|
def zero_shot_image_classification(
|
|
1888
2203
|
self, image: ContentT, labels: List[str], *, model: Optional[str] = None
|
|
1889
|
-
) -> List[
|
|
2204
|
+
) -> List[ZeroShotImageClassificationOutputElement]:
|
|
1890
2205
|
"""
|
|
1891
2206
|
Provide input image and text labels to predict text labels for the image.
|
|
1892
2207
|
|
|
@@ -1900,7 +2215,7 @@ class InferenceClient:
|
|
|
1900
2215
|
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
|
|
1901
2216
|
|
|
1902
2217
|
Returns:
|
|
1903
|
-
`List[
|
|
2218
|
+
`List[ZeroShotImageClassificationOutputElement]`: List of [`ZeroShotImageClassificationOutputElement`] items containing the predicted labels and their confidence.
|
|
1904
2219
|
|
|
1905
2220
|
Raises:
|
|
1906
2221
|
[`InferenceTimeoutError`]:
|
|
@@ -1917,7 +2232,7 @@ class InferenceClient:
|
|
|
1917
2232
|
... "https://upload.wikimedia.org/wikipedia/commons/thumb/4/43/Cute_dog.jpg/320px-Cute_dog.jpg",
|
|
1918
2233
|
... labels=["dog", "cat", "horse"],
|
|
1919
2234
|
... )
|
|
1920
|
-
[
|
|
2235
|
+
[ZeroShotImageClassificationOutputElement(label='dog', score=0.956),...]
|
|
1921
2236
|
```
|
|
1922
2237
|
"""
|
|
1923
2238
|
# Raise ValueError if input is less than 2 labels
|
|
@@ -1929,7 +2244,7 @@ class InferenceClient:
|
|
|
1929
2244
|
model=model,
|
|
1930
2245
|
task="zero-shot-image-classification",
|
|
1931
2246
|
)
|
|
1932
|
-
return
|
|
2247
|
+
return ZeroShotImageClassificationOutputElement.parse_obj_as_list(response)
|
|
1933
2248
|
|
|
1934
2249
|
def _resolve_url(self, model: Optional[str] = None, task: Optional[str] = None) -> str:
|
|
1935
2250
|
model = model or self.model
|