huggingface-hub 0.21.4__py3-none-any.whl → 0.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of huggingface-hub might be problematic. Click here for more details.

Files changed (97) hide show
  1. huggingface_hub/__init__.py +217 -1
  2. huggingface_hub/_commit_api.py +14 -15
  3. huggingface_hub/_inference_endpoints.py +12 -11
  4. huggingface_hub/_login.py +1 -0
  5. huggingface_hub/_multi_commits.py +1 -0
  6. huggingface_hub/_snapshot_download.py +9 -1
  7. huggingface_hub/_tensorboard_logger.py +1 -0
  8. huggingface_hub/_webhooks_payload.py +1 -0
  9. huggingface_hub/_webhooks_server.py +1 -0
  10. huggingface_hub/commands/_cli_utils.py +1 -0
  11. huggingface_hub/commands/delete_cache.py +1 -0
  12. huggingface_hub/commands/download.py +1 -0
  13. huggingface_hub/commands/env.py +1 -0
  14. huggingface_hub/commands/scan_cache.py +1 -0
  15. huggingface_hub/commands/upload.py +1 -0
  16. huggingface_hub/community.py +1 -0
  17. huggingface_hub/constants.py +3 -1
  18. huggingface_hub/errors.py +38 -0
  19. huggingface_hub/file_download.py +102 -95
  20. huggingface_hub/hf_api.py +47 -35
  21. huggingface_hub/hf_file_system.py +77 -3
  22. huggingface_hub/hub_mixin.py +215 -54
  23. huggingface_hub/inference/_client.py +554 -239
  24. huggingface_hub/inference/_common.py +195 -41
  25. huggingface_hub/inference/_generated/_async_client.py +558 -239
  26. huggingface_hub/inference/_generated/types/__init__.py +115 -0
  27. huggingface_hub/inference/_generated/types/audio_classification.py +43 -0
  28. huggingface_hub/inference/_generated/types/audio_to_audio.py +31 -0
  29. huggingface_hub/inference/_generated/types/automatic_speech_recognition.py +116 -0
  30. huggingface_hub/inference/_generated/types/base.py +149 -0
  31. huggingface_hub/inference/_generated/types/chat_completion.py +106 -0
  32. huggingface_hub/inference/_generated/types/depth_estimation.py +29 -0
  33. huggingface_hub/inference/_generated/types/document_question_answering.py +85 -0
  34. huggingface_hub/inference/_generated/types/feature_extraction.py +19 -0
  35. huggingface_hub/inference/_generated/types/fill_mask.py +50 -0
  36. huggingface_hub/inference/_generated/types/image_classification.py +43 -0
  37. huggingface_hub/inference/_generated/types/image_segmentation.py +52 -0
  38. huggingface_hub/inference/_generated/types/image_to_image.py +55 -0
  39. huggingface_hub/inference/_generated/types/image_to_text.py +105 -0
  40. huggingface_hub/inference/_generated/types/object_detection.py +55 -0
  41. huggingface_hub/inference/_generated/types/question_answering.py +77 -0
  42. huggingface_hub/inference/_generated/types/sentence_similarity.py +28 -0
  43. huggingface_hub/inference/_generated/types/summarization.py +46 -0
  44. huggingface_hub/inference/_generated/types/table_question_answering.py +45 -0
  45. huggingface_hub/inference/_generated/types/text2text_generation.py +45 -0
  46. huggingface_hub/inference/_generated/types/text_classification.py +43 -0
  47. huggingface_hub/inference/_generated/types/text_generation.py +161 -0
  48. huggingface_hub/inference/_generated/types/text_to_audio.py +105 -0
  49. huggingface_hub/inference/_generated/types/text_to_image.py +57 -0
  50. huggingface_hub/inference/_generated/types/token_classification.py +53 -0
  51. huggingface_hub/inference/_generated/types/translation.py +46 -0
  52. huggingface_hub/inference/_generated/types/video_classification.py +47 -0
  53. huggingface_hub/inference/_generated/types/visual_question_answering.py +53 -0
  54. huggingface_hub/inference/_generated/types/zero_shot_classification.py +56 -0
  55. huggingface_hub/inference/_generated/types/zero_shot_image_classification.py +51 -0
  56. huggingface_hub/inference/_generated/types/zero_shot_object_detection.py +55 -0
  57. huggingface_hub/inference/_templating.py +105 -0
  58. huggingface_hub/inference/_types.py +4 -152
  59. huggingface_hub/keras_mixin.py +39 -17
  60. huggingface_hub/lfs.py +20 -8
  61. huggingface_hub/repocard.py +11 -3
  62. huggingface_hub/repocard_data.py +12 -2
  63. huggingface_hub/serialization/__init__.py +1 -0
  64. huggingface_hub/serialization/_base.py +1 -0
  65. huggingface_hub/serialization/_numpy.py +1 -0
  66. huggingface_hub/serialization/_tensorflow.py +1 -0
  67. huggingface_hub/serialization/_torch.py +1 -0
  68. huggingface_hub/utils/__init__.py +4 -1
  69. huggingface_hub/utils/_cache_manager.py +7 -0
  70. huggingface_hub/utils/_chunk_utils.py +1 -0
  71. huggingface_hub/utils/_datetime.py +1 -0
  72. huggingface_hub/utils/_errors.py +10 -1
  73. huggingface_hub/utils/_experimental.py +1 -0
  74. huggingface_hub/utils/_fixes.py +19 -3
  75. huggingface_hub/utils/_git_credential.py +1 -0
  76. huggingface_hub/utils/_headers.py +10 -3
  77. huggingface_hub/utils/_hf_folder.py +1 -0
  78. huggingface_hub/utils/_http.py +1 -0
  79. huggingface_hub/utils/_pagination.py +1 -0
  80. huggingface_hub/utils/_paths.py +1 -0
  81. huggingface_hub/utils/_runtime.py +22 -0
  82. huggingface_hub/utils/_subprocess.py +1 -0
  83. huggingface_hub/utils/_token.py +1 -0
  84. huggingface_hub/utils/_typing.py +29 -1
  85. huggingface_hub/utils/_validators.py +1 -0
  86. huggingface_hub/utils/endpoint_helpers.py +1 -0
  87. huggingface_hub/utils/logging.py +1 -1
  88. huggingface_hub/utils/sha.py +1 -0
  89. huggingface_hub/utils/tqdm.py +1 -0
  90. {huggingface_hub-0.21.4.dist-info → huggingface_hub-0.22.0.dist-info}/METADATA +14 -15
  91. huggingface_hub-0.22.0.dist-info/RECORD +113 -0
  92. {huggingface_hub-0.21.4.dist-info → huggingface_hub-0.22.0.dist-info}/WHEEL +1 -1
  93. huggingface_hub/inference/_text_generation.py +0 -551
  94. huggingface_hub-0.21.4.dist-info/RECORD +0 -81
  95. {huggingface_hub-0.21.4.dist-info → huggingface_hub-0.22.0.dist-info}/LICENSE +0 -0
  96. {huggingface_hub-0.21.4.dist-info → huggingface_hub-0.22.0.dist-info}/entry_points.txt +0 -0
  97. {huggingface_hub-0.21.4.dist-info → huggingface_hub-0.22.0.dist-info}/top_level.txt +0 -0
@@ -23,7 +23,6 @@
23
23
  # https://github.com/huggingface/unity-api#tasks
24
24
  #
25
25
  # Some TODO:
26
- # - validate inputs/options/parameters? with Pydantic for instance? or only optionally?
27
26
  # - add all tasks
28
27
  #
29
28
  # NOTE: the philosophy of this client is "let's make it as easy as possible to use it, even if less optimized". Some
@@ -37,7 +36,6 @@ import base64
37
36
  import logging
38
37
  import time
39
38
  import warnings
40
- from dataclasses import asdict
41
39
  from typing import (
42
40
  TYPE_CHECKING,
43
41
  Any,
@@ -54,10 +52,10 @@ from requests import HTTPError
54
52
  from requests.structures import CaseInsensitiveDict
55
53
 
56
54
  from huggingface_hub.constants import ALL_INFERENCE_API_FRAMEWORKS, INFERENCE_ENDPOINT, MAIN_INFERENCE_API_FRAMEWORKS
55
+ from huggingface_hub.errors import InferenceTimeoutError
57
56
  from huggingface_hub.inference._common import (
58
57
  TASKS_EXPECTING_IMAGES,
59
58
  ContentT,
60
- InferenceTimeoutError,
61
59
  ModelStatus,
62
60
  _b64_encode,
63
61
  _b64_to_image,
@@ -66,28 +64,45 @@ from huggingface_hub.inference._common import (
66
64
  _bytes_to_list,
67
65
  _fetch_recommended_models,
68
66
  _import_numpy,
67
+ _is_chat_completion_server,
69
68
  _is_tgi_server,
70
69
  _open_as_binary,
70
+ _set_as_non_chat_completion_server,
71
71
  _set_as_non_tgi,
72
+ _stream_chat_completion_response_from_bytes,
73
+ _stream_chat_completion_response_from_text_generation,
72
74
  _stream_text_generation_response,
73
- )
74
- from huggingface_hub.inference._text_generation import (
75
- TextGenerationParameters,
76
- TextGenerationRequest,
77
- TextGenerationResponse,
78
- TextGenerationStreamResponse,
79
75
  raise_text_generation_error,
80
76
  )
77
+ from huggingface_hub.inference._generated.types import (
78
+ AudioClassificationOutputElement,
79
+ AudioToAudioOutputElement,
80
+ AutomaticSpeechRecognitionOutput,
81
+ ChatCompletionOutput,
82
+ ChatCompletionOutputChoice,
83
+ ChatCompletionOutputChoiceMessage,
84
+ ChatCompletionStreamOutput,
85
+ DocumentQuestionAnsweringOutputElement,
86
+ FillMaskOutputElement,
87
+ ImageClassificationOutputElement,
88
+ ImageSegmentationOutputElement,
89
+ ImageToTextOutput,
90
+ ObjectDetectionOutputElement,
91
+ QuestionAnsweringOutputElement,
92
+ SummarizationOutput,
93
+ TableQuestionAnsweringOutputElement,
94
+ TextClassificationOutputElement,
95
+ TextGenerationOutput,
96
+ TextGenerationStreamOutput,
97
+ TokenClassificationOutputElement,
98
+ TranslationOutput,
99
+ VisualQuestionAnsweringOutputElement,
100
+ ZeroShotClassificationOutputElement,
101
+ ZeroShotImageClassificationOutputElement,
102
+ )
103
+ from huggingface_hub.inference._templating import render_chat_prompt
81
104
  from huggingface_hub.inference._types import (
82
- AudioToAudioOutput,
83
- ClassificationOutput,
84
- ConversationalOutput,
85
- FillMaskOutput,
86
- ImageSegmentationOutput,
87
- ObjectDetectionOutput,
88
- QuestionAnsweringOutput,
89
- TableQuestionAnsweringOutput,
90
- TokenClassificationOutput,
105
+ ConversationalOutput, # soon to be removed
91
106
  )
92
107
  from huggingface_hub.utils import (
93
108
  BadRequestError,
@@ -116,9 +131,9 @@ class InferenceClient:
116
131
  The model to run inference with. Can be a model id hosted on the Hugging Face Hub, e.g. `bigcode/starcoder`
117
132
  or a URL to a deployed Inference Endpoint. Defaults to None, in which case a recommended model is
118
133
  automatically selected for the task.
119
- token (`str`, *optional*):
120
- Hugging Face token. Will default to the locally saved token. Pass `token=False` if you don't want to send
121
- your token to the server.
134
+ token (`str` or `bool`, *optional*):
135
+ Hugging Face token. Will default to the locally saved token if not provided.
136
+ Pass `token=False` if you don't want to send your token to the server.
122
137
  timeout (`float`, `optional`):
123
138
  The maximum number of seconds to wait for a response from the server. Loading a new model in Inference
124
139
  API can take up to several minutes. Defaults to None, meaning it will loop until the server is available.
@@ -138,6 +153,7 @@ class InferenceClient:
138
153
  cookies: Optional[Dict[str, str]] = None,
139
154
  ) -> None:
140
155
  self.model: Optional[str] = model
156
+ self.token: Union[str, bool, None] = token
141
157
  self.headers = CaseInsensitiveDict(build_hf_headers(token=token)) # contains 'authorization' + 'user-agent'
142
158
  if headers is not None:
143
159
  self.headers.update(headers)
@@ -156,11 +172,10 @@ class InferenceClient:
156
172
  model: Optional[str] = None,
157
173
  task: Optional[str] = None,
158
174
  stream: Literal[False] = ...,
159
- ) -> bytes:
160
- pass
175
+ ) -> bytes: ...
161
176
 
162
177
  @overload
163
- def post(
178
+ def post( # type: ignore[misc]
164
179
  self,
165
180
  *,
166
181
  json: Optional[Union[str, Dict, List]] = None,
@@ -168,8 +183,18 @@ class InferenceClient:
168
183
  model: Optional[str] = None,
169
184
  task: Optional[str] = None,
170
185
  stream: Literal[True] = ...,
171
- ) -> Iterable[bytes]:
172
- pass
186
+ ) -> Iterable[bytes]: ...
187
+
188
+ @overload
189
+ def post(
190
+ self,
191
+ *,
192
+ json: Optional[Union[str, Dict, List]] = None,
193
+ data: Optional[ContentT] = None,
194
+ model: Optional[str] = None,
195
+ task: Optional[str] = None,
196
+ stream: bool = False,
197
+ ) -> Union[bytes, Iterable[bytes]]: ...
173
198
 
174
199
  def post(
175
200
  self,
@@ -268,7 +293,7 @@ class InferenceClient:
268
293
  audio: ContentT,
269
294
  *,
270
295
  model: Optional[str] = None,
271
- ) -> List[ClassificationOutput]:
296
+ ) -> List[AudioClassificationOutputElement]:
272
297
  """
273
298
  Perform audio classification on the provided audio content.
274
299
 
@@ -282,7 +307,7 @@ class InferenceClient:
282
307
  audio classification will be used.
283
308
 
284
309
  Returns:
285
- `List[Dict]`: The classification output containing the predicted label and its confidence.
310
+ `List[AudioClassificationOutputElement]`: List of [`AudioClassificationOutputElement`] items containing the predicted labels and their confidence.
286
311
 
287
312
  Raises:
288
313
  [`InferenceTimeoutError`]:
@@ -295,18 +320,22 @@ class InferenceClient:
295
320
  >>> from huggingface_hub import InferenceClient
296
321
  >>> client = InferenceClient()
297
322
  >>> client.audio_classification("audio.flac")
298
- [{'score': 0.4976358711719513, 'label': 'hap'}, {'score': 0.3677836060523987, 'label': 'neu'},...]
323
+ [
324
+ AudioClassificationOutputElement(score=0.4976358711719513, label='hap'),
325
+ AudioClassificationOutputElement(score=0.3677836060523987, label='neu'),
326
+ ...
327
+ ]
299
328
  ```
300
329
  """
301
330
  response = self.post(data=audio, model=model, task="audio-classification")
302
- return _bytes_to_list(response)
331
+ return AudioClassificationOutputElement.parse_obj_as_list(response)
303
332
 
304
333
  def audio_to_audio(
305
334
  self,
306
335
  audio: ContentT,
307
336
  *,
308
337
  model: Optional[str] = None,
309
- ) -> List[AudioToAudioOutput]:
338
+ ) -> List[AudioToAudioOutputElement]:
310
339
  """
311
340
  Performs multiple tasks related to audio-to-audio depending on the model (eg: speech enhancement, source separation).
312
341
 
@@ -320,7 +349,7 @@ class InferenceClient:
320
349
  audio_to_audio will be used.
321
350
 
322
351
  Returns:
323
- `List[Dict]`: A list of dictionary where each index contains audios label, content-type, and audio content in blob.
352
+ `List[AudioToAudioOutputElement]`: A list of [`AudioToAudioOutputElement`] items containing audios label, content-type, and audio content in blob.
324
353
 
325
354
  Raises:
326
355
  `InferenceTimeoutError`:
@@ -335,13 +364,13 @@ class InferenceClient:
335
364
  >>> audio_output = client.audio_to_audio("audio.flac")
336
365
  >>> for i, item in enumerate(audio_output):
337
366
  >>> with open(f"output_{i}.flac", "wb") as f:
338
- f.write(item["blob"])
367
+ f.write(item.blob)
339
368
  ```
340
369
  """
341
370
  response = self.post(data=audio, model=model, task="audio-to-audio")
342
- audio_output = _bytes_to_list(response)
371
+ audio_output = AudioToAudioOutputElement.parse_obj_as_list(response)
343
372
  for item in audio_output:
344
- item["blob"] = base64.b64decode(item["blob"])
373
+ item.blob = base64.b64decode(item.blob)
345
374
  return audio_output
346
375
 
347
376
  def automatic_speech_recognition(
@@ -349,7 +378,7 @@ class InferenceClient:
349
378
  audio: ContentT,
350
379
  *,
351
380
  model: Optional[str] = None,
352
- ) -> str:
381
+ ) -> AutomaticSpeechRecognitionOutput:
353
382
  """
354
383
  Perform automatic speech recognition (ASR or audio-to-text) on the given audio content.
355
384
 
@@ -361,7 +390,7 @@ class InferenceClient:
361
390
  Inference Endpoint. If not provided, the default recommended model for ASR will be used.
362
391
 
363
392
  Returns:
364
- str: The transcribed text.
393
+ [`AutomaticSpeechRecognitionOutput`]: An item containing the transcribed text and optionally the timestamp chunks.
365
394
 
366
395
  Raises:
367
396
  [`InferenceTimeoutError`]:
@@ -373,12 +402,265 @@ class InferenceClient:
373
402
  ```py
374
403
  >>> from huggingface_hub import InferenceClient
375
404
  >>> client = InferenceClient()
376
- >>> client.automatic_speech_recognition("hello_world.flac")
405
+ >>> client.automatic_speech_recognition("hello_world.flac").text
377
406
  "hello world"
378
407
  ```
379
408
  """
380
409
  response = self.post(data=audio, model=model, task="automatic-speech-recognition")
381
- return _bytes_to_dict(response)["text"]
410
+ return AutomaticSpeechRecognitionOutput.parse_obj_as_instance(response)
411
+
412
+ @overload
413
+ def chat_completion( # type: ignore
414
+ self,
415
+ messages: List[Dict[str, str]],
416
+ *,
417
+ model: Optional[str] = None,
418
+ stream: Literal[False] = False,
419
+ max_tokens: int = 20,
420
+ seed: Optional[int] = None,
421
+ stop: Optional[Union[List[str], str]] = None,
422
+ temperature: float = 1.0,
423
+ top_p: Optional[float] = None,
424
+ ) -> ChatCompletionOutput: ...
425
+
426
+ @overload
427
+ def chat_completion( # type: ignore
428
+ self,
429
+ messages: List[Dict[str, str]],
430
+ *,
431
+ model: Optional[str] = None,
432
+ stream: Literal[True] = True,
433
+ max_tokens: int = 20,
434
+ seed: Optional[int] = None,
435
+ stop: Optional[Union[List[str], str]] = None,
436
+ temperature: float = 1.0,
437
+ top_p: Optional[float] = None,
438
+ ) -> Iterable[ChatCompletionStreamOutput]: ...
439
+
440
+ @overload
441
+ def chat_completion(
442
+ self,
443
+ messages: List[Dict[str, str]],
444
+ *,
445
+ model: Optional[str] = None,
446
+ stream: bool = False,
447
+ max_tokens: int = 20,
448
+ seed: Optional[int] = None,
449
+ stop: Optional[Union[List[str], str]] = None,
450
+ temperature: float = 1.0,
451
+ top_p: Optional[float] = None,
452
+ ) -> Union[ChatCompletionOutput, Iterable[ChatCompletionStreamOutput]]: ...
453
+
454
+ def chat_completion(
455
+ self,
456
+ messages: List[Dict[str, str]],
457
+ *,
458
+ model: Optional[str] = None,
459
+ stream: bool = False,
460
+ max_tokens: int = 20,
461
+ seed: Optional[int] = None,
462
+ stop: Optional[Union[List[str], str]] = None,
463
+ temperature: float = 1.0,
464
+ top_p: Optional[float] = None,
465
+ ) -> Union[ChatCompletionOutput, Iterable[ChatCompletionStreamOutput]]:
466
+ """
467
+ A method for completing conversations using a specified language model.
468
+
469
+ <Tip>
470
+
471
+ If the model is served by a server supporting chat-completion, the method will directly call the server's
472
+ `/v1/chat/completions` endpoint. If the server does not support chat-completion, the method will render the
473
+ chat template client-side based on the information fetched from the Hub API. In this case, you will need to
474
+ have `minijinja` template engine installed. Run `pip install "huggingface_hub[inference]"` or `pip install minijinja`
475
+ to install it.
476
+
477
+ </Tip>
478
+
479
+ Args:
480
+ messages (List[Union[`SystemMessage`, `UserMessage`, `AssistantMessage`]]):
481
+ Conversation history consisting of roles and content pairs.
482
+ model (`str`, *optional*):
483
+ The model to use for chat-completion. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
484
+ Inference Endpoint. If not provided, the default recommended model for chat-based text-generation will be used.
485
+ See https://huggingface.co/tasks/text-generation for more details.
486
+ frequency_penalty (`float`, optional):
487
+ Penalizes new tokens based on their existing frequency
488
+ in the text so far. Range: [-2.0, 2.0]. Defaults to 0.0.
489
+ max_tokens (`int`, optional):
490
+ Maximum number of tokens allowed in the response. Defaults to 20.
491
+ seed (Optional[`int`], optional):
492
+ Seed for reproducible control flow. Defaults to None.
493
+ stop (Optional[`str`], optional):
494
+ Up to four strings which trigger the end of the response.
495
+ Defaults to None.
496
+ stream (`bool`, optional):
497
+ Enable realtime streaming of responses. Defaults to False.
498
+ temperature (`float`, optional):
499
+ Controls randomness of the generations. Lower values ensure
500
+ less random completions. Range: [0, 2]. Defaults to 1.0.
501
+ top_p (`float`, optional):
502
+ Fraction of the most likely next words to sample from.
503
+ Must be between 0 and 1. Defaults to 1.0.
504
+
505
+ Returns:
506
+ `Union[ChatCompletionOutput, Iterable[ChatCompletionStreamOutput]]`:
507
+ Generated text returned from the server:
508
+ - if `stream=False`, the generated text is returned as a [`ChatCompletionOutput`] (default).
509
+ - if `stream=True`, the generated text is returned token by token as a sequence of [`ChatCompletionStreamOutput`].
510
+
511
+ Raises:
512
+ [`InferenceTimeoutError`]:
513
+ If the model is unavailable or the request times out.
514
+ `HTTPError`:
515
+ If the request fails with an HTTP error status code other than HTTP 503.
516
+
517
+ Example:
518
+ ```py
519
+ >>> from huggingface_hub import InferenceClient
520
+ >>> messages = [{"role": "user", "content": "What is the capital of France?"}]
521
+ >>> client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
522
+ >>> client.chat_completion(messages, max_tokens=100)
523
+ ChatCompletionOutput(
524
+ choices=[
525
+ ChatCompletionOutputChoice(
526
+ finish_reason='eos_token',
527
+ index=0,
528
+ message=ChatCompletionOutputChoiceMessage(
529
+ content='The capital of France is Paris. The official name of the city is "Ville de Paris" (City of Paris) and the name of the country\'s governing body, which is located in Paris, is "La République française" (The French Republic). \nI hope that helps! Let me know if you need any further information.'
530
+ )
531
+ )
532
+ ],
533
+ created=1710498360
534
+ )
535
+
536
+ >>> for token in client.chat_completion(messages, max_tokens=10, stream=True):
537
+ ... print(token)
538
+ ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content='The', role='assistant'), index=0, finish_reason=None)], created=1710498504)
539
+ ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content=' capital', role='assistant'), index=0, finish_reason=None)], created=1710498504)
540
+ (...)
541
+ ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content=' may', role='assistant'), index=0, finish_reason=None)], created=1710498504)
542
+ ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content=None, role=None), index=0, finish_reason='length')], created=1710498504)
543
+ ```
544
+ """
545
+ # determine model
546
+ model = model or self.model or self.get_recommended_model("text-generation")
547
+
548
+ if _is_chat_completion_server(model):
549
+ # First, let's consider the server has a `/v1/chat/completions` endpoint.
550
+ # If that's the case, we don't have to render the chat template client-side.
551
+ model_url = self._resolve_url(model) + "/v1/chat/completions"
552
+
553
+ try:
554
+ data = self.post(
555
+ model=model_url,
556
+ json=dict(
557
+ model="tgi", # random string
558
+ messages=messages,
559
+ max_tokens=max_tokens,
560
+ seed=seed,
561
+ stop=stop,
562
+ temperature=temperature,
563
+ top_p=top_p,
564
+ stream=stream,
565
+ ),
566
+ stream=stream,
567
+ )
568
+ except HTTPError:
569
+ # Let's consider the server is not a chat completion server.
570
+ # Then we call again `chat_completion` which will render the chat template client side.
571
+ # (can be HTTP 500, HTTP 400, HTTP 404 depending on the server)
572
+ _set_as_non_chat_completion_server(model)
573
+ return self.chat_completion(
574
+ messages=messages,
575
+ model=model,
576
+ stream=stream,
577
+ max_tokens=max_tokens,
578
+ seed=seed,
579
+ stop=stop,
580
+ temperature=temperature,
581
+ top_p=top_p,
582
+ )
583
+
584
+ if stream:
585
+ return _stream_chat_completion_response_from_bytes(data) # type: ignore[arg-type]
586
+
587
+ return ChatCompletionOutput.parse_obj_as_instance(data) # type: ignore[arg-type]
588
+
589
+ # At this point, we know the server is not a chat completion server.
590
+ # We need to render the chat template client side based on the information we can fetch from
591
+ # the Hub API.
592
+
593
+ model_id = None
594
+ if model.startswith(("http://", "https://")):
595
+ # If URL, we need to know which model is served. This is not always possible.
596
+ # A workaround is to list the user Inference Endpoints and check if one of them correspond to the model URL.
597
+ # If not, we raise an error.
598
+ # TODO: fix when we have a proper API for this (at least for Inference Endpoints)
599
+ # TODO: what if Sagemaker URL?
600
+ # TODO: what if Azure URL?
601
+ from ..hf_api import HfApi
602
+
603
+ for endpoint in HfApi(token=self.token).list_inference_endpoints():
604
+ if endpoint.url == model:
605
+ model_id = endpoint.repository
606
+ break
607
+ else:
608
+ model_id = model
609
+
610
+ if model_id is None:
611
+ # If we don't have the model ID, we can't fetch the chat template.
612
+ # We raise an error.
613
+ raise ValueError(
614
+ "Request can't be processed as the model ID can't be inferred from model URL. "
615
+ "This is needed to fetch the chat template from the Hub since the model is not "
616
+ "served with a Chat-completion API."
617
+ )
618
+
619
+ # fetch chat template + tokens
620
+ prompt = render_chat_prompt(model_id=model_id, token=self.token, messages=messages)
621
+
622
+ # generate response
623
+ stop_sequences = [stop] if isinstance(stop, str) else stop
624
+ text_generation_output = self.text_generation(
625
+ prompt=prompt,
626
+ details=True,
627
+ stream=stream,
628
+ model=model,
629
+ max_new_tokens=max_tokens,
630
+ seed=seed,
631
+ stop_sequences=stop_sequences,
632
+ temperature=temperature,
633
+ top_p=top_p,
634
+ )
635
+
636
+ created = int(time.time())
637
+
638
+ if stream:
639
+ return _stream_chat_completion_response_from_text_generation(text_generation_output) # type: ignore [arg-type]
640
+
641
+ if isinstance(text_generation_output, TextGenerationOutput):
642
+ # General use case => format ChatCompletionOutput from text generation details
643
+ content: str = text_generation_output.generated_text
644
+ finish_reason: str = text_generation_output.details.finish_reason # type: ignore[union-attr]
645
+ else:
646
+ # Corner case: if server doesn't support details (e.g. if not a TGI server), we only receive an output string.
647
+ # In such a case, `finish_reason` is set to `"unk"`.
648
+ content = text_generation_output # type: ignore[assignment]
649
+ finish_reason = "unk"
650
+
651
+ return ChatCompletionOutput(
652
+ created=created,
653
+ choices=[
654
+ ChatCompletionOutputChoice(
655
+ finish_reason=finish_reason, # type: ignore
656
+ index=0,
657
+ message=ChatCompletionOutputChoiceMessage(
658
+ content=content,
659
+ role="assistant",
660
+ ),
661
+ )
662
+ ],
663
+ )
382
664
 
383
665
  def conversational(
384
666
  self,
@@ -392,6 +674,13 @@ class InferenceClient:
392
674
  """
393
675
  Generate conversational responses based on the given input text (i.e. chat with the API).
394
676
 
677
+ <Tip warning={true}>
678
+
679
+ [`InferenceClient.conversational`] API is deprecated and will be removed in a future release. Please use
680
+ [`InferenceClient.chat_completion`] instead.
681
+
682
+ </Tip>
683
+
395
684
  Args:
396
685
  text (`str`):
397
686
  The last input from the user in the conversation.
@@ -431,6 +720,11 @@ class InferenceClient:
431
720
  ... )
432
721
  ```
433
722
  """
723
+ warnings.warn(
724
+ "'InferenceClient.conversational' is deprecated and will be removed starting from huggingface_hub>=0.25. "
725
+ "Please use the more appropriate 'InferenceClient.chat_completion' API instead.",
726
+ FutureWarning,
727
+ )
434
728
  payload: Dict[str, Any] = {"inputs": {"text": text}}
435
729
  if generated_responses is not None:
436
730
  payload["inputs"]["generated_responses"] = generated_responses
@@ -441,57 +735,13 @@ class InferenceClient:
441
735
  response = self.post(json=payload, model=model, task="conversational")
442
736
  return _bytes_to_dict(response) # type: ignore
443
737
 
444
- def visual_question_answering(
445
- self,
446
- image: ContentT,
447
- question: str,
448
- *,
449
- model: Optional[str] = None,
450
- ) -> List[str]:
451
- """
452
- Answering open-ended questions based on an image.
453
-
454
- Args:
455
- image (`Union[str, Path, bytes, BinaryIO]`):
456
- The input image for the context. It can be raw bytes, an image file, or a URL to an online image.
457
- question (`str`):
458
- Question to be answered.
459
- model (`str`, *optional*):
460
- The model to use for the visual question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to
461
- a deployed Inference Endpoint. If not provided, the default recommended visual question answering model will be used.
462
- Defaults to None.
463
-
464
- Returns:
465
- `List[Dict]`: a list of dictionaries containing the predicted label and associated probability.
466
-
467
- Raises:
468
- `InferenceTimeoutError`:
469
- If the model is unavailable or the request times out.
470
- `HTTPError`:
471
- If the request fails with an HTTP error status code other than HTTP 503.
472
-
473
- Example:
474
- ```py
475
- >>> from huggingface_hub import InferenceClient
476
- >>> client = InferenceClient()
477
- >>> client.visual_question_answering(
478
- ... image="https://huggingface.co/datasets/mishig/sample_images/resolve/main/tiger.jpg",
479
- ... question="What is the animal doing?"
480
- ... )
481
- [{'score': 0.778609573841095, 'answer': 'laying down'},{'score': 0.6957435607910156, 'answer': 'sitting'}, ...]
482
- ```
483
- """
484
- payload: Dict[str, Any] = {"question": question, "image": _b64_encode(image)}
485
- response = self.post(json=payload, model=model, task="visual-question-answering")
486
- return _bytes_to_list(response)
487
-
488
738
  def document_question_answering(
489
739
  self,
490
740
  image: ContentT,
491
741
  question: str,
492
742
  *,
493
743
  model: Optional[str] = None,
494
- ) -> List[QuestionAnsweringOutput]:
744
+ ) -> List[DocumentQuestionAnsweringOutputElement]:
495
745
  """
496
746
  Answer questions on document images.
497
747
 
@@ -506,7 +756,7 @@ class InferenceClient:
506
756
  Defaults to None.
507
757
 
508
758
  Returns:
509
- `List[Dict]`: a list of dictionaries containing the predicted label, associated probability, word ids, and page number.
759
+ `List[DocumentQuestionAnsweringOutputElement]`: a list of [`DocumentQuestionAnsweringOutputElement`] items containing the predicted label, associated probability, word ids, and page number.
510
760
 
511
761
  Raises:
512
762
  [`InferenceTimeoutError`]:
@@ -519,12 +769,12 @@ class InferenceClient:
519
769
  >>> from huggingface_hub import InferenceClient
520
770
  >>> client = InferenceClient()
521
771
  >>> client.document_question_answering(image="https://huggingface.co/spaces/impira/docquery/resolve/2359223c1837a7587402bda0f2643382a6eefeab/invoice.png", question="What is the invoice number?")
522
- [{'score': 0.42515629529953003, 'answer': 'us-001', 'start': 16, 'end': 16}]
772
+ [DocumentQuestionAnsweringOutputElement(score=0.42515629529953003, answer='us-001', start=16, end=16)]
523
773
  ```
524
774
  """
525
775
  payload: Dict[str, Any] = {"question": question, "image": _b64_encode(image)}
526
776
  response = self.post(json=payload, model=model, task="document-question-answering")
527
- return _bytes_to_list(response)
777
+ return DocumentQuestionAnsweringOutputElement.parse_obj_as_list(response)
528
778
 
529
779
  def feature_extraction(self, text: str, *, model: Optional[str] = None) -> "np.ndarray":
530
780
  """
@@ -562,7 +812,7 @@ class InferenceClient:
562
812
  np = _import_numpy()
563
813
  return np.array(_bytes_to_dict(response), dtype="float32")
564
814
 
565
- def fill_mask(self, text: str, *, model: Optional[str] = None) -> List[FillMaskOutput]:
815
+ def fill_mask(self, text: str, *, model: Optional[str] = None) -> List[FillMaskOutputElement]:
566
816
  """
567
817
  Fill in a hole with a missing word (token to be precise).
568
818
 
@@ -575,7 +825,7 @@ class InferenceClient:
575
825
  Defaults to None.
576
826
 
577
827
  Returns:
578
- `List[Dict]`: a list of fill mask output dictionaries containing the predicted label, associated
828
+ `List[FillMaskOutputElement]`: a list of [`FillMaskOutputElement`] items containing the predicted label, associated
579
829
  probability, token reference, and completed text.
580
830
 
581
831
  Raises:
@@ -589,25 +839,21 @@ class InferenceClient:
589
839
  >>> from huggingface_hub import InferenceClient
590
840
  >>> client = InferenceClient()
591
841
  >>> client.fill_mask("The goal of life is <mask>.")
592
- [{'score': 0.06897063553333282,
593
- 'token': 11098,
594
- 'token_str': ' happiness',
595
- 'sequence': 'The goal of life is happiness.'},
596
- {'score': 0.06554922461509705,
597
- 'token': 45075,
598
- 'token_str': ' immortality',
599
- 'sequence': 'The goal of life is immortality.'}]
842
+ [
843
+ FillMaskOutputElement(score=0.06897063553333282, token=11098, token_str=' happiness', sequence='The goal of life is happiness.'),
844
+ FillMaskOutputElement(score=0.06554922461509705, token=45075, token_str=' immortality', sequence='The goal of life is immortality.')
845
+ ]
600
846
  ```
601
847
  """
602
848
  response = self.post(json={"inputs": text}, model=model, task="fill-mask")
603
- return _bytes_to_list(response)
849
+ return FillMaskOutputElement.parse_obj_as_list(response)
604
850
 
605
851
  def image_classification(
606
852
  self,
607
853
  image: ContentT,
608
854
  *,
609
855
  model: Optional[str] = None,
610
- ) -> List[ClassificationOutput]:
856
+ ) -> List[ImageClassificationOutputElement]:
611
857
  """
612
858
  Perform image classification on the given image using the specified model.
613
859
 
@@ -619,7 +865,7 @@ class InferenceClient:
619
865
  deployed Inference Endpoint. If not provided, the default recommended model for image classification will be used.
620
866
 
621
867
  Returns:
622
- `List[Dict]`: a list of dictionaries containing the predicted label and associated probability.
868
+ `List[ImageClassificationOutputElement]`: a list of [`ImageClassificationOutputElement`] items containing the predicted label and associated probability.
623
869
 
624
870
  Raises:
625
871
  [`InferenceTimeoutError`]:
@@ -632,18 +878,18 @@ class InferenceClient:
632
878
  >>> from huggingface_hub import InferenceClient
633
879
  >>> client = InferenceClient()
634
880
  >>> client.image_classification("https://upload.wikimedia.org/wikipedia/commons/thumb/4/43/Cute_dog.jpg/320px-Cute_dog.jpg")
635
- [{'score': 0.9779096841812134, 'label': 'Blenheim spaniel'}, ...]
881
+ [ImageClassificationOutputElement(score=0.9779096841812134, label='Blenheim spaniel'), ...]
636
882
  ```
637
883
  """
638
884
  response = self.post(data=image, model=model, task="image-classification")
639
- return _bytes_to_list(response)
885
+ return ImageClassificationOutputElement.parse_obj_as_list(response)
640
886
 
641
887
  def image_segmentation(
642
888
  self,
643
889
  image: ContentT,
644
890
  *,
645
891
  model: Optional[str] = None,
646
- ) -> List[ImageSegmentationOutput]:
892
+ ) -> List[ImageSegmentationOutputElement]:
647
893
  """
648
894
  Perform image segmentation on the given image using the specified model.
649
895
 
@@ -661,7 +907,7 @@ class InferenceClient:
661
907
  deployed Inference Endpoint. If not provided, the default recommended model for image segmentation will be used.
662
908
 
663
909
  Returns:
664
- `List[Dict]`: A list of dictionaries containing the segmented masks and associated attributes.
910
+ `List[ImageSegmentationOutputElement]`: A list of [`ImageSegmentationOutputElement`] items containing the segmented masks and associated attributes.
665
911
 
666
912
  Raises:
667
913
  [`InferenceTimeoutError`]:
@@ -674,19 +920,13 @@ class InferenceClient:
674
920
  >>> from huggingface_hub import InferenceClient
675
921
  >>> client = InferenceClient()
676
922
  >>> client.image_segmentation("cat.jpg"):
677
- [{'score': 0.989008, 'label': 'LABEL_184', 'mask': <PIL.PngImagePlugin.PngImageFile image mode=L size=400x300 at 0x7FDD2B129CC0>}, ...]
923
+ [ImageSegmentationOutputElement(score=0.989008, label='LABEL_184', mask=<PIL.PngImagePlugin.PngImageFile image mode=L size=400x300 at 0x7FDD2B129CC0>), ...]
678
924
  ```
679
925
  """
680
-
681
- # Segment
682
926
  response = self.post(data=image, model=model, task="image-segmentation")
683
- output = _bytes_to_dict(response)
684
-
685
- # Parse masks as PIL Image
686
- if not isinstance(output, list):
687
- raise ValueError(f"Server output must be a list. Got {type(output)}: {str(output)[:200]}...")
927
+ output = ImageSegmentationOutputElement.parse_obj_as_list(response)
688
928
  for item in output:
689
- item["mask"] = _b64_to_image(item["mask"])
929
+ item.mask = _b64_to_image(item.mask)
690
930
  return output
691
931
 
692
932
  def image_to_image(
@@ -773,7 +1013,7 @@ class InferenceClient:
773
1013
  response = self.post(json=payload, data=data, model=model, task="image-to-image")
774
1014
  return _bytes_to_image(response)
775
1015
 
776
- def image_to_text(self, image: ContentT, *, model: Optional[str] = None) -> str:
1016
+ def image_to_text(self, image: ContentT, *, model: Optional[str] = None) -> ImageToTextOutput:
777
1017
  """
778
1018
  Takes an input image and return text.
779
1019
 
@@ -788,7 +1028,7 @@ class InferenceClient:
788
1028
  Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
789
1029
 
790
1030
  Returns:
791
- `str`: The generated text.
1031
+ [`ImageToTextOutput`]: The generated text.
792
1032
 
793
1033
  Raises:
794
1034
  [`InferenceTimeoutError`]:
@@ -807,7 +1047,7 @@ class InferenceClient:
807
1047
  ```
808
1048
  """
809
1049
  response = self.post(data=image, model=model, task="image-to-text")
810
- return _bytes_to_dict(response)[0]["generated_text"]
1050
+ return ImageToTextOutput.parse_obj_as_instance(response)
811
1051
 
812
1052
  def list_deployed_models(
813
1053
  self, frameworks: Union[None, str, Literal["all"], List[str]] = None
@@ -889,7 +1129,7 @@ class InferenceClient:
889
1129
  image: ContentT,
890
1130
  *,
891
1131
  model: Optional[str] = None,
892
- ) -> List[ObjectDetectionOutput]:
1132
+ ) -> List[ObjectDetectionOutputElement]:
893
1133
  """
894
1134
  Perform object detection on the given image using the specified model.
895
1135
 
@@ -907,7 +1147,7 @@ class InferenceClient:
907
1147
  deployed Inference Endpoint. If not provided, the default recommended model for object detection (DETR) will be used.
908
1148
 
909
1149
  Returns:
910
- `List[ObjectDetectionOutput]`: A list of dictionaries containing the bounding boxes and associated attributes.
1150
+ `List[ObjectDetectionOutputElement]`: A list of [`ObjectDetectionOutputElement`] items containing the bounding boxes and associated attributes.
911
1151
 
912
1152
  Raises:
913
1153
  [`InferenceTimeoutError`]:
@@ -922,19 +1162,16 @@ class InferenceClient:
922
1162
  >>> from huggingface_hub import InferenceClient
923
1163
  >>> client = InferenceClient()
924
1164
  >>> client.object_detection("people.jpg"):
925
- [{"score":0.9486683011054993,"label":"person","box":{"xmin":59,"ymin":39,"xmax":420,"ymax":510}}, ... ]
1165
+ [ObjectDetectionOutputElement(score=0.9486683011054993, label='person', box=ObjectDetectionBoundingBox(xmin=59, ymin=39, xmax=420, ymax=510)), ...]
926
1166
  ```
927
1167
  """
928
1168
  # detect objects
929
1169
  response = self.post(data=image, model=model, task="object-detection")
930
- output = _bytes_to_dict(response)
931
- if not isinstance(output, list):
932
- raise ValueError(f"Server output must be a list. Got {type(output)}: {str(output)[:200]}...")
933
- return output
1170
+ return ObjectDetectionOutputElement.parse_obj_as_list(response)
934
1171
 
935
1172
  def question_answering(
936
1173
  self, question: str, context: str, *, model: Optional[str] = None
937
- ) -> QuestionAnsweringOutput:
1174
+ ) -> QuestionAnsweringOutputElement:
938
1175
  """
939
1176
  Retrieve the answer to a question from a given text.
940
1177
 
@@ -948,7 +1185,7 @@ class InferenceClient:
948
1185
  a deployed Inference Endpoint.
949
1186
 
950
1187
  Returns:
951
- `Dict`: a dictionary of question answering output containing the score, start index, end index, and answer.
1188
+ [`QuestionAnsweringOutputElement`]: an question answering output containing the score, start index, end index, and answer.
952
1189
 
953
1190
  Raises:
954
1191
  [`InferenceTimeoutError`]:
@@ -961,7 +1198,7 @@ class InferenceClient:
961
1198
  >>> from huggingface_hub import InferenceClient
962
1199
  >>> client = InferenceClient()
963
1200
  >>> client.question_answering(question="What's my name?", context="My name is Clara and I live in Berkeley.")
964
- {'score': 0.9326562285423279, 'start': 11, 'end': 16, 'answer': 'Clara'}
1201
+ QuestionAnsweringOutputElement(score=0.9326562285423279, start=11, end=16, answer='Clara')
965
1202
  ```
966
1203
  """
967
1204
 
@@ -971,7 +1208,7 @@ class InferenceClient:
971
1208
  model=model,
972
1209
  task="question-answering",
973
1210
  )
974
- return _bytes_to_dict(response) # type: ignore
1211
+ return QuestionAnsweringOutputElement.parse_obj_as_instance(response)
975
1212
 
976
1213
  def sentence_similarity(
977
1214
  self, sentence: str, other_sentences: List[str], *, model: Optional[str] = None
@@ -1026,7 +1263,7 @@ class InferenceClient:
1026
1263
  *,
1027
1264
  parameters: Optional[Dict[str, Any]] = None,
1028
1265
  model: Optional[str] = None,
1029
- ) -> str:
1266
+ ) -> SummarizationOutput:
1030
1267
  """
1031
1268
  Generate a summary of a given text using a specified model.
1032
1269
 
@@ -1041,7 +1278,7 @@ class InferenceClient:
1041
1278
  Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
1042
1279
 
1043
1280
  Returns:
1044
- `str`: The generated summary text.
1281
+ [`SummarizationOutput`]: The generated summary text.
1045
1282
 
1046
1283
  Raises:
1047
1284
  [`InferenceTimeoutError`]:
@@ -1054,18 +1291,18 @@ class InferenceClient:
1054
1291
  >>> from huggingface_hub import InferenceClient
1055
1292
  >>> client = InferenceClient()
1056
1293
  >>> client.summarization("The Eiffel tower...")
1057
- 'The Eiffel tower is one of the most famous landmarks in the world....'
1294
+ SummarizationOutput(generated_text="The Eiffel tower is one of the most famous landmarks in the world....")
1058
1295
  ```
1059
1296
  """
1060
1297
  payload: Dict[str, Any] = {"inputs": text}
1061
1298
  if parameters is not None:
1062
1299
  payload["parameters"] = parameters
1063
1300
  response = self.post(json=payload, model=model, task="summarization")
1064
- return _bytes_to_dict(response)[0]["summary_text"]
1301
+ return SummarizationOutput.parse_obj_as_list(response)[0]
1065
1302
 
1066
1303
  def table_question_answering(
1067
1304
  self, table: Dict[str, Any], query: str, *, model: Optional[str] = None
1068
- ) -> TableQuestionAnsweringOutput:
1305
+ ) -> TableQuestionAnsweringOutputElement:
1069
1306
  """
1070
1307
  Retrieve the answer to a question from information given in a table.
1071
1308
 
@@ -1080,7 +1317,7 @@ class InferenceClient:
1080
1317
  Hub or a URL to a deployed Inference Endpoint.
1081
1318
 
1082
1319
  Returns:
1083
- `Dict`: a dictionary of table question answering output containing the answer, coordinates, cells and the aggregator used.
1320
+ [`TableQuestionAnsweringOutputElement`]: a table question answering output containing the answer, coordinates, cells and the aggregator used.
1084
1321
 
1085
1322
  Raises:
1086
1323
  [`InferenceTimeoutError`]:
@@ -1095,7 +1332,7 @@ class InferenceClient:
1095
1332
  >>> query = "How many stars does the transformers repository have?"
1096
1333
  >>> table = {"Repository": ["Transformers", "Datasets", "Tokenizers"], "Stars": ["36542", "4512", "3934"]}
1097
1334
  >>> client.table_question_answering(table, query, model="google/tapas-base-finetuned-wtq")
1098
- {'answer': 'AVERAGE > 36542', 'coordinates': [[0, 1]], 'cells': ['36542'], 'aggregator': 'AVERAGE'}
1335
+ TableQuestionAnsweringOutputElement(answer='36542', coordinates=[[0, 1]], cells=['36542'], aggregator='AVERAGE')
1099
1336
  ```
1100
1337
  """
1101
1338
  response = self.post(
@@ -1106,7 +1343,7 @@ class InferenceClient:
1106
1343
  model=model,
1107
1344
  task="table-question-answering",
1108
1345
  )
1109
- return _bytes_to_dict(response) # type: ignore
1346
+ return TableQuestionAnsweringOutputElement.parse_obj_as_instance(response)
1110
1347
 
1111
1348
  def tabular_classification(self, table: Dict[str, Any], *, model: Optional[str] = None) -> List[str]:
1112
1349
  """
@@ -1193,7 +1430,7 @@ class InferenceClient:
1193
1430
  response = self.post(json={"table": table}, model=model, task="tabular-regression")
1194
1431
  return _bytes_to_list(response)
1195
1432
 
1196
- def text_classification(self, text: str, *, model: Optional[str] = None) -> List[ClassificationOutput]:
1433
+ def text_classification(self, text: str, *, model: Optional[str] = None) -> List[TextClassificationOutputElement]:
1197
1434
  """
1198
1435
  Perform text classification (e.g. sentiment-analysis) on the given text.
1199
1436
 
@@ -1206,7 +1443,7 @@ class InferenceClient:
1206
1443
  Defaults to None.
1207
1444
 
1208
1445
  Returns:
1209
- `List[Dict]`: a list of dictionaries containing the predicted label and associated probability.
1446
+ `List[TextClassificationOutputElement]`: a list of [`TextClassificationOutputElement`] items containing the predicted label and associated probability.
1210
1447
 
1211
1448
  Raises:
1212
1449
  [`InferenceTimeoutError`]:
@@ -1219,11 +1456,14 @@ class InferenceClient:
1219
1456
  >>> from huggingface_hub import InferenceClient
1220
1457
  >>> client = InferenceClient()
1221
1458
  >>> client.text_classification("I like you")
1222
- [{'label': 'POSITIVE', 'score': 0.9998695850372314}, {'label': 'NEGATIVE', 'score': 0.0001304351753788069}]
1459
+ [
1460
+ TextClassificationOutputElement(label='POSITIVE', score=0.9998695850372314),
1461
+ TextClassificationOutputElement(label='NEGATIVE', score=0.0001304351753788069),
1462
+ ]
1223
1463
  ```
1224
1464
  """
1225
1465
  response = self.post(json={"inputs": text}, model=model, task="text-classification")
1226
- return _bytes_to_list(response)[0]
1466
+ return TextClassificationOutputElement.parse_obj_as_list(response)[0] # type: ignore [return-value]
1227
1467
 
1228
1468
  @overload
1229
1469
  def text_generation( # type: ignore
@@ -1246,8 +1486,7 @@ class InferenceClient:
1246
1486
  truncate: Optional[int] = None,
1247
1487
  typical_p: Optional[float] = None,
1248
1488
  watermark: bool = False,
1249
- ) -> str:
1250
- ...
1489
+ ) -> str: ...
1251
1490
 
1252
1491
  @overload
1253
1492
  def text_generation( # type: ignore
@@ -1270,8 +1509,7 @@ class InferenceClient:
1270
1509
  truncate: Optional[int] = None,
1271
1510
  typical_p: Optional[float] = None,
1272
1511
  watermark: bool = False,
1273
- ) -> TextGenerationResponse:
1274
- ...
1512
+ ) -> TextGenerationOutput: ...
1275
1513
 
1276
1514
  @overload
1277
1515
  def text_generation( # type: ignore
@@ -1294,11 +1532,10 @@ class InferenceClient:
1294
1532
  truncate: Optional[int] = None,
1295
1533
  typical_p: Optional[float] = None,
1296
1534
  watermark: bool = False,
1297
- ) -> Iterable[str]:
1298
- ...
1535
+ ) -> Iterable[str]: ...
1299
1536
 
1300
1537
  @overload
1301
- def text_generation(
1538
+ def text_generation( # type: ignore
1302
1539
  self,
1303
1540
  prompt: str,
1304
1541
  *,
@@ -1318,8 +1555,30 @@ class InferenceClient:
1318
1555
  truncate: Optional[int] = None,
1319
1556
  typical_p: Optional[float] = None,
1320
1557
  watermark: bool = False,
1321
- ) -> Iterable[TextGenerationStreamResponse]:
1322
- ...
1558
+ ) -> Iterable[TextGenerationStreamOutput]: ...
1559
+
1560
+ @overload
1561
+ def text_generation(
1562
+ self,
1563
+ prompt: str,
1564
+ *,
1565
+ details: Literal[True] = ...,
1566
+ stream: bool = ...,
1567
+ model: Optional[str] = None,
1568
+ do_sample: bool = False,
1569
+ max_new_tokens: int = 20,
1570
+ best_of: Optional[int] = None,
1571
+ repetition_penalty: Optional[float] = None,
1572
+ return_full_text: bool = False,
1573
+ seed: Optional[int] = None,
1574
+ stop_sequences: Optional[List[str]] = None,
1575
+ temperature: Optional[float] = None,
1576
+ top_k: Optional[int] = None,
1577
+ top_p: Optional[float] = None,
1578
+ truncate: Optional[int] = None,
1579
+ typical_p: Optional[float] = None,
1580
+ watermark: bool = False,
1581
+ ) -> Union[TextGenerationOutput, Iterable[TextGenerationStreamOutput]]: ...
1323
1582
 
1324
1583
  def text_generation(
1325
1584
  self,
@@ -1342,13 +1601,10 @@ class InferenceClient:
1342
1601
  typical_p: Optional[float] = None,
1343
1602
  watermark: bool = False,
1344
1603
  decoder_input_details: bool = False,
1345
- ) -> Union[str, TextGenerationResponse, Iterable[str], Iterable[TextGenerationStreamResponse]]:
1604
+ ) -> Union[str, TextGenerationOutput, Iterable[str], Iterable[TextGenerationStreamOutput]]:
1346
1605
  """
1347
1606
  Given a prompt, generate the following text.
1348
1607
 
1349
- It is recommended to have Pydantic installed in order to get inputs validated. This is preferable as it allow
1350
- early failures.
1351
-
1352
1608
  API endpoint is supposed to run with the `text-generation-inference` backend (TGI). This backend is the
1353
1609
  go-to solution to run large language models at scale. However, for some smaller models (e.g. "gpt2") the
1354
1610
  default `transformers` + `api-inference` solution is still in use. Both approaches have very similar APIs, but
@@ -1406,12 +1662,12 @@ class InferenceClient:
1406
1662
  into account. Defaults to `False`.
1407
1663
 
1408
1664
  Returns:
1409
- `Union[str, TextGenerationResponse, Iterable[str], Iterable[TextGenerationStreamResponse]]`:
1665
+ `Union[str, TextGenerationOutput, Iterable[str], Iterable[TextGenerationStreamOutput]]`:
1410
1666
  Generated text returned from the server:
1411
1667
  - if `stream=False` and `details=False`, the generated text is returned as a `str` (default)
1412
1668
  - if `stream=True` and `details=False`, the generated text is returned token by token as a `Iterable[str]`
1413
- - if `stream=False` and `details=True`, the generated text is returned with more details as a [`~huggingface_hub.inference._text_generation.TextGenerationResponse`]
1414
- - if `details=True` and `stream=True`, the generated text is returned token by token as a iterable of [`~huggingface_hub.inference._text_generation.TextGenerationStreamResponse`]
1669
+ - if `stream=False` and `details=True`, the generated text is returned with more details as a [`~huggingface_hub.TextGenerationOutput`]
1670
+ - if `details=True` and `stream=True`, the generated text is returned token by token as a iterable of [`~huggingface_hub.TextGenerationStreamOutput`]
1415
1671
 
1416
1672
  Raises:
1417
1673
  `ValidationError`:
@@ -1448,23 +1704,23 @@ class InferenceClient:
1448
1704
 
1449
1705
  # Case 3: get more details about the generation process.
1450
1706
  >>> client.text_generation("The huggingface_hub library is ", max_new_tokens=12, details=True)
1451
- TextGenerationResponse(
1707
+ TextGenerationOutput(
1452
1708
  generated_text='100% open source and built to be easy to use.',
1453
- details=Details(
1454
- finish_reason=<FinishReason.Length: 'length'>,
1709
+ details=TextGenerationDetails(
1710
+ finish_reason='length',
1455
1711
  generated_tokens=12,
1456
1712
  seed=None,
1457
1713
  prefill=[
1458
- InputToken(id=487, text='The', logprob=None),
1459
- InputToken(id=53789, text=' hugging', logprob=-13.171875),
1714
+ TextGenerationPrefillToken(id=487, text='The', logprob=None),
1715
+ TextGenerationPrefillToken(id=53789, text=' hugging', logprob=-13.171875),
1460
1716
  (...)
1461
- InputToken(id=204, text=' ', logprob=-7.0390625)
1717
+ TextGenerationPrefillToken(id=204, text=' ', logprob=-7.0390625)
1462
1718
  ],
1463
1719
  tokens=[
1464
- Token(id=1425, text='100', logprob=-1.0175781, special=False),
1465
- Token(id=16, text='%', logprob=-0.0463562, special=False),
1720
+ TokenElement(id=1425, text='100', logprob=-1.0175781, special=False),
1721
+ TokenElement(id=16, text='%', logprob=-0.0463562, special=False),
1466
1722
  (...)
1467
- Token(id=25, text='.', logprob=-0.5703125, special=False)
1723
+ TokenElement(id=25, text='.', logprob=-0.5703125, special=False)
1468
1724
  ],
1469
1725
  best_of_sequences=None
1470
1726
  )
@@ -1475,30 +1731,27 @@ class InferenceClient:
1475
1731
  >>> for details in client.text_generation("The huggingface_hub library is ", max_new_tokens=12, details=True, stream=True):
1476
1732
  ... print(details)
1477
1733
  ...
1478
- TextGenerationStreamResponse(token=Token(id=1425, text='100', logprob=-1.0175781, special=False), generated_text=None, details=None)
1479
- TextGenerationStreamResponse(token=Token(id=16, text='%', logprob=-0.0463562, special=False), generated_text=None, details=None)
1480
- TextGenerationStreamResponse(token=Token(id=1314, text=' open', logprob=-1.3359375, special=False), generated_text=None, details=None)
1481
- TextGenerationStreamResponse(token=Token(id=3178, text=' source', logprob=-0.28100586, special=False), generated_text=None, details=None)
1482
- TextGenerationStreamResponse(token=Token(id=273, text=' and', logprob=-0.5961914, special=False), generated_text=None, details=None)
1483
- TextGenerationStreamResponse(token=Token(id=3426, text=' built', logprob=-1.9423828, special=False), generated_text=None, details=None)
1484
- TextGenerationStreamResponse(token=Token(id=271, text=' to', logprob=-1.4121094, special=False), generated_text=None, details=None)
1485
- TextGenerationStreamResponse(token=Token(id=314, text=' be', logprob=-1.5224609, special=False), generated_text=None, details=None)
1486
- TextGenerationStreamResponse(token=Token(id=1833, text=' easy', logprob=-2.1132812, special=False), generated_text=None, details=None)
1487
- TextGenerationStreamResponse(token=Token(id=271, text=' to', logprob=-0.08520508, special=False), generated_text=None, details=None)
1488
- TextGenerationStreamResponse(token=Token(id=745, text=' use', logprob=-0.39453125, special=False), generated_text=None, details=None)
1489
- TextGenerationStreamResponse(token=Token(
1734
+ TextGenerationStreamOutput(token=TokenElement(id=1425, text='100', logprob=-1.0175781, special=False), generated_text=None, details=None)
1735
+ TextGenerationStreamOutput(token=TokenElement(id=16, text='%', logprob=-0.0463562, special=False), generated_text=None, details=None)
1736
+ TextGenerationStreamOutput(token=TokenElement(id=1314, text=' open', logprob=-1.3359375, special=False), generated_text=None, details=None)
1737
+ TextGenerationStreamOutput(token=TokenElement(id=3178, text=' source', logprob=-0.28100586, special=False), generated_text=None, details=None)
1738
+ TextGenerationStreamOutput(token=TokenElement(id=273, text=' and', logprob=-0.5961914, special=False), generated_text=None, details=None)
1739
+ TextGenerationStreamOutput(token=TokenElement(id=3426, text=' built', logprob=-1.9423828, special=False), generated_text=None, details=None)
1740
+ TextGenerationStreamOutput(token=TokenElement(id=271, text=' to', logprob=-1.4121094, special=False), generated_text=None, details=None)
1741
+ TextGenerationStreamOutput(token=TokenElement(id=314, text=' be', logprob=-1.5224609, special=False), generated_text=None, details=None)
1742
+ TextGenerationStreamOutput(token=TokenElement(id=1833, text=' easy', logprob=-2.1132812, special=False), generated_text=None, details=None)
1743
+ TextGenerationStreamOutput(token=TokenElement(id=271, text=' to', logprob=-0.08520508, special=False), generated_text=None, details=None)
1744
+ TextGenerationStreamOutput(token=TokenElement(id=745, text=' use', logprob=-0.39453125, special=False), generated_text=None, details=None)
1745
+ TextGenerationStreamOutput(token=TokenElement(
1490
1746
  id=25,
1491
1747
  text='.',
1492
1748
  logprob=-0.5703125,
1493
1749
  special=False),
1494
1750
  generated_text='100% open source and built to be easy to use.',
1495
- details=StreamDetails(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=12, seed=None)
1751
+ details=TextGenerationStreamDetails(finish_reason='length', generated_tokens=12, seed=None)
1496
1752
  )
1497
1753
  ```
1498
1754
  """
1499
- # NOTE: Text-generation integration is taken from the text-generation-inference project. It has more features
1500
- # like input/output validation (if Pydantic is installed). See `_text_generation.py` header for more details.
1501
-
1502
1755
  if decoder_input_details and not details:
1503
1756
  warnings.warn(
1504
1757
  "`decoder_input_details=True` has been passed to the server but `details=False` is set meaning that"
@@ -1506,34 +1759,38 @@ class InferenceClient:
1506
1759
  )
1507
1760
  decoder_input_details = False
1508
1761
 
1509
- # Validate parameters
1510
- parameters = TextGenerationParameters(
1511
- best_of=best_of,
1512
- details=details,
1513
- do_sample=do_sample,
1514
- max_new_tokens=max_new_tokens,
1515
- repetition_penalty=repetition_penalty,
1516
- return_full_text=return_full_text,
1517
- seed=seed,
1518
- stop=stop_sequences if stop_sequences is not None else [],
1519
- temperature=temperature,
1520
- top_k=top_k,
1521
- top_p=top_p,
1522
- truncate=truncate,
1523
- typical_p=typical_p,
1524
- watermark=watermark,
1525
- decoder_input_details=decoder_input_details,
1526
- )
1527
- request = TextGenerationRequest(inputs=prompt, stream=stream, parameters=parameters)
1528
- payload = asdict(request)
1762
+ # Build payload
1763
+ payload = {
1764
+ "inputs": prompt,
1765
+ "parameters": {
1766
+ "best_of": best_of,
1767
+ "decoder_input_details": decoder_input_details,
1768
+ "details": details,
1769
+ "do_sample": do_sample,
1770
+ "max_new_tokens": max_new_tokens,
1771
+ "repetition_penalty": repetition_penalty,
1772
+ "return_full_text": return_full_text,
1773
+ "seed": seed,
1774
+ "stop": stop_sequences if stop_sequences is not None else [],
1775
+ "temperature": temperature,
1776
+ "top_k": top_k,
1777
+ "top_p": top_p,
1778
+ "truncate": truncate,
1779
+ "typical_p": typical_p,
1780
+ "watermark": watermark,
1781
+ },
1782
+ "stream": stream,
1783
+ }
1529
1784
 
1530
1785
  # Remove some parameters if not a TGI server
1531
1786
  if not _is_tgi_server(model):
1787
+ parameters: Dict = payload["parameters"] # type: ignore [assignment]
1788
+
1532
1789
  ignored_parameters = []
1533
- for key in "watermark", "stop", "details", "decoder_input_details", "best_of":
1534
- if payload["parameters"][key] is not None:
1790
+ for key in "watermark", "details", "decoder_input_details", "best_of", "stop", "return_full_text":
1791
+ if parameters[key] is not None:
1535
1792
  ignored_parameters.append(key)
1536
- del payload["parameters"][key]
1793
+ del parameters[key]
1537
1794
  if len(ignored_parameters) > 0:
1538
1795
  warnings.warn(
1539
1796
  "API endpoint/model for text-generation is not served via TGI. Ignoring parameters"
@@ -1585,8 +1842,8 @@ class InferenceClient:
1585
1842
  if stream:
1586
1843
  return _stream_text_generation_response(bytes_output, details) # type: ignore
1587
1844
 
1588
- data = _bytes_to_dict(bytes_output)[0]
1589
- return TextGenerationResponse(**data) if details else data["generated_text"]
1845
+ data = _bytes_to_dict(bytes_output)[0] # type: ignore[arg-type]
1846
+ return TextGenerationOutput.parse_obj_as_instance(data) if details else data["generated_text"]
1590
1847
 
1591
1848
  def text_to_image(
1592
1849
  self,
@@ -1700,7 +1957,9 @@ class InferenceClient:
1700
1957
  """
1701
1958
  return self.post(json={"inputs": text}, model=model, task="text-to-speech")
1702
1959
 
1703
- def token_classification(self, text: str, *, model: Optional[str] = None) -> List[TokenClassificationOutput]:
1960
+ def token_classification(
1961
+ self, text: str, *, model: Optional[str] = None
1962
+ ) -> List[TokenClassificationOutputElement]:
1704
1963
  """
1705
1964
  Perform token classification on the given text.
1706
1965
  Usually used for sentence parsing, either grammatical, or Named Entity Recognition (NER) to understand keywords contained within text.
@@ -1714,7 +1973,7 @@ class InferenceClient:
1714
1973
  Defaults to None.
1715
1974
 
1716
1975
  Returns:
1717
- `List[Dict]`: List of token classification outputs containing the entity group, confidence score, word, start and end index.
1976
+ `List[TokenClassificationOutputElement]`: List of [`TokenClassificationOutputElement`] items containing the entity group, confidence score, word, start and end index.
1718
1977
 
1719
1978
  Raises:
1720
1979
  [`InferenceTimeoutError`]:
@@ -1727,16 +1986,22 @@ class InferenceClient:
1727
1986
  >>> from huggingface_hub import InferenceClient
1728
1987
  >>> client = InferenceClient()
1729
1988
  >>> client.token_classification("My name is Sarah Jessica Parker but you can call me Jessica")
1730
- [{'entity_group': 'PER',
1731
- 'score': 0.9971321225166321,
1732
- 'word': 'Sarah Jessica Parker',
1733
- 'start': 11,
1734
- 'end': 31},
1735
- {'entity_group': 'PER',
1736
- 'score': 0.9773476123809814,
1737
- 'word': 'Jessica',
1738
- 'start': 52,
1739
- 'end': 59}]
1989
+ [
1990
+ TokenClassificationOutputElement(
1991
+ entity_group='PER',
1992
+ score=0.9971321225166321,
1993
+ word='Sarah Jessica Parker',
1994
+ start=11,
1995
+ end=31,
1996
+ ),
1997
+ TokenClassificationOutputElement(
1998
+ entity_group='PER',
1999
+ score=0.9773476123809814,
2000
+ word='Jessica',
2001
+ start=52,
2002
+ end=59,
2003
+ )
2004
+ ]
1740
2005
  ```
1741
2006
  """
1742
2007
  payload: Dict[str, Any] = {"inputs": text}
@@ -1745,11 +2010,11 @@ class InferenceClient:
1745
2010
  model=model,
1746
2011
  task="token-classification",
1747
2012
  )
1748
- return _bytes_to_list(response)
2013
+ return TokenClassificationOutputElement.parse_obj_as_list(response)
1749
2014
 
1750
2015
  def translation(
1751
2016
  self, text: str, *, model: Optional[str] = None, src_lang: Optional[str] = None, tgt_lang: Optional[str] = None
1752
- ) -> str:
2017
+ ) -> TranslationOutput:
1753
2018
  """
1754
2019
  Convert text from one language to another.
1755
2020
 
@@ -1772,7 +2037,7 @@ class InferenceClient:
1772
2037
  Target language of the translation task, i.e. output language. Cannot be passed without `src_lang`.
1773
2038
 
1774
2039
  Returns:
1775
- `str`: The generated translated text.
2040
+ [`TranslationOutput`]: The generated translated text.
1776
2041
 
1777
2042
  Raises:
1778
2043
  [`InferenceTimeoutError`]:
@@ -1789,7 +2054,7 @@ class InferenceClient:
1789
2054
  >>> client.translation("My name is Wolfgang and I live in Berlin")
1790
2055
  'Mein Name ist Wolfgang und ich lebe in Berlin.'
1791
2056
  >>> client.translation("My name is Wolfgang and I live in Berlin", model="Helsinki-NLP/opus-mt-en-fr")
1792
- "Je m'appelle Wolfgang et je vis à Berlin."
2057
+ TranslationOutput(translation_text='Je m\'appelle Wolfgang et je vis à Berlin.')
1793
2058
  ```
1794
2059
 
1795
2060
  Specifying languages:
@@ -1810,11 +2075,58 @@ class InferenceClient:
1810
2075
  if src_lang and tgt_lang:
1811
2076
  payload["parameters"] = {"src_lang": src_lang, "tgt_lang": tgt_lang}
1812
2077
  response = self.post(json=payload, model=model, task="translation")
1813
- return _bytes_to_dict(response)[0]["translation_text"]
2078
+ return TranslationOutput.parse_obj_as_list(response)[0]
2079
+
2080
+ def visual_question_answering(
2081
+ self,
2082
+ image: ContentT,
2083
+ question: str,
2084
+ *,
2085
+ model: Optional[str] = None,
2086
+ ) -> List[VisualQuestionAnsweringOutputElement]:
2087
+ """
2088
+ Answering open-ended questions based on an image.
2089
+
2090
+ Args:
2091
+ image (`Union[str, Path, bytes, BinaryIO]`):
2092
+ The input image for the context. It can be raw bytes, an image file, or a URL to an online image.
2093
+ question (`str`):
2094
+ Question to be answered.
2095
+ model (`str`, *optional*):
2096
+ The model to use for the visual question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to
2097
+ a deployed Inference Endpoint. If not provided, the default recommended visual question answering model will be used.
2098
+ Defaults to None.
2099
+
2100
+ Returns:
2101
+ `List[VisualQuestionAnsweringOutputElement]`: a list of [`VisualQuestionAnsweringOutputElement`] items containing the predicted label and associated probability.
2102
+
2103
+ Raises:
2104
+ `InferenceTimeoutError`:
2105
+ If the model is unavailable or the request times out.
2106
+ `HTTPError`:
2107
+ If the request fails with an HTTP error status code other than HTTP 503.
2108
+
2109
+ Example:
2110
+ ```py
2111
+ >>> from huggingface_hub import InferenceClient
2112
+ >>> client = InferenceClient()
2113
+ >>> client.visual_question_answering(
2114
+ ... image="https://huggingface.co/datasets/mishig/sample_images/resolve/main/tiger.jpg",
2115
+ ... question="What is the animal doing?"
2116
+ ... )
2117
+ [
2118
+ VisualQuestionAnsweringOutputElement(score=0.778609573841095, answer='laying down'),
2119
+ VisualQuestionAnsweringOutputElement(score=0.6957435607910156, answer='sitting'),
2120
+ ]
2121
+ ```
2122
+ """
2123
+ payload: Dict[str, Any] = {"question": question, "image": _b64_encode(image)}
2124
+ response = self.post(json=payload, model=model, task="visual-question-answering")
2125
+ return VisualQuestionAnsweringOutputElement.parse_obj_as_list(response)
1814
2126
 
1815
2127
  def zero_shot_classification(
1816
2128
  self, text: str, labels: List[str], *, multi_label: bool = False, model: Optional[str] = None
1817
- ) -> List[ClassificationOutput]:
2129
+ ) -> List[ZeroShotClassificationOutputElement]:
1818
2130
  """
1819
2131
  Provide as input a text and a set of candidate labels to classify the input text.
1820
2132
 
@@ -1830,7 +2142,7 @@ class InferenceClient:
1830
2142
  Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
1831
2143
 
1832
2144
  Returns:
1833
- `List[Dict]`: List of classification outputs containing the predicted labels and their confidence.
2145
+ `List[ZeroShotClassificationOutputElement]`: List of [`ZeroShotClassificationOutputElement`] items containing the predicted labels and their confidence.
1834
2146
 
1835
2147
  Raises:
1836
2148
  [`InferenceTimeoutError`]:
@@ -1850,19 +2162,19 @@ class InferenceClient:
1850
2162
  >>> labels = ["space & cosmos", "scientific discovery", "microbiology", "robots", "archeology"]
1851
2163
  >>> client.zero_shot_classification(text, labels)
1852
2164
  [
1853
- {"label": "scientific discovery", "score": 0.7961668968200684},
1854
- {"label": "space & cosmos", "score": 0.18570658564567566},
1855
- {"label": "microbiology", "score": 0.00730885099619627},
1856
- {"label": "archeology", "score": 0.006258360575884581},
1857
- {"label": "robots", "score": 0.004559356719255447},
2165
+ ZeroShotClassificationOutputElement(label='scientific discovery', score=0.7961668968200684),
2166
+ ZeroShotClassificationOutputElement(label='space & cosmos', score=0.18570658564567566),
2167
+ ZeroShotClassificationOutputElement(label='microbiology', score=0.00730885099619627),
2168
+ ZeroShotClassificationOutputElement(label='archeology', score=0.006258360575884581),
2169
+ ZeroShotClassificationOutputElement(label='robots', score=0.004559356719255447),
1858
2170
  ]
1859
2171
  >>> client.zero_shot_classification(text, labels, multi_label=True)
1860
2172
  [
1861
- {"label": "scientific discovery", "score": 0.9829297661781311},
1862
- {"label": "space & cosmos", "score": 0.755190908908844},
1863
- {"label": "microbiology", "score": 0.0005462635890580714},
1864
- {"label": "archeology", "score": 0.00047131875180639327},
1865
- {"label": "robots", "score": 0.00030448526376858354},
2173
+ ZeroShotClassificationOutputElement(label='scientific discovery', score=0.9829297661781311),
2174
+ ZeroShotClassificationOutputElement(label='space & cosmos', score=0.755190908908844),
2175
+ ZeroShotClassificationOutputElement(label='microbiology', score=0.0005462635890580714),
2176
+ ZeroShotClassificationOutputElement(label='archeology', score=0.00047131875180639327),
2177
+ ZeroShotClassificationOutputElement(label='robots', score=0.00030448526376858354),
1866
2178
  ]
1867
2179
  ```
1868
2180
  """
@@ -1882,11 +2194,14 @@ class InferenceClient:
1882
2194
  task="zero-shot-classification",
1883
2195
  )
1884
2196
  output = _bytes_to_dict(response)
1885
- return [{"label": label, "score": score} for label, score in zip(output["labels"], output["scores"])]
2197
+ return [
2198
+ ZeroShotClassificationOutputElement.parse_obj_as_instance({"label": label, "score": score})
2199
+ for label, score in zip(output["labels"], output["scores"])
2200
+ ]
1886
2201
 
1887
2202
  def zero_shot_image_classification(
1888
2203
  self, image: ContentT, labels: List[str], *, model: Optional[str] = None
1889
- ) -> List[ClassificationOutput]:
2204
+ ) -> List[ZeroShotImageClassificationOutputElement]:
1890
2205
  """
1891
2206
  Provide input image and text labels to predict text labels for the image.
1892
2207
 
@@ -1900,7 +2215,7 @@ class InferenceClient:
1900
2215
  Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
1901
2216
 
1902
2217
  Returns:
1903
- `List[Dict]`: List of classification outputs containing the predicted labels and their confidence.
2218
+ `List[ZeroShotImageClassificationOutputElement]`: List of [`ZeroShotImageClassificationOutputElement`] items containing the predicted labels and their confidence.
1904
2219
 
1905
2220
  Raises:
1906
2221
  [`InferenceTimeoutError`]:
@@ -1917,7 +2232,7 @@ class InferenceClient:
1917
2232
  ... "https://upload.wikimedia.org/wikipedia/commons/thumb/4/43/Cute_dog.jpg/320px-Cute_dog.jpg",
1918
2233
  ... labels=["dog", "cat", "horse"],
1919
2234
  ... )
1920
- [{"label": "dog", "score": 0.956}, ...]
2235
+ [ZeroShotImageClassificationOutputElement(label='dog', score=0.956),...]
1921
2236
  ```
1922
2237
  """
1923
2238
  # Raise ValueError if input is less than 2 labels
@@ -1929,7 +2244,7 @@ class InferenceClient:
1929
2244
  model=model,
1930
2245
  task="zero-shot-image-classification",
1931
2246
  )
1932
- return _bytes_to_list(response)
2247
+ return ZeroShotImageClassificationOutputElement.parse_obj_as_list(response)
1933
2248
 
1934
2249
  def _resolve_url(self, model: Optional[str] = None, task: Optional[str] = None) -> str:
1935
2250
  model = model or self.model