huggingface-hub 0.27.1__py3-none-any.whl → 0.28.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of huggingface-hub might be problematic. Click here for more details.

Files changed (39) hide show
  1. huggingface_hub/__init__.py +418 -12
  2. huggingface_hub/_commit_api.py +33 -4
  3. huggingface_hub/_inference_endpoints.py +8 -2
  4. huggingface_hub/_local_folder.py +14 -3
  5. huggingface_hub/commands/scan_cache.py +1 -1
  6. huggingface_hub/commands/upload_large_folder.py +1 -1
  7. huggingface_hub/constants.py +7 -2
  8. huggingface_hub/file_download.py +1 -2
  9. huggingface_hub/hf_api.py +64 -83
  10. huggingface_hub/inference/_client.py +706 -450
  11. huggingface_hub/inference/_common.py +32 -64
  12. huggingface_hub/inference/_generated/_async_client.py +722 -470
  13. huggingface_hub/inference/_generated/types/__init__.py +1 -0
  14. huggingface_hub/inference/_generated/types/image_to_image.py +3 -3
  15. huggingface_hub/inference/_generated/types/text_to_audio.py +1 -2
  16. huggingface_hub/inference/_generated/types/text_to_image.py +3 -3
  17. huggingface_hub/inference/_generated/types/text_to_speech.py +3 -6
  18. huggingface_hub/inference/_generated/types/text_to_video.py +47 -0
  19. huggingface_hub/inference/_generated/types/visual_question_answering.py +1 -1
  20. huggingface_hub/inference/_providers/__init__.py +89 -0
  21. huggingface_hub/inference/_providers/fal_ai.py +155 -0
  22. huggingface_hub/inference/_providers/hf_inference.py +202 -0
  23. huggingface_hub/inference/_providers/replicate.py +144 -0
  24. huggingface_hub/inference/_providers/sambanova.py +85 -0
  25. huggingface_hub/inference/_providers/together.py +148 -0
  26. huggingface_hub/py.typed +0 -0
  27. huggingface_hub/repocard.py +1 -1
  28. huggingface_hub/repocard_data.py +2 -1
  29. huggingface_hub/serialization/_base.py +1 -1
  30. huggingface_hub/serialization/_torch.py +1 -1
  31. huggingface_hub/utils/_fixes.py +25 -13
  32. huggingface_hub/utils/_http.py +2 -2
  33. huggingface_hub/utils/logging.py +1 -1
  34. {huggingface_hub-0.27.1.dist-info → huggingface_hub-0.28.0rc0.dist-info}/METADATA +4 -4
  35. {huggingface_hub-0.27.1.dist-info → huggingface_hub-0.28.0rc0.dist-info}/RECORD +39 -31
  36. {huggingface_hub-0.27.1.dist-info → huggingface_hub-0.28.0rc0.dist-info}/LICENSE +0 -0
  37. {huggingface_hub-0.27.1.dist-info → huggingface_hub-0.28.0rc0.dist-info}/WHEEL +0 -0
  38. {huggingface_hub-0.27.1.dist-info → huggingface_hub-0.28.0rc0.dist-info}/entry_points.txt +0 -0
  39. {huggingface_hub-0.27.1.dist-info → huggingface_hub-0.28.0rc0.dist-info}/top_level.txt +0 -0
@@ -26,14 +26,13 @@ import time
26
26
  import warnings
27
27
  from typing import TYPE_CHECKING, Any, AsyncIterable, Dict, List, Literal, Optional, Set, Union, overload
28
28
 
29
- from requests.structures import CaseInsensitiveDict
30
-
31
29
  from huggingface_hub.constants import ALL_INFERENCE_API_FRAMEWORKS, INFERENCE_ENDPOINT, MAIN_INFERENCE_API_FRAMEWORKS
32
30
  from huggingface_hub.errors import InferenceTimeoutError
33
31
  from huggingface_hub.inference._common import (
34
32
  TASKS_EXPECTING_IMAGES,
35
33
  ContentT,
36
34
  ModelStatus,
35
+ RequestParameters,
37
36
  _async_stream_chat_completion_response,
38
37
  _async_stream_text_generation_response,
39
38
  _b64_encode,
@@ -41,11 +40,9 @@ from huggingface_hub.inference._common import (
41
40
  _bytes_to_dict,
42
41
  _bytes_to_image,
43
42
  _bytes_to_list,
44
- _fetch_recommended_models,
45
43
  _get_unsupported_text_generation_kwargs,
46
44
  _import_numpy,
47
45
  _open_as_binary,
48
- _prepare_payload,
49
46
  _set_unsupported_text_generation_kwargs,
50
47
  raise_text_generation_error,
51
48
  )
@@ -90,8 +87,9 @@ from huggingface_hub.inference._generated.types import (
90
87
  ZeroShotClassificationOutputElement,
91
88
  ZeroShotImageClassificationOutputElement,
92
89
  )
93
- from huggingface_hub.utils import build_hf_headers
94
- from huggingface_hub.utils._deprecation import _deprecate_arguments
90
+ from huggingface_hub.inference._providers import PROVIDER_T, HFInferenceTask, get_provider_helper
91
+ from huggingface_hub.utils import build_hf_headers, get_session, hf_raise_for_status
92
+ from huggingface_hub.utils._deprecation import _deprecate_arguments, _deprecate_method
95
93
 
96
94
  from .._common import _async_yield_from, _import_aiohttp
97
95
 
@@ -112,7 +110,7 @@ class AsyncInferenceClient:
112
110
  Initialize a new Inference Client.
113
111
 
114
112
  [`InferenceClient`] aims to provide a unified experience to perform inference. The client can be used
115
- seamlessly with either the (free) Inference API or self-hosted Inference Endpoints.
113
+ seamlessly with either the (free) Inference API, self-hosted Inference Endpoints, or third-party Inference Providers.
116
114
 
117
115
  Args:
118
116
  model (`str`, `optional`):
@@ -123,6 +121,10 @@ class AsyncInferenceClient:
123
121
  arguments are mutually exclusive. If using `base_url` for chat completion, the `/chat/completions` suffix
124
122
  path will be appended to the base URL (see the [TGI Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api)
125
123
  documentation for details). When passing a URL as `model`, the client will not append any suffix path to it.
124
+ provider (`str`, *optional*):
125
+ Name of the provider to use for inference. Can be `"replicate"`, `"together"`, `"fal-ai"`, `"sambanova"` or `"hf-inference"`.
126
+ defaults to hf-inference (Hugging Face Serverless Inference API).
127
+ If model is a URL or `base_url` is passed, then `provider` is not used.
126
128
  token (`str` or `bool`, *optional*):
127
129
  Hugging Face token. Will default to the locally saved token if not provided.
128
130
  Pass `token=False` if you don't want to send your token to the server.
@@ -152,7 +154,8 @@ class AsyncInferenceClient:
152
154
  self,
153
155
  model: Optional[str] = None,
154
156
  *,
155
- token: Union[str, bool, None] = None,
157
+ provider: Optional[PROVIDER_T] = None,
158
+ token: Optional[str] = None,
156
159
  timeout: Optional[float] = None,
157
160
  headers: Optional[Dict[str, str]] = None,
158
161
  cookies: Optional[Dict[str, str]] = None,
@@ -177,12 +180,12 @@ class AsyncInferenceClient:
177
180
  )
178
181
 
179
182
  self.model: Optional[str] = model
180
- self.token: Union[str, bool, None] = token if token is not None else api_key
181
- self.headers: CaseInsensitiveDict[str] = CaseInsensitiveDict(
182
- build_hf_headers(token=self.token) # 'authorization' + 'user-agent'
183
- )
184
- if headers is not None:
185
- self.headers.update(headers)
183
+ self.token: Optional[str] = token if token is not None else api_key
184
+ self.headers = headers if headers is not None else {}
185
+
186
+ # Configure provider
187
+ self.provider = provider if provider is not None else "hf-inference"
188
+
186
189
  self.cookies = cookies
187
190
  self.timeout = timeout
188
191
  self.trust_env = trust_env
@@ -230,6 +233,14 @@ class AsyncInferenceClient:
230
233
  stream: bool = False,
231
234
  ) -> Union[bytes, AsyncIterable[bytes]]: ...
232
235
 
236
+ @_deprecate_method(
237
+ version="0.31.0",
238
+ message=(
239
+ "Making direct POST requests to the inference server is not supported anymore. "
240
+ "Please use task methods instead (e.g. `InferenceClient.chat_completion`). "
241
+ "If your use case is not supported, please open an issue in https://github.com/huggingface/huggingface_hub."
242
+ ),
243
+ )
233
244
  async def post(
234
245
  self,
235
246
  *,
@@ -242,56 +253,67 @@ class AsyncInferenceClient:
242
253
  """
243
254
  Make a POST request to the inference server.
244
255
 
245
- Args:
246
- json (`Union[str, Dict, List]`, *optional*):
247
- The JSON data to send in the request body, specific to each task. Defaults to None.
248
- data (`Union[str, Path, bytes, BinaryIO]`, *optional*):
249
- The content to send in the request body, specific to each task.
250
- It can be raw bytes, a pointer to an opened file, a local file path,
251
- or a URL to an online resource (image, audio file,...). If both `json` and `data` are passed,
252
- `data` will take precedence. At least `json` or `data` must be provided. Defaults to None.
253
- model (`str`, *optional*):
254
- The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
255
- Inference Endpoint. Will override the model defined at the instance level. Defaults to None.
256
- task (`str`, *optional*):
257
- The task to perform on the inference. All available tasks can be found
258
- [here](https://huggingface.co/tasks). Used only to default to a recommended model if `model` is not
259
- provided. At least `model` or `task` must be provided. Defaults to None.
260
- stream (`bool`, *optional*):
261
- Whether to iterate over streaming APIs.
256
+ This method is deprecated and will be removed in the future.
257
+ Please use task methods instead (e.g. `InferenceClient.chat_completion`).
258
+ """
259
+ if self.provider != "hf-inference":
260
+ raise ValueError(
261
+ "Cannot use `post` with another provider than `hf-inference`. "
262
+ "`InferenceClient.post` is deprecated and should not be used directly anymore."
263
+ )
264
+ provider_helper = HFInferenceTask(task or "unknown")
265
+ url = provider_helper.build_url(provider_helper.map_model(model))
266
+ headers = provider_helper.prepare_headers(headers=self.headers, api_key=self.token)
267
+ return await self._inner_post(
268
+ request_parameters=RequestParameters(
269
+ url=url,
270
+ task=task or "unknown",
271
+ model=model or "unknown",
272
+ json=json,
273
+ data=data,
274
+ headers=headers,
275
+ ),
276
+ stream=stream,
277
+ )
262
278
 
263
- Returns:
264
- bytes: The raw bytes returned by the server.
279
+ @overload
280
+ async def _inner_post( # type: ignore[misc]
281
+ self, request_parameters: RequestParameters, *, stream: Literal[False] = ...
282
+ ) -> bytes: ...
265
283
 
266
- Raises:
267
- [`InferenceTimeoutError`]:
268
- If the model is unavailable or the request times out.
269
- `aiohttp.ClientResponseError`:
270
- If the request fails with an HTTP error status code other than HTTP 503.
271
- """
284
+ @overload
285
+ async def _inner_post( # type: ignore[misc]
286
+ self, request_parameters: RequestParameters, *, stream: Literal[True] = ...
287
+ ) -> AsyncIterable[bytes]: ...
272
288
 
273
- aiohttp = _import_aiohttp()
289
+ @overload
290
+ async def _inner_post(
291
+ self, request_parameters: RequestParameters, *, stream: bool = False
292
+ ) -> Union[bytes, AsyncIterable[bytes]]: ...
274
293
 
275
- url = self._resolve_url(model, task)
294
+ async def _inner_post(
295
+ self, request_parameters: RequestParameters, *, stream: bool = False
296
+ ) -> Union[bytes, AsyncIterable[bytes]]:
297
+ """Make a request to the inference server."""
276
298
 
277
- if data is not None and json is not None:
278
- warnings.warn("Ignoring `json` as `data` is passed as binary.")
299
+ aiohttp = _import_aiohttp()
279
300
 
280
- # Set Accept header if relevant
281
- headers = dict()
282
- if task in TASKS_EXPECTING_IMAGES and "Accept" not in headers:
283
- headers["Accept"] = "image/png"
301
+ # TODO: this should be handled in provider helpers directly
302
+ if request_parameters.task in TASKS_EXPECTING_IMAGES and "Accept" not in request_parameters.headers:
303
+ request_parameters.headers["Accept"] = "image/png"
284
304
 
285
305
  t0 = time.time()
286
306
  timeout = self.timeout
287
307
  while True:
288
- with _open_as_binary(data) as data_as_binary:
308
+ with _open_as_binary(request_parameters.data) as data_as_binary:
289
309
  # Do not use context manager as we don't want to close the connection immediately when returning
290
310
  # a stream
291
- session = self._get_client_session(headers=headers)
311
+ session = self._get_client_session(headers=request_parameters.headers)
292
312
 
293
313
  try:
294
- response = await session.post(url, json=json, data=data_as_binary, proxy=self.proxies)
314
+ response = await session.post(
315
+ request_parameters.url, json=request_parameters.json, data=data_as_binary, proxy=self.proxies
316
+ )
295
317
  response_error_payload = None
296
318
  if response.status != 200:
297
319
  try:
@@ -308,25 +330,27 @@ class AsyncInferenceClient:
308
330
  except asyncio.TimeoutError as error:
309
331
  await session.close()
310
332
  # Convert any `TimeoutError` to a `InferenceTimeoutError`
311
- raise InferenceTimeoutError(f"Inference call timed out: {url}") from error # type: ignore
333
+ raise InferenceTimeoutError(f"Inference call timed out: {request_parameters.url}") from error # type: ignore
312
334
  except aiohttp.ClientResponseError as error:
313
335
  error.response_error_payload = response_error_payload
314
336
  await session.close()
315
- if response.status == 422 and task is not None:
316
- error.message += f". Make sure '{task}' task is supported by the model."
337
+ if response.status == 422 and request_parameters.task != "unknown":
338
+ error.message += f". Make sure '{request_parameters.task}' task is supported by the model."
317
339
  if response.status == 503:
318
340
  # If Model is unavailable, either raise a TimeoutError...
319
341
  if timeout is not None and time.time() - t0 > timeout:
320
342
  raise InferenceTimeoutError(
321
- f"Model not loaded on the server: {url}. Please retry with a higher timeout"
343
+ f"Model not loaded on the server: {request_parameters.url}. Please retry with a higher timeout"
322
344
  f" (current: {self.timeout}).",
323
345
  request=error.request,
324
346
  response=error.response,
325
347
  ) from error
326
348
  # ...or wait 1s and retry
327
349
  logger.info(f"Waiting for model to be loaded on the server: {error}")
328
- if "X-wait-for-model" not in headers and url.startswith(INFERENCE_ENDPOINT):
329
- headers["X-wait-for-model"] = "1"
350
+ if "X-wait-for-model" not in request_parameters.headers and request_parameters.url.startswith(
351
+ INFERENCE_ENDPOINT
352
+ ):
353
+ request_parameters.headers["X-wait-for-model"] = "1"
330
354
  await asyncio.sleep(1)
331
355
  if timeout is not None:
332
356
  timeout = max(self.timeout - (time.time() - t0), 1) # type: ignore
@@ -408,9 +432,15 @@ class AsyncInferenceClient:
408
432
  ]
409
433
  ```
410
434
  """
411
- parameters = {"function_to_apply": function_to_apply, "top_k": top_k}
412
- payload = _prepare_payload(audio, parameters=parameters, expect_binary=True)
413
- response = await self.post(**payload, model=model, task="audio-classification")
435
+ provider_helper = get_provider_helper(self.provider, task="audio-classification")
436
+ request_parameters = provider_helper.prepare_request(
437
+ inputs=audio,
438
+ parameters={"function_to_apply": function_to_apply, "top_k": top_k},
439
+ headers=self.headers,
440
+ model=model or self.model,
441
+ api_key=self.token,
442
+ )
443
+ response = await self._inner_post(request_parameters)
414
444
  return AudioClassificationOutputElement.parse_obj_as_list(response)
415
445
 
416
446
  async def audio_to_audio(
@@ -451,7 +481,15 @@ class AsyncInferenceClient:
451
481
  f.write(item.blob)
452
482
  ```
453
483
  """
454
- response = await self.post(data=audio, model=model, task="audio-to-audio")
484
+ provider_helper = get_provider_helper(self.provider, task="audio-to-audio")
485
+ request_parameters = provider_helper.prepare_request(
486
+ inputs=audio,
487
+ parameters={},
488
+ headers=self.headers,
489
+ model=model or self.model,
490
+ api_key=self.token,
491
+ )
492
+ response = await self._inner_post(request_parameters)
455
493
  audio_output = AudioToAudioOutputElement.parse_obj_as_list(response)
456
494
  for item in audio_output:
457
495
  item.blob = base64.b64decode(item.blob)
@@ -472,7 +510,8 @@ class AsyncInferenceClient:
472
510
  model (`str`, *optional*):
473
511
  The model to use for ASR. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
474
512
  Inference Endpoint. If not provided, the default recommended model for ASR will be used.
475
-
513
+ parameters (Dict[str, Any], *optional*):
514
+ Additional parameters to pass to the model.
476
515
  Returns:
477
516
  [`AutomaticSpeechRecognitionOutput`]: An item containing the transcribed text and optionally the timestamp chunks.
478
517
 
@@ -491,7 +530,15 @@ class AsyncInferenceClient:
491
530
  "hello world"
492
531
  ```
493
532
  """
494
- response = await self.post(data=audio, model=model, task="automatic-speech-recognition")
533
+ provider_helper = get_provider_helper(self.provider, task="automatic-speech-recognition")
534
+ request_parameters = provider_helper.prepare_request(
535
+ inputs=audio,
536
+ parameters={},
537
+ headers=self.headers,
538
+ model=model or self.model,
539
+ api_key=self.token,
540
+ )
541
+ response = await self._inner_post(request_parameters)
495
542
  return AutomaticSpeechRecognitionOutput.parse_obj_as_instance(response)
496
543
 
497
544
  @overload
@@ -605,6 +652,10 @@ class AsyncInferenceClient:
605
652
 
606
653
  </Tip>
607
654
 
655
+ <Tip>
656
+ Some parameters might not be supported by some providers.
657
+ </Tip>
658
+
608
659
  Args:
609
660
  messages (List of [`ChatCompletionInputMessage`]):
610
661
  Conversation history consisting of roles and content pairs.
@@ -612,25 +663,20 @@ class AsyncInferenceClient:
612
663
  The model to use for chat-completion. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
613
664
  Inference Endpoint. If not provided, the default recommended model for chat-based text-generation will be used.
614
665
  See https://huggingface.co/tasks/text-generation for more details.
615
-
616
666
  If `model` is a model ID, it is passed to the server as the `model` parameter. If you want to define a
617
667
  custom URL while setting `model` in the request payload, you must set `base_url` when initializing [`InferenceClient`].
618
668
  frequency_penalty (`float`, *optional*):
619
669
  Penalizes new tokens based on their existing frequency
620
670
  in the text so far. Range: [-2.0, 2.0]. Defaults to 0.0.
621
671
  logit_bias (`List[float]`, *optional*):
622
- Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens
623
- (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,
624
- the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,
625
- but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should
626
- result in a ban or exclusive selection of the relevant token. Defaults to None.
672
+ Adjusts the likelihood of specific tokens appearing in the generated output.
627
673
  logprobs (`bool`, *optional*):
628
674
  Whether to return log probabilities of the output tokens or not. If true, returns the log
629
675
  probabilities of each output token returned in the content of message.
630
676
  max_tokens (`int`, *optional*):
631
677
  Maximum number of tokens allowed in the response. Defaults to 100.
632
678
  n (`int`, *optional*):
633
- UNUSED.
679
+ The number of completions to generate for each prompt.
634
680
  presence_penalty (`float`, *optional*):
635
681
  Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the
636
682
  text so far, increasing the model's likelihood to talk about new topics.
@@ -638,7 +684,7 @@ class AsyncInferenceClient:
638
684
  Grammar constraints. Can be either a JSONSchema or a regex.
639
685
  seed (Optional[`int`], *optional*):
640
686
  Seed for reproducible control flow. Defaults to None.
641
- stop (Optional[`str`], *optional*):
687
+ stop (`List[str]`, *optional*):
642
688
  Up to four strings which trigger the end of the response.
643
689
  Defaults to None.
644
690
  stream (`bool`, *optional*):
@@ -750,6 +796,32 @@ class AsyncInferenceClient:
750
796
  print(chunk.choices[0].delta.content)
751
797
  ```
752
798
 
799
+ Example using a third-party provider directly. Usage will be billed on your Together AI account.
800
+ ```py
801
+ >>> from huggingface_hub import InferenceClient
802
+ >>> client = InferenceClient(
803
+ ... provider="together", # Use Together AI provider
804
+ ... api_key="<together_api_key>", # Pass your Together API key directly
805
+ ... )
806
+ >>> client.chat_completion(
807
+ ... model="meta-llama/Meta-Llama-3-8B-Instruct",
808
+ ... messages=[{"role": "user", "content": "What is the capital of France?"}],
809
+ ... )
810
+ ```
811
+
812
+ Example using a third-party provider through Hugging Face Routing. Usage will be billed on your Hugging Face account.
813
+ ```py
814
+ >>> from huggingface_hub import InferenceClient
815
+ >>> client = InferenceClient(
816
+ ... provider="sambanova", # Use Sambanova provider
817
+ ... api_key="hf_...", # Pass your HF token
818
+ ... )
819
+ >>> client.chat_completion(
820
+ ... model="meta-llama/Meta-Llama-3-8B-Instruct",
821
+ ... messages=[{"role": "user", "content": "What is the capital of France?"}],
822
+ ... )
823
+ ```
824
+
753
825
  Example using Image + Text as input:
754
826
  ```py
755
827
  # Must be run in an async context
@@ -901,68 +973,50 @@ class AsyncInferenceClient:
901
973
  '{\n\n"activity": "bike ride",\n"animals": ["puppy", "cat", "raccoon"],\n"animals_seen": 3,\n"location": "park"}'
902
974
  ```
903
975
  """
904
- model_url = self._resolve_chat_completion_url(model)
905
-
906
- # `model` is sent in the payload. Not used by the server but can be useful for debugging/routing.
907
- # If it's a ID on the Hub => use it. Otherwise, we use a random string.
908
- model_id = model or self.model or "tgi"
909
- if model_id.startswith(("http://", "https://")):
910
- model_id = "tgi" # dummy value
911
-
912
- payload = dict(
913
- model=model_id,
914
- messages=messages,
915
- frequency_penalty=frequency_penalty,
916
- logit_bias=logit_bias,
917
- logprobs=logprobs,
918
- max_tokens=max_tokens,
919
- n=n,
920
- presence_penalty=presence_penalty,
921
- response_format=response_format,
922
- seed=seed,
923
- stop=stop,
924
- temperature=temperature,
925
- tool_choice=tool_choice,
926
- tool_prompt=tool_prompt,
927
- tools=tools,
928
- top_logprobs=top_logprobs,
929
- top_p=top_p,
930
- stream=stream,
931
- stream_options=stream_options,
976
+ # Get the provider helper
977
+ provider_helper = get_provider_helper(self.provider, task="conversational")
978
+
979
+ # Since `chat_completion(..., model=xxx)` is also a payload parameter for the server, we need to handle 'model' differently.
980
+ # `self.base_url` and `self.model` takes precedence over 'model' argument for building URL.
981
+ # `model` takes precedence for payload value.
982
+ model_id_or_url = self.base_url or self.model or model
983
+ payload_model = model or self.model
984
+
985
+ # Prepare the payload
986
+ parameters = {
987
+ "model": payload_model,
988
+ "frequency_penalty": frequency_penalty,
989
+ "logit_bias": logit_bias,
990
+ "logprobs": logprobs,
991
+ "max_tokens": max_tokens,
992
+ "n": n,
993
+ "presence_penalty": presence_penalty,
994
+ "response_format": response_format,
995
+ "seed": seed,
996
+ "stop": stop,
997
+ "temperature": temperature,
998
+ "tool_choice": tool_choice,
999
+ "tool_prompt": tool_prompt,
1000
+ "tools": tools,
1001
+ "top_logprobs": top_logprobs,
1002
+ "top_p": top_p,
1003
+ "stream": stream,
1004
+ "stream_options": stream_options,
1005
+ }
1006
+ request_parameters = provider_helper.prepare_request(
1007
+ inputs=messages,
1008
+ parameters=parameters,
1009
+ headers=self.headers,
1010
+ model=model_id_or_url,
1011
+ api_key=self.token,
932
1012
  )
933
- payload = {key: value for key, value in payload.items() if value is not None}
934
- data = await self.post(model=model_url, json=payload, stream=stream)
1013
+ data = await self._inner_post(request_parameters, stream=stream)
935
1014
 
936
1015
  if stream:
937
1016
  return _async_stream_chat_completion_response(data) # type: ignore[arg-type]
938
1017
 
939
1018
  return ChatCompletionOutput.parse_obj_as_instance(data) # type: ignore[arg-type]
940
1019
 
941
- def _resolve_chat_completion_url(self, model: Optional[str] = None) -> str:
942
- # Since `chat_completion(..., model=xxx)` is also a payload parameter for the server, we need to handle 'model' differently.
943
- # `self.base_url` and `self.model` takes precedence over 'model' argument only in `chat_completion`.
944
- model_id_or_url = self.base_url or self.model or model or self.get_recommended_model("text-generation")
945
-
946
- # Resolve URL if it's a model ID
947
- model_url = (
948
- model_id_or_url
949
- if model_id_or_url.startswith(("http://", "https://"))
950
- else self._resolve_url(model_id_or_url, task="text-generation")
951
- )
952
-
953
- # Strip trailing /
954
- model_url = model_url.rstrip("/")
955
-
956
- # Append /chat/completions if not already present
957
- if model_url.endswith("/v1"):
958
- model_url += "/chat/completions"
959
-
960
- # Append /v1/chat/completions if not already present
961
- if not model_url.endswith("/chat/completions"):
962
- model_url += "/v1/chat/completions"
963
-
964
- return model_url
965
-
966
1020
  async def document_question_answering(
967
1021
  self,
968
1022
  image: ContentT,
@@ -1030,18 +1084,24 @@ class AsyncInferenceClient:
1030
1084
  ```
1031
1085
  """
1032
1086
  inputs: Dict[str, Any] = {"question": question, "image": _b64_encode(image)}
1033
- parameters = {
1034
- "doc_stride": doc_stride,
1035
- "handle_impossible_answer": handle_impossible_answer,
1036
- "lang": lang,
1037
- "max_answer_len": max_answer_len,
1038
- "max_question_len": max_question_len,
1039
- "max_seq_len": max_seq_len,
1040
- "top_k": top_k,
1041
- "word_boxes": word_boxes,
1042
- }
1043
- payload = _prepare_payload(inputs, parameters=parameters)
1044
- response = await self.post(**payload, model=model, task="document-question-answering")
1087
+ provider_helper = get_provider_helper(self.provider, task="document-question-answering")
1088
+ request_parameters = provider_helper.prepare_request(
1089
+ inputs=inputs,
1090
+ parameters={
1091
+ "doc_stride": doc_stride,
1092
+ "handle_impossible_answer": handle_impossible_answer,
1093
+ "lang": lang,
1094
+ "max_answer_len": max_answer_len,
1095
+ "max_question_len": max_question_len,
1096
+ "max_seq_len": max_seq_len,
1097
+ "top_k": top_k,
1098
+ "word_boxes": word_boxes,
1099
+ },
1100
+ headers=self.headers,
1101
+ model=model or self.model,
1102
+ api_key=self.token,
1103
+ )
1104
+ response = await self._inner_post(request_parameters)
1045
1105
  return DocumentQuestionAnsweringOutputElement.parse_obj_as_list(response)
1046
1106
 
1047
1107
  async def feature_extraction(
@@ -1100,14 +1160,20 @@ class AsyncInferenceClient:
1100
1160
  [ 0.28552425, -0.928395 , -1.2077185 , ..., 0.76810825, -2.1069427 , 0.6236161 ]], dtype=float32)
1101
1161
  ```
1102
1162
  """
1103
- parameters = {
1104
- "normalize": normalize,
1105
- "prompt_name": prompt_name,
1106
- "truncate": truncate,
1107
- "truncation_direction": truncation_direction,
1108
- }
1109
- payload = _prepare_payload(text, parameters=parameters)
1110
- response = await self.post(**payload, model=model, task="feature-extraction")
1163
+ provider_helper = get_provider_helper(self.provider, task="feature-extraction")
1164
+ request_parameters = provider_helper.prepare_request(
1165
+ inputs=text,
1166
+ parameters={
1167
+ "normalize": normalize,
1168
+ "prompt_name": prompt_name,
1169
+ "truncate": truncate,
1170
+ "truncation_direction": truncation_direction,
1171
+ },
1172
+ headers=self.headers,
1173
+ model=model or self.model,
1174
+ api_key=self.token,
1175
+ )
1176
+ response = await self._inner_post(request_parameters)
1111
1177
  np = _import_numpy()
1112
1178
  return np.array(_bytes_to_dict(response), dtype="float32")
1113
1179
 
@@ -1156,9 +1222,15 @@ class AsyncInferenceClient:
1156
1222
  ]
1157
1223
  ```
1158
1224
  """
1159
- parameters = {"targets": targets, "top_k": top_k}
1160
- payload = _prepare_payload(text, parameters=parameters)
1161
- response = await self.post(**payload, model=model, task="fill-mask")
1225
+ provider_helper = get_provider_helper(self.provider, task="fill-mask")
1226
+ request_parameters = provider_helper.prepare_request(
1227
+ inputs=text,
1228
+ parameters={"targets": targets, "top_k": top_k},
1229
+ headers=self.headers,
1230
+ model=model or self.model,
1231
+ api_key=self.token,
1232
+ )
1233
+ response = await self._inner_post(request_parameters)
1162
1234
  return FillMaskOutputElement.parse_obj_as_list(response)
1163
1235
 
1164
1236
  async def image_classification(
@@ -1200,9 +1272,15 @@ class AsyncInferenceClient:
1200
1272
  [ImageClassificationOutputElement(label='Blenheim spaniel', score=0.9779096841812134), ...]
1201
1273
  ```
1202
1274
  """
1203
- parameters = {"function_to_apply": function_to_apply, "top_k": top_k}
1204
- payload = _prepare_payload(image, parameters=parameters, expect_binary=True)
1205
- response = await self.post(**payload, model=model, task="image-classification")
1275
+ provider_helper = get_provider_helper(self.provider, task="image-classification")
1276
+ request_parameters = provider_helper.prepare_request(
1277
+ inputs=image,
1278
+ parameters={"function_to_apply": function_to_apply, "top_k": top_k},
1279
+ headers=self.headers,
1280
+ model=model or self.model,
1281
+ api_key=self.token,
1282
+ )
1283
+ response = await self._inner_post(request_parameters)
1206
1284
  return ImageClassificationOutputElement.parse_obj_as_list(response)
1207
1285
 
1208
1286
  async def image_segmentation(
@@ -1256,14 +1334,20 @@ class AsyncInferenceClient:
1256
1334
  [ImageSegmentationOutputElement(score=0.989008, label='LABEL_184', mask=<PIL.PngImagePlugin.PngImageFile image mode=L size=400x300 at 0x7FDD2B129CC0>), ...]
1257
1335
  ```
1258
1336
  """
1259
- parameters = {
1260
- "mask_threshold": mask_threshold,
1261
- "overlap_mask_area_threshold": overlap_mask_area_threshold,
1262
- "subtask": subtask,
1263
- "threshold": threshold,
1264
- }
1265
- payload = _prepare_payload(image, parameters=parameters, expect_binary=True)
1266
- response = await self.post(**payload, model=model, task="image-segmentation")
1337
+ provider_helper = get_provider_helper(self.provider, task="audio-classification")
1338
+ request_parameters = provider_helper.prepare_request(
1339
+ inputs=image,
1340
+ parameters={
1341
+ "mask_threshold": mask_threshold,
1342
+ "overlap_mask_area_threshold": overlap_mask_area_threshold,
1343
+ "subtask": subtask,
1344
+ "threshold": threshold,
1345
+ },
1346
+ headers=self.headers,
1347
+ model=model or self.model,
1348
+ api_key=self.token,
1349
+ )
1350
+ response = await self._inner_post(request_parameters)
1267
1351
  output = ImageSegmentationOutputElement.parse_obj_as_list(response)
1268
1352
  for item in output:
1269
1353
  item.mask = _b64_to_image(item.mask) # type: ignore [assignment]
@@ -1274,7 +1358,7 @@ class AsyncInferenceClient:
1274
1358
  image: ContentT,
1275
1359
  prompt: Optional[str] = None,
1276
1360
  *,
1277
- negative_prompt: Optional[List[str]] = None,
1361
+ negative_prompt: Optional[str] = None,
1278
1362
  num_inference_steps: Optional[int] = None,
1279
1363
  guidance_scale: Optional[float] = None,
1280
1364
  model: Optional[str] = None,
@@ -1295,8 +1379,8 @@ class AsyncInferenceClient:
1295
1379
  The input image for translation. It can be raw bytes, an image file, or a URL to an online image.
1296
1380
  prompt (`str`, *optional*):
1297
1381
  The text prompt to guide the image generation.
1298
- negative_prompt (`List[str]`, *optional*):
1299
- One or several prompt to guide what NOT to include in image generation.
1382
+ negative_prompt (`str`, *optional*):
1383
+ One prompt to guide what NOT to include in image generation.
1300
1384
  num_inference_steps (`int`, *optional*):
1301
1385
  For diffusion models. The number of denoising steps. More denoising steps usually lead to a higher
1302
1386
  quality image at the expense of slower inference.
@@ -1327,16 +1411,22 @@ class AsyncInferenceClient:
1327
1411
  >>> image.save("tiger.jpg")
1328
1412
  ```
1329
1413
  """
1330
- parameters = {
1331
- "prompt": prompt,
1332
- "negative_prompt": negative_prompt,
1333
- "target_size": target_size,
1334
- "num_inference_steps": num_inference_steps,
1335
- "guidance_scale": guidance_scale,
1336
- **kwargs,
1337
- }
1338
- payload = _prepare_payload(image, parameters=parameters, expect_binary=True)
1339
- response = await self.post(**payload, model=model, task="image-to-image")
1414
+ provider_helper = get_provider_helper(self.provider, task="image-to-image")
1415
+ request_parameters = provider_helper.prepare_request(
1416
+ inputs=image,
1417
+ parameters={
1418
+ "prompt": prompt,
1419
+ "negative_prompt": negative_prompt,
1420
+ "target_size": target_size,
1421
+ "num_inference_steps": num_inference_steps,
1422
+ "guidance_scale": guidance_scale,
1423
+ **kwargs,
1424
+ },
1425
+ headers=self.headers,
1426
+ model=model or self.model,
1427
+ api_key=self.token,
1428
+ )
1429
+ response = await self._inner_post(request_parameters)
1340
1430
  return _bytes_to_image(response)
1341
1431
 
1342
1432
  async def image_to_text(self, image: ContentT, *, model: Optional[str] = None) -> ImageToTextOutput:
@@ -1373,99 +1463,18 @@ class AsyncInferenceClient:
1373
1463
  'a dog laying on the grass next to a flower pot '
1374
1464
  ```
1375
1465
  """
1376
- response = await self.post(data=image, model=model, task="image-to-text")
1466
+ provider_helper = get_provider_helper(self.provider, task="image-to-text")
1467
+ request_parameters = provider_helper.prepare_request(
1468
+ inputs=image,
1469
+ parameters={},
1470
+ headers=self.headers,
1471
+ model=model or self.model,
1472
+ api_key=self.token,
1473
+ )
1474
+ response = await self._inner_post(request_parameters)
1377
1475
  output = ImageToTextOutput.parse_obj(response)
1378
1476
  return output[0] if isinstance(output, list) else output
1379
1477
 
1380
- async def list_deployed_models(
1381
- self, frameworks: Union[None, str, Literal["all"], List[str]] = None
1382
- ) -> Dict[str, List[str]]:
1383
- """
1384
- List models deployed on the Serverless Inference API service.
1385
-
1386
- This helper checks deployed models framework by framework. By default, it will check the 4 main frameworks that
1387
- are supported and account for 95% of the hosted models. However, if you want a complete list of models you can
1388
- specify `frameworks="all"` as input. Alternatively, if you know before-hand which framework you are interested
1389
- in, you can also restrict to search to this one (e.g. `frameworks="text-generation-inference"`). The more
1390
- frameworks are checked, the more time it will take.
1391
-
1392
- <Tip warning={true}>
1393
-
1394
- This endpoint method does not return a live list of all models available for the Serverless Inference API service.
1395
- It searches over a cached list of models that were recently available and the list may not be up to date.
1396
- If you want to know the live status of a specific model, use [`~InferenceClient.get_model_status`].
1397
-
1398
- </Tip>
1399
-
1400
- <Tip>
1401
-
1402
- This endpoint method is mostly useful for discoverability. If you already know which model you want to use and want to
1403
- check its availability, you can directly use [`~InferenceClient.get_model_status`].
1404
-
1405
- </Tip>
1406
-
1407
- Args:
1408
- frameworks (`Literal["all"]` or `List[str]` or `str`, *optional*):
1409
- The frameworks to filter on. By default only a subset of the available frameworks are tested. If set to
1410
- "all", all available frameworks will be tested. It is also possible to provide a single framework or a
1411
- custom set of frameworks to check.
1412
-
1413
- Returns:
1414
- `Dict[str, List[str]]`: A dictionary mapping task names to a sorted list of model IDs.
1415
-
1416
- Example:
1417
- ```py
1418
- # Must be run in an async contextthon
1419
- >>> from huggingface_hub import AsyncInferenceClient
1420
- >>> client = AsyncInferenceClient()
1421
-
1422
- # Discover zero-shot-classification models currently deployed
1423
- >>> models = await client.list_deployed_models()
1424
- >>> models["zero-shot-classification"]
1425
- ['Narsil/deberta-large-mnli-zero-cls', 'facebook/bart-large-mnli', ...]
1426
-
1427
- # List from only 1 framework
1428
- >>> await client.list_deployed_models("text-generation-inference")
1429
- {'text-generation': ['bigcode/starcoder', 'meta-llama/Llama-2-70b-chat-hf', ...], ...}
1430
- ```
1431
- """
1432
- # Resolve which frameworks to check
1433
- if frameworks is None:
1434
- frameworks = MAIN_INFERENCE_API_FRAMEWORKS
1435
- elif frameworks == "all":
1436
- frameworks = ALL_INFERENCE_API_FRAMEWORKS
1437
- elif isinstance(frameworks, str):
1438
- frameworks = [frameworks]
1439
- frameworks = list(set(frameworks))
1440
-
1441
- # Fetch them iteratively
1442
- models_by_task: Dict[str, List[str]] = {}
1443
-
1444
- def _unpack_response(framework: str, items: List[Dict]) -> None:
1445
- for model in items:
1446
- if framework == "sentence-transformers":
1447
- # Model running with the `sentence-transformers` framework can work with both tasks even if not
1448
- # branded as such in the API response
1449
- models_by_task.setdefault("feature-extraction", []).append(model["model_id"])
1450
- models_by_task.setdefault("sentence-similarity", []).append(model["model_id"])
1451
- else:
1452
- models_by_task.setdefault(model["task"], []).append(model["model_id"])
1453
-
1454
- async def _fetch_framework(framework: str) -> None:
1455
- async with self._get_client_session() as client:
1456
- response = await client.get(f"{INFERENCE_ENDPOINT}/framework/{framework}", proxy=self.proxies)
1457
- response.raise_for_status()
1458
- _unpack_response(framework, await response.json())
1459
-
1460
- import asyncio
1461
-
1462
- await asyncio.gather(*[_fetch_framework(framework) for framework in frameworks])
1463
-
1464
- # Sort alphabetically for discoverability and return
1465
- for task, models in models_by_task.items():
1466
- models_by_task[task] = sorted(set(models), key=lambda x: x.lower())
1467
- return models_by_task
1468
-
1469
1478
  async def object_detection(
1470
1479
  self, image: ContentT, *, model: Optional[str] = None, threshold: Optional[float] = None
1471
1480
  ) -> List[ObjectDetectionOutputElement]:
@@ -1506,11 +1515,15 @@ class AsyncInferenceClient:
1506
1515
  [ObjectDetectionOutputElement(score=0.9486683011054993, label='person', box=ObjectDetectionBoundingBox(xmin=59, ymin=39, xmax=420, ymax=510)), ...]
1507
1516
  ```
1508
1517
  """
1509
- parameters = {
1510
- "threshold": threshold,
1511
- }
1512
- payload = _prepare_payload(image, parameters=parameters, expect_binary=True)
1513
- response = await self.post(**payload, model=model, task="object-detection")
1518
+ provider_helper = get_provider_helper(self.provider, task="object-detection")
1519
+ request_parameters = provider_helper.prepare_request(
1520
+ inputs=image,
1521
+ parameters={"threshold": threshold},
1522
+ headers=self.headers,
1523
+ model=model or self.model,
1524
+ api_key=self.token,
1525
+ )
1526
+ response = await self._inner_post(request_parameters)
1514
1527
  return ObjectDetectionOutputElement.parse_obj_as_list(response)
1515
1528
 
1516
1529
  async def question_answering(
@@ -1576,22 +1589,24 @@ class AsyncInferenceClient:
1576
1589
  QuestionAnsweringOutputElement(answer='Clara', end=16, score=0.9326565265655518, start=11)
1577
1590
  ```
1578
1591
  """
1579
- parameters = {
1580
- "align_to_words": align_to_words,
1581
- "doc_stride": doc_stride,
1582
- "handle_impossible_answer": handle_impossible_answer,
1583
- "max_answer_len": max_answer_len,
1584
- "max_question_len": max_question_len,
1585
- "max_seq_len": max_seq_len,
1586
- "top_k": top_k,
1587
- }
1588
- inputs: Dict[str, Any] = {"question": question, "context": context}
1589
- payload = _prepare_payload(inputs, parameters=parameters)
1590
- response = await self.post(
1591
- **payload,
1592
- model=model,
1593
- task="question-answering",
1592
+ provider_helper = get_provider_helper(self.provider, task="question-answering")
1593
+ request_parameters = provider_helper.prepare_request(
1594
+ inputs=None,
1595
+ parameters={
1596
+ "align_to_words": align_to_words,
1597
+ "doc_stride": doc_stride,
1598
+ "handle_impossible_answer": handle_impossible_answer,
1599
+ "max_answer_len": max_answer_len,
1600
+ "max_question_len": max_question_len,
1601
+ "max_seq_len": max_seq_len,
1602
+ "top_k": top_k,
1603
+ },
1604
+ extra_payload={"question": question, "context": context},
1605
+ headers=self.headers,
1606
+ model=model or self.model,
1607
+ api_key=self.token,
1594
1608
  )
1609
+ response = await self._inner_post(request_parameters)
1595
1610
  # Parse the response as a single `QuestionAnsweringOutputElement` when top_k is 1 or not provided, or a list of `QuestionAnsweringOutputElement` to ensure backward compatibility.
1596
1611
  output = QuestionAnsweringOutputElement.parse_obj(response)
1597
1612
  return output
@@ -1637,11 +1652,16 @@ class AsyncInferenceClient:
1637
1652
  [0.7785726189613342, 0.45876261591911316, 0.2906220555305481]
1638
1653
  ```
1639
1654
  """
1640
- response = await self.post(
1641
- json={"inputs": {"source_sentence": sentence, "sentences": other_sentences}},
1642
- model=model,
1643
- task="sentence-similarity",
1655
+ provider_helper = get_provider_helper(self.provider, task="sentence-similarity")
1656
+ request_parameters = provider_helper.prepare_request(
1657
+ inputs=None,
1658
+ parameters={},
1659
+ extra_payload={"source_sentence": sentence, "sentences": other_sentences},
1660
+ headers=self.headers,
1661
+ model=model or self.model,
1662
+ api_key=self.token,
1644
1663
  )
1664
+ response = await self._inner_post(request_parameters)
1645
1665
  return _bytes_to_list(response)
1646
1666
 
1647
1667
  @_deprecate_arguments(
@@ -1704,8 +1724,15 @@ class AsyncInferenceClient:
1704
1724
  "generate_parameters": generate_parameters,
1705
1725
  "truncation": truncation,
1706
1726
  }
1707
- payload = _prepare_payload(text, parameters=parameters)
1708
- response = await self.post(**payload, model=model, task="summarization")
1727
+ provider_helper = get_provider_helper(self.provider, task="summarization")
1728
+ request_parameters = provider_helper.prepare_request(
1729
+ inputs=text,
1730
+ parameters=parameters,
1731
+ headers=self.headers,
1732
+ model=model or self.model,
1733
+ api_key=self.token,
1734
+ )
1735
+ response = await self._inner_post(request_parameters)
1709
1736
  return SummarizationOutput.parse_obj_as_list(response)[0]
1710
1737
 
1711
1738
  async def table_question_answering(
@@ -1759,21 +1786,16 @@ class AsyncInferenceClient:
1759
1786
  TableQuestionAnsweringOutputElement(answer='36542', coordinates=[[0, 1]], cells=['36542'], aggregator='AVERAGE')
1760
1787
  ```
1761
1788
  """
1762
- parameters = {
1763
- "padding": padding,
1764
- "sequential": sequential,
1765
- "truncation": truncation,
1766
- }
1767
- inputs = {
1768
- "query": query,
1769
- "table": table,
1770
- }
1771
- payload = _prepare_payload(inputs, parameters=parameters)
1772
- response = await self.post(
1773
- **payload,
1774
- model=model,
1775
- task="table-question-answering",
1789
+ provider_helper = get_provider_helper(self.provider, task="table-question-answering")
1790
+ request_parameters = provider_helper.prepare_request(
1791
+ inputs=None,
1792
+ parameters={"model": model, "padding": padding, "sequential": sequential, "truncation": truncation},
1793
+ extra_payload={"query": query, "table": table},
1794
+ headers=self.headers,
1795
+ model=model or self.model,
1796
+ api_key=self.token,
1776
1797
  )
1798
+ response = await self._inner_post(request_parameters)
1777
1799
  return TableQuestionAnsweringOutputElement.parse_obj_as_instance(response)
1778
1800
 
1779
1801
  async def tabular_classification(self, table: Dict[str, Any], *, model: Optional[str] = None) -> List[str]:
@@ -1819,11 +1841,16 @@ class AsyncInferenceClient:
1819
1841
  ["5", "5", "5"]
1820
1842
  ```
1821
1843
  """
1822
- response = await self.post(
1823
- json={"table": table},
1824
- model=model,
1825
- task="tabular-classification",
1844
+ provider_helper = get_provider_helper(self.provider, task="tabular-classification")
1845
+ request_parameters = provider_helper.prepare_request(
1846
+ inputs=None,
1847
+ extra_payload={"table": table},
1848
+ parameters={},
1849
+ headers=self.headers,
1850
+ model=model or self.model,
1851
+ api_key=self.token,
1826
1852
  )
1853
+ response = await self._inner_post(request_parameters)
1827
1854
  return _bytes_to_list(response)
1828
1855
 
1829
1856
  async def tabular_regression(self, table: Dict[str, Any], *, model: Optional[str] = None) -> List[float]:
@@ -1864,7 +1891,16 @@ class AsyncInferenceClient:
1864
1891
  [110, 120, 130]
1865
1892
  ```
1866
1893
  """
1867
- response = await self.post(json={"table": table}, model=model, task="tabular-regression")
1894
+ provider_helper = get_provider_helper(self.provider, task="tabular-regression")
1895
+ request_parameters = provider_helper.prepare_request(
1896
+ inputs=None,
1897
+ parameters={},
1898
+ extra_payload={"table": table},
1899
+ headers=self.headers,
1900
+ model=model or self.model,
1901
+ api_key=self.token,
1902
+ )
1903
+ response = await self._inner_post(request_parameters)
1868
1904
  return _bytes_to_list(response)
1869
1905
 
1870
1906
  async def text_classification(
@@ -1911,16 +1947,18 @@ class AsyncInferenceClient:
1911
1947
  ]
1912
1948
  ```
1913
1949
  """
1914
- parameters = {
1915
- "function_to_apply": function_to_apply,
1916
- "top_k": top_k,
1917
- }
1918
- payload = _prepare_payload(text, parameters=parameters)
1919
- response = await self.post(
1920
- **payload,
1921
- model=model,
1922
- task="text-classification",
1950
+ provider_helper = get_provider_helper(self.provider, task="text-classification")
1951
+ request_parameters = provider_helper.prepare_request(
1952
+ inputs=text,
1953
+ parameters={
1954
+ "function_to_apply": function_to_apply,
1955
+ "top_k": top_k,
1956
+ },
1957
+ headers=self.headers,
1958
+ model=model or self.model,
1959
+ api_key=self.token,
1923
1960
  )
1961
+ response = await self._inner_post(request_parameters)
1924
1962
  return TextClassificationOutputElement.parse_obj_as_list(response)[0] # type: ignore [return-value]
1925
1963
 
1926
1964
  @overload
@@ -2104,15 +2142,6 @@ class AsyncInferenceClient:
2104
2142
  """
2105
2143
  Given a prompt, generate the following text.
2106
2144
 
2107
- API endpoint is supposed to run with the `text-generation-inference` backend (TGI). This backend is the
2108
- go-to solution to run large language models at scale. However, for some smaller models (e.g. "gpt2") the
2109
- default `transformers` + `api-inference` solution is still in use. Both approaches have very similar APIs, but
2110
- not exactly the same. This method is compatible with both approaches but some parameters are only available for
2111
- `text-generation-inference`. If some parameters are ignored, a warning message is triggered but the process
2112
- continues correctly.
2113
-
2114
- To learn more about the TGI project, please refer to https://github.com/huggingface/text-generation-inference.
2115
-
2116
2145
  <Tip>
2117
2146
 
2118
2147
  If you want to generate a response from chat messages, you should use the [`InferenceClient.chat_completion`] method.
@@ -2336,12 +2365,6 @@ class AsyncInferenceClient:
2336
2365
  "typical_p": typical_p,
2337
2366
  "watermark": watermark,
2338
2367
  }
2339
- parameters = {k: v for k, v in parameters.items() if v is not None}
2340
- payload = {
2341
- "inputs": prompt,
2342
- "parameters": parameters,
2343
- "stream": stream,
2344
- }
2345
2368
 
2346
2369
  # Remove some parameters if not a TGI server
2347
2370
  unsupported_kwargs = _get_unsupported_text_generation_kwargs(model)
@@ -2374,9 +2397,19 @@ class AsyncInferenceClient:
2374
2397
  " Please pass `stream=False` as input."
2375
2398
  )
2376
2399
 
2400
+ provider_helper = get_provider_helper(self.provider, task="text-generation")
2401
+ request_parameters = provider_helper.prepare_request(
2402
+ inputs=prompt,
2403
+ parameters=parameters,
2404
+ extra_payload={"stream": stream},
2405
+ headers=self.headers,
2406
+ model=model or self.model,
2407
+ api_key=self.token,
2408
+ )
2409
+
2377
2410
  # Handle errors separately for more precise error messages
2378
2411
  try:
2379
- bytes_output = await self.post(json=payload, model=model, task="text-generation", stream=stream) # type: ignore
2412
+ bytes_output = await self._inner_post(request_parameters, stream=stream)
2380
2413
  except _import_aiohttp().ClientResponseError as e:
2381
2414
  match = MODEL_KWARGS_NOT_USED_REGEX.search(e.response_error_payload["error"])
2382
2415
  if e.status == 400 and match:
@@ -2386,7 +2419,7 @@ class AsyncInferenceClient:
2386
2419
  prompt=prompt,
2387
2420
  details=details,
2388
2421
  stream=stream,
2389
- model=model,
2422
+ model=model or self.model,
2390
2423
  adapter_id=adapter_id,
2391
2424
  best_of=best_of,
2392
2425
  decoder_input_details=decoder_input_details,
@@ -2424,7 +2457,7 @@ class AsyncInferenceClient:
2424
2457
  self,
2425
2458
  prompt: str,
2426
2459
  *,
2427
- negative_prompt: Optional[List[str]] = None,
2460
+ negative_prompt: Optional[str] = None,
2428
2461
  height: Optional[float] = None,
2429
2462
  width: Optional[float] = None,
2430
2463
  num_inference_steps: Optional[int] = None,
@@ -2447,8 +2480,8 @@ class AsyncInferenceClient:
2447
2480
  Args:
2448
2481
  prompt (`str`):
2449
2482
  The prompt to generate an image from.
2450
- negative_prompt (`List[str`, *optional*):
2451
- One or several prompt to guide what NOT to include in image generation.
2483
+ negative_prompt (`str`, *optional*):
2484
+ One prompt to guide what NOT to include in image generation.
2452
2485
  height (`float`, *optional*):
2453
2486
  The height in pixels of the image to generate.
2454
2487
  width (`float`, *optional*):
@@ -2495,23 +2528,143 @@ class AsyncInferenceClient:
2495
2528
  ... )
2496
2529
  >>> image.save("better_astronaut.png")
2497
2530
  ```
2498
- """
2531
+ Example using a third-party provider directly. Usage will be billed on your fal.ai account.
2532
+ ```py
2533
+ >>> from huggingface_hub import InferenceClient
2534
+ >>> client = InferenceClient(
2535
+ ... provider="fal-ai", # Use fal.ai provider
2536
+ ... api_key="fal-ai-api-key", # Pass your fal.ai API key
2537
+ ... )
2538
+ >>> image = client.text_to_image(
2539
+ ... "A majestic lion in a fantasy forest",
2540
+ ... model="black-forest-labs/FLUX.1-schnell",
2541
+ ... )
2542
+ >>> image.save("lion.png")
2543
+ ```
2499
2544
 
2500
- parameters = {
2501
- "negative_prompt": negative_prompt,
2502
- "height": height,
2503
- "width": width,
2504
- "num_inference_steps": num_inference_steps,
2505
- "guidance_scale": guidance_scale,
2506
- "scheduler": scheduler,
2507
- "target_size": target_size,
2508
- "seed": seed,
2509
- **kwargs,
2510
- }
2511
- payload = _prepare_payload(prompt, parameters=parameters)
2512
- response = await self.post(**payload, model=model, task="text-to-image")
2545
+ Example using a third-party provider through Hugging Face Routing. Usage will be billed on your Hugging Face account.
2546
+ ```py
2547
+ >>> from huggingface_hub import InferenceClient
2548
+ >>> client = InferenceClient(
2549
+ ... provider="replicate", # Use replicate provider
2550
+ ... api_key="hf_...", # Pass your HF token
2551
+ ... )
2552
+ >>> image = client.text_to_image(
2553
+ ... "An astronaut riding a horse on the moon.",
2554
+ ... model="black-forest-labs/FLUX.1-dev",
2555
+ ... )
2556
+ >>> image.save("astronaut.png")
2557
+ ```
2558
+ """
2559
+ provider_helper = get_provider_helper(self.provider, task="text-to-image")
2560
+ request_parameters = provider_helper.prepare_request(
2561
+ inputs=prompt,
2562
+ parameters={
2563
+ "negative_prompt": negative_prompt,
2564
+ "height": height,
2565
+ "width": width,
2566
+ "num_inference_steps": num_inference_steps,
2567
+ "guidance_scale": guidance_scale,
2568
+ "scheduler": scheduler,
2569
+ "target_size": target_size,
2570
+ "seed": seed,
2571
+ **kwargs,
2572
+ },
2573
+ headers=self.headers,
2574
+ model=model or self.model,
2575
+ api_key=self.token,
2576
+ )
2577
+ response = await self._inner_post(request_parameters)
2578
+ response = provider_helper.get_response(response)
2513
2579
  return _bytes_to_image(response)
2514
2580
 
2581
+ async def text_to_video(
2582
+ self,
2583
+ prompt: str,
2584
+ *,
2585
+ model: Optional[str] = None,
2586
+ guidance_scale: Optional[float] = None,
2587
+ negative_prompt: Optional[List[str]] = None,
2588
+ num_frames: Optional[float] = None,
2589
+ num_inference_steps: Optional[int] = None,
2590
+ seed: Optional[int] = None,
2591
+ ) -> bytes:
2592
+ """
2593
+ Generate a video based on a given text.
2594
+
2595
+ Args:
2596
+ prompt (`str`):
2597
+ The prompt to generate a video from.
2598
+ model (`str`, *optional*):
2599
+ The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
2600
+ Inference Endpoint. If not provided, the default recommended text-to-video model will be used.
2601
+ Defaults to None.
2602
+ guidance_scale (`float`, *optional*):
2603
+ A higher guidance scale value encourages the model to generate videos closely linked to the text
2604
+ prompt, but values too high may cause saturation and other artifacts.
2605
+ negative_prompt (`List[str]`, *optional*):
2606
+ One or several prompt to guide what NOT to include in video generation.
2607
+ num_frames (`float`, *optional*):
2608
+ The num_frames parameter determines how many video frames are generated.
2609
+ num_inference_steps (`int`, *optional*):
2610
+ The number of denoising steps. More denoising steps usually lead to a higher quality video at the
2611
+ expense of slower inference.
2612
+ seed (`int`, *optional*):
2613
+ Seed for the random number generator.
2614
+
2615
+ Returns:
2616
+ `bytes`: The generated video.
2617
+
2618
+ Example:
2619
+
2620
+ Example using a third-party provider directly. Usage will be billed on your fal.ai account.
2621
+ ```py
2622
+ >>> from huggingface_hub import InferenceClient
2623
+ >>> client = InferenceClient(
2624
+ ... provider="fal-ai", # Using fal.ai provider
2625
+ ... api_key="fal-ai-api-key", # Pass your fal.ai API key
2626
+ ... )
2627
+ >>> video = client.text_to_video(
2628
+ ... "A majestic lion running in a fantasy forest",
2629
+ ... model="tencent/HunyuanVideo",
2630
+ ... )
2631
+ >>> with open("lion.mp4", "wb") as file:
2632
+ ... file.write(video)
2633
+ ```
2634
+
2635
+ Example using a third-party provider through Hugging Face Routing. Usage will be billed on your Hugging Face account.
2636
+ ```py
2637
+ >>> from huggingface_hub import InferenceClient
2638
+ >>> client = InferenceClient(
2639
+ ... provider="replicate", # Using replicate provider
2640
+ ... api_key="hf_...", # Pass your HF token
2641
+ ... )
2642
+ >>> video = client.text_to_video(
2643
+ ... "A cat running in a park",
2644
+ ... model="genmo/mochi-1-preview",
2645
+ ... )
2646
+ >>> with open("cat.mp4", "wb") as file:
2647
+ ... file.write(video)
2648
+ ```
2649
+ """
2650
+ provider_helper = get_provider_helper(self.provider, task="text-to-video")
2651
+ request_parameters = provider_helper.prepare_request(
2652
+ inputs=prompt,
2653
+ parameters={
2654
+ "guidance_scale": guidance_scale,
2655
+ "negative_prompt": negative_prompt,
2656
+ "num_frames": num_frames,
2657
+ "num_inference_steps": num_inference_steps,
2658
+ "seed": seed,
2659
+ },
2660
+ headers=self.headers,
2661
+ model=model or self.model,
2662
+ api_key=self.token,
2663
+ )
2664
+ response = await self._inner_post(request_parameters)
2665
+ response = provider_helper.get_response(response)
2666
+ return response
2667
+
2515
2668
  async def text_to_speech(
2516
2669
  self,
2517
2670
  text: str,
@@ -2610,27 +2763,62 @@ class AsyncInferenceClient:
2610
2763
  >>> audio = await client.text_to_speech("Hello world")
2611
2764
  >>> Path("hello_world.flac").write_bytes(audio)
2612
2765
  ```
2766
+
2767
+ Example using a third-party provider directly. Usage will be billed on your Replicate account.
2768
+ ```py
2769
+ >>> from huggingface_hub import InferenceClient
2770
+ >>> client = InferenceClient(
2771
+ ... provider="replicate",
2772
+ ... api_key="your-replicate-api-key", # Pass your Replicate API key directly
2773
+ ... )
2774
+ >>> audio = client.text_to_speech(
2775
+ ... text="Hello world",
2776
+ ... model="OuteAI/OuteTTS-0.3-500M",
2777
+ ... )
2778
+ >>> Path("hello_world.flac").write_bytes(audio)
2779
+ ```
2780
+
2781
+ Example using a third-party provider through Hugging Face Routing. Usage will be billed on your Hugging Face account.
2782
+ ```py
2783
+ >>> from huggingface_hub import InferenceClient
2784
+ >>> client = InferenceClient(
2785
+ ... provider="replicate",
2786
+ ... api_key="hf_...", # Pass your HF token
2787
+ ... )
2788
+ >>> audio =client.text_to_speech(
2789
+ ... text="Hello world",
2790
+ ... model="OuteAI/OuteTTS-0.3-500M",
2791
+ ... )
2792
+ >>> Path("hello_world.flac").write_bytes(audio)
2793
+ ```
2613
2794
  """
2614
- parameters = {
2615
- "do_sample": do_sample,
2616
- "early_stopping": early_stopping,
2617
- "epsilon_cutoff": epsilon_cutoff,
2618
- "eta_cutoff": eta_cutoff,
2619
- "max_length": max_length,
2620
- "max_new_tokens": max_new_tokens,
2621
- "min_length": min_length,
2622
- "min_new_tokens": min_new_tokens,
2623
- "num_beam_groups": num_beam_groups,
2624
- "num_beams": num_beams,
2625
- "penalty_alpha": penalty_alpha,
2626
- "temperature": temperature,
2627
- "top_k": top_k,
2628
- "top_p": top_p,
2629
- "typical_p": typical_p,
2630
- "use_cache": use_cache,
2631
- }
2632
- payload = _prepare_payload(text, parameters=parameters)
2633
- response = await self.post(**payload, model=model, task="text-to-speech")
2795
+ provider_helper = get_provider_helper(self.provider, task="text-to-speech")
2796
+ request_parameters = provider_helper.prepare_request(
2797
+ inputs=text,
2798
+ parameters={
2799
+ "do_sample": do_sample,
2800
+ "early_stopping": early_stopping,
2801
+ "epsilon_cutoff": epsilon_cutoff,
2802
+ "eta_cutoff": eta_cutoff,
2803
+ "max_length": max_length,
2804
+ "max_new_tokens": max_new_tokens,
2805
+ "min_length": min_length,
2806
+ "min_new_tokens": min_new_tokens,
2807
+ "num_beam_groups": num_beam_groups,
2808
+ "num_beams": num_beams,
2809
+ "penalty_alpha": penalty_alpha,
2810
+ "temperature": temperature,
2811
+ "top_k": top_k,
2812
+ "top_p": top_p,
2813
+ "typical_p": typical_p,
2814
+ "use_cache": use_cache,
2815
+ },
2816
+ headers=self.headers,
2817
+ model=model or self.model,
2818
+ api_key=self.token,
2819
+ )
2820
+ response = await self._inner_post(request_parameters)
2821
+ response = provider_helper.get_response(response)
2634
2822
  return response
2635
2823
 
2636
2824
  async def token_classification(
@@ -2693,18 +2881,19 @@ class AsyncInferenceClient:
2693
2881
  ]
2694
2882
  ```
2695
2883
  """
2696
-
2697
- parameters = {
2698
- "aggregation_strategy": aggregation_strategy,
2699
- "ignore_labels": ignore_labels,
2700
- "stride": stride,
2701
- }
2702
- payload = _prepare_payload(text, parameters=parameters)
2703
- response = await self.post(
2704
- **payload,
2705
- model=model,
2706
- task="token-classification",
2884
+ provider_helper = get_provider_helper(self.provider, task="token-classification")
2885
+ request_parameters = provider_helper.prepare_request(
2886
+ inputs=text,
2887
+ parameters={
2888
+ "aggregation_strategy": aggregation_strategy,
2889
+ "ignore_labels": ignore_labels,
2890
+ "stride": stride,
2891
+ },
2892
+ headers=self.headers,
2893
+ model=model or self.model,
2894
+ api_key=self.token,
2707
2895
  )
2896
+ response = await self._inner_post(request_parameters)
2708
2897
  return TokenClassificationOutputElement.parse_obj_as_list(response)
2709
2898
 
2710
2899
  async def translation(
@@ -2778,15 +2967,22 @@ class AsyncInferenceClient:
2778
2967
 
2779
2968
  if src_lang is None and tgt_lang is not None:
2780
2969
  raise ValueError("You cannot specify `tgt_lang` without specifying `src_lang`.")
2781
- parameters = {
2782
- "src_lang": src_lang,
2783
- "tgt_lang": tgt_lang,
2784
- "clean_up_tokenization_spaces": clean_up_tokenization_spaces,
2785
- "truncation": truncation,
2786
- "generate_parameters": generate_parameters,
2787
- }
2788
- payload = _prepare_payload(text, parameters=parameters)
2789
- response = await self.post(**payload, model=model, task="translation")
2970
+
2971
+ provider_helper = get_provider_helper(self.provider, task="translation")
2972
+ request_parameters = provider_helper.prepare_request(
2973
+ inputs=text,
2974
+ parameters={
2975
+ "src_lang": src_lang,
2976
+ "tgt_lang": tgt_lang,
2977
+ "clean_up_tokenization_spaces": clean_up_tokenization_spaces,
2978
+ "truncation": truncation,
2979
+ "generate_parameters": generate_parameters,
2980
+ },
2981
+ headers=self.headers,
2982
+ model=model or self.model,
2983
+ api_key=self.token,
2984
+ )
2985
+ response = await self._inner_post(request_parameters)
2790
2986
  return TranslationOutput.parse_obj_as_list(response)[0]
2791
2987
 
2792
2988
  async def visual_question_answering(
@@ -2836,10 +3032,16 @@ class AsyncInferenceClient:
2836
3032
  ]
2837
3033
  ```
2838
3034
  """
2839
- payload: Dict[str, Any] = {"question": question, "image": _b64_encode(image)}
2840
- if top_k is not None:
2841
- payload.setdefault("parameters", {})["top_k"] = top_k
2842
- response = await self.post(json=payload, model=model, task="visual-question-answering")
3035
+ provider_helper = get_provider_helper(self.provider, task="visual-question-answering")
3036
+ request_parameters = provider_helper.prepare_request(
3037
+ inputs=image,
3038
+ parameters={"top_k": top_k},
3039
+ headers=self.headers,
3040
+ model=model or self.model,
3041
+ api_key=self.token,
3042
+ extra_payload={"question": question, "image": _b64_encode(image)},
3043
+ )
3044
+ response = await self._inner_post(request_parameters)
2843
3045
  return VisualQuestionAnsweringOutputElement.parse_obj_as_list(response)
2844
3046
 
2845
3047
  @_deprecate_arguments(
@@ -2947,17 +3149,20 @@ class AsyncInferenceClient:
2947
3149
  candidate_labels = labels
2948
3150
  elif candidate_labels is None:
2949
3151
  raise ValueError("Must specify `candidate_labels`")
2950
- parameters = {
2951
- "candidate_labels": candidate_labels,
2952
- "multi_label": multi_label,
2953
- "hypothesis_template": hypothesis_template,
2954
- }
2955
- payload = _prepare_payload(text, parameters=parameters)
2956
- response = await self.post(
2957
- **payload,
2958
- task="zero-shot-classification",
2959
- model=model,
3152
+
3153
+ provider_helper = get_provider_helper(self.provider, task="zero-shot-classification")
3154
+ request_parameters = provider_helper.prepare_request(
3155
+ inputs=text,
3156
+ parameters={
3157
+ "candidate_labels": candidate_labels,
3158
+ "multi_label": multi_label,
3159
+ "hypothesis_template": hypothesis_template,
3160
+ },
3161
+ headers=self.headers,
3162
+ model=model or self.model,
3163
+ api_key=self.token,
2960
3164
  )
3165
+ response = await self._inner_post(request_parameters)
2961
3166
  output = _bytes_to_dict(response)
2962
3167
  return [
2963
3168
  ZeroShotClassificationOutputElement.parse_obj_as_instance({"label": label, "score": score})
@@ -3031,18 +3236,110 @@ class AsyncInferenceClient:
3031
3236
  # Raise ValueError if input is less than 2 labels
3032
3237
  if len(candidate_labels) < 2:
3033
3238
  raise ValueError("You must specify at least 2 classes to compare.")
3034
- parameters = {
3035
- "candidate_labels": candidate_labels,
3036
- "hypothesis_template": hypothesis_template,
3037
- }
3038
- payload = _prepare_payload(image, parameters=parameters, expect_binary=True)
3039
- response = await self.post(
3040
- **payload,
3041
- model=model,
3042
- task="zero-shot-image-classification",
3239
+
3240
+ provider_helper = get_provider_helper(self.provider, task="zero-shot-image-classification")
3241
+ request_parameters = provider_helper.prepare_request(
3242
+ inputs=image,
3243
+ parameters={
3244
+ "candidate_labels": candidate_labels,
3245
+ "hypothesis_template": hypothesis_template,
3246
+ },
3247
+ headers=self.headers,
3248
+ model=model or self.model,
3249
+ api_key=self.token,
3043
3250
  )
3251
+ response = await self._inner_post(request_parameters)
3044
3252
  return ZeroShotImageClassificationOutputElement.parse_obj_as_list(response)
3045
3253
 
3254
+ async def list_deployed_models(
3255
+ self, frameworks: Union[None, str, Literal["all"], List[str]] = None
3256
+ ) -> Dict[str, List[str]]:
3257
+ """
3258
+ List models deployed on the Serverless Inference API service.
3259
+
3260
+ This helper checks deployed models framework by framework. By default, it will check the 4 main frameworks that
3261
+ are supported and account for 95% of the hosted models. However, if you want a complete list of models you can
3262
+ specify `frameworks="all"` as input. Alternatively, if you know before-hand which framework you are interested
3263
+ in, you can also restrict to search to this one (e.g. `frameworks="text-generation-inference"`). The more
3264
+ frameworks are checked, the more time it will take.
3265
+
3266
+ <Tip warning={true}>
3267
+
3268
+ This endpoint method does not return a live list of all models available for the Serverless Inference API service.
3269
+ It searches over a cached list of models that were recently available and the list may not be up to date.
3270
+ If you want to know the live status of a specific model, use [`~InferenceClient.get_model_status`].
3271
+
3272
+ </Tip>
3273
+
3274
+ <Tip>
3275
+
3276
+ This endpoint method is mostly useful for discoverability. If you already know which model you want to use and want to
3277
+ check its availability, you can directly use [`~InferenceClient.get_model_status`].
3278
+
3279
+ </Tip>
3280
+
3281
+ Args:
3282
+ frameworks (`Literal["all"]` or `List[str]` or `str`, *optional*):
3283
+ The frameworks to filter on. By default only a subset of the available frameworks are tested. If set to
3284
+ "all", all available frameworks will be tested. It is also possible to provide a single framework or a
3285
+ custom set of frameworks to check.
3286
+
3287
+ Returns:
3288
+ `Dict[str, List[str]]`: A dictionary mapping task names to a sorted list of model IDs.
3289
+
3290
+ Example:
3291
+ ```py
3292
+ # Must be run in an async contextthon
3293
+ >>> from huggingface_hub import AsyncInferenceClient
3294
+ >>> client = AsyncInferenceClient()
3295
+
3296
+ # Discover zero-shot-classification models currently deployed
3297
+ >>> models = await client.list_deployed_models()
3298
+ >>> models["zero-shot-classification"]
3299
+ ['Narsil/deberta-large-mnli-zero-cls', 'facebook/bart-large-mnli', ...]
3300
+
3301
+ # List from only 1 framework
3302
+ >>> await client.list_deployed_models("text-generation-inference")
3303
+ {'text-generation': ['bigcode/starcoder', 'meta-llama/Llama-2-70b-chat-hf', ...], ...}
3304
+ ```
3305
+ """
3306
+ if self.provider != "hf-inference":
3307
+ raise ValueError(f"Listing deployed models is not supported on '{self.provider}'.")
3308
+
3309
+ # Resolve which frameworks to check
3310
+ if frameworks is None:
3311
+ frameworks = MAIN_INFERENCE_API_FRAMEWORKS
3312
+ elif frameworks == "all":
3313
+ frameworks = ALL_INFERENCE_API_FRAMEWORKS
3314
+ elif isinstance(frameworks, str):
3315
+ frameworks = [frameworks]
3316
+ frameworks = list(set(frameworks))
3317
+
3318
+ # Fetch them iteratively
3319
+ models_by_task: Dict[str, List[str]] = {}
3320
+
3321
+ def _unpack_response(framework: str, items: List[Dict]) -> None:
3322
+ for model in items:
3323
+ if framework == "sentence-transformers":
3324
+ # Model running with the `sentence-transformers` framework can work with both tasks even if not
3325
+ # branded as such in the API response
3326
+ models_by_task.setdefault("feature-extraction", []).append(model["model_id"])
3327
+ models_by_task.setdefault("sentence-similarity", []).append(model["model_id"])
3328
+ else:
3329
+ models_by_task.setdefault(model["task"], []).append(model["model_id"])
3330
+
3331
+ for framework in frameworks:
3332
+ response = get_session().get(
3333
+ f"{INFERENCE_ENDPOINT}/framework/{framework}", headers=build_hf_headers(token=self.token)
3334
+ )
3335
+ hf_raise_for_status(response)
3336
+ _unpack_response(framework, response.json())
3337
+
3338
+ # Sort alphabetically for discoverability and return
3339
+ for task, models in models_by_task.items():
3340
+ models_by_task[task] = sorted(set(models), key=lambda x: x.lower())
3341
+ return models_by_task
3342
+
3046
3343
  def _get_client_session(self, headers: Optional[Dict] = None) -> "ClientSession":
3047
3344
  aiohttp = _import_aiohttp()
3048
3345
  client_headers = self.headers.copy()
@@ -3084,60 +3381,6 @@ class AsyncInferenceClient:
3084
3381
  session.close = close_session
3085
3382
  return session
3086
3383
 
3087
- def _resolve_url(self, model: Optional[str] = None, task: Optional[str] = None) -> str:
3088
- model = model or self.model or self.base_url
3089
-
3090
- # If model is already a URL, ignore `task` and return directly
3091
- if model is not None and (model.startswith("http://") or model.startswith("https://")):
3092
- return model
3093
-
3094
- # # If no model but task is set => fetch the recommended one for this task
3095
- if model is None:
3096
- if task is None:
3097
- raise ValueError(
3098
- "You must specify at least a model (repo_id or URL) or a task, either when instantiating"
3099
- " `InferenceClient` or when making a request."
3100
- )
3101
- model = self.get_recommended_model(task)
3102
- logger.info(
3103
- f"Using recommended model {model} for task {task}. Note that it is"
3104
- f" encouraged to explicitly set `model='{model}'` as the recommended"
3105
- " models list might get updated without prior notice."
3106
- )
3107
-
3108
- # Compute InferenceAPI url
3109
- return (
3110
- # Feature-extraction and sentence-similarity are the only cases where we handle models with several tasks.
3111
- f"{INFERENCE_ENDPOINT}/pipeline/{task}/{model}"
3112
- if task in ("feature-extraction", "sentence-similarity")
3113
- # Otherwise, we use the default endpoint
3114
- else f"{INFERENCE_ENDPOINT}/models/{model}"
3115
- )
3116
-
3117
- @staticmethod
3118
- def get_recommended_model(task: str) -> str:
3119
- """
3120
- Get the model Hugging Face recommends for the input task.
3121
-
3122
- Args:
3123
- task (`str`):
3124
- The Hugging Face task to get which model Hugging Face recommends.
3125
- All available tasks can be found [here](https://huggingface.co/tasks).
3126
-
3127
- Returns:
3128
- `str`: Name of the model recommended for the input task.
3129
-
3130
- Raises:
3131
- `ValueError`: If Hugging Face has no recommendation for the input task.
3132
- """
3133
- model = _fetch_recommended_models().get(task)
3134
- if model is None:
3135
- raise ValueError(
3136
- f"Task {task} has no recommended model. Please specify a model"
3137
- " explicitly. Visit https://huggingface.co/tasks for more info."
3138
- )
3139
- return model
3140
-
3141
3384
  async def get_endpoint_info(self, *, model: Optional[str] = None) -> Dict[str, Any]:
3142
3385
  """
3143
3386
  Get information about the deployed endpoint.
@@ -3182,6 +3425,9 @@ class AsyncInferenceClient:
3182
3425
  }
3183
3426
  ```
3184
3427
  """
3428
+ if self.provider != "hf-inference":
3429
+ raise ValueError(f"Getting endpoint info is not supported on '{self.provider}'.")
3430
+
3185
3431
  model = model or self.model
3186
3432
  if model is None:
3187
3433
  raise ValueError("Model id not provided.")
@@ -3190,7 +3436,7 @@ class AsyncInferenceClient:
3190
3436
  else:
3191
3437
  url = f"{INFERENCE_ENDPOINT}/models/{model}/info"
3192
3438
 
3193
- async with self._get_client_session() as client:
3439
+ async with self._get_client_session(headers=build_hf_headers(token=self.token)) as client:
3194
3440
  response = await client.get(url, proxy=self.proxies)
3195
3441
  response.raise_for_status()
3196
3442
  return await response.json()
@@ -3218,6 +3464,9 @@ class AsyncInferenceClient:
3218
3464
  True
3219
3465
  ```
3220
3466
  """
3467
+ if self.provider != "hf-inference":
3468
+ raise ValueError(f"Health check is not supported on '{self.provider}'.")
3469
+
3221
3470
  model = model or self.model
3222
3471
  if model is None:
3223
3472
  raise ValueError("Model id not provided.")
@@ -3227,7 +3476,7 @@ class AsyncInferenceClient:
3227
3476
  )
3228
3477
  url = model.rstrip("/") + "/health"
3229
3478
 
3230
- async with self._get_client_session() as client:
3479
+ async with self._get_client_session(headers=build_hf_headers(token=self.token)) as client:
3231
3480
  response = await client.get(url, proxy=self.proxies)
3232
3481
  return response.status == 200
3233
3482
 
@@ -3262,6 +3511,9 @@ class AsyncInferenceClient:
3262
3511
  ModelStatus(loaded=True, state='Loaded', compute_type='gpu', framework='text-generation-inference')
3263
3512
  ```
3264
3513
  """
3514
+ if self.provider != "hf-inference":
3515
+ raise ValueError(f"Getting model status is not supported on '{self.provider}'.")
3516
+
3265
3517
  model = model or self.model
3266
3518
  if model is None:
3267
3519
  raise ValueError("Model id not provided.")
@@ -3269,7 +3521,7 @@ class AsyncInferenceClient:
3269
3521
  raise NotImplementedError("Model status is only available for Inference API endpoints.")
3270
3522
  url = f"{INFERENCE_ENDPOINT}/status/{model}"
3271
3523
 
3272
- async with self._get_client_session() as client:
3524
+ async with self._get_client_session(headers=build_hf_headers(token=self.token)) as client:
3273
3525
  response = await client.get(url, proxy=self.proxies)
3274
3526
  response.raise_for_status()
3275
3527
  response_data = await response.json()