huggingface-hub 0.23.5__py3-none-any.whl → 0.24.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of huggingface-hub might be problematic. Click here for more details.

Files changed (42) hide show
  1. huggingface_hub/__init__.py +47 -15
  2. huggingface_hub/_commit_api.py +38 -8
  3. huggingface_hub/_inference_endpoints.py +11 -4
  4. huggingface_hub/_local_folder.py +22 -13
  5. huggingface_hub/_snapshot_download.py +12 -7
  6. huggingface_hub/_webhooks_server.py +3 -1
  7. huggingface_hub/commands/huggingface_cli.py +4 -3
  8. huggingface_hub/commands/repo_files.py +128 -0
  9. huggingface_hub/constants.py +12 -0
  10. huggingface_hub/file_download.py +127 -91
  11. huggingface_hub/hf_api.py +976 -341
  12. huggingface_hub/hf_file_system.py +30 -3
  13. huggingface_hub/inference/_client.py +408 -147
  14. huggingface_hub/inference/_common.py +25 -63
  15. huggingface_hub/inference/_generated/_async_client.py +425 -153
  16. huggingface_hub/inference/_generated/types/__init__.py +4 -1
  17. huggingface_hub/inference/_generated/types/chat_completion.py +41 -21
  18. huggingface_hub/inference/_generated/types/feature_extraction.py +23 -5
  19. huggingface_hub/inference/_generated/types/text_generation.py +29 -0
  20. huggingface_hub/lfs.py +11 -6
  21. huggingface_hub/repocard_data.py +3 -3
  22. huggingface_hub/repository.py +6 -6
  23. huggingface_hub/serialization/__init__.py +8 -3
  24. huggingface_hub/serialization/_base.py +13 -16
  25. huggingface_hub/serialization/_tensorflow.py +4 -3
  26. huggingface_hub/serialization/_torch.py +399 -22
  27. huggingface_hub/utils/__init__.py +0 -1
  28. huggingface_hub/utils/_errors.py +1 -1
  29. huggingface_hub/utils/_fixes.py +14 -3
  30. huggingface_hub/utils/_paths.py +17 -6
  31. huggingface_hub/utils/_subprocess.py +0 -1
  32. huggingface_hub/utils/_telemetry.py +9 -1
  33. huggingface_hub/utils/endpoint_helpers.py +2 -186
  34. huggingface_hub/utils/sha.py +36 -1
  35. huggingface_hub/utils/tqdm.py +0 -1
  36. {huggingface_hub-0.23.5.dist-info → huggingface_hub-0.24.1.dist-info}/METADATA +12 -9
  37. {huggingface_hub-0.23.5.dist-info → huggingface_hub-0.24.1.dist-info}/RECORD +41 -41
  38. huggingface_hub/serialization/_numpy.py +0 -68
  39. {huggingface_hub-0.23.5.dist-info → huggingface_hub-0.24.1.dist-info}/LICENSE +0 -0
  40. {huggingface_hub-0.23.5.dist-info → huggingface_hub-0.24.1.dist-info}/WHEEL +0 -0
  41. {huggingface_hub-0.23.5.dist-info → huggingface_hub-0.24.1.dist-info}/entry_points.txt +0 -0
  42. {huggingface_hub-0.23.5.dist-info → huggingface_hub-0.24.1.dist-info}/top_level.txt +0 -0
@@ -44,7 +44,7 @@ from huggingface_hub.inference._common import (
44
44
  TASKS_EXPECTING_IMAGES,
45
45
  ContentT,
46
46
  ModelStatus,
47
- _async_stream_chat_completion_response_from_bytes,
47
+ _async_stream_chat_completion_response,
48
48
  _async_stream_text_generation_response,
49
49
  _b64_encode,
50
50
  _b64_to_image,
@@ -54,9 +54,7 @@ from huggingface_hub.inference._common import (
54
54
  _fetch_recommended_models,
55
55
  _get_unsupported_text_generation_kwargs,
56
56
  _import_numpy,
57
- _is_chat_completion_server,
58
57
  _open_as_binary,
59
- _set_as_non_chat_completion_server,
60
58
  _set_unsupported_text_generation_kwargs,
61
59
  raise_text_generation_error,
62
60
  )
@@ -64,11 +62,10 @@ from huggingface_hub.inference._generated.types import (
64
62
  AudioClassificationOutputElement,
65
63
  AudioToAudioOutputElement,
66
64
  AutomaticSpeechRecognitionOutput,
65
+ ChatCompletionInputGrammarType,
67
66
  ChatCompletionInputTool,
68
67
  ChatCompletionInputToolTypeClass,
69
68
  ChatCompletionOutput,
70
- ChatCompletionOutputComplete,
71
- ChatCompletionOutputMessage,
72
69
  ChatCompletionStreamOutput,
73
70
  DocumentQuestionAnsweringOutputElement,
74
71
  FillMaskOutputElement,
@@ -89,13 +86,13 @@ from huggingface_hub.inference._generated.types import (
89
86
  ZeroShotClassificationOutputElement,
90
87
  ZeroShotImageClassificationOutputElement,
91
88
  )
92
- from huggingface_hub.inference._generated.types.chat_completion import ChatCompletionInputToolTypeEnum
93
89
  from huggingface_hub.inference._types import (
94
90
  ConversationalOutput, # soon to be removed
95
91
  )
96
92
  from huggingface_hub.utils import (
97
93
  build_hf_headers,
98
94
  )
95
+ from huggingface_hub.utils._deprecation import _deprecate_positional_args
99
96
 
100
97
  from .._common import _async_yield_from, _import_aiohttp
101
98
 
@@ -119,12 +116,16 @@ class AsyncInferenceClient:
119
116
 
120
117
  Args:
121
118
  model (`str`, `optional`):
122
- The model to run inference with. Can be a model id hosted on the Hugging Face Hub, e.g. `bigcode/starcoder`
119
+ The model to run inference with. Can be a model id hosted on the Hugging Face Hub, e.g. `meta-llama/Meta-Llama-3-8B-Instruct`
123
120
  or a URL to a deployed Inference Endpoint. Defaults to None, in which case a recommended model is
124
121
  automatically selected for the task.
122
+ Note: for better compatibility with OpenAI's client, `model` has been aliased as `base_url`. Those 2
123
+ arguments are mutually exclusive and have the exact same behavior.
125
124
  token (`str` or `bool`, *optional*):
126
125
  Hugging Face token. Will default to the locally saved token if not provided.
127
126
  Pass `token=False` if you don't want to send your token to the server.
127
+ Note: for better compatibility with OpenAI's client, `token` has been aliased as `api_key`. Those 2
128
+ arguments are mutually exclusive and have the exact same behavior.
128
129
  timeout (`float`, `optional`):
129
130
  The maximum number of seconds to wait for a response from the server. Loading a new model in Inference
130
131
  API can take up to several minutes. Defaults to None, meaning it will loop until the server is available.
@@ -133,23 +134,52 @@ class AsyncInferenceClient:
133
134
  Values in this dictionary will override the default values.
134
135
  cookies (`Dict[str, str]`, `optional`):
135
136
  Additional cookies to send to the server.
137
+ base_url (`str`, `optional`):
138
+ Base URL to run inference. This is a duplicated argument from `model` to make [`InferenceClient`]
139
+ follow the same pattern as `openai.OpenAI` client. Cannot be used if `model` is set. Defaults to None.
140
+ api_key (`str`, `optional`):
141
+ Token to use for authentication. This is a duplicated argument from `token` to make [`InferenceClient`]
142
+ follow the same pattern as `openai.OpenAI` client. Cannot be used if `token` is set. Defaults to None.
136
143
  """
137
144
 
145
+ @_deprecate_positional_args(version="0.26")
138
146
  def __init__(
139
147
  self,
140
148
  model: Optional[str] = None,
149
+ *,
141
150
  token: Union[str, bool, None] = None,
142
151
  timeout: Optional[float] = None,
143
152
  headers: Optional[Dict[str, str]] = None,
144
153
  cookies: Optional[Dict[str, str]] = None,
154
+ proxies: Optional[Any] = None,
155
+ # OpenAI compatibility
156
+ base_url: Optional[str] = None,
157
+ api_key: Optional[str] = None,
145
158
  ) -> None:
159
+ if model is not None and base_url is not None:
160
+ raise ValueError(
161
+ "Received both `model` and `base_url` arguments. Please provide only one of them."
162
+ " `base_url` is an alias for `model` to make the API compatible with OpenAI's client."
163
+ " It has the exact same behavior as `model`."
164
+ )
165
+ if token is not None and api_key is not None:
166
+ raise ValueError(
167
+ "Received both `token` and `api_key` arguments. Please provide only one of them."
168
+ " `api_key` is an alias for `token` to make the API compatible with OpenAI's client."
169
+ " It has the exact same behavior as `token`."
170
+ )
171
+
146
172
  self.model: Optional[str] = model
147
- self.token: Union[str, bool, None] = token
148
- self.headers = CaseInsensitiveDict(build_hf_headers(token=token)) # contains 'authorization' + 'user-agent'
173
+ self.token: Union[str, bool, None] = token if token is not None else api_key
174
+ self.headers = CaseInsensitiveDict(build_hf_headers(token=self.token)) # 'authorization' + 'user-agent'
149
175
  if headers is not None:
150
176
  self.headers.update(headers)
151
177
  self.cookies = cookies
152
178
  self.timeout = timeout
179
+ self.proxies = proxies
180
+
181
+ # OpenAI compatibility
182
+ self.base_url = base_url
153
183
 
154
184
  def __repr__(self):
155
185
  return f"<InferenceClient(model='{self.model if self.model else ''}', timeout={self.timeout})>"
@@ -250,7 +280,7 @@ class AsyncInferenceClient:
250
280
  )
251
281
 
252
282
  try:
253
- response = await client.post(url, json=json, data=data_as_binary)
283
+ response = await client.post(url, json=json, data=data_as_binary, proxy=self.proxies)
254
284
  response_error_payload = None
255
285
  if response.status != 200:
256
286
  try:
@@ -284,11 +314,16 @@ class AsyncInferenceClient:
284
314
  ) from error
285
315
  # ...or wait 1s and retry
286
316
  logger.info(f"Waiting for model to be loaded on the server: {error}")
317
+ if "X-wait-for-model" not in headers and url.startswith(INFERENCE_ENDPOINT):
318
+ headers["X-wait-for-model"] = "1"
287
319
  time.sleep(1)
288
320
  if timeout is not None:
289
321
  timeout = max(self.timeout - (time.time() - t0), 1) # type: ignore
290
322
  continue
291
323
  raise error
324
+ except Exception:
325
+ await client.close()
326
+ raise
292
327
 
293
328
  async def audio_classification(
294
329
  self,
@@ -427,10 +462,11 @@ class AsyncInferenceClient:
427
462
  max_tokens: Optional[int] = None,
428
463
  n: Optional[int] = None,
429
464
  presence_penalty: Optional[float] = None,
465
+ response_format: Optional[ChatCompletionInputGrammarType] = None,
430
466
  seed: Optional[int] = None,
431
467
  stop: Optional[List[str]] = None,
432
468
  temperature: Optional[float] = None,
433
- tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, ChatCompletionInputToolTypeEnum]] = None,
469
+ tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, str]] = None,
434
470
  tool_prompt: Optional[str] = None,
435
471
  tools: Optional[List[ChatCompletionInputTool]] = None,
436
472
  top_logprobs: Optional[int] = None,
@@ -450,10 +486,11 @@ class AsyncInferenceClient:
450
486
  max_tokens: Optional[int] = None,
451
487
  n: Optional[int] = None,
452
488
  presence_penalty: Optional[float] = None,
489
+ response_format: Optional[ChatCompletionInputGrammarType] = None,
453
490
  seed: Optional[int] = None,
454
491
  stop: Optional[List[str]] = None,
455
492
  temperature: Optional[float] = None,
456
- tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, ChatCompletionInputToolTypeEnum]] = None,
493
+ tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, str]] = None,
457
494
  tool_prompt: Optional[str] = None,
458
495
  tools: Optional[List[ChatCompletionInputTool]] = None,
459
496
  top_logprobs: Optional[int] = None,
@@ -473,10 +510,11 @@ class AsyncInferenceClient:
473
510
  max_tokens: Optional[int] = None,
474
511
  n: Optional[int] = None,
475
512
  presence_penalty: Optional[float] = None,
513
+ response_format: Optional[ChatCompletionInputGrammarType] = None,
476
514
  seed: Optional[int] = None,
477
515
  stop: Optional[List[str]] = None,
478
516
  temperature: Optional[float] = None,
479
- tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, ChatCompletionInputToolTypeEnum]] = None,
517
+ tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, str]] = None,
480
518
  tool_prompt: Optional[str] = None,
481
519
  tools: Optional[List[ChatCompletionInputTool]] = None,
482
520
  top_logprobs: Optional[int] = None,
@@ -496,10 +534,11 @@ class AsyncInferenceClient:
496
534
  max_tokens: Optional[int] = None,
497
535
  n: Optional[int] = None,
498
536
  presence_penalty: Optional[float] = None,
537
+ response_format: Optional[ChatCompletionInputGrammarType] = None,
499
538
  seed: Optional[int] = None,
500
539
  stop: Optional[List[str]] = None,
501
540
  temperature: Optional[float] = None,
502
- tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, ChatCompletionInputToolTypeEnum]] = None,
541
+ tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, str]] = None,
503
542
  tool_prompt: Optional[str] = None,
504
543
  tools: Optional[List[ChatCompletionInputTool]] = None,
505
544
  top_logprobs: Optional[int] = None,
@@ -510,11 +549,10 @@ class AsyncInferenceClient:
510
549
 
511
550
  <Tip>
512
551
 
513
- If the model is served by a server supporting chat-completion, the method will directly call the server's
514
- `/v1/chat/completions` endpoint. If the server does not support chat-completion, the method will render the
515
- chat template client-side based on the information fetched from the Hub API. In this case, you will need to
516
- have `minijinja` template engine installed. Run `pip install "huggingface_hub[inference]"` or `pip install minijinja`
517
- to install it.
552
+ The `client.chat_completion` method is aliased as `client.chat.completions.create` for compatibility with OpenAI's client.
553
+ Inputs and outputs are strictly the same and using either syntax will yield the same results.
554
+ Check out the [Inference guide](https://huggingface.co/docs/huggingface_hub/guides/inference#openai-compatibility)
555
+ for more details about OpenAI's compatibility.
518
556
 
519
557
  </Tip>
520
558
 
@@ -525,6 +563,9 @@ class AsyncInferenceClient:
525
563
  The model to use for chat-completion. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
526
564
  Inference Endpoint. If not provided, the default recommended model for chat-based text-generation will be used.
527
565
  See https://huggingface.co/tasks/text-generation for more details.
566
+
567
+ If `model` is a model ID, it is passed to the server as the `model` parameter. If you want to define a
568
+ custom URL while setting `model` in the request payload, you must set `base_url` when initializing [`InferenceClient`].
528
569
  frequency_penalty (`float`, *optional*):
529
570
  Penalizes new tokens based on their existing frequency
530
571
  in the text so far. Range: [-2.0, 2.0]. Defaults to 0.0.
@@ -544,6 +585,8 @@ class AsyncInferenceClient:
544
585
  presence_penalty (`float`, *optional*):
545
586
  Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the
546
587
  text so far, increasing the model's likelihood to talk about new topics.
588
+ response_format ([`ChatCompletionInputGrammarType`], *optional*):
589
+ Grammar constraints. Can be either a JSONSchema or a regex.
547
590
  seed (Optional[`int`], *optional*):
548
591
  Seed for reproducible control flow. Defaults to None.
549
592
  stop (Optional[`str`], *optional*):
@@ -561,7 +604,7 @@ class AsyncInferenceClient:
561
604
  top_p (`float`, *optional*):
562
605
  Fraction of the most likely next words to sample from.
563
606
  Must be between 0 and 1. Defaults to 1.0.
564
- tool_choice ([`ChatCompletionInputToolTypeClass`] or [`ChatCompletionInputToolTypeEnum`], *optional*):
607
+ tool_choice ([`ChatCompletionInputToolTypeClass`] or `str`, *optional*):
565
608
  The tool to use for the completion. Defaults to "auto".
566
609
  tool_prompt (`str`, *optional*):
567
610
  A prompt to be appended before the tools.
@@ -570,7 +613,7 @@ class AsyncInferenceClient:
570
613
  provide a list of functions the model may generate JSON inputs for.
571
614
 
572
615
  Returns:
573
- [`ChatCompletionOutput] or Iterable of [`ChatCompletionStreamOutput`]:
616
+ [`ChatCompletionOutput`] or Iterable of [`ChatCompletionStreamOutput`]:
574
617
  Generated text returned from the server:
575
618
  - if `stream=False`, the generated text is returned as a [`ChatCompletionOutput`] (default).
576
619
  - if `stream=True`, the generated text is returned token by token as a sequence of [`ChatCompletionStreamOutput`].
@@ -585,10 +628,9 @@ class AsyncInferenceClient:
585
628
 
586
629
  ```py
587
630
  # Must be run in an async context
588
- # Chat example
589
631
  >>> from huggingface_hub import AsyncInferenceClient
590
632
  >>> messages = [{"role": "user", "content": "What is the capital of France?"}]
591
- >>> client = AsyncInferenceClient("HuggingFaceH4/zephyr-7b-beta")
633
+ >>> client = AsyncInferenceClient("meta-llama/Meta-Llama-3-8B-Instruct")
592
634
  >>> await client.chat_completion(messages, max_tokens=100)
593
635
  ChatCompletionOutput(
594
636
  choices=[
@@ -596,26 +638,75 @@ class AsyncInferenceClient:
596
638
  finish_reason='eos_token',
597
639
  index=0,
598
640
  message=ChatCompletionOutputMessage(
599
- content='The capital of France is Paris. The official name of the city is Ville de Paris (City of Paris) and the name of the country governing body, which is located in Paris, is La République française (The French Republic). \nI hope that helps! Let me know if you need any further information.'
600
- )
641
+ role='assistant',
642
+ content='The capital of France is Paris.',
643
+ name=None,
644
+ tool_calls=None
645
+ ),
646
+ logprobs=None
601
647
  )
602
648
  ],
603
- created=1710498360
649
+ created=1719907176,
650
+ id='',
651
+ model='meta-llama/Meta-Llama-3-8B-Instruct',
652
+ object='text_completion',
653
+ system_fingerprint='2.0.4-sha-f426a33',
654
+ usage=ChatCompletionOutputUsage(
655
+ completion_tokens=8,
656
+ prompt_tokens=17,
657
+ total_tokens=25
658
+ )
604
659
  )
660
+ ```
605
661
 
662
+ Example (stream=True):
663
+ ```py
664
+ # Must be run in an async context
665
+ >>> from huggingface_hub import AsyncInferenceClient
666
+ >>> messages = [{"role": "user", "content": "What is the capital of France?"}]
667
+ >>> client = AsyncInferenceClient("meta-llama/Meta-Llama-3-8B-Instruct")
606
668
  >>> async for token in await client.chat_completion(messages, max_tokens=10, stream=True):
607
669
  ... print(token)
608
670
  ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content='The', role='assistant'), index=0, finish_reason=None)], created=1710498504)
609
671
  ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content=' capital', role='assistant'), index=0, finish_reason=None)], created=1710498504)
610
672
  (...)
611
673
  ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content=' may', role='assistant'), index=0, finish_reason=None)], created=1710498504)
674
+ ```
675
+
676
+ Example using OpenAI's syntax:
677
+ ```py
678
+ # Must be run in an async context
679
+ # instead of `from openai import OpenAI`
680
+ from huggingface_hub import AsyncInferenceClient
612
681
 
613
- # Chat example with tools
682
+ # instead of `client = OpenAI(...)`
683
+ client = AsyncInferenceClient(
684
+ base_url=...,
685
+ api_key=...,
686
+ )
687
+
688
+ output = await client.chat.completions.create(
689
+ model="meta-llama/Meta-Llama-3-8B-Instruct",
690
+ messages=[
691
+ {"role": "system", "content": "You are a helpful assistant."},
692
+ {"role": "user", "content": "Count to 10"},
693
+ ],
694
+ stream=True,
695
+ max_tokens=1024,
696
+ )
697
+
698
+ for chunk in output:
699
+ print(chunk.choices[0].delta.content)
700
+ ```
701
+
702
+ Example using tools:
703
+ ```py
704
+ # Must be run in an async context
614
705
  >>> client = AsyncInferenceClient("meta-llama/Meta-Llama-3-70B-Instruct")
615
706
  >>> messages = [
616
707
  ... {
617
708
  ... "role": "system",
618
- ... "content": "Don't make assumptions about what values to plug into functions. Ask async for clarification if a user request is ambiguous.",
709
+ ... "content": "Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.",
619
710
  ... },
620
711
  ... {
621
712
  ... "role": "user",
@@ -691,113 +782,90 @@ class AsyncInferenceClient:
691
782
  description=None
692
783
  )
693
784
  ```
694
- """
695
- # determine model
696
- model = model or self.model or self.get_recommended_model("text-generation")
697
-
698
- if _is_chat_completion_server(model):
699
- # First, let's consider the server has a `/v1/chat/completions` endpoint.
700
- # If that's the case, we don't have to render the chat template client-side.
701
- model_url = self._resolve_url(model)
702
- if not model_url.endswith("/chat/completions"):
703
- model_url += "/v1/chat/completions"
704
-
705
- try:
706
- data = await self.post(
707
- model=model_url,
708
- json=dict(
709
- model="tgi", # random string
710
- messages=messages,
711
- frequency_penalty=frequency_penalty,
712
- logit_bias=logit_bias,
713
- logprobs=logprobs,
714
- max_tokens=max_tokens,
715
- n=n,
716
- presence_penalty=presence_penalty,
717
- seed=seed,
718
- stop=stop,
719
- temperature=temperature,
720
- tool_choice=tool_choice,
721
- tool_prompt=tool_prompt,
722
- tools=tools,
723
- top_logprobs=top_logprobs,
724
- top_p=top_p,
725
- stream=stream,
726
- ),
727
- stream=stream,
728
- )
729
- except _import_aiohttp().ClientResponseError as e:
730
- if e.status in (400, 404, 500):
731
- # Let's consider the server is not a chat completion server.
732
- # Then we call again `chat_completion` which will render the chat template client side.
733
- # (can be HTTP 500, HTTP 400, HTTP 404 depending on the server)
734
- _set_as_non_chat_completion_server(model)
735
- logger.warning(
736
- f"Server {model_url} does not seem to support chat completion. Falling back to text generation. Error: {e}"
737
- )
738
- return await self.chat_completion(
739
- messages=messages,
740
- model=model,
741
- stream=stream,
742
- max_tokens=max_tokens,
743
- seed=seed,
744
- stop=stop,
745
- temperature=temperature,
746
- top_p=top_p,
747
- )
748
- raise
749
785
 
750
- if stream:
751
- return _async_stream_chat_completion_response_from_bytes(data) # type: ignore[arg-type]
752
-
753
- return ChatCompletionOutput.parse_obj_as_instance(data) # type: ignore[arg-type]
786
+ Example using response_format:
787
+ ```py
788
+ # Must be run in an async context
789
+ >>> from huggingface_hub import AsyncInferenceClient
790
+ >>> client = AsyncInferenceClient("meta-llama/Meta-Llama-3-70B-Instruct")
791
+ >>> messages = [
792
+ ... {
793
+ ... "role": "user",
794
+ ... "content": "I saw a puppy a cat and a raccoon during my bike ride in the park. What did I saw and when?",
795
+ ... },
796
+ ... ]
797
+ >>> response_format = {
798
+ ... "type": "json",
799
+ ... "value": {
800
+ ... "properties": {
801
+ ... "location": {"type": "string"},
802
+ ... "activity": {"type": "string"},
803
+ ... "animals_seen": {"type": "integer", "minimum": 1, "maximum": 5},
804
+ ... "animals": {"type": "array", "items": {"type": "string"}},
805
+ ... },
806
+ ... "required": ["location", "activity", "animals_seen", "animals"],
807
+ ... },
808
+ ... }
809
+ >>> response = await client.chat_completion(
810
+ ... messages=messages,
811
+ ... response_format=response_format,
812
+ ... max_tokens=500,
813
+ )
814
+ >>> response.choices[0].message.content
815
+ '{\n\n"activity": "bike ride",\n"animals": ["puppy", "cat", "raccoon"],\n"animals_seen": 3,\n"location": "park"}'
816
+ ```
817
+ """
818
+ # Determine model
819
+ # `self.xxx` takes precedence over the method argument only in `chat_completion`
820
+ # since `chat_completion(..., model=xxx)` is also a payload parameter for the
821
+ # server, we need to handle it differently
822
+ model = self.base_url or self.model or model or self.get_recommended_model("text-generation")
823
+ is_url = model.startswith(("http://", "https://"))
824
+
825
+ # First, resolve the model chat completions URL
826
+ if model == self.base_url:
827
+ # base_url passed => add server route
828
+ model_url = model + "/v1/chat/completions"
829
+ elif is_url:
830
+ # model is a URL => use it directly
831
+ model_url = model
832
+ else:
833
+ # model is a model ID => resolve it + add server route
834
+ model_url = self._resolve_url(model) + "/v1/chat/completions"
835
+
836
+ # `model` is sent in the payload. Not used by the server but can be useful for debugging/routing.
837
+ # If it's a ID on the Hub => use it. Otherwise, we use a random string.
838
+ model_id = model if not is_url and model.count("/") == 1 else "tgi"
839
+
840
+ data = await self.post(
841
+ model=model_url,
842
+ json=dict(
843
+ model=model_id,
844
+ messages=messages,
845
+ frequency_penalty=frequency_penalty,
846
+ logit_bias=logit_bias,
847
+ logprobs=logprobs,
848
+ max_tokens=max_tokens,
849
+ n=n,
850
+ presence_penalty=presence_penalty,
851
+ response_format=response_format,
852
+ seed=seed,
853
+ stop=stop,
854
+ temperature=temperature,
855
+ tool_choice=tool_choice,
856
+ tool_prompt=tool_prompt,
857
+ tools=tools,
858
+ top_logprobs=top_logprobs,
859
+ top_p=top_p,
860
+ stream=stream,
861
+ ),
862
+ stream=stream,
863
+ )
754
864
 
755
- # At this point, we know the server is not a chat completion server.
756
- # It means it's a transformers-backed server for which we can send a list of messages directly to the
757
- # `text-generation` pipeline. We won't receive a detailed response but only the generated text.
758
865
  if stream:
759
- raise ValueError(
760
- "Streaming token is not supported by the model. This is due to the model not been served by a "
761
- "Text-Generation-Inference server. Please pass `stream=False` as input."
762
- )
763
- if tool_choice is not None or tool_prompt is not None or tools is not None:
764
- warnings.warn(
765
- "Tools are not supported by the model. This is due to the model not been served by a "
766
- "Text-Generation-Inference server. The provided tool parameters will be ignored."
767
- )
768
-
769
- # generate response
770
- text_generation_output = await self.text_generation(
771
- prompt=messages, # type: ignore # Not correct type but works implicitly
772
- model=model,
773
- stream=False,
774
- details=False,
775
- max_new_tokens=max_tokens,
776
- seed=seed,
777
- stop_sequences=stop,
778
- temperature=temperature,
779
- top_p=top_p,
780
- )
866
+ return _async_stream_chat_completion_response(data) # type: ignore[arg-type]
781
867
 
782
- # Format as a ChatCompletionOutput with dummy values for fields we can't provide
783
- return ChatCompletionOutput(
784
- id="dummy",
785
- model="dummy",
786
- object="dummy",
787
- system_fingerprint="dummy",
788
- usage=None, # type: ignore # set to `None` as we don't want to provide false information
789
- created=int(time.time()),
790
- choices=[
791
- ChatCompletionOutputComplete(
792
- finish_reason="unk", # type: ignore # set to `unk` as we don't want to provide false information
793
- index=0,
794
- message=ChatCompletionOutputMessage(
795
- content=text_generation_output,
796
- role="assistant",
797
- ),
798
- )
799
- ],
800
- )
868
+ return ChatCompletionOutput.parse_obj_as_instance(data) # type: ignore[arg-type]
801
869
 
802
870
  async def conversational(
803
871
  self,
@@ -850,7 +918,7 @@ class AsyncInferenceClient:
850
918
  >>> client = AsyncInferenceClient()
851
919
  >>> output = await client.conversational("Hi, who are you?")
852
920
  >>> output
853
- {'generated_text': 'I am the one who knocks.', 'conversation': {'generated_responses': ['I am the one who knocks.'], 'past_user_inputs': ['Hi, who are you?']}, 'warnings': ['Setting `pad_token_id` to `eos_token_id`:50256 async for open-end generation.']}
921
+ {'generated_text': 'I am the one who knocks.', 'conversation': {'generated_responses': ['I am the one who knocks.'], 'past_user_inputs': ['Hi, who are you?']}, 'warnings': ['Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.']}
854
922
  >>> await client.conversational(
855
923
  ... "Wow, that's scary!",
856
924
  ... generated_responses=output["conversation"]["generated_responses"],
@@ -915,7 +983,16 @@ class AsyncInferenceClient:
915
983
  response = await self.post(json=payload, model=model, task="document-question-answering")
916
984
  return DocumentQuestionAnsweringOutputElement.parse_obj_as_list(response)
917
985
 
918
- async def feature_extraction(self, text: str, *, model: Optional[str] = None) -> "np.ndarray":
986
+ async def feature_extraction(
987
+ self,
988
+ text: str,
989
+ *,
990
+ normalize: Optional[bool] = None,
991
+ prompt_name: Optional[str] = None,
992
+ truncate: Optional[bool] = None,
993
+ truncation_direction: Optional[Literal["Left", "Right"]] = None,
994
+ model: Optional[str] = None,
995
+ ) -> "np.ndarray":
919
996
  """
920
997
  Generate embeddings for a given text.
921
998
 
@@ -926,6 +1003,20 @@ class AsyncInferenceClient:
926
1003
  The model to use for the conversational task. Can be a model ID hosted on the Hugging Face Hub or a URL to
927
1004
  a deployed Inference Endpoint. If not provided, the default recommended conversational model will be used.
928
1005
  Defaults to None.
1006
+ normalize (`bool`, *optional*):
1007
+ Whether to normalize the embeddings or not. Defaults to None.
1008
+ Only available on server powered by Text-Embedding-Inference.
1009
+ prompt_name (`str`, *optional*):
1010
+ The name of the prompt that should be used by for encoding. If not set, no prompt will be applied.
1011
+ Must be a key in the `Sentence Transformers` configuration `prompts` dictionary.
1012
+ For example if ``prompt_name`` is "query" and the ``prompts`` is {"query": "query: ",...},
1013
+ then the sentence "What is the capital of France?" will be encoded as "query: What is the capital of France?"
1014
+ because the prompt text will be prepended before any text to encode.
1015
+ truncate (`bool`, *optional*):
1016
+ Whether to truncate the embeddings or not. Defaults to None.
1017
+ Only available on server powered by Text-Embedding-Inference.
1018
+ truncation_direction (`Literal["Left", "Right"]`, *optional*):
1019
+ Which side of the input should be truncated when `truncate=True` is passed.
929
1020
 
930
1021
  Returns:
931
1022
  `np.ndarray`: The embedding representing the input text as a float32 numpy array.
@@ -948,7 +1039,16 @@ class AsyncInferenceClient:
948
1039
  [ 0.28552425, -0.928395 , -1.2077185 , ..., 0.76810825, -2.1069427 , 0.6236161 ]], dtype=float32)
949
1040
  ```
950
1041
  """
951
- response = await self.post(json={"inputs": text}, model=model, task="feature-extraction")
1042
+ payload: Dict = {"inputs": text}
1043
+ if normalize is not None:
1044
+ payload["normalize"] = normalize
1045
+ if prompt_name is not None:
1046
+ payload["prompt_name"] = prompt_name
1047
+ if truncate is not None:
1048
+ payload["truncate"] = truncate
1049
+ if truncation_direction is not None:
1050
+ payload["truncation_direction"] = truncation_direction
1051
+ response = await self.post(json=payload, model=model, task="feature-extraction")
952
1052
  np = _import_numpy()
953
1053
  return np.array(_bytes_to_dict(response), dtype="float32")
954
1054
 
@@ -1192,7 +1292,8 @@ class AsyncInferenceClient:
1192
1292
  ```
1193
1293
  """
1194
1294
  response = await self.post(data=image, model=model, task="image-to-text")
1195
- return ImageToTextOutput.parse_obj_as_instance(response)
1295
+ output = ImageToTextOutput.parse_obj(response)
1296
+ return output[0] if isinstance(output, list) else output
1196
1297
 
1197
1298
  async def list_deployed_models(
1198
1299
  self, frameworks: Union[None, str, Literal["all"], List[str]] = None
@@ -1643,6 +1744,7 @@ class AsyncInferenceClient:
1643
1744
  stream: Literal[False] = ...,
1644
1745
  model: Optional[str] = None,
1645
1746
  # Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
1747
+ adapter_id: Optional[str] = None,
1646
1748
  best_of: Optional[int] = None,
1647
1749
  decoder_input_details: Optional[bool] = None,
1648
1750
  do_sample: Optional[bool] = False, # Manual default value
@@ -1671,6 +1773,7 @@ class AsyncInferenceClient:
1671
1773
  stream: Literal[False] = ...,
1672
1774
  model: Optional[str] = None,
1673
1775
  # Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
1776
+ adapter_id: Optional[str] = None,
1674
1777
  best_of: Optional[int] = None,
1675
1778
  decoder_input_details: Optional[bool] = None,
1676
1779
  do_sample: Optional[bool] = False, # Manual default value
@@ -1699,6 +1802,7 @@ class AsyncInferenceClient:
1699
1802
  stream: Literal[True] = ...,
1700
1803
  model: Optional[str] = None,
1701
1804
  # Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
1805
+ adapter_id: Optional[str] = None,
1702
1806
  best_of: Optional[int] = None,
1703
1807
  decoder_input_details: Optional[bool] = None,
1704
1808
  do_sample: Optional[bool] = False, # Manual default value
@@ -1727,6 +1831,7 @@ class AsyncInferenceClient:
1727
1831
  stream: Literal[True] = ...,
1728
1832
  model: Optional[str] = None,
1729
1833
  # Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
1834
+ adapter_id: Optional[str] = None,
1730
1835
  best_of: Optional[int] = None,
1731
1836
  decoder_input_details: Optional[bool] = None,
1732
1837
  do_sample: Optional[bool] = False, # Manual default value
@@ -1755,6 +1860,7 @@ class AsyncInferenceClient:
1755
1860
  stream: bool = ...,
1756
1861
  model: Optional[str] = None,
1757
1862
  # Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
1863
+ adapter_id: Optional[str] = None,
1758
1864
  best_of: Optional[int] = None,
1759
1865
  decoder_input_details: Optional[bool] = None,
1760
1866
  do_sample: Optional[bool] = False, # Manual default value
@@ -1782,6 +1888,7 @@ class AsyncInferenceClient:
1782
1888
  stream: bool = False,
1783
1889
  model: Optional[str] = None,
1784
1890
  # Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
1891
+ adapter_id: Optional[str] = None,
1785
1892
  best_of: Optional[int] = None,
1786
1893
  decoder_input_details: Optional[bool] = None,
1787
1894
  do_sample: Optional[bool] = False, # Manual default value
@@ -1812,6 +1919,13 @@ class AsyncInferenceClient:
1812
1919
 
1813
1920
  To learn more about the TGI project, please refer to https://github.com/huggingface/text-generation-inference.
1814
1921
 
1922
+ <Tip>
1923
+
1924
+ If you want to generate a response from chat messages, you should use the [`InferenceClient.chat_completion`] method.
1925
+ It accepts a list of messages instead of a single text prompt and handles the chat templating for you.
1926
+
1927
+ </Tip>
1928
+
1815
1929
  Args:
1816
1930
  prompt (`str`):
1817
1931
  Input text.
@@ -1826,6 +1940,8 @@ class AsyncInferenceClient:
1826
1940
  model (`str`, *optional*):
1827
1941
  The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
1828
1942
  Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
1943
+ adapter_id (`str`, *optional*):
1944
+ Lora adapter id.
1829
1945
  best_of (`int`, *optional*):
1830
1946
  Generate best_of sequences and return the one if the highest token logprobs.
1831
1947
  decoder_input_details (`bool`, *optional*):
@@ -1893,7 +2009,7 @@ class AsyncInferenceClient:
1893
2009
  >>> await client.text_generation("The huggingface_hub library is ", max_new_tokens=12)
1894
2010
  '100% open source and built to be easy to use.'
1895
2011
 
1896
- # Case 2: iterate over the generated tokens. Useful async for large generation.
2012
+ # Case 2: iterate over the generated tokens. Useful for large generation.
1897
2013
  >>> async for token in await client.text_generation("The huggingface_hub library is ", max_new_tokens=12, stream=True):
1898
2014
  ... print(token)
1899
2015
  100
@@ -1995,6 +2111,7 @@ class AsyncInferenceClient:
1995
2111
 
1996
2112
  # Build payload
1997
2113
  parameters = {
2114
+ "adapter_id": adapter_id,
1998
2115
  "best_of": best_of,
1999
2116
  "decoder_input_details": decoder_input_details,
2000
2117
  "details": details,
@@ -2065,6 +2182,7 @@ class AsyncInferenceClient:
2065
2182
  details=details,
2066
2183
  stream=stream,
2067
2184
  model=model,
2185
+ adapter_id=adapter_id,
2068
2186
  best_of=best_of,
2069
2187
  decoder_input_details=decoder_input_details,
2070
2188
  do_sample=do_sample,
@@ -2089,7 +2207,12 @@ class AsyncInferenceClient:
2089
2207
  if stream:
2090
2208
  return _async_stream_text_generation_response(bytes_output, details) # type: ignore
2091
2209
 
2092
- data = _bytes_to_dict(bytes_output)[0] # type: ignore[arg-type]
2210
+ data = _bytes_to_dict(bytes_output) # type: ignore[arg-type]
2211
+
2212
+ # Data can be a single element (dict) or an iterable of dicts where we select the first element of.
2213
+ if isinstance(data, list):
2214
+ data = data[0]
2215
+
2093
2216
  return TextGenerationOutput.parse_obj_as_instance(data) if details else data["generated_text"]
2094
2217
 
2095
2218
  async def text_to_image(
@@ -2377,7 +2500,13 @@ class AsyncInferenceClient:
2377
2500
  return VisualQuestionAnsweringOutputElement.parse_obj_as_list(response)
2378
2501
 
2379
2502
  async def zero_shot_classification(
2380
- self, text: str, labels: List[str], *, multi_label: bool = False, model: Optional[str] = None
2503
+ self,
2504
+ text: str,
2505
+ labels: List[str],
2506
+ *,
2507
+ multi_label: bool = False,
2508
+ hypothesis_template: Optional[str] = None,
2509
+ model: Optional[str] = None,
2381
2510
  ) -> List[ZeroShotClassificationOutputElement]:
2382
2511
  """
2383
2512
  Provide as input a text and a set of candidate labels to classify the input text.
@@ -2386,9 +2515,15 @@ class AsyncInferenceClient:
2386
2515
  text (`str`):
2387
2516
  The input text to classify.
2388
2517
  labels (`List[str]`):
2389
- List of string possible labels. There must be at least 2 labels.
2518
+ List of strings. Each string is the verbalization of a possible label for the input text.
2390
2519
  multi_label (`bool`):
2391
- Boolean that is set to True if classes can overlap.
2520
+ Boolean. If True, the probability for each label is evaluated independently and multiple labels can have a probability close to 1 simultaneously or all probabilities can be close to 0.
2521
+ If False, the labels are considered mutually exclusive and the probability over all labels always sums to 1. Defaults to False.
2522
+ hypothesis_template (`str`, *optional*):
2523
+ A template sentence string with curly brackets to which the label strings are added. The label strings are added at the position of the curly brackets "{}".
2524
+ Zero-shot classifiers are based on NLI models, which evaluate if a hypothesis is entailed in another text or not.
2525
+ For example, with hypothesis_template="This text is about {}." and labels=["economics", "politics"], the system internally creates the two hypotheses "This text is about economics." and "This text is about politics.".
2526
+ The model then evaluates for both hypotheses if they are entailed in the provided `text` or not.
2392
2527
  model (`str`, *optional*):
2393
2528
  The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
2394
2529
  Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
@@ -2402,15 +2537,15 @@ class AsyncInferenceClient:
2402
2537
  `aiohttp.ClientResponseError`:
2403
2538
  If the request fails with an HTTP error status code other than HTTP 503.
2404
2539
 
2405
- Example:
2540
+ Example with `multi_label=False`:
2406
2541
  ```py
2407
2542
  # Must be run in an async context
2408
2543
  >>> from huggingface_hub import AsyncInferenceClient
2409
2544
  >>> client = AsyncInferenceClient()
2410
2545
  >>> text = (
2411
- ... "A new model offers an explanation async for how the Galilean satellites formed around the solar system's"
2546
+ ... "A new model offers an explanation for how the Galilean satellites formed around the solar system's"
2412
2547
  ... "largest world. Konstantin Batygin did not set out to solve one of the solar system's most puzzling"
2413
- ... " mysteries when he went async for a run up a hill in Nice, France."
2548
+ ... " mysteries when he went for a run up a hill in Nice, France."
2414
2549
  ... )
2415
2550
  >>> labels = ["space & cosmos", "scientific discovery", "microbiology", "robots", "archeology"]
2416
2551
  >>> await client.zero_shot_classification(text, labels)
@@ -2430,21 +2565,38 @@ class AsyncInferenceClient:
2430
2565
  ZeroShotClassificationOutputElement(label='robots', score=0.00030448526376858354),
2431
2566
  ]
2432
2567
  ```
2568
+
2569
+ Example with `multi_label=True` and a custom `hypothesis_template`:
2570
+ ```py
2571
+ # Must be run in an async context
2572
+ >>> from huggingface_hub import AsyncInferenceClient
2573
+ >>> client = AsyncInferenceClient()
2574
+ >>> await client.zero_shot_classification(
2575
+ ... text="I really like our dinner and I'm very happy. I don't like the weather though.",
2576
+ ... labels=["positive", "negative", "pessimistic", "optimistic"],
2577
+ ... multi_label=True,
2578
+ ... hypothesis_template="This text is {} towards the weather"
2579
+ ... )
2580
+ [
2581
+ ZeroShotClassificationOutputElement(label='negative', score=0.9231801629066467),
2582
+ ZeroShotClassificationOutputElement(label='pessimistic', score=0.8760990500450134),
2583
+ ZeroShotClassificationOutputElement(label='optimistic', score=0.0008674879791215062),
2584
+ ZeroShotClassificationOutputElement(label='positive', score=0.0005250611575320363)
2585
+ ]
2586
+ ```
2433
2587
  """
2434
- # Raise ValueError if input is less than 2 labels
2435
- if len(labels) < 2:
2436
- raise ValueError("You must specify at least 2 classes to compare.")
2588
+
2589
+ parameters = {"candidate_labels": labels, "multi_label": multi_label}
2590
+ if hypothesis_template is not None:
2591
+ parameters["hypothesis_template"] = hypothesis_template
2437
2592
 
2438
2593
  response = await self.post(
2439
2594
  json={
2440
2595
  "inputs": text,
2441
- "parameters": {
2442
- "candidate_labels": ",".join(labels),
2443
- "multi_label": multi_label,
2444
- },
2596
+ "parameters": parameters,
2445
2597
  },
2446
- model=model,
2447
2598
  task="zero-shot-classification",
2599
+ model=model,
2448
2600
  )
2449
2601
  output = _bytes_to_dict(response)
2450
2602
  return [
@@ -2501,7 +2653,7 @@ class AsyncInferenceClient:
2501
2653
  return ZeroShotImageClassificationOutputElement.parse_obj_as_list(response)
2502
2654
 
2503
2655
  def _resolve_url(self, model: Optional[str] = None, task: Optional[str] = None) -> str:
2504
- model = model or self.model
2656
+ model = model or self.model or self.base_url
2505
2657
 
2506
2658
  # If model is already a URL, ignore `task` and return directly
2507
2659
  if model is not None and (model.startswith("http://") or model.startswith("https://")):
@@ -2554,6 +2706,99 @@ class AsyncInferenceClient:
2554
2706
  )
2555
2707
  return model
2556
2708
 
2709
+ async def get_endpoint_info(self, *, model: Optional[str] = None) -> Dict[str, Any]:
2710
+ """
2711
+ Get information about the deployed endpoint.
2712
+
2713
+ This endpoint is only available on endpoints powered by Text-Generation-Inference (TGI) or Text-Embedding-Inference (TEI).
2714
+ Endpoints powered by `transformers` return an empty payload.
2715
+
2716
+ Args:
2717
+ model (`str`, *optional*):
2718
+ The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
2719
+ Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
2720
+
2721
+ Returns:
2722
+ `Dict[str, Any]`: Information about the endpoint.
2723
+
2724
+ Example:
2725
+ ```py
2726
+ # Must be run in an async context
2727
+ >>> from huggingface_hub import AsyncInferenceClient
2728
+ >>> client = AsyncInferenceClient("meta-llama/Meta-Llama-3-70B-Instruct")
2729
+ >>> await client.get_endpoint_info()
2730
+ {
2731
+ 'model_id': 'meta-llama/Meta-Llama-3-70B-Instruct',
2732
+ 'model_sha': None,
2733
+ 'model_dtype': 'torch.float16',
2734
+ 'model_device_type': 'cuda',
2735
+ 'model_pipeline_tag': None,
2736
+ 'max_concurrent_requests': 128,
2737
+ 'max_best_of': 2,
2738
+ 'max_stop_sequences': 4,
2739
+ 'max_input_length': 8191,
2740
+ 'max_total_tokens': 8192,
2741
+ 'waiting_served_ratio': 0.3,
2742
+ 'max_batch_total_tokens': 1259392,
2743
+ 'max_waiting_tokens': 20,
2744
+ 'max_batch_size': None,
2745
+ 'validation_workers': 32,
2746
+ 'max_client_batch_size': 4,
2747
+ 'version': '2.0.2',
2748
+ 'sha': 'dccab72549635c7eb5ddb17f43f0b7cdff07c214',
2749
+ 'docker_label': 'sha-dccab72'
2750
+ }
2751
+ ```
2752
+ """
2753
+ model = model or self.model
2754
+ if model is None:
2755
+ raise ValueError("Model id not provided.")
2756
+ if model.startswith(("http://", "https://")):
2757
+ url = model.rstrip("/") + "/info"
2758
+ else:
2759
+ url = f"{INFERENCE_ENDPOINT}/models/{model}/info"
2760
+
2761
+ async with _import_aiohttp().ClientSession(headers=self.headers) as client:
2762
+ response = await client.get(url)
2763
+ response.raise_for_status()
2764
+ return await response.json()
2765
+
2766
+ async def health_check(self, model: Optional[str] = None) -> bool:
2767
+ """
2768
+ Check the health of the deployed endpoint.
2769
+
2770
+ Health check is only available with Inference Endpoints powered by Text-Generation-Inference (TGI) or Text-Embedding-Inference (TEI).
2771
+ For Inference API, please use [`InferenceClient.get_model_status`] instead.
2772
+
2773
+ Args:
2774
+ model (`str`, *optional*):
2775
+ URL of the Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
2776
+
2777
+ Returns:
2778
+ `bool`: True if everything is working fine.
2779
+
2780
+ Example:
2781
+ ```py
2782
+ # Must be run in an async context
2783
+ >>> from huggingface_hub import AsyncInferenceClient
2784
+ >>> client = AsyncInferenceClient("https://jzgu0buei5.us-east-1.aws.endpoints.huggingface.cloud")
2785
+ >>> await client.health_check()
2786
+ True
2787
+ ```
2788
+ """
2789
+ model = model or self.model
2790
+ if model is None:
2791
+ raise ValueError("Model id not provided.")
2792
+ if not model.startswith(("http://", "https://")):
2793
+ raise ValueError(
2794
+ "Model must be an Inference Endpoint URL. For serverless Inference API, please use `InferenceClient.get_model_status`."
2795
+ )
2796
+ url = model.rstrip("/") + "/health"
2797
+
2798
+ async with _import_aiohttp().ClientSession(headers=self.headers) as client:
2799
+ response = await client.get(url)
2800
+ return response.status == 200
2801
+
2557
2802
  async def get_model_status(self, model: Optional[str] = None) -> ModelStatus:
2558
2803
  """
2559
2804
  Get the status of a model hosted on the Inference API.
@@ -2581,7 +2826,7 @@ class AsyncInferenceClient:
2581
2826
  # Must be run in an async context
2582
2827
  >>> from huggingface_hub import AsyncInferenceClient
2583
2828
  >>> client = AsyncInferenceClient()
2584
- >>> await client.get_model_status("bigcode/starcoder")
2829
+ >>> await client.get_model_status("meta-llama/Meta-Llama-3-8B-Instruct")
2585
2830
  ModelStatus(loaded=True, state='Loaded', compute_type='gpu', framework='text-generation-inference')
2586
2831
  ```
2587
2832
  """
@@ -2606,3 +2851,30 @@ class AsyncInferenceClient:
2606
2851
  compute_type=response_data["compute_type"],
2607
2852
  framework=response_data["framework"],
2608
2853
  )
2854
+
2855
+ @property
2856
+ def chat(self) -> "ProxyClientChat":
2857
+ return ProxyClientChat(self)
2858
+
2859
+
2860
+ class _ProxyClient:
2861
+ """Proxy class to be able to call `client.chat.completion.create(...)` as OpenAI client."""
2862
+
2863
+ def __init__(self, client: AsyncInferenceClient):
2864
+ self._client = client
2865
+
2866
+
2867
+ class ProxyClientChat(_ProxyClient):
2868
+ """Proxy class to be able to call `client.chat.completion.create(...)` as OpenAI client."""
2869
+
2870
+ @property
2871
+ def completions(self) -> "ProxyClientChatCompletions":
2872
+ return ProxyClientChatCompletions(self._client)
2873
+
2874
+
2875
+ class ProxyClientChatCompletions(_ProxyClient):
2876
+ """Proxy class to be able to call `client.chat.completion.create(...)` as OpenAI client."""
2877
+
2878
+ @property
2879
+ def create(self):
2880
+ return self._client.chat_completion