huggingface-hub 0.29.0rc2__py3-none-any.whl → 1.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. huggingface_hub/__init__.py +160 -46
  2. huggingface_hub/_commit_api.py +277 -71
  3. huggingface_hub/_commit_scheduler.py +15 -15
  4. huggingface_hub/_inference_endpoints.py +33 -22
  5. huggingface_hub/_jobs_api.py +301 -0
  6. huggingface_hub/_local_folder.py +18 -3
  7. huggingface_hub/_login.py +31 -63
  8. huggingface_hub/_oauth.py +460 -0
  9. huggingface_hub/_snapshot_download.py +241 -81
  10. huggingface_hub/_space_api.py +18 -10
  11. huggingface_hub/_tensorboard_logger.py +15 -19
  12. huggingface_hub/_upload_large_folder.py +196 -76
  13. huggingface_hub/_webhooks_payload.py +3 -3
  14. huggingface_hub/_webhooks_server.py +15 -25
  15. huggingface_hub/{commands → cli}/__init__.py +1 -15
  16. huggingface_hub/cli/_cli_utils.py +173 -0
  17. huggingface_hub/cli/auth.py +147 -0
  18. huggingface_hub/cli/cache.py +841 -0
  19. huggingface_hub/cli/download.py +189 -0
  20. huggingface_hub/cli/hf.py +60 -0
  21. huggingface_hub/cli/inference_endpoints.py +377 -0
  22. huggingface_hub/cli/jobs.py +772 -0
  23. huggingface_hub/cli/lfs.py +175 -0
  24. huggingface_hub/cli/repo.py +315 -0
  25. huggingface_hub/cli/repo_files.py +94 -0
  26. huggingface_hub/{commands/env.py → cli/system.py} +10 -13
  27. huggingface_hub/cli/upload.py +294 -0
  28. huggingface_hub/cli/upload_large_folder.py +117 -0
  29. huggingface_hub/community.py +20 -12
  30. huggingface_hub/constants.py +83 -59
  31. huggingface_hub/dataclasses.py +609 -0
  32. huggingface_hub/errors.py +99 -30
  33. huggingface_hub/fastai_utils.py +30 -41
  34. huggingface_hub/file_download.py +606 -346
  35. huggingface_hub/hf_api.py +2445 -1132
  36. huggingface_hub/hf_file_system.py +269 -152
  37. huggingface_hub/hub_mixin.py +61 -66
  38. huggingface_hub/inference/_client.py +501 -630
  39. huggingface_hub/inference/_common.py +133 -121
  40. huggingface_hub/inference/_generated/_async_client.py +536 -722
  41. huggingface_hub/inference/_generated/types/__init__.py +6 -1
  42. huggingface_hub/inference/_generated/types/automatic_speech_recognition.py +5 -6
  43. huggingface_hub/inference/_generated/types/base.py +10 -7
  44. huggingface_hub/inference/_generated/types/chat_completion.py +77 -31
  45. huggingface_hub/inference/_generated/types/depth_estimation.py +2 -2
  46. huggingface_hub/inference/_generated/types/document_question_answering.py +2 -2
  47. huggingface_hub/inference/_generated/types/feature_extraction.py +2 -2
  48. huggingface_hub/inference/_generated/types/fill_mask.py +2 -2
  49. huggingface_hub/inference/_generated/types/image_to_image.py +8 -2
  50. huggingface_hub/inference/_generated/types/image_to_text.py +2 -3
  51. huggingface_hub/inference/_generated/types/image_to_video.py +60 -0
  52. huggingface_hub/inference/_generated/types/sentence_similarity.py +3 -3
  53. huggingface_hub/inference/_generated/types/summarization.py +2 -2
  54. huggingface_hub/inference/_generated/types/table_question_answering.py +5 -5
  55. huggingface_hub/inference/_generated/types/text2text_generation.py +2 -2
  56. huggingface_hub/inference/_generated/types/text_generation.py +11 -11
  57. huggingface_hub/inference/_generated/types/text_to_audio.py +1 -2
  58. huggingface_hub/inference/_generated/types/text_to_speech.py +1 -2
  59. huggingface_hub/inference/_generated/types/text_to_video.py +2 -2
  60. huggingface_hub/inference/_generated/types/token_classification.py +2 -2
  61. huggingface_hub/inference/_generated/types/translation.py +2 -2
  62. huggingface_hub/inference/_generated/types/zero_shot_classification.py +2 -2
  63. huggingface_hub/inference/_generated/types/zero_shot_image_classification.py +2 -2
  64. huggingface_hub/inference/_generated/types/zero_shot_object_detection.py +1 -3
  65. huggingface_hub/inference/_mcp/__init__.py +0 -0
  66. huggingface_hub/inference/_mcp/_cli_hacks.py +88 -0
  67. huggingface_hub/inference/_mcp/agent.py +100 -0
  68. huggingface_hub/inference/_mcp/cli.py +247 -0
  69. huggingface_hub/inference/_mcp/constants.py +81 -0
  70. huggingface_hub/inference/_mcp/mcp_client.py +395 -0
  71. huggingface_hub/inference/_mcp/types.py +45 -0
  72. huggingface_hub/inference/_mcp/utils.py +128 -0
  73. huggingface_hub/inference/_providers/__init__.py +149 -20
  74. huggingface_hub/inference/_providers/_common.py +160 -37
  75. huggingface_hub/inference/_providers/black_forest_labs.py +12 -9
  76. huggingface_hub/inference/_providers/cerebras.py +6 -0
  77. huggingface_hub/inference/_providers/clarifai.py +13 -0
  78. huggingface_hub/inference/_providers/cohere.py +32 -0
  79. huggingface_hub/inference/_providers/fal_ai.py +231 -22
  80. huggingface_hub/inference/_providers/featherless_ai.py +38 -0
  81. huggingface_hub/inference/_providers/fireworks_ai.py +22 -1
  82. huggingface_hub/inference/_providers/groq.py +9 -0
  83. huggingface_hub/inference/_providers/hf_inference.py +143 -33
  84. huggingface_hub/inference/_providers/hyperbolic.py +9 -5
  85. huggingface_hub/inference/_providers/nebius.py +47 -5
  86. huggingface_hub/inference/_providers/novita.py +48 -5
  87. huggingface_hub/inference/_providers/nscale.py +44 -0
  88. huggingface_hub/inference/_providers/openai.py +25 -0
  89. huggingface_hub/inference/_providers/publicai.py +6 -0
  90. huggingface_hub/inference/_providers/replicate.py +46 -9
  91. huggingface_hub/inference/_providers/sambanova.py +37 -1
  92. huggingface_hub/inference/_providers/scaleway.py +28 -0
  93. huggingface_hub/inference/_providers/together.py +34 -5
  94. huggingface_hub/inference/_providers/wavespeed.py +138 -0
  95. huggingface_hub/inference/_providers/zai_org.py +17 -0
  96. huggingface_hub/lfs.py +33 -100
  97. huggingface_hub/repocard.py +34 -38
  98. huggingface_hub/repocard_data.py +79 -59
  99. huggingface_hub/serialization/__init__.py +0 -1
  100. huggingface_hub/serialization/_base.py +12 -15
  101. huggingface_hub/serialization/_dduf.py +8 -8
  102. huggingface_hub/serialization/_torch.py +69 -69
  103. huggingface_hub/utils/__init__.py +27 -8
  104. huggingface_hub/utils/_auth.py +7 -7
  105. huggingface_hub/utils/_cache_manager.py +92 -147
  106. huggingface_hub/utils/_chunk_utils.py +2 -3
  107. huggingface_hub/utils/_deprecation.py +1 -1
  108. huggingface_hub/utils/_dotenv.py +55 -0
  109. huggingface_hub/utils/_experimental.py +7 -5
  110. huggingface_hub/utils/_fixes.py +0 -10
  111. huggingface_hub/utils/_git_credential.py +5 -5
  112. huggingface_hub/utils/_headers.py +8 -30
  113. huggingface_hub/utils/_http.py +399 -237
  114. huggingface_hub/utils/_pagination.py +6 -6
  115. huggingface_hub/utils/_parsing.py +98 -0
  116. huggingface_hub/utils/_paths.py +5 -5
  117. huggingface_hub/utils/_runtime.py +74 -22
  118. huggingface_hub/utils/_safetensors.py +21 -21
  119. huggingface_hub/utils/_subprocess.py +13 -11
  120. huggingface_hub/utils/_telemetry.py +4 -4
  121. huggingface_hub/{commands/_cli_utils.py → utils/_terminal.py} +4 -4
  122. huggingface_hub/utils/_typing.py +25 -5
  123. huggingface_hub/utils/_validators.py +55 -74
  124. huggingface_hub/utils/_verification.py +167 -0
  125. huggingface_hub/utils/_xet.py +235 -0
  126. huggingface_hub/utils/_xet_progress_reporting.py +162 -0
  127. huggingface_hub/utils/insecure_hashlib.py +3 -5
  128. huggingface_hub/utils/logging.py +8 -11
  129. huggingface_hub/utils/tqdm.py +33 -4
  130. {huggingface_hub-0.29.0rc2.dist-info → huggingface_hub-1.1.3.dist-info}/METADATA +94 -82
  131. huggingface_hub-1.1.3.dist-info/RECORD +155 -0
  132. {huggingface_hub-0.29.0rc2.dist-info → huggingface_hub-1.1.3.dist-info}/WHEEL +1 -1
  133. huggingface_hub-1.1.3.dist-info/entry_points.txt +6 -0
  134. huggingface_hub/commands/delete_cache.py +0 -428
  135. huggingface_hub/commands/download.py +0 -200
  136. huggingface_hub/commands/huggingface_cli.py +0 -61
  137. huggingface_hub/commands/lfs.py +0 -200
  138. huggingface_hub/commands/repo_files.py +0 -128
  139. huggingface_hub/commands/scan_cache.py +0 -181
  140. huggingface_hub/commands/tag.py +0 -159
  141. huggingface_hub/commands/upload.py +0 -299
  142. huggingface_hub/commands/upload_large_folder.py +0 -129
  143. huggingface_hub/commands/user.py +0 -304
  144. huggingface_hub/commands/version.py +0 -37
  145. huggingface_hub/inference_api.py +0 -217
  146. huggingface_hub/keras_mixin.py +0 -500
  147. huggingface_hub/repository.py +0 -1477
  148. huggingface_hub/serialization/_tensorflow.py +0 -95
  149. huggingface_hub/utils/_hf_folder.py +0 -68
  150. huggingface_hub-0.29.0rc2.dist-info/RECORD +0 -131
  151. huggingface_hub-0.29.0rc2.dist-info/entry_points.txt +0 -6
  152. {huggingface_hub-0.29.0rc2.dist-info → huggingface_hub-1.1.3.dist-info/licenses}/LICENSE +0 -0
  153. {huggingface_hub-0.29.0rc2.dist-info → huggingface_hub-1.1.3.dist-info}/top_level.txt +0 -0
@@ -21,16 +21,19 @@
21
21
  import asyncio
22
22
  import base64
23
23
  import logging
24
+ import os
24
25
  import re
25
26
  import warnings
26
- from typing import TYPE_CHECKING, Any, AsyncIterable, Dict, List, Literal, Optional, Set, Union, overload
27
+ from contextlib import AsyncExitStack
28
+ from typing import TYPE_CHECKING, Any, AsyncIterable, Literal, Optional, Union, overload
29
+
30
+ import httpx
27
31
 
28
32
  from huggingface_hub import constants
29
- from huggingface_hub.errors import InferenceTimeoutError
33
+ from huggingface_hub.errors import BadRequestError, HfHubHTTPError, InferenceTimeoutError
30
34
  from huggingface_hub.inference._common import (
31
35
  TASKS_EXPECTING_IMAGES,
32
36
  ContentT,
33
- ModelStatus,
34
37
  RequestParameters,
35
38
  _async_stream_chat_completion_response,
36
39
  _async_stream_text_generation_response,
@@ -41,7 +44,6 @@ from huggingface_hub.inference._common import (
41
44
  _bytes_to_list,
42
45
  _get_unsupported_text_generation_kwargs,
43
46
  _import_numpy,
44
- _open_as_binary,
45
47
  _set_unsupported_text_generation_kwargs,
46
48
  raise_text_generation_error,
47
49
  )
@@ -51,6 +53,7 @@ from huggingface_hub.inference._generated.types import (
51
53
  AudioToAudioOutputElement,
52
54
  AutomaticSpeechRecognitionOutput,
53
55
  ChatCompletionInputGrammarType,
56
+ ChatCompletionInputMessage,
54
57
  ChatCompletionInputStreamOptions,
55
58
  ChatCompletionInputTool,
56
59
  ChatCompletionInputToolChoiceClass,
@@ -65,6 +68,7 @@ from huggingface_hub.inference._generated.types import (
65
68
  ImageSegmentationSubtask,
66
69
  ImageToImageTargetSize,
67
70
  ImageToTextOutput,
71
+ ImageToVideoTargetSize,
68
72
  ObjectDetectionOutputElement,
69
73
  Padding,
70
74
  QuestionAnsweringOutputElement,
@@ -85,16 +89,20 @@ from huggingface_hub.inference._generated.types import (
85
89
  ZeroShotClassificationOutputElement,
86
90
  ZeroShotImageClassificationOutputElement,
87
91
  )
88
- from huggingface_hub.inference._providers import PROVIDER_T, HFInferenceTask, get_provider_helper
89
- from huggingface_hub.utils import build_hf_headers, get_session, hf_raise_for_status
90
- from huggingface_hub.utils._deprecation import _deprecate_arguments, _deprecate_method
92
+ from huggingface_hub.inference._providers import PROVIDER_OR_POLICY_T, get_provider_helper
93
+ from huggingface_hub.utils import (
94
+ build_hf_headers,
95
+ get_async_session,
96
+ hf_raise_for_status,
97
+ validate_hf_hub_args,
98
+ )
99
+ from huggingface_hub.utils._auth import get_token
91
100
 
92
- from .._common import _async_yield_from, _import_aiohttp
101
+ from .._common import _async_yield_from
93
102
 
94
103
 
95
104
  if TYPE_CHECKING:
96
105
  import numpy as np
97
- from aiohttp import ClientResponse, ClientSession
98
106
  from PIL.Image import Image
99
107
 
100
108
  logger = logging.getLogger(__name__)
@@ -116,30 +124,25 @@ class AsyncInferenceClient:
116
124
  or a URL to a deployed Inference Endpoint. Defaults to None, in which case a recommended model is
117
125
  automatically selected for the task.
118
126
  Note: for better compatibility with OpenAI's client, `model` has been aliased as `base_url`. Those 2
119
- arguments are mutually exclusive. If using `base_url` for chat completion, the `/chat/completions` suffix
120
- path will be appended to the base URL (see the [TGI Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api)
121
- documentation for details). When passing a URL as `model`, the client will not append any suffix path to it.
127
+ arguments are mutually exclusive. If a URL is passed as `model` or `base_url` for chat completion, the `(/v1)/chat/completions` suffix path will be appended to the URL.
122
128
  provider (`str`, *optional*):
123
- Name of the provider to use for inference. Can be `"black-forest-labs"`, `"fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"replicate"`, "sambanova"` or `"together"`.
124
- defaults to hf-inference (Hugging Face Serverless Inference API).
129
+ Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"clarifai"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `"publicai"`, `"replicate"`, `"sambanova"`, `"scaleway"`, `"together"`, `"wavespeed"` or `"zai-org"`.
130
+ Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers.
125
131
  If model is a URL or `base_url` is passed, then `provider` is not used.
126
- token (`str` or `bool`, *optional*):
132
+ token (`str`, *optional*):
127
133
  Hugging Face token. Will default to the locally saved token if not provided.
128
- Pass `token=False` if you don't want to send your token to the server.
129
134
  Note: for better compatibility with OpenAI's client, `token` has been aliased as `api_key`. Those 2
130
135
  arguments are mutually exclusive and have the exact same behavior.
131
136
  timeout (`float`, `optional`):
132
- The maximum number of seconds to wait for a response from the server. Loading a new model in Inference
133
- API can take up to several minutes. Defaults to None, meaning it will loop until the server is available.
134
- headers (`Dict[str, str]`, `optional`):
137
+ The maximum number of seconds to wait for a response from the server. Defaults to None, meaning it will loop until the server is available.
138
+ headers (`dict[str, str]`, `optional`):
135
139
  Additional headers to send to the server. By default only the authorization and user-agent headers are sent.
136
140
  Values in this dictionary will override the default values.
137
- cookies (`Dict[str, str]`, `optional`):
141
+ bill_to (`str`, `optional`):
142
+ The billing account to use for the requests. By default the requests are billed on the user's account.
143
+ Requests can only be billed to an organization the user is a member of, and which has subscribed to Enterprise Hub.
144
+ cookies (`dict[str, str]`, `optional`):
138
145
  Additional cookies to send to the server.
139
- trust_env ('bool', 'optional'):
140
- Trust environment settings for proxy configuration if the parameter is `True` (`False` by default).
141
- proxies (`Any`, `optional`):
142
- Proxies to use for the request.
143
146
  base_url (`str`, `optional`):
144
147
  Base URL to run inference. This is a duplicated argument from `model` to make [`InferenceClient`]
145
148
  follow the same pattern as `openai.OpenAI` client. Cannot be used if `model` is set. Defaults to None.
@@ -148,17 +151,17 @@ class AsyncInferenceClient:
148
151
  follow the same pattern as `openai.OpenAI` client. Cannot be used if `token` is set. Defaults to None.
149
152
  """
150
153
 
154
+ @validate_hf_hub_args
151
155
  def __init__(
152
156
  self,
153
157
  model: Optional[str] = None,
154
158
  *,
155
- provider: Optional[PROVIDER_T] = None,
159
+ provider: Optional[PROVIDER_OR_POLICY_T] = None,
156
160
  token: Optional[str] = None,
157
161
  timeout: Optional[float] = None,
158
- headers: Optional[Dict[str, str]] = None,
159
- cookies: Optional[Dict[str, str]] = None,
160
- trust_env: bool = False,
161
- proxies: Optional[Any] = None,
162
+ headers: Optional[dict[str, str]] = None,
163
+ cookies: Optional[dict[str, str]] = None,
164
+ bill_to: Optional[str] = None,
162
165
  # OpenAI compatibility
163
166
  base_url: Optional[str] = None,
164
167
  api_key: Optional[str] = None,
@@ -176,101 +179,78 @@ class AsyncInferenceClient:
176
179
  " `api_key` is an alias for `token` to make the API compatible with OpenAI's client."
177
180
  " It has the exact same behavior as `token`."
178
181
  )
182
+ token = token if token is not None else api_key
183
+ if isinstance(token, bool):
184
+ # Legacy behavior: previously it was possible to pass `token=False` to disable authentication. This is not
185
+ # supported anymore as authentication is required. Better to explicitly raise here rather than risking
186
+ # sending the locally saved token without the user knowing about it.
187
+ if token is False:
188
+ raise ValueError(
189
+ "Cannot use `token=False` to disable authentication as authentication is required to run Inference."
190
+ )
191
+ warnings.warn(
192
+ "Using `token=True` to automatically use the locally saved token is deprecated and will be removed in a future release. "
193
+ "Please use `token=None` instead (default).",
194
+ DeprecationWarning,
195
+ )
196
+ token = get_token()
179
197
 
180
198
  self.model: Optional[str] = base_url or model
181
- self.token: Optional[str] = token if token is not None else api_key
182
- self.headers = headers if headers is not None else {}
199
+ self.token: Optional[str] = token
200
+
201
+ self.headers = {**headers} if headers is not None else {}
202
+ if bill_to is not None:
203
+ if (
204
+ constants.HUGGINGFACE_HEADER_X_BILL_TO in self.headers
205
+ and self.headers[constants.HUGGINGFACE_HEADER_X_BILL_TO] != bill_to
206
+ ):
207
+ warnings.warn(
208
+ f"Overriding existing '{self.headers[constants.HUGGINGFACE_HEADER_X_BILL_TO]}' value in headers with '{bill_to}'.",
209
+ UserWarning,
210
+ )
211
+ self.headers[constants.HUGGINGFACE_HEADER_X_BILL_TO] = bill_to
212
+
213
+ if token is not None and not token.startswith("hf_"):
214
+ warnings.warn(
215
+ "You've provided an external provider's API key, so requests will be billed directly by the provider. "
216
+ "The `bill_to` parameter is only applicable for Hugging Face billing and will be ignored.",
217
+ UserWarning,
218
+ )
183
219
 
184
220
  # Configure provider
185
- self.provider = provider if provider is not None else "hf-inference"
221
+ self.provider = provider
186
222
 
187
223
  self.cookies = cookies
188
224
  self.timeout = timeout
189
- self.trust_env = trust_env
190
- self.proxies = proxies
191
225
 
192
- # Keep track of the sessions to close them properly
193
- self._sessions: Dict["ClientSession", Set["ClientResponse"]] = dict()
226
+ self.exit_stack = AsyncExitStack()
227
+ self._async_client: Optional[httpx.AsyncClient] = None
194
228
 
195
229
  def __repr__(self):
196
230
  return f"<InferenceClient(model='{self.model if self.model else ''}', timeout={self.timeout})>"
197
231
 
198
- @overload
199
- async def post( # type: ignore[misc]
200
- self,
201
- *,
202
- json: Optional[Union[str, Dict, List]] = None,
203
- data: Optional[ContentT] = None,
204
- model: Optional[str] = None,
205
- task: Optional[str] = None,
206
- stream: Literal[False] = ...,
207
- ) -> bytes: ...
232
+ async def __aenter__(self):
233
+ return self
208
234
 
209
- @overload
210
- async def post( # type: ignore[misc]
211
- self,
212
- *,
213
- json: Optional[Union[str, Dict, List]] = None,
214
- data: Optional[ContentT] = None,
215
- model: Optional[str] = None,
216
- task: Optional[str] = None,
217
- stream: Literal[True] = ...,
218
- ) -> AsyncIterable[bytes]: ...
235
+ async def __aexit__(self, exc_type, exc_value, traceback):
236
+ await self.close()
219
237
 
220
- @overload
221
- async def post(
222
- self,
223
- *,
224
- json: Optional[Union[str, Dict, List]] = None,
225
- data: Optional[ContentT] = None,
226
- model: Optional[str] = None,
227
- task: Optional[str] = None,
228
- stream: bool = False,
229
- ) -> Union[bytes, AsyncIterable[bytes]]: ...
230
-
231
- @_deprecate_method(
232
- version="0.31.0",
233
- message=(
234
- "Making direct POST requests to the inference server is not supported anymore. "
235
- "Please use task methods instead (e.g. `InferenceClient.chat_completion`). "
236
- "If your use case is not supported, please open an issue in https://github.com/huggingface/huggingface_hub."
237
- ),
238
- )
239
- async def post(
240
- self,
241
- *,
242
- json: Optional[Union[str, Dict, List]] = None,
243
- data: Optional[ContentT] = None,
244
- model: Optional[str] = None,
245
- task: Optional[str] = None,
246
- stream: bool = False,
247
- ) -> Union[bytes, AsyncIterable[bytes]]:
238
+ async def close(self):
239
+ """Close the client.
240
+
241
+ This method is automatically called when using the client as a context manager.
248
242
  """
249
- Make a POST request to the inference server.
243
+ await self.exit_stack.aclose()
250
244
 
251
- This method is deprecated and will be removed in the future.
252
- Please use task methods instead (e.g. `InferenceClient.chat_completion`).
245
+ async def _get_async_client(self):
246
+ """Get a unique async client for this AsyncInferenceClient instance.
247
+
248
+ Returns the same client instance on subsequent calls, ensuring proper
249
+ connection reuse and resource management through the exit stack.
253
250
  """
254
- if self.provider != "hf-inference":
255
- raise ValueError(
256
- "Cannot use `post` with another provider than `hf-inference`. "
257
- "`InferenceClient.post` is deprecated and should not be used directly anymore."
258
- )
259
- provider_helper = HFInferenceTask(task or "unknown")
260
- mapped_model = provider_helper._prepare_mapped_model(model or self.model)
261
- url = provider_helper._prepare_url(self.token, mapped_model) # type: ignore[arg-type]
262
- headers = provider_helper._prepare_headers(self.headers, self.token) # type: ignore[arg-type]
263
- return await self._inner_post(
264
- request_parameters=RequestParameters(
265
- url=url,
266
- task=task or "unknown",
267
- model=model or "unknown",
268
- json=json,
269
- data=data,
270
- headers=headers,
271
- ),
272
- stream=stream,
273
- )
251
+ if self._async_client is None:
252
+ self._async_client = await self.exit_stack.enter_async_context(get_async_session())
253
+ return self._async_client
274
254
 
275
255
  @overload
276
256
  async def _inner_post( # type: ignore[misc]
@@ -280,84 +260,59 @@ class AsyncInferenceClient:
280
260
  @overload
281
261
  async def _inner_post( # type: ignore[misc]
282
262
  self, request_parameters: RequestParameters, *, stream: Literal[True] = ...
283
- ) -> AsyncIterable[bytes]: ...
263
+ ) -> AsyncIterable[str]: ...
284
264
 
285
265
  @overload
286
266
  async def _inner_post(
287
267
  self, request_parameters: RequestParameters, *, stream: bool = False
288
- ) -> Union[bytes, AsyncIterable[bytes]]: ...
268
+ ) -> Union[bytes, AsyncIterable[str]]: ...
289
269
 
290
270
  async def _inner_post(
291
271
  self, request_parameters: RequestParameters, *, stream: bool = False
292
- ) -> Union[bytes, AsyncIterable[bytes]]:
272
+ ) -> Union[bytes, AsyncIterable[str]]:
293
273
  """Make a request to the inference server."""
294
274
 
295
- aiohttp = _import_aiohttp()
296
-
297
275
  # TODO: this should be handled in provider helpers directly
298
276
  if request_parameters.task in TASKS_EXPECTING_IMAGES and "Accept" not in request_parameters.headers:
299
277
  request_parameters.headers["Accept"] = "image/png"
300
278
 
301
- while True:
302
- with _open_as_binary(request_parameters.data) as data_as_binary:
303
- # Do not use context manager as we don't want to close the connection immediately when returning
304
- # a stream
305
- session = self._get_client_session(headers=request_parameters.headers)
306
-
307
- try:
308
- response = await session.post(
309
- request_parameters.url, json=request_parameters.json, data=data_as_binary, proxy=self.proxies
279
+ try:
280
+ client = await self._get_async_client()
281
+ if stream:
282
+ response = await self.exit_stack.enter_async_context(
283
+ client.stream(
284
+ "POST",
285
+ request_parameters.url,
286
+ json=request_parameters.json,
287
+ data=request_parameters.data,
288
+ headers=request_parameters.headers,
289
+ cookies=self.cookies,
290
+ timeout=self.timeout,
310
291
  )
311
- response_error_payload = None
312
- if response.status != 200:
313
- try:
314
- response_error_payload = await response.json() # get payload before connection closed
315
- except Exception:
316
- pass
317
- response.raise_for_status()
318
- if stream:
319
- return _async_yield_from(session, response)
320
- else:
321
- content = await response.read()
322
- await session.close()
323
- return content
324
- except asyncio.TimeoutError as error:
325
- await session.close()
326
- # Convert any `TimeoutError` to a `InferenceTimeoutError`
327
- raise InferenceTimeoutError(f"Inference call timed out: {request_parameters.url}") from error # type: ignore
328
- except aiohttp.ClientResponseError as error:
329
- error.response_error_payload = response_error_payload
330
- await session.close()
331
- raise error
332
- except Exception:
333
- await session.close()
334
- raise
335
-
336
- async def __aenter__(self):
337
- return self
338
-
339
- async def __aexit__(self, exc_type, exc_value, traceback):
340
- await self.close()
341
-
342
- def __del__(self):
343
- if len(self._sessions) > 0:
344
- warnings.warn(
345
- "Deleting 'AsyncInferenceClient' client but some sessions are still open. "
346
- "This can happen if you've stopped streaming data from the server before the stream was complete. "
347
- "To close the client properly, you must call `await client.close()` "
348
- "or use an async context (e.g. `async with AsyncInferenceClient(): ...`."
349
- )
350
-
351
- async def close(self):
352
- """Close all open sessions.
353
-
354
- By default, 'aiohttp.ClientSession' objects are closed automatically when a call is completed. However, if you
355
- are streaming data from the server and you stop before the stream is complete, you must call this method to
356
- close the session properly.
357
-
358
- Another possibility is to use an async context (e.g. `async with AsyncInferenceClient(): ...`).
359
- """
360
- await asyncio.gather(*[session.close() for session in self._sessions.keys()])
292
+ )
293
+ hf_raise_for_status(response)
294
+ return _async_yield_from(client, response)
295
+ else:
296
+ response = await client.post(
297
+ request_parameters.url,
298
+ json=request_parameters.json,
299
+ data=request_parameters.data,
300
+ headers=request_parameters.headers,
301
+ cookies=self.cookies,
302
+ timeout=self.timeout,
303
+ )
304
+ hf_raise_for_status(response)
305
+ return response.content
306
+ except asyncio.TimeoutError as error:
307
+ # Convert any `TimeoutError` to a `InferenceTimeoutError`
308
+ raise InferenceTimeoutError(f"Inference call timed out: {request_parameters.url}") from error # type: ignore
309
+ except HfHubHTTPError as error:
310
+ if error.response.status_code == 422 and request_parameters.task != "unknown":
311
+ msg = str(error.args[0])
312
+ if len(error.response.text) > 0:
313
+ msg += f"{os.linesep}{error.response.text}{os.linesep}"
314
+ error.args = (msg,) + error.args[1:]
315
+ raise
361
316
 
362
317
  async def audio_classification(
363
318
  self,
@@ -366,7 +321,7 @@ class AsyncInferenceClient:
366
321
  model: Optional[str] = None,
367
322
  top_k: Optional[int] = None,
368
323
  function_to_apply: Optional["AudioClassificationOutputTransform"] = None,
369
- ) -> List[AudioClassificationOutputElement]:
324
+ ) -> list[AudioClassificationOutputElement]:
370
325
  """
371
326
  Perform audio classification on the provided audio content.
372
327
 
@@ -384,12 +339,12 @@ class AsyncInferenceClient:
384
339
  The function to apply to the model outputs in order to retrieve the scores.
385
340
 
386
341
  Returns:
387
- `List[AudioClassificationOutputElement]`: List of [`AudioClassificationOutputElement`] items containing the predicted labels and their confidence.
342
+ `list[AudioClassificationOutputElement]`: List of [`AudioClassificationOutputElement`] items containing the predicted labels and their confidence.
388
343
 
389
344
  Raises:
390
345
  [`InferenceTimeoutError`]:
391
346
  If the model is unavailable or the request times out.
392
- `aiohttp.ClientResponseError`:
347
+ [`HfHubHTTPError`]:
393
348
  If the request fails with an HTTP error status code other than HTTP 503.
394
349
 
395
350
  Example:
@@ -405,12 +360,13 @@ class AsyncInferenceClient:
405
360
  ]
406
361
  ```
407
362
  """
408
- provider_helper = get_provider_helper(self.provider, task="audio-classification")
363
+ model_id = model or self.model
364
+ provider_helper = get_provider_helper(self.provider, task="audio-classification", model=model_id)
409
365
  request_parameters = provider_helper.prepare_request(
410
366
  inputs=audio,
411
367
  parameters={"function_to_apply": function_to_apply, "top_k": top_k},
412
368
  headers=self.headers,
413
- model=model or self.model,
369
+ model=model_id,
414
370
  api_key=self.token,
415
371
  )
416
372
  response = await self._inner_post(request_parameters)
@@ -421,7 +377,7 @@ class AsyncInferenceClient:
421
377
  audio: ContentT,
422
378
  *,
423
379
  model: Optional[str] = None,
424
- ) -> List[AudioToAudioOutputElement]:
380
+ ) -> list[AudioToAudioOutputElement]:
425
381
  """
426
382
  Performs multiple tasks related to audio-to-audio depending on the model (eg: speech enhancement, source separation).
427
383
 
@@ -435,12 +391,12 @@ class AsyncInferenceClient:
435
391
  audio_to_audio will be used.
436
392
 
437
393
  Returns:
438
- `List[AudioToAudioOutputElement]`: A list of [`AudioToAudioOutputElement`] items containing audios label, content-type, and audio content in blob.
394
+ `list[AudioToAudioOutputElement]`: A list of [`AudioToAudioOutputElement`] items containing audios label, content-type, and audio content in blob.
439
395
 
440
396
  Raises:
441
397
  `InferenceTimeoutError`:
442
398
  If the model is unavailable or the request times out.
443
- `aiohttp.ClientResponseError`:
399
+ [`HfHubHTTPError`]:
444
400
  If the request fails with an HTTP error status code other than HTTP 503.
445
401
 
446
402
  Example:
@@ -454,12 +410,13 @@ class AsyncInferenceClient:
454
410
  f.write(item.blob)
455
411
  ```
456
412
  """
457
- provider_helper = get_provider_helper(self.provider, task="audio-to-audio")
413
+ model_id = model or self.model
414
+ provider_helper = get_provider_helper(self.provider, task="audio-to-audio", model=model_id)
458
415
  request_parameters = provider_helper.prepare_request(
459
416
  inputs=audio,
460
417
  parameters={},
461
418
  headers=self.headers,
462
- model=model or self.model,
419
+ model=model_id,
463
420
  api_key=self.token,
464
421
  )
465
422
  response = await self._inner_post(request_parameters)
@@ -473,7 +430,7 @@ class AsyncInferenceClient:
473
430
  audio: ContentT,
474
431
  *,
475
432
  model: Optional[str] = None,
476
- extra_body: Optional[Dict] = None,
433
+ extra_body: Optional[dict] = None,
477
434
  ) -> AutomaticSpeechRecognitionOutput:
478
435
  """
479
436
  Perform automatic speech recognition (ASR or audio-to-text) on the given audio content.
@@ -484,7 +441,7 @@ class AsyncInferenceClient:
484
441
  model (`str`, *optional*):
485
442
  The model to use for ASR. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
486
443
  Inference Endpoint. If not provided, the default recommended model for ASR will be used.
487
- extra_body (`Dict`, *optional*):
444
+ extra_body (`dict`, *optional*):
488
445
  Additional provider-specific parameters to pass to the model. Refer to the provider's documentation
489
446
  for supported parameters.
490
447
  Returns:
@@ -493,7 +450,7 @@ class AsyncInferenceClient:
493
450
  Raises:
494
451
  [`InferenceTimeoutError`]:
495
452
  If the model is unavailable or the request times out.
496
- `aiohttp.ClientResponseError`:
453
+ [`HfHubHTTPError`]:
497
454
  If the request fails with an HTTP error status code other than HTTP 503.
498
455
 
499
456
  Example:
@@ -505,12 +462,13 @@ class AsyncInferenceClient:
505
462
  "hello world"
506
463
  ```
507
464
  """
508
- provider_helper = get_provider_helper(self.provider, task="automatic-speech-recognition")
465
+ model_id = model or self.model
466
+ provider_helper = get_provider_helper(self.provider, task="automatic-speech-recognition", model=model_id)
509
467
  request_parameters = provider_helper.prepare_request(
510
468
  inputs=audio,
511
469
  parameters={**(extra_body or {})},
512
470
  headers=self.headers,
513
- model=model or self.model,
471
+ model=model_id,
514
472
  api_key=self.token,
515
473
  )
516
474
  response = await self._inner_post(request_parameters)
@@ -519,121 +477,117 @@ class AsyncInferenceClient:
519
477
  @overload
520
478
  async def chat_completion( # type: ignore
521
479
  self,
522
- messages: List[Dict],
480
+ messages: list[Union[dict, ChatCompletionInputMessage]],
523
481
  *,
524
482
  model: Optional[str] = None,
525
483
  stream: Literal[False] = False,
526
484
  frequency_penalty: Optional[float] = None,
527
- logit_bias: Optional[List[float]] = None,
485
+ logit_bias: Optional[list[float]] = None,
528
486
  logprobs: Optional[bool] = None,
529
487
  max_tokens: Optional[int] = None,
530
488
  n: Optional[int] = None,
531
489
  presence_penalty: Optional[float] = None,
532
490
  response_format: Optional[ChatCompletionInputGrammarType] = None,
533
491
  seed: Optional[int] = None,
534
- stop: Optional[List[str]] = None,
492
+ stop: Optional[list[str]] = None,
535
493
  stream_options: Optional[ChatCompletionInputStreamOptions] = None,
536
494
  temperature: Optional[float] = None,
537
495
  tool_choice: Optional[Union[ChatCompletionInputToolChoiceClass, "ChatCompletionInputToolChoiceEnum"]] = None,
538
496
  tool_prompt: Optional[str] = None,
539
- tools: Optional[List[ChatCompletionInputTool]] = None,
497
+ tools: Optional[list[ChatCompletionInputTool]] = None,
540
498
  top_logprobs: Optional[int] = None,
541
499
  top_p: Optional[float] = None,
542
- extra_body: Optional[Dict] = None,
500
+ extra_body: Optional[dict] = None,
543
501
  ) -> ChatCompletionOutput: ...
544
502
 
545
503
  @overload
546
504
  async def chat_completion( # type: ignore
547
505
  self,
548
- messages: List[Dict],
506
+ messages: list[Union[dict, ChatCompletionInputMessage]],
549
507
  *,
550
508
  model: Optional[str] = None,
551
509
  stream: Literal[True] = True,
552
510
  frequency_penalty: Optional[float] = None,
553
- logit_bias: Optional[List[float]] = None,
511
+ logit_bias: Optional[list[float]] = None,
554
512
  logprobs: Optional[bool] = None,
555
513
  max_tokens: Optional[int] = None,
556
514
  n: Optional[int] = None,
557
515
  presence_penalty: Optional[float] = None,
558
516
  response_format: Optional[ChatCompletionInputGrammarType] = None,
559
517
  seed: Optional[int] = None,
560
- stop: Optional[List[str]] = None,
518
+ stop: Optional[list[str]] = None,
561
519
  stream_options: Optional[ChatCompletionInputStreamOptions] = None,
562
520
  temperature: Optional[float] = None,
563
521
  tool_choice: Optional[Union[ChatCompletionInputToolChoiceClass, "ChatCompletionInputToolChoiceEnum"]] = None,
564
522
  tool_prompt: Optional[str] = None,
565
- tools: Optional[List[ChatCompletionInputTool]] = None,
523
+ tools: Optional[list[ChatCompletionInputTool]] = None,
566
524
  top_logprobs: Optional[int] = None,
567
525
  top_p: Optional[float] = None,
568
- extra_body: Optional[Dict] = None,
526
+ extra_body: Optional[dict] = None,
569
527
  ) -> AsyncIterable[ChatCompletionStreamOutput]: ...
570
528
 
571
529
  @overload
572
530
  async def chat_completion(
573
531
  self,
574
- messages: List[Dict],
532
+ messages: list[Union[dict, ChatCompletionInputMessage]],
575
533
  *,
576
534
  model: Optional[str] = None,
577
535
  stream: bool = False,
578
536
  frequency_penalty: Optional[float] = None,
579
- logit_bias: Optional[List[float]] = None,
537
+ logit_bias: Optional[list[float]] = None,
580
538
  logprobs: Optional[bool] = None,
581
539
  max_tokens: Optional[int] = None,
582
540
  n: Optional[int] = None,
583
541
  presence_penalty: Optional[float] = None,
584
542
  response_format: Optional[ChatCompletionInputGrammarType] = None,
585
543
  seed: Optional[int] = None,
586
- stop: Optional[List[str]] = None,
544
+ stop: Optional[list[str]] = None,
587
545
  stream_options: Optional[ChatCompletionInputStreamOptions] = None,
588
546
  temperature: Optional[float] = None,
589
547
  tool_choice: Optional[Union[ChatCompletionInputToolChoiceClass, "ChatCompletionInputToolChoiceEnum"]] = None,
590
548
  tool_prompt: Optional[str] = None,
591
- tools: Optional[List[ChatCompletionInputTool]] = None,
549
+ tools: Optional[list[ChatCompletionInputTool]] = None,
592
550
  top_logprobs: Optional[int] = None,
593
551
  top_p: Optional[float] = None,
594
- extra_body: Optional[Dict] = None,
552
+ extra_body: Optional[dict] = None,
595
553
  ) -> Union[ChatCompletionOutput, AsyncIterable[ChatCompletionStreamOutput]]: ...
596
554
 
597
555
  async def chat_completion(
598
556
  self,
599
- messages: List[Dict],
557
+ messages: list[Union[dict, ChatCompletionInputMessage]],
600
558
  *,
601
559
  model: Optional[str] = None,
602
560
  stream: bool = False,
603
561
  # Parameters from ChatCompletionInput (handled manually)
604
562
  frequency_penalty: Optional[float] = None,
605
- logit_bias: Optional[List[float]] = None,
563
+ logit_bias: Optional[list[float]] = None,
606
564
  logprobs: Optional[bool] = None,
607
565
  max_tokens: Optional[int] = None,
608
566
  n: Optional[int] = None,
609
567
  presence_penalty: Optional[float] = None,
610
568
  response_format: Optional[ChatCompletionInputGrammarType] = None,
611
569
  seed: Optional[int] = None,
612
- stop: Optional[List[str]] = None,
570
+ stop: Optional[list[str]] = None,
613
571
  stream_options: Optional[ChatCompletionInputStreamOptions] = None,
614
572
  temperature: Optional[float] = None,
615
573
  tool_choice: Optional[Union[ChatCompletionInputToolChoiceClass, "ChatCompletionInputToolChoiceEnum"]] = None,
616
574
  tool_prompt: Optional[str] = None,
617
- tools: Optional[List[ChatCompletionInputTool]] = None,
575
+ tools: Optional[list[ChatCompletionInputTool]] = None,
618
576
  top_logprobs: Optional[int] = None,
619
577
  top_p: Optional[float] = None,
620
- extra_body: Optional[Dict] = None,
578
+ extra_body: Optional[dict] = None,
621
579
  ) -> Union[ChatCompletionOutput, AsyncIterable[ChatCompletionStreamOutput]]:
622
580
  """
623
581
  A method for completing conversations using a specified language model.
624
582
 
625
- <Tip>
626
-
627
- The `client.chat_completion` method is aliased as `client.chat.completions.create` for compatibility with OpenAI's client.
628
- Inputs and outputs are strictly the same and using either syntax will yield the same results.
629
- Check out the [Inference guide](https://huggingface.co/docs/huggingface_hub/guides/inference#openai-compatibility)
630
- for more details about OpenAI's compatibility.
583
+ > [!TIP]
584
+ > The `client.chat_completion` method is aliased as `client.chat.completions.create` for compatibility with OpenAI's client.
585
+ > Inputs and outputs are strictly the same and using either syntax will yield the same results.
586
+ > Check out the [Inference guide](https://huggingface.co/docs/huggingface_hub/guides/inference#openai-compatibility)
587
+ > for more details about OpenAI's compatibility.
631
588
 
632
- </Tip>
633
-
634
- <Tip>
635
- You can pass provider-specific parameters to the model by using the `extra_body` argument.
636
- </Tip>
589
+ > [!TIP]
590
+ > You can pass provider-specific parameters to the model by using the `extra_body` argument.
637
591
 
638
592
  Args:
639
593
  messages (List of [`ChatCompletionInputMessage`]):
@@ -647,7 +601,7 @@ class AsyncInferenceClient:
647
601
  frequency_penalty (`float`, *optional*):
648
602
  Penalizes new tokens based on their existing frequency
649
603
  in the text so far. Range: [-2.0, 2.0]. Defaults to 0.0.
650
- logit_bias (`List[float]`, *optional*):
604
+ logit_bias (`list[float]`, *optional*):
651
605
  Adjusts the likelihood of specific tokens appearing in the generated output.
652
606
  logprobs (`bool`, *optional*):
653
607
  Whether to return log probabilities of the output tokens or not. If true, returns the log
@@ -663,7 +617,7 @@ class AsyncInferenceClient:
663
617
  Grammar constraints. Can be either a JSONSchema or a regex.
664
618
  seed (Optional[`int`], *optional*):
665
619
  Seed for reproducible control flow. Defaults to None.
666
- stop (`List[str]`, *optional*):
620
+ stop (`list[str]`, *optional*):
667
621
  Up to four strings which trigger the end of the response.
668
622
  Defaults to None.
669
623
  stream (`bool`, *optional*):
@@ -687,7 +641,7 @@ class AsyncInferenceClient:
687
641
  tools (List of [`ChatCompletionInputTool`], *optional*):
688
642
  A list of tools the model may call. Currently, only functions are supported as a tool. Use this to
689
643
  provide a list of functions the model may generate JSON inputs for.
690
- extra_body (`Dict`, *optional*):
644
+ extra_body (`dict`, *optional*):
691
645
  Additional provider-specific parameters to pass to the model. Refer to the provider's documentation
692
646
  for supported parameters.
693
647
  Returns:
@@ -699,7 +653,7 @@ class AsyncInferenceClient:
699
653
  Raises:
700
654
  [`InferenceTimeoutError`]:
701
655
  If the model is unavailable or the request times out.
702
- `aiohttp.ClientResponseError`:
656
+ [`HfHubHTTPError`]:
703
657
  If the request fails with an HTTP error status code other than HTTP 503.
704
658
 
705
659
  Example:
@@ -931,7 +885,7 @@ class AsyncInferenceClient:
931
885
  >>> messages = [
932
886
  ... {
933
887
  ... "role": "user",
934
- ... "content": "I saw a puppy a cat and a raccoon during my bike ride in the park. What did I saw and when?",
888
+ ... "content": "I saw a puppy a cat and a raccoon during my bike ride in the park. What did I see and when?",
935
889
  ... },
936
890
  ... ]
937
891
  >>> response_format = {
@@ -950,20 +904,26 @@ class AsyncInferenceClient:
950
904
  ... messages=messages,
951
905
  ... response_format=response_format,
952
906
  ... max_tokens=500,
953
- )
907
+ ... )
954
908
  >>> response.choices[0].message.content
955
909
  '{\n\n"activity": "bike ride",\n"animals": ["puppy", "cat", "raccoon"],\n"animals_seen": 3,\n"location": "park"}'
956
910
  ```
957
911
  """
958
- # Get the provider helper
959
- provider_helper = get_provider_helper(self.provider, task="conversational")
960
-
961
912
  # Since `chat_completion(..., model=xxx)` is also a payload parameter for the server, we need to handle 'model' differently.
962
913
  # `self.model` takes precedence over 'model' argument for building URL.
963
914
  # `model` takes precedence for payload value.
964
915
  model_id_or_url = self.model or model
965
916
  payload_model = model or self.model
966
917
 
918
+ # Get the provider helper
919
+ provider_helper = get_provider_helper(
920
+ self.provider,
921
+ task="conversational",
922
+ model=model_id_or_url
923
+ if model_id_or_url is not None and model_id_or_url.startswith(("http://", "https://"))
924
+ else payload_model,
925
+ )
926
+
967
927
  # Prepare the payload
968
928
  parameters = {
969
929
  "model": payload_model,
@@ -1013,8 +973,8 @@ class AsyncInferenceClient:
1013
973
  max_question_len: Optional[int] = None,
1014
974
  max_seq_len: Optional[int] = None,
1015
975
  top_k: Optional[int] = None,
1016
- word_boxes: Optional[List[Union[List[float], str]]] = None,
1017
- ) -> List[DocumentQuestionAnsweringOutputElement]:
976
+ word_boxes: Optional[list[Union[list[float], str]]] = None,
977
+ ) -> list[DocumentQuestionAnsweringOutputElement]:
1018
978
  """
1019
979
  Answer questions on document images.
1020
980
 
@@ -1044,16 +1004,16 @@ class AsyncInferenceClient:
1044
1004
  top_k (`int`, *optional*):
1045
1005
  The number of answers to return (will be chosen by order of likelihood). Can return less than top_k
1046
1006
  answers if there are not enough options available within the context.
1047
- word_boxes (`List[Union[List[float], str`, *optional*):
1007
+ word_boxes (`list[Union[list[float], str`, *optional*):
1048
1008
  A list of words and bounding boxes (normalized 0->1000). If provided, the inference will skip the OCR
1049
1009
  step and use the provided bounding boxes instead.
1050
1010
  Returns:
1051
- `List[DocumentQuestionAnsweringOutputElement]`: a list of [`DocumentQuestionAnsweringOutputElement`] items containing the predicted label, associated probability, word ids, and page number.
1011
+ `list[DocumentQuestionAnsweringOutputElement]`: a list of [`DocumentQuestionAnsweringOutputElement`] items containing the predicted label, associated probability, word ids, and page number.
1052
1012
 
1053
1013
  Raises:
1054
1014
  [`InferenceTimeoutError`]:
1055
1015
  If the model is unavailable or the request times out.
1056
- `aiohttp.ClientResponseError`:
1016
+ [`HfHubHTTPError`]:
1057
1017
  If the request fails with an HTTP error status code other than HTTP 503.
1058
1018
 
1059
1019
 
@@ -1066,8 +1026,9 @@ class AsyncInferenceClient:
1066
1026
  [DocumentQuestionAnsweringOutputElement(answer='us-001', end=16, score=0.9999666213989258, start=16)]
1067
1027
  ```
1068
1028
  """
1069
- inputs: Dict[str, Any] = {"question": question, "image": _b64_encode(image)}
1070
- provider_helper = get_provider_helper(self.provider, task="document-question-answering")
1029
+ model_id = model or self.model
1030
+ provider_helper = get_provider_helper(self.provider, task="document-question-answering", model=model_id)
1031
+ inputs: dict[str, Any] = {"question": question, "image": _b64_encode(image)}
1071
1032
  request_parameters = provider_helper.prepare_request(
1072
1033
  inputs=inputs,
1073
1034
  parameters={
@@ -1081,7 +1042,7 @@ class AsyncInferenceClient:
1081
1042
  "word_boxes": word_boxes,
1082
1043
  },
1083
1044
  headers=self.headers,
1084
- model=model or self.model,
1045
+ model=model_id,
1085
1046
  api_key=self.token,
1086
1047
  )
1087
1048
  response = await self._inner_post(request_parameters)
@@ -1104,8 +1065,8 @@ class AsyncInferenceClient:
1104
1065
  text (`str`):
1105
1066
  The text to embed.
1106
1067
  model (`str`, *optional*):
1107
- The model to use for the conversational task. Can be a model ID hosted on the Hugging Face Hub or a URL to
1108
- a deployed Inference Endpoint. If not provided, the default recommended conversational model will be used.
1068
+ The model to use for the feature extraction task. Can be a model ID hosted on the Hugging Face Hub or a URL to
1069
+ a deployed Inference Endpoint. If not provided, the default recommended feature extraction model will be used.
1109
1070
  Defaults to None.
1110
1071
  normalize (`bool`, *optional*):
1111
1072
  Whether to normalize the embeddings or not.
@@ -1128,7 +1089,7 @@ class AsyncInferenceClient:
1128
1089
  Raises:
1129
1090
  [`InferenceTimeoutError`]:
1130
1091
  If the model is unavailable or the request times out.
1131
- `aiohttp.ClientResponseError`:
1092
+ [`HfHubHTTPError`]:
1132
1093
  If the request fails with an HTTP error status code other than HTTP 503.
1133
1094
 
1134
1095
  Example:
@@ -1143,7 +1104,8 @@ class AsyncInferenceClient:
1143
1104
  [ 0.28552425, -0.928395 , -1.2077185 , ..., 0.76810825, -2.1069427 , 0.6236161 ]], dtype=float32)
1144
1105
  ```
1145
1106
  """
1146
- provider_helper = get_provider_helper(self.provider, task="feature-extraction")
1107
+ model_id = model or self.model
1108
+ provider_helper = get_provider_helper(self.provider, task="feature-extraction", model=model_id)
1147
1109
  request_parameters = provider_helper.prepare_request(
1148
1110
  inputs=text,
1149
1111
  parameters={
@@ -1153,21 +1115,21 @@ class AsyncInferenceClient:
1153
1115
  "truncation_direction": truncation_direction,
1154
1116
  },
1155
1117
  headers=self.headers,
1156
- model=model or self.model,
1118
+ model=model_id,
1157
1119
  api_key=self.token,
1158
1120
  )
1159
1121
  response = await self._inner_post(request_parameters)
1160
1122
  np = _import_numpy()
1161
- return np.array(_bytes_to_dict(response), dtype="float32")
1123
+ return np.array(provider_helper.get_response(response), dtype="float32")
1162
1124
 
1163
1125
  async def fill_mask(
1164
1126
  self,
1165
1127
  text: str,
1166
1128
  *,
1167
1129
  model: Optional[str] = None,
1168
- targets: Optional[List[str]] = None,
1130
+ targets: Optional[list[str]] = None,
1169
1131
  top_k: Optional[int] = None,
1170
- ) -> List[FillMaskOutputElement]:
1132
+ ) -> list[FillMaskOutputElement]:
1171
1133
  """
1172
1134
  Fill in a hole with a missing word (token to be precise).
1173
1135
 
@@ -1177,20 +1139,20 @@ class AsyncInferenceClient:
1177
1139
  model (`str`, *optional*):
1178
1140
  The model to use for the fill mask task. Can be a model ID hosted on the Hugging Face Hub or a URL to
1179
1141
  a deployed Inference Endpoint. If not provided, the default recommended fill mask model will be used.
1180
- targets (`List[str`, *optional*):
1142
+ targets (`list[str`, *optional*):
1181
1143
  When passed, the model will limit the scores to the passed targets instead of looking up in the whole
1182
1144
  vocabulary. If the provided targets are not in the model vocab, they will be tokenized and the first
1183
1145
  resulting token will be used (with a warning, and that might be slower).
1184
1146
  top_k (`int`, *optional*):
1185
1147
  When passed, overrides the number of predictions to return.
1186
1148
  Returns:
1187
- `List[FillMaskOutputElement]`: a list of [`FillMaskOutputElement`] items containing the predicted label, associated
1149
+ `list[FillMaskOutputElement]`: a list of [`FillMaskOutputElement`] items containing the predicted label, associated
1188
1150
  probability, token reference, and completed text.
1189
1151
 
1190
1152
  Raises:
1191
1153
  [`InferenceTimeoutError`]:
1192
1154
  If the model is unavailable or the request times out.
1193
- `aiohttp.ClientResponseError`:
1155
+ [`HfHubHTTPError`]:
1194
1156
  If the request fails with an HTTP error status code other than HTTP 503.
1195
1157
 
1196
1158
  Example:
@@ -1205,12 +1167,13 @@ class AsyncInferenceClient:
1205
1167
  ]
1206
1168
  ```
1207
1169
  """
1208
- provider_helper = get_provider_helper(self.provider, task="fill-mask")
1170
+ model_id = model or self.model
1171
+ provider_helper = get_provider_helper(self.provider, task="fill-mask", model=model_id)
1209
1172
  request_parameters = provider_helper.prepare_request(
1210
1173
  inputs=text,
1211
1174
  parameters={"targets": targets, "top_k": top_k},
1212
1175
  headers=self.headers,
1213
- model=model or self.model,
1176
+ model=model_id,
1214
1177
  api_key=self.token,
1215
1178
  )
1216
1179
  response = await self._inner_post(request_parameters)
@@ -1223,13 +1186,13 @@ class AsyncInferenceClient:
1223
1186
  model: Optional[str] = None,
1224
1187
  function_to_apply: Optional["ImageClassificationOutputTransform"] = None,
1225
1188
  top_k: Optional[int] = None,
1226
- ) -> List[ImageClassificationOutputElement]:
1189
+ ) -> list[ImageClassificationOutputElement]:
1227
1190
  """
1228
1191
  Perform image classification on the given image using the specified model.
1229
1192
 
1230
1193
  Args:
1231
- image (`Union[str, Path, bytes, BinaryIO]`):
1232
- The image to classify. It can be raw bytes, an image file, or a URL to an online image.
1194
+ image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`):
1195
+ The image to classify. It can be raw bytes, an image file, a URL to an online image, or a PIL Image.
1233
1196
  model (`str`, *optional*):
1234
1197
  The model to use for image classification. Can be a model ID hosted on the Hugging Face Hub or a URL to a
1235
1198
  deployed Inference Endpoint. If not provided, the default recommended model for image classification will be used.
@@ -1238,12 +1201,12 @@ class AsyncInferenceClient:
1238
1201
  top_k (`int`, *optional*):
1239
1202
  When specified, limits the output to the top K most probable classes.
1240
1203
  Returns:
1241
- `List[ImageClassificationOutputElement]`: a list of [`ImageClassificationOutputElement`] items containing the predicted label and associated probability.
1204
+ `list[ImageClassificationOutputElement]`: a list of [`ImageClassificationOutputElement`] items containing the predicted label and associated probability.
1242
1205
 
1243
1206
  Raises:
1244
1207
  [`InferenceTimeoutError`]:
1245
1208
  If the model is unavailable or the request times out.
1246
- `aiohttp.ClientResponseError`:
1209
+ [`HfHubHTTPError`]:
1247
1210
  If the request fails with an HTTP error status code other than HTTP 503.
1248
1211
 
1249
1212
  Example:
@@ -1255,12 +1218,13 @@ class AsyncInferenceClient:
1255
1218
  [ImageClassificationOutputElement(label='Blenheim spaniel', score=0.9779096841812134), ...]
1256
1219
  ```
1257
1220
  """
1258
- provider_helper = get_provider_helper(self.provider, task="image-classification")
1221
+ model_id = model or self.model
1222
+ provider_helper = get_provider_helper(self.provider, task="image-classification", model=model_id)
1259
1223
  request_parameters = provider_helper.prepare_request(
1260
1224
  inputs=image,
1261
1225
  parameters={"function_to_apply": function_to_apply, "top_k": top_k},
1262
1226
  headers=self.headers,
1263
- model=model or self.model,
1227
+ model=model_id,
1264
1228
  api_key=self.token,
1265
1229
  )
1266
1230
  response = await self._inner_post(request_parameters)
@@ -1275,19 +1239,16 @@ class AsyncInferenceClient:
1275
1239
  overlap_mask_area_threshold: Optional[float] = None,
1276
1240
  subtask: Optional["ImageSegmentationSubtask"] = None,
1277
1241
  threshold: Optional[float] = None,
1278
- ) -> List[ImageSegmentationOutputElement]:
1242
+ ) -> list[ImageSegmentationOutputElement]:
1279
1243
  """
1280
1244
  Perform image segmentation on the given image using the specified model.
1281
1245
 
1282
- <Tip warning={true}>
1283
-
1284
- You must have `PIL` installed if you want to work with images (`pip install Pillow`).
1285
-
1286
- </Tip>
1246
+ > [!WARNING]
1247
+ > You must have `PIL` installed if you want to work with images (`pip install Pillow`).
1287
1248
 
1288
1249
  Args:
1289
- image (`Union[str, Path, bytes, BinaryIO]`):
1290
- The image to segment. It can be raw bytes, an image file, or a URL to an online image.
1250
+ image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`):
1251
+ The image to segment. It can be raw bytes, an image file, a URL to an online image, or a PIL Image.
1291
1252
  model (`str`, *optional*):
1292
1253
  The model to use for image segmentation. Can be a model ID hosted on the Hugging Face Hub or a URL to a
1293
1254
  deployed Inference Endpoint. If not provided, the default recommended model for image segmentation will be used.
@@ -1300,12 +1261,12 @@ class AsyncInferenceClient:
1300
1261
  threshold (`float`, *optional*):
1301
1262
  Probability threshold to filter out predicted masks.
1302
1263
  Returns:
1303
- `List[ImageSegmentationOutputElement]`: A list of [`ImageSegmentationOutputElement`] items containing the segmented masks and associated attributes.
1264
+ `list[ImageSegmentationOutputElement]`: A list of [`ImageSegmentationOutputElement`] items containing the segmented masks and associated attributes.
1304
1265
 
1305
1266
  Raises:
1306
1267
  [`InferenceTimeoutError`]:
1307
1268
  If the model is unavailable or the request times out.
1308
- `aiohttp.ClientResponseError`:
1269
+ [`HfHubHTTPError`]:
1309
1270
  If the request fails with an HTTP error status code other than HTTP 503.
1310
1271
 
1311
1272
  Example:
@@ -1317,7 +1278,8 @@ class AsyncInferenceClient:
1317
1278
  [ImageSegmentationOutputElement(score=0.989008, label='LABEL_184', mask=<PIL.PngImagePlugin.PngImageFile image mode=L size=400x300 at 0x7FDD2B129CC0>), ...]
1318
1279
  ```
1319
1280
  """
1320
- provider_helper = get_provider_helper(self.provider, task="audio-classification")
1281
+ model_id = model or self.model
1282
+ provider_helper = get_provider_helper(self.provider, task="image-segmentation", model=model_id)
1321
1283
  request_parameters = provider_helper.prepare_request(
1322
1284
  inputs=image,
1323
1285
  parameters={
@@ -1327,10 +1289,11 @@ class AsyncInferenceClient:
1327
1289
  "threshold": threshold,
1328
1290
  },
1329
1291
  headers=self.headers,
1330
- model=model or self.model,
1292
+ model=model_id,
1331
1293
  api_key=self.token,
1332
1294
  )
1333
1295
  response = await self._inner_post(request_parameters)
1296
+ response = provider_helper.get_response(response, request_parameters)
1334
1297
  output = ImageSegmentationOutputElement.parse_obj_as_list(response)
1335
1298
  for item in output:
1336
1299
  item.mask = _b64_to_image(item.mask) # type: ignore [assignment]
@@ -1351,15 +1314,12 @@ class AsyncInferenceClient:
1351
1314
  """
1352
1315
  Perform image-to-image translation using a specified model.
1353
1316
 
1354
- <Tip warning={true}>
1355
-
1356
- You must have `PIL` installed if you want to work with images (`pip install Pillow`).
1357
-
1358
- </Tip>
1317
+ > [!WARNING]
1318
+ > You must have `PIL` installed if you want to work with images (`pip install Pillow`).
1359
1319
 
1360
1320
  Args:
1361
- image (`Union[str, Path, bytes, BinaryIO]`):
1362
- The input image for translation. It can be raw bytes, an image file, or a URL to an online image.
1321
+ image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`):
1322
+ The input image for translation. It can be raw bytes, an image file, a URL to an online image, or a PIL Image.
1363
1323
  prompt (`str`, *optional*):
1364
1324
  The text prompt to guide the image generation.
1365
1325
  negative_prompt (`str`, *optional*):
@@ -1374,7 +1334,8 @@ class AsyncInferenceClient:
1374
1334
  The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
1375
1335
  Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
1376
1336
  target_size (`ImageToImageTargetSize`, *optional*):
1377
- The size in pixel of the output image.
1337
+ The size in pixels of the output image. This parameter is only supported by some providers and for
1338
+ specific models. It will be ignored when unsupported.
1378
1339
 
1379
1340
  Returns:
1380
1341
  `Image`: The translated image.
@@ -1382,7 +1343,7 @@ class AsyncInferenceClient:
1382
1343
  Raises:
1383
1344
  [`InferenceTimeoutError`]:
1384
1345
  If the model is unavailable or the request times out.
1385
- `aiohttp.ClientResponseError`:
1346
+ [`HfHubHTTPError`]:
1386
1347
  If the request fails with an HTTP error status code other than HTTP 503.
1387
1348
 
1388
1349
  Example:
@@ -1393,8 +1354,10 @@ class AsyncInferenceClient:
1393
1354
  >>> image = await client.image_to_image("cat.jpg", prompt="turn the cat into a tiger")
1394
1355
  >>> image.save("tiger.jpg")
1395
1356
  ```
1357
+
1396
1358
  """
1397
- provider_helper = get_provider_helper(self.provider, task="image-to-image")
1359
+ model_id = model or self.model
1360
+ provider_helper = get_provider_helper(self.provider, task="image-to-image", model=model_id)
1398
1361
  request_parameters = provider_helper.prepare_request(
1399
1362
  inputs=image,
1400
1363
  parameters={
@@ -1406,22 +1369,103 @@ class AsyncInferenceClient:
1406
1369
  **kwargs,
1407
1370
  },
1408
1371
  headers=self.headers,
1409
- model=model or self.model,
1372
+ model=model_id,
1410
1373
  api_key=self.token,
1411
1374
  )
1412
1375
  response = await self._inner_post(request_parameters)
1376
+ response = provider_helper.get_response(response, request_parameters)
1413
1377
  return _bytes_to_image(response)
1414
1378
 
1379
+ async def image_to_video(
1380
+ self,
1381
+ image: ContentT,
1382
+ *,
1383
+ model: Optional[str] = None,
1384
+ prompt: Optional[str] = None,
1385
+ negative_prompt: Optional[str] = None,
1386
+ num_frames: Optional[float] = None,
1387
+ num_inference_steps: Optional[int] = None,
1388
+ guidance_scale: Optional[float] = None,
1389
+ seed: Optional[int] = None,
1390
+ target_size: Optional[ImageToVideoTargetSize] = None,
1391
+ **kwargs,
1392
+ ) -> bytes:
1393
+ """
1394
+ Generate a video from an input image.
1395
+
1396
+ Args:
1397
+ image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`):
1398
+ The input image to generate a video from. It can be raw bytes, an image file, a URL to an online image, or a PIL Image.
1399
+ model (`str`, *optional*):
1400
+ The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
1401
+ Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
1402
+ prompt (`str`, *optional*):
1403
+ The text prompt to guide the video generation.
1404
+ negative_prompt (`str`, *optional*):
1405
+ One prompt to guide what NOT to include in video generation.
1406
+ num_frames (`float`, *optional*):
1407
+ The num_frames parameter determines how many video frames are generated.
1408
+ num_inference_steps (`int`, *optional*):
1409
+ For diffusion models. The number of denoising steps. More denoising steps usually lead to a higher
1410
+ quality image at the expense of slower inference.
1411
+ guidance_scale (`float`, *optional*):
1412
+ For diffusion models. A higher guidance scale value encourages the model to generate videos closely
1413
+ linked to the text prompt at the expense of lower image quality.
1414
+ seed (`int`, *optional*):
1415
+ The seed to use for the video generation.
1416
+ target_size (`ImageToVideoTargetSize`, *optional*):
1417
+ The size in pixel of the output video frames.
1418
+ num_inference_steps (`int`, *optional*):
1419
+ The number of denoising steps. More denoising steps usually lead to a higher quality video at the
1420
+ expense of slower inference.
1421
+ seed (`int`, *optional*):
1422
+ Seed for the random number generator.
1423
+
1424
+ Returns:
1425
+ `bytes`: The generated video.
1426
+
1427
+ Examples:
1428
+ ```py
1429
+ # Must be run in an async context
1430
+ >>> from huggingface_hub import AsyncInferenceClient
1431
+ >>> client = AsyncInferenceClient()
1432
+ >>> video = await client.image_to_video("cat.jpg", model="Wan-AI/Wan2.2-I2V-A14B", prompt="turn the cat into a tiger")
1433
+ >>> with open("tiger.mp4", "wb") as f:
1434
+ ... f.write(video)
1435
+ ```
1436
+ """
1437
+ model_id = model or self.model
1438
+ provider_helper = get_provider_helper(self.provider, task="image-to-video", model=model_id)
1439
+ request_parameters = provider_helper.prepare_request(
1440
+ inputs=image,
1441
+ parameters={
1442
+ "prompt": prompt,
1443
+ "negative_prompt": negative_prompt,
1444
+ "num_frames": num_frames,
1445
+ "num_inference_steps": num_inference_steps,
1446
+ "guidance_scale": guidance_scale,
1447
+ "seed": seed,
1448
+ "target_size": target_size,
1449
+ **kwargs,
1450
+ },
1451
+ headers=self.headers,
1452
+ model=model_id,
1453
+ api_key=self.token,
1454
+ )
1455
+ response = await self._inner_post(request_parameters)
1456
+ response = provider_helper.get_response(response, request_parameters)
1457
+ return response
1458
+
1415
1459
  async def image_to_text(self, image: ContentT, *, model: Optional[str] = None) -> ImageToTextOutput:
1416
1460
  """
1417
1461
  Takes an input image and return text.
1418
1462
 
1419
1463
  Models can have very different outputs depending on your use case (image captioning, optical character recognition
1420
- (OCR), Pix2Struct, etc). Please have a look to the model card to learn more about a model's specificities.
1464
+ (OCR), Pix2Struct, etc.). Please have a look to the model card to learn more about a model's specificities.
1421
1465
 
1422
1466
  Args:
1423
- image (`Union[str, Path, bytes, BinaryIO]`):
1424
- The input image to caption. It can be raw bytes, an image file, or a URL to an online image..
1467
+ image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`):
1468
+ The input image to caption. It can be raw bytes, an image file, a URL to an online image, or a PIL Image.
1425
1469
  model (`str`, *optional*):
1426
1470
  The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
1427
1471
  Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
@@ -1432,7 +1476,7 @@ class AsyncInferenceClient:
1432
1476
  Raises:
1433
1477
  [`InferenceTimeoutError`]:
1434
1478
  If the model is unavailable or the request times out.
1435
- `aiohttp.ClientResponseError`:
1479
+ [`HfHubHTTPError`]:
1436
1480
  If the request fails with an HTTP error status code other than HTTP 503.
1437
1481
 
1438
1482
  Example:
@@ -1446,45 +1490,43 @@ class AsyncInferenceClient:
1446
1490
  'a dog laying on the grass next to a flower pot '
1447
1491
  ```
1448
1492
  """
1449
- provider_helper = get_provider_helper(self.provider, task="image-to-text")
1493
+ model_id = model or self.model
1494
+ provider_helper = get_provider_helper(self.provider, task="image-to-text", model=model_id)
1450
1495
  request_parameters = provider_helper.prepare_request(
1451
1496
  inputs=image,
1452
1497
  parameters={},
1453
1498
  headers=self.headers,
1454
- model=model or self.model,
1499
+ model=model_id,
1455
1500
  api_key=self.token,
1456
1501
  )
1457
1502
  response = await self._inner_post(request_parameters)
1458
- output = ImageToTextOutput.parse_obj(response)
1459
- return output[0] if isinstance(output, list) else output
1503
+ output_list: list[ImageToTextOutput] = ImageToTextOutput.parse_obj_as_list(response)
1504
+ return output_list[0]
1460
1505
 
1461
1506
  async def object_detection(
1462
1507
  self, image: ContentT, *, model: Optional[str] = None, threshold: Optional[float] = None
1463
- ) -> List[ObjectDetectionOutputElement]:
1508
+ ) -> list[ObjectDetectionOutputElement]:
1464
1509
  """
1465
1510
  Perform object detection on the given image using the specified model.
1466
1511
 
1467
- <Tip warning={true}>
1468
-
1469
- You must have `PIL` installed if you want to work with images (`pip install Pillow`).
1470
-
1471
- </Tip>
1512
+ > [!WARNING]
1513
+ > You must have `PIL` installed if you want to work with images (`pip install Pillow`).
1472
1514
 
1473
1515
  Args:
1474
- image (`Union[str, Path, bytes, BinaryIO]`):
1475
- The image to detect objects on. It can be raw bytes, an image file, or a URL to an online image.
1516
+ image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`):
1517
+ The image to detect objects on. It can be raw bytes, an image file, a URL to an online image, or a PIL Image.
1476
1518
  model (`str`, *optional*):
1477
1519
  The model to use for object detection. Can be a model ID hosted on the Hugging Face Hub or a URL to a
1478
1520
  deployed Inference Endpoint. If not provided, the default recommended model for object detection (DETR) will be used.
1479
1521
  threshold (`float`, *optional*):
1480
1522
  The probability necessary to make a prediction.
1481
1523
  Returns:
1482
- `List[ObjectDetectionOutputElement]`: A list of [`ObjectDetectionOutputElement`] items containing the bounding boxes and associated attributes.
1524
+ `list[ObjectDetectionOutputElement]`: A list of [`ObjectDetectionOutputElement`] items containing the bounding boxes and associated attributes.
1483
1525
 
1484
1526
  Raises:
1485
1527
  [`InferenceTimeoutError`]:
1486
1528
  If the model is unavailable or the request times out.
1487
- `aiohttp.ClientResponseError`:
1529
+ [`HfHubHTTPError`]:
1488
1530
  If the request fails with an HTTP error status code other than HTTP 503.
1489
1531
  `ValueError`:
1490
1532
  If the request output is not a List.
@@ -1498,12 +1540,13 @@ class AsyncInferenceClient:
1498
1540
  [ObjectDetectionOutputElement(score=0.9486683011054993, label='person', box=ObjectDetectionBoundingBox(xmin=59, ymin=39, xmax=420, ymax=510)), ...]
1499
1541
  ```
1500
1542
  """
1501
- provider_helper = get_provider_helper(self.provider, task="object-detection")
1543
+ model_id = model or self.model
1544
+ provider_helper = get_provider_helper(self.provider, task="object-detection", model=model_id)
1502
1545
  request_parameters = provider_helper.prepare_request(
1503
1546
  inputs=image,
1504
1547
  parameters={"threshold": threshold},
1505
1548
  headers=self.headers,
1506
- model=model or self.model,
1549
+ model=model_id,
1507
1550
  api_key=self.token,
1508
1551
  )
1509
1552
  response = await self._inner_post(request_parameters)
@@ -1522,7 +1565,7 @@ class AsyncInferenceClient:
1522
1565
  max_question_len: Optional[int] = None,
1523
1566
  max_seq_len: Optional[int] = None,
1524
1567
  top_k: Optional[int] = None,
1525
- ) -> Union[QuestionAnsweringOutputElement, List[QuestionAnsweringOutputElement]]:
1568
+ ) -> Union[QuestionAnsweringOutputElement, list[QuestionAnsweringOutputElement]]:
1526
1569
  """
1527
1570
  Retrieve the answer to a question from a given text.
1528
1571
 
@@ -1554,13 +1597,13 @@ class AsyncInferenceClient:
1554
1597
  topk answers if there are not enough options available within the context.
1555
1598
 
1556
1599
  Returns:
1557
- Union[`QuestionAnsweringOutputElement`, List[`QuestionAnsweringOutputElement`]]:
1600
+ Union[`QuestionAnsweringOutputElement`, list[`QuestionAnsweringOutputElement`]]:
1558
1601
  When top_k is 1 or not provided, it returns a single `QuestionAnsweringOutputElement`.
1559
1602
  When top_k is greater than 1, it returns a list of `QuestionAnsweringOutputElement`.
1560
1603
  Raises:
1561
1604
  [`InferenceTimeoutError`]:
1562
1605
  If the model is unavailable or the request times out.
1563
- `aiohttp.ClientResponseError`:
1606
+ [`HfHubHTTPError`]:
1564
1607
  If the request fails with an HTTP error status code other than HTTP 503.
1565
1608
 
1566
1609
  Example:
@@ -1572,9 +1615,10 @@ class AsyncInferenceClient:
1572
1615
  QuestionAnsweringOutputElement(answer='Clara', end=16, score=0.9326565265655518, start=11)
1573
1616
  ```
1574
1617
  """
1575
- provider_helper = get_provider_helper(self.provider, task="question-answering")
1618
+ model_id = model or self.model
1619
+ provider_helper = get_provider_helper(self.provider, task="question-answering", model=model_id)
1576
1620
  request_parameters = provider_helper.prepare_request(
1577
- inputs=None,
1621
+ inputs={"question": question, "context": context},
1578
1622
  parameters={
1579
1623
  "align_to_words": align_to_words,
1580
1624
  "doc_stride": doc_stride,
@@ -1584,9 +1628,8 @@ class AsyncInferenceClient:
1584
1628
  "max_seq_len": max_seq_len,
1585
1629
  "top_k": top_k,
1586
1630
  },
1587
- extra_payload={"question": question, "context": context},
1588
1631
  headers=self.headers,
1589
- model=model or self.model,
1632
+ model=model_id,
1590
1633
  api_key=self.token,
1591
1634
  )
1592
1635
  response = await self._inner_post(request_parameters)
@@ -1595,28 +1638,28 @@ class AsyncInferenceClient:
1595
1638
  return output
1596
1639
 
1597
1640
  async def sentence_similarity(
1598
- self, sentence: str, other_sentences: List[str], *, model: Optional[str] = None
1599
- ) -> List[float]:
1641
+ self, sentence: str, other_sentences: list[str], *, model: Optional[str] = None
1642
+ ) -> list[float]:
1600
1643
  """
1601
1644
  Compute the semantic similarity between a sentence and a list of other sentences by comparing their embeddings.
1602
1645
 
1603
1646
  Args:
1604
1647
  sentence (`str`):
1605
1648
  The main sentence to compare to others.
1606
- other_sentences (`List[str]`):
1649
+ other_sentences (`list[str]`):
1607
1650
  The list of sentences to compare to.
1608
1651
  model (`str`, *optional*):
1609
- The model to use for the conversational task. Can be a model ID hosted on the Hugging Face Hub or a URL to
1610
- a deployed Inference Endpoint. If not provided, the default recommended conversational model will be used.
1652
+ The model to use for the sentence similarity task. Can be a model ID hosted on the Hugging Face Hub or a URL to
1653
+ a deployed Inference Endpoint. If not provided, the default recommended sentence similarity model will be used.
1611
1654
  Defaults to None.
1612
1655
 
1613
1656
  Returns:
1614
- `List[float]`: The embedding representing the input text.
1657
+ `list[float]`: The embedding representing the input text.
1615
1658
 
1616
1659
  Raises:
1617
1660
  [`InferenceTimeoutError`]:
1618
1661
  If the model is unavailable or the request times out.
1619
- `aiohttp.ClientResponseError`:
1662
+ [`HfHubHTTPError`]:
1620
1663
  If the request fails with an HTTP error status code other than HTTP 503.
1621
1664
 
1622
1665
  Example:
@@ -1635,13 +1678,14 @@ class AsyncInferenceClient:
1635
1678
  [0.7785726189613342, 0.45876261591911316, 0.2906220555305481]
1636
1679
  ```
1637
1680
  """
1638
- provider_helper = get_provider_helper(self.provider, task="sentence-similarity")
1681
+ model_id = model or self.model
1682
+ provider_helper = get_provider_helper(self.provider, task="sentence-similarity", model=model_id)
1639
1683
  request_parameters = provider_helper.prepare_request(
1640
- inputs=None,
1684
+ inputs={"source_sentence": sentence, "sentences": other_sentences},
1641
1685
  parameters={},
1642
- extra_payload={"source_sentence": sentence, "sentences": other_sentences},
1686
+ extra_payload={},
1643
1687
  headers=self.headers,
1644
- model=model or self.model,
1688
+ model=model_id,
1645
1689
  api_key=self.token,
1646
1690
  )
1647
1691
  response = await self._inner_post(request_parameters)
@@ -1653,7 +1697,7 @@ class AsyncInferenceClient:
1653
1697
  *,
1654
1698
  model: Optional[str] = None,
1655
1699
  clean_up_tokenization_spaces: Optional[bool] = None,
1656
- generate_parameters: Optional[Dict[str, Any]] = None,
1700
+ generate_parameters: Optional[dict[str, Any]] = None,
1657
1701
  truncation: Optional["SummarizationTruncationStrategy"] = None,
1658
1702
  ) -> SummarizationOutput:
1659
1703
  """
@@ -1667,7 +1711,7 @@ class AsyncInferenceClient:
1667
1711
  Inference Endpoint. If not provided, the default recommended model for summarization will be used.
1668
1712
  clean_up_tokenization_spaces (`bool`, *optional*):
1669
1713
  Whether to clean up the potential extra spaces in the text output.
1670
- generate_parameters (`Dict[str, Any]`, *optional*):
1714
+ generate_parameters (`dict[str, Any]`, *optional*):
1671
1715
  Additional parametrization of the text generation algorithm.
1672
1716
  truncation (`"SummarizationTruncationStrategy"`, *optional*):
1673
1717
  The truncation strategy to use.
@@ -1677,7 +1721,7 @@ class AsyncInferenceClient:
1677
1721
  Raises:
1678
1722
  [`InferenceTimeoutError`]:
1679
1723
  If the model is unavailable or the request times out.
1680
- `aiohttp.ClientResponseError`:
1724
+ [`HfHubHTTPError`]:
1681
1725
  If the request fails with an HTTP error status code other than HTTP 503.
1682
1726
 
1683
1727
  Example:
@@ -1694,12 +1738,13 @@ class AsyncInferenceClient:
1694
1738
  "generate_parameters": generate_parameters,
1695
1739
  "truncation": truncation,
1696
1740
  }
1697
- provider_helper = get_provider_helper(self.provider, task="summarization")
1741
+ model_id = model or self.model
1742
+ provider_helper = get_provider_helper(self.provider, task="summarization", model=model_id)
1698
1743
  request_parameters = provider_helper.prepare_request(
1699
1744
  inputs=text,
1700
1745
  parameters=parameters,
1701
1746
  headers=self.headers,
1702
- model=model or self.model,
1747
+ model=model_id,
1703
1748
  api_key=self.token,
1704
1749
  )
1705
1750
  response = await self._inner_post(request_parameters)
@@ -1707,7 +1752,7 @@ class AsyncInferenceClient:
1707
1752
 
1708
1753
  async def table_question_answering(
1709
1754
  self,
1710
- table: Dict[str, Any],
1755
+ table: dict[str, Any],
1711
1756
  query: str,
1712
1757
  *,
1713
1758
  model: Optional[str] = None,
@@ -1742,7 +1787,7 @@ class AsyncInferenceClient:
1742
1787
  Raises:
1743
1788
  [`InferenceTimeoutError`]:
1744
1789
  If the model is unavailable or the request times out.
1745
- `aiohttp.ClientResponseError`:
1790
+ [`HfHubHTTPError`]:
1746
1791
  If the request fails with an HTTP error status code other than HTTP 503.
1747
1792
 
1748
1793
  Example:
@@ -1756,24 +1801,24 @@ class AsyncInferenceClient:
1756
1801
  TableQuestionAnsweringOutputElement(answer='36542', coordinates=[[0, 1]], cells=['36542'], aggregator='AVERAGE')
1757
1802
  ```
1758
1803
  """
1759
- provider_helper = get_provider_helper(self.provider, task="table-question-answering")
1804
+ model_id = model or self.model
1805
+ provider_helper = get_provider_helper(self.provider, task="table-question-answering", model=model_id)
1760
1806
  request_parameters = provider_helper.prepare_request(
1761
- inputs=None,
1807
+ inputs={"query": query, "table": table},
1762
1808
  parameters={"model": model, "padding": padding, "sequential": sequential, "truncation": truncation},
1763
- extra_payload={"query": query, "table": table},
1764
1809
  headers=self.headers,
1765
- model=model or self.model,
1810
+ model=model_id,
1766
1811
  api_key=self.token,
1767
1812
  )
1768
1813
  response = await self._inner_post(request_parameters)
1769
1814
  return TableQuestionAnsweringOutputElement.parse_obj_as_instance(response)
1770
1815
 
1771
- async def tabular_classification(self, table: Dict[str, Any], *, model: Optional[str] = None) -> List[str]:
1816
+ async def tabular_classification(self, table: dict[str, Any], *, model: Optional[str] = None) -> list[str]:
1772
1817
  """
1773
1818
  Classifying a target category (a group) based on a set of attributes.
1774
1819
 
1775
1820
  Args:
1776
- table (`Dict[str, Any]`):
1821
+ table (`dict[str, Any]`):
1777
1822
  Set of attributes to classify.
1778
1823
  model (`str`, *optional*):
1779
1824
  The model to use for the tabular classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to
@@ -1786,7 +1831,7 @@ class AsyncInferenceClient:
1786
1831
  Raises:
1787
1832
  [`InferenceTimeoutError`]:
1788
1833
  If the model is unavailable or the request times out.
1789
- `aiohttp.ClientResponseError`:
1834
+ [`HfHubHTTPError`]:
1790
1835
  If the request fails with an HTTP error status code other than HTTP 503.
1791
1836
 
1792
1837
  Example:
@@ -1811,24 +1856,25 @@ class AsyncInferenceClient:
1811
1856
  ["5", "5", "5"]
1812
1857
  ```
1813
1858
  """
1814
- provider_helper = get_provider_helper(self.provider, task="tabular-classification")
1859
+ model_id = model or self.model
1860
+ provider_helper = get_provider_helper(self.provider, task="tabular-classification", model=model_id)
1815
1861
  request_parameters = provider_helper.prepare_request(
1816
1862
  inputs=None,
1817
1863
  extra_payload={"table": table},
1818
1864
  parameters={},
1819
1865
  headers=self.headers,
1820
- model=model or self.model,
1866
+ model=model_id,
1821
1867
  api_key=self.token,
1822
1868
  )
1823
1869
  response = await self._inner_post(request_parameters)
1824
1870
  return _bytes_to_list(response)
1825
1871
 
1826
- async def tabular_regression(self, table: Dict[str, Any], *, model: Optional[str] = None) -> List[float]:
1872
+ async def tabular_regression(self, table: dict[str, Any], *, model: Optional[str] = None) -> list[float]:
1827
1873
  """
1828
1874
  Predicting a numerical target value given a set of attributes/features in a table.
1829
1875
 
1830
1876
  Args:
1831
- table (`Dict[str, Any]`):
1877
+ table (`dict[str, Any]`):
1832
1878
  Set of attributes stored in a table. The attributes used to predict the target can be both numerical and categorical.
1833
1879
  model (`str`, *optional*):
1834
1880
  The model to use for the tabular regression task. Can be a model ID hosted on the Hugging Face Hub or a URL to
@@ -1841,7 +1887,7 @@ class AsyncInferenceClient:
1841
1887
  Raises:
1842
1888
  [`InferenceTimeoutError`]:
1843
1889
  If the model is unavailable or the request times out.
1844
- `aiohttp.ClientResponseError`:
1890
+ [`HfHubHTTPError`]:
1845
1891
  If the request fails with an HTTP error status code other than HTTP 503.
1846
1892
 
1847
1893
  Example:
@@ -1861,13 +1907,14 @@ class AsyncInferenceClient:
1861
1907
  [110, 120, 130]
1862
1908
  ```
1863
1909
  """
1864
- provider_helper = get_provider_helper(self.provider, task="tabular-regression")
1910
+ model_id = model or self.model
1911
+ provider_helper = get_provider_helper(self.provider, task="tabular-regression", model=model_id)
1865
1912
  request_parameters = provider_helper.prepare_request(
1866
1913
  inputs=None,
1867
1914
  parameters={},
1868
1915
  extra_payload={"table": table},
1869
1916
  headers=self.headers,
1870
- model=model or self.model,
1917
+ model=model_id,
1871
1918
  api_key=self.token,
1872
1919
  )
1873
1920
  response = await self._inner_post(request_parameters)
@@ -1880,7 +1927,7 @@ class AsyncInferenceClient:
1880
1927
  model: Optional[str] = None,
1881
1928
  top_k: Optional[int] = None,
1882
1929
  function_to_apply: Optional["TextClassificationOutputTransform"] = None,
1883
- ) -> List[TextClassificationOutputElement]:
1930
+ ) -> list[TextClassificationOutputElement]:
1884
1931
  """
1885
1932
  Perform text classification (e.g. sentiment-analysis) on the given text.
1886
1933
 
@@ -1897,12 +1944,12 @@ class AsyncInferenceClient:
1897
1944
  The function to apply to the model outputs in order to retrieve the scores.
1898
1945
 
1899
1946
  Returns:
1900
- `List[TextClassificationOutputElement]`: a list of [`TextClassificationOutputElement`] items containing the predicted label and associated probability.
1947
+ `list[TextClassificationOutputElement]`: a list of [`TextClassificationOutputElement`] items containing the predicted label and associated probability.
1901
1948
 
1902
1949
  Raises:
1903
1950
  [`InferenceTimeoutError`]:
1904
1951
  If the model is unavailable or the request times out.
1905
- `aiohttp.ClientResponseError`:
1952
+ [`HfHubHTTPError`]:
1906
1953
  If the request fails with an HTTP error status code other than HTTP 503.
1907
1954
 
1908
1955
  Example:
@@ -1917,7 +1964,8 @@ class AsyncInferenceClient:
1917
1964
  ]
1918
1965
  ```
1919
1966
  """
1920
- provider_helper = get_provider_helper(self.provider, task="text-classification")
1967
+ model_id = model or self.model
1968
+ provider_helper = get_provider_helper(self.provider, task="text-classification", model=model_id)
1921
1969
  request_parameters = provider_helper.prepare_request(
1922
1970
  inputs=text,
1923
1971
  parameters={
@@ -1925,33 +1973,33 @@ class AsyncInferenceClient:
1925
1973
  "top_k": top_k,
1926
1974
  },
1927
1975
  headers=self.headers,
1928
- model=model or self.model,
1976
+ model=model_id,
1929
1977
  api_key=self.token,
1930
1978
  )
1931
1979
  response = await self._inner_post(request_parameters)
1932
1980
  return TextClassificationOutputElement.parse_obj_as_list(response)[0] # type: ignore [return-value]
1933
1981
 
1934
1982
  @overload
1935
- async def text_generation( # type: ignore
1983
+ async def text_generation(
1936
1984
  self,
1937
1985
  prompt: str,
1938
1986
  *,
1939
- details: Literal[False] = ...,
1940
- stream: Literal[False] = ...,
1987
+ details: Literal[True],
1988
+ stream: Literal[True],
1941
1989
  model: Optional[str] = None,
1942
1990
  # Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
1943
1991
  adapter_id: Optional[str] = None,
1944
1992
  best_of: Optional[int] = None,
1945
1993
  decoder_input_details: Optional[bool] = None,
1946
- do_sample: Optional[bool] = False, # Manual default value
1994
+ do_sample: Optional[bool] = None,
1947
1995
  frequency_penalty: Optional[float] = None,
1948
1996
  grammar: Optional[TextGenerationInputGrammarType] = None,
1949
1997
  max_new_tokens: Optional[int] = None,
1950
1998
  repetition_penalty: Optional[float] = None,
1951
- return_full_text: Optional[bool] = False, # Manual default value
1999
+ return_full_text: Optional[bool] = None,
1952
2000
  seed: Optional[int] = None,
1953
- stop: Optional[List[str]] = None,
1954
- stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead
2001
+ stop: Optional[list[str]] = None,
2002
+ stop_sequences: Optional[list[str]] = None, # Deprecated, use `stop` instead
1955
2003
  temperature: Optional[float] = None,
1956
2004
  top_k: Optional[int] = None,
1957
2005
  top_n_tokens: Optional[int] = None,
@@ -1959,29 +2007,29 @@ class AsyncInferenceClient:
1959
2007
  truncate: Optional[int] = None,
1960
2008
  typical_p: Optional[float] = None,
1961
2009
  watermark: Optional[bool] = None,
1962
- ) -> str: ...
2010
+ ) -> AsyncIterable[TextGenerationStreamOutput]: ...
1963
2011
 
1964
2012
  @overload
1965
- async def text_generation( # type: ignore
2013
+ async def text_generation(
1966
2014
  self,
1967
2015
  prompt: str,
1968
2016
  *,
1969
- details: Literal[True] = ...,
1970
- stream: Literal[False] = ...,
2017
+ details: Literal[True],
2018
+ stream: Optional[Literal[False]] = None,
1971
2019
  model: Optional[str] = None,
1972
2020
  # Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
1973
2021
  adapter_id: Optional[str] = None,
1974
2022
  best_of: Optional[int] = None,
1975
2023
  decoder_input_details: Optional[bool] = None,
1976
- do_sample: Optional[bool] = False, # Manual default value
2024
+ do_sample: Optional[bool] = None,
1977
2025
  frequency_penalty: Optional[float] = None,
1978
2026
  grammar: Optional[TextGenerationInputGrammarType] = None,
1979
2027
  max_new_tokens: Optional[int] = None,
1980
2028
  repetition_penalty: Optional[float] = None,
1981
- return_full_text: Optional[bool] = False, # Manual default value
2029
+ return_full_text: Optional[bool] = None,
1982
2030
  seed: Optional[int] = None,
1983
- stop: Optional[List[str]] = None,
1984
- stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead
2031
+ stop: Optional[list[str]] = None,
2032
+ stop_sequences: Optional[list[str]] = None, # Deprecated, use `stop` instead
1985
2033
  temperature: Optional[float] = None,
1986
2034
  top_k: Optional[int] = None,
1987
2035
  top_n_tokens: Optional[int] = None,
@@ -1992,26 +2040,26 @@ class AsyncInferenceClient:
1992
2040
  ) -> TextGenerationOutput: ...
1993
2041
 
1994
2042
  @overload
1995
- async def text_generation( # type: ignore
2043
+ async def text_generation(
1996
2044
  self,
1997
2045
  prompt: str,
1998
2046
  *,
1999
- details: Literal[False] = ...,
2000
- stream: Literal[True] = ...,
2047
+ details: Optional[Literal[False]] = None,
2048
+ stream: Literal[True],
2001
2049
  model: Optional[str] = None,
2002
2050
  # Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
2003
2051
  adapter_id: Optional[str] = None,
2004
2052
  best_of: Optional[int] = None,
2005
2053
  decoder_input_details: Optional[bool] = None,
2006
- do_sample: Optional[bool] = False, # Manual default value
2054
+ do_sample: Optional[bool] = None,
2007
2055
  frequency_penalty: Optional[float] = None,
2008
2056
  grammar: Optional[TextGenerationInputGrammarType] = None,
2009
2057
  max_new_tokens: Optional[int] = None,
2010
2058
  repetition_penalty: Optional[float] = None,
2011
- return_full_text: Optional[bool] = False, # Manual default value
2059
+ return_full_text: Optional[bool] = None, # Manual default value
2012
2060
  seed: Optional[int] = None,
2013
- stop: Optional[List[str]] = None,
2014
- stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead
2061
+ stop: Optional[list[str]] = None,
2062
+ stop_sequences: Optional[list[str]] = None, # Deprecated, use `stop` instead
2015
2063
  temperature: Optional[float] = None,
2016
2064
  top_k: Optional[int] = None,
2017
2065
  top_n_tokens: Optional[int] = None,
@@ -2022,26 +2070,26 @@ class AsyncInferenceClient:
2022
2070
  ) -> AsyncIterable[str]: ...
2023
2071
 
2024
2072
  @overload
2025
- async def text_generation( # type: ignore
2073
+ async def text_generation(
2026
2074
  self,
2027
2075
  prompt: str,
2028
2076
  *,
2029
- details: Literal[True] = ...,
2030
- stream: Literal[True] = ...,
2077
+ details: Optional[Literal[False]] = None,
2078
+ stream: Optional[Literal[False]] = None,
2031
2079
  model: Optional[str] = None,
2032
2080
  # Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
2033
2081
  adapter_id: Optional[str] = None,
2034
2082
  best_of: Optional[int] = None,
2035
2083
  decoder_input_details: Optional[bool] = None,
2036
- do_sample: Optional[bool] = False, # Manual default value
2084
+ do_sample: Optional[bool] = None,
2037
2085
  frequency_penalty: Optional[float] = None,
2038
2086
  grammar: Optional[TextGenerationInputGrammarType] = None,
2039
2087
  max_new_tokens: Optional[int] = None,
2040
2088
  repetition_penalty: Optional[float] = None,
2041
- return_full_text: Optional[bool] = False, # Manual default value
2089
+ return_full_text: Optional[bool] = None,
2042
2090
  seed: Optional[int] = None,
2043
- stop: Optional[List[str]] = None,
2044
- stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead
2091
+ stop: Optional[list[str]] = None,
2092
+ stop_sequences: Optional[list[str]] = None, # Deprecated, use `stop` instead
2045
2093
  temperature: Optional[float] = None,
2046
2094
  top_k: Optional[int] = None,
2047
2095
  top_n_tokens: Optional[int] = None,
@@ -2049,29 +2097,29 @@ class AsyncInferenceClient:
2049
2097
  truncate: Optional[int] = None,
2050
2098
  typical_p: Optional[float] = None,
2051
2099
  watermark: Optional[bool] = None,
2052
- ) -> AsyncIterable[TextGenerationStreamOutput]: ...
2100
+ ) -> str: ...
2053
2101
 
2054
2102
  @overload
2055
2103
  async def text_generation(
2056
2104
  self,
2057
2105
  prompt: str,
2058
2106
  *,
2059
- details: Literal[True] = ...,
2060
- stream: bool = ...,
2107
+ details: Optional[bool] = None,
2108
+ stream: Optional[bool] = None,
2061
2109
  model: Optional[str] = None,
2062
2110
  # Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
2063
2111
  adapter_id: Optional[str] = None,
2064
2112
  best_of: Optional[int] = None,
2065
2113
  decoder_input_details: Optional[bool] = None,
2066
- do_sample: Optional[bool] = False, # Manual default value
2114
+ do_sample: Optional[bool] = None,
2067
2115
  frequency_penalty: Optional[float] = None,
2068
2116
  grammar: Optional[TextGenerationInputGrammarType] = None,
2069
2117
  max_new_tokens: Optional[int] = None,
2070
2118
  repetition_penalty: Optional[float] = None,
2071
- return_full_text: Optional[bool] = False, # Manual default value
2119
+ return_full_text: Optional[bool] = None,
2072
2120
  seed: Optional[int] = None,
2073
- stop: Optional[List[str]] = None,
2074
- stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead
2121
+ stop: Optional[list[str]] = None,
2122
+ stop_sequences: Optional[list[str]] = None, # Deprecated, use `stop` instead
2075
2123
  temperature: Optional[float] = None,
2076
2124
  top_k: Optional[int] = None,
2077
2125
  top_n_tokens: Optional[int] = None,
@@ -2079,28 +2127,28 @@ class AsyncInferenceClient:
2079
2127
  truncate: Optional[int] = None,
2080
2128
  typical_p: Optional[float] = None,
2081
2129
  watermark: Optional[bool] = None,
2082
- ) -> Union[TextGenerationOutput, AsyncIterable[TextGenerationStreamOutput]]: ...
2130
+ ) -> Union[str, TextGenerationOutput, AsyncIterable[str], AsyncIterable[TextGenerationStreamOutput]]: ...
2083
2131
 
2084
2132
  async def text_generation(
2085
2133
  self,
2086
2134
  prompt: str,
2087
2135
  *,
2088
- details: bool = False,
2089
- stream: bool = False,
2136
+ details: Optional[bool] = None,
2137
+ stream: Optional[bool] = None,
2090
2138
  model: Optional[str] = None,
2091
2139
  # Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
2092
2140
  adapter_id: Optional[str] = None,
2093
2141
  best_of: Optional[int] = None,
2094
2142
  decoder_input_details: Optional[bool] = None,
2095
- do_sample: Optional[bool] = False, # Manual default value
2143
+ do_sample: Optional[bool] = None,
2096
2144
  frequency_penalty: Optional[float] = None,
2097
2145
  grammar: Optional[TextGenerationInputGrammarType] = None,
2098
2146
  max_new_tokens: Optional[int] = None,
2099
2147
  repetition_penalty: Optional[float] = None,
2100
- return_full_text: Optional[bool] = False, # Manual default value
2148
+ return_full_text: Optional[bool] = None,
2101
2149
  seed: Optional[int] = None,
2102
- stop: Optional[List[str]] = None,
2103
- stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead
2150
+ stop: Optional[list[str]] = None,
2151
+ stop_sequences: Optional[list[str]] = None, # Deprecated, use `stop` instead
2104
2152
  temperature: Optional[float] = None,
2105
2153
  top_k: Optional[int] = None,
2106
2154
  top_n_tokens: Optional[int] = None,
@@ -2112,12 +2160,9 @@ class AsyncInferenceClient:
2112
2160
  """
2113
2161
  Given a prompt, generate the following text.
2114
2162
 
2115
- <Tip>
2116
-
2117
- If you want to generate a response from chat messages, you should use the [`InferenceClient.chat_completion`] method.
2118
- It accepts a list of messages instead of a single text prompt and handles the chat templating for you.
2119
-
2120
- </Tip>
2163
+ > [!TIP]
2164
+ > If you want to generate a response from chat messages, you should use the [`InferenceClient.chat_completion`] method.
2165
+ > It accepts a list of messages instead of a single text prompt and handles the chat templating for you.
2121
2166
 
2122
2167
  Args:
2123
2168
  prompt (`str`):
@@ -2156,9 +2201,9 @@ class AsyncInferenceClient:
2156
2201
  Whether to prepend the prompt to the generated text
2157
2202
  seed (`int`, *optional*):
2158
2203
  Random sampling seed
2159
- stop (`List[str]`, *optional*):
2204
+ stop (`list[str]`, *optional*):
2160
2205
  Stop generating tokens if a member of `stop` is generated.
2161
- stop_sequences (`List[str]`, *optional*):
2206
+ stop_sequences (`list[str]`, *optional*):
2162
2207
  Deprecated argument. Use `stop` instead.
2163
2208
  temperature (`float`, *optional*):
2164
2209
  The value used to module the logits distribution.
@@ -2175,14 +2220,14 @@ class AsyncInferenceClient:
2175
2220
  typical_p (`float`, *optional`):
2176
2221
  Typical Decoding mass
2177
2222
  See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
2178
- watermark (`bool`, *optional`):
2223
+ watermark (`bool`, *optional*):
2179
2224
  Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
2180
2225
 
2181
2226
  Returns:
2182
- `Union[str, TextGenerationOutput, Iterable[str], Iterable[TextGenerationStreamOutput]]`:
2227
+ `Union[str, TextGenerationOutput, AsyncIterable[str], AsyncIterable[TextGenerationStreamOutput]]`:
2183
2228
  Generated text returned from the server:
2184
2229
  - if `stream=False` and `details=False`, the generated text is returned as a `str` (default)
2185
- - if `stream=True` and `details=False`, the generated text is returned token by token as a `Iterable[str]`
2230
+ - if `stream=True` and `details=False`, the generated text is returned token by token as a `AsyncIterable[str]`
2186
2231
  - if `stream=False` and `details=True`, the generated text is returned with more details as a [`~huggingface_hub.TextGenerationOutput`]
2187
2232
  - if `details=True` and `stream=True`, the generated text is returned token by token as a iterable of [`~huggingface_hub.TextGenerationStreamOutput`]
2188
2233
 
@@ -2191,7 +2236,7 @@ class AsyncInferenceClient:
2191
2236
  If input values are not valid. No HTTP call is made to the server.
2192
2237
  [`InferenceTimeoutError`]:
2193
2238
  If the model is unavailable or the request times out.
2194
- `aiohttp.ClientResponseError`:
2239
+ [`HfHubHTTPError`]:
2195
2240
  If the request fails with an HTTP error status code other than HTTP 503.
2196
2241
 
2197
2242
  Example:
@@ -2326,7 +2371,7 @@ class AsyncInferenceClient:
2326
2371
  "repetition_penalty": repetition_penalty,
2327
2372
  "return_full_text": return_full_text,
2328
2373
  "seed": seed,
2329
- "stop": stop if stop is not None else [],
2374
+ "stop": stop,
2330
2375
  "temperature": temperature,
2331
2376
  "top_k": top_k,
2332
2377
  "top_n_tokens": top_n_tokens,
@@ -2367,29 +2412,30 @@ class AsyncInferenceClient:
2367
2412
  " Please pass `stream=False` as input."
2368
2413
  )
2369
2414
 
2370
- provider_helper = get_provider_helper(self.provider, task="text-generation")
2415
+ model_id = model or self.model
2416
+ provider_helper = get_provider_helper(self.provider, task="text-generation", model=model_id)
2371
2417
  request_parameters = provider_helper.prepare_request(
2372
2418
  inputs=prompt,
2373
2419
  parameters=parameters,
2374
2420
  extra_payload={"stream": stream},
2375
2421
  headers=self.headers,
2376
- model=model or self.model,
2422
+ model=model_id,
2377
2423
  api_key=self.token,
2378
2424
  )
2379
2425
 
2380
2426
  # Handle errors separately for more precise error messages
2381
2427
  try:
2382
- bytes_output = await self._inner_post(request_parameters, stream=stream)
2383
- except _import_aiohttp().ClientResponseError as e:
2384
- match = MODEL_KWARGS_NOT_USED_REGEX.search(e.response_error_payload["error"])
2385
- if e.status == 400 and match:
2428
+ bytes_output = await self._inner_post(request_parameters, stream=stream or False)
2429
+ except HfHubHTTPError as e:
2430
+ match = MODEL_KWARGS_NOT_USED_REGEX.search(str(e))
2431
+ if isinstance(e, BadRequestError) and match:
2386
2432
  unused_params = [kwarg.strip("' ") for kwarg in match.group(1).split(",")]
2387
2433
  _set_unsupported_text_generation_kwargs(model, unused_params)
2388
2434
  return await self.text_generation( # type: ignore
2389
2435
  prompt=prompt,
2390
2436
  details=details,
2391
2437
  stream=stream,
2392
- model=model or self.model,
2438
+ model=model_id,
2393
2439
  adapter_id=adapter_id,
2394
2440
  best_of=best_of,
2395
2441
  decoder_input_details=decoder_input_details,
@@ -2420,8 +2466,8 @@ class AsyncInferenceClient:
2420
2466
  # Data can be a single element (dict) or an iterable of dicts where we select the first element of.
2421
2467
  if isinstance(data, list):
2422
2468
  data = data[0]
2423
-
2424
- return TextGenerationOutput.parse_obj_as_instance(data) if details else data["generated_text"]
2469
+ response = provider_helper.get_response(data, request_parameters)
2470
+ return TextGenerationOutput.parse_obj_as_instance(response) if details else response["generated_text"]
2425
2471
 
2426
2472
  async def text_to_image(
2427
2473
  self,
@@ -2435,20 +2481,16 @@ class AsyncInferenceClient:
2435
2481
  model: Optional[str] = None,
2436
2482
  scheduler: Optional[str] = None,
2437
2483
  seed: Optional[int] = None,
2438
- extra_body: Optional[Dict[str, Any]] = None,
2484
+ extra_body: Optional[dict[str, Any]] = None,
2439
2485
  ) -> "Image":
2440
2486
  """
2441
2487
  Generate an image based on a given text using a specified model.
2442
2488
 
2443
- <Tip warning={true}>
2444
-
2445
- You must have `PIL` installed if you want to work with images (`pip install Pillow`).
2489
+ > [!WARNING]
2490
+ > You must have `PIL` installed if you want to work with images (`pip install Pillow`).
2446
2491
 
2447
- </Tip>
2448
-
2449
- <Tip>
2450
- You can pass provider-specific parameters to the model by using the `extra_body` argument.
2451
- </Tip>
2492
+ > [!TIP]
2493
+ > You can pass provider-specific parameters to the model by using the `extra_body` argument.
2452
2494
 
2453
2495
  Args:
2454
2496
  prompt (`str`):
@@ -2473,7 +2515,7 @@ class AsyncInferenceClient:
2473
2515
  Override the scheduler with a compatible one.
2474
2516
  seed (`int`, *optional*):
2475
2517
  Seed for the random number generator.
2476
- extra_body (`Dict[str, Any]`, *optional*):
2518
+ extra_body (`dict[str, Any]`, *optional*):
2477
2519
  Additional provider-specific parameters to pass to the model. Refer to the provider's documentation
2478
2520
  for supported parameters.
2479
2521
 
@@ -2483,7 +2525,7 @@ class AsyncInferenceClient:
2483
2525
  Raises:
2484
2526
  [`InferenceTimeoutError`]:
2485
2527
  If the model is unavailable or the request times out.
2486
- `aiohttp.ClientResponseError`:
2528
+ [`HfHubHTTPError`]:
2487
2529
  If the request fails with an HTTP error status code other than HTTP 503.
2488
2530
 
2489
2531
  Example:
@@ -2544,8 +2586,10 @@ class AsyncInferenceClient:
2544
2586
  ... )
2545
2587
  >>> image.save("astronaut.png")
2546
2588
  ```
2589
+
2547
2590
  """
2548
- provider_helper = get_provider_helper(self.provider, task="text-to-image")
2591
+ model_id = model or self.model
2592
+ provider_helper = get_provider_helper(self.provider, task="text-to-image", model=model_id)
2549
2593
  request_parameters = provider_helper.prepare_request(
2550
2594
  inputs=prompt,
2551
2595
  parameters={
@@ -2559,11 +2603,11 @@ class AsyncInferenceClient:
2559
2603
  **(extra_body or {}),
2560
2604
  },
2561
2605
  headers=self.headers,
2562
- model=model or self.model,
2606
+ model=model_id,
2563
2607
  api_key=self.token,
2564
2608
  )
2565
2609
  response = await self._inner_post(request_parameters)
2566
- response = provider_helper.get_response(response)
2610
+ response = provider_helper.get_response(response, request_parameters)
2567
2611
  return _bytes_to_image(response)
2568
2612
 
2569
2613
  async def text_to_video(
@@ -2572,18 +2616,17 @@ class AsyncInferenceClient:
2572
2616
  *,
2573
2617
  model: Optional[str] = None,
2574
2618
  guidance_scale: Optional[float] = None,
2575
- negative_prompt: Optional[List[str]] = None,
2619
+ negative_prompt: Optional[list[str]] = None,
2576
2620
  num_frames: Optional[float] = None,
2577
2621
  num_inference_steps: Optional[int] = None,
2578
2622
  seed: Optional[int] = None,
2579
- extra_body: Optional[Dict[str, Any]] = None,
2623
+ extra_body: Optional[dict[str, Any]] = None,
2580
2624
  ) -> bytes:
2581
2625
  """
2582
2626
  Generate a video based on a given text.
2583
2627
 
2584
- <Tip>
2585
- You can pass provider-specific parameters to the model by using the `extra_body` argument.
2586
- </Tip>
2628
+ > [!TIP]
2629
+ > You can pass provider-specific parameters to the model by using the `extra_body` argument.
2587
2630
 
2588
2631
  Args:
2589
2632
  prompt (`str`):
@@ -2595,7 +2638,7 @@ class AsyncInferenceClient:
2595
2638
  guidance_scale (`float`, *optional*):
2596
2639
  A higher guidance scale value encourages the model to generate videos closely linked to the text
2597
2640
  prompt, but values too high may cause saturation and other artifacts.
2598
- negative_prompt (`List[str]`, *optional*):
2641
+ negative_prompt (`list[str]`, *optional*):
2599
2642
  One or several prompt to guide what NOT to include in video generation.
2600
2643
  num_frames (`float`, *optional*):
2601
2644
  The num_frames parameter determines how many video frames are generated.
@@ -2604,7 +2647,7 @@ class AsyncInferenceClient:
2604
2647
  expense of slower inference.
2605
2648
  seed (`int`, *optional*):
2606
2649
  Seed for the random number generator.
2607
- extra_body (`Dict[str, Any]`, *optional*):
2650
+ extra_body (`dict[str, Any]`, *optional*):
2608
2651
  Additional provider-specific parameters to pass to the model. Refer to the provider's documentation
2609
2652
  for supported parameters.
2610
2653
 
@@ -2642,8 +2685,10 @@ class AsyncInferenceClient:
2642
2685
  >>> with open("cat.mp4", "wb") as file:
2643
2686
  ... file.write(video)
2644
2687
  ```
2688
+
2645
2689
  """
2646
- provider_helper = get_provider_helper(self.provider, task="text-to-video")
2690
+ model_id = model or self.model
2691
+ provider_helper = get_provider_helper(self.provider, task="text-to-video", model=model_id)
2647
2692
  request_parameters = provider_helper.prepare_request(
2648
2693
  inputs=prompt,
2649
2694
  parameters={
@@ -2655,11 +2700,11 @@ class AsyncInferenceClient:
2655
2700
  **(extra_body or {}),
2656
2701
  },
2657
2702
  headers=self.headers,
2658
- model=model or self.model,
2703
+ model=model_id,
2659
2704
  api_key=self.token,
2660
2705
  )
2661
2706
  response = await self._inner_post(request_parameters)
2662
- response = provider_helper.get_response(response)
2707
+ response = provider_helper.get_response(response, request_parameters)
2663
2708
  return response
2664
2709
 
2665
2710
  async def text_to_speech(
@@ -2683,14 +2728,13 @@ class AsyncInferenceClient:
2683
2728
  top_p: Optional[float] = None,
2684
2729
  typical_p: Optional[float] = None,
2685
2730
  use_cache: Optional[bool] = None,
2686
- extra_body: Optional[Dict[str, Any]] = None,
2731
+ extra_body: Optional[dict[str, Any]] = None,
2687
2732
  ) -> bytes:
2688
2733
  """
2689
2734
  Synthesize an audio of a voice pronouncing a given text.
2690
2735
 
2691
- <Tip>
2692
- You can pass provider-specific parameters to the model by using the `extra_body` argument.
2693
- </Tip>
2736
+ > [!TIP]
2737
+ > You can pass provider-specific parameters to the model by using the `extra_body` argument.
2694
2738
 
2695
2739
  Args:
2696
2740
  text (`str`):
@@ -2745,7 +2789,7 @@ class AsyncInferenceClient:
2745
2789
  paper](https://hf.co/papers/2202.00666) for more details.
2746
2790
  use_cache (`bool`, *optional*):
2747
2791
  Whether the model should use the past last key/values attentions to speed up decoding
2748
- extra_body (`Dict[str, Any]`, *optional*):
2792
+ extra_body (`dict[str, Any]`, *optional*):
2749
2793
  Additional provider-specific parameters to pass to the model. Refer to the provider's documentation
2750
2794
  for supported parameters.
2751
2795
  Returns:
@@ -2754,7 +2798,7 @@ class AsyncInferenceClient:
2754
2798
  Raises:
2755
2799
  [`InferenceTimeoutError`]:
2756
2800
  If the model is unavailable or the request times out.
2757
- `aiohttp.ClientResponseError`:
2801
+ [`HfHubHTTPError`]:
2758
2802
  If the request fails with an HTTP error status code other than HTTP 503.
2759
2803
 
2760
2804
  Example:
@@ -2841,7 +2885,8 @@ class AsyncInferenceClient:
2841
2885
  ... f.write(audio)
2842
2886
  ```
2843
2887
  """
2844
- provider_helper = get_provider_helper(self.provider, task="text-to-speech")
2888
+ model_id = model or self.model
2889
+ provider_helper = get_provider_helper(self.provider, task="text-to-speech", model=model_id)
2845
2890
  request_parameters = provider_helper.prepare_request(
2846
2891
  inputs=text,
2847
2892
  parameters={
@@ -2864,7 +2909,7 @@ class AsyncInferenceClient:
2864
2909
  **(extra_body or {}),
2865
2910
  },
2866
2911
  headers=self.headers,
2867
- model=model or self.model,
2912
+ model=model_id,
2868
2913
  api_key=self.token,
2869
2914
  )
2870
2915
  response = await self._inner_post(request_parameters)
@@ -2877,9 +2922,9 @@ class AsyncInferenceClient:
2877
2922
  *,
2878
2923
  model: Optional[str] = None,
2879
2924
  aggregation_strategy: Optional["TokenClassificationAggregationStrategy"] = None,
2880
- ignore_labels: Optional[List[str]] = None,
2925
+ ignore_labels: Optional[list[str]] = None,
2881
2926
  stride: Optional[int] = None,
2882
- ) -> List[TokenClassificationOutputElement]:
2927
+ ) -> list[TokenClassificationOutputElement]:
2883
2928
  """
2884
2929
  Perform token classification on the given text.
2885
2930
  Usually used for sentence parsing, either grammatical, or Named Entity Recognition (NER) to understand keywords contained within text.
@@ -2893,18 +2938,18 @@ class AsyncInferenceClient:
2893
2938
  Defaults to None.
2894
2939
  aggregation_strategy (`"TokenClassificationAggregationStrategy"`, *optional*):
2895
2940
  The strategy used to fuse tokens based on model predictions
2896
- ignore_labels (`List[str`, *optional*):
2941
+ ignore_labels (`list[str`, *optional*):
2897
2942
  A list of labels to ignore
2898
2943
  stride (`int`, *optional*):
2899
2944
  The number of overlapping tokens between chunks when splitting the input text.
2900
2945
 
2901
2946
  Returns:
2902
- `List[TokenClassificationOutputElement]`: List of [`TokenClassificationOutputElement`] items containing the entity group, confidence score, word, start and end index.
2947
+ `list[TokenClassificationOutputElement]`: List of [`TokenClassificationOutputElement`] items containing the entity group, confidence score, word, start and end index.
2903
2948
 
2904
2949
  Raises:
2905
2950
  [`InferenceTimeoutError`]:
2906
2951
  If the model is unavailable or the request times out.
2907
- `aiohttp.ClientResponseError`:
2952
+ [`HfHubHTTPError`]:
2908
2953
  If the request fails with an HTTP error status code other than HTTP 503.
2909
2954
 
2910
2955
  Example:
@@ -2931,7 +2976,8 @@ class AsyncInferenceClient:
2931
2976
  ]
2932
2977
  ```
2933
2978
  """
2934
- provider_helper = get_provider_helper(self.provider, task="token-classification")
2979
+ model_id = model or self.model
2980
+ provider_helper = get_provider_helper(self.provider, task="token-classification", model=model_id)
2935
2981
  request_parameters = provider_helper.prepare_request(
2936
2982
  inputs=text,
2937
2983
  parameters={
@@ -2940,7 +2986,7 @@ class AsyncInferenceClient:
2940
2986
  "stride": stride,
2941
2987
  },
2942
2988
  headers=self.headers,
2943
- model=model or self.model,
2989
+ model=model_id,
2944
2990
  api_key=self.token,
2945
2991
  )
2946
2992
  response = await self._inner_post(request_parameters)
@@ -2955,7 +3001,7 @@ class AsyncInferenceClient:
2955
3001
  tgt_lang: Optional[str] = None,
2956
3002
  clean_up_tokenization_spaces: Optional[bool] = None,
2957
3003
  truncation: Optional["TranslationTruncationStrategy"] = None,
2958
- generate_parameters: Optional[Dict[str, Any]] = None,
3004
+ generate_parameters: Optional[dict[str, Any]] = None,
2959
3005
  ) -> TranslationOutput:
2960
3006
  """
2961
3007
  Convert text from one language to another.
@@ -2980,7 +3026,7 @@ class AsyncInferenceClient:
2980
3026
  Whether to clean up the potential extra spaces in the text output.
2981
3027
  truncation (`"TranslationTruncationStrategy"`, *optional*):
2982
3028
  The truncation strategy to use.
2983
- generate_parameters (`Dict[str, Any]`, *optional*):
3029
+ generate_parameters (`dict[str, Any]`, *optional*):
2984
3030
  Additional parametrization of the text generation algorithm.
2985
3031
 
2986
3032
  Returns:
@@ -2989,7 +3035,7 @@ class AsyncInferenceClient:
2989
3035
  Raises:
2990
3036
  [`InferenceTimeoutError`]:
2991
3037
  If the model is unavailable or the request times out.
2992
- `aiohttp.ClientResponseError`:
3038
+ [`HfHubHTTPError`]:
2993
3039
  If the request fails with an HTTP error status code other than HTTP 503.
2994
3040
  `ValueError`:
2995
3041
  If only one of the `src_lang` and `tgt_lang` arguments are provided.
@@ -3018,7 +3064,8 @@ class AsyncInferenceClient:
3018
3064
  if src_lang is None and tgt_lang is not None:
3019
3065
  raise ValueError("You cannot specify `tgt_lang` without specifying `src_lang`.")
3020
3066
 
3021
- provider_helper = get_provider_helper(self.provider, task="translation")
3067
+ model_id = model or self.model
3068
+ provider_helper = get_provider_helper(self.provider, task="translation", model=model_id)
3022
3069
  request_parameters = provider_helper.prepare_request(
3023
3070
  inputs=text,
3024
3071
  parameters={
@@ -3029,7 +3076,7 @@ class AsyncInferenceClient:
3029
3076
  "generate_parameters": generate_parameters,
3030
3077
  },
3031
3078
  headers=self.headers,
3032
- model=model or self.model,
3079
+ model=model_id,
3033
3080
  api_key=self.token,
3034
3081
  )
3035
3082
  response = await self._inner_post(request_parameters)
@@ -3042,13 +3089,13 @@ class AsyncInferenceClient:
3042
3089
  *,
3043
3090
  model: Optional[str] = None,
3044
3091
  top_k: Optional[int] = None,
3045
- ) -> List[VisualQuestionAnsweringOutputElement]:
3092
+ ) -> list[VisualQuestionAnsweringOutputElement]:
3046
3093
  """
3047
3094
  Answering open-ended questions based on an image.
3048
3095
 
3049
3096
  Args:
3050
- image (`Union[str, Path, bytes, BinaryIO]`):
3051
- The input image for the context. It can be raw bytes, an image file, or a URL to an online image.
3097
+ image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`):
3098
+ The input image for the context. It can be raw bytes, an image file, a URL to an online image, or a PIL Image.
3052
3099
  question (`str`):
3053
3100
  Question to be answered.
3054
3101
  model (`str`, *optional*):
@@ -3059,12 +3106,12 @@ class AsyncInferenceClient:
3059
3106
  The number of answers to return (will be chosen by order of likelihood). Note that we return less than
3060
3107
  topk answers if there are not enough options available within the context.
3061
3108
  Returns:
3062
- `List[VisualQuestionAnsweringOutputElement]`: a list of [`VisualQuestionAnsweringOutputElement`] items containing the predicted label and associated probability.
3109
+ `list[VisualQuestionAnsweringOutputElement]`: a list of [`VisualQuestionAnsweringOutputElement`] items containing the predicted label and associated probability.
3063
3110
 
3064
3111
  Raises:
3065
3112
  `InferenceTimeoutError`:
3066
3113
  If the model is unavailable or the request times out.
3067
- `aiohttp.ClientResponseError`:
3114
+ [`HfHubHTTPError`]:
3068
3115
  If the request fails with an HTTP error status code other than HTTP 503.
3069
3116
 
3070
3117
  Example:
@@ -3082,44 +3129,37 @@ class AsyncInferenceClient:
3082
3129
  ]
3083
3130
  ```
3084
3131
  """
3085
- provider_helper = get_provider_helper(self.provider, task="visual-question-answering")
3132
+ model_id = model or self.model
3133
+ provider_helper = get_provider_helper(self.provider, task="visual-question-answering", model=model_id)
3086
3134
  request_parameters = provider_helper.prepare_request(
3087
3135
  inputs=image,
3088
3136
  parameters={"top_k": top_k},
3089
3137
  headers=self.headers,
3090
- model=model or self.model,
3138
+ model=model_id,
3091
3139
  api_key=self.token,
3092
3140
  extra_payload={"question": question, "image": _b64_encode(image)},
3093
3141
  )
3094
3142
  response = await self._inner_post(request_parameters)
3095
3143
  return VisualQuestionAnsweringOutputElement.parse_obj_as_list(response)
3096
3144
 
3097
- @_deprecate_arguments(
3098
- version="0.30.0",
3099
- deprecated_args=["labels"],
3100
- custom_message="`labels`has been renamed to `candidate_labels` and will be removed in huggingface_hub>=0.30.0.",
3101
- )
3102
3145
  async def zero_shot_classification(
3103
3146
  self,
3104
3147
  text: str,
3105
- # temporarily keeping it optional for backward compatibility.
3106
- candidate_labels: List[str] = None, # type: ignore
3148
+ candidate_labels: list[str],
3107
3149
  *,
3108
3150
  multi_label: Optional[bool] = False,
3109
3151
  hypothesis_template: Optional[str] = None,
3110
3152
  model: Optional[str] = None,
3111
- # deprecated argument
3112
- labels: List[str] = None, # type: ignore
3113
- ) -> List[ZeroShotClassificationOutputElement]:
3153
+ ) -> list[ZeroShotClassificationOutputElement]:
3114
3154
  """
3115
3155
  Provide as input a text and a set of candidate labels to classify the input text.
3116
3156
 
3117
3157
  Args:
3118
3158
  text (`str`):
3119
3159
  The input text to classify.
3120
- candidate_labels (`List[str]`):
3160
+ candidate_labels (`list[str]`):
3121
3161
  The set of possible class labels to classify the text into.
3122
- labels (`List[str]`, *optional*):
3162
+ labels (`list[str]`, *optional*):
3123
3163
  (deprecated) List of strings. Each string is the verbalization of a possible label for the input text.
3124
3164
  multi_label (`bool`, *optional*):
3125
3165
  Whether multiple candidate labels can be true. If false, the scores are normalized such that the sum of
@@ -3134,12 +3174,12 @@ class AsyncInferenceClient:
3134
3174
 
3135
3175
 
3136
3176
  Returns:
3137
- `List[ZeroShotClassificationOutputElement]`: List of [`ZeroShotClassificationOutputElement`] items containing the predicted labels and their confidence.
3177
+ `list[ZeroShotClassificationOutputElement]`: List of [`ZeroShotClassificationOutputElement`] items containing the predicted labels and their confidence.
3138
3178
 
3139
3179
  Raises:
3140
3180
  [`InferenceTimeoutError`]:
3141
3181
  If the model is unavailable or the request times out.
3142
- `aiohttp.ClientResponseError`:
3182
+ [`HfHubHTTPError`]:
3143
3183
  If the request fails with an HTTP error status code other than HTTP 503.
3144
3184
 
3145
3185
  Example with `multi_label=False`:
@@ -3190,17 +3230,8 @@ class AsyncInferenceClient:
3190
3230
  ]
3191
3231
  ```
3192
3232
  """
3193
- # handle deprecation
3194
- if labels is not None:
3195
- if candidate_labels is not None:
3196
- raise ValueError(
3197
- "Cannot specify both `labels` and `candidate_labels`. Use `candidate_labels` instead."
3198
- )
3199
- candidate_labels = labels
3200
- elif candidate_labels is None:
3201
- raise ValueError("Must specify `candidate_labels`")
3202
-
3203
- provider_helper = get_provider_helper(self.provider, task="zero-shot-classification")
3233
+ model_id = model or self.model
3234
+ provider_helper = get_provider_helper(self.provider, task="zero-shot-classification", model=model_id)
3204
3235
  request_parameters = provider_helper.prepare_request(
3205
3236
  inputs=text,
3206
3237
  parameters={
@@ -3209,7 +3240,7 @@ class AsyncInferenceClient:
3209
3240
  "hypothesis_template": hypothesis_template,
3210
3241
  },
3211
3242
  headers=self.headers,
3212
- model=model or self.model,
3243
+ model=model_id,
3213
3244
  api_key=self.token,
3214
3245
  )
3215
3246
  response = await self._inner_post(request_parameters)
@@ -3219,31 +3250,25 @@ class AsyncInferenceClient:
3219
3250
  for label, score in zip(output["labels"], output["scores"])
3220
3251
  ]
3221
3252
 
3222
- @_deprecate_arguments(
3223
- version="0.30.0",
3224
- deprecated_args=["labels"],
3225
- custom_message="`labels`has been renamed to `candidate_labels` and will be removed in huggingface_hub>=0.30.0.",
3226
- )
3227
3253
  async def zero_shot_image_classification(
3228
3254
  self,
3229
3255
  image: ContentT,
3230
- # temporarily keeping it optional for backward compatibility.
3231
- candidate_labels: List[str] = None, # type: ignore
3256
+ candidate_labels: list[str],
3232
3257
  *,
3233
3258
  model: Optional[str] = None,
3234
3259
  hypothesis_template: Optional[str] = None,
3235
3260
  # deprecated argument
3236
- labels: List[str] = None, # type: ignore
3237
- ) -> List[ZeroShotImageClassificationOutputElement]:
3261
+ labels: list[str] = None, # type: ignore
3262
+ ) -> list[ZeroShotImageClassificationOutputElement]:
3238
3263
  """
3239
3264
  Provide input image and text labels to predict text labels for the image.
3240
3265
 
3241
3266
  Args:
3242
- image (`Union[str, Path, bytes, BinaryIO]`):
3243
- The input image to caption. It can be raw bytes, an image file, or a URL to an online image.
3244
- candidate_labels (`List[str]`):
3267
+ image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`):
3268
+ The input image to caption. It can be raw bytes, an image file, a URL to an online image, or a PIL Image.
3269
+ candidate_labels (`list[str]`):
3245
3270
  The candidate labels for this image
3246
- labels (`List[str]`, *optional*):
3271
+ labels (`list[str]`, *optional*):
3247
3272
  (deprecated) List of string possible labels. There must be at least 2 labels.
3248
3273
  model (`str`, *optional*):
3249
3274
  The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
@@ -3253,12 +3278,12 @@ class AsyncInferenceClient:
3253
3278
  replacing the placeholder with the candidate labels.
3254
3279
 
3255
3280
  Returns:
3256
- `List[ZeroShotImageClassificationOutputElement]`: List of [`ZeroShotImageClassificationOutputElement`] items containing the predicted labels and their confidence.
3281
+ `list[ZeroShotImageClassificationOutputElement]`: List of [`ZeroShotImageClassificationOutputElement`] items containing the predicted labels and their confidence.
3257
3282
 
3258
3283
  Raises:
3259
3284
  [`InferenceTimeoutError`]:
3260
3285
  If the model is unavailable or the request times out.
3261
- `aiohttp.ClientResponseError`:
3286
+ [`HfHubHTTPError`]:
3262
3287
  If the request fails with an HTTP error status code other than HTTP 503.
3263
3288
 
3264
3289
  Example:
@@ -3274,20 +3299,12 @@ class AsyncInferenceClient:
3274
3299
  [ZeroShotImageClassificationOutputElement(label='dog', score=0.956),...]
3275
3300
  ```
3276
3301
  """
3277
- # handle deprecation
3278
- if labels is not None:
3279
- if candidate_labels is not None:
3280
- raise ValueError(
3281
- "Cannot specify both `labels` and `candidate_labels`. Use `candidate_labels` instead."
3282
- )
3283
- candidate_labels = labels
3284
- elif candidate_labels is None:
3285
- raise ValueError("Must specify `candidate_labels`")
3286
3302
  # Raise ValueError if input is less than 2 labels
3287
3303
  if len(candidate_labels) < 2:
3288
3304
  raise ValueError("You must specify at least 2 classes to compare.")
3289
3305
 
3290
- provider_helper = get_provider_helper(self.provider, task="zero-shot-image-classification")
3306
+ model_id = model or self.model
3307
+ provider_helper = get_provider_helper(self.provider, task="zero-shot-image-classification", model=model_id)
3291
3308
  request_parameters = provider_helper.prepare_request(
3292
3309
  inputs=image,
3293
3310
  parameters={
@@ -3295,150 +3312,13 @@ class AsyncInferenceClient:
3295
3312
  "hypothesis_template": hypothesis_template,
3296
3313
  },
3297
3314
  headers=self.headers,
3298
- model=model or self.model,
3315
+ model=model_id,
3299
3316
  api_key=self.token,
3300
3317
  )
3301
3318
  response = await self._inner_post(request_parameters)
3302
3319
  return ZeroShotImageClassificationOutputElement.parse_obj_as_list(response)
3303
3320
 
3304
- @_deprecate_method(
3305
- version="0.33.0",
3306
- message=(
3307
- "HF Inference API is getting revamped and will only support warm models in the future (no cold start allowed)."
3308
- " Use `HfApi.list_models(..., inference_provider='...')` to list warm models per provider."
3309
- ),
3310
- )
3311
- async def list_deployed_models(
3312
- self, frameworks: Union[None, str, Literal["all"], List[str]] = None
3313
- ) -> Dict[str, List[str]]:
3314
- """
3315
- List models deployed on the HF Serverless Inference API service.
3316
-
3317
- This helper checks deployed models framework by framework. By default, it will check the 4 main frameworks that
3318
- are supported and account for 95% of the hosted models. However, if you want a complete list of models you can
3319
- specify `frameworks="all"` as input. Alternatively, if you know before-hand which framework you are interested
3320
- in, you can also restrict to search to this one (e.g. `frameworks="text-generation-inference"`). The more
3321
- frameworks are checked, the more time it will take.
3322
-
3323
- <Tip warning={true}>
3324
-
3325
- This endpoint method does not return a live list of all models available for the HF Inference API service.
3326
- It searches over a cached list of models that were recently available and the list may not be up to date.
3327
- If you want to know the live status of a specific model, use [`~InferenceClient.get_model_status`].
3328
-
3329
- </Tip>
3330
-
3331
- <Tip>
3332
-
3333
- This endpoint method is mostly useful for discoverability. If you already know which model you want to use and want to
3334
- check its availability, you can directly use [`~InferenceClient.get_model_status`].
3335
-
3336
- </Tip>
3337
-
3338
- Args:
3339
- frameworks (`Literal["all"]` or `List[str]` or `str`, *optional*):
3340
- The frameworks to filter on. By default only a subset of the available frameworks are tested. If set to
3341
- "all", all available frameworks will be tested. It is also possible to provide a single framework or a
3342
- custom set of frameworks to check.
3343
-
3344
- Returns:
3345
- `Dict[str, List[str]]`: A dictionary mapping task names to a sorted list of model IDs.
3346
-
3347
- Example:
3348
- ```py
3349
- # Must be run in an async contextthon
3350
- >>> from huggingface_hub import AsyncInferenceClient
3351
- >>> client = AsyncInferenceClient()
3352
-
3353
- # Discover zero-shot-classification models currently deployed
3354
- >>> models = await client.list_deployed_models()
3355
- >>> models["zero-shot-classification"]
3356
- ['Narsil/deberta-large-mnli-zero-cls', 'facebook/bart-large-mnli', ...]
3357
-
3358
- # List from only 1 framework
3359
- >>> await client.list_deployed_models("text-generation-inference")
3360
- {'text-generation': ['bigcode/starcoder', 'meta-llama/Llama-2-70b-chat-hf', ...], ...}
3361
- ```
3362
- """
3363
- if self.provider != "hf-inference":
3364
- raise ValueError(f"Listing deployed models is not supported on '{self.provider}'.")
3365
-
3366
- # Resolve which frameworks to check
3367
- if frameworks is None:
3368
- frameworks = constants.MAIN_INFERENCE_API_FRAMEWORKS
3369
- elif frameworks == "all":
3370
- frameworks = constants.ALL_INFERENCE_API_FRAMEWORKS
3371
- elif isinstance(frameworks, str):
3372
- frameworks = [frameworks]
3373
- frameworks = list(set(frameworks))
3374
-
3375
- # Fetch them iteratively
3376
- models_by_task: Dict[str, List[str]] = {}
3377
-
3378
- def _unpack_response(framework: str, items: List[Dict]) -> None:
3379
- for model in items:
3380
- if framework == "sentence-transformers":
3381
- # Model running with the `sentence-transformers` framework can work with both tasks even if not
3382
- # branded as such in the API response
3383
- models_by_task.setdefault("feature-extraction", []).append(model["model_id"])
3384
- models_by_task.setdefault("sentence-similarity", []).append(model["model_id"])
3385
- else:
3386
- models_by_task.setdefault(model["task"], []).append(model["model_id"])
3387
-
3388
- for framework in frameworks:
3389
- response = get_session().get(
3390
- f"{constants.INFERENCE_ENDPOINT}/framework/{framework}", headers=build_hf_headers(token=self.token)
3391
- )
3392
- hf_raise_for_status(response)
3393
- _unpack_response(framework, response.json())
3394
-
3395
- # Sort alphabetically for discoverability and return
3396
- for task, models in models_by_task.items():
3397
- models_by_task[task] = sorted(set(models), key=lambda x: x.lower())
3398
- return models_by_task
3399
-
3400
- def _get_client_session(self, headers: Optional[Dict] = None) -> "ClientSession":
3401
- aiohttp = _import_aiohttp()
3402
- client_headers = self.headers.copy()
3403
- if headers is not None:
3404
- client_headers.update(headers)
3405
-
3406
- # Return a new aiohttp ClientSession with correct settings.
3407
- session = aiohttp.ClientSession(
3408
- headers=client_headers,
3409
- cookies=self.cookies,
3410
- timeout=aiohttp.ClientTimeout(self.timeout),
3411
- trust_env=self.trust_env,
3412
- )
3413
-
3414
- # Keep track of sessions to close them later
3415
- self._sessions[session] = set()
3416
-
3417
- # Override the `._request` method to register responses to be closed
3418
- session._wrapped_request = session._request
3419
-
3420
- async def _request(method, url, **kwargs):
3421
- response = await session._wrapped_request(method, url, **kwargs)
3422
- self._sessions[session].add(response)
3423
- return response
3424
-
3425
- session._request = _request
3426
-
3427
- # Override the 'close' method to
3428
- # 1. close ongoing responses
3429
- # 2. deregister the session when closed
3430
- session._close = session.close
3431
-
3432
- async def close_session():
3433
- for response in self._sessions[session]:
3434
- response.close()
3435
- await session._close()
3436
- self._sessions.pop(session, None)
3437
-
3438
- session.close = close_session
3439
- return session
3440
-
3441
- async def get_endpoint_info(self, *, model: Optional[str] = None) -> Dict[str, Any]:
3321
+ async def get_endpoint_info(self, *, model: Optional[str] = None) -> dict[str, Any]:
3442
3322
  """
3443
3323
  Get information about the deployed endpoint.
3444
3324
 
@@ -3451,7 +3331,7 @@ class AsyncInferenceClient:
3451
3331
  Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
3452
3332
 
3453
3333
  Returns:
3454
- `Dict[str, Any]`: Information about the endpoint.
3334
+ `dict[str, Any]`: Information about the endpoint.
3455
3335
 
3456
3336
  Example:
3457
3337
  ```py
@@ -3493,17 +3373,16 @@ class AsyncInferenceClient:
3493
3373
  else:
3494
3374
  url = f"{constants.INFERENCE_ENDPOINT}/models/{model}/info"
3495
3375
 
3496
- async with self._get_client_session(headers=build_hf_headers(token=self.token)) as client:
3497
- response = await client.get(url, proxy=self.proxies)
3498
- response.raise_for_status()
3499
- return await response.json()
3376
+ client = await self._get_async_client()
3377
+ response = await client.get(url, headers=build_hf_headers(token=self.token))
3378
+ hf_raise_for_status(response)
3379
+ return response.json()
3500
3380
 
3501
3381
  async def health_check(self, model: Optional[str] = None) -> bool:
3502
3382
  """
3503
3383
  Check the health of the deployed endpoint.
3504
3384
 
3505
3385
  Health check is only available with Inference Endpoints powered by Text-Generation-Inference (TGI) or Text-Embedding-Inference (TEI).
3506
- For Inference API, please use [`InferenceClient.get_model_status`] instead.
3507
3386
 
3508
3387
  Args:
3509
3388
  model (`str`, *optional*):
@@ -3528,77 +3407,12 @@ class AsyncInferenceClient:
3528
3407
  if model is None:
3529
3408
  raise ValueError("Model id not provided.")
3530
3409
  if not model.startswith(("http://", "https://")):
3531
- raise ValueError(
3532
- "Model must be an Inference Endpoint URL. For serverless Inference API, please use `InferenceClient.get_model_status`."
3533
- )
3410
+ raise ValueError("Model must be an Inference Endpoint URL.")
3534
3411
  url = model.rstrip("/") + "/health"
3535
3412
 
3536
- async with self._get_client_session(headers=build_hf_headers(token=self.token)) as client:
3537
- response = await client.get(url, proxy=self.proxies)
3538
- return response.status == 200
3539
-
3540
- @_deprecate_method(
3541
- version="0.33.0",
3542
- message=(
3543
- "HF Inference API is getting revamped and will only support warm models in the future (no cold start allowed)."
3544
- " Use `HfApi.model_info` to get the model status both with HF Inference API and external providers."
3545
- ),
3546
- )
3547
- async def get_model_status(self, model: Optional[str] = None) -> ModelStatus:
3548
- """
3549
- Get the status of a model hosted on the HF Inference API.
3550
-
3551
- <Tip>
3552
-
3553
- This endpoint is mostly useful when you already know which model you want to use and want to check its
3554
- availability. If you want to discover already deployed models, you should rather use [`~InferenceClient.list_deployed_models`].
3555
-
3556
- </Tip>
3557
-
3558
- Args:
3559
- model (`str`, *optional*):
3560
- Identifier of the model for witch the status gonna be checked. If model is not provided,
3561
- the model associated with this instance of [`InferenceClient`] will be used. Only HF Inference API service can be checked so the
3562
- identifier cannot be a URL.
3563
-
3564
-
3565
- Returns:
3566
- [`ModelStatus`]: An instance of ModelStatus dataclass, containing information,
3567
- about the state of the model: load, state, compute type and framework.
3568
-
3569
- Example:
3570
- ```py
3571
- # Must be run in an async context
3572
- >>> from huggingface_hub import AsyncInferenceClient
3573
- >>> client = AsyncInferenceClient()
3574
- >>> await client.get_model_status("meta-llama/Meta-Llama-3-8B-Instruct")
3575
- ModelStatus(loaded=True, state='Loaded', compute_type='gpu', framework='text-generation-inference')
3576
- ```
3577
- """
3578
- if self.provider != "hf-inference":
3579
- raise ValueError(f"Getting model status is not supported on '{self.provider}'.")
3580
-
3581
- model = model or self.model
3582
- if model is None:
3583
- raise ValueError("Model id not provided.")
3584
- if model.startswith("https://"):
3585
- raise NotImplementedError("Model status is only available for Inference API endpoints.")
3586
- url = f"{constants.INFERENCE_ENDPOINT}/status/{model}"
3587
-
3588
- async with self._get_client_session(headers=build_hf_headers(token=self.token)) as client:
3589
- response = await client.get(url, proxy=self.proxies)
3590
- response.raise_for_status()
3591
- response_data = await response.json()
3592
-
3593
- if "error" in response_data:
3594
- raise ValueError(response_data["error"])
3595
-
3596
- return ModelStatus(
3597
- loaded=response_data["loaded"],
3598
- state=response_data["state"],
3599
- compute_type=response_data["compute_type"],
3600
- framework=response_data["framework"],
3601
- )
3413
+ client = await self._get_async_client()
3414
+ response = await client.get(url, headers=build_hf_headers(token=self.token))
3415
+ return response.status_code == 200
3602
3416
 
3603
3417
  @property
3604
3418
  def chat(self) -> "ProxyClientChat":