huggingface-hub 0.23.3__py3-none-any.whl → 0.24.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of huggingface-hub might be problematic. Click here for more details.

Files changed (44) hide show
  1. huggingface_hub/__init__.py +47 -15
  2. huggingface_hub/_commit_api.py +38 -8
  3. huggingface_hub/_inference_endpoints.py +11 -4
  4. huggingface_hub/_local_folder.py +22 -13
  5. huggingface_hub/_snapshot_download.py +12 -7
  6. huggingface_hub/_webhooks_server.py +3 -1
  7. huggingface_hub/commands/huggingface_cli.py +4 -3
  8. huggingface_hub/commands/repo_files.py +128 -0
  9. huggingface_hub/constants.py +12 -0
  10. huggingface_hub/file_download.py +127 -91
  11. huggingface_hub/hf_api.py +979 -341
  12. huggingface_hub/hf_file_system.py +30 -3
  13. huggingface_hub/hub_mixin.py +103 -41
  14. huggingface_hub/inference/_client.py +373 -42
  15. huggingface_hub/inference/_common.py +0 -2
  16. huggingface_hub/inference/_generated/_async_client.py +390 -48
  17. huggingface_hub/inference/_generated/types/__init__.py +4 -1
  18. huggingface_hub/inference/_generated/types/chat_completion.py +41 -21
  19. huggingface_hub/inference/_generated/types/feature_extraction.py +23 -5
  20. huggingface_hub/inference/_generated/types/text_generation.py +29 -0
  21. huggingface_hub/lfs.py +11 -6
  22. huggingface_hub/repocard_data.py +41 -29
  23. huggingface_hub/repository.py +6 -6
  24. huggingface_hub/serialization/__init__.py +8 -3
  25. huggingface_hub/serialization/_base.py +13 -16
  26. huggingface_hub/serialization/_tensorflow.py +4 -3
  27. huggingface_hub/serialization/_torch.py +399 -22
  28. huggingface_hub/utils/__init__.py +1 -2
  29. huggingface_hub/utils/_errors.py +1 -1
  30. huggingface_hub/utils/_fixes.py +14 -3
  31. huggingface_hub/utils/_paths.py +17 -6
  32. huggingface_hub/utils/_subprocess.py +0 -1
  33. huggingface_hub/utils/_telemetry.py +9 -1
  34. huggingface_hub/utils/_typing.py +26 -1
  35. huggingface_hub/utils/endpoint_helpers.py +2 -186
  36. huggingface_hub/utils/sha.py +36 -1
  37. huggingface_hub/utils/tqdm.py +0 -1
  38. {huggingface_hub-0.23.3.dist-info → huggingface_hub-0.24.0rc0.dist-info}/METADATA +12 -9
  39. {huggingface_hub-0.23.3.dist-info → huggingface_hub-0.24.0rc0.dist-info}/RECORD +43 -43
  40. huggingface_hub/serialization/_numpy.py +0 -68
  41. {huggingface_hub-0.23.3.dist-info → huggingface_hub-0.24.0rc0.dist-info}/LICENSE +0 -0
  42. {huggingface_hub-0.23.3.dist-info → huggingface_hub-0.24.0rc0.dist-info}/WHEEL +0 -0
  43. {huggingface_hub-0.23.3.dist-info → huggingface_hub-0.24.0rc0.dist-info}/entry_points.txt +0 -0
  44. {huggingface_hub-0.23.3.dist-info → huggingface_hub-0.24.0rc0.dist-info}/top_level.txt +0 -0
@@ -64,6 +64,7 @@ from huggingface_hub.inference._generated.types import (
64
64
  AudioClassificationOutputElement,
65
65
  AudioToAudioOutputElement,
66
66
  AutomaticSpeechRecognitionOutput,
67
+ ChatCompletionInputGrammarType,
67
68
  ChatCompletionInputTool,
68
69
  ChatCompletionInputToolTypeClass,
69
70
  ChatCompletionOutput,
@@ -89,13 +90,13 @@ from huggingface_hub.inference._generated.types import (
89
90
  ZeroShotClassificationOutputElement,
90
91
  ZeroShotImageClassificationOutputElement,
91
92
  )
92
- from huggingface_hub.inference._generated.types.chat_completion import ChatCompletionInputToolTypeEnum
93
93
  from huggingface_hub.inference._types import (
94
94
  ConversationalOutput, # soon to be removed
95
95
  )
96
96
  from huggingface_hub.utils import (
97
97
  build_hf_headers,
98
98
  )
99
+ from huggingface_hub.utils._deprecation import _deprecate_positional_args
99
100
 
100
101
  from .._common import _async_yield_from, _import_aiohttp
101
102
 
@@ -119,12 +120,16 @@ class AsyncInferenceClient:
119
120
 
120
121
  Args:
121
122
  model (`str`, `optional`):
122
- The model to run inference with. Can be a model id hosted on the Hugging Face Hub, e.g. `bigcode/starcoder`
123
+ The model to run inference with. Can be a model id hosted on the Hugging Face Hub, e.g. `meta-llama/Meta-Llama-3-8B-Instruct`
123
124
  or a URL to a deployed Inference Endpoint. Defaults to None, in which case a recommended model is
124
125
  automatically selected for the task.
126
+ Note: for better compatibility with OpenAI's client, `model` has been aliased as `base_url`. Those 2
127
+ arguments are mutually exclusive and have the exact same behavior.
125
128
  token (`str` or `bool`, *optional*):
126
129
  Hugging Face token. Will default to the locally saved token if not provided.
127
130
  Pass `token=False` if you don't want to send your token to the server.
131
+ Note: for better compatibility with OpenAI's client, `token` has been aliased as `api_key`. Those 2
132
+ arguments are mutually exclusive and have the exact same behavior.
128
133
  timeout (`float`, `optional`):
129
134
  The maximum number of seconds to wait for a response from the server. Loading a new model in Inference
130
135
  API can take up to several minutes. Defaults to None, meaning it will loop until the server is available.
@@ -133,23 +138,52 @@ class AsyncInferenceClient:
133
138
  Values in this dictionary will override the default values.
134
139
  cookies (`Dict[str, str]`, `optional`):
135
140
  Additional cookies to send to the server.
141
+ base_url (`str`, `optional`):
142
+ Base URL to run inference. This is a duplicated argument from `model` to make [`InferenceClient`]
143
+ follow the same pattern as `openai.OpenAI` client. Cannot be used if `model` is set. Defaults to None.
144
+ api_key (`str`, `optional`):
145
+ Token to use for authentication. This is a duplicated argument from `token` to make [`InferenceClient`]
146
+ follow the same pattern as `openai.OpenAI` client. Cannot be used if `token` is set. Defaults to None.
136
147
  """
137
148
 
149
+ @_deprecate_positional_args(version="0.26")
138
150
  def __init__(
139
151
  self,
140
152
  model: Optional[str] = None,
153
+ *,
141
154
  token: Union[str, bool, None] = None,
142
155
  timeout: Optional[float] = None,
143
156
  headers: Optional[Dict[str, str]] = None,
144
157
  cookies: Optional[Dict[str, str]] = None,
158
+ proxies: Optional[Any] = None,
159
+ # OpenAI compatibility
160
+ base_url: Optional[str] = None,
161
+ api_key: Optional[str] = None,
145
162
  ) -> None:
163
+ if model is not None and base_url is not None:
164
+ raise ValueError(
165
+ "Received both `model` and `base_url` arguments. Please provide only one of them."
166
+ " `base_url` is an alias for `model` to make the API compatible with OpenAI's client."
167
+ " It has the exact same behavior as `model`."
168
+ )
169
+ if token is not None and api_key is not None:
170
+ raise ValueError(
171
+ "Received both `token` and `api_key` arguments. Please provide only one of them."
172
+ " `api_key` is an alias for `token` to make the API compatible with OpenAI's client."
173
+ " It has the exact same behavior as `token`."
174
+ )
175
+
146
176
  self.model: Optional[str] = model
147
- self.token: Union[str, bool, None] = token
148
- self.headers = CaseInsensitiveDict(build_hf_headers(token=token)) # contains 'authorization' + 'user-agent'
177
+ self.token: Union[str, bool, None] = token or api_key
178
+ self.headers = CaseInsensitiveDict(build_hf_headers(token=self.token)) # 'authorization' + 'user-agent'
149
179
  if headers is not None:
150
180
  self.headers.update(headers)
151
181
  self.cookies = cookies
152
182
  self.timeout = timeout
183
+ self.proxies = proxies
184
+
185
+ # OpenAI compatibility
186
+ self.base_url = base_url
153
187
 
154
188
  def __repr__(self):
155
189
  return f"<InferenceClient(model='{self.model if self.model else ''}', timeout={self.timeout})>"
@@ -250,7 +284,7 @@ class AsyncInferenceClient:
250
284
  )
251
285
 
252
286
  try:
253
- response = await client.post(url, json=json, data=data_as_binary)
287
+ response = await client.post(url, json=json, data=data_as_binary, proxy=self.proxies)
254
288
  response_error_payload = None
255
289
  if response.status != 200:
256
290
  try:
@@ -284,11 +318,16 @@ class AsyncInferenceClient:
284
318
  ) from error
285
319
  # ...or wait 1s and retry
286
320
  logger.info(f"Waiting for model to be loaded on the server: {error}")
321
+ if "X-wait-for-model" not in headers and url.startswith(INFERENCE_ENDPOINT):
322
+ headers["X-wait-for-model"] = "1"
287
323
  time.sleep(1)
288
324
  if timeout is not None:
289
325
  timeout = max(self.timeout - (time.time() - t0), 1) # type: ignore
290
326
  continue
291
327
  raise error
328
+ except Exception:
329
+ await client.close()
330
+ raise
292
331
 
293
332
  async def audio_classification(
294
333
  self,
@@ -427,10 +466,11 @@ class AsyncInferenceClient:
427
466
  max_tokens: Optional[int] = None,
428
467
  n: Optional[int] = None,
429
468
  presence_penalty: Optional[float] = None,
469
+ response_format: Optional[ChatCompletionInputGrammarType] = None,
430
470
  seed: Optional[int] = None,
431
471
  stop: Optional[List[str]] = None,
432
472
  temperature: Optional[float] = None,
433
- tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, ChatCompletionInputToolTypeEnum]] = None,
473
+ tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, str]] = None,
434
474
  tool_prompt: Optional[str] = None,
435
475
  tools: Optional[List[ChatCompletionInputTool]] = None,
436
476
  top_logprobs: Optional[int] = None,
@@ -450,10 +490,11 @@ class AsyncInferenceClient:
450
490
  max_tokens: Optional[int] = None,
451
491
  n: Optional[int] = None,
452
492
  presence_penalty: Optional[float] = None,
493
+ response_format: Optional[ChatCompletionInputGrammarType] = None,
453
494
  seed: Optional[int] = None,
454
495
  stop: Optional[List[str]] = None,
455
496
  temperature: Optional[float] = None,
456
- tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, ChatCompletionInputToolTypeEnum]] = None,
497
+ tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, str]] = None,
457
498
  tool_prompt: Optional[str] = None,
458
499
  tools: Optional[List[ChatCompletionInputTool]] = None,
459
500
  top_logprobs: Optional[int] = None,
@@ -473,10 +514,11 @@ class AsyncInferenceClient:
473
514
  max_tokens: Optional[int] = None,
474
515
  n: Optional[int] = None,
475
516
  presence_penalty: Optional[float] = None,
517
+ response_format: Optional[ChatCompletionInputGrammarType] = None,
476
518
  seed: Optional[int] = None,
477
519
  stop: Optional[List[str]] = None,
478
520
  temperature: Optional[float] = None,
479
- tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, ChatCompletionInputToolTypeEnum]] = None,
521
+ tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, str]] = None,
480
522
  tool_prompt: Optional[str] = None,
481
523
  tools: Optional[List[ChatCompletionInputTool]] = None,
482
524
  top_logprobs: Optional[int] = None,
@@ -496,10 +538,11 @@ class AsyncInferenceClient:
496
538
  max_tokens: Optional[int] = None,
497
539
  n: Optional[int] = None,
498
540
  presence_penalty: Optional[float] = None,
541
+ response_format: Optional[ChatCompletionInputGrammarType] = None,
499
542
  seed: Optional[int] = None,
500
543
  stop: Optional[List[str]] = None,
501
544
  temperature: Optional[float] = None,
502
- tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, ChatCompletionInputToolTypeEnum]] = None,
545
+ tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, str]] = None,
503
546
  tool_prompt: Optional[str] = None,
504
547
  tools: Optional[List[ChatCompletionInputTool]] = None,
505
548
  top_logprobs: Optional[int] = None,
@@ -510,11 +553,10 @@ class AsyncInferenceClient:
510
553
 
511
554
  <Tip>
512
555
 
513
- If the model is served by a server supporting chat-completion, the method will directly call the server's
514
- `/v1/chat/completions` endpoint. If the server does not support chat-completion, the method will render the
515
- chat template client-side based on the information fetched from the Hub API. In this case, you will need to
516
- have `minijinja` template engine installed. Run `pip install "huggingface_hub[inference]"` or `pip install minijinja`
517
- to install it.
556
+ The `client.chat_completion` method is aliased as `client.chat.completions.create` for compatibility with OpenAI's client.
557
+ Inputs and outputs are strictly the same and using either syntax will yield the same results.
558
+ Check out the [Inference guide](https://huggingface.co/docs/huggingface_hub/guides/inference#openai-compatibility)
559
+ for more details about OpenAI's compatibility.
518
560
 
519
561
  </Tip>
520
562
 
@@ -525,6 +567,9 @@ class AsyncInferenceClient:
525
567
  The model to use for chat-completion. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
526
568
  Inference Endpoint. If not provided, the default recommended model for chat-based text-generation will be used.
527
569
  See https://huggingface.co/tasks/text-generation for more details.
570
+
571
+ If `model` is a model ID, it is passed to the server as the `model` parameter. If you want to define a
572
+ custom URL while setting `model` in the request payload, you must set `base_url` when initializing [`InferenceClient`].
528
573
  frequency_penalty (`float`, *optional*):
529
574
  Penalizes new tokens based on their existing frequency
530
575
  in the text so far. Range: [-2.0, 2.0]. Defaults to 0.0.
@@ -544,6 +589,8 @@ class AsyncInferenceClient:
544
589
  presence_penalty (`float`, *optional*):
545
590
  Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the
546
591
  text so far, increasing the model's likelihood to talk about new topics.
592
+ response_format ([`ChatCompletionInputGrammarType`], *optional*):
593
+ Grammar constraints. Can be either a JSONSchema or a regex.
547
594
  seed (Optional[`int`], *optional*):
548
595
  Seed for reproducible control flow. Defaults to None.
549
596
  stop (Optional[`str`], *optional*):
@@ -561,7 +608,7 @@ class AsyncInferenceClient:
561
608
  top_p (`float`, *optional*):
562
609
  Fraction of the most likely next words to sample from.
563
610
  Must be between 0 and 1. Defaults to 1.0.
564
- tool_choice ([`ChatCompletionInputToolTypeClass`] or [`ChatCompletionInputToolTypeEnum`], *optional*):
611
+ tool_choice ([`ChatCompletionInputToolTypeClass`] or `str`, *optional*):
565
612
  The tool to use for the completion. Defaults to "auto".
566
613
  tool_prompt (`str`, *optional*):
567
614
  A prompt to be appended before the tools.
@@ -570,7 +617,7 @@ class AsyncInferenceClient:
570
617
  provide a list of functions the model may generate JSON inputs for.
571
618
 
572
619
  Returns:
573
- [`ChatCompletionOutput] or Iterable of [`ChatCompletionStreamOutput`]:
620
+ [`ChatCompletionOutput`] or Iterable of [`ChatCompletionStreamOutput`]:
574
621
  Generated text returned from the server:
575
622
  - if `stream=False`, the generated text is returned as a [`ChatCompletionOutput`] (default).
576
623
  - if `stream=True`, the generated text is returned token by token as a sequence of [`ChatCompletionStreamOutput`].
@@ -585,10 +632,9 @@ class AsyncInferenceClient:
585
632
 
586
633
  ```py
587
634
  # Must be run in an async context
588
- # Chat example
589
635
  >>> from huggingface_hub import AsyncInferenceClient
590
636
  >>> messages = [{"role": "user", "content": "What is the capital of France?"}]
591
- >>> client = AsyncInferenceClient("HuggingFaceH4/zephyr-7b-beta")
637
+ >>> client = AsyncInferenceClient("meta-llama/Meta-Llama-3-8B-Instruct")
592
638
  >>> await client.chat_completion(messages, max_tokens=100)
593
639
  ChatCompletionOutput(
594
640
  choices=[
@@ -596,26 +642,75 @@ class AsyncInferenceClient:
596
642
  finish_reason='eos_token',
597
643
  index=0,
598
644
  message=ChatCompletionOutputMessage(
599
- content='The capital of France is Paris. The official name of the city is Ville de Paris (City of Paris) and the name of the country governing body, which is located in Paris, is La République française (The French Republic). \nI hope that helps! Let me know if you need any further information.'
600
- )
645
+ role='assistant',
646
+ content='The capital of France is Paris.',
647
+ name=None,
648
+ tool_calls=None
649
+ ),
650
+ logprobs=None
601
651
  )
602
652
  ],
603
- created=1710498360
653
+ created=1719907176,
654
+ id='',
655
+ model='meta-llama/Meta-Llama-3-8B-Instruct',
656
+ object='text_completion',
657
+ system_fingerprint='2.0.4-sha-f426a33',
658
+ usage=ChatCompletionOutputUsage(
659
+ completion_tokens=8,
660
+ prompt_tokens=17,
661
+ total_tokens=25
662
+ )
604
663
  )
664
+ ```
605
665
 
666
+ Example (stream=True):
667
+ ```py
668
+ # Must be run in an async context
669
+ >>> from huggingface_hub import AsyncInferenceClient
670
+ >>> messages = [{"role": "user", "content": "What is the capital of France?"}]
671
+ >>> client = AsyncInferenceClient("meta-llama/Meta-Llama-3-8B-Instruct")
606
672
  >>> async for token in await client.chat_completion(messages, max_tokens=10, stream=True):
607
673
  ... print(token)
608
674
  ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content='The', role='assistant'), index=0, finish_reason=None)], created=1710498504)
609
675
  ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content=' capital', role='assistant'), index=0, finish_reason=None)], created=1710498504)
610
676
  (...)
611
677
  ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content=' may', role='assistant'), index=0, finish_reason=None)], created=1710498504)
678
+ ```
612
679
 
613
- # Chat example with tools
680
+ Example using OpenAI's syntax:
681
+ ```py
682
+ # Must be run in an async context
683
+ # instead of `from openai import OpenAI`
684
+ from huggingface_hub import AsyncInferenceClient
685
+
686
+ # instead of `client = OpenAI(...)`
687
+ client = AsyncInferenceClient(
688
+ base_url=...,
689
+ api_key=...,
690
+ )
691
+
692
+ output = await client.chat.completions.create(
693
+ model="meta-llama/Meta-Llama-3-8B-Instruct",
694
+ messages=[
695
+ {"role": "system", "content": "You are a helpful assistant."},
696
+ {"role": "user", "content": "Count to 10"},
697
+ ],
698
+ stream=True,
699
+ max_tokens=1024,
700
+ )
701
+
702
+ for chunk in output:
703
+ print(chunk.choices[0].delta.content)
704
+ ```
705
+
706
+ Example using tools:
707
+ ```py
708
+ # Must be run in an async context
614
709
  >>> client = AsyncInferenceClient("meta-llama/Meta-Llama-3-70B-Instruct")
615
710
  >>> messages = [
616
711
  ... {
617
712
  ... "role": "system",
618
- ... "content": "Don't make assumptions about what values to plug into functions. Ask async for clarification if a user request is ambiguous.",
713
+ ... "content": "Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.",
619
714
  ... },
620
715
  ... {
621
716
  ... "role": "user",
@@ -691,9 +786,44 @@ class AsyncInferenceClient:
691
786
  description=None
692
787
  )
693
788
  ```
789
+
790
+ Example using response_format:
791
+ ```py
792
+ # Must be run in an async context
793
+ >>> from huggingface_hub import AsyncInferenceClient
794
+ >>> client = AsyncInferenceClient("meta-llama/Meta-Llama-3-70B-Instruct")
795
+ >>> messages = [
796
+ ... {
797
+ ... "role": "user",
798
+ ... "content": "I saw a puppy a cat and a raccoon during my bike ride in the park. What did I saw and when?",
799
+ ... },
800
+ ... ]
801
+ >>> response_format = {
802
+ ... "type": "json",
803
+ ... "value": {
804
+ ... "properties": {
805
+ ... "location": {"type": "string"},
806
+ ... "activity": {"type": "string"},
807
+ ... "animals_seen": {"type": "integer", "minimum": 1, "maximum": 5},
808
+ ... "animals": {"type": "array", "items": {"type": "string"}},
809
+ ... },
810
+ ... "required": ["location", "activity", "animals_seen", "animals"],
811
+ ... },
812
+ ... }
813
+ >>> response = await client.chat_completion(
814
+ ... messages=messages,
815
+ ... response_format=response_format,
816
+ ... max_tokens=500,
817
+ )
818
+ >>> response.choices[0].message.content
819
+ '{\n\n"activity": "bike ride",\n"animals": ["puppy", "cat", "raccoon"],\n"animals_seen": 3,\n"location": "park"}'
820
+ ```
694
821
  """
695
- # determine model
696
- model = model or self.model or self.get_recommended_model("text-generation")
822
+ # Determine model
823
+ # `self.xxx` takes precedence over the method argument only in `chat_completion`
824
+ # since `chat_completion(..., model=xxx)` is also a payload parameter for the
825
+ # server, we need to handle it differently
826
+ model = self.base_url or self.model or model or self.get_recommended_model("text-generation")
697
827
 
698
828
  if _is_chat_completion_server(model):
699
829
  # First, let's consider the server has a `/v1/chat/completions` endpoint.
@@ -702,11 +832,19 @@ class AsyncInferenceClient:
702
832
  if not model_url.endswith("/chat/completions"):
703
833
  model_url += "/v1/chat/completions"
704
834
 
835
+ # `model` is sent in the payload. Not used by the server but can be useful for debugging/routing.
836
+ if not model.startswith("http") and model.count("/") == 1:
837
+ # If it's a ID on the Hub => use it
838
+ model_id = model
839
+ else:
840
+ # Otherwise, we use a random string
841
+ model_id = "tgi"
842
+
705
843
  try:
706
844
  data = await self.post(
707
845
  model=model_url,
708
846
  json=dict(
709
- model="tgi", # random string
847
+ model=model_id,
710
848
  messages=messages,
711
849
  frequency_penalty=frequency_penalty,
712
850
  logit_bias=logit_bias,
@@ -714,6 +852,7 @@ class AsyncInferenceClient:
714
852
  max_tokens=max_tokens,
715
853
  n=n,
716
854
  presence_penalty=presence_penalty,
855
+ response_format=response_format,
717
856
  seed=seed,
718
857
  stop=stop,
719
858
  temperature=temperature,
@@ -765,6 +904,11 @@ class AsyncInferenceClient:
765
904
  "Tools are not supported by the model. This is due to the model not been served by a "
766
905
  "Text-Generation-Inference server. The provided tool parameters will be ignored."
767
906
  )
907
+ if response_format is not None:
908
+ warnings.warn(
909
+ "Response format is not supported by the model. This is due to the model not been served by a "
910
+ "Text-Generation-Inference server. The provided response format will be ignored."
911
+ )
768
912
 
769
913
  # generate response
770
914
  text_generation_output = await self.text_generation(
@@ -783,7 +927,6 @@ class AsyncInferenceClient:
783
927
  return ChatCompletionOutput(
784
928
  id="dummy",
785
929
  model="dummy",
786
- object="dummy",
787
930
  system_fingerprint="dummy",
788
931
  usage=None, # type: ignore # set to `None` as we don't want to provide false information
789
932
  created=int(time.time()),
@@ -850,7 +993,7 @@ class AsyncInferenceClient:
850
993
  >>> client = AsyncInferenceClient()
851
994
  >>> output = await client.conversational("Hi, who are you?")
852
995
  >>> output
853
- {'generated_text': 'I am the one who knocks.', 'conversation': {'generated_responses': ['I am the one who knocks.'], 'past_user_inputs': ['Hi, who are you?']}, 'warnings': ['Setting `pad_token_id` to `eos_token_id`:50256 async for open-end generation.']}
996
+ {'generated_text': 'I am the one who knocks.', 'conversation': {'generated_responses': ['I am the one who knocks.'], 'past_user_inputs': ['Hi, who are you?']}, 'warnings': ['Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.']}
854
997
  >>> await client.conversational(
855
998
  ... "Wow, that's scary!",
856
999
  ... generated_responses=output["conversation"]["generated_responses"],
@@ -915,7 +1058,16 @@ class AsyncInferenceClient:
915
1058
  response = await self.post(json=payload, model=model, task="document-question-answering")
916
1059
  return DocumentQuestionAnsweringOutputElement.parse_obj_as_list(response)
917
1060
 
918
- async def feature_extraction(self, text: str, *, model: Optional[str] = None) -> "np.ndarray":
1061
+ async def feature_extraction(
1062
+ self,
1063
+ text: str,
1064
+ *,
1065
+ normalize: Optional[bool] = None,
1066
+ prompt_name: Optional[str] = None,
1067
+ truncate: Optional[bool] = None,
1068
+ truncation_direction: Optional[Literal["Left", "Right"]] = None,
1069
+ model: Optional[str] = None,
1070
+ ) -> "np.ndarray":
919
1071
  """
920
1072
  Generate embeddings for a given text.
921
1073
 
@@ -926,6 +1078,20 @@ class AsyncInferenceClient:
926
1078
  The model to use for the conversational task. Can be a model ID hosted on the Hugging Face Hub or a URL to
927
1079
  a deployed Inference Endpoint. If not provided, the default recommended conversational model will be used.
928
1080
  Defaults to None.
1081
+ normalize (`bool`, *optional*):
1082
+ Whether to normalize the embeddings or not. Defaults to None.
1083
+ Only available on server powered by Text-Embedding-Inference.
1084
+ prompt_name (`str`, *optional*):
1085
+ The name of the prompt that should be used by for encoding. If not set, no prompt will be applied.
1086
+ Must be a key in the `Sentence Transformers` configuration `prompts` dictionary.
1087
+ For example if ``prompt_name`` is "query" and the ``prompts`` is {"query": "query: ",...},
1088
+ then the sentence "What is the capital of France?" will be encoded as "query: What is the capital of France?"
1089
+ because the prompt text will be prepended before any text to encode.
1090
+ truncate (`bool`, *optional*):
1091
+ Whether to truncate the embeddings or not. Defaults to None.
1092
+ Only available on server powered by Text-Embedding-Inference.
1093
+ truncation_direction (`Literal["Left", "Right"]`, *optional*):
1094
+ Which side of the input should be truncated when `truncate=True` is passed.
929
1095
 
930
1096
  Returns:
931
1097
  `np.ndarray`: The embedding representing the input text as a float32 numpy array.
@@ -948,7 +1114,16 @@ class AsyncInferenceClient:
948
1114
  [ 0.28552425, -0.928395 , -1.2077185 , ..., 0.76810825, -2.1069427 , 0.6236161 ]], dtype=float32)
949
1115
  ```
950
1116
  """
951
- response = await self.post(json={"inputs": text}, model=model, task="feature-extraction")
1117
+ payload: Dict = {"inputs": text}
1118
+ if normalize is not None:
1119
+ payload["normalize"] = normalize
1120
+ if prompt_name is not None:
1121
+ payload["prompt_name"] = prompt_name
1122
+ if truncate is not None:
1123
+ payload["truncate"] = truncate
1124
+ if truncation_direction is not None:
1125
+ payload["truncation_direction"] = truncation_direction
1126
+ response = await self.post(json=payload, model=model, task="feature-extraction")
952
1127
  np = _import_numpy()
953
1128
  return np.array(_bytes_to_dict(response), dtype="float32")
954
1129
 
@@ -1192,7 +1367,8 @@ class AsyncInferenceClient:
1192
1367
  ```
1193
1368
  """
1194
1369
  response = await self.post(data=image, model=model, task="image-to-text")
1195
- return ImageToTextOutput.parse_obj_as_instance(response)
1370
+ output = ImageToTextOutput.parse_obj(response)
1371
+ return output[0] if isinstance(output, list) else output
1196
1372
 
1197
1373
  async def list_deployed_models(
1198
1374
  self, frameworks: Union[None, str, Literal["all"], List[str]] = None
@@ -1643,6 +1819,7 @@ class AsyncInferenceClient:
1643
1819
  stream: Literal[False] = ...,
1644
1820
  model: Optional[str] = None,
1645
1821
  # Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
1822
+ adapter_id: Optional[str] = None,
1646
1823
  best_of: Optional[int] = None,
1647
1824
  decoder_input_details: Optional[bool] = None,
1648
1825
  do_sample: Optional[bool] = False, # Manual default value
@@ -1671,6 +1848,7 @@ class AsyncInferenceClient:
1671
1848
  stream: Literal[False] = ...,
1672
1849
  model: Optional[str] = None,
1673
1850
  # Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
1851
+ adapter_id: Optional[str] = None,
1674
1852
  best_of: Optional[int] = None,
1675
1853
  decoder_input_details: Optional[bool] = None,
1676
1854
  do_sample: Optional[bool] = False, # Manual default value
@@ -1699,6 +1877,7 @@ class AsyncInferenceClient:
1699
1877
  stream: Literal[True] = ...,
1700
1878
  model: Optional[str] = None,
1701
1879
  # Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
1880
+ adapter_id: Optional[str] = None,
1702
1881
  best_of: Optional[int] = None,
1703
1882
  decoder_input_details: Optional[bool] = None,
1704
1883
  do_sample: Optional[bool] = False, # Manual default value
@@ -1727,6 +1906,7 @@ class AsyncInferenceClient:
1727
1906
  stream: Literal[True] = ...,
1728
1907
  model: Optional[str] = None,
1729
1908
  # Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
1909
+ adapter_id: Optional[str] = None,
1730
1910
  best_of: Optional[int] = None,
1731
1911
  decoder_input_details: Optional[bool] = None,
1732
1912
  do_sample: Optional[bool] = False, # Manual default value
@@ -1755,6 +1935,7 @@ class AsyncInferenceClient:
1755
1935
  stream: bool = ...,
1756
1936
  model: Optional[str] = None,
1757
1937
  # Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
1938
+ adapter_id: Optional[str] = None,
1758
1939
  best_of: Optional[int] = None,
1759
1940
  decoder_input_details: Optional[bool] = None,
1760
1941
  do_sample: Optional[bool] = False, # Manual default value
@@ -1782,6 +1963,7 @@ class AsyncInferenceClient:
1782
1963
  stream: bool = False,
1783
1964
  model: Optional[str] = None,
1784
1965
  # Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
1966
+ adapter_id: Optional[str] = None,
1785
1967
  best_of: Optional[int] = None,
1786
1968
  decoder_input_details: Optional[bool] = None,
1787
1969
  do_sample: Optional[bool] = False, # Manual default value
@@ -1812,6 +1994,13 @@ class AsyncInferenceClient:
1812
1994
 
1813
1995
  To learn more about the TGI project, please refer to https://github.com/huggingface/text-generation-inference.
1814
1996
 
1997
+ <Tip>
1998
+
1999
+ If you want to generate a response from chat messages, you should use the [`InferenceClient.chat_completion`] method.
2000
+ It accepts a list of messages instead of a single text prompt and handles the chat templating for you.
2001
+
2002
+ </Tip>
2003
+
1815
2004
  Args:
1816
2005
  prompt (`str`):
1817
2006
  Input text.
@@ -1826,6 +2015,8 @@ class AsyncInferenceClient:
1826
2015
  model (`str`, *optional*):
1827
2016
  The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
1828
2017
  Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
2018
+ adapter_id (`str`, *optional*):
2019
+ Lora adapter id.
1829
2020
  best_of (`int`, *optional*):
1830
2021
  Generate best_of sequences and return the one if the highest token logprobs.
1831
2022
  decoder_input_details (`bool`, *optional*):
@@ -1893,7 +2084,7 @@ class AsyncInferenceClient:
1893
2084
  >>> await client.text_generation("The huggingface_hub library is ", max_new_tokens=12)
1894
2085
  '100% open source and built to be easy to use.'
1895
2086
 
1896
- # Case 2: iterate over the generated tokens. Useful async for large generation.
2087
+ # Case 2: iterate over the generated tokens. Useful for large generation.
1897
2088
  >>> async for token in await client.text_generation("The huggingface_hub library is ", max_new_tokens=12, stream=True):
1898
2089
  ... print(token)
1899
2090
  100
@@ -1995,6 +2186,7 @@ class AsyncInferenceClient:
1995
2186
 
1996
2187
  # Build payload
1997
2188
  parameters = {
2189
+ "adapter_id": adapter_id,
1998
2190
  "best_of": best_of,
1999
2191
  "decoder_input_details": decoder_input_details,
2000
2192
  "details": details,
@@ -2065,6 +2257,7 @@ class AsyncInferenceClient:
2065
2257
  details=details,
2066
2258
  stream=stream,
2067
2259
  model=model,
2260
+ adapter_id=adapter_id,
2068
2261
  best_of=best_of,
2069
2262
  decoder_input_details=decoder_input_details,
2070
2263
  do_sample=do_sample,
@@ -2377,7 +2570,13 @@ class AsyncInferenceClient:
2377
2570
  return VisualQuestionAnsweringOutputElement.parse_obj_as_list(response)
2378
2571
 
2379
2572
  async def zero_shot_classification(
2380
- self, text: str, labels: List[str], *, multi_label: bool = False, model: Optional[str] = None
2573
+ self,
2574
+ text: str,
2575
+ labels: List[str],
2576
+ *,
2577
+ multi_label: bool = False,
2578
+ hypothesis_template: Optional[str] = None,
2579
+ model: Optional[str] = None,
2381
2580
  ) -> List[ZeroShotClassificationOutputElement]:
2382
2581
  """
2383
2582
  Provide as input a text and a set of candidate labels to classify the input text.
@@ -2386,9 +2585,15 @@ class AsyncInferenceClient:
2386
2585
  text (`str`):
2387
2586
  The input text to classify.
2388
2587
  labels (`List[str]`):
2389
- List of string possible labels. There must be at least 2 labels.
2588
+ List of strings. Each string is the verbalization of a possible label for the input text.
2390
2589
  multi_label (`bool`):
2391
- Boolean that is set to True if classes can overlap.
2590
+ Boolean. If True, the probability for each label is evaluated independently and multiple labels can have a probability close to 1 simultaneously or all probabilities can be close to 0.
2591
+ If False, the labels are considered mutually exclusive and the probability over all labels always sums to 1. Defaults to False.
2592
+ hypothesis_template (`str`, *optional*):
2593
+ A template sentence string with curly brackets to which the label strings are added. The label strings are added at the position of the curly brackets "{}".
2594
+ Zero-shot classifiers are based on NLI models, which evaluate if a hypothesis is entailed in another text or not.
2595
+ For example, with hypothesis_template="This text is about {}." and labels=["economics", "politics"], the system internally creates the two hypotheses "This text is about economics." and "This text is about politics.".
2596
+ The model then evaluates for both hypotheses if they are entailed in the provided `text` or not.
2392
2597
  model (`str`, *optional*):
2393
2598
  The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
2394
2599
  Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
@@ -2402,15 +2607,15 @@ class AsyncInferenceClient:
2402
2607
  `aiohttp.ClientResponseError`:
2403
2608
  If the request fails with an HTTP error status code other than HTTP 503.
2404
2609
 
2405
- Example:
2610
+ Example with `multi_label=False`:
2406
2611
  ```py
2407
2612
  # Must be run in an async context
2408
2613
  >>> from huggingface_hub import AsyncInferenceClient
2409
2614
  >>> client = AsyncInferenceClient()
2410
2615
  >>> text = (
2411
- ... "A new model offers an explanation async for how the Galilean satellites formed around the solar system's"
2616
+ ... "A new model offers an explanation for how the Galilean satellites formed around the solar system's"
2412
2617
  ... "largest world. Konstantin Batygin did not set out to solve one of the solar system's most puzzling"
2413
- ... " mysteries when he went async for a run up a hill in Nice, France."
2618
+ ... " mysteries when he went for a run up a hill in Nice, France."
2414
2619
  ... )
2415
2620
  >>> labels = ["space & cosmos", "scientific discovery", "microbiology", "robots", "archeology"]
2416
2621
  >>> await client.zero_shot_classification(text, labels)
@@ -2430,21 +2635,38 @@ class AsyncInferenceClient:
2430
2635
  ZeroShotClassificationOutputElement(label='robots', score=0.00030448526376858354),
2431
2636
  ]
2432
2637
  ```
2638
+
2639
+ Example with `multi_label=True` and a custom `hypothesis_template`:
2640
+ ```py
2641
+ # Must be run in an async context
2642
+ >>> from huggingface_hub import AsyncInferenceClient
2643
+ >>> client = AsyncInferenceClient()
2644
+ >>> await client.zero_shot_classification(
2645
+ ... text="I really like our dinner and I'm very happy. I don't like the weather though.",
2646
+ ... labels=["positive", "negative", "pessimistic", "optimistic"],
2647
+ ... multi_label=True,
2648
+ ... hypothesis_template="This text is {} towards the weather"
2649
+ ... )
2650
+ [
2651
+ ZeroShotClassificationOutputElement(label='negative', score=0.9231801629066467),
2652
+ ZeroShotClassificationOutputElement(label='pessimistic', score=0.8760990500450134),
2653
+ ZeroShotClassificationOutputElement(label='optimistic', score=0.0008674879791215062),
2654
+ ZeroShotClassificationOutputElement(label='positive', score=0.0005250611575320363)
2655
+ ]
2656
+ ```
2433
2657
  """
2434
- # Raise ValueError if input is less than 2 labels
2435
- if len(labels) < 2:
2436
- raise ValueError("You must specify at least 2 classes to compare.")
2658
+
2659
+ parameters = {"candidate_labels": labels, "multi_label": multi_label}
2660
+ if hypothesis_template is not None:
2661
+ parameters["hypothesis_template"] = hypothesis_template
2437
2662
 
2438
2663
  response = await self.post(
2439
2664
  json={
2440
2665
  "inputs": text,
2441
- "parameters": {
2442
- "candidate_labels": ",".join(labels),
2443
- "multi_label": multi_label,
2444
- },
2666
+ "parameters": parameters,
2445
2667
  },
2446
- model=model,
2447
2668
  task="zero-shot-classification",
2669
+ model=model,
2448
2670
  )
2449
2671
  output = _bytes_to_dict(response)
2450
2672
  return [
@@ -2501,7 +2723,7 @@ class AsyncInferenceClient:
2501
2723
  return ZeroShotImageClassificationOutputElement.parse_obj_as_list(response)
2502
2724
 
2503
2725
  def _resolve_url(self, model: Optional[str] = None, task: Optional[str] = None) -> str:
2504
- model = model or self.model
2726
+ model = model or self.model or self.base_url
2505
2727
 
2506
2728
  # If model is already a URL, ignore `task` and return directly
2507
2729
  if model is not None and (model.startswith("http://") or model.startswith("https://")):
@@ -2554,6 +2776,99 @@ class AsyncInferenceClient:
2554
2776
  )
2555
2777
  return model
2556
2778
 
2779
+ async def get_endpoint_info(self, *, model: Optional[str] = None) -> Dict[str, Any]:
2780
+ """
2781
+ Get information about the deployed endpoint.
2782
+
2783
+ This endpoint is only available on endpoints powered by Text-Generation-Inference (TGI) or Text-Embedding-Inference (TEI).
2784
+ Endpoints powered by `transformers` return an empty payload.
2785
+
2786
+ Args:
2787
+ model (`str`, *optional*):
2788
+ The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
2789
+ Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
2790
+
2791
+ Returns:
2792
+ `Dict[str, Any]`: Information about the endpoint.
2793
+
2794
+ Example:
2795
+ ```py
2796
+ # Must be run in an async context
2797
+ >>> from huggingface_hub import AsyncInferenceClient
2798
+ >>> client = AsyncInferenceClient("meta-llama/Meta-Llama-3-70B-Instruct")
2799
+ >>> await client.get_endpoint_info()
2800
+ {
2801
+ 'model_id': 'meta-llama/Meta-Llama-3-70B-Instruct',
2802
+ 'model_sha': None,
2803
+ 'model_dtype': 'torch.float16',
2804
+ 'model_device_type': 'cuda',
2805
+ 'model_pipeline_tag': None,
2806
+ 'max_concurrent_requests': 128,
2807
+ 'max_best_of': 2,
2808
+ 'max_stop_sequences': 4,
2809
+ 'max_input_length': 8191,
2810
+ 'max_total_tokens': 8192,
2811
+ 'waiting_served_ratio': 0.3,
2812
+ 'max_batch_total_tokens': 1259392,
2813
+ 'max_waiting_tokens': 20,
2814
+ 'max_batch_size': None,
2815
+ 'validation_workers': 32,
2816
+ 'max_client_batch_size': 4,
2817
+ 'version': '2.0.2',
2818
+ 'sha': 'dccab72549635c7eb5ddb17f43f0b7cdff07c214',
2819
+ 'docker_label': 'sha-dccab72'
2820
+ }
2821
+ ```
2822
+ """
2823
+ model = model or self.model
2824
+ if model is None:
2825
+ raise ValueError("Model id not provided.")
2826
+ if model.startswith(("http://", "https://")):
2827
+ url = model.rstrip("/") + "/info"
2828
+ else:
2829
+ url = f"{INFERENCE_ENDPOINT}/models/{model}/info"
2830
+
2831
+ async with _import_aiohttp().ClientSession(headers=self.headers) as client:
2832
+ response = await client.get(url)
2833
+ response.raise_for_status()
2834
+ return await response.json()
2835
+
2836
+ async def health_check(self, model: Optional[str] = None) -> bool:
2837
+ """
2838
+ Check the health of the deployed endpoint.
2839
+
2840
+ Health check is only available with Inference Endpoints powered by Text-Generation-Inference (TGI) or Text-Embedding-Inference (TEI).
2841
+ For Inference API, please use [`InferenceClient.get_model_status`] instead.
2842
+
2843
+ Args:
2844
+ model (`str`, *optional*):
2845
+ URL of the Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
2846
+
2847
+ Returns:
2848
+ `bool`: True if everything is working fine.
2849
+
2850
+ Example:
2851
+ ```py
2852
+ # Must be run in an async context
2853
+ >>> from huggingface_hub import AsyncInferenceClient
2854
+ >>> client = AsyncInferenceClient("https://jzgu0buei5.us-east-1.aws.endpoints.huggingface.cloud")
2855
+ >>> await client.health_check()
2856
+ True
2857
+ ```
2858
+ """
2859
+ model = model or self.model
2860
+ if model is None:
2861
+ raise ValueError("Model id not provided.")
2862
+ if not model.startswith(("http://", "https://")):
2863
+ raise ValueError(
2864
+ "Model must be an Inference Endpoint URL. For serverless Inference API, please use `InferenceClient.get_model_status`."
2865
+ )
2866
+ url = model.rstrip("/") + "/health"
2867
+
2868
+ async with _import_aiohttp().ClientSession(headers=self.headers) as client:
2869
+ response = await client.get(url)
2870
+ return response.status == 200
2871
+
2557
2872
  async def get_model_status(self, model: Optional[str] = None) -> ModelStatus:
2558
2873
  """
2559
2874
  Get the status of a model hosted on the Inference API.
@@ -2581,7 +2896,7 @@ class AsyncInferenceClient:
2581
2896
  # Must be run in an async context
2582
2897
  >>> from huggingface_hub import AsyncInferenceClient
2583
2898
  >>> client = AsyncInferenceClient()
2584
- >>> await client.get_model_status("bigcode/starcoder")
2899
+ >>> await client.get_model_status("meta-llama/Meta-Llama-3-8B-Instruct")
2585
2900
  ModelStatus(loaded=True, state='Loaded', compute_type='gpu', framework='text-generation-inference')
2586
2901
  ```
2587
2902
  """
@@ -2606,3 +2921,30 @@ class AsyncInferenceClient:
2606
2921
  compute_type=response_data["compute_type"],
2607
2922
  framework=response_data["framework"],
2608
2923
  )
2924
+
2925
+ @property
2926
+ def chat(self) -> "ProxyClientChat":
2927
+ return ProxyClientChat(self)
2928
+
2929
+
2930
+ class _ProxyClient:
2931
+ """Proxy class to be able to call `client.chat.completion.create(...)` as OpenAI client."""
2932
+
2933
+ def __init__(self, client: AsyncInferenceClient):
2934
+ self._client = client
2935
+
2936
+
2937
+ class ProxyClientChat(_ProxyClient):
2938
+ """Proxy class to be able to call `client.chat.completion.create(...)` as OpenAI client."""
2939
+
2940
+ @property
2941
+ def completions(self) -> "ProxyClientChatCompletions":
2942
+ return ProxyClientChatCompletions(self._client)
2943
+
2944
+
2945
+ class ProxyClientChatCompletions(_ProxyClient):
2946
+ """Proxy class to be able to call `client.chat.completion.create(...)` as OpenAI client."""
2947
+
2948
+ @property
2949
+ def create(self):
2950
+ return self._client.chat_completion