huggingface-hub 0.22.1__py3-none-any.whl → 0.23.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of huggingface-hub might be problematic. Click here for more details.
- huggingface_hub/__init__.py +51 -19
- huggingface_hub/_commit_api.py +10 -9
- huggingface_hub/_commit_scheduler.py +2 -2
- huggingface_hub/_inference_endpoints.py +10 -17
- huggingface_hub/_local_folder.py +229 -0
- huggingface_hub/_login.py +4 -3
- huggingface_hub/_multi_commits.py +1 -1
- huggingface_hub/_snapshot_download.py +16 -38
- huggingface_hub/_tensorboard_logger.py +16 -6
- huggingface_hub/_webhooks_payload.py +22 -1
- huggingface_hub/_webhooks_server.py +24 -20
- huggingface_hub/commands/download.py +11 -34
- huggingface_hub/commands/huggingface_cli.py +2 -0
- huggingface_hub/commands/tag.py +159 -0
- huggingface_hub/constants.py +3 -5
- huggingface_hub/errors.py +58 -0
- huggingface_hub/file_download.py +545 -376
- huggingface_hub/hf_api.py +758 -629
- huggingface_hub/hf_file_system.py +14 -5
- huggingface_hub/hub_mixin.py +127 -43
- huggingface_hub/inference/_client.py +402 -183
- huggingface_hub/inference/_common.py +19 -29
- huggingface_hub/inference/_generated/_async_client.py +402 -184
- huggingface_hub/inference/_generated/types/__init__.py +23 -6
- huggingface_hub/inference/_generated/types/chat_completion.py +197 -43
- huggingface_hub/inference/_generated/types/text_generation.py +57 -79
- huggingface_hub/inference/_templating.py +2 -4
- huggingface_hub/keras_mixin.py +0 -3
- huggingface_hub/lfs.py +16 -4
- huggingface_hub/repository.py +1 -0
- huggingface_hub/utils/__init__.py +19 -6
- huggingface_hub/utils/_fixes.py +1 -0
- huggingface_hub/utils/_headers.py +2 -4
- huggingface_hub/utils/_http.py +16 -5
- huggingface_hub/utils/_paths.py +13 -1
- huggingface_hub/utils/_runtime.py +10 -0
- huggingface_hub/utils/_safetensors.py +0 -13
- huggingface_hub/utils/_validators.py +2 -7
- huggingface_hub/utils/tqdm.py +124 -46
- {huggingface_hub-0.22.1.dist-info → huggingface_hub-0.23.0.dist-info}/METADATA +5 -1
- {huggingface_hub-0.22.1.dist-info → huggingface_hub-0.23.0.dist-info}/RECORD +45 -43
- {huggingface_hub-0.22.1.dist-info → huggingface_hub-0.23.0.dist-info}/LICENSE +0 -0
- {huggingface_hub-0.22.1.dist-info → huggingface_hub-0.23.0.dist-info}/WHEEL +0 -0
- {huggingface_hub-0.22.1.dist-info → huggingface_hub-0.23.0.dist-info}/entry_points.txt +0 -0
- {huggingface_hub-0.22.1.dist-info → huggingface_hub-0.23.0.dist-info}/top_level.txt +0 -0
|
@@ -34,6 +34,7 @@
|
|
|
34
34
|
# - Only the main parameters are publicly exposed. Power users can always read the docs for more options.
|
|
35
35
|
import base64
|
|
36
36
|
import logging
|
|
37
|
+
import re
|
|
37
38
|
import time
|
|
38
39
|
import warnings
|
|
39
40
|
from typing import (
|
|
@@ -63,14 +64,13 @@ from huggingface_hub.inference._common import (
|
|
|
63
64
|
_bytes_to_image,
|
|
64
65
|
_bytes_to_list,
|
|
65
66
|
_fetch_recommended_models,
|
|
67
|
+
_get_unsupported_text_generation_kwargs,
|
|
66
68
|
_import_numpy,
|
|
67
69
|
_is_chat_completion_server,
|
|
68
|
-
_is_tgi_server,
|
|
69
70
|
_open_as_binary,
|
|
70
71
|
_set_as_non_chat_completion_server,
|
|
71
|
-
|
|
72
|
+
_set_unsupported_text_generation_kwargs,
|
|
72
73
|
_stream_chat_completion_response_from_bytes,
|
|
73
|
-
_stream_chat_completion_response_from_text_generation,
|
|
74
74
|
_stream_text_generation_response,
|
|
75
75
|
raise_text_generation_error,
|
|
76
76
|
)
|
|
@@ -78,9 +78,11 @@ from huggingface_hub.inference._generated.types import (
|
|
|
78
78
|
AudioClassificationOutputElement,
|
|
79
79
|
AudioToAudioOutputElement,
|
|
80
80
|
AutomaticSpeechRecognitionOutput,
|
|
81
|
+
ChatCompletionInputTool,
|
|
82
|
+
ChatCompletionInputToolTypeClass,
|
|
81
83
|
ChatCompletionOutput,
|
|
82
|
-
|
|
83
|
-
|
|
84
|
+
ChatCompletionOutputComplete,
|
|
85
|
+
ChatCompletionOutputMessage,
|
|
84
86
|
ChatCompletionStreamOutput,
|
|
85
87
|
DocumentQuestionAnsweringOutputElement,
|
|
86
88
|
FillMaskOutputElement,
|
|
@@ -92,6 +94,7 @@ from huggingface_hub.inference._generated.types import (
|
|
|
92
94
|
SummarizationOutput,
|
|
93
95
|
TableQuestionAnsweringOutputElement,
|
|
94
96
|
TextClassificationOutputElement,
|
|
97
|
+
TextGenerationInputGrammarType,
|
|
95
98
|
TextGenerationOutput,
|
|
96
99
|
TextGenerationStreamOutput,
|
|
97
100
|
TokenClassificationOutputElement,
|
|
@@ -100,7 +103,7 @@ from huggingface_hub.inference._generated.types import (
|
|
|
100
103
|
ZeroShotClassificationOutputElement,
|
|
101
104
|
ZeroShotImageClassificationOutputElement,
|
|
102
105
|
)
|
|
103
|
-
from huggingface_hub.inference.
|
|
106
|
+
from huggingface_hub.inference._generated.types.chat_completion import ChatCompletionInputToolTypeEnum
|
|
104
107
|
from huggingface_hub.inference._types import (
|
|
105
108
|
ConversationalOutput, # soon to be removed
|
|
106
109
|
)
|
|
@@ -114,11 +117,14 @@ from huggingface_hub.utils import (
|
|
|
114
117
|
|
|
115
118
|
if TYPE_CHECKING:
|
|
116
119
|
import numpy as np
|
|
117
|
-
from PIL import Image
|
|
120
|
+
from PIL.Image import Image
|
|
118
121
|
|
|
119
122
|
logger = logging.getLogger(__name__)
|
|
120
123
|
|
|
121
124
|
|
|
125
|
+
MODEL_KWARGS_NOT_USED_REGEX = re.compile(r"The following `model_kwargs` are not used by the model: \[(.*?)\]")
|
|
126
|
+
|
|
127
|
+
|
|
122
128
|
class InferenceClient:
|
|
123
129
|
"""
|
|
124
130
|
Initialize a new Inference Client.
|
|
@@ -416,10 +422,19 @@ class InferenceClient:
|
|
|
416
422
|
*,
|
|
417
423
|
model: Optional[str] = None,
|
|
418
424
|
stream: Literal[False] = False,
|
|
419
|
-
|
|
425
|
+
frequency_penalty: Optional[float] = None,
|
|
426
|
+
logit_bias: Optional[List[float]] = None,
|
|
427
|
+
logprobs: Optional[bool] = None,
|
|
428
|
+
max_tokens: Optional[int] = None,
|
|
429
|
+
n: Optional[int] = None,
|
|
430
|
+
presence_penalty: Optional[float] = None,
|
|
420
431
|
seed: Optional[int] = None,
|
|
421
|
-
stop: Optional[
|
|
422
|
-
temperature: float =
|
|
432
|
+
stop: Optional[List[str]] = None,
|
|
433
|
+
temperature: Optional[float] = None,
|
|
434
|
+
tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, ChatCompletionInputToolTypeEnum]] = None,
|
|
435
|
+
tool_prompt: Optional[str] = None,
|
|
436
|
+
tools: Optional[List[ChatCompletionInputTool]] = None,
|
|
437
|
+
top_logprobs: Optional[int] = None,
|
|
423
438
|
top_p: Optional[float] = None,
|
|
424
439
|
) -> ChatCompletionOutput: ...
|
|
425
440
|
|
|
@@ -430,10 +445,19 @@ class InferenceClient:
|
|
|
430
445
|
*,
|
|
431
446
|
model: Optional[str] = None,
|
|
432
447
|
stream: Literal[True] = True,
|
|
433
|
-
|
|
448
|
+
frequency_penalty: Optional[float] = None,
|
|
449
|
+
logit_bias: Optional[List[float]] = None,
|
|
450
|
+
logprobs: Optional[bool] = None,
|
|
451
|
+
max_tokens: Optional[int] = None,
|
|
452
|
+
n: Optional[int] = None,
|
|
453
|
+
presence_penalty: Optional[float] = None,
|
|
434
454
|
seed: Optional[int] = None,
|
|
435
|
-
stop: Optional[
|
|
436
|
-
temperature: float =
|
|
455
|
+
stop: Optional[List[str]] = None,
|
|
456
|
+
temperature: Optional[float] = None,
|
|
457
|
+
tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, ChatCompletionInputToolTypeEnum]] = None,
|
|
458
|
+
tool_prompt: Optional[str] = None,
|
|
459
|
+
tools: Optional[List[ChatCompletionInputTool]] = None,
|
|
460
|
+
top_logprobs: Optional[int] = None,
|
|
437
461
|
top_p: Optional[float] = None,
|
|
438
462
|
) -> Iterable[ChatCompletionStreamOutput]: ...
|
|
439
463
|
|
|
@@ -444,10 +468,19 @@ class InferenceClient:
|
|
|
444
468
|
*,
|
|
445
469
|
model: Optional[str] = None,
|
|
446
470
|
stream: bool = False,
|
|
447
|
-
|
|
471
|
+
frequency_penalty: Optional[float] = None,
|
|
472
|
+
logit_bias: Optional[List[float]] = None,
|
|
473
|
+
logprobs: Optional[bool] = None,
|
|
474
|
+
max_tokens: Optional[int] = None,
|
|
475
|
+
n: Optional[int] = None,
|
|
476
|
+
presence_penalty: Optional[float] = None,
|
|
448
477
|
seed: Optional[int] = None,
|
|
449
|
-
stop: Optional[
|
|
450
|
-
temperature: float =
|
|
478
|
+
stop: Optional[List[str]] = None,
|
|
479
|
+
temperature: Optional[float] = None,
|
|
480
|
+
tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, ChatCompletionInputToolTypeEnum]] = None,
|
|
481
|
+
tool_prompt: Optional[str] = None,
|
|
482
|
+
tools: Optional[List[ChatCompletionInputTool]] = None,
|
|
483
|
+
top_logprobs: Optional[int] = None,
|
|
451
484
|
top_p: Optional[float] = None,
|
|
452
485
|
) -> Union[ChatCompletionOutput, Iterable[ChatCompletionStreamOutput]]: ...
|
|
453
486
|
|
|
@@ -457,10 +490,20 @@ class InferenceClient:
|
|
|
457
490
|
*,
|
|
458
491
|
model: Optional[str] = None,
|
|
459
492
|
stream: bool = False,
|
|
460
|
-
|
|
493
|
+
# Parameters from ChatCompletionInput (handled manually)
|
|
494
|
+
frequency_penalty: Optional[float] = None,
|
|
495
|
+
logit_bias: Optional[List[float]] = None,
|
|
496
|
+
logprobs: Optional[bool] = None,
|
|
497
|
+
max_tokens: Optional[int] = None,
|
|
498
|
+
n: Optional[int] = None,
|
|
499
|
+
presence_penalty: Optional[float] = None,
|
|
461
500
|
seed: Optional[int] = None,
|
|
462
|
-
stop: Optional[
|
|
463
|
-
temperature: float =
|
|
501
|
+
stop: Optional[List[str]] = None,
|
|
502
|
+
temperature: Optional[float] = None,
|
|
503
|
+
tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, ChatCompletionInputToolTypeEnum]] = None,
|
|
504
|
+
tool_prompt: Optional[str] = None,
|
|
505
|
+
tools: Optional[List[ChatCompletionInputTool]] = None,
|
|
506
|
+
top_logprobs: Optional[int] = None,
|
|
464
507
|
top_p: Optional[float] = None,
|
|
465
508
|
) -> Union[ChatCompletionOutput, Iterable[ChatCompletionStreamOutput]]:
|
|
466
509
|
"""
|
|
@@ -483,27 +526,52 @@ class InferenceClient:
|
|
|
483
526
|
The model to use for chat-completion. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
|
484
527
|
Inference Endpoint. If not provided, the default recommended model for chat-based text-generation will be used.
|
|
485
528
|
See https://huggingface.co/tasks/text-generation for more details.
|
|
486
|
-
frequency_penalty (`float`, optional):
|
|
529
|
+
frequency_penalty (`float`, *optional*):
|
|
487
530
|
Penalizes new tokens based on their existing frequency
|
|
488
531
|
in the text so far. Range: [-2.0, 2.0]. Defaults to 0.0.
|
|
489
|
-
|
|
532
|
+
logit_bias (`List[float]`, *optional*):
|
|
533
|
+
Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens
|
|
534
|
+
(specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,
|
|
535
|
+
the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,
|
|
536
|
+
but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should
|
|
537
|
+
result in a ban or exclusive selection of the relevant token. Defaults to None.
|
|
538
|
+
logprobs (`bool`, *optional*):
|
|
539
|
+
Whether to return log probabilities of the output tokens or not. If true, returns the log
|
|
540
|
+
probabilities of each output token returned in the content of message.
|
|
541
|
+
max_tokens (`int`, *optional*):
|
|
490
542
|
Maximum number of tokens allowed in the response. Defaults to 20.
|
|
491
|
-
|
|
543
|
+
n (`int`, *optional*):
|
|
544
|
+
UNUSED.
|
|
545
|
+
presence_penalty (`float`, *optional*):
|
|
546
|
+
Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the
|
|
547
|
+
text so far, increasing the model's likelihood to talk about new topics.
|
|
548
|
+
seed (Optional[`int`], *optional*):
|
|
492
549
|
Seed for reproducible control flow. Defaults to None.
|
|
493
|
-
stop (Optional[`str`], optional):
|
|
550
|
+
stop (Optional[`str`], *optional*):
|
|
494
551
|
Up to four strings which trigger the end of the response.
|
|
495
552
|
Defaults to None.
|
|
496
|
-
stream (`bool`, optional):
|
|
553
|
+
stream (`bool`, *optional*):
|
|
497
554
|
Enable realtime streaming of responses. Defaults to False.
|
|
498
|
-
temperature (`float`, optional):
|
|
555
|
+
temperature (`float`, *optional*):
|
|
499
556
|
Controls randomness of the generations. Lower values ensure
|
|
500
557
|
less random completions. Range: [0, 2]. Defaults to 1.0.
|
|
501
|
-
|
|
558
|
+
top_logprobs (`int`, *optional*):
|
|
559
|
+
An integer between 0 and 5 specifying the number of most likely tokens to return at each token
|
|
560
|
+
position, each with an associated log probability. logprobs must be set to true if this parameter is
|
|
561
|
+
used.
|
|
562
|
+
top_p (`float`, *optional*):
|
|
502
563
|
Fraction of the most likely next words to sample from.
|
|
503
564
|
Must be between 0 and 1. Defaults to 1.0.
|
|
565
|
+
tool_choice ([`ChatCompletionInputToolTypeClass`] or [`ChatCompletionInputToolTypeEnum`], *optional*):
|
|
566
|
+
The tool to use for the completion. Defaults to "auto".
|
|
567
|
+
tool_prompt (`str`, *optional*):
|
|
568
|
+
A prompt to be appended before the tools.
|
|
569
|
+
tools (List of [`ChatCompletionInputTool`], *optional*):
|
|
570
|
+
A list of tools the model may call. Currently, only functions are supported as a tool. Use this to
|
|
571
|
+
provide a list of functions the model may generate JSON inputs for.
|
|
504
572
|
|
|
505
573
|
Returns:
|
|
506
|
-
`
|
|
574
|
+
[`ChatCompletionOutput] or Iterable of [`ChatCompletionStreamOutput`]:
|
|
507
575
|
Generated text returned from the server:
|
|
508
576
|
- if `stream=False`, the generated text is returned as a [`ChatCompletionOutput`] (default).
|
|
509
577
|
- if `stream=True`, the generated text is returned token by token as a sequence of [`ChatCompletionStreamOutput`].
|
|
@@ -515,18 +583,20 @@ class InferenceClient:
|
|
|
515
583
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
516
584
|
|
|
517
585
|
Example:
|
|
586
|
+
|
|
518
587
|
```py
|
|
588
|
+
# Chat example
|
|
519
589
|
>>> from huggingface_hub import InferenceClient
|
|
520
590
|
>>> messages = [{"role": "user", "content": "What is the capital of France?"}]
|
|
521
591
|
>>> client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
|
|
522
592
|
>>> client.chat_completion(messages, max_tokens=100)
|
|
523
593
|
ChatCompletionOutput(
|
|
524
594
|
choices=[
|
|
525
|
-
|
|
595
|
+
ChatCompletionOutputComplete(
|
|
526
596
|
finish_reason='eos_token',
|
|
527
597
|
index=0,
|
|
528
|
-
message=
|
|
529
|
-
content='The capital of France is Paris. The official name of the city is
|
|
598
|
+
message=ChatCompletionOutputMessage(
|
|
599
|
+
content='The capital of France is Paris. The official name of the city is Ville de Paris (City of Paris) and the name of the country governing body, which is located in Paris, is La République française (The French Republic). \nI hope that helps! Let me know if you need any further information.'
|
|
530
600
|
)
|
|
531
601
|
)
|
|
532
602
|
],
|
|
@@ -539,7 +609,87 @@ class InferenceClient:
|
|
|
539
609
|
ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content=' capital', role='assistant'), index=0, finish_reason=None)], created=1710498504)
|
|
540
610
|
(...)
|
|
541
611
|
ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content=' may', role='assistant'), index=0, finish_reason=None)], created=1710498504)
|
|
542
|
-
|
|
612
|
+
|
|
613
|
+
# Chat example with tools
|
|
614
|
+
>>> client = InferenceClient("meta-llama/Meta-Llama-3-70B-Instruct")
|
|
615
|
+
>>> messages = [
|
|
616
|
+
... {
|
|
617
|
+
... "role": "system",
|
|
618
|
+
... "content": "Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.",
|
|
619
|
+
... },
|
|
620
|
+
... {
|
|
621
|
+
... "role": "user",
|
|
622
|
+
... "content": "What's the weather like the next 3 days in San Francisco, CA?",
|
|
623
|
+
... },
|
|
624
|
+
... ]
|
|
625
|
+
>>> tools = [
|
|
626
|
+
... {
|
|
627
|
+
... "type": "function",
|
|
628
|
+
... "function": {
|
|
629
|
+
... "name": "get_current_weather",
|
|
630
|
+
... "description": "Get the current weather",
|
|
631
|
+
... "parameters": {
|
|
632
|
+
... "type": "object",
|
|
633
|
+
... "properties": {
|
|
634
|
+
... "location": {
|
|
635
|
+
... "type": "string",
|
|
636
|
+
... "description": "The city and state, e.g. San Francisco, CA",
|
|
637
|
+
... },
|
|
638
|
+
... "format": {
|
|
639
|
+
... "type": "string",
|
|
640
|
+
... "enum": ["celsius", "fahrenheit"],
|
|
641
|
+
... "description": "The temperature unit to use. Infer this from the users location.",
|
|
642
|
+
... },
|
|
643
|
+
... },
|
|
644
|
+
... "required": ["location", "format"],
|
|
645
|
+
... },
|
|
646
|
+
... },
|
|
647
|
+
... },
|
|
648
|
+
... {
|
|
649
|
+
... "type": "function",
|
|
650
|
+
... "function": {
|
|
651
|
+
... "name": "get_n_day_weather_forecast",
|
|
652
|
+
... "description": "Get an N-day weather forecast",
|
|
653
|
+
... "parameters": {
|
|
654
|
+
... "type": "object",
|
|
655
|
+
... "properties": {
|
|
656
|
+
... "location": {
|
|
657
|
+
... "type": "string",
|
|
658
|
+
... "description": "The city and state, e.g. San Francisco, CA",
|
|
659
|
+
... },
|
|
660
|
+
... "format": {
|
|
661
|
+
... "type": "string",
|
|
662
|
+
... "enum": ["celsius", "fahrenheit"],
|
|
663
|
+
... "description": "The temperature unit to use. Infer this from the users location.",
|
|
664
|
+
... },
|
|
665
|
+
... "num_days": {
|
|
666
|
+
... "type": "integer",
|
|
667
|
+
... "description": "The number of days to forecast",
|
|
668
|
+
... },
|
|
669
|
+
... },
|
|
670
|
+
... "required": ["location", "format", "num_days"],
|
|
671
|
+
... },
|
|
672
|
+
... },
|
|
673
|
+
... },
|
|
674
|
+
... ]
|
|
675
|
+
|
|
676
|
+
>>> response = client.chat_completion(
|
|
677
|
+
... model="meta-llama/Meta-Llama-3-70B-Instruct",
|
|
678
|
+
... messages=messages,
|
|
679
|
+
... tools=tools,
|
|
680
|
+
... tool_choice="auto",
|
|
681
|
+
... max_tokens=500,
|
|
682
|
+
... )
|
|
683
|
+
>>> response.choices[0].message.tool_calls[0].function
|
|
684
|
+
ChatCompletionOutputFunctionDefinition(
|
|
685
|
+
arguments={
|
|
686
|
+
'location': 'San Francisco, CA',
|
|
687
|
+
'format': 'fahrenheit',
|
|
688
|
+
'num_days': 3
|
|
689
|
+
},
|
|
690
|
+
name='get_n_day_weather_forecast',
|
|
691
|
+
description=None
|
|
692
|
+
)
|
|
543
693
|
```
|
|
544
694
|
"""
|
|
545
695
|
# determine model
|
|
@@ -558,30 +708,44 @@ class InferenceClient:
|
|
|
558
708
|
json=dict(
|
|
559
709
|
model="tgi", # random string
|
|
560
710
|
messages=messages,
|
|
711
|
+
frequency_penalty=frequency_penalty,
|
|
712
|
+
logit_bias=logit_bias,
|
|
713
|
+
logprobs=logprobs,
|
|
561
714
|
max_tokens=max_tokens,
|
|
715
|
+
n=n,
|
|
716
|
+
presence_penalty=presence_penalty,
|
|
562
717
|
seed=seed,
|
|
563
718
|
stop=stop,
|
|
564
719
|
temperature=temperature,
|
|
720
|
+
tool_choice=tool_choice,
|
|
721
|
+
tool_prompt=tool_prompt,
|
|
722
|
+
tools=tools,
|
|
723
|
+
top_logprobs=top_logprobs,
|
|
565
724
|
top_p=top_p,
|
|
566
725
|
stream=stream,
|
|
567
726
|
),
|
|
568
727
|
stream=stream,
|
|
569
728
|
)
|
|
570
|
-
except HTTPError:
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
729
|
+
except HTTPError as e:
|
|
730
|
+
if e.response.status_code in (400, 404, 500):
|
|
731
|
+
# Let's consider the server is not a chat completion server.
|
|
732
|
+
# Then we call again `chat_completion` which will render the chat template client side.
|
|
733
|
+
# (can be HTTP 500, HTTP 400, HTTP 404 depending on the server)
|
|
734
|
+
_set_as_non_chat_completion_server(model)
|
|
735
|
+
logger.warning(
|
|
736
|
+
f"Server {model_url} does not seem to support chat completion. Falling back to text generation. Error: {e}"
|
|
737
|
+
)
|
|
738
|
+
return self.chat_completion(
|
|
739
|
+
messages=messages,
|
|
740
|
+
model=model,
|
|
741
|
+
stream=stream,
|
|
742
|
+
max_tokens=max_tokens,
|
|
743
|
+
seed=seed,
|
|
744
|
+
stop=stop,
|
|
745
|
+
temperature=temperature,
|
|
746
|
+
top_p=top_p,
|
|
747
|
+
)
|
|
748
|
+
raise
|
|
585
749
|
|
|
586
750
|
if stream:
|
|
587
751
|
return _stream_chat_completion_response_from_bytes(data) # type: ignore[arg-type]
|
|
@@ -589,75 +753,46 @@ class InferenceClient:
|
|
|
589
753
|
return ChatCompletionOutput.parse_obj_as_instance(data) # type: ignore[arg-type]
|
|
590
754
|
|
|
591
755
|
# At this point, we know the server is not a chat completion server.
|
|
592
|
-
#
|
|
593
|
-
# the
|
|
594
|
-
|
|
595
|
-
model_id = None
|
|
596
|
-
if model.startswith(("http://", "https://")):
|
|
597
|
-
# If URL, we need to know which model is served. This is not always possible.
|
|
598
|
-
# A workaround is to list the user Inference Endpoints and check if one of them correspond to the model URL.
|
|
599
|
-
# If not, we raise an error.
|
|
600
|
-
# TODO: fix when we have a proper API for this (at least for Inference Endpoints)
|
|
601
|
-
# TODO: what if Sagemaker URL?
|
|
602
|
-
# TODO: what if Azure URL?
|
|
603
|
-
from ..hf_api import HfApi
|
|
604
|
-
|
|
605
|
-
for endpoint in HfApi(token=self.token).list_inference_endpoints():
|
|
606
|
-
if endpoint.url == model:
|
|
607
|
-
model_id = endpoint.repository
|
|
608
|
-
break
|
|
609
|
-
else:
|
|
610
|
-
model_id = model
|
|
611
|
-
|
|
612
|
-
if model_id is None:
|
|
613
|
-
# If we don't have the model ID, we can't fetch the chat template.
|
|
614
|
-
# We raise an error.
|
|
756
|
+
# It means it's a transformers-backed server for which we can send a list of messages directly to the
|
|
757
|
+
# `text-generation` pipeline. We won't receive a detailed response but only the generated text.
|
|
758
|
+
if stream:
|
|
615
759
|
raise ValueError(
|
|
616
|
-
"
|
|
617
|
-
"
|
|
618
|
-
|
|
760
|
+
"Streaming token is not supported by the model. This is due to the model not been served by a "
|
|
761
|
+
"Text-Generation-Inference server. Please pass `stream=False` as input."
|
|
762
|
+
)
|
|
763
|
+
if tool_choice is not None or tool_prompt is not None or tools is not None:
|
|
764
|
+
warnings.warn(
|
|
765
|
+
"Tools are not supported by the model. This is due to the model not been served by a "
|
|
766
|
+
"Text-Generation-Inference server. The provided tool parameters will be ignored."
|
|
619
767
|
)
|
|
620
|
-
|
|
621
|
-
# fetch chat template + tokens
|
|
622
|
-
prompt = render_chat_prompt(model_id=model_id, token=self.token, messages=messages)
|
|
623
768
|
|
|
624
769
|
# generate response
|
|
625
|
-
stop_sequences = [stop] if isinstance(stop, str) else stop
|
|
626
770
|
text_generation_output = self.text_generation(
|
|
627
|
-
prompt=
|
|
628
|
-
details=True,
|
|
629
|
-
stream=stream,
|
|
771
|
+
prompt=messages, # type: ignore # Not correct type but works implicitly
|
|
630
772
|
model=model,
|
|
773
|
+
stream=False,
|
|
774
|
+
details=False,
|
|
631
775
|
max_new_tokens=max_tokens,
|
|
632
776
|
seed=seed,
|
|
633
|
-
stop_sequences=
|
|
777
|
+
stop_sequences=stop,
|
|
634
778
|
temperature=temperature,
|
|
635
779
|
top_p=top_p,
|
|
636
780
|
)
|
|
637
781
|
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
if stream:
|
|
641
|
-
return _stream_chat_completion_response_from_text_generation(text_generation_output) # type: ignore [arg-type]
|
|
642
|
-
|
|
643
|
-
if isinstance(text_generation_output, TextGenerationOutput):
|
|
644
|
-
# General use case => format ChatCompletionOutput from text generation details
|
|
645
|
-
content: str = text_generation_output.generated_text
|
|
646
|
-
finish_reason: str = text_generation_output.details.finish_reason # type: ignore[union-attr]
|
|
647
|
-
else:
|
|
648
|
-
# Corner case: if server doesn't support details (e.g. if not a TGI server), we only receive an output string.
|
|
649
|
-
# In such a case, `finish_reason` is set to `"unk"`.
|
|
650
|
-
content = text_generation_output # type: ignore[assignment]
|
|
651
|
-
finish_reason = "unk"
|
|
652
|
-
|
|
782
|
+
# Format as a ChatCompletionOutput with dummy values for fields we can't provide
|
|
653
783
|
return ChatCompletionOutput(
|
|
654
|
-
|
|
784
|
+
id="dummy",
|
|
785
|
+
model="dummy",
|
|
786
|
+
object="dummy",
|
|
787
|
+
system_fingerprint="dummy",
|
|
788
|
+
usage=None, # type: ignore # set to `None` as we don't want to provide false information
|
|
789
|
+
created=int(time.time()),
|
|
655
790
|
choices=[
|
|
656
|
-
|
|
657
|
-
finish_reason=
|
|
791
|
+
ChatCompletionOutputComplete(
|
|
792
|
+
finish_reason="unk", # type: ignore # set to `unk` as we don't want to provide false information
|
|
658
793
|
index=0,
|
|
659
|
-
message=
|
|
660
|
-
content=
|
|
794
|
+
message=ChatCompletionOutputMessage(
|
|
795
|
+
content=text_generation_output,
|
|
661
796
|
role="assistant",
|
|
662
797
|
),
|
|
663
798
|
)
|
|
@@ -1055,7 +1190,7 @@ class InferenceClient:
|
|
|
1055
1190
|
self, frameworks: Union[None, str, Literal["all"], List[str]] = None
|
|
1056
1191
|
) -> Dict[str, List[str]]:
|
|
1057
1192
|
"""
|
|
1058
|
-
List models
|
|
1193
|
+
List models deployed on the Serverless Inference API service.
|
|
1059
1194
|
|
|
1060
1195
|
This helper checks deployed models framework by framework. By default, it will check the 4 main frameworks that
|
|
1061
1196
|
are supported and account for 95% of the hosted models. However, if you want a complete list of models you can
|
|
@@ -1063,9 +1198,17 @@ class InferenceClient:
|
|
|
1063
1198
|
in, you can also restrict to search to this one (e.g. `frameworks="text-generation-inference"`). The more
|
|
1064
1199
|
frameworks are checked, the more time it will take.
|
|
1065
1200
|
|
|
1201
|
+
<Tip warning={true}>
|
|
1202
|
+
|
|
1203
|
+
This endpoint method does not return a live list of all models available for the Serverless Inference API service.
|
|
1204
|
+
It searches over a cached list of models that were recently available and the list may not be up to date.
|
|
1205
|
+
If you want to know the live status of a specific model, use [`~InferenceClient.get_model_status`].
|
|
1206
|
+
|
|
1207
|
+
</Tip>
|
|
1208
|
+
|
|
1066
1209
|
<Tip>
|
|
1067
1210
|
|
|
1068
|
-
This endpoint is mostly useful for discoverability. If you already know which model you want to use and want to
|
|
1211
|
+
This endpoint method is mostly useful for discoverability. If you already know which model you want to use and want to
|
|
1069
1212
|
check its availability, you can directly use [`~InferenceClient.get_model_status`].
|
|
1070
1213
|
|
|
1071
1214
|
</Tip>
|
|
@@ -1475,19 +1618,24 @@ class InferenceClient:
|
|
|
1475
1618
|
details: Literal[False] = ...,
|
|
1476
1619
|
stream: Literal[False] = ...,
|
|
1477
1620
|
model: Optional[str] = None,
|
|
1478
|
-
|
|
1479
|
-
max_new_tokens: int = 20,
|
|
1621
|
+
# Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
|
|
1480
1622
|
best_of: Optional[int] = None,
|
|
1623
|
+
decoder_input_details: Optional[bool] = None,
|
|
1624
|
+
do_sample: Optional[bool] = False, # Manual default value
|
|
1625
|
+
frequency_penalty: Optional[float] = None,
|
|
1626
|
+
grammar: Optional[TextGenerationInputGrammarType] = None,
|
|
1627
|
+
max_new_tokens: Optional[int] = None,
|
|
1481
1628
|
repetition_penalty: Optional[float] = None,
|
|
1482
|
-
return_full_text: bool = False,
|
|
1629
|
+
return_full_text: Optional[bool] = False, # Manual default value
|
|
1483
1630
|
seed: Optional[int] = None,
|
|
1484
|
-
stop_sequences: Optional[List[str]] = None,
|
|
1631
|
+
stop_sequences: Optional[List[str]] = None, # Same as `stop`
|
|
1485
1632
|
temperature: Optional[float] = None,
|
|
1486
1633
|
top_k: Optional[int] = None,
|
|
1634
|
+
top_n_tokens: Optional[int] = None,
|
|
1487
1635
|
top_p: Optional[float] = None,
|
|
1488
1636
|
truncate: Optional[int] = None,
|
|
1489
1637
|
typical_p: Optional[float] = None,
|
|
1490
|
-
watermark: bool =
|
|
1638
|
+
watermark: Optional[bool] = None,
|
|
1491
1639
|
) -> str: ...
|
|
1492
1640
|
|
|
1493
1641
|
@overload
|
|
@@ -1498,19 +1646,24 @@ class InferenceClient:
|
|
|
1498
1646
|
details: Literal[True] = ...,
|
|
1499
1647
|
stream: Literal[False] = ...,
|
|
1500
1648
|
model: Optional[str] = None,
|
|
1501
|
-
|
|
1502
|
-
max_new_tokens: int = 20,
|
|
1649
|
+
# Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
|
|
1503
1650
|
best_of: Optional[int] = None,
|
|
1651
|
+
decoder_input_details: Optional[bool] = None,
|
|
1652
|
+
do_sample: Optional[bool] = False, # Manual default value
|
|
1653
|
+
frequency_penalty: Optional[float] = None,
|
|
1654
|
+
grammar: Optional[TextGenerationInputGrammarType] = None,
|
|
1655
|
+
max_new_tokens: Optional[int] = None,
|
|
1504
1656
|
repetition_penalty: Optional[float] = None,
|
|
1505
|
-
return_full_text: bool = False,
|
|
1657
|
+
return_full_text: Optional[bool] = False, # Manual default value
|
|
1506
1658
|
seed: Optional[int] = None,
|
|
1507
|
-
stop_sequences: Optional[List[str]] = None,
|
|
1659
|
+
stop_sequences: Optional[List[str]] = None, # Same as `stop`
|
|
1508
1660
|
temperature: Optional[float] = None,
|
|
1509
1661
|
top_k: Optional[int] = None,
|
|
1662
|
+
top_n_tokens: Optional[int] = None,
|
|
1510
1663
|
top_p: Optional[float] = None,
|
|
1511
1664
|
truncate: Optional[int] = None,
|
|
1512
1665
|
typical_p: Optional[float] = None,
|
|
1513
|
-
watermark: bool =
|
|
1666
|
+
watermark: Optional[bool] = None,
|
|
1514
1667
|
) -> TextGenerationOutput: ...
|
|
1515
1668
|
|
|
1516
1669
|
@overload
|
|
@@ -1521,19 +1674,24 @@ class InferenceClient:
|
|
|
1521
1674
|
details: Literal[False] = ...,
|
|
1522
1675
|
stream: Literal[True] = ...,
|
|
1523
1676
|
model: Optional[str] = None,
|
|
1524
|
-
|
|
1525
|
-
max_new_tokens: int = 20,
|
|
1677
|
+
# Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
|
|
1526
1678
|
best_of: Optional[int] = None,
|
|
1679
|
+
decoder_input_details: Optional[bool] = None,
|
|
1680
|
+
do_sample: Optional[bool] = False, # Manual default value
|
|
1681
|
+
frequency_penalty: Optional[float] = None,
|
|
1682
|
+
grammar: Optional[TextGenerationInputGrammarType] = None,
|
|
1683
|
+
max_new_tokens: Optional[int] = None,
|
|
1527
1684
|
repetition_penalty: Optional[float] = None,
|
|
1528
|
-
return_full_text: bool = False,
|
|
1685
|
+
return_full_text: Optional[bool] = False, # Manual default value
|
|
1529
1686
|
seed: Optional[int] = None,
|
|
1530
|
-
stop_sequences: Optional[List[str]] = None,
|
|
1687
|
+
stop_sequences: Optional[List[str]] = None, # Same as `stop`
|
|
1531
1688
|
temperature: Optional[float] = None,
|
|
1532
1689
|
top_k: Optional[int] = None,
|
|
1690
|
+
top_n_tokens: Optional[int] = None,
|
|
1533
1691
|
top_p: Optional[float] = None,
|
|
1534
1692
|
truncate: Optional[int] = None,
|
|
1535
1693
|
typical_p: Optional[float] = None,
|
|
1536
|
-
watermark: bool =
|
|
1694
|
+
watermark: Optional[bool] = None,
|
|
1537
1695
|
) -> Iterable[str]: ...
|
|
1538
1696
|
|
|
1539
1697
|
@overload
|
|
@@ -1544,19 +1702,24 @@ class InferenceClient:
|
|
|
1544
1702
|
details: Literal[True] = ...,
|
|
1545
1703
|
stream: Literal[True] = ...,
|
|
1546
1704
|
model: Optional[str] = None,
|
|
1547
|
-
|
|
1548
|
-
max_new_tokens: int = 20,
|
|
1705
|
+
# Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
|
|
1549
1706
|
best_of: Optional[int] = None,
|
|
1707
|
+
decoder_input_details: Optional[bool] = None,
|
|
1708
|
+
do_sample: Optional[bool] = False, # Manual default value
|
|
1709
|
+
frequency_penalty: Optional[float] = None,
|
|
1710
|
+
grammar: Optional[TextGenerationInputGrammarType] = None,
|
|
1711
|
+
max_new_tokens: Optional[int] = None,
|
|
1550
1712
|
repetition_penalty: Optional[float] = None,
|
|
1551
|
-
return_full_text: bool = False,
|
|
1713
|
+
return_full_text: Optional[bool] = False, # Manual default value
|
|
1552
1714
|
seed: Optional[int] = None,
|
|
1553
|
-
stop_sequences: Optional[List[str]] = None,
|
|
1715
|
+
stop_sequences: Optional[List[str]] = None, # Same as `stop`
|
|
1554
1716
|
temperature: Optional[float] = None,
|
|
1555
1717
|
top_k: Optional[int] = None,
|
|
1718
|
+
top_n_tokens: Optional[int] = None,
|
|
1556
1719
|
top_p: Optional[float] = None,
|
|
1557
1720
|
truncate: Optional[int] = None,
|
|
1558
1721
|
typical_p: Optional[float] = None,
|
|
1559
|
-
watermark: bool =
|
|
1722
|
+
watermark: Optional[bool] = None,
|
|
1560
1723
|
) -> Iterable[TextGenerationStreamOutput]: ...
|
|
1561
1724
|
|
|
1562
1725
|
@overload
|
|
@@ -1567,19 +1730,24 @@ class InferenceClient:
|
|
|
1567
1730
|
details: Literal[True] = ...,
|
|
1568
1731
|
stream: bool = ...,
|
|
1569
1732
|
model: Optional[str] = None,
|
|
1570
|
-
|
|
1571
|
-
max_new_tokens: int = 20,
|
|
1733
|
+
# Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
|
|
1572
1734
|
best_of: Optional[int] = None,
|
|
1735
|
+
decoder_input_details: Optional[bool] = None,
|
|
1736
|
+
do_sample: Optional[bool] = False, # Manual default value
|
|
1737
|
+
frequency_penalty: Optional[float] = None,
|
|
1738
|
+
grammar: Optional[TextGenerationInputGrammarType] = None,
|
|
1739
|
+
max_new_tokens: Optional[int] = None,
|
|
1573
1740
|
repetition_penalty: Optional[float] = None,
|
|
1574
|
-
return_full_text: bool = False,
|
|
1741
|
+
return_full_text: Optional[bool] = False, # Manual default value
|
|
1575
1742
|
seed: Optional[int] = None,
|
|
1576
|
-
stop_sequences: Optional[List[str]] = None,
|
|
1743
|
+
stop_sequences: Optional[List[str]] = None, # Same as `stop`
|
|
1577
1744
|
temperature: Optional[float] = None,
|
|
1578
1745
|
top_k: Optional[int] = None,
|
|
1746
|
+
top_n_tokens: Optional[int] = None,
|
|
1579
1747
|
top_p: Optional[float] = None,
|
|
1580
1748
|
truncate: Optional[int] = None,
|
|
1581
1749
|
typical_p: Optional[float] = None,
|
|
1582
|
-
watermark: bool =
|
|
1750
|
+
watermark: Optional[bool] = None,
|
|
1583
1751
|
) -> Union[TextGenerationOutput, Iterable[TextGenerationStreamOutput]]: ...
|
|
1584
1752
|
|
|
1585
1753
|
def text_generation(
|
|
@@ -1589,20 +1757,24 @@ class InferenceClient:
|
|
|
1589
1757
|
details: bool = False,
|
|
1590
1758
|
stream: bool = False,
|
|
1591
1759
|
model: Optional[str] = None,
|
|
1592
|
-
|
|
1593
|
-
max_new_tokens: int = 20,
|
|
1760
|
+
# Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
|
|
1594
1761
|
best_of: Optional[int] = None,
|
|
1762
|
+
decoder_input_details: Optional[bool] = None,
|
|
1763
|
+
do_sample: Optional[bool] = False, # Manual default value
|
|
1764
|
+
frequency_penalty: Optional[float] = None,
|
|
1765
|
+
grammar: Optional[TextGenerationInputGrammarType] = None,
|
|
1766
|
+
max_new_tokens: Optional[int] = None,
|
|
1595
1767
|
repetition_penalty: Optional[float] = None,
|
|
1596
|
-
return_full_text: bool = False,
|
|
1768
|
+
return_full_text: Optional[bool] = False, # Manual default value
|
|
1597
1769
|
seed: Optional[int] = None,
|
|
1598
|
-
stop_sequences: Optional[List[str]] = None,
|
|
1770
|
+
stop_sequences: Optional[List[str]] = None, # Same as `stop`
|
|
1599
1771
|
temperature: Optional[float] = None,
|
|
1600
1772
|
top_k: Optional[int] = None,
|
|
1773
|
+
top_n_tokens: Optional[int] = None,
|
|
1601
1774
|
top_p: Optional[float] = None,
|
|
1602
1775
|
truncate: Optional[int] = None,
|
|
1603
1776
|
typical_p: Optional[float] = None,
|
|
1604
|
-
watermark: bool =
|
|
1605
|
-
decoder_input_details: bool = False,
|
|
1777
|
+
watermark: Optional[bool] = None,
|
|
1606
1778
|
) -> Union[str, TextGenerationOutput, Iterable[str], Iterable[TextGenerationStreamOutput]]:
|
|
1607
1779
|
"""
|
|
1608
1780
|
Given a prompt, generate the following text.
|
|
@@ -1630,38 +1802,46 @@ class InferenceClient:
|
|
|
1630
1802
|
model (`str`, *optional*):
|
|
1631
1803
|
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
|
1632
1804
|
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
|
|
1633
|
-
|
|
1805
|
+
best_of (`int`, *optional*):
|
|
1806
|
+
Generate best_of sequences and return the one if the highest token logprobs.
|
|
1807
|
+
decoder_input_details (`bool`, *optional*):
|
|
1808
|
+
Return the decoder input token logprobs and ids. You must set `details=True` as well for it to be taken
|
|
1809
|
+
into account. Defaults to `False`.
|
|
1810
|
+
do_sample (`bool`, *optional*):
|
|
1634
1811
|
Activate logits sampling
|
|
1635
|
-
|
|
1812
|
+
frequency_penalty (`float`, *optional*):
|
|
1813
|
+
Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in
|
|
1814
|
+
the text so far, decreasing the model's likelihood to repeat the same line verbatim.
|
|
1815
|
+
grammar ([`TextGenerationInputGrammarType`], *optional*):
|
|
1816
|
+
Grammar constraints. Can be either a JSONSchema or a regex.
|
|
1817
|
+
max_new_tokens (`int`, *optional*):
|
|
1636
1818
|
Maximum number of generated tokens
|
|
1637
|
-
|
|
1638
|
-
Generate best_of sequences and return the one if the highest token logprobs
|
|
1639
|
-
repetition_penalty (`float`):
|
|
1819
|
+
repetition_penalty (`float`, *optional*):
|
|
1640
1820
|
The parameter for repetition penalty. 1.0 means no penalty. See [this
|
|
1641
1821
|
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
|
1642
|
-
return_full_text (`bool
|
|
1822
|
+
return_full_text (`bool`, *optional*):
|
|
1643
1823
|
Whether to prepend the prompt to the generated text
|
|
1644
|
-
seed (`int
|
|
1824
|
+
seed (`int`, *optional*):
|
|
1645
1825
|
Random sampling seed
|
|
1646
|
-
stop_sequences (`List[str]
|
|
1826
|
+
stop_sequences (`List[str]`, *optional*):
|
|
1647
1827
|
Stop generating tokens if a member of `stop_sequences` is generated
|
|
1648
|
-
temperature (`float
|
|
1828
|
+
temperature (`float`, *optional*):
|
|
1649
1829
|
The value used to module the logits distribution.
|
|
1650
|
-
|
|
1830
|
+
top_n_tokens (`int`, *optional*):
|
|
1831
|
+
Return information about the `top_n_tokens` most likely tokens at each generation step, instead of
|
|
1832
|
+
just the sampled token.
|
|
1833
|
+
top_k (`int`, *optional`):
|
|
1651
1834
|
The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
|
1652
|
-
top_p (`float`):
|
|
1835
|
+
top_p (`float`, *optional`):
|
|
1653
1836
|
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
|
1654
1837
|
higher are kept for generation.
|
|
1655
|
-
truncate (`int`):
|
|
1656
|
-
Truncate inputs tokens to the given size
|
|
1657
|
-
typical_p (`float`):
|
|
1838
|
+
truncate (`int`, *optional`):
|
|
1839
|
+
Truncate inputs tokens to the given size.
|
|
1840
|
+
typical_p (`float`, *optional`):
|
|
1658
1841
|
Typical Decoding mass
|
|
1659
1842
|
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
|
|
1660
|
-
watermark (`bool`):
|
|
1843
|
+
watermark (`bool`, *optional`):
|
|
1661
1844
|
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
|
1662
|
-
decoder_input_details (`bool`):
|
|
1663
|
-
Return the decoder input token logprobs and ids. You must set `details=True` as well for it to be taken
|
|
1664
|
-
into account. Defaults to `False`.
|
|
1665
1845
|
|
|
1666
1846
|
Returns:
|
|
1667
1847
|
`Union[str, TextGenerationOutput, Iterable[str], Iterable[TextGenerationStreamOutput]]`:
|
|
@@ -1713,10 +1893,10 @@ class InferenceClient:
|
|
|
1713
1893
|
generated_tokens=12,
|
|
1714
1894
|
seed=None,
|
|
1715
1895
|
prefill=[
|
|
1716
|
-
|
|
1717
|
-
|
|
1896
|
+
TextGenerationPrefillOutputToken(id=487, text='The', logprob=None),
|
|
1897
|
+
TextGenerationPrefillOutputToken(id=53789, text=' hugging', logprob=-13.171875),
|
|
1718
1898
|
(...)
|
|
1719
|
-
|
|
1899
|
+
TextGenerationPrefillOutputToken(id=204, text=' ', logprob=-7.0390625)
|
|
1720
1900
|
],
|
|
1721
1901
|
tokens=[
|
|
1722
1902
|
TokenElement(id=1425, text='100', logprob=-1.0175781, special=False),
|
|
@@ -1750,8 +1930,35 @@ class InferenceClient:
|
|
|
1750
1930
|
logprob=-0.5703125,
|
|
1751
1931
|
special=False),
|
|
1752
1932
|
generated_text='100% open source and built to be easy to use.',
|
|
1753
|
-
details=
|
|
1933
|
+
details=TextGenerationStreamOutputStreamDetails(finish_reason='length', generated_tokens=12, seed=None)
|
|
1754
1934
|
)
|
|
1935
|
+
|
|
1936
|
+
# Case 5: generate constrained output using grammar
|
|
1937
|
+
>>> response = client.text_generation(
|
|
1938
|
+
... prompt="I saw a puppy a cat and a raccoon during my bike ride in the park",
|
|
1939
|
+
... model="HuggingFaceH4/zephyr-orpo-141b-A35b-v0.1",
|
|
1940
|
+
... max_new_tokens=100,
|
|
1941
|
+
... repetition_penalty=1.3,
|
|
1942
|
+
... grammar={
|
|
1943
|
+
... "type": "json",
|
|
1944
|
+
... "value": {
|
|
1945
|
+
... "properties": {
|
|
1946
|
+
... "location": {"type": "string"},
|
|
1947
|
+
... "activity": {"type": "string"},
|
|
1948
|
+
... "animals_seen": {"type": "integer", "minimum": 1, "maximum": 5},
|
|
1949
|
+
... "animals": {"type": "array", "items": {"type": "string"}},
|
|
1950
|
+
... },
|
|
1951
|
+
... "required": ["location", "activity", "animals_seen", "animals"],
|
|
1952
|
+
... },
|
|
1953
|
+
... },
|
|
1954
|
+
... )
|
|
1955
|
+
>>> json.loads(response)
|
|
1956
|
+
{
|
|
1957
|
+
"activity": "bike riding",
|
|
1958
|
+
"animals": ["puppy", "cat", "raccoon"],
|
|
1959
|
+
"animals_seen": 3,
|
|
1960
|
+
"location": "park"
|
|
1961
|
+
}
|
|
1755
1962
|
```
|
|
1756
1963
|
"""
|
|
1757
1964
|
if decoder_input_details and not details:
|
|
@@ -1762,41 +1969,48 @@ class InferenceClient:
|
|
|
1762
1969
|
decoder_input_details = False
|
|
1763
1970
|
|
|
1764
1971
|
# Build payload
|
|
1972
|
+
parameters = {
|
|
1973
|
+
"best_of": best_of,
|
|
1974
|
+
"decoder_input_details": decoder_input_details,
|
|
1975
|
+
"do_sample": do_sample,
|
|
1976
|
+
"frequency_penalty": frequency_penalty,
|
|
1977
|
+
"grammar": grammar,
|
|
1978
|
+
"max_new_tokens": max_new_tokens,
|
|
1979
|
+
"repetition_penalty": repetition_penalty,
|
|
1980
|
+
"return_full_text": return_full_text,
|
|
1981
|
+
"seed": seed,
|
|
1982
|
+
"stop": stop_sequences if stop_sequences is not None else [],
|
|
1983
|
+
"temperature": temperature,
|
|
1984
|
+
"top_k": top_k,
|
|
1985
|
+
"top_n_tokens": top_n_tokens,
|
|
1986
|
+
"top_p": top_p,
|
|
1987
|
+
"truncate": truncate,
|
|
1988
|
+
"typical_p": typical_p,
|
|
1989
|
+
"watermark": watermark,
|
|
1990
|
+
}
|
|
1991
|
+
parameters = {k: v for k, v in parameters.items() if v is not None}
|
|
1765
1992
|
payload = {
|
|
1766
1993
|
"inputs": prompt,
|
|
1767
|
-
"parameters":
|
|
1768
|
-
"best_of": best_of,
|
|
1769
|
-
"decoder_input_details": decoder_input_details,
|
|
1770
|
-
"details": details,
|
|
1771
|
-
"do_sample": do_sample,
|
|
1772
|
-
"max_new_tokens": max_new_tokens,
|
|
1773
|
-
"repetition_penalty": repetition_penalty,
|
|
1774
|
-
"return_full_text": return_full_text,
|
|
1775
|
-
"seed": seed,
|
|
1776
|
-
"stop": stop_sequences if stop_sequences is not None else [],
|
|
1777
|
-
"temperature": temperature,
|
|
1778
|
-
"top_k": top_k,
|
|
1779
|
-
"top_p": top_p,
|
|
1780
|
-
"truncate": truncate,
|
|
1781
|
-
"typical_p": typical_p,
|
|
1782
|
-
"watermark": watermark,
|
|
1783
|
-
},
|
|
1994
|
+
"parameters": parameters,
|
|
1784
1995
|
"stream": stream,
|
|
1785
1996
|
}
|
|
1786
1997
|
|
|
1787
1998
|
# Remove some parameters if not a TGI server
|
|
1788
|
-
|
|
1789
|
-
|
|
1999
|
+
unsupported_kwargs = _get_unsupported_text_generation_kwargs(model)
|
|
2000
|
+
if len(unsupported_kwargs) > 0:
|
|
2001
|
+
# The server does not support some parameters
|
|
2002
|
+
# => means it is not a TGI server
|
|
2003
|
+
# => remove unsupported parameters and warn the user
|
|
1790
2004
|
|
|
1791
2005
|
ignored_parameters = []
|
|
1792
|
-
for key in
|
|
1793
|
-
if parameters
|
|
2006
|
+
for key in unsupported_kwargs:
|
|
2007
|
+
if parameters.get(key):
|
|
1794
2008
|
ignored_parameters.append(key)
|
|
1795
|
-
|
|
2009
|
+
parameters.pop(key, None)
|
|
1796
2010
|
if len(ignored_parameters) > 0:
|
|
1797
2011
|
warnings.warn(
|
|
1798
|
-
"API endpoint/model for text-generation is not served via TGI. Ignoring parameters"
|
|
1799
|
-
f" {ignored_parameters}.",
|
|
2012
|
+
"API endpoint/model for text-generation is not served via TGI. Ignoring following parameters:"
|
|
2013
|
+
f" {', '.join(ignored_parameters)}.",
|
|
1800
2014
|
UserWarning,
|
|
1801
2015
|
)
|
|
1802
2016
|
if details:
|
|
@@ -1816,27 +2030,32 @@ class InferenceClient:
|
|
|
1816
2030
|
try:
|
|
1817
2031
|
bytes_output = self.post(json=payload, model=model, task="text-generation", stream=stream) # type: ignore
|
|
1818
2032
|
except HTTPError as e:
|
|
1819
|
-
|
|
1820
|
-
|
|
2033
|
+
match = MODEL_KWARGS_NOT_USED_REGEX.search(str(e))
|
|
2034
|
+
if isinstance(e, BadRequestError) and match:
|
|
2035
|
+
unused_params = [kwarg.strip("' ") for kwarg in match.group(1).split(",")]
|
|
2036
|
+
_set_unsupported_text_generation_kwargs(model, unused_params)
|
|
1821
2037
|
return self.text_generation( # type: ignore
|
|
1822
2038
|
prompt=prompt,
|
|
1823
2039
|
details=details,
|
|
1824
2040
|
stream=stream,
|
|
1825
2041
|
model=model,
|
|
2042
|
+
best_of=best_of,
|
|
2043
|
+
decoder_input_details=decoder_input_details,
|
|
1826
2044
|
do_sample=do_sample,
|
|
2045
|
+
frequency_penalty=frequency_penalty,
|
|
2046
|
+
grammar=grammar,
|
|
1827
2047
|
max_new_tokens=max_new_tokens,
|
|
1828
|
-
best_of=best_of,
|
|
1829
2048
|
repetition_penalty=repetition_penalty,
|
|
1830
2049
|
return_full_text=return_full_text,
|
|
1831
2050
|
seed=seed,
|
|
1832
2051
|
stop_sequences=stop_sequences,
|
|
1833
2052
|
temperature=temperature,
|
|
1834
2053
|
top_k=top_k,
|
|
2054
|
+
top_n_tokens=top_n_tokens,
|
|
1835
2055
|
top_p=top_p,
|
|
1836
2056
|
truncate=truncate,
|
|
1837
2057
|
typical_p=typical_p,
|
|
1838
2058
|
watermark=watermark,
|
|
1839
|
-
decoder_input_details=decoder_input_details,
|
|
1840
2059
|
)
|
|
1841
2060
|
raise_text_generation_error(e)
|
|
1842
2061
|
|