huggingface-hub 0.22.1__py3-none-any.whl → 0.23.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of huggingface-hub might be problematic. Click here for more details.

Files changed (45) hide show
  1. huggingface_hub/__init__.py +51 -19
  2. huggingface_hub/_commit_api.py +10 -9
  3. huggingface_hub/_commit_scheduler.py +2 -2
  4. huggingface_hub/_inference_endpoints.py +10 -17
  5. huggingface_hub/_local_folder.py +229 -0
  6. huggingface_hub/_login.py +4 -3
  7. huggingface_hub/_multi_commits.py +1 -1
  8. huggingface_hub/_snapshot_download.py +16 -38
  9. huggingface_hub/_tensorboard_logger.py +16 -6
  10. huggingface_hub/_webhooks_payload.py +22 -1
  11. huggingface_hub/_webhooks_server.py +24 -20
  12. huggingface_hub/commands/download.py +11 -34
  13. huggingface_hub/commands/huggingface_cli.py +2 -0
  14. huggingface_hub/commands/tag.py +159 -0
  15. huggingface_hub/constants.py +3 -5
  16. huggingface_hub/errors.py +58 -0
  17. huggingface_hub/file_download.py +545 -376
  18. huggingface_hub/hf_api.py +758 -629
  19. huggingface_hub/hf_file_system.py +14 -5
  20. huggingface_hub/hub_mixin.py +127 -43
  21. huggingface_hub/inference/_client.py +402 -183
  22. huggingface_hub/inference/_common.py +19 -29
  23. huggingface_hub/inference/_generated/_async_client.py +402 -184
  24. huggingface_hub/inference/_generated/types/__init__.py +23 -6
  25. huggingface_hub/inference/_generated/types/chat_completion.py +197 -43
  26. huggingface_hub/inference/_generated/types/text_generation.py +57 -79
  27. huggingface_hub/inference/_templating.py +2 -4
  28. huggingface_hub/keras_mixin.py +0 -3
  29. huggingface_hub/lfs.py +16 -4
  30. huggingface_hub/repository.py +1 -0
  31. huggingface_hub/utils/__init__.py +19 -6
  32. huggingface_hub/utils/_fixes.py +1 -0
  33. huggingface_hub/utils/_headers.py +2 -4
  34. huggingface_hub/utils/_http.py +16 -5
  35. huggingface_hub/utils/_paths.py +13 -1
  36. huggingface_hub/utils/_runtime.py +10 -0
  37. huggingface_hub/utils/_safetensors.py +0 -13
  38. huggingface_hub/utils/_validators.py +2 -7
  39. huggingface_hub/utils/tqdm.py +124 -46
  40. {huggingface_hub-0.22.1.dist-info → huggingface_hub-0.23.0.dist-info}/METADATA +5 -1
  41. {huggingface_hub-0.22.1.dist-info → huggingface_hub-0.23.0.dist-info}/RECORD +45 -43
  42. {huggingface_hub-0.22.1.dist-info → huggingface_hub-0.23.0.dist-info}/LICENSE +0 -0
  43. {huggingface_hub-0.22.1.dist-info → huggingface_hub-0.23.0.dist-info}/WHEEL +0 -0
  44. {huggingface_hub-0.22.1.dist-info → huggingface_hub-0.23.0.dist-info}/entry_points.txt +0 -0
  45. {huggingface_hub-0.22.1.dist-info → huggingface_hub-0.23.0.dist-info}/top_level.txt +0 -0
@@ -21,6 +21,7 @@
21
21
  import asyncio
22
22
  import base64
23
23
  import logging
24
+ import re
24
25
  import time
25
26
  import warnings
26
27
  from typing import (
@@ -44,7 +45,6 @@ from huggingface_hub.inference._common import (
44
45
  ContentT,
45
46
  ModelStatus,
46
47
  _async_stream_chat_completion_response_from_bytes,
47
- _async_stream_chat_completion_response_from_text_generation,
48
48
  _async_stream_text_generation_response,
49
49
  _b64_encode,
50
50
  _b64_to_image,
@@ -52,21 +52,23 @@ from huggingface_hub.inference._common import (
52
52
  _bytes_to_image,
53
53
  _bytes_to_list,
54
54
  _fetch_recommended_models,
55
+ _get_unsupported_text_generation_kwargs,
55
56
  _import_numpy,
56
57
  _is_chat_completion_server,
57
- _is_tgi_server,
58
58
  _open_as_binary,
59
59
  _set_as_non_chat_completion_server,
60
- _set_as_non_tgi,
60
+ _set_unsupported_text_generation_kwargs,
61
61
  raise_text_generation_error,
62
62
  )
63
63
  from huggingface_hub.inference._generated.types import (
64
64
  AudioClassificationOutputElement,
65
65
  AudioToAudioOutputElement,
66
66
  AutomaticSpeechRecognitionOutput,
67
+ ChatCompletionInputTool,
68
+ ChatCompletionInputToolTypeClass,
67
69
  ChatCompletionOutput,
68
- ChatCompletionOutputChoice,
69
- ChatCompletionOutputChoiceMessage,
70
+ ChatCompletionOutputComplete,
71
+ ChatCompletionOutputMessage,
70
72
  ChatCompletionStreamOutput,
71
73
  DocumentQuestionAnsweringOutputElement,
72
74
  FillMaskOutputElement,
@@ -78,6 +80,7 @@ from huggingface_hub.inference._generated.types import (
78
80
  SummarizationOutput,
79
81
  TableQuestionAnsweringOutputElement,
80
82
  TextClassificationOutputElement,
83
+ TextGenerationInputGrammarType,
81
84
  TextGenerationOutput,
82
85
  TextGenerationStreamOutput,
83
86
  TokenClassificationOutputElement,
@@ -86,7 +89,7 @@ from huggingface_hub.inference._generated.types import (
86
89
  ZeroShotClassificationOutputElement,
87
90
  ZeroShotImageClassificationOutputElement,
88
91
  )
89
- from huggingface_hub.inference._templating import render_chat_prompt
92
+ from huggingface_hub.inference._generated.types.chat_completion import ChatCompletionInputToolTypeEnum
90
93
  from huggingface_hub.inference._types import (
91
94
  ConversationalOutput, # soon to be removed
92
95
  )
@@ -99,11 +102,14 @@ from .._common import _async_yield_from, _import_aiohttp
99
102
 
100
103
  if TYPE_CHECKING:
101
104
  import numpy as np
102
- from PIL import Image
105
+ from PIL.Image import Image
103
106
 
104
107
  logger = logging.getLogger(__name__)
105
108
 
106
109
 
110
+ MODEL_KWARGS_NOT_USED_REGEX = re.compile(r"The following `model_kwargs` are not used by the model: \[(.*?)\]")
111
+
112
+
107
113
  class AsyncInferenceClient:
108
114
  """
109
115
  Initialize a new Inference Client.
@@ -415,10 +421,19 @@ class AsyncInferenceClient:
415
421
  *,
416
422
  model: Optional[str] = None,
417
423
  stream: Literal[False] = False,
418
- max_tokens: int = 20,
424
+ frequency_penalty: Optional[float] = None,
425
+ logit_bias: Optional[List[float]] = None,
426
+ logprobs: Optional[bool] = None,
427
+ max_tokens: Optional[int] = None,
428
+ n: Optional[int] = None,
429
+ presence_penalty: Optional[float] = None,
419
430
  seed: Optional[int] = None,
420
- stop: Optional[Union[List[str], str]] = None,
421
- temperature: float = 1.0,
431
+ stop: Optional[List[str]] = None,
432
+ temperature: Optional[float] = None,
433
+ tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, ChatCompletionInputToolTypeEnum]] = None,
434
+ tool_prompt: Optional[str] = None,
435
+ tools: Optional[List[ChatCompletionInputTool]] = None,
436
+ top_logprobs: Optional[int] = None,
422
437
  top_p: Optional[float] = None,
423
438
  ) -> ChatCompletionOutput: ...
424
439
 
@@ -429,10 +444,19 @@ class AsyncInferenceClient:
429
444
  *,
430
445
  model: Optional[str] = None,
431
446
  stream: Literal[True] = True,
432
- max_tokens: int = 20,
447
+ frequency_penalty: Optional[float] = None,
448
+ logit_bias: Optional[List[float]] = None,
449
+ logprobs: Optional[bool] = None,
450
+ max_tokens: Optional[int] = None,
451
+ n: Optional[int] = None,
452
+ presence_penalty: Optional[float] = None,
433
453
  seed: Optional[int] = None,
434
- stop: Optional[Union[List[str], str]] = None,
435
- temperature: float = 1.0,
454
+ stop: Optional[List[str]] = None,
455
+ temperature: Optional[float] = None,
456
+ tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, ChatCompletionInputToolTypeEnum]] = None,
457
+ tool_prompt: Optional[str] = None,
458
+ tools: Optional[List[ChatCompletionInputTool]] = None,
459
+ top_logprobs: Optional[int] = None,
436
460
  top_p: Optional[float] = None,
437
461
  ) -> AsyncIterable[ChatCompletionStreamOutput]: ...
438
462
 
@@ -443,10 +467,19 @@ class AsyncInferenceClient:
443
467
  *,
444
468
  model: Optional[str] = None,
445
469
  stream: bool = False,
446
- max_tokens: int = 20,
470
+ frequency_penalty: Optional[float] = None,
471
+ logit_bias: Optional[List[float]] = None,
472
+ logprobs: Optional[bool] = None,
473
+ max_tokens: Optional[int] = None,
474
+ n: Optional[int] = None,
475
+ presence_penalty: Optional[float] = None,
447
476
  seed: Optional[int] = None,
448
- stop: Optional[Union[List[str], str]] = None,
449
- temperature: float = 1.0,
477
+ stop: Optional[List[str]] = None,
478
+ temperature: Optional[float] = None,
479
+ tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, ChatCompletionInputToolTypeEnum]] = None,
480
+ tool_prompt: Optional[str] = None,
481
+ tools: Optional[List[ChatCompletionInputTool]] = None,
482
+ top_logprobs: Optional[int] = None,
450
483
  top_p: Optional[float] = None,
451
484
  ) -> Union[ChatCompletionOutput, AsyncIterable[ChatCompletionStreamOutput]]: ...
452
485
 
@@ -456,10 +489,20 @@ class AsyncInferenceClient:
456
489
  *,
457
490
  model: Optional[str] = None,
458
491
  stream: bool = False,
459
- max_tokens: int = 20,
492
+ # Parameters from ChatCompletionInput (handled manually)
493
+ frequency_penalty: Optional[float] = None,
494
+ logit_bias: Optional[List[float]] = None,
495
+ logprobs: Optional[bool] = None,
496
+ max_tokens: Optional[int] = None,
497
+ n: Optional[int] = None,
498
+ presence_penalty: Optional[float] = None,
460
499
  seed: Optional[int] = None,
461
- stop: Optional[Union[List[str], str]] = None,
462
- temperature: float = 1.0,
500
+ stop: Optional[List[str]] = None,
501
+ temperature: Optional[float] = None,
502
+ tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, ChatCompletionInputToolTypeEnum]] = None,
503
+ tool_prompt: Optional[str] = None,
504
+ tools: Optional[List[ChatCompletionInputTool]] = None,
505
+ top_logprobs: Optional[int] = None,
463
506
  top_p: Optional[float] = None,
464
507
  ) -> Union[ChatCompletionOutput, AsyncIterable[ChatCompletionStreamOutput]]:
465
508
  """
@@ -482,27 +525,52 @@ class AsyncInferenceClient:
482
525
  The model to use for chat-completion. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
483
526
  Inference Endpoint. If not provided, the default recommended model for chat-based text-generation will be used.
484
527
  See https://huggingface.co/tasks/text-generation for more details.
485
- frequency_penalty (`float`, optional):
528
+ frequency_penalty (`float`, *optional*):
486
529
  Penalizes new tokens based on their existing frequency
487
530
  in the text so far. Range: [-2.0, 2.0]. Defaults to 0.0.
488
- max_tokens (`int`, optional):
531
+ logit_bias (`List[float]`, *optional*):
532
+ Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens
533
+ (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,
534
+ the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,
535
+ but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should
536
+ result in a ban or exclusive selection of the relevant token. Defaults to None.
537
+ logprobs (`bool`, *optional*):
538
+ Whether to return log probabilities of the output tokens or not. If true, returns the log
539
+ probabilities of each output token returned in the content of message.
540
+ max_tokens (`int`, *optional*):
489
541
  Maximum number of tokens allowed in the response. Defaults to 20.
490
- seed (Optional[`int`], optional):
542
+ n (`int`, *optional*):
543
+ UNUSED.
544
+ presence_penalty (`float`, *optional*):
545
+ Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the
546
+ text so far, increasing the model's likelihood to talk about new topics.
547
+ seed (Optional[`int`], *optional*):
491
548
  Seed for reproducible control flow. Defaults to None.
492
- stop (Optional[`str`], optional):
549
+ stop (Optional[`str`], *optional*):
493
550
  Up to four strings which trigger the end of the response.
494
551
  Defaults to None.
495
- stream (`bool`, optional):
552
+ stream (`bool`, *optional*):
496
553
  Enable realtime streaming of responses. Defaults to False.
497
- temperature (`float`, optional):
554
+ temperature (`float`, *optional*):
498
555
  Controls randomness of the generations. Lower values ensure
499
556
  less random completions. Range: [0, 2]. Defaults to 1.0.
500
- top_p (`float`, optional):
557
+ top_logprobs (`int`, *optional*):
558
+ An integer between 0 and 5 specifying the number of most likely tokens to return at each token
559
+ position, each with an associated log probability. logprobs must be set to true if this parameter is
560
+ used.
561
+ top_p (`float`, *optional*):
501
562
  Fraction of the most likely next words to sample from.
502
563
  Must be between 0 and 1. Defaults to 1.0.
564
+ tool_choice ([`ChatCompletionInputToolTypeClass`] or [`ChatCompletionInputToolTypeEnum`], *optional*):
565
+ The tool to use for the completion. Defaults to "auto".
566
+ tool_prompt (`str`, *optional*):
567
+ A prompt to be appended before the tools.
568
+ tools (List of [`ChatCompletionInputTool`], *optional*):
569
+ A list of tools the model may call. Currently, only functions are supported as a tool. Use this to
570
+ provide a list of functions the model may generate JSON inputs for.
503
571
 
504
572
  Returns:
505
- `Union[ChatCompletionOutput, Iterable[ChatCompletionStreamOutput]]`:
573
+ [`ChatCompletionOutput] or Iterable of [`ChatCompletionStreamOutput`]:
506
574
  Generated text returned from the server:
507
575
  - if `stream=False`, the generated text is returned as a [`ChatCompletionOutput`] (default).
508
576
  - if `stream=True`, the generated text is returned token by token as a sequence of [`ChatCompletionStreamOutput`].
@@ -514,19 +582,21 @@ class AsyncInferenceClient:
514
582
  If the request fails with an HTTP error status code other than HTTP 503.
515
583
 
516
584
  Example:
585
+
517
586
  ```py
518
587
  # Must be run in an async context
588
+ # Chat example
519
589
  >>> from huggingface_hub import AsyncInferenceClient
520
590
  >>> messages = [{"role": "user", "content": "What is the capital of France?"}]
521
591
  >>> client = AsyncInferenceClient("HuggingFaceH4/zephyr-7b-beta")
522
592
  >>> await client.chat_completion(messages, max_tokens=100)
523
593
  ChatCompletionOutput(
524
594
  choices=[
525
- ChatCompletionOutputChoice(
595
+ ChatCompletionOutputComplete(
526
596
  finish_reason='eos_token',
527
597
  index=0,
528
- message=ChatCompletionOutputChoiceMessage(
529
- content='The capital of France is Paris. The official name of the city is "Ville de Paris" (City of Paris) and the name of the country\'s governing body, which is located in Paris, is "La République française" (The French Republic). \nI hope that helps! Let me know if you need any further information.'
598
+ message=ChatCompletionOutputMessage(
599
+ content='The capital of France is Paris. The official name of the city is Ville de Paris (City of Paris) and the name of the country governing body, which is located in Paris, is La République française (The French Republic). \nI hope that helps! Let me know if you need any further information.'
530
600
  )
531
601
  )
532
602
  ],
@@ -539,7 +609,87 @@ class AsyncInferenceClient:
539
609
  ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content=' capital', role='assistant'), index=0, finish_reason=None)], created=1710498504)
540
610
  (...)
541
611
  ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content=' may', role='assistant'), index=0, finish_reason=None)], created=1710498504)
542
- ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content=None, role=None), index=0, finish_reason='length')], created=1710498504)
612
+
613
+ # Chat example with tools
614
+ >>> client = AsyncInferenceClient("meta-llama/Meta-Llama-3-70B-Instruct")
615
+ >>> messages = [
616
+ ... {
617
+ ... "role": "system",
618
+ ... "content": "Don't make assumptions about what values to plug into functions. Ask async for clarification if a user request is ambiguous.",
619
+ ... },
620
+ ... {
621
+ ... "role": "user",
622
+ ... "content": "What's the weather like the next 3 days in San Francisco, CA?",
623
+ ... },
624
+ ... ]
625
+ >>> tools = [
626
+ ... {
627
+ ... "type": "function",
628
+ ... "function": {
629
+ ... "name": "get_current_weather",
630
+ ... "description": "Get the current weather",
631
+ ... "parameters": {
632
+ ... "type": "object",
633
+ ... "properties": {
634
+ ... "location": {
635
+ ... "type": "string",
636
+ ... "description": "The city and state, e.g. San Francisco, CA",
637
+ ... },
638
+ ... "format": {
639
+ ... "type": "string",
640
+ ... "enum": ["celsius", "fahrenheit"],
641
+ ... "description": "The temperature unit to use. Infer this from the users location.",
642
+ ... },
643
+ ... },
644
+ ... "required": ["location", "format"],
645
+ ... },
646
+ ... },
647
+ ... },
648
+ ... {
649
+ ... "type": "function",
650
+ ... "function": {
651
+ ... "name": "get_n_day_weather_forecast",
652
+ ... "description": "Get an N-day weather forecast",
653
+ ... "parameters": {
654
+ ... "type": "object",
655
+ ... "properties": {
656
+ ... "location": {
657
+ ... "type": "string",
658
+ ... "description": "The city and state, e.g. San Francisco, CA",
659
+ ... },
660
+ ... "format": {
661
+ ... "type": "string",
662
+ ... "enum": ["celsius", "fahrenheit"],
663
+ ... "description": "The temperature unit to use. Infer this from the users location.",
664
+ ... },
665
+ ... "num_days": {
666
+ ... "type": "integer",
667
+ ... "description": "The number of days to forecast",
668
+ ... },
669
+ ... },
670
+ ... "required": ["location", "format", "num_days"],
671
+ ... },
672
+ ... },
673
+ ... },
674
+ ... ]
675
+
676
+ >>> response = await client.chat_completion(
677
+ ... model="meta-llama/Meta-Llama-3-70B-Instruct",
678
+ ... messages=messages,
679
+ ... tools=tools,
680
+ ... tool_choice="auto",
681
+ ... max_tokens=500,
682
+ ... )
683
+ >>> response.choices[0].message.tool_calls[0].function
684
+ ChatCompletionOutputFunctionDefinition(
685
+ arguments={
686
+ 'location': 'San Francisco, CA',
687
+ 'format': 'fahrenheit',
688
+ 'num_days': 3
689
+ },
690
+ name='get_n_day_weather_forecast',
691
+ description=None
692
+ )
543
693
  ```
544
694
  """
545
695
  # determine model
@@ -558,30 +708,44 @@ class AsyncInferenceClient:
558
708
  json=dict(
559
709
  model="tgi", # random string
560
710
  messages=messages,
711
+ frequency_penalty=frequency_penalty,
712
+ logit_bias=logit_bias,
713
+ logprobs=logprobs,
561
714
  max_tokens=max_tokens,
715
+ n=n,
716
+ presence_penalty=presence_penalty,
562
717
  seed=seed,
563
718
  stop=stop,
564
719
  temperature=temperature,
720
+ tool_choice=tool_choice,
721
+ tool_prompt=tool_prompt,
722
+ tools=tools,
723
+ top_logprobs=top_logprobs,
565
724
  top_p=top_p,
566
725
  stream=stream,
567
726
  ),
568
727
  stream=stream,
569
728
  )
570
- except _import_aiohttp().ClientResponseError:
571
- # Let's consider the server is not a chat completion server.
572
- # Then we call again `chat_completion` which will render the chat template client side.
573
- # (can be HTTP 500, HTTP 400, HTTP 404 depending on the server)
574
- _set_as_non_chat_completion_server(model)
575
- return await self.chat_completion(
576
- messages=messages,
577
- model=model,
578
- stream=stream,
579
- max_tokens=max_tokens,
580
- seed=seed,
581
- stop=stop,
582
- temperature=temperature,
583
- top_p=top_p,
584
- )
729
+ except _import_aiohttp().ClientResponseError as e:
730
+ if e.status in (400, 404, 500):
731
+ # Let's consider the server is not a chat completion server.
732
+ # Then we call again `chat_completion` which will render the chat template client side.
733
+ # (can be HTTP 500, HTTP 400, HTTP 404 depending on the server)
734
+ _set_as_non_chat_completion_server(model)
735
+ logger.warning(
736
+ f"Server {model_url} does not seem to support chat completion. Falling back to text generation. Error: {e}"
737
+ )
738
+ return await self.chat_completion(
739
+ messages=messages,
740
+ model=model,
741
+ stream=stream,
742
+ max_tokens=max_tokens,
743
+ seed=seed,
744
+ stop=stop,
745
+ temperature=temperature,
746
+ top_p=top_p,
747
+ )
748
+ raise
585
749
 
586
750
  if stream:
587
751
  return _async_stream_chat_completion_response_from_bytes(data) # type: ignore[arg-type]
@@ -589,75 +753,46 @@ class AsyncInferenceClient:
589
753
  return ChatCompletionOutput.parse_obj_as_instance(data) # type: ignore[arg-type]
590
754
 
591
755
  # At this point, we know the server is not a chat completion server.
592
- # We need to render the chat template client side based on the information we can fetch from
593
- # the Hub API.
594
-
595
- model_id = None
596
- if model.startswith(("http://", "https://")):
597
- # If URL, we need to know which model is served. This is not always possible.
598
- # A workaround is to list the user Inference Endpoints and check if one of them correspond to the model URL.
599
- # If not, we raise an error.
600
- # TODO: fix when we have a proper API for this (at least for Inference Endpoints)
601
- # TODO: what if Sagemaker URL?
602
- # TODO: what if Azure URL?
603
- from ..hf_api import HfApi
604
-
605
- for endpoint in HfApi(token=self.token).list_inference_endpoints():
606
- if endpoint.url == model:
607
- model_id = endpoint.repository
608
- break
609
- else:
610
- model_id = model
611
-
612
- if model_id is None:
613
- # If we don't have the model ID, we can't fetch the chat template.
614
- # We raise an error.
756
+ # It means it's a transformers-backed server for which we can send a list of messages directly to the
757
+ # `text-generation` pipeline. We won't receive a detailed response but only the generated text.
758
+ if stream:
615
759
  raise ValueError(
616
- "Request can't be processed as the model ID can't be inferred from model URL. "
617
- "This is needed to fetch the chat template from the Hub since the model is not "
618
- "served with a Chat-completion API."
760
+ "Streaming token is not supported by the model. This is due to the model not been served by a "
761
+ "Text-Generation-Inference server. Please pass `stream=False` as input."
762
+ )
763
+ if tool_choice is not None or tool_prompt is not None or tools is not None:
764
+ warnings.warn(
765
+ "Tools are not supported by the model. This is due to the model not been served by a "
766
+ "Text-Generation-Inference server. The provided tool parameters will be ignored."
619
767
  )
620
-
621
- # fetch chat template + tokens
622
- prompt = render_chat_prompt(model_id=model_id, token=self.token, messages=messages)
623
768
 
624
769
  # generate response
625
- stop_sequences = [stop] if isinstance(stop, str) else stop
626
770
  text_generation_output = await self.text_generation(
627
- prompt=prompt,
628
- details=True,
629
- stream=stream,
771
+ prompt=messages, # type: ignore # Not correct type but works implicitly
630
772
  model=model,
773
+ stream=False,
774
+ details=False,
631
775
  max_new_tokens=max_tokens,
632
776
  seed=seed,
633
- stop_sequences=stop_sequences,
777
+ stop_sequences=stop,
634
778
  temperature=temperature,
635
779
  top_p=top_p,
636
780
  )
637
781
 
638
- created = int(time.time())
639
-
640
- if stream:
641
- return _async_stream_chat_completion_response_from_text_generation(text_generation_output) # type: ignore [arg-type]
642
-
643
- if isinstance(text_generation_output, TextGenerationOutput):
644
- # General use case => format ChatCompletionOutput from text generation details
645
- content: str = text_generation_output.generated_text
646
- finish_reason: str = text_generation_output.details.finish_reason # type: ignore[union-attr]
647
- else:
648
- # Corner case: if server doesn't support details (e.g. if not a TGI server), we only receive an output string.
649
- # In such a case, `finish_reason` is set to `"unk"`.
650
- content = text_generation_output # type: ignore[assignment]
651
- finish_reason = "unk"
652
-
782
+ # Format as a ChatCompletionOutput with dummy values for fields we can't provide
653
783
  return ChatCompletionOutput(
654
- created=created,
784
+ id="dummy",
785
+ model="dummy",
786
+ object="dummy",
787
+ system_fingerprint="dummy",
788
+ usage=None, # type: ignore # set to `None` as we don't want to provide false information
789
+ created=int(time.time()),
655
790
  choices=[
656
- ChatCompletionOutputChoice(
657
- finish_reason=finish_reason, # type: ignore
791
+ ChatCompletionOutputComplete(
792
+ finish_reason="unk", # type: ignore # set to `unk` as we don't want to provide false information
658
793
  index=0,
659
- message=ChatCompletionOutputChoiceMessage(
660
- content=content,
794
+ message=ChatCompletionOutputMessage(
795
+ content=text_generation_output,
661
796
  role="assistant",
662
797
  ),
663
798
  )
@@ -1063,7 +1198,7 @@ class AsyncInferenceClient:
1063
1198
  self, frameworks: Union[None, str, Literal["all"], List[str]] = None
1064
1199
  ) -> Dict[str, List[str]]:
1065
1200
  """
1066
- List models currently deployed on the Inference API service.
1201
+ List models deployed on the Serverless Inference API service.
1067
1202
 
1068
1203
  This helper checks deployed models framework by framework. By default, it will check the 4 main frameworks that
1069
1204
  are supported and account for 95% of the hosted models. However, if you want a complete list of models you can
@@ -1071,9 +1206,17 @@ class AsyncInferenceClient:
1071
1206
  in, you can also restrict to search to this one (e.g. `frameworks="text-generation-inference"`). The more
1072
1207
  frameworks are checked, the more time it will take.
1073
1208
 
1209
+ <Tip warning={true}>
1210
+
1211
+ This endpoint method does not return a live list of all models available for the Serverless Inference API service.
1212
+ It searches over a cached list of models that were recently available and the list may not be up to date.
1213
+ If you want to know the live status of a specific model, use [`~InferenceClient.get_model_status`].
1214
+
1215
+ </Tip>
1216
+
1074
1217
  <Tip>
1075
1218
 
1076
- This endpoint is mostly useful for discoverability. If you already know which model you want to use and want to
1219
+ This endpoint method is mostly useful for discoverability. If you already know which model you want to use and want to
1077
1220
  check its availability, you can directly use [`~InferenceClient.get_model_status`].
1078
1221
 
1079
1222
  </Tip>
@@ -1499,19 +1642,24 @@ class AsyncInferenceClient:
1499
1642
  details: Literal[False] = ...,
1500
1643
  stream: Literal[False] = ...,
1501
1644
  model: Optional[str] = None,
1502
- do_sample: bool = False,
1503
- max_new_tokens: int = 20,
1645
+ # Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
1504
1646
  best_of: Optional[int] = None,
1647
+ decoder_input_details: Optional[bool] = None,
1648
+ do_sample: Optional[bool] = False, # Manual default value
1649
+ frequency_penalty: Optional[float] = None,
1650
+ grammar: Optional[TextGenerationInputGrammarType] = None,
1651
+ max_new_tokens: Optional[int] = None,
1505
1652
  repetition_penalty: Optional[float] = None,
1506
- return_full_text: bool = False,
1653
+ return_full_text: Optional[bool] = False, # Manual default value
1507
1654
  seed: Optional[int] = None,
1508
- stop_sequences: Optional[List[str]] = None,
1655
+ stop_sequences: Optional[List[str]] = None, # Same as `stop`
1509
1656
  temperature: Optional[float] = None,
1510
1657
  top_k: Optional[int] = None,
1658
+ top_n_tokens: Optional[int] = None,
1511
1659
  top_p: Optional[float] = None,
1512
1660
  truncate: Optional[int] = None,
1513
1661
  typical_p: Optional[float] = None,
1514
- watermark: bool = False,
1662
+ watermark: Optional[bool] = None,
1515
1663
  ) -> str: ...
1516
1664
 
1517
1665
  @overload
@@ -1522,19 +1670,24 @@ class AsyncInferenceClient:
1522
1670
  details: Literal[True] = ...,
1523
1671
  stream: Literal[False] = ...,
1524
1672
  model: Optional[str] = None,
1525
- do_sample: bool = False,
1526
- max_new_tokens: int = 20,
1673
+ # Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
1527
1674
  best_of: Optional[int] = None,
1675
+ decoder_input_details: Optional[bool] = None,
1676
+ do_sample: Optional[bool] = False, # Manual default value
1677
+ frequency_penalty: Optional[float] = None,
1678
+ grammar: Optional[TextGenerationInputGrammarType] = None,
1679
+ max_new_tokens: Optional[int] = None,
1528
1680
  repetition_penalty: Optional[float] = None,
1529
- return_full_text: bool = False,
1681
+ return_full_text: Optional[bool] = False, # Manual default value
1530
1682
  seed: Optional[int] = None,
1531
- stop_sequences: Optional[List[str]] = None,
1683
+ stop_sequences: Optional[List[str]] = None, # Same as `stop`
1532
1684
  temperature: Optional[float] = None,
1533
1685
  top_k: Optional[int] = None,
1686
+ top_n_tokens: Optional[int] = None,
1534
1687
  top_p: Optional[float] = None,
1535
1688
  truncate: Optional[int] = None,
1536
1689
  typical_p: Optional[float] = None,
1537
- watermark: bool = False,
1690
+ watermark: Optional[bool] = None,
1538
1691
  ) -> TextGenerationOutput: ...
1539
1692
 
1540
1693
  @overload
@@ -1545,19 +1698,24 @@ class AsyncInferenceClient:
1545
1698
  details: Literal[False] = ...,
1546
1699
  stream: Literal[True] = ...,
1547
1700
  model: Optional[str] = None,
1548
- do_sample: bool = False,
1549
- max_new_tokens: int = 20,
1701
+ # Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
1550
1702
  best_of: Optional[int] = None,
1703
+ decoder_input_details: Optional[bool] = None,
1704
+ do_sample: Optional[bool] = False, # Manual default value
1705
+ frequency_penalty: Optional[float] = None,
1706
+ grammar: Optional[TextGenerationInputGrammarType] = None,
1707
+ max_new_tokens: Optional[int] = None,
1551
1708
  repetition_penalty: Optional[float] = None,
1552
- return_full_text: bool = False,
1709
+ return_full_text: Optional[bool] = False, # Manual default value
1553
1710
  seed: Optional[int] = None,
1554
- stop_sequences: Optional[List[str]] = None,
1711
+ stop_sequences: Optional[List[str]] = None, # Same as `stop`
1555
1712
  temperature: Optional[float] = None,
1556
1713
  top_k: Optional[int] = None,
1714
+ top_n_tokens: Optional[int] = None,
1557
1715
  top_p: Optional[float] = None,
1558
1716
  truncate: Optional[int] = None,
1559
1717
  typical_p: Optional[float] = None,
1560
- watermark: bool = False,
1718
+ watermark: Optional[bool] = None,
1561
1719
  ) -> AsyncIterable[str]: ...
1562
1720
 
1563
1721
  @overload
@@ -1568,19 +1726,24 @@ class AsyncInferenceClient:
1568
1726
  details: Literal[True] = ...,
1569
1727
  stream: Literal[True] = ...,
1570
1728
  model: Optional[str] = None,
1571
- do_sample: bool = False,
1572
- max_new_tokens: int = 20,
1729
+ # Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
1573
1730
  best_of: Optional[int] = None,
1731
+ decoder_input_details: Optional[bool] = None,
1732
+ do_sample: Optional[bool] = False, # Manual default value
1733
+ frequency_penalty: Optional[float] = None,
1734
+ grammar: Optional[TextGenerationInputGrammarType] = None,
1735
+ max_new_tokens: Optional[int] = None,
1574
1736
  repetition_penalty: Optional[float] = None,
1575
- return_full_text: bool = False,
1737
+ return_full_text: Optional[bool] = False, # Manual default value
1576
1738
  seed: Optional[int] = None,
1577
- stop_sequences: Optional[List[str]] = None,
1739
+ stop_sequences: Optional[List[str]] = None, # Same as `stop`
1578
1740
  temperature: Optional[float] = None,
1579
1741
  top_k: Optional[int] = None,
1742
+ top_n_tokens: Optional[int] = None,
1580
1743
  top_p: Optional[float] = None,
1581
1744
  truncate: Optional[int] = None,
1582
1745
  typical_p: Optional[float] = None,
1583
- watermark: bool = False,
1746
+ watermark: Optional[bool] = None,
1584
1747
  ) -> AsyncIterable[TextGenerationStreamOutput]: ...
1585
1748
 
1586
1749
  @overload
@@ -1591,19 +1754,24 @@ class AsyncInferenceClient:
1591
1754
  details: Literal[True] = ...,
1592
1755
  stream: bool = ...,
1593
1756
  model: Optional[str] = None,
1594
- do_sample: bool = False,
1595
- max_new_tokens: int = 20,
1757
+ # Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
1596
1758
  best_of: Optional[int] = None,
1759
+ decoder_input_details: Optional[bool] = None,
1760
+ do_sample: Optional[bool] = False, # Manual default value
1761
+ frequency_penalty: Optional[float] = None,
1762
+ grammar: Optional[TextGenerationInputGrammarType] = None,
1763
+ max_new_tokens: Optional[int] = None,
1597
1764
  repetition_penalty: Optional[float] = None,
1598
- return_full_text: bool = False,
1765
+ return_full_text: Optional[bool] = False, # Manual default value
1599
1766
  seed: Optional[int] = None,
1600
- stop_sequences: Optional[List[str]] = None,
1767
+ stop_sequences: Optional[List[str]] = None, # Same as `stop`
1601
1768
  temperature: Optional[float] = None,
1602
1769
  top_k: Optional[int] = None,
1770
+ top_n_tokens: Optional[int] = None,
1603
1771
  top_p: Optional[float] = None,
1604
1772
  truncate: Optional[int] = None,
1605
1773
  typical_p: Optional[float] = None,
1606
- watermark: bool = False,
1774
+ watermark: Optional[bool] = None,
1607
1775
  ) -> Union[TextGenerationOutput, AsyncIterable[TextGenerationStreamOutput]]: ...
1608
1776
 
1609
1777
  async def text_generation(
@@ -1613,20 +1781,24 @@ class AsyncInferenceClient:
1613
1781
  details: bool = False,
1614
1782
  stream: bool = False,
1615
1783
  model: Optional[str] = None,
1616
- do_sample: bool = False,
1617
- max_new_tokens: int = 20,
1784
+ # Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
1618
1785
  best_of: Optional[int] = None,
1786
+ decoder_input_details: Optional[bool] = None,
1787
+ do_sample: Optional[bool] = False, # Manual default value
1788
+ frequency_penalty: Optional[float] = None,
1789
+ grammar: Optional[TextGenerationInputGrammarType] = None,
1790
+ max_new_tokens: Optional[int] = None,
1619
1791
  repetition_penalty: Optional[float] = None,
1620
- return_full_text: bool = False,
1792
+ return_full_text: Optional[bool] = False, # Manual default value
1621
1793
  seed: Optional[int] = None,
1622
- stop_sequences: Optional[List[str]] = None,
1794
+ stop_sequences: Optional[List[str]] = None, # Same as `stop`
1623
1795
  temperature: Optional[float] = None,
1624
1796
  top_k: Optional[int] = None,
1797
+ top_n_tokens: Optional[int] = None,
1625
1798
  top_p: Optional[float] = None,
1626
1799
  truncate: Optional[int] = None,
1627
1800
  typical_p: Optional[float] = None,
1628
- watermark: bool = False,
1629
- decoder_input_details: bool = False,
1801
+ watermark: Optional[bool] = None,
1630
1802
  ) -> Union[str, TextGenerationOutput, AsyncIterable[str], AsyncIterable[TextGenerationStreamOutput]]:
1631
1803
  """
1632
1804
  Given a prompt, generate the following text.
@@ -1654,38 +1826,46 @@ class AsyncInferenceClient:
1654
1826
  model (`str`, *optional*):
1655
1827
  The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
1656
1828
  Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
1657
- do_sample (`bool`):
1829
+ best_of (`int`, *optional*):
1830
+ Generate best_of sequences and return the one if the highest token logprobs.
1831
+ decoder_input_details (`bool`, *optional*):
1832
+ Return the decoder input token logprobs and ids. You must set `details=True` as well for it to be taken
1833
+ into account. Defaults to `False`.
1834
+ do_sample (`bool`, *optional*):
1658
1835
  Activate logits sampling
1659
- max_new_tokens (`int`):
1836
+ frequency_penalty (`float`, *optional*):
1837
+ Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in
1838
+ the text so far, decreasing the model's likelihood to repeat the same line verbatim.
1839
+ grammar ([`TextGenerationInputGrammarType`], *optional*):
1840
+ Grammar constraints. Can be either a JSONSchema or a regex.
1841
+ max_new_tokens (`int`, *optional*):
1660
1842
  Maximum number of generated tokens
1661
- best_of (`int`):
1662
- Generate best_of sequences and return the one if the highest token logprobs
1663
- repetition_penalty (`float`):
1843
+ repetition_penalty (`float`, *optional*):
1664
1844
  The parameter for repetition penalty. 1.0 means no penalty. See [this
1665
1845
  paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
1666
- return_full_text (`bool`):
1846
+ return_full_text (`bool`, *optional*):
1667
1847
  Whether to prepend the prompt to the generated text
1668
- seed (`int`):
1848
+ seed (`int`, *optional*):
1669
1849
  Random sampling seed
1670
- stop_sequences (`List[str]`):
1850
+ stop_sequences (`List[str]`, *optional*):
1671
1851
  Stop generating tokens if a member of `stop_sequences` is generated
1672
- temperature (`float`):
1852
+ temperature (`float`, *optional*):
1673
1853
  The value used to module the logits distribution.
1674
- top_k (`int`):
1854
+ top_n_tokens (`int`, *optional*):
1855
+ Return information about the `top_n_tokens` most likely tokens at each generation step, instead of
1856
+ just the sampled token.
1857
+ top_k (`int`, *optional`):
1675
1858
  The number of highest probability vocabulary tokens to keep for top-k-filtering.
1676
- top_p (`float`):
1859
+ top_p (`float`, *optional`):
1677
1860
  If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
1678
1861
  higher are kept for generation.
1679
- truncate (`int`):
1680
- Truncate inputs tokens to the given size
1681
- typical_p (`float`):
1862
+ truncate (`int`, *optional`):
1863
+ Truncate inputs tokens to the given size.
1864
+ typical_p (`float`, *optional`):
1682
1865
  Typical Decoding mass
1683
1866
  See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
1684
- watermark (`bool`):
1867
+ watermark (`bool`, *optional`):
1685
1868
  Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
1686
- decoder_input_details (`bool`):
1687
- Return the decoder input token logprobs and ids. You must set `details=True` as well for it to be taken
1688
- into account. Defaults to `False`.
1689
1869
 
1690
1870
  Returns:
1691
1871
  `Union[str, TextGenerationOutput, Iterable[str], Iterable[TextGenerationStreamOutput]]`:
@@ -1738,10 +1918,10 @@ class AsyncInferenceClient:
1738
1918
  generated_tokens=12,
1739
1919
  seed=None,
1740
1920
  prefill=[
1741
- TextGenerationPrefillToken(id=487, text='The', logprob=None),
1742
- TextGenerationPrefillToken(id=53789, text=' hugging', logprob=-13.171875),
1921
+ TextGenerationPrefillOutputToken(id=487, text='The', logprob=None),
1922
+ TextGenerationPrefillOutputToken(id=53789, text=' hugging', logprob=-13.171875),
1743
1923
  (...)
1744
- TextGenerationPrefillToken(id=204, text=' ', logprob=-7.0390625)
1924
+ TextGenerationPrefillOutputToken(id=204, text=' ', logprob=-7.0390625)
1745
1925
  ],
1746
1926
  tokens=[
1747
1927
  TokenElement(id=1425, text='100', logprob=-1.0175781, special=False),
@@ -1775,8 +1955,35 @@ class AsyncInferenceClient:
1775
1955
  logprob=-0.5703125,
1776
1956
  special=False),
1777
1957
  generated_text='100% open source and built to be easy to use.',
1778
- details=TextGenerationStreamDetails(finish_reason='length', generated_tokens=12, seed=None)
1958
+ details=TextGenerationStreamOutputStreamDetails(finish_reason='length', generated_tokens=12, seed=None)
1779
1959
  )
1960
+
1961
+ # Case 5: generate constrained output using grammar
1962
+ >>> response = await client.text_generation(
1963
+ ... prompt="I saw a puppy a cat and a raccoon during my bike ride in the park",
1964
+ ... model="HuggingFaceH4/zephyr-orpo-141b-A35b-v0.1",
1965
+ ... max_new_tokens=100,
1966
+ ... repetition_penalty=1.3,
1967
+ ... grammar={
1968
+ ... "type": "json",
1969
+ ... "value": {
1970
+ ... "properties": {
1971
+ ... "location": {"type": "string"},
1972
+ ... "activity": {"type": "string"},
1973
+ ... "animals_seen": {"type": "integer", "minimum": 1, "maximum": 5},
1974
+ ... "animals": {"type": "array", "items": {"type": "string"}},
1975
+ ... },
1976
+ ... "required": ["location", "activity", "animals_seen", "animals"],
1977
+ ... },
1978
+ ... },
1979
+ ... )
1980
+ >>> json.loads(response)
1981
+ {
1982
+ "activity": "bike riding",
1983
+ "animals": ["puppy", "cat", "raccoon"],
1984
+ "animals_seen": 3,
1985
+ "location": "park"
1986
+ }
1780
1987
  ```
1781
1988
  """
1782
1989
  if decoder_input_details and not details:
@@ -1787,41 +1994,48 @@ class AsyncInferenceClient:
1787
1994
  decoder_input_details = False
1788
1995
 
1789
1996
  # Build payload
1997
+ parameters = {
1998
+ "best_of": best_of,
1999
+ "decoder_input_details": decoder_input_details,
2000
+ "do_sample": do_sample,
2001
+ "frequency_penalty": frequency_penalty,
2002
+ "grammar": grammar,
2003
+ "max_new_tokens": max_new_tokens,
2004
+ "repetition_penalty": repetition_penalty,
2005
+ "return_full_text": return_full_text,
2006
+ "seed": seed,
2007
+ "stop": stop_sequences if stop_sequences is not None else [],
2008
+ "temperature": temperature,
2009
+ "top_k": top_k,
2010
+ "top_n_tokens": top_n_tokens,
2011
+ "top_p": top_p,
2012
+ "truncate": truncate,
2013
+ "typical_p": typical_p,
2014
+ "watermark": watermark,
2015
+ }
2016
+ parameters = {k: v for k, v in parameters.items() if v is not None}
1790
2017
  payload = {
1791
2018
  "inputs": prompt,
1792
- "parameters": {
1793
- "best_of": best_of,
1794
- "decoder_input_details": decoder_input_details,
1795
- "details": details,
1796
- "do_sample": do_sample,
1797
- "max_new_tokens": max_new_tokens,
1798
- "repetition_penalty": repetition_penalty,
1799
- "return_full_text": return_full_text,
1800
- "seed": seed,
1801
- "stop": stop_sequences if stop_sequences is not None else [],
1802
- "temperature": temperature,
1803
- "top_k": top_k,
1804
- "top_p": top_p,
1805
- "truncate": truncate,
1806
- "typical_p": typical_p,
1807
- "watermark": watermark,
1808
- },
2019
+ "parameters": parameters,
1809
2020
  "stream": stream,
1810
2021
  }
1811
2022
 
1812
2023
  # Remove some parameters if not a TGI server
1813
- if not _is_tgi_server(model):
1814
- parameters: Dict = payload["parameters"] # type: ignore [assignment]
2024
+ unsupported_kwargs = _get_unsupported_text_generation_kwargs(model)
2025
+ if len(unsupported_kwargs) > 0:
2026
+ # The server does not support some parameters
2027
+ # => means it is not a TGI server
2028
+ # => remove unsupported parameters and warn the user
1815
2029
 
1816
2030
  ignored_parameters = []
1817
- for key in "watermark", "details", "decoder_input_details", "best_of", "stop", "return_full_text":
1818
- if parameters[key] is not None:
2031
+ for key in unsupported_kwargs:
2032
+ if parameters.get(key):
1819
2033
  ignored_parameters.append(key)
1820
- del parameters[key]
2034
+ parameters.pop(key, None)
1821
2035
  if len(ignored_parameters) > 0:
1822
2036
  warnings.warn(
1823
- "API endpoint/model for text-generation is not served via TGI. Ignoring parameters"
1824
- f" {ignored_parameters}.",
2037
+ "API endpoint/model for text-generation is not served via TGI. Ignoring following parameters:"
2038
+ f" {', '.join(ignored_parameters)}.",
1825
2039
  UserWarning,
1826
2040
  )
1827
2041
  if details:
@@ -1841,28 +2055,32 @@ class AsyncInferenceClient:
1841
2055
  try:
1842
2056
  bytes_output = await self.post(json=payload, model=model, task="text-generation", stream=stream) # type: ignore
1843
2057
  except _import_aiohttp().ClientResponseError as e:
1844
- error_message = getattr(e, "response_error_payload", {}).get("error", "")
1845
- if e.code == 400 and "The following `model_kwargs` are not used by the model" in error_message:
1846
- _set_as_non_tgi(model)
2058
+ match = MODEL_KWARGS_NOT_USED_REGEX.search(e.response_error_payload["error"])
2059
+ if e.status == 400 and match:
2060
+ unused_params = [kwarg.strip("' ") for kwarg in match.group(1).split(",")]
2061
+ _set_unsupported_text_generation_kwargs(model, unused_params)
1847
2062
  return await self.text_generation( # type: ignore
1848
2063
  prompt=prompt,
1849
2064
  details=details,
1850
2065
  stream=stream,
1851
2066
  model=model,
2067
+ best_of=best_of,
2068
+ decoder_input_details=decoder_input_details,
1852
2069
  do_sample=do_sample,
2070
+ frequency_penalty=frequency_penalty,
2071
+ grammar=grammar,
1853
2072
  max_new_tokens=max_new_tokens,
1854
- best_of=best_of,
1855
2073
  repetition_penalty=repetition_penalty,
1856
2074
  return_full_text=return_full_text,
1857
2075
  seed=seed,
1858
2076
  stop_sequences=stop_sequences,
1859
2077
  temperature=temperature,
1860
2078
  top_k=top_k,
2079
+ top_n_tokens=top_n_tokens,
1861
2080
  top_p=top_p,
1862
2081
  truncate=truncate,
1863
2082
  typical_p=typical_p,
1864
2083
  watermark=watermark,
1865
- decoder_input_details=decoder_input_details,
1866
2084
  )
1867
2085
  raise_text_generation_error(e)
1868
2086