huggingface-hub 0.22.2__py3-none-any.whl → 0.23.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of huggingface-hub might be problematic. Click here for more details.

Files changed (45) hide show
  1. huggingface_hub/__init__.py +51 -19
  2. huggingface_hub/_commit_api.py +9 -8
  3. huggingface_hub/_commit_scheduler.py +2 -2
  4. huggingface_hub/_inference_endpoints.py +10 -17
  5. huggingface_hub/_local_folder.py +229 -0
  6. huggingface_hub/_login.py +4 -3
  7. huggingface_hub/_multi_commits.py +1 -1
  8. huggingface_hub/_snapshot_download.py +16 -38
  9. huggingface_hub/_tensorboard_logger.py +16 -6
  10. huggingface_hub/_webhooks_payload.py +22 -1
  11. huggingface_hub/_webhooks_server.py +24 -20
  12. huggingface_hub/commands/download.py +11 -34
  13. huggingface_hub/commands/huggingface_cli.py +2 -0
  14. huggingface_hub/commands/tag.py +159 -0
  15. huggingface_hub/constants.py +3 -5
  16. huggingface_hub/errors.py +58 -0
  17. huggingface_hub/file_download.py +545 -376
  18. huggingface_hub/hf_api.py +756 -622
  19. huggingface_hub/hf_file_system.py +14 -5
  20. huggingface_hub/hub_mixin.py +127 -43
  21. huggingface_hub/inference/_client.py +402 -183
  22. huggingface_hub/inference/_common.py +19 -29
  23. huggingface_hub/inference/_generated/_async_client.py +402 -184
  24. huggingface_hub/inference/_generated/types/__init__.py +23 -6
  25. huggingface_hub/inference/_generated/types/chat_completion.py +197 -43
  26. huggingface_hub/inference/_generated/types/text_generation.py +57 -79
  27. huggingface_hub/inference/_templating.py +2 -4
  28. huggingface_hub/keras_mixin.py +0 -3
  29. huggingface_hub/lfs.py +9 -1
  30. huggingface_hub/repository.py +1 -0
  31. huggingface_hub/utils/__init__.py +12 -6
  32. huggingface_hub/utils/_fixes.py +1 -0
  33. huggingface_hub/utils/_headers.py +2 -4
  34. huggingface_hub/utils/_http.py +2 -4
  35. huggingface_hub/utils/_paths.py +13 -1
  36. huggingface_hub/utils/_runtime.py +10 -0
  37. huggingface_hub/utils/_safetensors.py +0 -13
  38. huggingface_hub/utils/_validators.py +2 -7
  39. huggingface_hub/utils/tqdm.py +124 -46
  40. {huggingface_hub-0.22.2.dist-info → huggingface_hub-0.23.0rc0.dist-info}/METADATA +5 -1
  41. {huggingface_hub-0.22.2.dist-info → huggingface_hub-0.23.0rc0.dist-info}/RECORD +45 -43
  42. {huggingface_hub-0.22.2.dist-info → huggingface_hub-0.23.0rc0.dist-info}/LICENSE +0 -0
  43. {huggingface_hub-0.22.2.dist-info → huggingface_hub-0.23.0rc0.dist-info}/WHEEL +0 -0
  44. {huggingface_hub-0.22.2.dist-info → huggingface_hub-0.23.0rc0.dist-info}/entry_points.txt +0 -0
  45. {huggingface_hub-0.22.2.dist-info → huggingface_hub-0.23.0rc0.dist-info}/top_level.txt +0 -0
@@ -34,6 +34,7 @@
34
34
  # - Only the main parameters are publicly exposed. Power users can always read the docs for more options.
35
35
  import base64
36
36
  import logging
37
+ import re
37
38
  import time
38
39
  import warnings
39
40
  from typing import (
@@ -63,14 +64,13 @@ from huggingface_hub.inference._common import (
63
64
  _bytes_to_image,
64
65
  _bytes_to_list,
65
66
  _fetch_recommended_models,
67
+ _get_unsupported_text_generation_kwargs,
66
68
  _import_numpy,
67
69
  _is_chat_completion_server,
68
- _is_tgi_server,
69
70
  _open_as_binary,
70
71
  _set_as_non_chat_completion_server,
71
- _set_as_non_tgi,
72
+ _set_unsupported_text_generation_kwargs,
72
73
  _stream_chat_completion_response_from_bytes,
73
- _stream_chat_completion_response_from_text_generation,
74
74
  _stream_text_generation_response,
75
75
  raise_text_generation_error,
76
76
  )
@@ -78,9 +78,11 @@ from huggingface_hub.inference._generated.types import (
78
78
  AudioClassificationOutputElement,
79
79
  AudioToAudioOutputElement,
80
80
  AutomaticSpeechRecognitionOutput,
81
+ ChatCompletionInputTool,
82
+ ChatCompletionInputToolTypeClass,
81
83
  ChatCompletionOutput,
82
- ChatCompletionOutputChoice,
83
- ChatCompletionOutputChoiceMessage,
84
+ ChatCompletionOutputComplete,
85
+ ChatCompletionOutputMessage,
84
86
  ChatCompletionStreamOutput,
85
87
  DocumentQuestionAnsweringOutputElement,
86
88
  FillMaskOutputElement,
@@ -92,6 +94,7 @@ from huggingface_hub.inference._generated.types import (
92
94
  SummarizationOutput,
93
95
  TableQuestionAnsweringOutputElement,
94
96
  TextClassificationOutputElement,
97
+ TextGenerationInputGrammarType,
95
98
  TextGenerationOutput,
96
99
  TextGenerationStreamOutput,
97
100
  TokenClassificationOutputElement,
@@ -100,7 +103,7 @@ from huggingface_hub.inference._generated.types import (
100
103
  ZeroShotClassificationOutputElement,
101
104
  ZeroShotImageClassificationOutputElement,
102
105
  )
103
- from huggingface_hub.inference._templating import render_chat_prompt
106
+ from huggingface_hub.inference._generated.types.chat_completion import ChatCompletionInputToolTypeEnum
104
107
  from huggingface_hub.inference._types import (
105
108
  ConversationalOutput, # soon to be removed
106
109
  )
@@ -114,11 +117,14 @@ from huggingface_hub.utils import (
114
117
 
115
118
  if TYPE_CHECKING:
116
119
  import numpy as np
117
- from PIL import Image
120
+ from PIL.Image import Image
118
121
 
119
122
  logger = logging.getLogger(__name__)
120
123
 
121
124
 
125
+ MODEL_KWARGS_NOT_USED_REGEX = re.compile(r"The following `model_kwargs` are not used by the model: \[(.*?)\]")
126
+
127
+
122
128
  class InferenceClient:
123
129
  """
124
130
  Initialize a new Inference Client.
@@ -416,10 +422,19 @@ class InferenceClient:
416
422
  *,
417
423
  model: Optional[str] = None,
418
424
  stream: Literal[False] = False,
419
- max_tokens: int = 20,
425
+ frequency_penalty: Optional[float] = None,
426
+ logit_bias: Optional[List[float]] = None,
427
+ logprobs: Optional[bool] = None,
428
+ max_tokens: Optional[int] = None,
429
+ n: Optional[int] = None,
430
+ presence_penalty: Optional[float] = None,
420
431
  seed: Optional[int] = None,
421
- stop: Optional[Union[List[str], str]] = None,
422
- temperature: float = 1.0,
432
+ stop: Optional[List[str]] = None,
433
+ temperature: Optional[float] = None,
434
+ tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, ChatCompletionInputToolTypeEnum]] = None,
435
+ tool_prompt: Optional[str] = None,
436
+ tools: Optional[List[ChatCompletionInputTool]] = None,
437
+ top_logprobs: Optional[int] = None,
423
438
  top_p: Optional[float] = None,
424
439
  ) -> ChatCompletionOutput: ...
425
440
 
@@ -430,10 +445,19 @@ class InferenceClient:
430
445
  *,
431
446
  model: Optional[str] = None,
432
447
  stream: Literal[True] = True,
433
- max_tokens: int = 20,
448
+ frequency_penalty: Optional[float] = None,
449
+ logit_bias: Optional[List[float]] = None,
450
+ logprobs: Optional[bool] = None,
451
+ max_tokens: Optional[int] = None,
452
+ n: Optional[int] = None,
453
+ presence_penalty: Optional[float] = None,
434
454
  seed: Optional[int] = None,
435
- stop: Optional[Union[List[str], str]] = None,
436
- temperature: float = 1.0,
455
+ stop: Optional[List[str]] = None,
456
+ temperature: Optional[float] = None,
457
+ tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, ChatCompletionInputToolTypeEnum]] = None,
458
+ tool_prompt: Optional[str] = None,
459
+ tools: Optional[List[ChatCompletionInputTool]] = None,
460
+ top_logprobs: Optional[int] = None,
437
461
  top_p: Optional[float] = None,
438
462
  ) -> Iterable[ChatCompletionStreamOutput]: ...
439
463
 
@@ -444,10 +468,19 @@ class InferenceClient:
444
468
  *,
445
469
  model: Optional[str] = None,
446
470
  stream: bool = False,
447
- max_tokens: int = 20,
471
+ frequency_penalty: Optional[float] = None,
472
+ logit_bias: Optional[List[float]] = None,
473
+ logprobs: Optional[bool] = None,
474
+ max_tokens: Optional[int] = None,
475
+ n: Optional[int] = None,
476
+ presence_penalty: Optional[float] = None,
448
477
  seed: Optional[int] = None,
449
- stop: Optional[Union[List[str], str]] = None,
450
- temperature: float = 1.0,
478
+ stop: Optional[List[str]] = None,
479
+ temperature: Optional[float] = None,
480
+ tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, ChatCompletionInputToolTypeEnum]] = None,
481
+ tool_prompt: Optional[str] = None,
482
+ tools: Optional[List[ChatCompletionInputTool]] = None,
483
+ top_logprobs: Optional[int] = None,
451
484
  top_p: Optional[float] = None,
452
485
  ) -> Union[ChatCompletionOutput, Iterable[ChatCompletionStreamOutput]]: ...
453
486
 
@@ -457,10 +490,20 @@ class InferenceClient:
457
490
  *,
458
491
  model: Optional[str] = None,
459
492
  stream: bool = False,
460
- max_tokens: int = 20,
493
+ # Parameters from ChatCompletionInput (handled manually)
494
+ frequency_penalty: Optional[float] = None,
495
+ logit_bias: Optional[List[float]] = None,
496
+ logprobs: Optional[bool] = None,
497
+ max_tokens: Optional[int] = None,
498
+ n: Optional[int] = None,
499
+ presence_penalty: Optional[float] = None,
461
500
  seed: Optional[int] = None,
462
- stop: Optional[Union[List[str], str]] = None,
463
- temperature: float = 1.0,
501
+ stop: Optional[List[str]] = None,
502
+ temperature: Optional[float] = None,
503
+ tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, ChatCompletionInputToolTypeEnum]] = None,
504
+ tool_prompt: Optional[str] = None,
505
+ tools: Optional[List[ChatCompletionInputTool]] = None,
506
+ top_logprobs: Optional[int] = None,
464
507
  top_p: Optional[float] = None,
465
508
  ) -> Union[ChatCompletionOutput, Iterable[ChatCompletionStreamOutput]]:
466
509
  """
@@ -483,27 +526,52 @@ class InferenceClient:
483
526
  The model to use for chat-completion. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
484
527
  Inference Endpoint. If not provided, the default recommended model for chat-based text-generation will be used.
485
528
  See https://huggingface.co/tasks/text-generation for more details.
486
- frequency_penalty (`float`, optional):
529
+ frequency_penalty (`float`, *optional*):
487
530
  Penalizes new tokens based on their existing frequency
488
531
  in the text so far. Range: [-2.0, 2.0]. Defaults to 0.0.
489
- max_tokens (`int`, optional):
532
+ logit_bias (`List[float]`, *optional*):
533
+ Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens
534
+ (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,
535
+ the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,
536
+ but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should
537
+ result in a ban or exclusive selection of the relevant token. Defaults to None.
538
+ logprobs (`bool`, *optional*):
539
+ Whether to return log probabilities of the output tokens or not. If true, returns the log
540
+ probabilities of each output token returned in the content of message.
541
+ max_tokens (`int`, *optional*):
490
542
  Maximum number of tokens allowed in the response. Defaults to 20.
491
- seed (Optional[`int`], optional):
543
+ n (`int`, *optional*):
544
+ UNUSED.
545
+ presence_penalty (`float`, *optional*):
546
+ Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the
547
+ text so far, increasing the model's likelihood to talk about new topics.
548
+ seed (Optional[`int`], *optional*):
492
549
  Seed for reproducible control flow. Defaults to None.
493
- stop (Optional[`str`], optional):
550
+ stop (Optional[`str`], *optional*):
494
551
  Up to four strings which trigger the end of the response.
495
552
  Defaults to None.
496
- stream (`bool`, optional):
553
+ stream (`bool`, *optional*):
497
554
  Enable realtime streaming of responses. Defaults to False.
498
- temperature (`float`, optional):
555
+ temperature (`float`, *optional*):
499
556
  Controls randomness of the generations. Lower values ensure
500
557
  less random completions. Range: [0, 2]. Defaults to 1.0.
501
- top_p (`float`, optional):
558
+ top_logprobs (`int`, *optional*):
559
+ An integer between 0 and 5 specifying the number of most likely tokens to return at each token
560
+ position, each with an associated log probability. logprobs must be set to true if this parameter is
561
+ used.
562
+ top_p (`float`, *optional*):
502
563
  Fraction of the most likely next words to sample from.
503
564
  Must be between 0 and 1. Defaults to 1.0.
565
+ tool_choice ([`ChatCompletionInputToolTypeClass`] or [`ChatCompletionInputToolTypeEnum`], *optional*):
566
+ The tool to use for the completion. Defaults to "auto".
567
+ tool_prompt (`str`, *optional*):
568
+ A prompt to be appended before the tools.
569
+ tools (List of [`ChatCompletionInputTool`], *optional*):
570
+ A list of tools the model may call. Currently, only functions are supported as a tool. Use this to
571
+ provide a list of functions the model may generate JSON inputs for.
504
572
 
505
573
  Returns:
506
- `Union[ChatCompletionOutput, Iterable[ChatCompletionStreamOutput]]`:
574
+ [`ChatCompletionOutput] or Iterable of [`ChatCompletionStreamOutput`]:
507
575
  Generated text returned from the server:
508
576
  - if `stream=False`, the generated text is returned as a [`ChatCompletionOutput`] (default).
509
577
  - if `stream=True`, the generated text is returned token by token as a sequence of [`ChatCompletionStreamOutput`].
@@ -515,18 +583,20 @@ class InferenceClient:
515
583
  If the request fails with an HTTP error status code other than HTTP 503.
516
584
 
517
585
  Example:
586
+
518
587
  ```py
588
+ # Chat example
519
589
  >>> from huggingface_hub import InferenceClient
520
590
  >>> messages = [{"role": "user", "content": "What is the capital of France?"}]
521
591
  >>> client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
522
592
  >>> client.chat_completion(messages, max_tokens=100)
523
593
  ChatCompletionOutput(
524
594
  choices=[
525
- ChatCompletionOutputChoice(
595
+ ChatCompletionOutputComplete(
526
596
  finish_reason='eos_token',
527
597
  index=0,
528
- message=ChatCompletionOutputChoiceMessage(
529
- content='The capital of France is Paris. The official name of the city is "Ville de Paris" (City of Paris) and the name of the country\'s governing body, which is located in Paris, is "La République française" (The French Republic). \nI hope that helps! Let me know if you need any further information.'
598
+ message=ChatCompletionOutputMessage(
599
+ content='The capital of France is Paris. The official name of the city is Ville de Paris (City of Paris) and the name of the country governing body, which is located in Paris, is La République française (The French Republic). \nI hope that helps! Let me know if you need any further information.'
530
600
  )
531
601
  )
532
602
  ],
@@ -539,7 +609,87 @@ class InferenceClient:
539
609
  ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content=' capital', role='assistant'), index=0, finish_reason=None)], created=1710498504)
540
610
  (...)
541
611
  ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content=' may', role='assistant'), index=0, finish_reason=None)], created=1710498504)
542
- ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content=None, role=None), index=0, finish_reason='length')], created=1710498504)
612
+
613
+ # Chat example with tools
614
+ >>> client = InferenceClient("meta-llama/Meta-Llama-3-70B-Instruct")
615
+ >>> messages = [
616
+ ... {
617
+ ... "role": "system",
618
+ ... "content": "Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.",
619
+ ... },
620
+ ... {
621
+ ... "role": "user",
622
+ ... "content": "What's the weather like the next 3 days in San Francisco, CA?",
623
+ ... },
624
+ ... ]
625
+ >>> tools = [
626
+ ... {
627
+ ... "type": "function",
628
+ ... "function": {
629
+ ... "name": "get_current_weather",
630
+ ... "description": "Get the current weather",
631
+ ... "parameters": {
632
+ ... "type": "object",
633
+ ... "properties": {
634
+ ... "location": {
635
+ ... "type": "string",
636
+ ... "description": "The city and state, e.g. San Francisco, CA",
637
+ ... },
638
+ ... "format": {
639
+ ... "type": "string",
640
+ ... "enum": ["celsius", "fahrenheit"],
641
+ ... "description": "The temperature unit to use. Infer this from the users location.",
642
+ ... },
643
+ ... },
644
+ ... "required": ["location", "format"],
645
+ ... },
646
+ ... },
647
+ ... },
648
+ ... {
649
+ ... "type": "function",
650
+ ... "function": {
651
+ ... "name": "get_n_day_weather_forecast",
652
+ ... "description": "Get an N-day weather forecast",
653
+ ... "parameters": {
654
+ ... "type": "object",
655
+ ... "properties": {
656
+ ... "location": {
657
+ ... "type": "string",
658
+ ... "description": "The city and state, e.g. San Francisco, CA",
659
+ ... },
660
+ ... "format": {
661
+ ... "type": "string",
662
+ ... "enum": ["celsius", "fahrenheit"],
663
+ ... "description": "The temperature unit to use. Infer this from the users location.",
664
+ ... },
665
+ ... "num_days": {
666
+ ... "type": "integer",
667
+ ... "description": "The number of days to forecast",
668
+ ... },
669
+ ... },
670
+ ... "required": ["location", "format", "num_days"],
671
+ ... },
672
+ ... },
673
+ ... },
674
+ ... ]
675
+
676
+ >>> response = client.chat_completion(
677
+ ... model="meta-llama/Meta-Llama-3-70B-Instruct",
678
+ ... messages=messages,
679
+ ... tools=tools,
680
+ ... tool_choice="auto",
681
+ ... max_tokens=500,
682
+ ... )
683
+ >>> response.choices[0].message.tool_calls[0].function
684
+ ChatCompletionOutputFunctionDefinition(
685
+ arguments={
686
+ 'location': 'San Francisco, CA',
687
+ 'format': 'fahrenheit',
688
+ 'num_days': 3
689
+ },
690
+ name='get_n_day_weather_forecast',
691
+ description=None
692
+ )
543
693
  ```
544
694
  """
545
695
  # determine model
@@ -558,30 +708,44 @@ class InferenceClient:
558
708
  json=dict(
559
709
  model="tgi", # random string
560
710
  messages=messages,
711
+ frequency_penalty=frequency_penalty,
712
+ logit_bias=logit_bias,
713
+ logprobs=logprobs,
561
714
  max_tokens=max_tokens,
715
+ n=n,
716
+ presence_penalty=presence_penalty,
562
717
  seed=seed,
563
718
  stop=stop,
564
719
  temperature=temperature,
720
+ tool_choice=tool_choice,
721
+ tool_prompt=tool_prompt,
722
+ tools=tools,
723
+ top_logprobs=top_logprobs,
565
724
  top_p=top_p,
566
725
  stream=stream,
567
726
  ),
568
727
  stream=stream,
569
728
  )
570
- except HTTPError:
571
- # Let's consider the server is not a chat completion server.
572
- # Then we call again `chat_completion` which will render the chat template client side.
573
- # (can be HTTP 500, HTTP 400, HTTP 404 depending on the server)
574
- _set_as_non_chat_completion_server(model)
575
- return self.chat_completion(
576
- messages=messages,
577
- model=model,
578
- stream=stream,
579
- max_tokens=max_tokens,
580
- seed=seed,
581
- stop=stop,
582
- temperature=temperature,
583
- top_p=top_p,
584
- )
729
+ except HTTPError as e:
730
+ if e.response.status_code in (400, 404, 500):
731
+ # Let's consider the server is not a chat completion server.
732
+ # Then we call again `chat_completion` which will render the chat template client side.
733
+ # (can be HTTP 500, HTTP 400, HTTP 404 depending on the server)
734
+ _set_as_non_chat_completion_server(model)
735
+ logger.warning(
736
+ f"Server {model_url} does not seem to support chat completion. Falling back to text generation. Error: {e}"
737
+ )
738
+ return self.chat_completion(
739
+ messages=messages,
740
+ model=model,
741
+ stream=stream,
742
+ max_tokens=max_tokens,
743
+ seed=seed,
744
+ stop=stop,
745
+ temperature=temperature,
746
+ top_p=top_p,
747
+ )
748
+ raise
585
749
 
586
750
  if stream:
587
751
  return _stream_chat_completion_response_from_bytes(data) # type: ignore[arg-type]
@@ -589,75 +753,46 @@ class InferenceClient:
589
753
  return ChatCompletionOutput.parse_obj_as_instance(data) # type: ignore[arg-type]
590
754
 
591
755
  # At this point, we know the server is not a chat completion server.
592
- # We need to render the chat template client side based on the information we can fetch from
593
- # the Hub API.
594
-
595
- model_id = None
596
- if model.startswith(("http://", "https://")):
597
- # If URL, we need to know which model is served. This is not always possible.
598
- # A workaround is to list the user Inference Endpoints and check if one of them correspond to the model URL.
599
- # If not, we raise an error.
600
- # TODO: fix when we have a proper API for this (at least for Inference Endpoints)
601
- # TODO: what if Sagemaker URL?
602
- # TODO: what if Azure URL?
603
- from ..hf_api import HfApi
604
-
605
- for endpoint in HfApi(token=self.token).list_inference_endpoints():
606
- if endpoint.url == model:
607
- model_id = endpoint.repository
608
- break
609
- else:
610
- model_id = model
611
-
612
- if model_id is None:
613
- # If we don't have the model ID, we can't fetch the chat template.
614
- # We raise an error.
756
+ # It means it's a transformers-backed server for which we can send a list of messages directly to the
757
+ # `text-generation` pipeline. We won't receive a detailed response but only the generated text.
758
+ if stream:
615
759
  raise ValueError(
616
- "Request can't be processed as the model ID can't be inferred from model URL. "
617
- "This is needed to fetch the chat template from the Hub since the model is not "
618
- "served with a Chat-completion API."
760
+ "Streaming token is not supported by the model. This is due to the model not been served by a "
761
+ "Text-Generation-Inference server. Please pass `stream=False` as input."
762
+ )
763
+ if tool_choice is not None or tool_prompt is not None or tools is not None:
764
+ warnings.warn(
765
+ "Tools are not supported by the model. This is due to the model not been served by a "
766
+ "Text-Generation-Inference server. The provided tool parameters will be ignored."
619
767
  )
620
-
621
- # fetch chat template + tokens
622
- prompt = render_chat_prompt(model_id=model_id, token=self.token, messages=messages)
623
768
 
624
769
  # generate response
625
- stop_sequences = [stop] if isinstance(stop, str) else stop
626
770
  text_generation_output = self.text_generation(
627
- prompt=prompt,
628
- details=True,
629
- stream=stream,
771
+ prompt=messages, # type: ignore # Not correct type but works implicitly
630
772
  model=model,
773
+ stream=False,
774
+ details=False,
631
775
  max_new_tokens=max_tokens,
632
776
  seed=seed,
633
- stop_sequences=stop_sequences,
777
+ stop_sequences=stop,
634
778
  temperature=temperature,
635
779
  top_p=top_p,
636
780
  )
637
781
 
638
- created = int(time.time())
639
-
640
- if stream:
641
- return _stream_chat_completion_response_from_text_generation(text_generation_output) # type: ignore [arg-type]
642
-
643
- if isinstance(text_generation_output, TextGenerationOutput):
644
- # General use case => format ChatCompletionOutput from text generation details
645
- content: str = text_generation_output.generated_text
646
- finish_reason: str = text_generation_output.details.finish_reason # type: ignore[union-attr]
647
- else:
648
- # Corner case: if server doesn't support details (e.g. if not a TGI server), we only receive an output string.
649
- # In such a case, `finish_reason` is set to `"unk"`.
650
- content = text_generation_output # type: ignore[assignment]
651
- finish_reason = "unk"
652
-
782
+ # Format as a ChatCompletionOutput with dummy values for fields we can't provide
653
783
  return ChatCompletionOutput(
654
- created=created,
784
+ id="dummy",
785
+ model="dummy",
786
+ object="dummy",
787
+ system_fingerprint="dummy",
788
+ usage=None, # type: ignore # set to `None` as we don't want to provide false information
789
+ created=int(time.time()),
655
790
  choices=[
656
- ChatCompletionOutputChoice(
657
- finish_reason=finish_reason, # type: ignore
791
+ ChatCompletionOutputComplete(
792
+ finish_reason="unk", # type: ignore # set to `unk` as we don't want to provide false information
658
793
  index=0,
659
- message=ChatCompletionOutputChoiceMessage(
660
- content=content,
794
+ message=ChatCompletionOutputMessage(
795
+ content=text_generation_output,
661
796
  role="assistant",
662
797
  ),
663
798
  )
@@ -1055,7 +1190,7 @@ class InferenceClient:
1055
1190
  self, frameworks: Union[None, str, Literal["all"], List[str]] = None
1056
1191
  ) -> Dict[str, List[str]]:
1057
1192
  """
1058
- List models currently deployed on the Inference API service.
1193
+ List models deployed on the Serverless Inference API service.
1059
1194
 
1060
1195
  This helper checks deployed models framework by framework. By default, it will check the 4 main frameworks that
1061
1196
  are supported and account for 95% of the hosted models. However, if you want a complete list of models you can
@@ -1063,9 +1198,17 @@ class InferenceClient:
1063
1198
  in, you can also restrict to search to this one (e.g. `frameworks="text-generation-inference"`). The more
1064
1199
  frameworks are checked, the more time it will take.
1065
1200
 
1201
+ <Tip warning={true}>
1202
+
1203
+ This endpoint method does not return a live list of all models available for the Serverless Inference API service.
1204
+ It searches over a cached list of models that were recently available and the list may not be up to date.
1205
+ If you want to know the live status of a specific model, use [`~InferenceClient.get_model_status`].
1206
+
1207
+ </Tip>
1208
+
1066
1209
  <Tip>
1067
1210
 
1068
- This endpoint is mostly useful for discoverability. If you already know which model you want to use and want to
1211
+ This endpoint method is mostly useful for discoverability. If you already know which model you want to use and want to
1069
1212
  check its availability, you can directly use [`~InferenceClient.get_model_status`].
1070
1213
 
1071
1214
  </Tip>
@@ -1475,19 +1618,24 @@ class InferenceClient:
1475
1618
  details: Literal[False] = ...,
1476
1619
  stream: Literal[False] = ...,
1477
1620
  model: Optional[str] = None,
1478
- do_sample: bool = False,
1479
- max_new_tokens: int = 20,
1621
+ # Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
1480
1622
  best_of: Optional[int] = None,
1623
+ decoder_input_details: Optional[bool] = None,
1624
+ do_sample: Optional[bool] = False, # Manual default value
1625
+ frequency_penalty: Optional[float] = None,
1626
+ grammar: Optional[TextGenerationInputGrammarType] = None,
1627
+ max_new_tokens: Optional[int] = None,
1481
1628
  repetition_penalty: Optional[float] = None,
1482
- return_full_text: bool = False,
1629
+ return_full_text: Optional[bool] = False, # Manual default value
1483
1630
  seed: Optional[int] = None,
1484
- stop_sequences: Optional[List[str]] = None,
1631
+ stop_sequences: Optional[List[str]] = None, # Same as `stop`
1485
1632
  temperature: Optional[float] = None,
1486
1633
  top_k: Optional[int] = None,
1634
+ top_n_tokens: Optional[int] = None,
1487
1635
  top_p: Optional[float] = None,
1488
1636
  truncate: Optional[int] = None,
1489
1637
  typical_p: Optional[float] = None,
1490
- watermark: bool = False,
1638
+ watermark: Optional[bool] = None,
1491
1639
  ) -> str: ...
1492
1640
 
1493
1641
  @overload
@@ -1498,19 +1646,24 @@ class InferenceClient:
1498
1646
  details: Literal[True] = ...,
1499
1647
  stream: Literal[False] = ...,
1500
1648
  model: Optional[str] = None,
1501
- do_sample: bool = False,
1502
- max_new_tokens: int = 20,
1649
+ # Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
1503
1650
  best_of: Optional[int] = None,
1651
+ decoder_input_details: Optional[bool] = None,
1652
+ do_sample: Optional[bool] = False, # Manual default value
1653
+ frequency_penalty: Optional[float] = None,
1654
+ grammar: Optional[TextGenerationInputGrammarType] = None,
1655
+ max_new_tokens: Optional[int] = None,
1504
1656
  repetition_penalty: Optional[float] = None,
1505
- return_full_text: bool = False,
1657
+ return_full_text: Optional[bool] = False, # Manual default value
1506
1658
  seed: Optional[int] = None,
1507
- stop_sequences: Optional[List[str]] = None,
1659
+ stop_sequences: Optional[List[str]] = None, # Same as `stop`
1508
1660
  temperature: Optional[float] = None,
1509
1661
  top_k: Optional[int] = None,
1662
+ top_n_tokens: Optional[int] = None,
1510
1663
  top_p: Optional[float] = None,
1511
1664
  truncate: Optional[int] = None,
1512
1665
  typical_p: Optional[float] = None,
1513
- watermark: bool = False,
1666
+ watermark: Optional[bool] = None,
1514
1667
  ) -> TextGenerationOutput: ...
1515
1668
 
1516
1669
  @overload
@@ -1521,19 +1674,24 @@ class InferenceClient:
1521
1674
  details: Literal[False] = ...,
1522
1675
  stream: Literal[True] = ...,
1523
1676
  model: Optional[str] = None,
1524
- do_sample: bool = False,
1525
- max_new_tokens: int = 20,
1677
+ # Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
1526
1678
  best_of: Optional[int] = None,
1679
+ decoder_input_details: Optional[bool] = None,
1680
+ do_sample: Optional[bool] = False, # Manual default value
1681
+ frequency_penalty: Optional[float] = None,
1682
+ grammar: Optional[TextGenerationInputGrammarType] = None,
1683
+ max_new_tokens: Optional[int] = None,
1527
1684
  repetition_penalty: Optional[float] = None,
1528
- return_full_text: bool = False,
1685
+ return_full_text: Optional[bool] = False, # Manual default value
1529
1686
  seed: Optional[int] = None,
1530
- stop_sequences: Optional[List[str]] = None,
1687
+ stop_sequences: Optional[List[str]] = None, # Same as `stop`
1531
1688
  temperature: Optional[float] = None,
1532
1689
  top_k: Optional[int] = None,
1690
+ top_n_tokens: Optional[int] = None,
1533
1691
  top_p: Optional[float] = None,
1534
1692
  truncate: Optional[int] = None,
1535
1693
  typical_p: Optional[float] = None,
1536
- watermark: bool = False,
1694
+ watermark: Optional[bool] = None,
1537
1695
  ) -> Iterable[str]: ...
1538
1696
 
1539
1697
  @overload
@@ -1544,19 +1702,24 @@ class InferenceClient:
1544
1702
  details: Literal[True] = ...,
1545
1703
  stream: Literal[True] = ...,
1546
1704
  model: Optional[str] = None,
1547
- do_sample: bool = False,
1548
- max_new_tokens: int = 20,
1705
+ # Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
1549
1706
  best_of: Optional[int] = None,
1707
+ decoder_input_details: Optional[bool] = None,
1708
+ do_sample: Optional[bool] = False, # Manual default value
1709
+ frequency_penalty: Optional[float] = None,
1710
+ grammar: Optional[TextGenerationInputGrammarType] = None,
1711
+ max_new_tokens: Optional[int] = None,
1550
1712
  repetition_penalty: Optional[float] = None,
1551
- return_full_text: bool = False,
1713
+ return_full_text: Optional[bool] = False, # Manual default value
1552
1714
  seed: Optional[int] = None,
1553
- stop_sequences: Optional[List[str]] = None,
1715
+ stop_sequences: Optional[List[str]] = None, # Same as `stop`
1554
1716
  temperature: Optional[float] = None,
1555
1717
  top_k: Optional[int] = None,
1718
+ top_n_tokens: Optional[int] = None,
1556
1719
  top_p: Optional[float] = None,
1557
1720
  truncate: Optional[int] = None,
1558
1721
  typical_p: Optional[float] = None,
1559
- watermark: bool = False,
1722
+ watermark: Optional[bool] = None,
1560
1723
  ) -> Iterable[TextGenerationStreamOutput]: ...
1561
1724
 
1562
1725
  @overload
@@ -1567,19 +1730,24 @@ class InferenceClient:
1567
1730
  details: Literal[True] = ...,
1568
1731
  stream: bool = ...,
1569
1732
  model: Optional[str] = None,
1570
- do_sample: bool = False,
1571
- max_new_tokens: int = 20,
1733
+ # Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
1572
1734
  best_of: Optional[int] = None,
1735
+ decoder_input_details: Optional[bool] = None,
1736
+ do_sample: Optional[bool] = False, # Manual default value
1737
+ frequency_penalty: Optional[float] = None,
1738
+ grammar: Optional[TextGenerationInputGrammarType] = None,
1739
+ max_new_tokens: Optional[int] = None,
1573
1740
  repetition_penalty: Optional[float] = None,
1574
- return_full_text: bool = False,
1741
+ return_full_text: Optional[bool] = False, # Manual default value
1575
1742
  seed: Optional[int] = None,
1576
- stop_sequences: Optional[List[str]] = None,
1743
+ stop_sequences: Optional[List[str]] = None, # Same as `stop`
1577
1744
  temperature: Optional[float] = None,
1578
1745
  top_k: Optional[int] = None,
1746
+ top_n_tokens: Optional[int] = None,
1579
1747
  top_p: Optional[float] = None,
1580
1748
  truncate: Optional[int] = None,
1581
1749
  typical_p: Optional[float] = None,
1582
- watermark: bool = False,
1750
+ watermark: Optional[bool] = None,
1583
1751
  ) -> Union[TextGenerationOutput, Iterable[TextGenerationStreamOutput]]: ...
1584
1752
 
1585
1753
  def text_generation(
@@ -1589,20 +1757,24 @@ class InferenceClient:
1589
1757
  details: bool = False,
1590
1758
  stream: bool = False,
1591
1759
  model: Optional[str] = None,
1592
- do_sample: bool = False,
1593
- max_new_tokens: int = 20,
1760
+ # Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
1594
1761
  best_of: Optional[int] = None,
1762
+ decoder_input_details: Optional[bool] = None,
1763
+ do_sample: Optional[bool] = False, # Manual default value
1764
+ frequency_penalty: Optional[float] = None,
1765
+ grammar: Optional[TextGenerationInputGrammarType] = None,
1766
+ max_new_tokens: Optional[int] = None,
1595
1767
  repetition_penalty: Optional[float] = None,
1596
- return_full_text: bool = False,
1768
+ return_full_text: Optional[bool] = False, # Manual default value
1597
1769
  seed: Optional[int] = None,
1598
- stop_sequences: Optional[List[str]] = None,
1770
+ stop_sequences: Optional[List[str]] = None, # Same as `stop`
1599
1771
  temperature: Optional[float] = None,
1600
1772
  top_k: Optional[int] = None,
1773
+ top_n_tokens: Optional[int] = None,
1601
1774
  top_p: Optional[float] = None,
1602
1775
  truncate: Optional[int] = None,
1603
1776
  typical_p: Optional[float] = None,
1604
- watermark: bool = False,
1605
- decoder_input_details: bool = False,
1777
+ watermark: Optional[bool] = None,
1606
1778
  ) -> Union[str, TextGenerationOutput, Iterable[str], Iterable[TextGenerationStreamOutput]]:
1607
1779
  """
1608
1780
  Given a prompt, generate the following text.
@@ -1630,38 +1802,46 @@ class InferenceClient:
1630
1802
  model (`str`, *optional*):
1631
1803
  The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
1632
1804
  Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
1633
- do_sample (`bool`):
1805
+ best_of (`int`, *optional*):
1806
+ Generate best_of sequences and return the one if the highest token logprobs.
1807
+ decoder_input_details (`bool`, *optional*):
1808
+ Return the decoder input token logprobs and ids. You must set `details=True` as well for it to be taken
1809
+ into account. Defaults to `False`.
1810
+ do_sample (`bool`, *optional*):
1634
1811
  Activate logits sampling
1635
- max_new_tokens (`int`):
1812
+ frequency_penalty (`float`, *optional*):
1813
+ Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in
1814
+ the text so far, decreasing the model's likelihood to repeat the same line verbatim.
1815
+ grammar ([`TextGenerationInputGrammarType`], *optional*):
1816
+ Grammar constraints. Can be either a JSONSchema or a regex.
1817
+ max_new_tokens (`int`, *optional*):
1636
1818
  Maximum number of generated tokens
1637
- best_of (`int`):
1638
- Generate best_of sequences and return the one if the highest token logprobs
1639
- repetition_penalty (`float`):
1819
+ repetition_penalty (`float`, *optional*):
1640
1820
  The parameter for repetition penalty. 1.0 means no penalty. See [this
1641
1821
  paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
1642
- return_full_text (`bool`):
1822
+ return_full_text (`bool`, *optional*):
1643
1823
  Whether to prepend the prompt to the generated text
1644
- seed (`int`):
1824
+ seed (`int`, *optional*):
1645
1825
  Random sampling seed
1646
- stop_sequences (`List[str]`):
1826
+ stop_sequences (`List[str]`, *optional*):
1647
1827
  Stop generating tokens if a member of `stop_sequences` is generated
1648
- temperature (`float`):
1828
+ temperature (`float`, *optional*):
1649
1829
  The value used to module the logits distribution.
1650
- top_k (`int`):
1830
+ top_n_tokens (`int`, *optional*):
1831
+ Return information about the `top_n_tokens` most likely tokens at each generation step, instead of
1832
+ just the sampled token.
1833
+ top_k (`int`, *optional`):
1651
1834
  The number of highest probability vocabulary tokens to keep for top-k-filtering.
1652
- top_p (`float`):
1835
+ top_p (`float`, *optional`):
1653
1836
  If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
1654
1837
  higher are kept for generation.
1655
- truncate (`int`):
1656
- Truncate inputs tokens to the given size
1657
- typical_p (`float`):
1838
+ truncate (`int`, *optional`):
1839
+ Truncate inputs tokens to the given size.
1840
+ typical_p (`float`, *optional`):
1658
1841
  Typical Decoding mass
1659
1842
  See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
1660
- watermark (`bool`):
1843
+ watermark (`bool`, *optional`):
1661
1844
  Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
1662
- decoder_input_details (`bool`):
1663
- Return the decoder input token logprobs and ids. You must set `details=True` as well for it to be taken
1664
- into account. Defaults to `False`.
1665
1845
 
1666
1846
  Returns:
1667
1847
  `Union[str, TextGenerationOutput, Iterable[str], Iterable[TextGenerationStreamOutput]]`:
@@ -1713,10 +1893,10 @@ class InferenceClient:
1713
1893
  generated_tokens=12,
1714
1894
  seed=None,
1715
1895
  prefill=[
1716
- TextGenerationPrefillToken(id=487, text='The', logprob=None),
1717
- TextGenerationPrefillToken(id=53789, text=' hugging', logprob=-13.171875),
1896
+ TextGenerationPrefillOutputToken(id=487, text='The', logprob=None),
1897
+ TextGenerationPrefillOutputToken(id=53789, text=' hugging', logprob=-13.171875),
1718
1898
  (...)
1719
- TextGenerationPrefillToken(id=204, text=' ', logprob=-7.0390625)
1899
+ TextGenerationPrefillOutputToken(id=204, text=' ', logprob=-7.0390625)
1720
1900
  ],
1721
1901
  tokens=[
1722
1902
  TokenElement(id=1425, text='100', logprob=-1.0175781, special=False),
@@ -1750,8 +1930,35 @@ class InferenceClient:
1750
1930
  logprob=-0.5703125,
1751
1931
  special=False),
1752
1932
  generated_text='100% open source and built to be easy to use.',
1753
- details=TextGenerationStreamDetails(finish_reason='length', generated_tokens=12, seed=None)
1933
+ details=TextGenerationStreamOutputStreamDetails(finish_reason='length', generated_tokens=12, seed=None)
1754
1934
  )
1935
+
1936
+ # Case 5: generate constrained output using grammar
1937
+ >>> response = client.text_generation(
1938
+ ... prompt="I saw a puppy a cat and a raccoon during my bike ride in the park",
1939
+ ... model="HuggingFaceH4/zephyr-orpo-141b-A35b-v0.1",
1940
+ ... max_new_tokens=100,
1941
+ ... repetition_penalty=1.3,
1942
+ ... grammar={
1943
+ ... "type": "json",
1944
+ ... "value": {
1945
+ ... "properties": {
1946
+ ... "location": {"type": "string"},
1947
+ ... "activity": {"type": "string"},
1948
+ ... "animals_seen": {"type": "integer", "minimum": 1, "maximum": 5},
1949
+ ... "animals": {"type": "array", "items": {"type": "string"}},
1950
+ ... },
1951
+ ... "required": ["location", "activity", "animals_seen", "animals"],
1952
+ ... },
1953
+ ... },
1954
+ ... )
1955
+ >>> json.loads(response)
1956
+ {
1957
+ "activity": "bike riding",
1958
+ "animals": ["puppy", "cat", "raccoon"],
1959
+ "animals_seen": 3,
1960
+ "location": "park"
1961
+ }
1755
1962
  ```
1756
1963
  """
1757
1964
  if decoder_input_details and not details:
@@ -1762,41 +1969,48 @@ class InferenceClient:
1762
1969
  decoder_input_details = False
1763
1970
 
1764
1971
  # Build payload
1972
+ parameters = {
1973
+ "best_of": best_of,
1974
+ "decoder_input_details": decoder_input_details,
1975
+ "do_sample": do_sample,
1976
+ "frequency_penalty": frequency_penalty,
1977
+ "grammar": grammar,
1978
+ "max_new_tokens": max_new_tokens,
1979
+ "repetition_penalty": repetition_penalty,
1980
+ "return_full_text": return_full_text,
1981
+ "seed": seed,
1982
+ "stop": stop_sequences if stop_sequences is not None else [],
1983
+ "temperature": temperature,
1984
+ "top_k": top_k,
1985
+ "top_n_tokens": top_n_tokens,
1986
+ "top_p": top_p,
1987
+ "truncate": truncate,
1988
+ "typical_p": typical_p,
1989
+ "watermark": watermark,
1990
+ }
1991
+ parameters = {k: v for k, v in parameters.items() if v is not None}
1765
1992
  payload = {
1766
1993
  "inputs": prompt,
1767
- "parameters": {
1768
- "best_of": best_of,
1769
- "decoder_input_details": decoder_input_details,
1770
- "details": details,
1771
- "do_sample": do_sample,
1772
- "max_new_tokens": max_new_tokens,
1773
- "repetition_penalty": repetition_penalty,
1774
- "return_full_text": return_full_text,
1775
- "seed": seed,
1776
- "stop": stop_sequences if stop_sequences is not None else [],
1777
- "temperature": temperature,
1778
- "top_k": top_k,
1779
- "top_p": top_p,
1780
- "truncate": truncate,
1781
- "typical_p": typical_p,
1782
- "watermark": watermark,
1783
- },
1994
+ "parameters": parameters,
1784
1995
  "stream": stream,
1785
1996
  }
1786
1997
 
1787
1998
  # Remove some parameters if not a TGI server
1788
- if not _is_tgi_server(model):
1789
- parameters: Dict = payload["parameters"] # type: ignore [assignment]
1999
+ unsupported_kwargs = _get_unsupported_text_generation_kwargs(model)
2000
+ if len(unsupported_kwargs) > 0:
2001
+ # The server does not support some parameters
2002
+ # => means it is not a TGI server
2003
+ # => remove unsupported parameters and warn the user
1790
2004
 
1791
2005
  ignored_parameters = []
1792
- for key in "watermark", "details", "decoder_input_details", "best_of", "stop", "return_full_text":
1793
- if parameters[key] is not None:
2006
+ for key in unsupported_kwargs:
2007
+ if parameters.get(key):
1794
2008
  ignored_parameters.append(key)
1795
- del parameters[key]
2009
+ parameters.pop(key, None)
1796
2010
  if len(ignored_parameters) > 0:
1797
2011
  warnings.warn(
1798
- "API endpoint/model for text-generation is not served via TGI. Ignoring parameters"
1799
- f" {ignored_parameters}.",
2012
+ "API endpoint/model for text-generation is not served via TGI. Ignoring following parameters:"
2013
+ f" {', '.join(ignored_parameters)}.",
1800
2014
  UserWarning,
1801
2015
  )
1802
2016
  if details:
@@ -1816,27 +2030,32 @@ class InferenceClient:
1816
2030
  try:
1817
2031
  bytes_output = self.post(json=payload, model=model, task="text-generation", stream=stream) # type: ignore
1818
2032
  except HTTPError as e:
1819
- if isinstance(e, BadRequestError) and "The following `model_kwargs` are not used by the model" in str(e):
1820
- _set_as_non_tgi(model)
2033
+ match = MODEL_KWARGS_NOT_USED_REGEX.search(str(e))
2034
+ if isinstance(e, BadRequestError) and match:
2035
+ unused_params = [kwarg.strip("' ") for kwarg in match.group(1).split(",")]
2036
+ _set_unsupported_text_generation_kwargs(model, unused_params)
1821
2037
  return self.text_generation( # type: ignore
1822
2038
  prompt=prompt,
1823
2039
  details=details,
1824
2040
  stream=stream,
1825
2041
  model=model,
2042
+ best_of=best_of,
2043
+ decoder_input_details=decoder_input_details,
1826
2044
  do_sample=do_sample,
2045
+ frequency_penalty=frequency_penalty,
2046
+ grammar=grammar,
1827
2047
  max_new_tokens=max_new_tokens,
1828
- best_of=best_of,
1829
2048
  repetition_penalty=repetition_penalty,
1830
2049
  return_full_text=return_full_text,
1831
2050
  seed=seed,
1832
2051
  stop_sequences=stop_sequences,
1833
2052
  temperature=temperature,
1834
2053
  top_k=top_k,
2054
+ top_n_tokens=top_n_tokens,
1835
2055
  top_p=top_p,
1836
2056
  truncate=truncate,
1837
2057
  typical_p=typical_p,
1838
2058
  watermark=watermark,
1839
- decoder_input_details=decoder_input_details,
1840
2059
  )
1841
2060
  raise_text_generation_error(e)
1842
2061