huggingface-hub 0.27.1__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of huggingface-hub might be problematic. Click here for more details.

Files changed (39) hide show
  1. huggingface_hub/__init__.py +418 -12
  2. huggingface_hub/_commit_api.py +33 -4
  3. huggingface_hub/_inference_endpoints.py +8 -2
  4. huggingface_hub/_local_folder.py +14 -3
  5. huggingface_hub/commands/scan_cache.py +1 -1
  6. huggingface_hub/commands/upload_large_folder.py +1 -1
  7. huggingface_hub/constants.py +7 -2
  8. huggingface_hub/file_download.py +1 -2
  9. huggingface_hub/hf_api.py +65 -84
  10. huggingface_hub/inference/_client.py +706 -450
  11. huggingface_hub/inference/_common.py +32 -64
  12. huggingface_hub/inference/_generated/_async_client.py +722 -470
  13. huggingface_hub/inference/_generated/types/__init__.py +1 -0
  14. huggingface_hub/inference/_generated/types/image_to_image.py +3 -3
  15. huggingface_hub/inference/_generated/types/text_to_audio.py +1 -2
  16. huggingface_hub/inference/_generated/types/text_to_image.py +3 -3
  17. huggingface_hub/inference/_generated/types/text_to_speech.py +3 -6
  18. huggingface_hub/inference/_generated/types/text_to_video.py +47 -0
  19. huggingface_hub/inference/_generated/types/visual_question_answering.py +1 -1
  20. huggingface_hub/inference/_providers/__init__.py +89 -0
  21. huggingface_hub/inference/_providers/fal_ai.py +159 -0
  22. huggingface_hub/inference/_providers/hf_inference.py +202 -0
  23. huggingface_hub/inference/_providers/replicate.py +148 -0
  24. huggingface_hub/inference/_providers/sambanova.py +89 -0
  25. huggingface_hub/inference/_providers/together.py +153 -0
  26. huggingface_hub/py.typed +0 -0
  27. huggingface_hub/repocard.py +1 -1
  28. huggingface_hub/repocard_data.py +2 -1
  29. huggingface_hub/serialization/_base.py +1 -1
  30. huggingface_hub/serialization/_torch.py +1 -1
  31. huggingface_hub/utils/_fixes.py +25 -13
  32. huggingface_hub/utils/_http.py +3 -3
  33. huggingface_hub/utils/logging.py +1 -1
  34. {huggingface_hub-0.27.1.dist-info → huggingface_hub-0.28.0.dist-info}/METADATA +4 -4
  35. {huggingface_hub-0.27.1.dist-info → huggingface_hub-0.28.0.dist-info}/RECORD +39 -31
  36. {huggingface_hub-0.27.1.dist-info → huggingface_hub-0.28.0.dist-info}/LICENSE +0 -0
  37. {huggingface_hub-0.27.1.dist-info → huggingface_hub-0.28.0.dist-info}/WHEEL +0 -0
  38. {huggingface_hub-0.27.1.dist-info → huggingface_hub-0.28.0.dist-info}/entry_points.txt +0 -0
  39. {huggingface_hub-0.27.1.dist-info → huggingface_hub-0.28.0.dist-info}/top_level.txt +0 -0
@@ -40,7 +40,6 @@ import warnings
40
40
  from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Literal, Optional, Union, overload
41
41
 
42
42
  from requests import HTTPError
43
- from requests.structures import CaseInsensitiveDict
44
43
 
45
44
  from huggingface_hub.constants import ALL_INFERENCE_API_FRAMEWORKS, INFERENCE_ENDPOINT, MAIN_INFERENCE_API_FRAMEWORKS
46
45
  from huggingface_hub.errors import BadRequestError, InferenceTimeoutError
@@ -48,16 +47,15 @@ from huggingface_hub.inference._common import (
48
47
  TASKS_EXPECTING_IMAGES,
49
48
  ContentT,
50
49
  ModelStatus,
50
+ RequestParameters,
51
51
  _b64_encode,
52
52
  _b64_to_image,
53
53
  _bytes_to_dict,
54
54
  _bytes_to_image,
55
55
  _bytes_to_list,
56
- _fetch_recommended_models,
57
56
  _get_unsupported_text_generation_kwargs,
58
57
  _import_numpy,
59
58
  _open_as_binary,
60
- _prepare_payload,
61
59
  _set_unsupported_text_generation_kwargs,
62
60
  _stream_chat_completion_response,
63
61
  _stream_text_generation_response,
@@ -104,8 +102,9 @@ from huggingface_hub.inference._generated.types import (
104
102
  ZeroShotClassificationOutputElement,
105
103
  ZeroShotImageClassificationOutputElement,
106
104
  )
105
+ from huggingface_hub.inference._providers import PROVIDER_T, HFInferenceTask, get_provider_helper
107
106
  from huggingface_hub.utils import build_hf_headers, get_session, hf_raise_for_status
108
- from huggingface_hub.utils._deprecation import _deprecate_arguments
107
+ from huggingface_hub.utils._deprecation import _deprecate_arguments, _deprecate_method
109
108
 
110
109
 
111
110
  if TYPE_CHECKING:
@@ -123,7 +122,7 @@ class InferenceClient:
123
122
  Initialize a new Inference Client.
124
123
 
125
124
  [`InferenceClient`] aims to provide a unified experience to perform inference. The client can be used
126
- seamlessly with either the (free) Inference API or self-hosted Inference Endpoints.
125
+ seamlessly with either the (free) Inference API, self-hosted Inference Endpoints, or third-party Inference Providers.
127
126
 
128
127
  Args:
129
128
  model (`str`, `optional`):
@@ -134,6 +133,10 @@ class InferenceClient:
134
133
  arguments are mutually exclusive. If using `base_url` for chat completion, the `/chat/completions` suffix
135
134
  path will be appended to the base URL (see the [TGI Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api)
136
135
  documentation for details). When passing a URL as `model`, the client will not append any suffix path to it.
136
+ provider (`str`, *optional*):
137
+ Name of the provider to use for inference. Can be `"replicate"`, `"together"`, `"fal-ai"`, `"sambanova"` or `"hf-inference"`.
138
+ defaults to hf-inference (Hugging Face Serverless Inference API).
139
+ If model is a URL or `base_url` is passed, then `provider` is not used.
137
140
  token (`str` or `bool`, *optional*):
138
141
  Hugging Face token. Will default to the locally saved token if not provided.
139
142
  Pass `token=False` if you don't want to send your token to the server.
@@ -161,7 +164,8 @@ class InferenceClient:
161
164
  self,
162
165
  model: Optional[str] = None,
163
166
  *,
164
- token: Union[str, bool, None] = None,
167
+ provider: Optional[PROVIDER_T] = None,
168
+ token: Optional[str] = None,
165
169
  timeout: Optional[float] = None,
166
170
  headers: Optional[Dict[str, str]] = None,
167
171
  cookies: Optional[Dict[str, str]] = None,
@@ -185,12 +189,12 @@ class InferenceClient:
185
189
  )
186
190
 
187
191
  self.model: Optional[str] = model
188
- self.token: Union[str, bool, None] = token if token is not None else api_key
189
- self.headers: CaseInsensitiveDict[str] = CaseInsensitiveDict(
190
- build_hf_headers(token=self.token) # 'authorization' + 'user-agent'
191
- )
192
- if headers is not None:
193
- self.headers.update(headers)
192
+ self.token: Optional[str] = token if token is not None else api_key
193
+ self.headers = headers if headers is not None else {}
194
+
195
+ # Configure provider
196
+ self.provider = provider if provider is not None else "hf-inference"
197
+
194
198
  self.cookies = cookies
195
199
  self.timeout = timeout
196
200
  self.proxies = proxies
@@ -234,6 +238,14 @@ class InferenceClient:
234
238
  stream: bool = False,
235
239
  ) -> Union[bytes, Iterable[bytes]]: ...
236
240
 
241
+ @_deprecate_method(
242
+ version="0.31.0",
243
+ message=(
244
+ "Making direct POST requests to the inference server is not supported anymore. "
245
+ "Please use task methods instead (e.g. `InferenceClient.chat_completion`). "
246
+ "If your use case is not supported, please open an issue in https://github.com/huggingface/huggingface_hub."
247
+ ),
248
+ )
237
249
  def post(
238
250
  self,
239
251
  *,
@@ -246,53 +258,62 @@ class InferenceClient:
246
258
  """
247
259
  Make a POST request to the inference server.
248
260
 
249
- Args:
250
- json (`Union[str, Dict, List]`, *optional*):
251
- The JSON data to send in the request body, specific to each task. Defaults to None.
252
- data (`Union[str, Path, bytes, BinaryIO]`, *optional*):
253
- The content to send in the request body, specific to each task.
254
- It can be raw bytes, a pointer to an opened file, a local file path,
255
- or a URL to an online resource (image, audio file,...). If both `json` and `data` are passed,
256
- `data` will take precedence. At least `json` or `data` must be provided. Defaults to None.
257
- model (`str`, *optional*):
258
- The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
259
- Inference Endpoint. Will override the model defined at the instance level. Defaults to None.
260
- task (`str`, *optional*):
261
- The task to perform on the inference. All available tasks can be found
262
- [here](https://huggingface.co/tasks). Used only to default to a recommended model if `model` is not
263
- provided. At least `model` or `task` must be provided. Defaults to None.
264
- stream (`bool`, *optional*):
265
- Whether to iterate over streaming APIs.
261
+ This method is deprecated and will be removed in the future.
262
+ Please use task methods instead (e.g. `InferenceClient.chat_completion`).
263
+ """
264
+ if self.provider != "hf-inference":
265
+ raise ValueError(
266
+ "Cannot use `post` with another provider than `hf-inference`. "
267
+ "`InferenceClient.post` is deprecated and should not be used directly anymore."
268
+ )
269
+ provider_helper = HFInferenceTask(task or "unknown")
270
+ url = provider_helper.build_url(provider_helper.map_model(model))
271
+ headers = provider_helper.prepare_headers(headers=self.headers, api_key=self.token)
272
+ return self._inner_post(
273
+ request_parameters=RequestParameters(
274
+ url=url,
275
+ task=task or "unknown",
276
+ model=model or "unknown",
277
+ json=json,
278
+ data=data,
279
+ headers=headers,
280
+ ),
281
+ stream=stream,
282
+ )
266
283
 
267
- Returns:
268
- bytes: The raw bytes returned by the server.
284
+ @overload
285
+ def _inner_post( # type: ignore[misc]
286
+ self, request_parameters: RequestParameters, *, stream: Literal[False] = ...
287
+ ) -> bytes: ...
269
288
 
270
- Raises:
271
- [`InferenceTimeoutError`]:
272
- If the model is unavailable or the request times out.
273
- `HTTPError`:
274
- If the request fails with an HTTP error status code other than HTTP 503.
275
- """
276
- url = self._resolve_url(model, task)
289
+ @overload
290
+ def _inner_post( # type: ignore[misc]
291
+ self, request_parameters: RequestParameters, *, stream: Literal[True] = ...
292
+ ) -> Iterable[bytes]: ...
277
293
 
278
- if data is not None and json is not None:
279
- warnings.warn("Ignoring `json` as `data` is passed as binary.")
294
+ @overload
295
+ def _inner_post(
296
+ self, request_parameters: RequestParameters, *, stream: bool = False
297
+ ) -> Union[bytes, Iterable[bytes]]: ...
280
298
 
281
- # Set Accept header if relevant
282
- headers = self.headers.copy()
283
- if task in TASKS_EXPECTING_IMAGES and "Accept" not in headers:
284
- headers["Accept"] = "image/png"
299
+ def _inner_post(
300
+ self, request_parameters: RequestParameters, *, stream: bool = False
301
+ ) -> Union[bytes, Iterable[bytes]]:
302
+ """Make a request to the inference server."""
303
+ # TODO: this should be handled in provider helpers directly
304
+ if request_parameters.task in TASKS_EXPECTING_IMAGES and "Accept" not in request_parameters.headers:
305
+ request_parameters.headers["Accept"] = "image/png"
285
306
 
286
307
  t0 = time.time()
287
308
  timeout = self.timeout
288
309
  while True:
289
- with _open_as_binary(data) as data_as_binary:
310
+ with _open_as_binary(request_parameters.data) as data_as_binary:
290
311
  try:
291
312
  response = get_session().post(
292
- url,
293
- json=json,
313
+ request_parameters.url,
314
+ json=request_parameters.json,
294
315
  data=data_as_binary,
295
- headers=headers,
316
+ headers=request_parameters.headers,
296
317
  cookies=self.cookies,
297
318
  timeout=self.timeout,
298
319
  stream=stream,
@@ -300,21 +321,21 @@ class InferenceClient:
300
321
  )
301
322
  except TimeoutError as error:
302
323
  # Convert any `TimeoutError` to a `InferenceTimeoutError`
303
- raise InferenceTimeoutError(f"Inference call timed out: {url}") from error # type: ignore
324
+ raise InferenceTimeoutError(f"Inference call timed out: {request_parameters.url}") from error # type: ignore
304
325
 
305
326
  try:
306
327
  hf_raise_for_status(response)
307
328
  return response.iter_lines() if stream else response.content
308
329
  except HTTPError as error:
309
- if error.response.status_code == 422 and task is not None:
330
+ if error.response.status_code == 422 and request_parameters.task != "unknown":
310
331
  error.args = (
311
- f"{error.args[0]}\nMake sure '{task}' task is supported by the model.",
332
+ f"{error.args[0]}\nMake sure '{request_parameters.task}' task is supported by the model.",
312
333
  ) + error.args[1:]
313
334
  if error.response.status_code == 503:
314
335
  # If Model is unavailable, either raise a TimeoutError...
315
336
  if timeout is not None and time.time() - t0 > timeout:
316
337
  raise InferenceTimeoutError(
317
- f"Model not loaded on the server: {url}. Please retry with a higher timeout (current:"
338
+ f"Model not loaded on the server: {request_parameters.url}. Please retry with a higher timeout (current:"
318
339
  f" {self.timeout}).",
319
340
  request=error.request,
320
341
  response=error.response,
@@ -322,8 +343,10 @@ class InferenceClient:
322
343
  # ...or wait 1s and retry
323
344
  logger.info(f"Waiting for model to be loaded on the server: {error}")
324
345
  time.sleep(1)
325
- if "X-wait-for-model" not in headers and url.startswith(INFERENCE_ENDPOINT):
326
- headers["X-wait-for-model"] = "1"
346
+ if "X-wait-for-model" not in request_parameters.headers and request_parameters.url.startswith(
347
+ INFERENCE_ENDPOINT
348
+ ):
349
+ request_parameters.headers["X-wait-for-model"] = "1"
327
350
  if timeout is not None:
328
351
  timeout = max(self.timeout - (time.time() - t0), 1) # type: ignore
329
352
  continue
@@ -374,9 +397,15 @@ class InferenceClient:
374
397
  ]
375
398
  ```
376
399
  """
377
- parameters = {"function_to_apply": function_to_apply, "top_k": top_k}
378
- payload = _prepare_payload(audio, parameters=parameters, expect_binary=True)
379
- response = self.post(**payload, model=model, task="audio-classification")
400
+ provider_helper = get_provider_helper(self.provider, task="audio-classification")
401
+ request_parameters = provider_helper.prepare_request(
402
+ inputs=audio,
403
+ parameters={"function_to_apply": function_to_apply, "top_k": top_k},
404
+ headers=self.headers,
405
+ model=model or self.model,
406
+ api_key=self.token,
407
+ )
408
+ response = self._inner_post(request_parameters)
380
409
  return AudioClassificationOutputElement.parse_obj_as_list(response)
381
410
 
382
411
  def audio_to_audio(
@@ -416,7 +445,15 @@ class InferenceClient:
416
445
  f.write(item.blob)
417
446
  ```
418
447
  """
419
- response = self.post(data=audio, model=model, task="audio-to-audio")
448
+ provider_helper = get_provider_helper(self.provider, task="audio-to-audio")
449
+ request_parameters = provider_helper.prepare_request(
450
+ inputs=audio,
451
+ parameters={},
452
+ headers=self.headers,
453
+ model=model or self.model,
454
+ api_key=self.token,
455
+ )
456
+ response = self._inner_post(request_parameters)
420
457
  audio_output = AudioToAudioOutputElement.parse_obj_as_list(response)
421
458
  for item in audio_output:
422
459
  item.blob = base64.b64decode(item.blob)
@@ -437,7 +474,8 @@ class InferenceClient:
437
474
  model (`str`, *optional*):
438
475
  The model to use for ASR. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
439
476
  Inference Endpoint. If not provided, the default recommended model for ASR will be used.
440
-
477
+ parameters (Dict[str, Any], *optional*):
478
+ Additional parameters to pass to the model.
441
479
  Returns:
442
480
  [`AutomaticSpeechRecognitionOutput`]: An item containing the transcribed text and optionally the timestamp chunks.
443
481
 
@@ -455,7 +493,15 @@ class InferenceClient:
455
493
  "hello world"
456
494
  ```
457
495
  """
458
- response = self.post(data=audio, model=model, task="automatic-speech-recognition")
496
+ provider_helper = get_provider_helper(self.provider, task="automatic-speech-recognition")
497
+ request_parameters = provider_helper.prepare_request(
498
+ inputs=audio,
499
+ parameters={},
500
+ headers=self.headers,
501
+ model=model or self.model,
502
+ api_key=self.token,
503
+ )
504
+ response = self._inner_post(request_parameters)
459
505
  return AutomaticSpeechRecognitionOutput.parse_obj_as_instance(response)
460
506
 
461
507
  @overload
@@ -569,6 +615,10 @@ class InferenceClient:
569
615
 
570
616
  </Tip>
571
617
 
618
+ <Tip>
619
+ Some parameters might not be supported by some providers.
620
+ </Tip>
621
+
572
622
  Args:
573
623
  messages (List of [`ChatCompletionInputMessage`]):
574
624
  Conversation history consisting of roles and content pairs.
@@ -576,25 +626,20 @@ class InferenceClient:
576
626
  The model to use for chat-completion. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
577
627
  Inference Endpoint. If not provided, the default recommended model for chat-based text-generation will be used.
578
628
  See https://huggingface.co/tasks/text-generation for more details.
579
-
580
629
  If `model` is a model ID, it is passed to the server as the `model` parameter. If you want to define a
581
630
  custom URL while setting `model` in the request payload, you must set `base_url` when initializing [`InferenceClient`].
582
631
  frequency_penalty (`float`, *optional*):
583
632
  Penalizes new tokens based on their existing frequency
584
633
  in the text so far. Range: [-2.0, 2.0]. Defaults to 0.0.
585
634
  logit_bias (`List[float]`, *optional*):
586
- Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens
587
- (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,
588
- the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,
589
- but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should
590
- result in a ban or exclusive selection of the relevant token. Defaults to None.
635
+ Adjusts the likelihood of specific tokens appearing in the generated output.
591
636
  logprobs (`bool`, *optional*):
592
637
  Whether to return log probabilities of the output tokens or not. If true, returns the log
593
638
  probabilities of each output token returned in the content of message.
594
639
  max_tokens (`int`, *optional*):
595
640
  Maximum number of tokens allowed in the response. Defaults to 100.
596
641
  n (`int`, *optional*):
597
- UNUSED.
642
+ The number of completions to generate for each prompt.
598
643
  presence_penalty (`float`, *optional*):
599
644
  Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the
600
645
  text so far, increasing the model's likelihood to talk about new topics.
@@ -602,7 +647,7 @@ class InferenceClient:
602
647
  Grammar constraints. Can be either a JSONSchema or a regex.
603
648
  seed (Optional[`int`], *optional*):
604
649
  Seed for reproducible control flow. Defaults to None.
605
- stop (Optional[`str`], *optional*):
650
+ stop (`List[str]`, *optional*):
606
651
  Up to four strings which trigger the end of the response.
607
652
  Defaults to None.
608
653
  stream (`bool`, *optional*):
@@ -711,6 +756,32 @@ class InferenceClient:
711
756
  print(chunk.choices[0].delta.content)
712
757
  ```
713
758
 
759
+ Example using a third-party provider directly. Usage will be billed on your Together AI account.
760
+ ```py
761
+ >>> from huggingface_hub import InferenceClient
762
+ >>> client = InferenceClient(
763
+ ... provider="together", # Use Together AI provider
764
+ ... api_key="<together_api_key>", # Pass your Together API key directly
765
+ ... )
766
+ >>> client.chat_completion(
767
+ ... model="meta-llama/Meta-Llama-3-8B-Instruct",
768
+ ... messages=[{"role": "user", "content": "What is the capital of France?"}],
769
+ ... )
770
+ ```
771
+
772
+ Example using a third-party provider through Hugging Face Routing. Usage will be billed on your Hugging Face account.
773
+ ```py
774
+ >>> from huggingface_hub import InferenceClient
775
+ >>> client = InferenceClient(
776
+ ... provider="sambanova", # Use Sambanova provider
777
+ ... api_key="hf_...", # Pass your HF token
778
+ ... )
779
+ >>> client.chat_completion(
780
+ ... model="meta-llama/Meta-Llama-3-8B-Instruct",
781
+ ... messages=[{"role": "user", "content": "What is the capital of France?"}],
782
+ ... )
783
+ ```
784
+
714
785
  Example using Image + Text as input:
715
786
  ```py
716
787
  >>> from huggingface_hub import InferenceClient
@@ -859,68 +930,50 @@ class InferenceClient:
859
930
  '{\n\n"activity": "bike ride",\n"animals": ["puppy", "cat", "raccoon"],\n"animals_seen": 3,\n"location": "park"}'
860
931
  ```
861
932
  """
862
- model_url = self._resolve_chat_completion_url(model)
863
-
864
- # `model` is sent in the payload. Not used by the server but can be useful for debugging/routing.
865
- # If it's a ID on the Hub => use it. Otherwise, we use a random string.
866
- model_id = model or self.model or "tgi"
867
- if model_id.startswith(("http://", "https://")):
868
- model_id = "tgi" # dummy value
869
-
870
- payload = dict(
871
- model=model_id,
872
- messages=messages,
873
- frequency_penalty=frequency_penalty,
874
- logit_bias=logit_bias,
875
- logprobs=logprobs,
876
- max_tokens=max_tokens,
877
- n=n,
878
- presence_penalty=presence_penalty,
879
- response_format=response_format,
880
- seed=seed,
881
- stop=stop,
882
- temperature=temperature,
883
- tool_choice=tool_choice,
884
- tool_prompt=tool_prompt,
885
- tools=tools,
886
- top_logprobs=top_logprobs,
887
- top_p=top_p,
888
- stream=stream,
889
- stream_options=stream_options,
933
+ # Get the provider helper
934
+ provider_helper = get_provider_helper(self.provider, task="conversational")
935
+
936
+ # Since `chat_completion(..., model=xxx)` is also a payload parameter for the server, we need to handle 'model' differently.
937
+ # `self.base_url` and `self.model` takes precedence over 'model' argument for building URL.
938
+ # `model` takes precedence for payload value.
939
+ model_id_or_url = self.base_url or self.model or model
940
+ payload_model = model or self.model
941
+
942
+ # Prepare the payload
943
+ parameters = {
944
+ "model": payload_model,
945
+ "frequency_penalty": frequency_penalty,
946
+ "logit_bias": logit_bias,
947
+ "logprobs": logprobs,
948
+ "max_tokens": max_tokens,
949
+ "n": n,
950
+ "presence_penalty": presence_penalty,
951
+ "response_format": response_format,
952
+ "seed": seed,
953
+ "stop": stop,
954
+ "temperature": temperature,
955
+ "tool_choice": tool_choice,
956
+ "tool_prompt": tool_prompt,
957
+ "tools": tools,
958
+ "top_logprobs": top_logprobs,
959
+ "top_p": top_p,
960
+ "stream": stream,
961
+ "stream_options": stream_options,
962
+ }
963
+ request_parameters = provider_helper.prepare_request(
964
+ inputs=messages,
965
+ parameters=parameters,
966
+ headers=self.headers,
967
+ model=model_id_or_url,
968
+ api_key=self.token,
890
969
  )
891
- payload = {key: value for key, value in payload.items() if value is not None}
892
- data = self.post(model=model_url, json=payload, stream=stream)
970
+ data = self._inner_post(request_parameters, stream=stream)
893
971
 
894
972
  if stream:
895
973
  return _stream_chat_completion_response(data) # type: ignore[arg-type]
896
974
 
897
975
  return ChatCompletionOutput.parse_obj_as_instance(data) # type: ignore[arg-type]
898
976
 
899
- def _resolve_chat_completion_url(self, model: Optional[str] = None) -> str:
900
- # Since `chat_completion(..., model=xxx)` is also a payload parameter for the server, we need to handle 'model' differently.
901
- # `self.base_url` and `self.model` takes precedence over 'model' argument only in `chat_completion`.
902
- model_id_or_url = self.base_url or self.model or model or self.get_recommended_model("text-generation")
903
-
904
- # Resolve URL if it's a model ID
905
- model_url = (
906
- model_id_or_url
907
- if model_id_or_url.startswith(("http://", "https://"))
908
- else self._resolve_url(model_id_or_url, task="text-generation")
909
- )
910
-
911
- # Strip trailing /
912
- model_url = model_url.rstrip("/")
913
-
914
- # Append /chat/completions if not already present
915
- if model_url.endswith("/v1"):
916
- model_url += "/chat/completions"
917
-
918
- # Append /v1/chat/completions if not already present
919
- if not model_url.endswith("/chat/completions"):
920
- model_url += "/v1/chat/completions"
921
-
922
- return model_url
923
-
924
977
  def document_question_answering(
925
978
  self,
926
979
  image: ContentT,
@@ -987,18 +1040,24 @@ class InferenceClient:
987
1040
  ```
988
1041
  """
989
1042
  inputs: Dict[str, Any] = {"question": question, "image": _b64_encode(image)}
990
- parameters = {
991
- "doc_stride": doc_stride,
992
- "handle_impossible_answer": handle_impossible_answer,
993
- "lang": lang,
994
- "max_answer_len": max_answer_len,
995
- "max_question_len": max_question_len,
996
- "max_seq_len": max_seq_len,
997
- "top_k": top_k,
998
- "word_boxes": word_boxes,
999
- }
1000
- payload = _prepare_payload(inputs, parameters=parameters)
1001
- response = self.post(**payload, model=model, task="document-question-answering")
1043
+ provider_helper = get_provider_helper(self.provider, task="document-question-answering")
1044
+ request_parameters = provider_helper.prepare_request(
1045
+ inputs=inputs,
1046
+ parameters={
1047
+ "doc_stride": doc_stride,
1048
+ "handle_impossible_answer": handle_impossible_answer,
1049
+ "lang": lang,
1050
+ "max_answer_len": max_answer_len,
1051
+ "max_question_len": max_question_len,
1052
+ "max_seq_len": max_seq_len,
1053
+ "top_k": top_k,
1054
+ "word_boxes": word_boxes,
1055
+ },
1056
+ headers=self.headers,
1057
+ model=model or self.model,
1058
+ api_key=self.token,
1059
+ )
1060
+ response = self._inner_post(request_parameters)
1002
1061
  return DocumentQuestionAnsweringOutputElement.parse_obj_as_list(response)
1003
1062
 
1004
1063
  def feature_extraction(
@@ -1056,14 +1115,20 @@ class InferenceClient:
1056
1115
  [ 0.28552425, -0.928395 , -1.2077185 , ..., 0.76810825, -2.1069427 , 0.6236161 ]], dtype=float32)
1057
1116
  ```
1058
1117
  """
1059
- parameters = {
1060
- "normalize": normalize,
1061
- "prompt_name": prompt_name,
1062
- "truncate": truncate,
1063
- "truncation_direction": truncation_direction,
1064
- }
1065
- payload = _prepare_payload(text, parameters=parameters)
1066
- response = self.post(**payload, model=model, task="feature-extraction")
1118
+ provider_helper = get_provider_helper(self.provider, task="feature-extraction")
1119
+ request_parameters = provider_helper.prepare_request(
1120
+ inputs=text,
1121
+ parameters={
1122
+ "normalize": normalize,
1123
+ "prompt_name": prompt_name,
1124
+ "truncate": truncate,
1125
+ "truncation_direction": truncation_direction,
1126
+ },
1127
+ headers=self.headers,
1128
+ model=model or self.model,
1129
+ api_key=self.token,
1130
+ )
1131
+ response = self._inner_post(request_parameters)
1067
1132
  np = _import_numpy()
1068
1133
  return np.array(_bytes_to_dict(response), dtype="float32")
1069
1134
 
@@ -1111,9 +1176,15 @@ class InferenceClient:
1111
1176
  ]
1112
1177
  ```
1113
1178
  """
1114
- parameters = {"targets": targets, "top_k": top_k}
1115
- payload = _prepare_payload(text, parameters=parameters)
1116
- response = self.post(**payload, model=model, task="fill-mask")
1179
+ provider_helper = get_provider_helper(self.provider, task="fill-mask")
1180
+ request_parameters = provider_helper.prepare_request(
1181
+ inputs=text,
1182
+ parameters={"targets": targets, "top_k": top_k},
1183
+ headers=self.headers,
1184
+ model=model or self.model,
1185
+ api_key=self.token,
1186
+ )
1187
+ response = self._inner_post(request_parameters)
1117
1188
  return FillMaskOutputElement.parse_obj_as_list(response)
1118
1189
 
1119
1190
  def image_classification(
@@ -1154,9 +1225,15 @@ class InferenceClient:
1154
1225
  [ImageClassificationOutputElement(label='Blenheim spaniel', score=0.9779096841812134), ...]
1155
1226
  ```
1156
1227
  """
1157
- parameters = {"function_to_apply": function_to_apply, "top_k": top_k}
1158
- payload = _prepare_payload(image, parameters=parameters, expect_binary=True)
1159
- response = self.post(**payload, model=model, task="image-classification")
1228
+ provider_helper = get_provider_helper(self.provider, task="image-classification")
1229
+ request_parameters = provider_helper.prepare_request(
1230
+ inputs=image,
1231
+ parameters={"function_to_apply": function_to_apply, "top_k": top_k},
1232
+ headers=self.headers,
1233
+ model=model or self.model,
1234
+ api_key=self.token,
1235
+ )
1236
+ response = self._inner_post(request_parameters)
1160
1237
  return ImageClassificationOutputElement.parse_obj_as_list(response)
1161
1238
 
1162
1239
  def image_segmentation(
@@ -1209,14 +1286,20 @@ class InferenceClient:
1209
1286
  [ImageSegmentationOutputElement(score=0.989008, label='LABEL_184', mask=<PIL.PngImagePlugin.PngImageFile image mode=L size=400x300 at 0x7FDD2B129CC0>), ...]
1210
1287
  ```
1211
1288
  """
1212
- parameters = {
1213
- "mask_threshold": mask_threshold,
1214
- "overlap_mask_area_threshold": overlap_mask_area_threshold,
1215
- "subtask": subtask,
1216
- "threshold": threshold,
1217
- }
1218
- payload = _prepare_payload(image, parameters=parameters, expect_binary=True)
1219
- response = self.post(**payload, model=model, task="image-segmentation")
1289
+ provider_helper = get_provider_helper(self.provider, task="audio-classification")
1290
+ request_parameters = provider_helper.prepare_request(
1291
+ inputs=image,
1292
+ parameters={
1293
+ "mask_threshold": mask_threshold,
1294
+ "overlap_mask_area_threshold": overlap_mask_area_threshold,
1295
+ "subtask": subtask,
1296
+ "threshold": threshold,
1297
+ },
1298
+ headers=self.headers,
1299
+ model=model or self.model,
1300
+ api_key=self.token,
1301
+ )
1302
+ response = self._inner_post(request_parameters)
1220
1303
  output = ImageSegmentationOutputElement.parse_obj_as_list(response)
1221
1304
  for item in output:
1222
1305
  item.mask = _b64_to_image(item.mask) # type: ignore [assignment]
@@ -1227,7 +1310,7 @@ class InferenceClient:
1227
1310
  image: ContentT,
1228
1311
  prompt: Optional[str] = None,
1229
1312
  *,
1230
- negative_prompt: Optional[List[str]] = None,
1313
+ negative_prompt: Optional[str] = None,
1231
1314
  num_inference_steps: Optional[int] = None,
1232
1315
  guidance_scale: Optional[float] = None,
1233
1316
  model: Optional[str] = None,
@@ -1248,8 +1331,8 @@ class InferenceClient:
1248
1331
  The input image for translation. It can be raw bytes, an image file, or a URL to an online image.
1249
1332
  prompt (`str`, *optional*):
1250
1333
  The text prompt to guide the image generation.
1251
- negative_prompt (`List[str]`, *optional*):
1252
- One or several prompt to guide what NOT to include in image generation.
1334
+ negative_prompt (`str`, *optional*):
1335
+ One prompt to guide what NOT to include in image generation.
1253
1336
  num_inference_steps (`int`, *optional*):
1254
1337
  For diffusion models. The number of denoising steps. More denoising steps usually lead to a higher
1255
1338
  quality image at the expense of slower inference.
@@ -1279,16 +1362,22 @@ class InferenceClient:
1279
1362
  >>> image.save("tiger.jpg")
1280
1363
  ```
1281
1364
  """
1282
- parameters = {
1283
- "prompt": prompt,
1284
- "negative_prompt": negative_prompt,
1285
- "target_size": target_size,
1286
- "num_inference_steps": num_inference_steps,
1287
- "guidance_scale": guidance_scale,
1288
- **kwargs,
1289
- }
1290
- payload = _prepare_payload(image, parameters=parameters, expect_binary=True)
1291
- response = self.post(**payload, model=model, task="image-to-image")
1365
+ provider_helper = get_provider_helper(self.provider, task="image-to-image")
1366
+ request_parameters = provider_helper.prepare_request(
1367
+ inputs=image,
1368
+ parameters={
1369
+ "prompt": prompt,
1370
+ "negative_prompt": negative_prompt,
1371
+ "target_size": target_size,
1372
+ "num_inference_steps": num_inference_steps,
1373
+ "guidance_scale": guidance_scale,
1374
+ **kwargs,
1375
+ },
1376
+ headers=self.headers,
1377
+ model=model or self.model,
1378
+ api_key=self.token,
1379
+ )
1380
+ response = self._inner_post(request_parameters)
1292
1381
  return _bytes_to_image(response)
1293
1382
 
1294
1383
  def image_to_text(self, image: ContentT, *, model: Optional[str] = None) -> ImageToTextOutput:
@@ -1324,93 +1413,18 @@ class InferenceClient:
1324
1413
  'a dog laying on the grass next to a flower pot '
1325
1414
  ```
1326
1415
  """
1327
- response = self.post(data=image, model=model, task="image-to-text")
1416
+ provider_helper = get_provider_helper(self.provider, task="image-to-text")
1417
+ request_parameters = provider_helper.prepare_request(
1418
+ inputs=image,
1419
+ parameters={},
1420
+ headers=self.headers,
1421
+ model=model or self.model,
1422
+ api_key=self.token,
1423
+ )
1424
+ response = self._inner_post(request_parameters)
1328
1425
  output = ImageToTextOutput.parse_obj(response)
1329
1426
  return output[0] if isinstance(output, list) else output
1330
1427
 
1331
- def list_deployed_models(
1332
- self, frameworks: Union[None, str, Literal["all"], List[str]] = None
1333
- ) -> Dict[str, List[str]]:
1334
- """
1335
- List models deployed on the Serverless Inference API service.
1336
-
1337
- This helper checks deployed models framework by framework. By default, it will check the 4 main frameworks that
1338
- are supported and account for 95% of the hosted models. However, if you want a complete list of models you can
1339
- specify `frameworks="all"` as input. Alternatively, if you know before-hand which framework you are interested
1340
- in, you can also restrict to search to this one (e.g. `frameworks="text-generation-inference"`). The more
1341
- frameworks are checked, the more time it will take.
1342
-
1343
- <Tip warning={true}>
1344
-
1345
- This endpoint method does not return a live list of all models available for the Serverless Inference API service.
1346
- It searches over a cached list of models that were recently available and the list may not be up to date.
1347
- If you want to know the live status of a specific model, use [`~InferenceClient.get_model_status`].
1348
-
1349
- </Tip>
1350
-
1351
- <Tip>
1352
-
1353
- This endpoint method is mostly useful for discoverability. If you already know which model you want to use and want to
1354
- check its availability, you can directly use [`~InferenceClient.get_model_status`].
1355
-
1356
- </Tip>
1357
-
1358
- Args:
1359
- frameworks (`Literal["all"]` or `List[str]` or `str`, *optional*):
1360
- The frameworks to filter on. By default only a subset of the available frameworks are tested. If set to
1361
- "all", all available frameworks will be tested. It is also possible to provide a single framework or a
1362
- custom set of frameworks to check.
1363
-
1364
- Returns:
1365
- `Dict[str, List[str]]`: A dictionary mapping task names to a sorted list of model IDs.
1366
-
1367
- Example:
1368
- ```python
1369
- >>> from huggingface_hub import InferenceClient
1370
- >>> client = InferenceClient()
1371
-
1372
- # Discover zero-shot-classification models currently deployed
1373
- >>> models = client.list_deployed_models()
1374
- >>> models["zero-shot-classification"]
1375
- ['Narsil/deberta-large-mnli-zero-cls', 'facebook/bart-large-mnli', ...]
1376
-
1377
- # List from only 1 framework
1378
- >>> client.list_deployed_models("text-generation-inference")
1379
- {'text-generation': ['bigcode/starcoder', 'meta-llama/Llama-2-70b-chat-hf', ...], ...}
1380
- ```
1381
- """
1382
- # Resolve which frameworks to check
1383
- if frameworks is None:
1384
- frameworks = MAIN_INFERENCE_API_FRAMEWORKS
1385
- elif frameworks == "all":
1386
- frameworks = ALL_INFERENCE_API_FRAMEWORKS
1387
- elif isinstance(frameworks, str):
1388
- frameworks = [frameworks]
1389
- frameworks = list(set(frameworks))
1390
-
1391
- # Fetch them iteratively
1392
- models_by_task: Dict[str, List[str]] = {}
1393
-
1394
- def _unpack_response(framework: str, items: List[Dict]) -> None:
1395
- for model in items:
1396
- if framework == "sentence-transformers":
1397
- # Model running with the `sentence-transformers` framework can work with both tasks even if not
1398
- # branded as such in the API response
1399
- models_by_task.setdefault("feature-extraction", []).append(model["model_id"])
1400
- models_by_task.setdefault("sentence-similarity", []).append(model["model_id"])
1401
- else:
1402
- models_by_task.setdefault(model["task"], []).append(model["model_id"])
1403
-
1404
- for framework in frameworks:
1405
- response = get_session().get(f"{INFERENCE_ENDPOINT}/framework/{framework}", headers=self.headers)
1406
- hf_raise_for_status(response)
1407
- _unpack_response(framework, response.json())
1408
-
1409
- # Sort alphabetically for discoverability and return
1410
- for task, models in models_by_task.items():
1411
- models_by_task[task] = sorted(set(models), key=lambda x: x.lower())
1412
- return models_by_task
1413
-
1414
1428
  def object_detection(
1415
1429
  self, image: ContentT, *, model: Optional[str] = None, threshold: Optional[float] = None
1416
1430
  ) -> List[ObjectDetectionOutputElement]:
@@ -1450,11 +1464,15 @@ class InferenceClient:
1450
1464
  [ObjectDetectionOutputElement(score=0.9486683011054993, label='person', box=ObjectDetectionBoundingBox(xmin=59, ymin=39, xmax=420, ymax=510)), ...]
1451
1465
  ```
1452
1466
  """
1453
- parameters = {
1454
- "threshold": threshold,
1455
- }
1456
- payload = _prepare_payload(image, parameters=parameters, expect_binary=True)
1457
- response = self.post(**payload, model=model, task="object-detection")
1467
+ provider_helper = get_provider_helper(self.provider, task="object-detection")
1468
+ request_parameters = provider_helper.prepare_request(
1469
+ inputs=image,
1470
+ parameters={"threshold": threshold},
1471
+ headers=self.headers,
1472
+ model=model or self.model,
1473
+ api_key=self.token,
1474
+ )
1475
+ response = self._inner_post(request_parameters)
1458
1476
  return ObjectDetectionOutputElement.parse_obj_as_list(response)
1459
1477
 
1460
1478
  def question_answering(
@@ -1519,22 +1537,24 @@ class InferenceClient:
1519
1537
  QuestionAnsweringOutputElement(answer='Clara', end=16, score=0.9326565265655518, start=11)
1520
1538
  ```
1521
1539
  """
1522
- parameters = {
1523
- "align_to_words": align_to_words,
1524
- "doc_stride": doc_stride,
1525
- "handle_impossible_answer": handle_impossible_answer,
1526
- "max_answer_len": max_answer_len,
1527
- "max_question_len": max_question_len,
1528
- "max_seq_len": max_seq_len,
1529
- "top_k": top_k,
1530
- }
1531
- inputs: Dict[str, Any] = {"question": question, "context": context}
1532
- payload = _prepare_payload(inputs, parameters=parameters)
1533
- response = self.post(
1534
- **payload,
1535
- model=model,
1536
- task="question-answering",
1540
+ provider_helper = get_provider_helper(self.provider, task="question-answering")
1541
+ request_parameters = provider_helper.prepare_request(
1542
+ inputs=None,
1543
+ parameters={
1544
+ "align_to_words": align_to_words,
1545
+ "doc_stride": doc_stride,
1546
+ "handle_impossible_answer": handle_impossible_answer,
1547
+ "max_answer_len": max_answer_len,
1548
+ "max_question_len": max_question_len,
1549
+ "max_seq_len": max_seq_len,
1550
+ "top_k": top_k,
1551
+ },
1552
+ extra_payload={"question": question, "context": context},
1553
+ headers=self.headers,
1554
+ model=model or self.model,
1555
+ api_key=self.token,
1537
1556
  )
1557
+ response = self._inner_post(request_parameters)
1538
1558
  # Parse the response as a single `QuestionAnsweringOutputElement` when top_k is 1 or not provided, or a list of `QuestionAnsweringOutputElement` to ensure backward compatibility.
1539
1559
  output = QuestionAnsweringOutputElement.parse_obj(response)
1540
1560
  return output
@@ -1579,11 +1599,16 @@ class InferenceClient:
1579
1599
  [0.7785726189613342, 0.45876261591911316, 0.2906220555305481]
1580
1600
  ```
1581
1601
  """
1582
- response = self.post(
1583
- json={"inputs": {"source_sentence": sentence, "sentences": other_sentences}},
1584
- model=model,
1585
- task="sentence-similarity",
1602
+ provider_helper = get_provider_helper(self.provider, task="sentence-similarity")
1603
+ request_parameters = provider_helper.prepare_request(
1604
+ inputs=None,
1605
+ parameters={},
1606
+ extra_payload={"source_sentence": sentence, "sentences": other_sentences},
1607
+ headers=self.headers,
1608
+ model=model or self.model,
1609
+ api_key=self.token,
1586
1610
  )
1611
+ response = self._inner_post(request_parameters)
1587
1612
  return _bytes_to_list(response)
1588
1613
 
1589
1614
  @_deprecate_arguments(
@@ -1645,8 +1670,15 @@ class InferenceClient:
1645
1670
  "generate_parameters": generate_parameters,
1646
1671
  "truncation": truncation,
1647
1672
  }
1648
- payload = _prepare_payload(text, parameters=parameters)
1649
- response = self.post(**payload, model=model, task="summarization")
1673
+ provider_helper = get_provider_helper(self.provider, task="summarization")
1674
+ request_parameters = provider_helper.prepare_request(
1675
+ inputs=text,
1676
+ parameters=parameters,
1677
+ headers=self.headers,
1678
+ model=model or self.model,
1679
+ api_key=self.token,
1680
+ )
1681
+ response = self._inner_post(request_parameters)
1650
1682
  return SummarizationOutput.parse_obj_as_list(response)[0]
1651
1683
 
1652
1684
  def table_question_answering(
@@ -1699,21 +1731,16 @@ class InferenceClient:
1699
1731
  TableQuestionAnsweringOutputElement(answer='36542', coordinates=[[0, 1]], cells=['36542'], aggregator='AVERAGE')
1700
1732
  ```
1701
1733
  """
1702
- parameters = {
1703
- "padding": padding,
1704
- "sequential": sequential,
1705
- "truncation": truncation,
1706
- }
1707
- inputs = {
1708
- "query": query,
1709
- "table": table,
1710
- }
1711
- payload = _prepare_payload(inputs, parameters=parameters)
1712
- response = self.post(
1713
- **payload,
1714
- model=model,
1715
- task="table-question-answering",
1734
+ provider_helper = get_provider_helper(self.provider, task="table-question-answering")
1735
+ request_parameters = provider_helper.prepare_request(
1736
+ inputs=None,
1737
+ parameters={"model": model, "padding": padding, "sequential": sequential, "truncation": truncation},
1738
+ extra_payload={"query": query, "table": table},
1739
+ headers=self.headers,
1740
+ model=model or self.model,
1741
+ api_key=self.token,
1716
1742
  )
1743
+ response = self._inner_post(request_parameters)
1717
1744
  return TableQuestionAnsweringOutputElement.parse_obj_as_instance(response)
1718
1745
 
1719
1746
  def tabular_classification(self, table: Dict[str, Any], *, model: Optional[str] = None) -> List[str]:
@@ -1758,11 +1785,16 @@ class InferenceClient:
1758
1785
  ["5", "5", "5"]
1759
1786
  ```
1760
1787
  """
1761
- response = self.post(
1762
- json={"table": table},
1763
- model=model,
1764
- task="tabular-classification",
1788
+ provider_helper = get_provider_helper(self.provider, task="tabular-classification")
1789
+ request_parameters = provider_helper.prepare_request(
1790
+ inputs=None,
1791
+ extra_payload={"table": table},
1792
+ parameters={},
1793
+ headers=self.headers,
1794
+ model=model or self.model,
1795
+ api_key=self.token,
1765
1796
  )
1797
+ response = self._inner_post(request_parameters)
1766
1798
  return _bytes_to_list(response)
1767
1799
 
1768
1800
  def tabular_regression(self, table: Dict[str, Any], *, model: Optional[str] = None) -> List[float]:
@@ -1802,7 +1834,16 @@ class InferenceClient:
1802
1834
  [110, 120, 130]
1803
1835
  ```
1804
1836
  """
1805
- response = self.post(json={"table": table}, model=model, task="tabular-regression")
1837
+ provider_helper = get_provider_helper(self.provider, task="tabular-regression")
1838
+ request_parameters = provider_helper.prepare_request(
1839
+ inputs=None,
1840
+ parameters={},
1841
+ extra_payload={"table": table},
1842
+ headers=self.headers,
1843
+ model=model or self.model,
1844
+ api_key=self.token,
1845
+ )
1846
+ response = self._inner_post(request_parameters)
1806
1847
  return _bytes_to_list(response)
1807
1848
 
1808
1849
  def text_classification(
@@ -1848,16 +1889,18 @@ class InferenceClient:
1848
1889
  ]
1849
1890
  ```
1850
1891
  """
1851
- parameters = {
1852
- "function_to_apply": function_to_apply,
1853
- "top_k": top_k,
1854
- }
1855
- payload = _prepare_payload(text, parameters=parameters)
1856
- response = self.post(
1857
- **payload,
1858
- model=model,
1859
- task="text-classification",
1892
+ provider_helper = get_provider_helper(self.provider, task="text-classification")
1893
+ request_parameters = provider_helper.prepare_request(
1894
+ inputs=text,
1895
+ parameters={
1896
+ "function_to_apply": function_to_apply,
1897
+ "top_k": top_k,
1898
+ },
1899
+ headers=self.headers,
1900
+ model=model or self.model,
1901
+ api_key=self.token,
1860
1902
  )
1903
+ response = self._inner_post(request_parameters)
1861
1904
  return TextClassificationOutputElement.parse_obj_as_list(response)[0] # type: ignore [return-value]
1862
1905
 
1863
1906
  @overload
@@ -2041,15 +2084,6 @@ class InferenceClient:
2041
2084
  """
2042
2085
  Given a prompt, generate the following text.
2043
2086
 
2044
- API endpoint is supposed to run with the `text-generation-inference` backend (TGI). This backend is the
2045
- go-to solution to run large language models at scale. However, for some smaller models (e.g. "gpt2") the
2046
- default `transformers` + `api-inference` solution is still in use. Both approaches have very similar APIs, but
2047
- not exactly the same. This method is compatible with both approaches but some parameters are only available for
2048
- `text-generation-inference`. If some parameters are ignored, a warning message is triggered but the process
2049
- continues correctly.
2050
-
2051
- To learn more about the TGI project, please refer to https://github.com/huggingface/text-generation-inference.
2052
-
2053
2087
  <Tip>
2054
2088
 
2055
2089
  If you want to generate a response from chat messages, you should use the [`InferenceClient.chat_completion`] method.
@@ -2272,12 +2306,6 @@ class InferenceClient:
2272
2306
  "typical_p": typical_p,
2273
2307
  "watermark": watermark,
2274
2308
  }
2275
- parameters = {k: v for k, v in parameters.items() if v is not None}
2276
- payload = {
2277
- "inputs": prompt,
2278
- "parameters": parameters,
2279
- "stream": stream,
2280
- }
2281
2309
 
2282
2310
  # Remove some parameters if not a TGI server
2283
2311
  unsupported_kwargs = _get_unsupported_text_generation_kwargs(model)
@@ -2310,9 +2338,19 @@ class InferenceClient:
2310
2338
  " Please pass `stream=False` as input."
2311
2339
  )
2312
2340
 
2341
+ provider_helper = get_provider_helper(self.provider, task="text-generation")
2342
+ request_parameters = provider_helper.prepare_request(
2343
+ inputs=prompt,
2344
+ parameters=parameters,
2345
+ extra_payload={"stream": stream},
2346
+ headers=self.headers,
2347
+ model=model or self.model,
2348
+ api_key=self.token,
2349
+ )
2350
+
2313
2351
  # Handle errors separately for more precise error messages
2314
2352
  try:
2315
- bytes_output = self.post(json=payload, model=model, task="text-generation", stream=stream) # type: ignore
2353
+ bytes_output = self._inner_post(request_parameters, stream=stream)
2316
2354
  except HTTPError as e:
2317
2355
  match = MODEL_KWARGS_NOT_USED_REGEX.search(str(e))
2318
2356
  if isinstance(e, BadRequestError) and match:
@@ -2322,7 +2360,7 @@ class InferenceClient:
2322
2360
  prompt=prompt,
2323
2361
  details=details,
2324
2362
  stream=stream,
2325
- model=model,
2363
+ model=model or self.model,
2326
2364
  adapter_id=adapter_id,
2327
2365
  best_of=best_of,
2328
2366
  decoder_input_details=decoder_input_details,
@@ -2360,7 +2398,7 @@ class InferenceClient:
2360
2398
  self,
2361
2399
  prompt: str,
2362
2400
  *,
2363
- negative_prompt: Optional[List[str]] = None,
2401
+ negative_prompt: Optional[str] = None,
2364
2402
  height: Optional[float] = None,
2365
2403
  width: Optional[float] = None,
2366
2404
  num_inference_steps: Optional[int] = None,
@@ -2383,8 +2421,8 @@ class InferenceClient:
2383
2421
  Args:
2384
2422
  prompt (`str`):
2385
2423
  The prompt to generate an image from.
2386
- negative_prompt (`List[str`, *optional*):
2387
- One or several prompt to guide what NOT to include in image generation.
2424
+ negative_prompt (`str`, *optional*):
2425
+ One prompt to guide what NOT to include in image generation.
2388
2426
  height (`float`, *optional*):
2389
2427
  The height in pixels of the image to generate.
2390
2428
  width (`float`, *optional*):
@@ -2430,23 +2468,143 @@ class InferenceClient:
2430
2468
  ... )
2431
2469
  >>> image.save("better_astronaut.png")
2432
2470
  ```
2433
- """
2471
+ Example using a third-party provider directly. Usage will be billed on your fal.ai account.
2472
+ ```py
2473
+ >>> from huggingface_hub import InferenceClient
2474
+ >>> client = InferenceClient(
2475
+ ... provider="fal-ai", # Use fal.ai provider
2476
+ ... api_key="fal-ai-api-key", # Pass your fal.ai API key
2477
+ ... )
2478
+ >>> image = client.text_to_image(
2479
+ ... "A majestic lion in a fantasy forest",
2480
+ ... model="black-forest-labs/FLUX.1-schnell",
2481
+ ... )
2482
+ >>> image.save("lion.png")
2483
+ ```
2434
2484
 
2435
- parameters = {
2436
- "negative_prompt": negative_prompt,
2437
- "height": height,
2438
- "width": width,
2439
- "num_inference_steps": num_inference_steps,
2440
- "guidance_scale": guidance_scale,
2441
- "scheduler": scheduler,
2442
- "target_size": target_size,
2443
- "seed": seed,
2444
- **kwargs,
2445
- }
2446
- payload = _prepare_payload(prompt, parameters=parameters)
2447
- response = self.post(**payload, model=model, task="text-to-image")
2485
+ Example using a third-party provider through Hugging Face Routing. Usage will be billed on your Hugging Face account.
2486
+ ```py
2487
+ >>> from huggingface_hub import InferenceClient
2488
+ >>> client = InferenceClient(
2489
+ ... provider="replicate", # Use replicate provider
2490
+ ... api_key="hf_...", # Pass your HF token
2491
+ ... )
2492
+ >>> image = client.text_to_image(
2493
+ ... "An astronaut riding a horse on the moon.",
2494
+ ... model="black-forest-labs/FLUX.1-dev",
2495
+ ... )
2496
+ >>> image.save("astronaut.png")
2497
+ ```
2498
+ """
2499
+ provider_helper = get_provider_helper(self.provider, task="text-to-image")
2500
+ request_parameters = provider_helper.prepare_request(
2501
+ inputs=prompt,
2502
+ parameters={
2503
+ "negative_prompt": negative_prompt,
2504
+ "height": height,
2505
+ "width": width,
2506
+ "num_inference_steps": num_inference_steps,
2507
+ "guidance_scale": guidance_scale,
2508
+ "scheduler": scheduler,
2509
+ "target_size": target_size,
2510
+ "seed": seed,
2511
+ **kwargs,
2512
+ },
2513
+ headers=self.headers,
2514
+ model=model or self.model,
2515
+ api_key=self.token,
2516
+ )
2517
+ response = self._inner_post(request_parameters)
2518
+ response = provider_helper.get_response(response)
2448
2519
  return _bytes_to_image(response)
2449
2520
 
2521
+ def text_to_video(
2522
+ self,
2523
+ prompt: str,
2524
+ *,
2525
+ model: Optional[str] = None,
2526
+ guidance_scale: Optional[float] = None,
2527
+ negative_prompt: Optional[List[str]] = None,
2528
+ num_frames: Optional[float] = None,
2529
+ num_inference_steps: Optional[int] = None,
2530
+ seed: Optional[int] = None,
2531
+ ) -> bytes:
2532
+ """
2533
+ Generate a video based on a given text.
2534
+
2535
+ Args:
2536
+ prompt (`str`):
2537
+ The prompt to generate a video from.
2538
+ model (`str`, *optional*):
2539
+ The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
2540
+ Inference Endpoint. If not provided, the default recommended text-to-video model will be used.
2541
+ Defaults to None.
2542
+ guidance_scale (`float`, *optional*):
2543
+ A higher guidance scale value encourages the model to generate videos closely linked to the text
2544
+ prompt, but values too high may cause saturation and other artifacts.
2545
+ negative_prompt (`List[str]`, *optional*):
2546
+ One or several prompt to guide what NOT to include in video generation.
2547
+ num_frames (`float`, *optional*):
2548
+ The num_frames parameter determines how many video frames are generated.
2549
+ num_inference_steps (`int`, *optional*):
2550
+ The number of denoising steps. More denoising steps usually lead to a higher quality video at the
2551
+ expense of slower inference.
2552
+ seed (`int`, *optional*):
2553
+ Seed for the random number generator.
2554
+
2555
+ Returns:
2556
+ `bytes`: The generated video.
2557
+
2558
+ Example:
2559
+
2560
+ Example using a third-party provider directly. Usage will be billed on your fal.ai account.
2561
+ ```py
2562
+ >>> from huggingface_hub import InferenceClient
2563
+ >>> client = InferenceClient(
2564
+ ... provider="fal-ai", # Using fal.ai provider
2565
+ ... api_key="fal-ai-api-key", # Pass your fal.ai API key
2566
+ ... )
2567
+ >>> video = client.text_to_video(
2568
+ ... "A majestic lion running in a fantasy forest",
2569
+ ... model="tencent/HunyuanVideo",
2570
+ ... )
2571
+ >>> with open("lion.mp4", "wb") as file:
2572
+ ... file.write(video)
2573
+ ```
2574
+
2575
+ Example using a third-party provider through Hugging Face Routing. Usage will be billed on your Hugging Face account.
2576
+ ```py
2577
+ >>> from huggingface_hub import InferenceClient
2578
+ >>> client = InferenceClient(
2579
+ ... provider="replicate", # Using replicate provider
2580
+ ... api_key="hf_...", # Pass your HF token
2581
+ ... )
2582
+ >>> video = client.text_to_video(
2583
+ ... "A cat running in a park",
2584
+ ... model="genmo/mochi-1-preview",
2585
+ ... )
2586
+ >>> with open("cat.mp4", "wb") as file:
2587
+ ... file.write(video)
2588
+ ```
2589
+ """
2590
+ provider_helper = get_provider_helper(self.provider, task="text-to-video")
2591
+ request_parameters = provider_helper.prepare_request(
2592
+ inputs=prompt,
2593
+ parameters={
2594
+ "guidance_scale": guidance_scale,
2595
+ "negative_prompt": negative_prompt,
2596
+ "num_frames": num_frames,
2597
+ "num_inference_steps": num_inference_steps,
2598
+ "seed": seed,
2599
+ },
2600
+ headers=self.headers,
2601
+ model=model or self.model,
2602
+ api_key=self.token,
2603
+ )
2604
+ response = self._inner_post(request_parameters)
2605
+ response = provider_helper.get_response(response)
2606
+ return response
2607
+
2450
2608
  def text_to_speech(
2451
2609
  self,
2452
2610
  text: str,
@@ -2544,27 +2702,62 @@ class InferenceClient:
2544
2702
  >>> audio = client.text_to_speech("Hello world")
2545
2703
  >>> Path("hello_world.flac").write_bytes(audio)
2546
2704
  ```
2705
+
2706
+ Example using a third-party provider directly. Usage will be billed on your Replicate account.
2707
+ ```py
2708
+ >>> from huggingface_hub import InferenceClient
2709
+ >>> client = InferenceClient(
2710
+ ... provider="replicate",
2711
+ ... api_key="your-replicate-api-key", # Pass your Replicate API key directly
2712
+ ... )
2713
+ >>> audio = client.text_to_speech(
2714
+ ... text="Hello world",
2715
+ ... model="OuteAI/OuteTTS-0.3-500M",
2716
+ ... )
2717
+ >>> Path("hello_world.flac").write_bytes(audio)
2718
+ ```
2719
+
2720
+ Example using a third-party provider through Hugging Face Routing. Usage will be billed on your Hugging Face account.
2721
+ ```py
2722
+ >>> from huggingface_hub import InferenceClient
2723
+ >>> client = InferenceClient(
2724
+ ... provider="replicate",
2725
+ ... api_key="hf_...", # Pass your HF token
2726
+ ... )
2727
+ >>> audio =client.text_to_speech(
2728
+ ... text="Hello world",
2729
+ ... model="OuteAI/OuteTTS-0.3-500M",
2730
+ ... )
2731
+ >>> Path("hello_world.flac").write_bytes(audio)
2732
+ ```
2547
2733
  """
2548
- parameters = {
2549
- "do_sample": do_sample,
2550
- "early_stopping": early_stopping,
2551
- "epsilon_cutoff": epsilon_cutoff,
2552
- "eta_cutoff": eta_cutoff,
2553
- "max_length": max_length,
2554
- "max_new_tokens": max_new_tokens,
2555
- "min_length": min_length,
2556
- "min_new_tokens": min_new_tokens,
2557
- "num_beam_groups": num_beam_groups,
2558
- "num_beams": num_beams,
2559
- "penalty_alpha": penalty_alpha,
2560
- "temperature": temperature,
2561
- "top_k": top_k,
2562
- "top_p": top_p,
2563
- "typical_p": typical_p,
2564
- "use_cache": use_cache,
2565
- }
2566
- payload = _prepare_payload(text, parameters=parameters)
2567
- response = self.post(**payload, model=model, task="text-to-speech")
2734
+ provider_helper = get_provider_helper(self.provider, task="text-to-speech")
2735
+ request_parameters = provider_helper.prepare_request(
2736
+ inputs=text,
2737
+ parameters={
2738
+ "do_sample": do_sample,
2739
+ "early_stopping": early_stopping,
2740
+ "epsilon_cutoff": epsilon_cutoff,
2741
+ "eta_cutoff": eta_cutoff,
2742
+ "max_length": max_length,
2743
+ "max_new_tokens": max_new_tokens,
2744
+ "min_length": min_length,
2745
+ "min_new_tokens": min_new_tokens,
2746
+ "num_beam_groups": num_beam_groups,
2747
+ "num_beams": num_beams,
2748
+ "penalty_alpha": penalty_alpha,
2749
+ "temperature": temperature,
2750
+ "top_k": top_k,
2751
+ "top_p": top_p,
2752
+ "typical_p": typical_p,
2753
+ "use_cache": use_cache,
2754
+ },
2755
+ headers=self.headers,
2756
+ model=model or self.model,
2757
+ api_key=self.token,
2758
+ )
2759
+ response = self._inner_post(request_parameters)
2760
+ response = provider_helper.get_response(response)
2568
2761
  return response
2569
2762
 
2570
2763
  def token_classification(
@@ -2626,18 +2819,19 @@ class InferenceClient:
2626
2819
  ]
2627
2820
  ```
2628
2821
  """
2629
-
2630
- parameters = {
2631
- "aggregation_strategy": aggregation_strategy,
2632
- "ignore_labels": ignore_labels,
2633
- "stride": stride,
2634
- }
2635
- payload = _prepare_payload(text, parameters=parameters)
2636
- response = self.post(
2637
- **payload,
2638
- model=model,
2639
- task="token-classification",
2822
+ provider_helper = get_provider_helper(self.provider, task="token-classification")
2823
+ request_parameters = provider_helper.prepare_request(
2824
+ inputs=text,
2825
+ parameters={
2826
+ "aggregation_strategy": aggregation_strategy,
2827
+ "ignore_labels": ignore_labels,
2828
+ "stride": stride,
2829
+ },
2830
+ headers=self.headers,
2831
+ model=model or self.model,
2832
+ api_key=self.token,
2640
2833
  )
2834
+ response = self._inner_post(request_parameters)
2641
2835
  return TokenClassificationOutputElement.parse_obj_as_list(response)
2642
2836
 
2643
2837
  def translation(
@@ -2710,15 +2904,22 @@ class InferenceClient:
2710
2904
 
2711
2905
  if src_lang is None and tgt_lang is not None:
2712
2906
  raise ValueError("You cannot specify `tgt_lang` without specifying `src_lang`.")
2713
- parameters = {
2714
- "src_lang": src_lang,
2715
- "tgt_lang": tgt_lang,
2716
- "clean_up_tokenization_spaces": clean_up_tokenization_spaces,
2717
- "truncation": truncation,
2718
- "generate_parameters": generate_parameters,
2719
- }
2720
- payload = _prepare_payload(text, parameters=parameters)
2721
- response = self.post(**payload, model=model, task="translation")
2907
+
2908
+ provider_helper = get_provider_helper(self.provider, task="translation")
2909
+ request_parameters = provider_helper.prepare_request(
2910
+ inputs=text,
2911
+ parameters={
2912
+ "src_lang": src_lang,
2913
+ "tgt_lang": tgt_lang,
2914
+ "clean_up_tokenization_spaces": clean_up_tokenization_spaces,
2915
+ "truncation": truncation,
2916
+ "generate_parameters": generate_parameters,
2917
+ },
2918
+ headers=self.headers,
2919
+ model=model or self.model,
2920
+ api_key=self.token,
2921
+ )
2922
+ response = self._inner_post(request_parameters)
2722
2923
  return TranslationOutput.parse_obj_as_list(response)[0]
2723
2924
 
2724
2925
  def visual_question_answering(
@@ -2767,10 +2968,16 @@ class InferenceClient:
2767
2968
  ]
2768
2969
  ```
2769
2970
  """
2770
- payload: Dict[str, Any] = {"question": question, "image": _b64_encode(image)}
2771
- if top_k is not None:
2772
- payload.setdefault("parameters", {})["top_k"] = top_k
2773
- response = self.post(json=payload, model=model, task="visual-question-answering")
2971
+ provider_helper = get_provider_helper(self.provider, task="visual-question-answering")
2972
+ request_parameters = provider_helper.prepare_request(
2973
+ inputs=image,
2974
+ parameters={"top_k": top_k},
2975
+ headers=self.headers,
2976
+ model=model or self.model,
2977
+ api_key=self.token,
2978
+ extra_payload={"question": question, "image": _b64_encode(image)},
2979
+ )
2980
+ response = self._inner_post(request_parameters)
2774
2981
  return VisualQuestionAnsweringOutputElement.parse_obj_as_list(response)
2775
2982
 
2776
2983
  @_deprecate_arguments(
@@ -2876,17 +3083,20 @@ class InferenceClient:
2876
3083
  candidate_labels = labels
2877
3084
  elif candidate_labels is None:
2878
3085
  raise ValueError("Must specify `candidate_labels`")
2879
- parameters = {
2880
- "candidate_labels": candidate_labels,
2881
- "multi_label": multi_label,
2882
- "hypothesis_template": hypothesis_template,
2883
- }
2884
- payload = _prepare_payload(text, parameters=parameters)
2885
- response = self.post(
2886
- **payload,
2887
- task="zero-shot-classification",
2888
- model=model,
3086
+
3087
+ provider_helper = get_provider_helper(self.provider, task="zero-shot-classification")
3088
+ request_parameters = provider_helper.prepare_request(
3089
+ inputs=text,
3090
+ parameters={
3091
+ "candidate_labels": candidate_labels,
3092
+ "multi_label": multi_label,
3093
+ "hypothesis_template": hypothesis_template,
3094
+ },
3095
+ headers=self.headers,
3096
+ model=model or self.model,
3097
+ api_key=self.token,
2889
3098
  )
3099
+ response = self._inner_post(request_parameters)
2890
3100
  output = _bytes_to_dict(response)
2891
3101
  return [
2892
3102
  ZeroShotClassificationOutputElement.parse_obj_as_instance({"label": label, "score": score})
@@ -2959,71 +3169,108 @@ class InferenceClient:
2959
3169
  # Raise ValueError if input is less than 2 labels
2960
3170
  if len(candidate_labels) < 2:
2961
3171
  raise ValueError("You must specify at least 2 classes to compare.")
2962
- parameters = {
2963
- "candidate_labels": candidate_labels,
2964
- "hypothesis_template": hypothesis_template,
2965
- }
2966
- payload = _prepare_payload(image, parameters=parameters, expect_binary=True)
2967
- response = self.post(
2968
- **payload,
2969
- model=model,
2970
- task="zero-shot-image-classification",
3172
+
3173
+ provider_helper = get_provider_helper(self.provider, task="zero-shot-image-classification")
3174
+ request_parameters = provider_helper.prepare_request(
3175
+ inputs=image,
3176
+ parameters={
3177
+ "candidate_labels": candidate_labels,
3178
+ "hypothesis_template": hypothesis_template,
3179
+ },
3180
+ headers=self.headers,
3181
+ model=model or self.model,
3182
+ api_key=self.token,
2971
3183
  )
3184
+ response = self._inner_post(request_parameters)
2972
3185
  return ZeroShotImageClassificationOutputElement.parse_obj_as_list(response)
2973
3186
 
2974
- def _resolve_url(self, model: Optional[str] = None, task: Optional[str] = None) -> str:
2975
- model = model or self.model or self.base_url
3187
+ def list_deployed_models(
3188
+ self, frameworks: Union[None, str, Literal["all"], List[str]] = None
3189
+ ) -> Dict[str, List[str]]:
3190
+ """
3191
+ List models deployed on the Serverless Inference API service.
2976
3192
 
2977
- # If model is already a URL, ignore `task` and return directly
2978
- if model is not None and (model.startswith("http://") or model.startswith("https://")):
2979
- return model
3193
+ This helper checks deployed models framework by framework. By default, it will check the 4 main frameworks that
3194
+ are supported and account for 95% of the hosted models. However, if you want a complete list of models you can
3195
+ specify `frameworks="all"` as input. Alternatively, if you know before-hand which framework you are interested
3196
+ in, you can also restrict to search to this one (e.g. `frameworks="text-generation-inference"`). The more
3197
+ frameworks are checked, the more time it will take.
2980
3198
 
2981
- # # If no model but task is set => fetch the recommended one for this task
2982
- if model is None:
2983
- if task is None:
2984
- raise ValueError(
2985
- "You must specify at least a model (repo_id or URL) or a task, either when instantiating"
2986
- " `InferenceClient` or when making a request."
2987
- )
2988
- model = self.get_recommended_model(task)
2989
- logger.info(
2990
- f"Using recommended model {model} for task {task}. Note that it is"
2991
- f" encouraged to explicitly set `model='{model}'` as the recommended"
2992
- " models list might get updated without prior notice."
2993
- )
3199
+ <Tip warning={true}>
2994
3200
 
2995
- # Compute InferenceAPI url
2996
- return (
2997
- # Feature-extraction and sentence-similarity are the only cases where we handle models with several tasks.
2998
- f"{INFERENCE_ENDPOINT}/pipeline/{task}/{model}"
2999
- if task in ("feature-extraction", "sentence-similarity")
3000
- # Otherwise, we use the default endpoint
3001
- else f"{INFERENCE_ENDPOINT}/models/{model}"
3002
- )
3201
+ This endpoint method does not return a live list of all models available for the Serverless Inference API service.
3202
+ It searches over a cached list of models that were recently available and the list may not be up to date.
3203
+ If you want to know the live status of a specific model, use [`~InferenceClient.get_model_status`].
3003
3204
 
3004
- @staticmethod
3005
- def get_recommended_model(task: str) -> str:
3006
- """
3007
- Get the model Hugging Face recommends for the input task.
3205
+ </Tip>
3206
+
3207
+ <Tip>
3208
+
3209
+ This endpoint method is mostly useful for discoverability. If you already know which model you want to use and want to
3210
+ check its availability, you can directly use [`~InferenceClient.get_model_status`].
3211
+
3212
+ </Tip>
3008
3213
 
3009
3214
  Args:
3010
- task (`str`):
3011
- The Hugging Face task to get which model Hugging Face recommends.
3012
- All available tasks can be found [here](https://huggingface.co/tasks).
3215
+ frameworks (`Literal["all"]` or `List[str]` or `str`, *optional*):
3216
+ The frameworks to filter on. By default only a subset of the available frameworks are tested. If set to
3217
+ "all", all available frameworks will be tested. It is also possible to provide a single framework or a
3218
+ custom set of frameworks to check.
3013
3219
 
3014
3220
  Returns:
3015
- `str`: Name of the model recommended for the input task.
3221
+ `Dict[str, List[str]]`: A dictionary mapping task names to a sorted list of model IDs.
3016
3222
 
3017
- Raises:
3018
- `ValueError`: If Hugging Face has no recommendation for the input task.
3223
+ Example:
3224
+ ```python
3225
+ >>> from huggingface_hub import InferenceClient
3226
+ >>> client = InferenceClient()
3227
+
3228
+ # Discover zero-shot-classification models currently deployed
3229
+ >>> models = client.list_deployed_models()
3230
+ >>> models["zero-shot-classification"]
3231
+ ['Narsil/deberta-large-mnli-zero-cls', 'facebook/bart-large-mnli', ...]
3232
+
3233
+ # List from only 1 framework
3234
+ >>> client.list_deployed_models("text-generation-inference")
3235
+ {'text-generation': ['bigcode/starcoder', 'meta-llama/Llama-2-70b-chat-hf', ...], ...}
3236
+ ```
3019
3237
  """
3020
- model = _fetch_recommended_models().get(task)
3021
- if model is None:
3022
- raise ValueError(
3023
- f"Task {task} has no recommended model. Please specify a model"
3024
- " explicitly. Visit https://huggingface.co/tasks for more info."
3238
+ if self.provider != "hf-inference":
3239
+ raise ValueError(f"Listing deployed models is not supported on '{self.provider}'.")
3240
+
3241
+ # Resolve which frameworks to check
3242
+ if frameworks is None:
3243
+ frameworks = MAIN_INFERENCE_API_FRAMEWORKS
3244
+ elif frameworks == "all":
3245
+ frameworks = ALL_INFERENCE_API_FRAMEWORKS
3246
+ elif isinstance(frameworks, str):
3247
+ frameworks = [frameworks]
3248
+ frameworks = list(set(frameworks))
3249
+
3250
+ # Fetch them iteratively
3251
+ models_by_task: Dict[str, List[str]] = {}
3252
+
3253
+ def _unpack_response(framework: str, items: List[Dict]) -> None:
3254
+ for model in items:
3255
+ if framework == "sentence-transformers":
3256
+ # Model running with the `sentence-transformers` framework can work with both tasks even if not
3257
+ # branded as such in the API response
3258
+ models_by_task.setdefault("feature-extraction", []).append(model["model_id"])
3259
+ models_by_task.setdefault("sentence-similarity", []).append(model["model_id"])
3260
+ else:
3261
+ models_by_task.setdefault(model["task"], []).append(model["model_id"])
3262
+
3263
+ for framework in frameworks:
3264
+ response = get_session().get(
3265
+ f"{INFERENCE_ENDPOINT}/framework/{framework}", headers=build_hf_headers(token=self.token)
3025
3266
  )
3026
- return model
3267
+ hf_raise_for_status(response)
3268
+ _unpack_response(framework, response.json())
3269
+
3270
+ # Sort alphabetically for discoverability and return
3271
+ for task, models in models_by_task.items():
3272
+ models_by_task[task] = sorted(set(models), key=lambda x: x.lower())
3273
+ return models_by_task
3027
3274
 
3028
3275
  def get_endpoint_info(self, *, model: Optional[str] = None) -> Dict[str, Any]:
3029
3276
  """
@@ -3068,6 +3315,9 @@ class InferenceClient:
3068
3315
  }
3069
3316
  ```
3070
3317
  """
3318
+ if self.provider != "hf-inference":
3319
+ raise ValueError(f"Getting endpoint info is not supported on '{self.provider}'.")
3320
+
3071
3321
  model = model or self.model
3072
3322
  if model is None:
3073
3323
  raise ValueError("Model id not provided.")
@@ -3076,7 +3326,7 @@ class InferenceClient:
3076
3326
  else:
3077
3327
  url = f"{INFERENCE_ENDPOINT}/models/{model}/info"
3078
3328
 
3079
- response = get_session().get(url, headers=self.headers)
3329
+ response = get_session().get(url, headers=build_hf_headers(token=self.token))
3080
3330
  hf_raise_for_status(response)
3081
3331
  return response.json()
3082
3332
 
@@ -3102,6 +3352,9 @@ class InferenceClient:
3102
3352
  True
3103
3353
  ```
3104
3354
  """
3355
+ if self.provider != "hf-inference":
3356
+ raise ValueError(f"Health check is not supported on '{self.provider}'.")
3357
+
3105
3358
  model = model or self.model
3106
3359
  if model is None:
3107
3360
  raise ValueError("Model id not provided.")
@@ -3111,7 +3364,7 @@ class InferenceClient:
3111
3364
  )
3112
3365
  url = model.rstrip("/") + "/health"
3113
3366
 
3114
- response = get_session().get(url, headers=self.headers)
3367
+ response = get_session().get(url, headers=build_hf_headers(token=self.token))
3115
3368
  return response.status_code == 200
3116
3369
 
3117
3370
  def get_model_status(self, model: Optional[str] = None) -> ModelStatus:
@@ -3144,6 +3397,9 @@ class InferenceClient:
3144
3397
  ModelStatus(loaded=True, state='Loaded', compute_type='gpu', framework='text-generation-inference')
3145
3398
  ```
3146
3399
  """
3400
+ if self.provider != "hf-inference":
3401
+ raise ValueError(f"Getting model status is not supported on '{self.provider}'.")
3402
+
3147
3403
  model = model or self.model
3148
3404
  if model is None:
3149
3405
  raise ValueError("Model id not provided.")
@@ -3151,7 +3407,7 @@ class InferenceClient:
3151
3407
  raise NotImplementedError("Model status is only available for Inference API endpoints.")
3152
3408
  url = f"{INFERENCE_ENDPOINT}/status/{model}"
3153
3409
 
3154
- response = get_session().get(url, headers=self.headers)
3410
+ response = get_session().get(url, headers=build_hf_headers(token=self.token))
3155
3411
  hf_raise_for_status(response)
3156
3412
  response_data = response.json()
3157
3413