huggingface-hub 0.23.5__py3-none-any.whl → 0.24.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of huggingface-hub might be problematic. Click here for more details.

Files changed (42) hide show
  1. huggingface_hub/__init__.py +47 -15
  2. huggingface_hub/_commit_api.py +38 -8
  3. huggingface_hub/_inference_endpoints.py +11 -4
  4. huggingface_hub/_local_folder.py +22 -13
  5. huggingface_hub/_snapshot_download.py +12 -7
  6. huggingface_hub/_webhooks_server.py +3 -1
  7. huggingface_hub/commands/huggingface_cli.py +4 -3
  8. huggingface_hub/commands/repo_files.py +128 -0
  9. huggingface_hub/constants.py +12 -0
  10. huggingface_hub/file_download.py +127 -91
  11. huggingface_hub/hf_api.py +976 -341
  12. huggingface_hub/hf_file_system.py +30 -3
  13. huggingface_hub/inference/_client.py +408 -147
  14. huggingface_hub/inference/_common.py +25 -63
  15. huggingface_hub/inference/_generated/_async_client.py +425 -153
  16. huggingface_hub/inference/_generated/types/__init__.py +4 -1
  17. huggingface_hub/inference/_generated/types/chat_completion.py +41 -21
  18. huggingface_hub/inference/_generated/types/feature_extraction.py +23 -5
  19. huggingface_hub/inference/_generated/types/text_generation.py +29 -0
  20. huggingface_hub/lfs.py +11 -6
  21. huggingface_hub/repocard_data.py +3 -3
  22. huggingface_hub/repository.py +6 -6
  23. huggingface_hub/serialization/__init__.py +8 -3
  24. huggingface_hub/serialization/_base.py +13 -16
  25. huggingface_hub/serialization/_tensorflow.py +4 -3
  26. huggingface_hub/serialization/_torch.py +399 -22
  27. huggingface_hub/utils/__init__.py +0 -1
  28. huggingface_hub/utils/_errors.py +1 -1
  29. huggingface_hub/utils/_fixes.py +14 -3
  30. huggingface_hub/utils/_paths.py +17 -6
  31. huggingface_hub/utils/_subprocess.py +0 -1
  32. huggingface_hub/utils/_telemetry.py +9 -1
  33. huggingface_hub/utils/endpoint_helpers.py +2 -186
  34. huggingface_hub/utils/sha.py +36 -1
  35. huggingface_hub/utils/tqdm.py +0 -1
  36. {huggingface_hub-0.23.5.dist-info → huggingface_hub-0.24.1.dist-info}/METADATA +12 -9
  37. {huggingface_hub-0.23.5.dist-info → huggingface_hub-0.24.1.dist-info}/RECORD +41 -41
  38. huggingface_hub/serialization/_numpy.py +0 -68
  39. {huggingface_hub-0.23.5.dist-info → huggingface_hub-0.24.1.dist-info}/LICENSE +0 -0
  40. {huggingface_hub-0.23.5.dist-info → huggingface_hub-0.24.1.dist-info}/WHEEL +0 -0
  41. {huggingface_hub-0.23.5.dist-info → huggingface_hub-0.24.1.dist-info}/entry_points.txt +0 -0
  42. {huggingface_hub-0.23.5.dist-info → huggingface_hub-0.24.1.dist-info}/top_level.txt +0 -0
@@ -66,11 +66,9 @@ from huggingface_hub.inference._common import (
66
66
  _fetch_recommended_models,
67
67
  _get_unsupported_text_generation_kwargs,
68
68
  _import_numpy,
69
- _is_chat_completion_server,
70
69
  _open_as_binary,
71
- _set_as_non_chat_completion_server,
72
70
  _set_unsupported_text_generation_kwargs,
73
- _stream_chat_completion_response_from_bytes,
71
+ _stream_chat_completion_response,
74
72
  _stream_text_generation_response,
75
73
  raise_text_generation_error,
76
74
  )
@@ -78,11 +76,10 @@ from huggingface_hub.inference._generated.types import (
78
76
  AudioClassificationOutputElement,
79
77
  AudioToAudioOutputElement,
80
78
  AutomaticSpeechRecognitionOutput,
79
+ ChatCompletionInputGrammarType,
81
80
  ChatCompletionInputTool,
82
81
  ChatCompletionInputToolTypeClass,
83
82
  ChatCompletionOutput,
84
- ChatCompletionOutputComplete,
85
- ChatCompletionOutputMessage,
86
83
  ChatCompletionStreamOutput,
87
84
  DocumentQuestionAnsweringOutputElement,
88
85
  FillMaskOutputElement,
@@ -103,7 +100,6 @@ from huggingface_hub.inference._generated.types import (
103
100
  ZeroShotClassificationOutputElement,
104
101
  ZeroShotImageClassificationOutputElement,
105
102
  )
106
- from huggingface_hub.inference._generated.types.chat_completion import ChatCompletionInputToolTypeEnum
107
103
  from huggingface_hub.inference._types import (
108
104
  ConversationalOutput, # soon to be removed
109
105
  )
@@ -113,6 +109,7 @@ from huggingface_hub.utils import (
113
109
  get_session,
114
110
  hf_raise_for_status,
115
111
  )
112
+ from huggingface_hub.utils._deprecation import _deprecate_positional_args
116
113
 
117
114
 
118
115
  if TYPE_CHECKING:
@@ -134,12 +131,16 @@ class InferenceClient:
134
131
 
135
132
  Args:
136
133
  model (`str`, `optional`):
137
- The model to run inference with. Can be a model id hosted on the Hugging Face Hub, e.g. `bigcode/starcoder`
134
+ The model to run inference with. Can be a model id hosted on the Hugging Face Hub, e.g. `meta-llama/Meta-Llama-3-8B-Instruct`
138
135
  or a URL to a deployed Inference Endpoint. Defaults to None, in which case a recommended model is
139
136
  automatically selected for the task.
137
+ Note: for better compatibility with OpenAI's client, `model` has been aliased as `base_url`. Those 2
138
+ arguments are mutually exclusive and have the exact same behavior.
140
139
  token (`str` or `bool`, *optional*):
141
140
  Hugging Face token. Will default to the locally saved token if not provided.
142
141
  Pass `token=False` if you don't want to send your token to the server.
142
+ Note: for better compatibility with OpenAI's client, `token` has been aliased as `api_key`. Those 2
143
+ arguments are mutually exclusive and have the exact same behavior.
143
144
  timeout (`float`, `optional`):
144
145
  The maximum number of seconds to wait for a response from the server. Loading a new model in Inference
145
146
  API can take up to several minutes. Defaults to None, meaning it will loop until the server is available.
@@ -148,23 +149,52 @@ class InferenceClient:
148
149
  Values in this dictionary will override the default values.
149
150
  cookies (`Dict[str, str]`, `optional`):
150
151
  Additional cookies to send to the server.
152
+ base_url (`str`, `optional`):
153
+ Base URL to run inference. This is a duplicated argument from `model` to make [`InferenceClient`]
154
+ follow the same pattern as `openai.OpenAI` client. Cannot be used if `model` is set. Defaults to None.
155
+ api_key (`str`, `optional`):
156
+ Token to use for authentication. This is a duplicated argument from `token` to make [`InferenceClient`]
157
+ follow the same pattern as `openai.OpenAI` client. Cannot be used if `token` is set. Defaults to None.
151
158
  """
152
159
 
160
+ @_deprecate_positional_args(version="0.26")
153
161
  def __init__(
154
162
  self,
155
163
  model: Optional[str] = None,
164
+ *,
156
165
  token: Union[str, bool, None] = None,
157
166
  timeout: Optional[float] = None,
158
167
  headers: Optional[Dict[str, str]] = None,
159
168
  cookies: Optional[Dict[str, str]] = None,
169
+ proxies: Optional[Any] = None,
170
+ # OpenAI compatibility
171
+ base_url: Optional[str] = None,
172
+ api_key: Optional[str] = None,
160
173
  ) -> None:
174
+ if model is not None and base_url is not None:
175
+ raise ValueError(
176
+ "Received both `model` and `base_url` arguments. Please provide only one of them."
177
+ " `base_url` is an alias for `model` to make the API compatible with OpenAI's client."
178
+ " It has the exact same behavior as `model`."
179
+ )
180
+ if token is not None and api_key is not None:
181
+ raise ValueError(
182
+ "Received both `token` and `api_key` arguments. Please provide only one of them."
183
+ " `api_key` is an alias for `token` to make the API compatible with OpenAI's client."
184
+ " It has the exact same behavior as `token`."
185
+ )
186
+
161
187
  self.model: Optional[str] = model
162
- self.token: Union[str, bool, None] = token
163
- self.headers = CaseInsensitiveDict(build_hf_headers(token=token)) # contains 'authorization' + 'user-agent'
188
+ self.token: Union[str, bool, None] = token if token is not None else api_key
189
+ self.headers = CaseInsensitiveDict(build_hf_headers(token=self.token)) # 'authorization' + 'user-agent'
164
190
  if headers is not None:
165
191
  self.headers.update(headers)
166
192
  self.cookies = cookies
167
193
  self.timeout = timeout
194
+ self.proxies = proxies
195
+
196
+ # OpenAI compatibility
197
+ self.base_url = base_url
168
198
 
169
199
  def __repr__(self):
170
200
  return f"<InferenceClient(model='{self.model if self.model else ''}', timeout={self.timeout})>"
@@ -264,6 +294,7 @@ class InferenceClient:
264
294
  cookies=self.cookies,
265
295
  timeout=self.timeout,
266
296
  stream=stream,
297
+ proxies=self.proxies,
267
298
  )
268
299
  except TimeoutError as error:
269
300
  # Convert any `TimeoutError` to a `InferenceTimeoutError`
@@ -289,6 +320,8 @@ class InferenceClient:
289
320
  # ...or wait 1s and retry
290
321
  logger.info(f"Waiting for model to be loaded on the server: {error}")
291
322
  time.sleep(1)
323
+ if "X-wait-for-model" not in headers and url.startswith(INFERENCE_ENDPOINT):
324
+ headers["X-wait-for-model"] = "1"
292
325
  if timeout is not None:
293
326
  timeout = max(self.timeout - (time.time() - t0), 1) # type: ignore
294
327
  continue
@@ -428,10 +461,11 @@ class InferenceClient:
428
461
  max_tokens: Optional[int] = None,
429
462
  n: Optional[int] = None,
430
463
  presence_penalty: Optional[float] = None,
464
+ response_format: Optional[ChatCompletionInputGrammarType] = None,
431
465
  seed: Optional[int] = None,
432
466
  stop: Optional[List[str]] = None,
433
467
  temperature: Optional[float] = None,
434
- tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, ChatCompletionInputToolTypeEnum]] = None,
468
+ tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, str]] = None,
435
469
  tool_prompt: Optional[str] = None,
436
470
  tools: Optional[List[ChatCompletionInputTool]] = None,
437
471
  top_logprobs: Optional[int] = None,
@@ -451,10 +485,11 @@ class InferenceClient:
451
485
  max_tokens: Optional[int] = None,
452
486
  n: Optional[int] = None,
453
487
  presence_penalty: Optional[float] = None,
488
+ response_format: Optional[ChatCompletionInputGrammarType] = None,
454
489
  seed: Optional[int] = None,
455
490
  stop: Optional[List[str]] = None,
456
491
  temperature: Optional[float] = None,
457
- tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, ChatCompletionInputToolTypeEnum]] = None,
492
+ tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, str]] = None,
458
493
  tool_prompt: Optional[str] = None,
459
494
  tools: Optional[List[ChatCompletionInputTool]] = None,
460
495
  top_logprobs: Optional[int] = None,
@@ -474,10 +509,11 @@ class InferenceClient:
474
509
  max_tokens: Optional[int] = None,
475
510
  n: Optional[int] = None,
476
511
  presence_penalty: Optional[float] = None,
512
+ response_format: Optional[ChatCompletionInputGrammarType] = None,
477
513
  seed: Optional[int] = None,
478
514
  stop: Optional[List[str]] = None,
479
515
  temperature: Optional[float] = None,
480
- tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, ChatCompletionInputToolTypeEnum]] = None,
516
+ tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, str]] = None,
481
517
  tool_prompt: Optional[str] = None,
482
518
  tools: Optional[List[ChatCompletionInputTool]] = None,
483
519
  top_logprobs: Optional[int] = None,
@@ -497,10 +533,11 @@ class InferenceClient:
497
533
  max_tokens: Optional[int] = None,
498
534
  n: Optional[int] = None,
499
535
  presence_penalty: Optional[float] = None,
536
+ response_format: Optional[ChatCompletionInputGrammarType] = None,
500
537
  seed: Optional[int] = None,
501
538
  stop: Optional[List[str]] = None,
502
539
  temperature: Optional[float] = None,
503
- tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, ChatCompletionInputToolTypeEnum]] = None,
540
+ tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, str]] = None,
504
541
  tool_prompt: Optional[str] = None,
505
542
  tools: Optional[List[ChatCompletionInputTool]] = None,
506
543
  top_logprobs: Optional[int] = None,
@@ -511,11 +548,10 @@ class InferenceClient:
511
548
 
512
549
  <Tip>
513
550
 
514
- If the model is served by a server supporting chat-completion, the method will directly call the server's
515
- `/v1/chat/completions` endpoint. If the server does not support chat-completion, the method will render the
516
- chat template client-side based on the information fetched from the Hub API. In this case, you will need to
517
- have `minijinja` template engine installed. Run `pip install "huggingface_hub[inference]"` or `pip install minijinja`
518
- to install it.
551
+ The `client.chat_completion` method is aliased as `client.chat.completions.create` for compatibility with OpenAI's client.
552
+ Inputs and outputs are strictly the same and using either syntax will yield the same results.
553
+ Check out the [Inference guide](https://huggingface.co/docs/huggingface_hub/guides/inference#openai-compatibility)
554
+ for more details about OpenAI's compatibility.
519
555
 
520
556
  </Tip>
521
557
 
@@ -526,6 +562,9 @@ class InferenceClient:
526
562
  The model to use for chat-completion. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
527
563
  Inference Endpoint. If not provided, the default recommended model for chat-based text-generation will be used.
528
564
  See https://huggingface.co/tasks/text-generation for more details.
565
+
566
+ If `model` is a model ID, it is passed to the server as the `model` parameter. If you want to define a
567
+ custom URL while setting `model` in the request payload, you must set `base_url` when initializing [`InferenceClient`].
529
568
  frequency_penalty (`float`, *optional*):
530
569
  Penalizes new tokens based on their existing frequency
531
570
  in the text so far. Range: [-2.0, 2.0]. Defaults to 0.0.
@@ -545,6 +584,8 @@ class InferenceClient:
545
584
  presence_penalty (`float`, *optional*):
546
585
  Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the
547
586
  text so far, increasing the model's likelihood to talk about new topics.
587
+ response_format ([`ChatCompletionInputGrammarType`], *optional*):
588
+ Grammar constraints. Can be either a JSONSchema or a regex.
548
589
  seed (Optional[`int`], *optional*):
549
590
  Seed for reproducible control flow. Defaults to None.
550
591
  stop (Optional[`str`], *optional*):
@@ -562,7 +603,7 @@ class InferenceClient:
562
603
  top_p (`float`, *optional*):
563
604
  Fraction of the most likely next words to sample from.
564
605
  Must be between 0 and 1. Defaults to 1.0.
565
- tool_choice ([`ChatCompletionInputToolTypeClass`] or [`ChatCompletionInputToolTypeEnum`], *optional*):
606
+ tool_choice ([`ChatCompletionInputToolTypeClass`] or `str`, *optional*):
566
607
  The tool to use for the completion. Defaults to "auto".
567
608
  tool_prompt (`str`, *optional*):
568
609
  A prompt to be appended before the tools.
@@ -571,7 +612,7 @@ class InferenceClient:
571
612
  provide a list of functions the model may generate JSON inputs for.
572
613
 
573
614
  Returns:
574
- [`ChatCompletionOutput] or Iterable of [`ChatCompletionStreamOutput`]:
615
+ [`ChatCompletionOutput`] or Iterable of [`ChatCompletionStreamOutput`]:
575
616
  Generated text returned from the server:
576
617
  - if `stream=False`, the generated text is returned as a [`ChatCompletionOutput`] (default).
577
618
  - if `stream=True`, the generated text is returned token by token as a sequence of [`ChatCompletionStreamOutput`].
@@ -585,10 +626,9 @@ class InferenceClient:
585
626
  Example:
586
627
 
587
628
  ```py
588
- # Chat example
589
629
  >>> from huggingface_hub import InferenceClient
590
630
  >>> messages = [{"role": "user", "content": "What is the capital of France?"}]
591
- >>> client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
631
+ >>> client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct")
592
632
  >>> client.chat_completion(messages, max_tokens=100)
593
633
  ChatCompletionOutput(
594
634
  choices=[
@@ -596,21 +636,67 @@ class InferenceClient:
596
636
  finish_reason='eos_token',
597
637
  index=0,
598
638
  message=ChatCompletionOutputMessage(
599
- content='The capital of France is Paris. The official name of the city is Ville de Paris (City of Paris) and the name of the country governing body, which is located in Paris, is La République française (The French Republic). \nI hope that helps! Let me know if you need any further information.'
600
- )
639
+ role='assistant',
640
+ content='The capital of France is Paris.',
641
+ name=None,
642
+ tool_calls=None
643
+ ),
644
+ logprobs=None
601
645
  )
602
646
  ],
603
- created=1710498360
647
+ created=1719907176,
648
+ id='',
649
+ model='meta-llama/Meta-Llama-3-8B-Instruct',
650
+ object='text_completion',
651
+ system_fingerprint='2.0.4-sha-f426a33',
652
+ usage=ChatCompletionOutputUsage(
653
+ completion_tokens=8,
654
+ prompt_tokens=17,
655
+ total_tokens=25
656
+ )
604
657
  )
658
+ ```
605
659
 
660
+ Example (stream=True):
661
+ ```py
662
+ >>> from huggingface_hub import InferenceClient
663
+ >>> messages = [{"role": "user", "content": "What is the capital of France?"}]
664
+ >>> client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct")
606
665
  >>> for token in client.chat_completion(messages, max_tokens=10, stream=True):
607
666
  ... print(token)
608
667
  ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content='The', role='assistant'), index=0, finish_reason=None)], created=1710498504)
609
668
  ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content=' capital', role='assistant'), index=0, finish_reason=None)], created=1710498504)
610
669
  (...)
611
670
  ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content=' may', role='assistant'), index=0, finish_reason=None)], created=1710498504)
671
+ ```
672
+
673
+ Example using OpenAI's syntax:
674
+ ```py
675
+ # instead of `from openai import OpenAI`
676
+ from huggingface_hub import InferenceClient
677
+
678
+ # instead of `client = OpenAI(...)`
679
+ client = InferenceClient(
680
+ base_url=...,
681
+ api_key=...,
682
+ )
612
683
 
613
- # Chat example with tools
684
+ output = client.chat.completions.create(
685
+ model="meta-llama/Meta-Llama-3-8B-Instruct",
686
+ messages=[
687
+ {"role": "system", "content": "You are a helpful assistant."},
688
+ {"role": "user", "content": "Count to 10"},
689
+ ],
690
+ stream=True,
691
+ max_tokens=1024,
692
+ )
693
+
694
+ for chunk in output:
695
+ print(chunk.choices[0].delta.content)
696
+ ```
697
+
698
+ Example using tools:
699
+ ```py
614
700
  >>> client = InferenceClient("meta-llama/Meta-Llama-3-70B-Instruct")
615
701
  >>> messages = [
616
702
  ... {
@@ -691,113 +777,89 @@ class InferenceClient:
691
777
  description=None
692
778
  )
693
779
  ```
694
- """
695
- # determine model
696
- model = model or self.model or self.get_recommended_model("text-generation")
697
-
698
- if _is_chat_completion_server(model):
699
- # First, let's consider the server has a `/v1/chat/completions` endpoint.
700
- # If that's the case, we don't have to render the chat template client-side.
701
- model_url = self._resolve_url(model)
702
- if not model_url.endswith("/chat/completions"):
703
- model_url += "/v1/chat/completions"
704
780
 
705
- try:
706
- data = self.post(
707
- model=model_url,
708
- json=dict(
709
- model="tgi", # random string
710
- messages=messages,
711
- frequency_penalty=frequency_penalty,
712
- logit_bias=logit_bias,
713
- logprobs=logprobs,
714
- max_tokens=max_tokens,
715
- n=n,
716
- presence_penalty=presence_penalty,
717
- seed=seed,
718
- stop=stop,
719
- temperature=temperature,
720
- tool_choice=tool_choice,
721
- tool_prompt=tool_prompt,
722
- tools=tools,
723
- top_logprobs=top_logprobs,
724
- top_p=top_p,
725
- stream=stream,
726
- ),
727
- stream=stream,
728
- )
729
- except HTTPError as e:
730
- if e.response.status_code in (400, 404, 500):
731
- # Let's consider the server is not a chat completion server.
732
- # Then we call again `chat_completion` which will render the chat template client side.
733
- # (can be HTTP 500, HTTP 400, HTTP 404 depending on the server)
734
- _set_as_non_chat_completion_server(model)
735
- logger.warning(
736
- f"Server {model_url} does not seem to support chat completion. Falling back to text generation. Error: {e}"
737
- )
738
- return self.chat_completion(
739
- messages=messages,
740
- model=model,
741
- stream=stream,
742
- max_tokens=max_tokens,
743
- seed=seed,
744
- stop=stop,
745
- temperature=temperature,
746
- top_p=top_p,
747
- )
748
- raise
749
-
750
- if stream:
751
- return _stream_chat_completion_response_from_bytes(data) # type: ignore[arg-type]
752
-
753
- return ChatCompletionOutput.parse_obj_as_instance(data) # type: ignore[arg-type]
781
+ Example using response_format:
782
+ ```py
783
+ >>> from huggingface_hub import InferenceClient
784
+ >>> client = InferenceClient("meta-llama/Meta-Llama-3-70B-Instruct")
785
+ >>> messages = [
786
+ ... {
787
+ ... "role": "user",
788
+ ... "content": "I saw a puppy a cat and a raccoon during my bike ride in the park. What did I saw and when?",
789
+ ... },
790
+ ... ]
791
+ >>> response_format = {
792
+ ... "type": "json",
793
+ ... "value": {
794
+ ... "properties": {
795
+ ... "location": {"type": "string"},
796
+ ... "activity": {"type": "string"},
797
+ ... "animals_seen": {"type": "integer", "minimum": 1, "maximum": 5},
798
+ ... "animals": {"type": "array", "items": {"type": "string"}},
799
+ ... },
800
+ ... "required": ["location", "activity", "animals_seen", "animals"],
801
+ ... },
802
+ ... }
803
+ >>> response = client.chat_completion(
804
+ ... messages=messages,
805
+ ... response_format=response_format,
806
+ ... max_tokens=500,
807
+ )
808
+ >>> response.choices[0].message.content
809
+ '{\n\n"activity": "bike ride",\n"animals": ["puppy", "cat", "raccoon"],\n"animals_seen": 3,\n"location": "park"}'
810
+ ```
811
+ """
812
+ # Determine model
813
+ # `self.xxx` takes precedence over the method argument only in `chat_completion`
814
+ # since `chat_completion(..., model=xxx)` is also a payload parameter for the
815
+ # server, we need to handle it differently
816
+ model = self.base_url or self.model or model or self.get_recommended_model("text-generation")
817
+ is_url = model.startswith(("http://", "https://"))
818
+
819
+ # First, resolve the model chat completions URL
820
+ if model == self.base_url:
821
+ # base_url passed => add server route
822
+ model_url = model + "/v1/chat/completions"
823
+ elif is_url:
824
+ # model is a URL => use it directly
825
+ model_url = model
826
+ else:
827
+ # model is a model ID => resolve it + add server route
828
+ model_url = self._resolve_url(model) + "/v1/chat/completions"
829
+
830
+ # `model` is sent in the payload. Not used by the server but can be useful for debugging/routing.
831
+ # If it's a ID on the Hub => use it. Otherwise, we use a random string.
832
+ model_id = model if not is_url and model.count("/") == 1 else "tgi"
833
+
834
+ data = self.post(
835
+ model=model_url,
836
+ json=dict(
837
+ model=model_id,
838
+ messages=messages,
839
+ frequency_penalty=frequency_penalty,
840
+ logit_bias=logit_bias,
841
+ logprobs=logprobs,
842
+ max_tokens=max_tokens,
843
+ n=n,
844
+ presence_penalty=presence_penalty,
845
+ response_format=response_format,
846
+ seed=seed,
847
+ stop=stop,
848
+ temperature=temperature,
849
+ tool_choice=tool_choice,
850
+ tool_prompt=tool_prompt,
851
+ tools=tools,
852
+ top_logprobs=top_logprobs,
853
+ top_p=top_p,
854
+ stream=stream,
855
+ ),
856
+ stream=stream,
857
+ )
754
858
 
755
- # At this point, we know the server is not a chat completion server.
756
- # It means it's a transformers-backed server for which we can send a list of messages directly to the
757
- # `text-generation` pipeline. We won't receive a detailed response but only the generated text.
758
859
  if stream:
759
- raise ValueError(
760
- "Streaming token is not supported by the model. This is due to the model not been served by a "
761
- "Text-Generation-Inference server. Please pass `stream=False` as input."
762
- )
763
- if tool_choice is not None or tool_prompt is not None or tools is not None:
764
- warnings.warn(
765
- "Tools are not supported by the model. This is due to the model not been served by a "
766
- "Text-Generation-Inference server. The provided tool parameters will be ignored."
767
- )
768
-
769
- # generate response
770
- text_generation_output = self.text_generation(
771
- prompt=messages, # type: ignore # Not correct type but works implicitly
772
- model=model,
773
- stream=False,
774
- details=False,
775
- max_new_tokens=max_tokens,
776
- seed=seed,
777
- stop_sequences=stop,
778
- temperature=temperature,
779
- top_p=top_p,
780
- )
860
+ return _stream_chat_completion_response(data) # type: ignore[arg-type]
781
861
 
782
- # Format as a ChatCompletionOutput with dummy values for fields we can't provide
783
- return ChatCompletionOutput(
784
- id="dummy",
785
- model="dummy",
786
- object="dummy",
787
- system_fingerprint="dummy",
788
- usage=None, # type: ignore # set to `None` as we don't want to provide false information
789
- created=int(time.time()),
790
- choices=[
791
- ChatCompletionOutputComplete(
792
- finish_reason="unk", # type: ignore # set to `unk` as we don't want to provide false information
793
- index=0,
794
- message=ChatCompletionOutputMessage(
795
- content=text_generation_output,
796
- role="assistant",
797
- ),
798
- )
799
- ],
800
- )
862
+ return ChatCompletionOutput.parse_obj_as_instance(data) # type: ignore[arg-type]
801
863
 
802
864
  def conversational(
803
865
  self,
@@ -913,7 +975,16 @@ class InferenceClient:
913
975
  response = self.post(json=payload, model=model, task="document-question-answering")
914
976
  return DocumentQuestionAnsweringOutputElement.parse_obj_as_list(response)
915
977
 
916
- def feature_extraction(self, text: str, *, model: Optional[str] = None) -> "np.ndarray":
978
+ def feature_extraction(
979
+ self,
980
+ text: str,
981
+ *,
982
+ normalize: Optional[bool] = None,
983
+ prompt_name: Optional[str] = None,
984
+ truncate: Optional[bool] = None,
985
+ truncation_direction: Optional[Literal["Left", "Right"]] = None,
986
+ model: Optional[str] = None,
987
+ ) -> "np.ndarray":
917
988
  """
918
989
  Generate embeddings for a given text.
919
990
 
@@ -924,6 +995,20 @@ class InferenceClient:
924
995
  The model to use for the conversational task. Can be a model ID hosted on the Hugging Face Hub or a URL to
925
996
  a deployed Inference Endpoint. If not provided, the default recommended conversational model will be used.
926
997
  Defaults to None.
998
+ normalize (`bool`, *optional*):
999
+ Whether to normalize the embeddings or not. Defaults to None.
1000
+ Only available on server powered by Text-Embedding-Inference.
1001
+ prompt_name (`str`, *optional*):
1002
+ The name of the prompt that should be used by for encoding. If not set, no prompt will be applied.
1003
+ Must be a key in the `Sentence Transformers` configuration `prompts` dictionary.
1004
+ For example if ``prompt_name`` is "query" and the ``prompts`` is {"query": "query: ",...},
1005
+ then the sentence "What is the capital of France?" will be encoded as "query: What is the capital of France?"
1006
+ because the prompt text will be prepended before any text to encode.
1007
+ truncate (`bool`, *optional*):
1008
+ Whether to truncate the embeddings or not. Defaults to None.
1009
+ Only available on server powered by Text-Embedding-Inference.
1010
+ truncation_direction (`Literal["Left", "Right"]`, *optional*):
1011
+ Which side of the input should be truncated when `truncate=True` is passed.
927
1012
 
928
1013
  Returns:
929
1014
  `np.ndarray`: The embedding representing the input text as a float32 numpy array.
@@ -945,7 +1030,16 @@ class InferenceClient:
945
1030
  [ 0.28552425, -0.928395 , -1.2077185 , ..., 0.76810825, -2.1069427 , 0.6236161 ]], dtype=float32)
946
1031
  ```
947
1032
  """
948
- response = self.post(json={"inputs": text}, model=model, task="feature-extraction")
1033
+ payload: Dict = {"inputs": text}
1034
+ if normalize is not None:
1035
+ payload["normalize"] = normalize
1036
+ if prompt_name is not None:
1037
+ payload["prompt_name"] = prompt_name
1038
+ if truncate is not None:
1039
+ payload["truncate"] = truncate
1040
+ if truncation_direction is not None:
1041
+ payload["truncation_direction"] = truncation_direction
1042
+ response = self.post(json=payload, model=model, task="feature-extraction")
949
1043
  np = _import_numpy()
950
1044
  return np.array(_bytes_to_dict(response), dtype="float32")
951
1045
 
@@ -1184,7 +1278,8 @@ class InferenceClient:
1184
1278
  ```
1185
1279
  """
1186
1280
  response = self.post(data=image, model=model, task="image-to-text")
1187
- return ImageToTextOutput.parse_obj_as_instance(response)
1281
+ output = ImageToTextOutput.parse_obj(response)
1282
+ return output[0] if isinstance(output, list) else output
1188
1283
 
1189
1284
  def list_deployed_models(
1190
1285
  self, frameworks: Union[None, str, Literal["all"], List[str]] = None
@@ -1619,6 +1714,7 @@ class InferenceClient:
1619
1714
  stream: Literal[False] = ...,
1620
1715
  model: Optional[str] = None,
1621
1716
  # Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
1717
+ adapter_id: Optional[str] = None,
1622
1718
  best_of: Optional[int] = None,
1623
1719
  decoder_input_details: Optional[bool] = None,
1624
1720
  do_sample: Optional[bool] = False, # Manual default value
@@ -1647,6 +1743,7 @@ class InferenceClient:
1647
1743
  stream: Literal[False] = ...,
1648
1744
  model: Optional[str] = None,
1649
1745
  # Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
1746
+ adapter_id: Optional[str] = None,
1650
1747
  best_of: Optional[int] = None,
1651
1748
  decoder_input_details: Optional[bool] = None,
1652
1749
  do_sample: Optional[bool] = False, # Manual default value
@@ -1675,6 +1772,7 @@ class InferenceClient:
1675
1772
  stream: Literal[True] = ...,
1676
1773
  model: Optional[str] = None,
1677
1774
  # Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
1775
+ adapter_id: Optional[str] = None,
1678
1776
  best_of: Optional[int] = None,
1679
1777
  decoder_input_details: Optional[bool] = None,
1680
1778
  do_sample: Optional[bool] = False, # Manual default value
@@ -1703,6 +1801,7 @@ class InferenceClient:
1703
1801
  stream: Literal[True] = ...,
1704
1802
  model: Optional[str] = None,
1705
1803
  # Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
1804
+ adapter_id: Optional[str] = None,
1706
1805
  best_of: Optional[int] = None,
1707
1806
  decoder_input_details: Optional[bool] = None,
1708
1807
  do_sample: Optional[bool] = False, # Manual default value
@@ -1731,6 +1830,7 @@ class InferenceClient:
1731
1830
  stream: bool = ...,
1732
1831
  model: Optional[str] = None,
1733
1832
  # Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
1833
+ adapter_id: Optional[str] = None,
1734
1834
  best_of: Optional[int] = None,
1735
1835
  decoder_input_details: Optional[bool] = None,
1736
1836
  do_sample: Optional[bool] = False, # Manual default value
@@ -1758,6 +1858,7 @@ class InferenceClient:
1758
1858
  stream: bool = False,
1759
1859
  model: Optional[str] = None,
1760
1860
  # Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
1861
+ adapter_id: Optional[str] = None,
1761
1862
  best_of: Optional[int] = None,
1762
1863
  decoder_input_details: Optional[bool] = None,
1763
1864
  do_sample: Optional[bool] = False, # Manual default value
@@ -1788,6 +1889,13 @@ class InferenceClient:
1788
1889
 
1789
1890
  To learn more about the TGI project, please refer to https://github.com/huggingface/text-generation-inference.
1790
1891
 
1892
+ <Tip>
1893
+
1894
+ If you want to generate a response from chat messages, you should use the [`InferenceClient.chat_completion`] method.
1895
+ It accepts a list of messages instead of a single text prompt and handles the chat templating for you.
1896
+
1897
+ </Tip>
1898
+
1791
1899
  Args:
1792
1900
  prompt (`str`):
1793
1901
  Input text.
@@ -1802,6 +1910,8 @@ class InferenceClient:
1802
1910
  model (`str`, *optional*):
1803
1911
  The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
1804
1912
  Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
1913
+ adapter_id (`str`, *optional*):
1914
+ Lora adapter id.
1805
1915
  best_of (`int`, *optional*):
1806
1916
  Generate best_of sequences and return the one if the highest token logprobs.
1807
1917
  decoder_input_details (`bool`, *optional*):
@@ -1970,6 +2080,7 @@ class InferenceClient:
1970
2080
 
1971
2081
  # Build payload
1972
2082
  parameters = {
2083
+ "adapter_id": adapter_id,
1973
2084
  "best_of": best_of,
1974
2085
  "decoder_input_details": decoder_input_details,
1975
2086
  "details": details,
@@ -2040,6 +2151,7 @@ class InferenceClient:
2040
2151
  details=details,
2041
2152
  stream=stream,
2042
2153
  model=model,
2154
+ adapter_id=adapter_id,
2043
2155
  best_of=best_of,
2044
2156
  decoder_input_details=decoder_input_details,
2045
2157
  do_sample=do_sample,
@@ -2064,7 +2176,12 @@ class InferenceClient:
2064
2176
  if stream:
2065
2177
  return _stream_text_generation_response(bytes_output, details) # type: ignore
2066
2178
 
2067
- data = _bytes_to_dict(bytes_output)[0] # type: ignore[arg-type]
2179
+ data = _bytes_to_dict(bytes_output) # type: ignore[arg-type]
2180
+
2181
+ # Data can be a single element (dict) or an iterable of dicts where we select the first element of.
2182
+ if isinstance(data, list):
2183
+ data = data[0]
2184
+
2068
2185
  return TextGenerationOutput.parse_obj_as_instance(data) if details else data["generated_text"]
2069
2186
 
2070
2187
  def text_to_image(
@@ -2347,7 +2464,13 @@ class InferenceClient:
2347
2464
  return VisualQuestionAnsweringOutputElement.parse_obj_as_list(response)
2348
2465
 
2349
2466
  def zero_shot_classification(
2350
- self, text: str, labels: List[str], *, multi_label: bool = False, model: Optional[str] = None
2467
+ self,
2468
+ text: str,
2469
+ labels: List[str],
2470
+ *,
2471
+ multi_label: bool = False,
2472
+ hypothesis_template: Optional[str] = None,
2473
+ model: Optional[str] = None,
2351
2474
  ) -> List[ZeroShotClassificationOutputElement]:
2352
2475
  """
2353
2476
  Provide as input a text and a set of candidate labels to classify the input text.
@@ -2356,9 +2479,15 @@ class InferenceClient:
2356
2479
  text (`str`):
2357
2480
  The input text to classify.
2358
2481
  labels (`List[str]`):
2359
- List of string possible labels. There must be at least 2 labels.
2482
+ List of strings. Each string is the verbalization of a possible label for the input text.
2360
2483
  multi_label (`bool`):
2361
- Boolean that is set to True if classes can overlap.
2484
+ Boolean. If True, the probability for each label is evaluated independently and multiple labels can have a probability close to 1 simultaneously or all probabilities can be close to 0.
2485
+ If False, the labels are considered mutually exclusive and the probability over all labels always sums to 1. Defaults to False.
2486
+ hypothesis_template (`str`, *optional*):
2487
+ A template sentence string with curly brackets to which the label strings are added. The label strings are added at the position of the curly brackets "{}".
2488
+ Zero-shot classifiers are based on NLI models, which evaluate if a hypothesis is entailed in another text or not.
2489
+ For example, with hypothesis_template="This text is about {}." and labels=["economics", "politics"], the system internally creates the two hypotheses "This text is about economics." and "This text is about politics.".
2490
+ The model then evaluates for both hypotheses if they are entailed in the provided `text` or not.
2362
2491
  model (`str`, *optional*):
2363
2492
  The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
2364
2493
  Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
@@ -2372,7 +2501,7 @@ class InferenceClient:
2372
2501
  `HTTPError`:
2373
2502
  If the request fails with an HTTP error status code other than HTTP 503.
2374
2503
 
2375
- Example:
2504
+ Example with `multi_label=False`:
2376
2505
  ```py
2377
2506
  >>> from huggingface_hub import InferenceClient
2378
2507
  >>> client = InferenceClient()
@@ -2399,21 +2528,37 @@ class InferenceClient:
2399
2528
  ZeroShotClassificationOutputElement(label='robots', score=0.00030448526376858354),
2400
2529
  ]
2401
2530
  ```
2531
+
2532
+ Example with `multi_label=True` and a custom `hypothesis_template`:
2533
+ ```py
2534
+ >>> from huggingface_hub import InferenceClient
2535
+ >>> client = InferenceClient()
2536
+ >>> client.zero_shot_classification(
2537
+ ... text="I really like our dinner and I'm very happy. I don't like the weather though.",
2538
+ ... labels=["positive", "negative", "pessimistic", "optimistic"],
2539
+ ... multi_label=True,
2540
+ ... hypothesis_template="This text is {} towards the weather"
2541
+ ... )
2542
+ [
2543
+ ZeroShotClassificationOutputElement(label='negative', score=0.9231801629066467),
2544
+ ZeroShotClassificationOutputElement(label='pessimistic', score=0.8760990500450134),
2545
+ ZeroShotClassificationOutputElement(label='optimistic', score=0.0008674879791215062),
2546
+ ZeroShotClassificationOutputElement(label='positive', score=0.0005250611575320363)
2547
+ ]
2548
+ ```
2402
2549
  """
2403
- # Raise ValueError if input is less than 2 labels
2404
- if len(labels) < 2:
2405
- raise ValueError("You must specify at least 2 classes to compare.")
2550
+
2551
+ parameters = {"candidate_labels": labels, "multi_label": multi_label}
2552
+ if hypothesis_template is not None:
2553
+ parameters["hypothesis_template"] = hypothesis_template
2406
2554
 
2407
2555
  response = self.post(
2408
2556
  json={
2409
2557
  "inputs": text,
2410
- "parameters": {
2411
- "candidate_labels": ",".join(labels),
2412
- "multi_label": multi_label,
2413
- },
2558
+ "parameters": parameters,
2414
2559
  },
2415
- model=model,
2416
2560
  task="zero-shot-classification",
2561
+ model=model,
2417
2562
  )
2418
2563
  output = _bytes_to_dict(response)
2419
2564
  return [
@@ -2469,7 +2614,7 @@ class InferenceClient:
2469
2614
  return ZeroShotImageClassificationOutputElement.parse_obj_as_list(response)
2470
2615
 
2471
2616
  def _resolve_url(self, model: Optional[str] = None, task: Optional[str] = None) -> str:
2472
- model = model or self.model
2617
+ model = model or self.model or self.base_url
2473
2618
 
2474
2619
  # If model is already a URL, ignore `task` and return directly
2475
2620
  if model is not None and (model.startswith("http://") or model.startswith("https://")):
@@ -2522,6 +2667,95 @@ class InferenceClient:
2522
2667
  )
2523
2668
  return model
2524
2669
 
2670
+ def get_endpoint_info(self, *, model: Optional[str] = None) -> Dict[str, Any]:
2671
+ """
2672
+ Get information about the deployed endpoint.
2673
+
2674
+ This endpoint is only available on endpoints powered by Text-Generation-Inference (TGI) or Text-Embedding-Inference (TEI).
2675
+ Endpoints powered by `transformers` return an empty payload.
2676
+
2677
+ Args:
2678
+ model (`str`, *optional*):
2679
+ The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
2680
+ Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
2681
+
2682
+ Returns:
2683
+ `Dict[str, Any]`: Information about the endpoint.
2684
+
2685
+ Example:
2686
+ ```py
2687
+ >>> from huggingface_hub import InferenceClient
2688
+ >>> client = InferenceClient("meta-llama/Meta-Llama-3-70B-Instruct")
2689
+ >>> client.get_endpoint_info()
2690
+ {
2691
+ 'model_id': 'meta-llama/Meta-Llama-3-70B-Instruct',
2692
+ 'model_sha': None,
2693
+ 'model_dtype': 'torch.float16',
2694
+ 'model_device_type': 'cuda',
2695
+ 'model_pipeline_tag': None,
2696
+ 'max_concurrent_requests': 128,
2697
+ 'max_best_of': 2,
2698
+ 'max_stop_sequences': 4,
2699
+ 'max_input_length': 8191,
2700
+ 'max_total_tokens': 8192,
2701
+ 'waiting_served_ratio': 0.3,
2702
+ 'max_batch_total_tokens': 1259392,
2703
+ 'max_waiting_tokens': 20,
2704
+ 'max_batch_size': None,
2705
+ 'validation_workers': 32,
2706
+ 'max_client_batch_size': 4,
2707
+ 'version': '2.0.2',
2708
+ 'sha': 'dccab72549635c7eb5ddb17f43f0b7cdff07c214',
2709
+ 'docker_label': 'sha-dccab72'
2710
+ }
2711
+ ```
2712
+ """
2713
+ model = model or self.model
2714
+ if model is None:
2715
+ raise ValueError("Model id not provided.")
2716
+ if model.startswith(("http://", "https://")):
2717
+ url = model.rstrip("/") + "/info"
2718
+ else:
2719
+ url = f"{INFERENCE_ENDPOINT}/models/{model}/info"
2720
+
2721
+ response = get_session().get(url, headers=self.headers)
2722
+ hf_raise_for_status(response)
2723
+ return response.json()
2724
+
2725
+ def health_check(self, model: Optional[str] = None) -> bool:
2726
+ """
2727
+ Check the health of the deployed endpoint.
2728
+
2729
+ Health check is only available with Inference Endpoints powered by Text-Generation-Inference (TGI) or Text-Embedding-Inference (TEI).
2730
+ For Inference API, please use [`InferenceClient.get_model_status`] instead.
2731
+
2732
+ Args:
2733
+ model (`str`, *optional*):
2734
+ URL of the Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
2735
+
2736
+ Returns:
2737
+ `bool`: True if everything is working fine.
2738
+
2739
+ Example:
2740
+ ```py
2741
+ >>> from huggingface_hub import InferenceClient
2742
+ >>> client = InferenceClient("https://jzgu0buei5.us-east-1.aws.endpoints.huggingface.cloud")
2743
+ >>> client.health_check()
2744
+ True
2745
+ ```
2746
+ """
2747
+ model = model or self.model
2748
+ if model is None:
2749
+ raise ValueError("Model id not provided.")
2750
+ if not model.startswith(("http://", "https://")):
2751
+ raise ValueError(
2752
+ "Model must be an Inference Endpoint URL. For serverless Inference API, please use `InferenceClient.get_model_status`."
2753
+ )
2754
+ url = model.rstrip("/") + "/health"
2755
+
2756
+ response = get_session().get(url, headers=self.headers)
2757
+ return response.status_code == 200
2758
+
2525
2759
  def get_model_status(self, model: Optional[str] = None) -> ModelStatus:
2526
2760
  """
2527
2761
  Get the status of a model hosted on the Inference API.
@@ -2548,7 +2782,7 @@ class InferenceClient:
2548
2782
  ```py
2549
2783
  >>> from huggingface_hub import InferenceClient
2550
2784
  >>> client = InferenceClient()
2551
- >>> client.get_model_status("bigcode/starcoder")
2785
+ >>> client.get_model_status("meta-llama/Meta-Llama-3-8B-Instruct")
2552
2786
  ModelStatus(loaded=True, state='Loaded', compute_type='gpu', framework='text-generation-inference')
2553
2787
  ```
2554
2788
  """
@@ -2572,3 +2806,30 @@ class InferenceClient:
2572
2806
  compute_type=response_data["compute_type"],
2573
2807
  framework=response_data["framework"],
2574
2808
  )
2809
+
2810
+ @property
2811
+ def chat(self) -> "ProxyClientChat":
2812
+ return ProxyClientChat(self)
2813
+
2814
+
2815
+ class _ProxyClient:
2816
+ """Proxy class to be able to call `client.chat.completion.create(...)` as OpenAI client."""
2817
+
2818
+ def __init__(self, client: InferenceClient):
2819
+ self._client = client
2820
+
2821
+
2822
+ class ProxyClientChat(_ProxyClient):
2823
+ """Proxy class to be able to call `client.chat.completion.create(...)` as OpenAI client."""
2824
+
2825
+ @property
2826
+ def completions(self) -> "ProxyClientChatCompletions":
2827
+ return ProxyClientChatCompletions(self._client)
2828
+
2829
+
2830
+ class ProxyClientChatCompletions(_ProxyClient):
2831
+ """Proxy class to be able to call `client.chat.completion.create(...)` as OpenAI client."""
2832
+
2833
+ @property
2834
+ def create(self):
2835
+ return self._client.chat_completion