huggingface-hub 0.21.4__py3-none-any.whl → 0.22.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of huggingface-hub might be problematic. Click here for more details.

Files changed (96) hide show
  1. huggingface_hub/__init__.py +217 -1
  2. huggingface_hub/_commit_api.py +14 -15
  3. huggingface_hub/_inference_endpoints.py +12 -11
  4. huggingface_hub/_login.py +1 -0
  5. huggingface_hub/_multi_commits.py +1 -0
  6. huggingface_hub/_snapshot_download.py +9 -1
  7. huggingface_hub/_tensorboard_logger.py +1 -0
  8. huggingface_hub/_webhooks_payload.py +1 -0
  9. huggingface_hub/_webhooks_server.py +1 -0
  10. huggingface_hub/commands/_cli_utils.py +1 -0
  11. huggingface_hub/commands/delete_cache.py +1 -0
  12. huggingface_hub/commands/download.py +1 -0
  13. huggingface_hub/commands/env.py +1 -0
  14. huggingface_hub/commands/scan_cache.py +1 -0
  15. huggingface_hub/commands/upload.py +1 -0
  16. huggingface_hub/community.py +1 -0
  17. huggingface_hub/constants.py +3 -1
  18. huggingface_hub/errors.py +38 -0
  19. huggingface_hub/file_download.py +24 -24
  20. huggingface_hub/hf_api.py +47 -35
  21. huggingface_hub/hub_mixin.py +210 -54
  22. huggingface_hub/inference/_client.py +554 -239
  23. huggingface_hub/inference/_common.py +195 -41
  24. huggingface_hub/inference/_generated/_async_client.py +558 -239
  25. huggingface_hub/inference/_generated/types/__init__.py +115 -0
  26. huggingface_hub/inference/_generated/types/audio_classification.py +43 -0
  27. huggingface_hub/inference/_generated/types/audio_to_audio.py +31 -0
  28. huggingface_hub/inference/_generated/types/automatic_speech_recognition.py +116 -0
  29. huggingface_hub/inference/_generated/types/base.py +149 -0
  30. huggingface_hub/inference/_generated/types/chat_completion.py +106 -0
  31. huggingface_hub/inference/_generated/types/depth_estimation.py +29 -0
  32. huggingface_hub/inference/_generated/types/document_question_answering.py +85 -0
  33. huggingface_hub/inference/_generated/types/feature_extraction.py +19 -0
  34. huggingface_hub/inference/_generated/types/fill_mask.py +50 -0
  35. huggingface_hub/inference/_generated/types/image_classification.py +43 -0
  36. huggingface_hub/inference/_generated/types/image_segmentation.py +52 -0
  37. huggingface_hub/inference/_generated/types/image_to_image.py +55 -0
  38. huggingface_hub/inference/_generated/types/image_to_text.py +105 -0
  39. huggingface_hub/inference/_generated/types/object_detection.py +55 -0
  40. huggingface_hub/inference/_generated/types/question_answering.py +77 -0
  41. huggingface_hub/inference/_generated/types/sentence_similarity.py +28 -0
  42. huggingface_hub/inference/_generated/types/summarization.py +46 -0
  43. huggingface_hub/inference/_generated/types/table_question_answering.py +45 -0
  44. huggingface_hub/inference/_generated/types/text2text_generation.py +45 -0
  45. huggingface_hub/inference/_generated/types/text_classification.py +43 -0
  46. huggingface_hub/inference/_generated/types/text_generation.py +161 -0
  47. huggingface_hub/inference/_generated/types/text_to_audio.py +105 -0
  48. huggingface_hub/inference/_generated/types/text_to_image.py +57 -0
  49. huggingface_hub/inference/_generated/types/token_classification.py +53 -0
  50. huggingface_hub/inference/_generated/types/translation.py +46 -0
  51. huggingface_hub/inference/_generated/types/video_classification.py +47 -0
  52. huggingface_hub/inference/_generated/types/visual_question_answering.py +53 -0
  53. huggingface_hub/inference/_generated/types/zero_shot_classification.py +56 -0
  54. huggingface_hub/inference/_generated/types/zero_shot_image_classification.py +51 -0
  55. huggingface_hub/inference/_generated/types/zero_shot_object_detection.py +55 -0
  56. huggingface_hub/inference/_templating.py +105 -0
  57. huggingface_hub/inference/_types.py +4 -152
  58. huggingface_hub/keras_mixin.py +39 -17
  59. huggingface_hub/lfs.py +20 -8
  60. huggingface_hub/repocard.py +11 -3
  61. huggingface_hub/repocard_data.py +12 -2
  62. huggingface_hub/serialization/__init__.py +1 -0
  63. huggingface_hub/serialization/_base.py +1 -0
  64. huggingface_hub/serialization/_numpy.py +1 -0
  65. huggingface_hub/serialization/_tensorflow.py +1 -0
  66. huggingface_hub/serialization/_torch.py +1 -0
  67. huggingface_hub/utils/__init__.py +4 -1
  68. huggingface_hub/utils/_cache_manager.py +7 -0
  69. huggingface_hub/utils/_chunk_utils.py +1 -0
  70. huggingface_hub/utils/_datetime.py +1 -0
  71. huggingface_hub/utils/_errors.py +10 -1
  72. huggingface_hub/utils/_experimental.py +1 -0
  73. huggingface_hub/utils/_fixes.py +19 -3
  74. huggingface_hub/utils/_git_credential.py +1 -0
  75. huggingface_hub/utils/_headers.py +10 -3
  76. huggingface_hub/utils/_hf_folder.py +1 -0
  77. huggingface_hub/utils/_http.py +1 -0
  78. huggingface_hub/utils/_pagination.py +1 -0
  79. huggingface_hub/utils/_paths.py +1 -0
  80. huggingface_hub/utils/_runtime.py +22 -0
  81. huggingface_hub/utils/_subprocess.py +1 -0
  82. huggingface_hub/utils/_token.py +1 -0
  83. huggingface_hub/utils/_typing.py +29 -1
  84. huggingface_hub/utils/_validators.py +1 -0
  85. huggingface_hub/utils/endpoint_helpers.py +1 -0
  86. huggingface_hub/utils/logging.py +1 -1
  87. huggingface_hub/utils/sha.py +1 -0
  88. huggingface_hub/utils/tqdm.py +1 -0
  89. {huggingface_hub-0.21.4.dist-info → huggingface_hub-0.22.0rc0.dist-info}/METADATA +14 -15
  90. huggingface_hub-0.22.0rc0.dist-info/RECORD +113 -0
  91. {huggingface_hub-0.21.4.dist-info → huggingface_hub-0.22.0rc0.dist-info}/WHEEL +1 -1
  92. huggingface_hub/inference/_text_generation.py +0 -551
  93. huggingface_hub-0.21.4.dist-info/RECORD +0 -81
  94. {huggingface_hub-0.21.4.dist-info → huggingface_hub-0.22.0rc0.dist-info}/LICENSE +0 -0
  95. {huggingface_hub-0.21.4.dist-info → huggingface_hub-0.22.0rc0.dist-info}/entry_points.txt +0 -0
  96. {huggingface_hub-0.21.4.dist-info → huggingface_hub-0.22.0rc0.dist-info}/top_level.txt +0 -0
@@ -23,7 +23,6 @@ import base64
23
23
  import logging
24
24
  import time
25
25
  import warnings
26
- from dataclasses import asdict
27
26
  from typing import (
28
27
  TYPE_CHECKING,
29
28
  Any,
@@ -39,11 +38,13 @@ from typing import (
39
38
  from requests.structures import CaseInsensitiveDict
40
39
 
41
40
  from huggingface_hub.constants import ALL_INFERENCE_API_FRAMEWORKS, INFERENCE_ENDPOINT, MAIN_INFERENCE_API_FRAMEWORKS
41
+ from huggingface_hub.errors import InferenceTimeoutError
42
42
  from huggingface_hub.inference._common import (
43
43
  TASKS_EXPECTING_IMAGES,
44
44
  ContentT,
45
- InferenceTimeoutError,
46
45
  ModelStatus,
46
+ _async_stream_chat_completion_response_from_bytes,
47
+ _async_stream_chat_completion_response_from_text_generation,
47
48
  _async_stream_text_generation_response,
48
49
  _b64_encode,
49
50
  _b64_to_image,
@@ -52,27 +53,42 @@ from huggingface_hub.inference._common import (
52
53
  _bytes_to_list,
53
54
  _fetch_recommended_models,
54
55
  _import_numpy,
56
+ _is_chat_completion_server,
55
57
  _is_tgi_server,
56
58
  _open_as_binary,
59
+ _set_as_non_chat_completion_server,
57
60
  _set_as_non_tgi,
58
- )
59
- from huggingface_hub.inference._text_generation import (
60
- TextGenerationParameters,
61
- TextGenerationRequest,
62
- TextGenerationResponse,
63
- TextGenerationStreamResponse,
64
61
  raise_text_generation_error,
65
62
  )
63
+ from huggingface_hub.inference._generated.types import (
64
+ AudioClassificationOutputElement,
65
+ AudioToAudioOutputElement,
66
+ AutomaticSpeechRecognitionOutput,
67
+ ChatCompletionOutput,
68
+ ChatCompletionOutputChoice,
69
+ ChatCompletionOutputChoiceMessage,
70
+ ChatCompletionStreamOutput,
71
+ DocumentQuestionAnsweringOutputElement,
72
+ FillMaskOutputElement,
73
+ ImageClassificationOutputElement,
74
+ ImageSegmentationOutputElement,
75
+ ImageToTextOutput,
76
+ ObjectDetectionOutputElement,
77
+ QuestionAnsweringOutputElement,
78
+ SummarizationOutput,
79
+ TableQuestionAnsweringOutputElement,
80
+ TextClassificationOutputElement,
81
+ TextGenerationOutput,
82
+ TextGenerationStreamOutput,
83
+ TokenClassificationOutputElement,
84
+ TranslationOutput,
85
+ VisualQuestionAnsweringOutputElement,
86
+ ZeroShotClassificationOutputElement,
87
+ ZeroShotImageClassificationOutputElement,
88
+ )
89
+ from huggingface_hub.inference._templating import render_chat_prompt
66
90
  from huggingface_hub.inference._types import (
67
- AudioToAudioOutput,
68
- ClassificationOutput,
69
- ConversationalOutput,
70
- FillMaskOutput,
71
- ImageSegmentationOutput,
72
- ObjectDetectionOutput,
73
- QuestionAnsweringOutput,
74
- TableQuestionAnsweringOutput,
75
- TokenClassificationOutput,
91
+ ConversationalOutput, # soon to be removed
76
92
  )
77
93
  from huggingface_hub.utils import (
78
94
  build_hf_headers,
@@ -100,9 +116,9 @@ class AsyncInferenceClient:
100
116
  The model to run inference with. Can be a model id hosted on the Hugging Face Hub, e.g. `bigcode/starcoder`
101
117
  or a URL to a deployed Inference Endpoint. Defaults to None, in which case a recommended model is
102
118
  automatically selected for the task.
103
- token (`str`, *optional*):
104
- Hugging Face token. Will default to the locally saved token. Pass `token=False` if you don't want to send
105
- your token to the server.
119
+ token (`str` or `bool`, *optional*):
120
+ Hugging Face token. Will default to the locally saved token if not provided.
121
+ Pass `token=False` if you don't want to send your token to the server.
106
122
  timeout (`float`, `optional`):
107
123
  The maximum number of seconds to wait for a response from the server. Loading a new model in Inference
108
124
  API can take up to several minutes. Defaults to None, meaning it will loop until the server is available.
@@ -122,6 +138,7 @@ class AsyncInferenceClient:
122
138
  cookies: Optional[Dict[str, str]] = None,
123
139
  ) -> None:
124
140
  self.model: Optional[str] = model
141
+ self.token: Union[str, bool, None] = token
125
142
  self.headers = CaseInsensitiveDict(build_hf_headers(token=token)) # contains 'authorization' + 'user-agent'
126
143
  if headers is not None:
127
144
  self.headers.update(headers)
@@ -140,11 +157,10 @@ class AsyncInferenceClient:
140
157
  model: Optional[str] = None,
141
158
  task: Optional[str] = None,
142
159
  stream: Literal[False] = ...,
143
- ) -> bytes:
144
- pass
160
+ ) -> bytes: ...
145
161
 
146
162
  @overload
147
- async def post(
163
+ async def post( # type: ignore[misc]
148
164
  self,
149
165
  *,
150
166
  json: Optional[Union[str, Dict, List]] = None,
@@ -152,8 +168,18 @@ class AsyncInferenceClient:
152
168
  model: Optional[str] = None,
153
169
  task: Optional[str] = None,
154
170
  stream: Literal[True] = ...,
155
- ) -> AsyncIterable[bytes]:
156
- pass
171
+ ) -> AsyncIterable[bytes]: ...
172
+
173
+ @overload
174
+ async def post(
175
+ self,
176
+ *,
177
+ json: Optional[Union[str, Dict, List]] = None,
178
+ data: Optional[ContentT] = None,
179
+ model: Optional[str] = None,
180
+ task: Optional[str] = None,
181
+ stream: bool = False,
182
+ ) -> Union[bytes, AsyncIterable[bytes]]: ...
157
183
 
158
184
  async def post(
159
185
  self,
@@ -263,7 +289,7 @@ class AsyncInferenceClient:
263
289
  audio: ContentT,
264
290
  *,
265
291
  model: Optional[str] = None,
266
- ) -> List[ClassificationOutput]:
292
+ ) -> List[AudioClassificationOutputElement]:
267
293
  """
268
294
  Perform audio classification on the provided audio content.
269
295
 
@@ -277,7 +303,7 @@ class AsyncInferenceClient:
277
303
  audio classification will be used.
278
304
 
279
305
  Returns:
280
- `List[Dict]`: The classification output containing the predicted label and its confidence.
306
+ `List[AudioClassificationOutputElement]`: List of [`AudioClassificationOutputElement`] items containing the predicted labels and their confidence.
281
307
 
282
308
  Raises:
283
309
  [`InferenceTimeoutError`]:
@@ -291,18 +317,22 @@ class AsyncInferenceClient:
291
317
  >>> from huggingface_hub import AsyncInferenceClient
292
318
  >>> client = AsyncInferenceClient()
293
319
  >>> await client.audio_classification("audio.flac")
294
- [{'score': 0.4976358711719513, 'label': 'hap'}, {'score': 0.3677836060523987, 'label': 'neu'},...]
320
+ [
321
+ AudioClassificationOutputElement(score=0.4976358711719513, label='hap'),
322
+ AudioClassificationOutputElement(score=0.3677836060523987, label='neu'),
323
+ ...
324
+ ]
295
325
  ```
296
326
  """
297
327
  response = await self.post(data=audio, model=model, task="audio-classification")
298
- return _bytes_to_list(response)
328
+ return AudioClassificationOutputElement.parse_obj_as_list(response)
299
329
 
300
330
  async def audio_to_audio(
301
331
  self,
302
332
  audio: ContentT,
303
333
  *,
304
334
  model: Optional[str] = None,
305
- ) -> List[AudioToAudioOutput]:
335
+ ) -> List[AudioToAudioOutputElement]:
306
336
  """
307
337
  Performs multiple tasks related to audio-to-audio depending on the model (eg: speech enhancement, source separation).
308
338
 
@@ -316,7 +346,7 @@ class AsyncInferenceClient:
316
346
  audio_to_audio will be used.
317
347
 
318
348
  Returns:
319
- `List[Dict]`: A list of dictionary where each index contains audios label, content-type, and audio content in blob.
349
+ `List[AudioToAudioOutputElement]`: A list of [`AudioToAudioOutputElement`] items containing audios label, content-type, and audio content in blob.
320
350
 
321
351
  Raises:
322
352
  `InferenceTimeoutError`:
@@ -332,13 +362,13 @@ class AsyncInferenceClient:
332
362
  >>> audio_output = await client.audio_to_audio("audio.flac")
333
363
  >>> async for i, item in enumerate(audio_output):
334
364
  >>> with open(f"output_{i}.flac", "wb") as f:
335
- f.write(item["blob"])
365
+ f.write(item.blob)
336
366
  ```
337
367
  """
338
368
  response = await self.post(data=audio, model=model, task="audio-to-audio")
339
- audio_output = _bytes_to_list(response)
369
+ audio_output = AudioToAudioOutputElement.parse_obj_as_list(response)
340
370
  for item in audio_output:
341
- item["blob"] = base64.b64decode(item["blob"])
371
+ item.blob = base64.b64decode(item.blob)
342
372
  return audio_output
343
373
 
344
374
  async def automatic_speech_recognition(
@@ -346,7 +376,7 @@ class AsyncInferenceClient:
346
376
  audio: ContentT,
347
377
  *,
348
378
  model: Optional[str] = None,
349
- ) -> str:
379
+ ) -> AutomaticSpeechRecognitionOutput:
350
380
  """
351
381
  Perform automatic speech recognition (ASR or audio-to-text) on the given audio content.
352
382
 
@@ -358,7 +388,7 @@ class AsyncInferenceClient:
358
388
  Inference Endpoint. If not provided, the default recommended model for ASR will be used.
359
389
 
360
390
  Returns:
361
- str: The transcribed text.
391
+ [`AutomaticSpeechRecognitionOutput`]: An item containing the transcribed text and optionally the timestamp chunks.
362
392
 
363
393
  Raises:
364
394
  [`InferenceTimeoutError`]:
@@ -371,12 +401,266 @@ class AsyncInferenceClient:
371
401
  # Must be run in an async context
372
402
  >>> from huggingface_hub import AsyncInferenceClient
373
403
  >>> client = AsyncInferenceClient()
374
- >>> await client.automatic_speech_recognition("hello_world.flac")
404
+ >>> await client.automatic_speech_recognition("hello_world.flac").text
375
405
  "hello world"
376
406
  ```
377
407
  """
378
408
  response = await self.post(data=audio, model=model, task="automatic-speech-recognition")
379
- return _bytes_to_dict(response)["text"]
409
+ return AutomaticSpeechRecognitionOutput.parse_obj_as_instance(response)
410
+
411
+ @overload
412
+ async def chat_completion( # type: ignore
413
+ self,
414
+ messages: List[Dict[str, str]],
415
+ *,
416
+ model: Optional[str] = None,
417
+ stream: Literal[False] = False,
418
+ max_tokens: int = 20,
419
+ seed: Optional[int] = None,
420
+ stop: Optional[Union[List[str], str]] = None,
421
+ temperature: float = 1.0,
422
+ top_p: Optional[float] = None,
423
+ ) -> ChatCompletionOutput: ...
424
+
425
+ @overload
426
+ async def chat_completion( # type: ignore
427
+ self,
428
+ messages: List[Dict[str, str]],
429
+ *,
430
+ model: Optional[str] = None,
431
+ stream: Literal[True] = True,
432
+ max_tokens: int = 20,
433
+ seed: Optional[int] = None,
434
+ stop: Optional[Union[List[str], str]] = None,
435
+ temperature: float = 1.0,
436
+ top_p: Optional[float] = None,
437
+ ) -> AsyncIterable[ChatCompletionStreamOutput]: ...
438
+
439
+ @overload
440
+ async def chat_completion(
441
+ self,
442
+ messages: List[Dict[str, str]],
443
+ *,
444
+ model: Optional[str] = None,
445
+ stream: bool = False,
446
+ max_tokens: int = 20,
447
+ seed: Optional[int] = None,
448
+ stop: Optional[Union[List[str], str]] = None,
449
+ temperature: float = 1.0,
450
+ top_p: Optional[float] = None,
451
+ ) -> Union[ChatCompletionOutput, AsyncIterable[ChatCompletionStreamOutput]]: ...
452
+
453
+ async def chat_completion(
454
+ self,
455
+ messages: List[Dict[str, str]],
456
+ *,
457
+ model: Optional[str] = None,
458
+ stream: bool = False,
459
+ max_tokens: int = 20,
460
+ seed: Optional[int] = None,
461
+ stop: Optional[Union[List[str], str]] = None,
462
+ temperature: float = 1.0,
463
+ top_p: Optional[float] = None,
464
+ ) -> Union[ChatCompletionOutput, AsyncIterable[ChatCompletionStreamOutput]]:
465
+ """
466
+ A method for completing conversations using a specified language model.
467
+
468
+ <Tip>
469
+
470
+ If the model is served by a server supporting chat-completion, the method will directly call the server's
471
+ `/v1/chat/completions` endpoint. If the server does not support chat-completion, the method will render the
472
+ chat template client-side based on the information fetched from the Hub API. In this case, you will need to
473
+ have `minijinja` template engine installed. Run `pip install "huggingface_hub[inference]"` or `pip install minijinja`
474
+ to install it.
475
+
476
+ </Tip>
477
+
478
+ Args:
479
+ messages (List[Union[`SystemMessage`, `UserMessage`, `AssistantMessage`]]):
480
+ Conversation history consisting of roles and content pairs.
481
+ model (`str`, *optional*):
482
+ The model to use for chat-completion. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
483
+ Inference Endpoint. If not provided, the default recommended model for chat-based text-generation will be used.
484
+ See https://huggingface.co/tasks/text-generation for more details.
485
+ frequency_penalty (`float`, optional):
486
+ Penalizes new tokens based on their existing frequency
487
+ in the text so far. Range: [-2.0, 2.0]. Defaults to 0.0.
488
+ max_tokens (`int`, optional):
489
+ Maximum number of tokens allowed in the response. Defaults to 20.
490
+ seed (Optional[`int`], optional):
491
+ Seed for reproducible control flow. Defaults to None.
492
+ stop (Optional[`str`], optional):
493
+ Up to four strings which trigger the end of the response.
494
+ Defaults to None.
495
+ stream (`bool`, optional):
496
+ Enable realtime streaming of responses. Defaults to False.
497
+ temperature (`float`, optional):
498
+ Controls randomness of the generations. Lower values ensure
499
+ less random completions. Range: [0, 2]. Defaults to 1.0.
500
+ top_p (`float`, optional):
501
+ Fraction of the most likely next words to sample from.
502
+ Must be between 0 and 1. Defaults to 1.0.
503
+
504
+ Returns:
505
+ `Union[ChatCompletionOutput, Iterable[ChatCompletionStreamOutput]]`:
506
+ Generated text returned from the server:
507
+ - if `stream=False`, the generated text is returned as a [`ChatCompletionOutput`] (default).
508
+ - if `stream=True`, the generated text is returned token by token as a sequence of [`ChatCompletionStreamOutput`].
509
+
510
+ Raises:
511
+ [`InferenceTimeoutError`]:
512
+ If the model is unavailable or the request times out.
513
+ `aiohttp.ClientResponseError`:
514
+ If the request fails with an HTTP error status code other than HTTP 503.
515
+
516
+ Example:
517
+ ```py
518
+ # Must be run in an async context
519
+ >>> from huggingface_hub import AsyncInferenceClient
520
+ >>> messages = [{"role": "user", "content": "What is the capital of France?"}]
521
+ >>> client = AsyncInferenceClient("HuggingFaceH4/zephyr-7b-beta")
522
+ >>> await client.chat_completion(messages, max_tokens=100)
523
+ ChatCompletionOutput(
524
+ choices=[
525
+ ChatCompletionOutputChoice(
526
+ finish_reason='eos_token',
527
+ index=0,
528
+ message=ChatCompletionOutputChoiceMessage(
529
+ content='The capital of France is Paris. The official name of the city is "Ville de Paris" (City of Paris) and the name of the country\'s governing body, which is located in Paris, is "La République française" (The French Republic). \nI hope that helps! Let me know if you need any further information.'
530
+ )
531
+ )
532
+ ],
533
+ created=1710498360
534
+ )
535
+
536
+ >>> async for token in await client.chat_completion(messages, max_tokens=10, stream=True):
537
+ ... print(token)
538
+ ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content='The', role='assistant'), index=0, finish_reason=None)], created=1710498504)
539
+ ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content=' capital', role='assistant'), index=0, finish_reason=None)], created=1710498504)
540
+ (...)
541
+ ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content=' may', role='assistant'), index=0, finish_reason=None)], created=1710498504)
542
+ ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content=None, role=None), index=0, finish_reason='length')], created=1710498504)
543
+ ```
544
+ """
545
+ # determine model
546
+ model = model or self.model or self.get_recommended_model("text-generation")
547
+
548
+ if _is_chat_completion_server(model):
549
+ # First, let's consider the server has a `/v1/chat/completions` endpoint.
550
+ # If that's the case, we don't have to render the chat template client-side.
551
+ model_url = self._resolve_url(model) + "/v1/chat/completions"
552
+
553
+ try:
554
+ data = await self.post(
555
+ model=model_url,
556
+ json=dict(
557
+ model="tgi", # random string
558
+ messages=messages,
559
+ max_tokens=max_tokens,
560
+ seed=seed,
561
+ stop=stop,
562
+ temperature=temperature,
563
+ top_p=top_p,
564
+ stream=stream,
565
+ ),
566
+ stream=stream,
567
+ )
568
+ except _import_aiohttp().ClientResponseError:
569
+ # Let's consider the server is not a chat completion server.
570
+ # Then we call again `chat_completion` which will render the chat template client side.
571
+ # (can be HTTP 500, HTTP 400, HTTP 404 depending on the server)
572
+ _set_as_non_chat_completion_server(model)
573
+ return await self.chat_completion(
574
+ messages=messages,
575
+ model=model,
576
+ stream=stream,
577
+ max_tokens=max_tokens,
578
+ seed=seed,
579
+ stop=stop,
580
+ temperature=temperature,
581
+ top_p=top_p,
582
+ )
583
+
584
+ if stream:
585
+ return _async_stream_chat_completion_response_from_bytes(data) # type: ignore[arg-type]
586
+
587
+ return ChatCompletionOutput.parse_obj_as_instance(data) # type: ignore[arg-type]
588
+
589
+ # At this point, we know the server is not a chat completion server.
590
+ # We need to render the chat template client side based on the information we can fetch from
591
+ # the Hub API.
592
+
593
+ model_id = None
594
+ if model.startswith(("http://", "https://")):
595
+ # If URL, we need to know which model is served. This is not always possible.
596
+ # A workaround is to list the user Inference Endpoints and check if one of them correspond to the model URL.
597
+ # If not, we raise an error.
598
+ # TODO: fix when we have a proper API for this (at least for Inference Endpoints)
599
+ # TODO: what if Sagemaker URL?
600
+ # TODO: what if Azure URL?
601
+ from ..hf_api import HfApi
602
+
603
+ for endpoint in HfApi(token=self.token).list_inference_endpoints():
604
+ if endpoint.url == model:
605
+ model_id = endpoint.repository
606
+ break
607
+ else:
608
+ model_id = model
609
+
610
+ if model_id is None:
611
+ # If we don't have the model ID, we can't fetch the chat template.
612
+ # We raise an error.
613
+ raise ValueError(
614
+ "Request can't be processed as the model ID can't be inferred from model URL. "
615
+ "This is needed to fetch the chat template from the Hub since the model is not "
616
+ "served with a Chat-completion API."
617
+ )
618
+
619
+ # fetch chat template + tokens
620
+ prompt = render_chat_prompt(model_id=model_id, token=self.token, messages=messages)
621
+
622
+ # generate response
623
+ stop_sequences = [stop] if isinstance(stop, str) else stop
624
+ text_generation_output = await self.text_generation(
625
+ prompt=prompt,
626
+ details=True,
627
+ stream=stream,
628
+ model=model,
629
+ max_new_tokens=max_tokens,
630
+ seed=seed,
631
+ stop_sequences=stop_sequences,
632
+ temperature=temperature,
633
+ top_p=top_p,
634
+ )
635
+
636
+ created = int(time.time())
637
+
638
+ if stream:
639
+ return _async_stream_chat_completion_response_from_text_generation(text_generation_output) # type: ignore [arg-type]
640
+
641
+ if isinstance(text_generation_output, TextGenerationOutput):
642
+ # General use case => format ChatCompletionOutput from text generation details
643
+ content: str = text_generation_output.generated_text
644
+ finish_reason: str = text_generation_output.details.finish_reason # type: ignore[union-attr]
645
+ else:
646
+ # Corner case: if server doesn't support details (e.g. if not a TGI server), we only receive an output string.
647
+ # In such a case, `finish_reason` is set to `"unk"`.
648
+ content = text_generation_output # type: ignore[assignment]
649
+ finish_reason = "unk"
650
+
651
+ return ChatCompletionOutput(
652
+ created=created,
653
+ choices=[
654
+ ChatCompletionOutputChoice(
655
+ finish_reason=finish_reason, # type: ignore
656
+ index=0,
657
+ message=ChatCompletionOutputChoiceMessage(
658
+ content=content,
659
+ role="assistant",
660
+ ),
661
+ )
662
+ ],
663
+ )
380
664
 
381
665
  async def conversational(
382
666
  self,
@@ -390,6 +674,13 @@ class AsyncInferenceClient:
390
674
  """
391
675
  Generate conversational responses based on the given input text (i.e. chat with the API).
392
676
 
677
+ <Tip warning={true}>
678
+
679
+ [`InferenceClient.conversational`] API is deprecated and will be removed in a future release. Please use
680
+ [`InferenceClient.chat_completion`] instead.
681
+
682
+ </Tip>
683
+
393
684
  Args:
394
685
  text (`str`):
395
686
  The last input from the user in the conversation.
@@ -430,6 +721,11 @@ class AsyncInferenceClient:
430
721
  ... )
431
722
  ```
432
723
  """
724
+ warnings.warn(
725
+ "'InferenceClient.conversational' is deprecated and will be removed starting from huggingface_hub>=0.25. "
726
+ "Please use the more appropriate 'InferenceClient.chat_completion' API instead.",
727
+ FutureWarning,
728
+ )
433
729
  payload: Dict[str, Any] = {"inputs": {"text": text}}
434
730
  if generated_responses is not None:
435
731
  payload["inputs"]["generated_responses"] = generated_responses
@@ -440,58 +736,13 @@ class AsyncInferenceClient:
440
736
  response = await self.post(json=payload, model=model, task="conversational")
441
737
  return _bytes_to_dict(response) # type: ignore
442
738
 
443
- async def visual_question_answering(
444
- self,
445
- image: ContentT,
446
- question: str,
447
- *,
448
- model: Optional[str] = None,
449
- ) -> List[str]:
450
- """
451
- Answering open-ended questions based on an image.
452
-
453
- Args:
454
- image (`Union[str, Path, bytes, BinaryIO]`):
455
- The input image for the context. It can be raw bytes, an image file, or a URL to an online image.
456
- question (`str`):
457
- Question to be answered.
458
- model (`str`, *optional*):
459
- The model to use for the visual question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to
460
- a deployed Inference Endpoint. If not provided, the default recommended visual question answering model will be used.
461
- Defaults to None.
462
-
463
- Returns:
464
- `List[Dict]`: a list of dictionaries containing the predicted label and associated probability.
465
-
466
- Raises:
467
- `InferenceTimeoutError`:
468
- If the model is unavailable or the request times out.
469
- `aiohttp.ClientResponseError`:
470
- If the request fails with an HTTP error status code other than HTTP 503.
471
-
472
- Example:
473
- ```py
474
- # Must be run in an async context
475
- >>> from huggingface_hub import AsyncInferenceClient
476
- >>> client = AsyncInferenceClient()
477
- >>> await client.visual_question_answering(
478
- ... image="https://huggingface.co/datasets/mishig/sample_images/resolve/main/tiger.jpg",
479
- ... question="What is the animal doing?"
480
- ... )
481
- [{'score': 0.778609573841095, 'answer': 'laying down'},{'score': 0.6957435607910156, 'answer': 'sitting'}, ...]
482
- ```
483
- """
484
- payload: Dict[str, Any] = {"question": question, "image": _b64_encode(image)}
485
- response = await self.post(json=payload, model=model, task="visual-question-answering")
486
- return _bytes_to_list(response)
487
-
488
739
  async def document_question_answering(
489
740
  self,
490
741
  image: ContentT,
491
742
  question: str,
492
743
  *,
493
744
  model: Optional[str] = None,
494
- ) -> List[QuestionAnsweringOutput]:
745
+ ) -> List[DocumentQuestionAnsweringOutputElement]:
495
746
  """
496
747
  Answer questions on document images.
497
748
 
@@ -506,7 +757,7 @@ class AsyncInferenceClient:
506
757
  Defaults to None.
507
758
 
508
759
  Returns:
509
- `List[Dict]`: a list of dictionaries containing the predicted label, associated probability, word ids, and page number.
760
+ `List[DocumentQuestionAnsweringOutputElement]`: a list of [`DocumentQuestionAnsweringOutputElement`] items containing the predicted label, associated probability, word ids, and page number.
510
761
 
511
762
  Raises:
512
763
  [`InferenceTimeoutError`]:
@@ -520,12 +771,12 @@ class AsyncInferenceClient:
520
771
  >>> from huggingface_hub import AsyncInferenceClient
521
772
  >>> client = AsyncInferenceClient()
522
773
  >>> await client.document_question_answering(image="https://huggingface.co/spaces/impira/docquery/resolve/2359223c1837a7587402bda0f2643382a6eefeab/invoice.png", question="What is the invoice number?")
523
- [{'score': 0.42515629529953003, 'answer': 'us-001', 'start': 16, 'end': 16}]
774
+ [DocumentQuestionAnsweringOutputElement(score=0.42515629529953003, answer='us-001', start=16, end=16)]
524
775
  ```
525
776
  """
526
777
  payload: Dict[str, Any] = {"question": question, "image": _b64_encode(image)}
527
778
  response = await self.post(json=payload, model=model, task="document-question-answering")
528
- return _bytes_to_list(response)
779
+ return DocumentQuestionAnsweringOutputElement.parse_obj_as_list(response)
529
780
 
530
781
  async def feature_extraction(self, text: str, *, model: Optional[str] = None) -> "np.ndarray":
531
782
  """
@@ -564,7 +815,7 @@ class AsyncInferenceClient:
564
815
  np = _import_numpy()
565
816
  return np.array(_bytes_to_dict(response), dtype="float32")
566
817
 
567
- async def fill_mask(self, text: str, *, model: Optional[str] = None) -> List[FillMaskOutput]:
818
+ async def fill_mask(self, text: str, *, model: Optional[str] = None) -> List[FillMaskOutputElement]:
568
819
  """
569
820
  Fill in a hole with a missing word (token to be precise).
570
821
 
@@ -577,7 +828,7 @@ class AsyncInferenceClient:
577
828
  Defaults to None.
578
829
 
579
830
  Returns:
580
- `List[Dict]`: a list of fill mask output dictionaries containing the predicted label, associated
831
+ `List[FillMaskOutputElement]`: a list of [`FillMaskOutputElement`] items containing the predicted label, associated
581
832
  probability, token reference, and completed text.
582
833
 
583
834
  Raises:
@@ -592,25 +843,21 @@ class AsyncInferenceClient:
592
843
  >>> from huggingface_hub import AsyncInferenceClient
593
844
  >>> client = AsyncInferenceClient()
594
845
  >>> await client.fill_mask("The goal of life is <mask>.")
595
- [{'score': 0.06897063553333282,
596
- 'token': 11098,
597
- 'token_str': ' happiness',
598
- 'sequence': 'The goal of life is happiness.'},
599
- {'score': 0.06554922461509705,
600
- 'token': 45075,
601
- 'token_str': ' immortality',
602
- 'sequence': 'The goal of life is immortality.'}]
846
+ [
847
+ FillMaskOutputElement(score=0.06897063553333282, token=11098, token_str=' happiness', sequence='The goal of life is happiness.'),
848
+ FillMaskOutputElement(score=0.06554922461509705, token=45075, token_str=' immortality', sequence='The goal of life is immortality.')
849
+ ]
603
850
  ```
604
851
  """
605
852
  response = await self.post(json={"inputs": text}, model=model, task="fill-mask")
606
- return _bytes_to_list(response)
853
+ return FillMaskOutputElement.parse_obj_as_list(response)
607
854
 
608
855
  async def image_classification(
609
856
  self,
610
857
  image: ContentT,
611
858
  *,
612
859
  model: Optional[str] = None,
613
- ) -> List[ClassificationOutput]:
860
+ ) -> List[ImageClassificationOutputElement]:
614
861
  """
615
862
  Perform image classification on the given image using the specified model.
616
863
 
@@ -622,7 +869,7 @@ class AsyncInferenceClient:
622
869
  deployed Inference Endpoint. If not provided, the default recommended model for image classification will be used.
623
870
 
624
871
  Returns:
625
- `List[Dict]`: a list of dictionaries containing the predicted label and associated probability.
872
+ `List[ImageClassificationOutputElement]`: a list of [`ImageClassificationOutputElement`] items containing the predicted label and associated probability.
626
873
 
627
874
  Raises:
628
875
  [`InferenceTimeoutError`]:
@@ -636,18 +883,18 @@ class AsyncInferenceClient:
636
883
  >>> from huggingface_hub import AsyncInferenceClient
637
884
  >>> client = AsyncInferenceClient()
638
885
  >>> await client.image_classification("https://upload.wikimedia.org/wikipedia/commons/thumb/4/43/Cute_dog.jpg/320px-Cute_dog.jpg")
639
- [{'score': 0.9779096841812134, 'label': 'Blenheim spaniel'}, ...]
886
+ [ImageClassificationOutputElement(score=0.9779096841812134, label='Blenheim spaniel'), ...]
640
887
  ```
641
888
  """
642
889
  response = await self.post(data=image, model=model, task="image-classification")
643
- return _bytes_to_list(response)
890
+ return ImageClassificationOutputElement.parse_obj_as_list(response)
644
891
 
645
892
  async def image_segmentation(
646
893
  self,
647
894
  image: ContentT,
648
895
  *,
649
896
  model: Optional[str] = None,
650
- ) -> List[ImageSegmentationOutput]:
897
+ ) -> List[ImageSegmentationOutputElement]:
651
898
  """
652
899
  Perform image segmentation on the given image using the specified model.
653
900
 
@@ -665,7 +912,7 @@ class AsyncInferenceClient:
665
912
  deployed Inference Endpoint. If not provided, the default recommended model for image segmentation will be used.
666
913
 
667
914
  Returns:
668
- `List[Dict]`: A list of dictionaries containing the segmented masks and associated attributes.
915
+ `List[ImageSegmentationOutputElement]`: A list of [`ImageSegmentationOutputElement`] items containing the segmented masks and associated attributes.
669
916
 
670
917
  Raises:
671
918
  [`InferenceTimeoutError`]:
@@ -679,19 +926,13 @@ class AsyncInferenceClient:
679
926
  >>> from huggingface_hub import AsyncInferenceClient
680
927
  >>> client = AsyncInferenceClient()
681
928
  >>> await client.image_segmentation("cat.jpg"):
682
- [{'score': 0.989008, 'label': 'LABEL_184', 'mask': <PIL.PngImagePlugin.PngImageFile image mode=L size=400x300 at 0x7FDD2B129CC0>}, ...]
929
+ [ImageSegmentationOutputElement(score=0.989008, label='LABEL_184', mask=<PIL.PngImagePlugin.PngImageFile image mode=L size=400x300 at 0x7FDD2B129CC0>), ...]
683
930
  ```
684
931
  """
685
-
686
- # Segment
687
932
  response = await self.post(data=image, model=model, task="image-segmentation")
688
- output = _bytes_to_dict(response)
689
-
690
- # Parse masks as PIL Image
691
- if not isinstance(output, list):
692
- raise ValueError(f"Server output must be a list. Got {type(output)}: {str(output)[:200]}...")
933
+ output = ImageSegmentationOutputElement.parse_obj_as_list(response)
693
934
  for item in output:
694
- item["mask"] = _b64_to_image(item["mask"])
935
+ item.mask = _b64_to_image(item.mask)
695
936
  return output
696
937
 
697
938
  async def image_to_image(
@@ -779,7 +1020,7 @@ class AsyncInferenceClient:
779
1020
  response = await self.post(json=payload, data=data, model=model, task="image-to-image")
780
1021
  return _bytes_to_image(response)
781
1022
 
782
- async def image_to_text(self, image: ContentT, *, model: Optional[str] = None) -> str:
1023
+ async def image_to_text(self, image: ContentT, *, model: Optional[str] = None) -> ImageToTextOutput:
783
1024
  """
784
1025
  Takes an input image and return text.
785
1026
 
@@ -794,7 +1035,7 @@ class AsyncInferenceClient:
794
1035
  Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
795
1036
 
796
1037
  Returns:
797
- `str`: The generated text.
1038
+ [`ImageToTextOutput`]: The generated text.
798
1039
 
799
1040
  Raises:
800
1041
  [`InferenceTimeoutError`]:
@@ -814,7 +1055,7 @@ class AsyncInferenceClient:
814
1055
  ```
815
1056
  """
816
1057
  response = await self.post(data=image, model=model, task="image-to-text")
817
- return _bytes_to_dict(response)[0]["generated_text"]
1058
+ return ImageToTextOutput.parse_obj_as_instance(response)
818
1059
 
819
1060
  async def list_deployed_models(
820
1061
  self, frameworks: Union[None, str, Literal["all"], List[str]] = None
@@ -902,7 +1143,7 @@ class AsyncInferenceClient:
902
1143
  image: ContentT,
903
1144
  *,
904
1145
  model: Optional[str] = None,
905
- ) -> List[ObjectDetectionOutput]:
1146
+ ) -> List[ObjectDetectionOutputElement]:
906
1147
  """
907
1148
  Perform object detection on the given image using the specified model.
908
1149
 
@@ -920,7 +1161,7 @@ class AsyncInferenceClient:
920
1161
  deployed Inference Endpoint. If not provided, the default recommended model for object detection (DETR) will be used.
921
1162
 
922
1163
  Returns:
923
- `List[ObjectDetectionOutput]`: A list of dictionaries containing the bounding boxes and associated attributes.
1164
+ `List[ObjectDetectionOutputElement]`: A list of [`ObjectDetectionOutputElement`] items containing the bounding boxes and associated attributes.
924
1165
 
925
1166
  Raises:
926
1167
  [`InferenceTimeoutError`]:
@@ -936,19 +1177,16 @@ class AsyncInferenceClient:
936
1177
  >>> from huggingface_hub import AsyncInferenceClient
937
1178
  >>> client = AsyncInferenceClient()
938
1179
  >>> await client.object_detection("people.jpg"):
939
- [{"score":0.9486683011054993,"label":"person","box":{"xmin":59,"ymin":39,"xmax":420,"ymax":510}}, ... ]
1180
+ [ObjectDetectionOutputElement(score=0.9486683011054993, label='person', box=ObjectDetectionBoundingBox(xmin=59, ymin=39, xmax=420, ymax=510)), ...]
940
1181
  ```
941
1182
  """
942
1183
  # detect objects
943
1184
  response = await self.post(data=image, model=model, task="object-detection")
944
- output = _bytes_to_dict(response)
945
- if not isinstance(output, list):
946
- raise ValueError(f"Server output must be a list. Got {type(output)}: {str(output)[:200]}...")
947
- return output
1185
+ return ObjectDetectionOutputElement.parse_obj_as_list(response)
948
1186
 
949
1187
  async def question_answering(
950
1188
  self, question: str, context: str, *, model: Optional[str] = None
951
- ) -> QuestionAnsweringOutput:
1189
+ ) -> QuestionAnsweringOutputElement:
952
1190
  """
953
1191
  Retrieve the answer to a question from a given text.
954
1192
 
@@ -962,7 +1200,7 @@ class AsyncInferenceClient:
962
1200
  a deployed Inference Endpoint.
963
1201
 
964
1202
  Returns:
965
- `Dict`: a dictionary of question answering output containing the score, start index, end index, and answer.
1203
+ [`QuestionAnsweringOutputElement`]: an question answering output containing the score, start index, end index, and answer.
966
1204
 
967
1205
  Raises:
968
1206
  [`InferenceTimeoutError`]:
@@ -976,7 +1214,7 @@ class AsyncInferenceClient:
976
1214
  >>> from huggingface_hub import AsyncInferenceClient
977
1215
  >>> client = AsyncInferenceClient()
978
1216
  >>> await client.question_answering(question="What's my name?", context="My name is Clara and I live in Berkeley.")
979
- {'score': 0.9326562285423279, 'start': 11, 'end': 16, 'answer': 'Clara'}
1217
+ QuestionAnsweringOutputElement(score=0.9326562285423279, start=11, end=16, answer='Clara')
980
1218
  ```
981
1219
  """
982
1220
 
@@ -986,7 +1224,7 @@ class AsyncInferenceClient:
986
1224
  model=model,
987
1225
  task="question-answering",
988
1226
  )
989
- return _bytes_to_dict(response) # type: ignore
1227
+ return QuestionAnsweringOutputElement.parse_obj_as_instance(response)
990
1228
 
991
1229
  async def sentence_similarity(
992
1230
  self, sentence: str, other_sentences: List[str], *, model: Optional[str] = None
@@ -1042,7 +1280,7 @@ class AsyncInferenceClient:
1042
1280
  *,
1043
1281
  parameters: Optional[Dict[str, Any]] = None,
1044
1282
  model: Optional[str] = None,
1045
- ) -> str:
1283
+ ) -> SummarizationOutput:
1046
1284
  """
1047
1285
  Generate a summary of a given text using a specified model.
1048
1286
 
@@ -1057,7 +1295,7 @@ class AsyncInferenceClient:
1057
1295
  Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
1058
1296
 
1059
1297
  Returns:
1060
- `str`: The generated summary text.
1298
+ [`SummarizationOutput`]: The generated summary text.
1061
1299
 
1062
1300
  Raises:
1063
1301
  [`InferenceTimeoutError`]:
@@ -1071,18 +1309,18 @@ class AsyncInferenceClient:
1071
1309
  >>> from huggingface_hub import AsyncInferenceClient
1072
1310
  >>> client = AsyncInferenceClient()
1073
1311
  >>> await client.summarization("The Eiffel tower...")
1074
- 'The Eiffel tower is one of the most famous landmarks in the world....'
1312
+ SummarizationOutput(generated_text="The Eiffel tower is one of the most famous landmarks in the world....")
1075
1313
  ```
1076
1314
  """
1077
1315
  payload: Dict[str, Any] = {"inputs": text}
1078
1316
  if parameters is not None:
1079
1317
  payload["parameters"] = parameters
1080
1318
  response = await self.post(json=payload, model=model, task="summarization")
1081
- return _bytes_to_dict(response)[0]["summary_text"]
1319
+ return SummarizationOutput.parse_obj_as_list(response)[0]
1082
1320
 
1083
1321
  async def table_question_answering(
1084
1322
  self, table: Dict[str, Any], query: str, *, model: Optional[str] = None
1085
- ) -> TableQuestionAnsweringOutput:
1323
+ ) -> TableQuestionAnsweringOutputElement:
1086
1324
  """
1087
1325
  Retrieve the answer to a question from information given in a table.
1088
1326
 
@@ -1097,7 +1335,7 @@ class AsyncInferenceClient:
1097
1335
  Hub or a URL to a deployed Inference Endpoint.
1098
1336
 
1099
1337
  Returns:
1100
- `Dict`: a dictionary of table question answering output containing the answer, coordinates, cells and the aggregator used.
1338
+ [`TableQuestionAnsweringOutputElement`]: a table question answering output containing the answer, coordinates, cells and the aggregator used.
1101
1339
 
1102
1340
  Raises:
1103
1341
  [`InferenceTimeoutError`]:
@@ -1113,7 +1351,7 @@ class AsyncInferenceClient:
1113
1351
  >>> query = "How many stars does the transformers repository have?"
1114
1352
  >>> table = {"Repository": ["Transformers", "Datasets", "Tokenizers"], "Stars": ["36542", "4512", "3934"]}
1115
1353
  >>> await client.table_question_answering(table, query, model="google/tapas-base-finetuned-wtq")
1116
- {'answer': 'AVERAGE > 36542', 'coordinates': [[0, 1]], 'cells': ['36542'], 'aggregator': 'AVERAGE'}
1354
+ TableQuestionAnsweringOutputElement(answer='36542', coordinates=[[0, 1]], cells=['36542'], aggregator='AVERAGE')
1117
1355
  ```
1118
1356
  """
1119
1357
  response = await self.post(
@@ -1124,7 +1362,7 @@ class AsyncInferenceClient:
1124
1362
  model=model,
1125
1363
  task="table-question-answering",
1126
1364
  )
1127
- return _bytes_to_dict(response) # type: ignore
1365
+ return TableQuestionAnsweringOutputElement.parse_obj_as_instance(response)
1128
1366
 
1129
1367
  async def tabular_classification(self, table: Dict[str, Any], *, model: Optional[str] = None) -> List[str]:
1130
1368
  """
@@ -1213,7 +1451,9 @@ class AsyncInferenceClient:
1213
1451
  response = await self.post(json={"table": table}, model=model, task="tabular-regression")
1214
1452
  return _bytes_to_list(response)
1215
1453
 
1216
- async def text_classification(self, text: str, *, model: Optional[str] = None) -> List[ClassificationOutput]:
1454
+ async def text_classification(
1455
+ self, text: str, *, model: Optional[str] = None
1456
+ ) -> List[TextClassificationOutputElement]:
1217
1457
  """
1218
1458
  Perform text classification (e.g. sentiment-analysis) on the given text.
1219
1459
 
@@ -1226,7 +1466,7 @@ class AsyncInferenceClient:
1226
1466
  Defaults to None.
1227
1467
 
1228
1468
  Returns:
1229
- `List[Dict]`: a list of dictionaries containing the predicted label and associated probability.
1469
+ `List[TextClassificationOutputElement]`: a list of [`TextClassificationOutputElement`] items containing the predicted label and associated probability.
1230
1470
 
1231
1471
  Raises:
1232
1472
  [`InferenceTimeoutError`]:
@@ -1240,11 +1480,14 @@ class AsyncInferenceClient:
1240
1480
  >>> from huggingface_hub import AsyncInferenceClient
1241
1481
  >>> client = AsyncInferenceClient()
1242
1482
  >>> await client.text_classification("I like you")
1243
- [{'label': 'POSITIVE', 'score': 0.9998695850372314}, {'label': 'NEGATIVE', 'score': 0.0001304351753788069}]
1483
+ [
1484
+ TextClassificationOutputElement(label='POSITIVE', score=0.9998695850372314),
1485
+ TextClassificationOutputElement(label='NEGATIVE', score=0.0001304351753788069),
1486
+ ]
1244
1487
  ```
1245
1488
  """
1246
1489
  response = await self.post(json={"inputs": text}, model=model, task="text-classification")
1247
- return _bytes_to_list(response)[0]
1490
+ return TextClassificationOutputElement.parse_obj_as_list(response)[0] # type: ignore [return-value]
1248
1491
 
1249
1492
  @overload
1250
1493
  async def text_generation( # type: ignore
@@ -1267,8 +1510,7 @@ class AsyncInferenceClient:
1267
1510
  truncate: Optional[int] = None,
1268
1511
  typical_p: Optional[float] = None,
1269
1512
  watermark: bool = False,
1270
- ) -> str:
1271
- ...
1513
+ ) -> str: ...
1272
1514
 
1273
1515
  @overload
1274
1516
  async def text_generation( # type: ignore
@@ -1291,8 +1533,7 @@ class AsyncInferenceClient:
1291
1533
  truncate: Optional[int] = None,
1292
1534
  typical_p: Optional[float] = None,
1293
1535
  watermark: bool = False,
1294
- ) -> TextGenerationResponse:
1295
- ...
1536
+ ) -> TextGenerationOutput: ...
1296
1537
 
1297
1538
  @overload
1298
1539
  async def text_generation( # type: ignore
@@ -1315,11 +1556,10 @@ class AsyncInferenceClient:
1315
1556
  truncate: Optional[int] = None,
1316
1557
  typical_p: Optional[float] = None,
1317
1558
  watermark: bool = False,
1318
- ) -> AsyncIterable[str]:
1319
- ...
1559
+ ) -> AsyncIterable[str]: ...
1320
1560
 
1321
1561
  @overload
1322
- async def text_generation(
1562
+ async def text_generation( # type: ignore
1323
1563
  self,
1324
1564
  prompt: str,
1325
1565
  *,
@@ -1339,8 +1579,30 @@ class AsyncInferenceClient:
1339
1579
  truncate: Optional[int] = None,
1340
1580
  typical_p: Optional[float] = None,
1341
1581
  watermark: bool = False,
1342
- ) -> AsyncIterable[TextGenerationStreamResponse]:
1343
- ...
1582
+ ) -> AsyncIterable[TextGenerationStreamOutput]: ...
1583
+
1584
+ @overload
1585
+ async def text_generation(
1586
+ self,
1587
+ prompt: str,
1588
+ *,
1589
+ details: Literal[True] = ...,
1590
+ stream: bool = ...,
1591
+ model: Optional[str] = None,
1592
+ do_sample: bool = False,
1593
+ max_new_tokens: int = 20,
1594
+ best_of: Optional[int] = None,
1595
+ repetition_penalty: Optional[float] = None,
1596
+ return_full_text: bool = False,
1597
+ seed: Optional[int] = None,
1598
+ stop_sequences: Optional[List[str]] = None,
1599
+ temperature: Optional[float] = None,
1600
+ top_k: Optional[int] = None,
1601
+ top_p: Optional[float] = None,
1602
+ truncate: Optional[int] = None,
1603
+ typical_p: Optional[float] = None,
1604
+ watermark: bool = False,
1605
+ ) -> Union[TextGenerationOutput, AsyncIterable[TextGenerationStreamOutput]]: ...
1344
1606
 
1345
1607
  async def text_generation(
1346
1608
  self,
@@ -1363,13 +1625,10 @@ class AsyncInferenceClient:
1363
1625
  typical_p: Optional[float] = None,
1364
1626
  watermark: bool = False,
1365
1627
  decoder_input_details: bool = False,
1366
- ) -> Union[str, TextGenerationResponse, AsyncIterable[str], AsyncIterable[TextGenerationStreamResponse]]:
1628
+ ) -> Union[str, TextGenerationOutput, AsyncIterable[str], AsyncIterable[TextGenerationStreamOutput]]:
1367
1629
  """
1368
1630
  Given a prompt, generate the following text.
1369
1631
 
1370
- It is recommended to have Pydantic installed in order to get inputs validated. This is preferable as it allow
1371
- early failures.
1372
-
1373
1632
  API endpoint is supposed to run with the `text-generation-inference` backend (TGI). This backend is the
1374
1633
  go-to solution to run large language models at scale. However, for some smaller models (e.g. "gpt2") the
1375
1634
  default `transformers` + `api-inference` solution is still in use. Both approaches have very similar APIs, but
@@ -1427,12 +1686,12 @@ class AsyncInferenceClient:
1427
1686
  into account. Defaults to `False`.
1428
1687
 
1429
1688
  Returns:
1430
- `Union[str, TextGenerationResponse, Iterable[str], Iterable[TextGenerationStreamResponse]]`:
1689
+ `Union[str, TextGenerationOutput, Iterable[str], Iterable[TextGenerationStreamOutput]]`:
1431
1690
  Generated text returned from the server:
1432
1691
  - if `stream=False` and `details=False`, the generated text is returned as a `str` (default)
1433
1692
  - if `stream=True` and `details=False`, the generated text is returned token by token as a `Iterable[str]`
1434
- - if `stream=False` and `details=True`, the generated text is returned with more details as a [`~huggingface_hub.inference._text_generation.TextGenerationResponse`]
1435
- - if `details=True` and `stream=True`, the generated text is returned token by token as a iterable of [`~huggingface_hub.inference._text_generation.TextGenerationStreamResponse`]
1693
+ - if `stream=False` and `details=True`, the generated text is returned with more details as a [`~huggingface_hub.TextGenerationOutput`]
1694
+ - if `details=True` and `stream=True`, the generated text is returned token by token as a iterable of [`~huggingface_hub.TextGenerationStreamOutput`]
1436
1695
 
1437
1696
  Raises:
1438
1697
  `ValidationError`:
@@ -1470,23 +1729,23 @@ class AsyncInferenceClient:
1470
1729
 
1471
1730
  # Case 3: get more details about the generation process.
1472
1731
  >>> await client.text_generation("The huggingface_hub library is ", max_new_tokens=12, details=True)
1473
- TextGenerationResponse(
1732
+ TextGenerationOutput(
1474
1733
  generated_text='100% open source and built to be easy to use.',
1475
- details=Details(
1476
- finish_reason=<FinishReason.Length: 'length'>,
1734
+ details=TextGenerationDetails(
1735
+ finish_reason='length',
1477
1736
  generated_tokens=12,
1478
1737
  seed=None,
1479
1738
  prefill=[
1480
- InputToken(id=487, text='The', logprob=None),
1481
- InputToken(id=53789, text=' hugging', logprob=-13.171875),
1739
+ TextGenerationPrefillToken(id=487, text='The', logprob=None),
1740
+ TextGenerationPrefillToken(id=53789, text=' hugging', logprob=-13.171875),
1482
1741
  (...)
1483
- InputToken(id=204, text=' ', logprob=-7.0390625)
1742
+ TextGenerationPrefillToken(id=204, text=' ', logprob=-7.0390625)
1484
1743
  ],
1485
1744
  tokens=[
1486
- Token(id=1425, text='100', logprob=-1.0175781, special=False),
1487
- Token(id=16, text='%', logprob=-0.0463562, special=False),
1745
+ TokenElement(id=1425, text='100', logprob=-1.0175781, special=False),
1746
+ TokenElement(id=16, text='%', logprob=-0.0463562, special=False),
1488
1747
  (...)
1489
- Token(id=25, text='.', logprob=-0.5703125, special=False)
1748
+ TokenElement(id=25, text='.', logprob=-0.5703125, special=False)
1490
1749
  ],
1491
1750
  best_of_sequences=None
1492
1751
  )
@@ -1497,30 +1756,27 @@ class AsyncInferenceClient:
1497
1756
  >>> async for details in await client.text_generation("The huggingface_hub library is ", max_new_tokens=12, details=True, stream=True):
1498
1757
  ... print(details)
1499
1758
  ...
1500
- TextGenerationStreamResponse(token=Token(id=1425, text='100', logprob=-1.0175781, special=False), generated_text=None, details=None)
1501
- TextGenerationStreamResponse(token=Token(id=16, text='%', logprob=-0.0463562, special=False), generated_text=None, details=None)
1502
- TextGenerationStreamResponse(token=Token(id=1314, text=' open', logprob=-1.3359375, special=False), generated_text=None, details=None)
1503
- TextGenerationStreamResponse(token=Token(id=3178, text=' source', logprob=-0.28100586, special=False), generated_text=None, details=None)
1504
- TextGenerationStreamResponse(token=Token(id=273, text=' and', logprob=-0.5961914, special=False), generated_text=None, details=None)
1505
- TextGenerationStreamResponse(token=Token(id=3426, text=' built', logprob=-1.9423828, special=False), generated_text=None, details=None)
1506
- TextGenerationStreamResponse(token=Token(id=271, text=' to', logprob=-1.4121094, special=False), generated_text=None, details=None)
1507
- TextGenerationStreamResponse(token=Token(id=314, text=' be', logprob=-1.5224609, special=False), generated_text=None, details=None)
1508
- TextGenerationStreamResponse(token=Token(id=1833, text=' easy', logprob=-2.1132812, special=False), generated_text=None, details=None)
1509
- TextGenerationStreamResponse(token=Token(id=271, text=' to', logprob=-0.08520508, special=False), generated_text=None, details=None)
1510
- TextGenerationStreamResponse(token=Token(id=745, text=' use', logprob=-0.39453125, special=False), generated_text=None, details=None)
1511
- TextGenerationStreamResponse(token=Token(
1759
+ TextGenerationStreamOutput(token=TokenElement(id=1425, text='100', logprob=-1.0175781, special=False), generated_text=None, details=None)
1760
+ TextGenerationStreamOutput(token=TokenElement(id=16, text='%', logprob=-0.0463562, special=False), generated_text=None, details=None)
1761
+ TextGenerationStreamOutput(token=TokenElement(id=1314, text=' open', logprob=-1.3359375, special=False), generated_text=None, details=None)
1762
+ TextGenerationStreamOutput(token=TokenElement(id=3178, text=' source', logprob=-0.28100586, special=False), generated_text=None, details=None)
1763
+ TextGenerationStreamOutput(token=TokenElement(id=273, text=' and', logprob=-0.5961914, special=False), generated_text=None, details=None)
1764
+ TextGenerationStreamOutput(token=TokenElement(id=3426, text=' built', logprob=-1.9423828, special=False), generated_text=None, details=None)
1765
+ TextGenerationStreamOutput(token=TokenElement(id=271, text=' to', logprob=-1.4121094, special=False), generated_text=None, details=None)
1766
+ TextGenerationStreamOutput(token=TokenElement(id=314, text=' be', logprob=-1.5224609, special=False), generated_text=None, details=None)
1767
+ TextGenerationStreamOutput(token=TokenElement(id=1833, text=' easy', logprob=-2.1132812, special=False), generated_text=None, details=None)
1768
+ TextGenerationStreamOutput(token=TokenElement(id=271, text=' to', logprob=-0.08520508, special=False), generated_text=None, details=None)
1769
+ TextGenerationStreamOutput(token=TokenElement(id=745, text=' use', logprob=-0.39453125, special=False), generated_text=None, details=None)
1770
+ TextGenerationStreamOutput(token=TokenElement(
1512
1771
  id=25,
1513
1772
  text='.',
1514
1773
  logprob=-0.5703125,
1515
1774
  special=False),
1516
1775
  generated_text='100% open source and built to be easy to use.',
1517
- details=StreamDetails(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=12, seed=None)
1776
+ details=TextGenerationStreamDetails(finish_reason='length', generated_tokens=12, seed=None)
1518
1777
  )
1519
1778
  ```
1520
1779
  """
1521
- # NOTE: Text-generation integration is taken from the text-generation-inference project. It has more features
1522
- # like input/output validation (if Pydantic is installed). See `_text_generation.py` header for more details.
1523
-
1524
1780
  if decoder_input_details and not details:
1525
1781
  warnings.warn(
1526
1782
  "`decoder_input_details=True` has been passed to the server but `details=False` is set meaning that"
@@ -1528,34 +1784,38 @@ class AsyncInferenceClient:
1528
1784
  )
1529
1785
  decoder_input_details = False
1530
1786
 
1531
- # Validate parameters
1532
- parameters = TextGenerationParameters(
1533
- best_of=best_of,
1534
- details=details,
1535
- do_sample=do_sample,
1536
- max_new_tokens=max_new_tokens,
1537
- repetition_penalty=repetition_penalty,
1538
- return_full_text=return_full_text,
1539
- seed=seed,
1540
- stop=stop_sequences if stop_sequences is not None else [],
1541
- temperature=temperature,
1542
- top_k=top_k,
1543
- top_p=top_p,
1544
- truncate=truncate,
1545
- typical_p=typical_p,
1546
- watermark=watermark,
1547
- decoder_input_details=decoder_input_details,
1548
- )
1549
- request = TextGenerationRequest(inputs=prompt, stream=stream, parameters=parameters)
1550
- payload = asdict(request)
1787
+ # Build payload
1788
+ payload = {
1789
+ "inputs": prompt,
1790
+ "parameters": {
1791
+ "best_of": best_of,
1792
+ "decoder_input_details": decoder_input_details,
1793
+ "details": details,
1794
+ "do_sample": do_sample,
1795
+ "max_new_tokens": max_new_tokens,
1796
+ "repetition_penalty": repetition_penalty,
1797
+ "return_full_text": return_full_text,
1798
+ "seed": seed,
1799
+ "stop": stop_sequences if stop_sequences is not None else [],
1800
+ "temperature": temperature,
1801
+ "top_k": top_k,
1802
+ "top_p": top_p,
1803
+ "truncate": truncate,
1804
+ "typical_p": typical_p,
1805
+ "watermark": watermark,
1806
+ },
1807
+ "stream": stream,
1808
+ }
1551
1809
 
1552
1810
  # Remove some parameters if not a TGI server
1553
1811
  if not _is_tgi_server(model):
1812
+ parameters: Dict = payload["parameters"] # type: ignore [assignment]
1813
+
1554
1814
  ignored_parameters = []
1555
- for key in "watermark", "stop", "details", "decoder_input_details", "best_of":
1556
- if payload["parameters"][key] is not None:
1815
+ for key in "watermark", "details", "decoder_input_details", "best_of", "stop", "return_full_text":
1816
+ if parameters[key] is not None:
1557
1817
  ignored_parameters.append(key)
1558
- del payload["parameters"][key]
1818
+ del parameters[key]
1559
1819
  if len(ignored_parameters) > 0:
1560
1820
  warnings.warn(
1561
1821
  "API endpoint/model for text-generation is not served via TGI. Ignoring parameters"
@@ -1608,8 +1868,8 @@ class AsyncInferenceClient:
1608
1868
  if stream:
1609
1869
  return _async_stream_text_generation_response(bytes_output, details) # type: ignore
1610
1870
 
1611
- data = _bytes_to_dict(bytes_output)[0]
1612
- return TextGenerationResponse(**data) if details else data["generated_text"]
1871
+ data = _bytes_to_dict(bytes_output)[0] # type: ignore[arg-type]
1872
+ return TextGenerationOutput.parse_obj_as_instance(data) if details else data["generated_text"]
1613
1873
 
1614
1874
  async def text_to_image(
1615
1875
  self,
@@ -1725,7 +1985,9 @@ class AsyncInferenceClient:
1725
1985
  """
1726
1986
  return await self.post(json={"inputs": text}, model=model, task="text-to-speech")
1727
1987
 
1728
- async def token_classification(self, text: str, *, model: Optional[str] = None) -> List[TokenClassificationOutput]:
1988
+ async def token_classification(
1989
+ self, text: str, *, model: Optional[str] = None
1990
+ ) -> List[TokenClassificationOutputElement]:
1729
1991
  """
1730
1992
  Perform token classification on the given text.
1731
1993
  Usually used for sentence parsing, either grammatical, or Named Entity Recognition (NER) to understand keywords contained within text.
@@ -1739,7 +2001,7 @@ class AsyncInferenceClient:
1739
2001
  Defaults to None.
1740
2002
 
1741
2003
  Returns:
1742
- `List[Dict]`: List of token classification outputs containing the entity group, confidence score, word, start and end index.
2004
+ `List[TokenClassificationOutputElement]`: List of [`TokenClassificationOutputElement`] items containing the entity group, confidence score, word, start and end index.
1743
2005
 
1744
2006
  Raises:
1745
2007
  [`InferenceTimeoutError`]:
@@ -1753,16 +2015,22 @@ class AsyncInferenceClient:
1753
2015
  >>> from huggingface_hub import AsyncInferenceClient
1754
2016
  >>> client = AsyncInferenceClient()
1755
2017
  >>> await client.token_classification("My name is Sarah Jessica Parker but you can call me Jessica")
1756
- [{'entity_group': 'PER',
1757
- 'score': 0.9971321225166321,
1758
- 'word': 'Sarah Jessica Parker',
1759
- 'start': 11,
1760
- 'end': 31},
1761
- {'entity_group': 'PER',
1762
- 'score': 0.9773476123809814,
1763
- 'word': 'Jessica',
1764
- 'start': 52,
1765
- 'end': 59}]
2018
+ [
2019
+ TokenClassificationOutputElement(
2020
+ entity_group='PER',
2021
+ score=0.9971321225166321,
2022
+ word='Sarah Jessica Parker',
2023
+ start=11,
2024
+ end=31,
2025
+ ),
2026
+ TokenClassificationOutputElement(
2027
+ entity_group='PER',
2028
+ score=0.9773476123809814,
2029
+ word='Jessica',
2030
+ start=52,
2031
+ end=59,
2032
+ )
2033
+ ]
1766
2034
  ```
1767
2035
  """
1768
2036
  payload: Dict[str, Any] = {"inputs": text}
@@ -1771,11 +2039,11 @@ class AsyncInferenceClient:
1771
2039
  model=model,
1772
2040
  task="token-classification",
1773
2041
  )
1774
- return _bytes_to_list(response)
2042
+ return TokenClassificationOutputElement.parse_obj_as_list(response)
1775
2043
 
1776
2044
  async def translation(
1777
2045
  self, text: str, *, model: Optional[str] = None, src_lang: Optional[str] = None, tgt_lang: Optional[str] = None
1778
- ) -> str:
2046
+ ) -> TranslationOutput:
1779
2047
  """
1780
2048
  Convert text from one language to another.
1781
2049
 
@@ -1798,7 +2066,7 @@ class AsyncInferenceClient:
1798
2066
  Target language of the translation task, i.e. output language. Cannot be passed without `src_lang`.
1799
2067
 
1800
2068
  Returns:
1801
- `str`: The generated translated text.
2069
+ [`TranslationOutput`]: The generated translated text.
1802
2070
 
1803
2071
  Raises:
1804
2072
  [`InferenceTimeoutError`]:
@@ -1816,7 +2084,7 @@ class AsyncInferenceClient:
1816
2084
  >>> await client.translation("My name is Wolfgang and I live in Berlin")
1817
2085
  'Mein Name ist Wolfgang und ich lebe in Berlin.'
1818
2086
  >>> await client.translation("My name is Wolfgang and I live in Berlin", model="Helsinki-NLP/opus-mt-en-fr")
1819
- "Je m'appelle Wolfgang et je vis à Berlin."
2087
+ TranslationOutput(translation_text='Je m\'appelle Wolfgang et je vis à Berlin.')
1820
2088
  ```
1821
2089
 
1822
2090
  Specifying languages:
@@ -1837,11 +2105,59 @@ class AsyncInferenceClient:
1837
2105
  if src_lang and tgt_lang:
1838
2106
  payload["parameters"] = {"src_lang": src_lang, "tgt_lang": tgt_lang}
1839
2107
  response = await self.post(json=payload, model=model, task="translation")
1840
- return _bytes_to_dict(response)[0]["translation_text"]
2108
+ return TranslationOutput.parse_obj_as_list(response)[0]
2109
+
2110
+ async def visual_question_answering(
2111
+ self,
2112
+ image: ContentT,
2113
+ question: str,
2114
+ *,
2115
+ model: Optional[str] = None,
2116
+ ) -> List[VisualQuestionAnsweringOutputElement]:
2117
+ """
2118
+ Answering open-ended questions based on an image.
2119
+
2120
+ Args:
2121
+ image (`Union[str, Path, bytes, BinaryIO]`):
2122
+ The input image for the context. It can be raw bytes, an image file, or a URL to an online image.
2123
+ question (`str`):
2124
+ Question to be answered.
2125
+ model (`str`, *optional*):
2126
+ The model to use for the visual question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to
2127
+ a deployed Inference Endpoint. If not provided, the default recommended visual question answering model will be used.
2128
+ Defaults to None.
2129
+
2130
+ Returns:
2131
+ `List[VisualQuestionAnsweringOutputElement]`: a list of [`VisualQuestionAnsweringOutputElement`] items containing the predicted label and associated probability.
2132
+
2133
+ Raises:
2134
+ `InferenceTimeoutError`:
2135
+ If the model is unavailable or the request times out.
2136
+ `aiohttp.ClientResponseError`:
2137
+ If the request fails with an HTTP error status code other than HTTP 503.
2138
+
2139
+ Example:
2140
+ ```py
2141
+ # Must be run in an async context
2142
+ >>> from huggingface_hub import AsyncInferenceClient
2143
+ >>> client = AsyncInferenceClient()
2144
+ >>> await client.visual_question_answering(
2145
+ ... image="https://huggingface.co/datasets/mishig/sample_images/resolve/main/tiger.jpg",
2146
+ ... question="What is the animal doing?"
2147
+ ... )
2148
+ [
2149
+ VisualQuestionAnsweringOutputElement(score=0.778609573841095, answer='laying down'),
2150
+ VisualQuestionAnsweringOutputElement(score=0.6957435607910156, answer='sitting'),
2151
+ ]
2152
+ ```
2153
+ """
2154
+ payload: Dict[str, Any] = {"question": question, "image": _b64_encode(image)}
2155
+ response = await self.post(json=payload, model=model, task="visual-question-answering")
2156
+ return VisualQuestionAnsweringOutputElement.parse_obj_as_list(response)
1841
2157
 
1842
2158
  async def zero_shot_classification(
1843
2159
  self, text: str, labels: List[str], *, multi_label: bool = False, model: Optional[str] = None
1844
- ) -> List[ClassificationOutput]:
2160
+ ) -> List[ZeroShotClassificationOutputElement]:
1845
2161
  """
1846
2162
  Provide as input a text and a set of candidate labels to classify the input text.
1847
2163
 
@@ -1857,7 +2173,7 @@ class AsyncInferenceClient:
1857
2173
  Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
1858
2174
 
1859
2175
  Returns:
1860
- `List[Dict]`: List of classification outputs containing the predicted labels and their confidence.
2176
+ `List[ZeroShotClassificationOutputElement]`: List of [`ZeroShotClassificationOutputElement`] items containing the predicted labels and their confidence.
1861
2177
 
1862
2178
  Raises:
1863
2179
  [`InferenceTimeoutError`]:
@@ -1878,19 +2194,19 @@ class AsyncInferenceClient:
1878
2194
  >>> labels = ["space & cosmos", "scientific discovery", "microbiology", "robots", "archeology"]
1879
2195
  >>> await client.zero_shot_classification(text, labels)
1880
2196
  [
1881
- {"label": "scientific discovery", "score": 0.7961668968200684},
1882
- {"label": "space & cosmos", "score": 0.18570658564567566},
1883
- {"label": "microbiology", "score": 0.00730885099619627},
1884
- {"label": "archeology", "score": 0.006258360575884581},
1885
- {"label": "robots", "score": 0.004559356719255447},
2197
+ ZeroShotClassificationOutputElement(label='scientific discovery', score=0.7961668968200684),
2198
+ ZeroShotClassificationOutputElement(label='space & cosmos', score=0.18570658564567566),
2199
+ ZeroShotClassificationOutputElement(label='microbiology', score=0.00730885099619627),
2200
+ ZeroShotClassificationOutputElement(label='archeology', score=0.006258360575884581),
2201
+ ZeroShotClassificationOutputElement(label='robots', score=0.004559356719255447),
1886
2202
  ]
1887
2203
  >>> await client.zero_shot_classification(text, labels, multi_label=True)
1888
2204
  [
1889
- {"label": "scientific discovery", "score": 0.9829297661781311},
1890
- {"label": "space & cosmos", "score": 0.755190908908844},
1891
- {"label": "microbiology", "score": 0.0005462635890580714},
1892
- {"label": "archeology", "score": 0.00047131875180639327},
1893
- {"label": "robots", "score": 0.00030448526376858354},
2205
+ ZeroShotClassificationOutputElement(label='scientific discovery', score=0.9829297661781311),
2206
+ ZeroShotClassificationOutputElement(label='space & cosmos', score=0.755190908908844),
2207
+ ZeroShotClassificationOutputElement(label='microbiology', score=0.0005462635890580714),
2208
+ ZeroShotClassificationOutputElement(label='archeology', score=0.00047131875180639327),
2209
+ ZeroShotClassificationOutputElement(label='robots', score=0.00030448526376858354),
1894
2210
  ]
1895
2211
  ```
1896
2212
  """
@@ -1910,11 +2226,14 @@ class AsyncInferenceClient:
1910
2226
  task="zero-shot-classification",
1911
2227
  )
1912
2228
  output = _bytes_to_dict(response)
1913
- return [{"label": label, "score": score} for label, score in zip(output["labels"], output["scores"])]
2229
+ return [
2230
+ ZeroShotClassificationOutputElement.parse_obj_as_instance({"label": label, "score": score})
2231
+ for label, score in zip(output["labels"], output["scores"])
2232
+ ]
1914
2233
 
1915
2234
  async def zero_shot_image_classification(
1916
2235
  self, image: ContentT, labels: List[str], *, model: Optional[str] = None
1917
- ) -> List[ClassificationOutput]:
2236
+ ) -> List[ZeroShotImageClassificationOutputElement]:
1918
2237
  """
1919
2238
  Provide input image and text labels to predict text labels for the image.
1920
2239
 
@@ -1928,7 +2247,7 @@ class AsyncInferenceClient:
1928
2247
  Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
1929
2248
 
1930
2249
  Returns:
1931
- `List[Dict]`: List of classification outputs containing the predicted labels and their confidence.
2250
+ `List[ZeroShotImageClassificationOutputElement]`: List of [`ZeroShotImageClassificationOutputElement`] items containing the predicted labels and their confidence.
1932
2251
 
1933
2252
  Raises:
1934
2253
  [`InferenceTimeoutError`]:
@@ -1946,7 +2265,7 @@ class AsyncInferenceClient:
1946
2265
  ... "https://upload.wikimedia.org/wikipedia/commons/thumb/4/43/Cute_dog.jpg/320px-Cute_dog.jpg",
1947
2266
  ... labels=["dog", "cat", "horse"],
1948
2267
  ... )
1949
- [{"label": "dog", "score": 0.956}, ...]
2268
+ [ZeroShotImageClassificationOutputElement(label='dog', score=0.956),...]
1950
2269
  ```
1951
2270
  """
1952
2271
  # Raise ValueError if input is less than 2 labels
@@ -1958,7 +2277,7 @@ class AsyncInferenceClient:
1958
2277
  model=model,
1959
2278
  task="zero-shot-image-classification",
1960
2279
  )
1961
- return _bytes_to_list(response)
2280
+ return ZeroShotImageClassificationOutputElement.parse_obj_as_list(response)
1962
2281
 
1963
2282
  def _resolve_url(self, model: Optional[str] = None, task: Optional[str] = None) -> str:
1964
2283
  model = model or self.model