huggingface-hub 0.22.2__py3-none-any.whl → 0.23.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of huggingface-hub might be problematic. Click here for more details.
- huggingface_hub/__init__.py +51 -19
- huggingface_hub/_commit_api.py +9 -8
- huggingface_hub/_commit_scheduler.py +2 -2
- huggingface_hub/_inference_endpoints.py +10 -17
- huggingface_hub/_local_folder.py +229 -0
- huggingface_hub/_login.py +4 -3
- huggingface_hub/_multi_commits.py +1 -1
- huggingface_hub/_snapshot_download.py +16 -38
- huggingface_hub/_tensorboard_logger.py +16 -6
- huggingface_hub/_webhooks_payload.py +22 -1
- huggingface_hub/_webhooks_server.py +24 -20
- huggingface_hub/commands/download.py +11 -34
- huggingface_hub/commands/huggingface_cli.py +2 -0
- huggingface_hub/commands/tag.py +159 -0
- huggingface_hub/constants.py +3 -5
- huggingface_hub/errors.py +58 -0
- huggingface_hub/file_download.py +545 -376
- huggingface_hub/hf_api.py +756 -622
- huggingface_hub/hf_file_system.py +20 -5
- huggingface_hub/hub_mixin.py +127 -43
- huggingface_hub/inference/_client.py +402 -183
- huggingface_hub/inference/_common.py +19 -29
- huggingface_hub/inference/_generated/_async_client.py +402 -184
- huggingface_hub/inference/_generated/types/__init__.py +23 -6
- huggingface_hub/inference/_generated/types/chat_completion.py +197 -43
- huggingface_hub/inference/_generated/types/text_generation.py +57 -79
- huggingface_hub/inference/_templating.py +2 -4
- huggingface_hub/keras_mixin.py +0 -3
- huggingface_hub/lfs.py +9 -1
- huggingface_hub/repository.py +1 -0
- huggingface_hub/utils/__init__.py +12 -6
- huggingface_hub/utils/_fixes.py +1 -0
- huggingface_hub/utils/_headers.py +2 -4
- huggingface_hub/utils/_http.py +2 -4
- huggingface_hub/utils/_paths.py +13 -1
- huggingface_hub/utils/_runtime.py +10 -0
- huggingface_hub/utils/_safetensors.py +0 -13
- huggingface_hub/utils/_validators.py +2 -7
- huggingface_hub/utils/tqdm.py +124 -46
- {huggingface_hub-0.22.2.dist-info → huggingface_hub-0.23.1.dist-info}/METADATA +5 -1
- {huggingface_hub-0.22.2.dist-info → huggingface_hub-0.23.1.dist-info}/RECORD +45 -43
- {huggingface_hub-0.22.2.dist-info → huggingface_hub-0.23.1.dist-info}/LICENSE +0 -0
- {huggingface_hub-0.22.2.dist-info → huggingface_hub-0.23.1.dist-info}/WHEEL +0 -0
- {huggingface_hub-0.22.2.dist-info → huggingface_hub-0.23.1.dist-info}/entry_points.txt +0 -0
- {huggingface_hub-0.22.2.dist-info → huggingface_hub-0.23.1.dist-info}/top_level.txt +0 -0
|
@@ -21,6 +21,7 @@
|
|
|
21
21
|
import asyncio
|
|
22
22
|
import base64
|
|
23
23
|
import logging
|
|
24
|
+
import re
|
|
24
25
|
import time
|
|
25
26
|
import warnings
|
|
26
27
|
from typing import (
|
|
@@ -44,7 +45,6 @@ from huggingface_hub.inference._common import (
|
|
|
44
45
|
ContentT,
|
|
45
46
|
ModelStatus,
|
|
46
47
|
_async_stream_chat_completion_response_from_bytes,
|
|
47
|
-
_async_stream_chat_completion_response_from_text_generation,
|
|
48
48
|
_async_stream_text_generation_response,
|
|
49
49
|
_b64_encode,
|
|
50
50
|
_b64_to_image,
|
|
@@ -52,21 +52,23 @@ from huggingface_hub.inference._common import (
|
|
|
52
52
|
_bytes_to_image,
|
|
53
53
|
_bytes_to_list,
|
|
54
54
|
_fetch_recommended_models,
|
|
55
|
+
_get_unsupported_text_generation_kwargs,
|
|
55
56
|
_import_numpy,
|
|
56
57
|
_is_chat_completion_server,
|
|
57
|
-
_is_tgi_server,
|
|
58
58
|
_open_as_binary,
|
|
59
59
|
_set_as_non_chat_completion_server,
|
|
60
|
-
|
|
60
|
+
_set_unsupported_text_generation_kwargs,
|
|
61
61
|
raise_text_generation_error,
|
|
62
62
|
)
|
|
63
63
|
from huggingface_hub.inference._generated.types import (
|
|
64
64
|
AudioClassificationOutputElement,
|
|
65
65
|
AudioToAudioOutputElement,
|
|
66
66
|
AutomaticSpeechRecognitionOutput,
|
|
67
|
+
ChatCompletionInputTool,
|
|
68
|
+
ChatCompletionInputToolTypeClass,
|
|
67
69
|
ChatCompletionOutput,
|
|
68
|
-
|
|
69
|
-
|
|
70
|
+
ChatCompletionOutputComplete,
|
|
71
|
+
ChatCompletionOutputMessage,
|
|
70
72
|
ChatCompletionStreamOutput,
|
|
71
73
|
DocumentQuestionAnsweringOutputElement,
|
|
72
74
|
FillMaskOutputElement,
|
|
@@ -78,6 +80,7 @@ from huggingface_hub.inference._generated.types import (
|
|
|
78
80
|
SummarizationOutput,
|
|
79
81
|
TableQuestionAnsweringOutputElement,
|
|
80
82
|
TextClassificationOutputElement,
|
|
83
|
+
TextGenerationInputGrammarType,
|
|
81
84
|
TextGenerationOutput,
|
|
82
85
|
TextGenerationStreamOutput,
|
|
83
86
|
TokenClassificationOutputElement,
|
|
@@ -86,7 +89,7 @@ from huggingface_hub.inference._generated.types import (
|
|
|
86
89
|
ZeroShotClassificationOutputElement,
|
|
87
90
|
ZeroShotImageClassificationOutputElement,
|
|
88
91
|
)
|
|
89
|
-
from huggingface_hub.inference.
|
|
92
|
+
from huggingface_hub.inference._generated.types.chat_completion import ChatCompletionInputToolTypeEnum
|
|
90
93
|
from huggingface_hub.inference._types import (
|
|
91
94
|
ConversationalOutput, # soon to be removed
|
|
92
95
|
)
|
|
@@ -99,11 +102,14 @@ from .._common import _async_yield_from, _import_aiohttp
|
|
|
99
102
|
|
|
100
103
|
if TYPE_CHECKING:
|
|
101
104
|
import numpy as np
|
|
102
|
-
from PIL import Image
|
|
105
|
+
from PIL.Image import Image
|
|
103
106
|
|
|
104
107
|
logger = logging.getLogger(__name__)
|
|
105
108
|
|
|
106
109
|
|
|
110
|
+
MODEL_KWARGS_NOT_USED_REGEX = re.compile(r"The following `model_kwargs` are not used by the model: \[(.*?)\]")
|
|
111
|
+
|
|
112
|
+
|
|
107
113
|
class AsyncInferenceClient:
|
|
108
114
|
"""
|
|
109
115
|
Initialize a new Inference Client.
|
|
@@ -415,10 +421,19 @@ class AsyncInferenceClient:
|
|
|
415
421
|
*,
|
|
416
422
|
model: Optional[str] = None,
|
|
417
423
|
stream: Literal[False] = False,
|
|
418
|
-
|
|
424
|
+
frequency_penalty: Optional[float] = None,
|
|
425
|
+
logit_bias: Optional[List[float]] = None,
|
|
426
|
+
logprobs: Optional[bool] = None,
|
|
427
|
+
max_tokens: Optional[int] = None,
|
|
428
|
+
n: Optional[int] = None,
|
|
429
|
+
presence_penalty: Optional[float] = None,
|
|
419
430
|
seed: Optional[int] = None,
|
|
420
|
-
stop: Optional[
|
|
421
|
-
temperature: float =
|
|
431
|
+
stop: Optional[List[str]] = None,
|
|
432
|
+
temperature: Optional[float] = None,
|
|
433
|
+
tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, ChatCompletionInputToolTypeEnum]] = None,
|
|
434
|
+
tool_prompt: Optional[str] = None,
|
|
435
|
+
tools: Optional[List[ChatCompletionInputTool]] = None,
|
|
436
|
+
top_logprobs: Optional[int] = None,
|
|
422
437
|
top_p: Optional[float] = None,
|
|
423
438
|
) -> ChatCompletionOutput: ...
|
|
424
439
|
|
|
@@ -429,10 +444,19 @@ class AsyncInferenceClient:
|
|
|
429
444
|
*,
|
|
430
445
|
model: Optional[str] = None,
|
|
431
446
|
stream: Literal[True] = True,
|
|
432
|
-
|
|
447
|
+
frequency_penalty: Optional[float] = None,
|
|
448
|
+
logit_bias: Optional[List[float]] = None,
|
|
449
|
+
logprobs: Optional[bool] = None,
|
|
450
|
+
max_tokens: Optional[int] = None,
|
|
451
|
+
n: Optional[int] = None,
|
|
452
|
+
presence_penalty: Optional[float] = None,
|
|
433
453
|
seed: Optional[int] = None,
|
|
434
|
-
stop: Optional[
|
|
435
|
-
temperature: float =
|
|
454
|
+
stop: Optional[List[str]] = None,
|
|
455
|
+
temperature: Optional[float] = None,
|
|
456
|
+
tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, ChatCompletionInputToolTypeEnum]] = None,
|
|
457
|
+
tool_prompt: Optional[str] = None,
|
|
458
|
+
tools: Optional[List[ChatCompletionInputTool]] = None,
|
|
459
|
+
top_logprobs: Optional[int] = None,
|
|
436
460
|
top_p: Optional[float] = None,
|
|
437
461
|
) -> AsyncIterable[ChatCompletionStreamOutput]: ...
|
|
438
462
|
|
|
@@ -443,10 +467,19 @@ class AsyncInferenceClient:
|
|
|
443
467
|
*,
|
|
444
468
|
model: Optional[str] = None,
|
|
445
469
|
stream: bool = False,
|
|
446
|
-
|
|
470
|
+
frequency_penalty: Optional[float] = None,
|
|
471
|
+
logit_bias: Optional[List[float]] = None,
|
|
472
|
+
logprobs: Optional[bool] = None,
|
|
473
|
+
max_tokens: Optional[int] = None,
|
|
474
|
+
n: Optional[int] = None,
|
|
475
|
+
presence_penalty: Optional[float] = None,
|
|
447
476
|
seed: Optional[int] = None,
|
|
448
|
-
stop: Optional[
|
|
449
|
-
temperature: float =
|
|
477
|
+
stop: Optional[List[str]] = None,
|
|
478
|
+
temperature: Optional[float] = None,
|
|
479
|
+
tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, ChatCompletionInputToolTypeEnum]] = None,
|
|
480
|
+
tool_prompt: Optional[str] = None,
|
|
481
|
+
tools: Optional[List[ChatCompletionInputTool]] = None,
|
|
482
|
+
top_logprobs: Optional[int] = None,
|
|
450
483
|
top_p: Optional[float] = None,
|
|
451
484
|
) -> Union[ChatCompletionOutput, AsyncIterable[ChatCompletionStreamOutput]]: ...
|
|
452
485
|
|
|
@@ -456,10 +489,20 @@ class AsyncInferenceClient:
|
|
|
456
489
|
*,
|
|
457
490
|
model: Optional[str] = None,
|
|
458
491
|
stream: bool = False,
|
|
459
|
-
|
|
492
|
+
# Parameters from ChatCompletionInput (handled manually)
|
|
493
|
+
frequency_penalty: Optional[float] = None,
|
|
494
|
+
logit_bias: Optional[List[float]] = None,
|
|
495
|
+
logprobs: Optional[bool] = None,
|
|
496
|
+
max_tokens: Optional[int] = None,
|
|
497
|
+
n: Optional[int] = None,
|
|
498
|
+
presence_penalty: Optional[float] = None,
|
|
460
499
|
seed: Optional[int] = None,
|
|
461
|
-
stop: Optional[
|
|
462
|
-
temperature: float =
|
|
500
|
+
stop: Optional[List[str]] = None,
|
|
501
|
+
temperature: Optional[float] = None,
|
|
502
|
+
tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, ChatCompletionInputToolTypeEnum]] = None,
|
|
503
|
+
tool_prompt: Optional[str] = None,
|
|
504
|
+
tools: Optional[List[ChatCompletionInputTool]] = None,
|
|
505
|
+
top_logprobs: Optional[int] = None,
|
|
463
506
|
top_p: Optional[float] = None,
|
|
464
507
|
) -> Union[ChatCompletionOutput, AsyncIterable[ChatCompletionStreamOutput]]:
|
|
465
508
|
"""
|
|
@@ -482,27 +525,52 @@ class AsyncInferenceClient:
|
|
|
482
525
|
The model to use for chat-completion. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
|
483
526
|
Inference Endpoint. If not provided, the default recommended model for chat-based text-generation will be used.
|
|
484
527
|
See https://huggingface.co/tasks/text-generation for more details.
|
|
485
|
-
frequency_penalty (`float`, optional):
|
|
528
|
+
frequency_penalty (`float`, *optional*):
|
|
486
529
|
Penalizes new tokens based on their existing frequency
|
|
487
530
|
in the text so far. Range: [-2.0, 2.0]. Defaults to 0.0.
|
|
488
|
-
|
|
531
|
+
logit_bias (`List[float]`, *optional*):
|
|
532
|
+
Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens
|
|
533
|
+
(specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,
|
|
534
|
+
the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,
|
|
535
|
+
but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should
|
|
536
|
+
result in a ban or exclusive selection of the relevant token. Defaults to None.
|
|
537
|
+
logprobs (`bool`, *optional*):
|
|
538
|
+
Whether to return log probabilities of the output tokens or not. If true, returns the log
|
|
539
|
+
probabilities of each output token returned in the content of message.
|
|
540
|
+
max_tokens (`int`, *optional*):
|
|
489
541
|
Maximum number of tokens allowed in the response. Defaults to 20.
|
|
490
|
-
|
|
542
|
+
n (`int`, *optional*):
|
|
543
|
+
UNUSED.
|
|
544
|
+
presence_penalty (`float`, *optional*):
|
|
545
|
+
Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the
|
|
546
|
+
text so far, increasing the model's likelihood to talk about new topics.
|
|
547
|
+
seed (Optional[`int`], *optional*):
|
|
491
548
|
Seed for reproducible control flow. Defaults to None.
|
|
492
|
-
stop (Optional[`str`], optional):
|
|
549
|
+
stop (Optional[`str`], *optional*):
|
|
493
550
|
Up to four strings which trigger the end of the response.
|
|
494
551
|
Defaults to None.
|
|
495
|
-
stream (`bool`, optional):
|
|
552
|
+
stream (`bool`, *optional*):
|
|
496
553
|
Enable realtime streaming of responses. Defaults to False.
|
|
497
|
-
temperature (`float`, optional):
|
|
554
|
+
temperature (`float`, *optional*):
|
|
498
555
|
Controls randomness of the generations. Lower values ensure
|
|
499
556
|
less random completions. Range: [0, 2]. Defaults to 1.0.
|
|
500
|
-
|
|
557
|
+
top_logprobs (`int`, *optional*):
|
|
558
|
+
An integer between 0 and 5 specifying the number of most likely tokens to return at each token
|
|
559
|
+
position, each with an associated log probability. logprobs must be set to true if this parameter is
|
|
560
|
+
used.
|
|
561
|
+
top_p (`float`, *optional*):
|
|
501
562
|
Fraction of the most likely next words to sample from.
|
|
502
563
|
Must be between 0 and 1. Defaults to 1.0.
|
|
564
|
+
tool_choice ([`ChatCompletionInputToolTypeClass`] or [`ChatCompletionInputToolTypeEnum`], *optional*):
|
|
565
|
+
The tool to use for the completion. Defaults to "auto".
|
|
566
|
+
tool_prompt (`str`, *optional*):
|
|
567
|
+
A prompt to be appended before the tools.
|
|
568
|
+
tools (List of [`ChatCompletionInputTool`], *optional*):
|
|
569
|
+
A list of tools the model may call. Currently, only functions are supported as a tool. Use this to
|
|
570
|
+
provide a list of functions the model may generate JSON inputs for.
|
|
503
571
|
|
|
504
572
|
Returns:
|
|
505
|
-
`
|
|
573
|
+
[`ChatCompletionOutput] or Iterable of [`ChatCompletionStreamOutput`]:
|
|
506
574
|
Generated text returned from the server:
|
|
507
575
|
- if `stream=False`, the generated text is returned as a [`ChatCompletionOutput`] (default).
|
|
508
576
|
- if `stream=True`, the generated text is returned token by token as a sequence of [`ChatCompletionStreamOutput`].
|
|
@@ -514,19 +582,21 @@ class AsyncInferenceClient:
|
|
|
514
582
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
515
583
|
|
|
516
584
|
Example:
|
|
585
|
+
|
|
517
586
|
```py
|
|
518
587
|
# Must be run in an async context
|
|
588
|
+
# Chat example
|
|
519
589
|
>>> from huggingface_hub import AsyncInferenceClient
|
|
520
590
|
>>> messages = [{"role": "user", "content": "What is the capital of France?"}]
|
|
521
591
|
>>> client = AsyncInferenceClient("HuggingFaceH4/zephyr-7b-beta")
|
|
522
592
|
>>> await client.chat_completion(messages, max_tokens=100)
|
|
523
593
|
ChatCompletionOutput(
|
|
524
594
|
choices=[
|
|
525
|
-
|
|
595
|
+
ChatCompletionOutputComplete(
|
|
526
596
|
finish_reason='eos_token',
|
|
527
597
|
index=0,
|
|
528
|
-
message=
|
|
529
|
-
content='The capital of France is Paris. The official name of the city is
|
|
598
|
+
message=ChatCompletionOutputMessage(
|
|
599
|
+
content='The capital of France is Paris. The official name of the city is Ville de Paris (City of Paris) and the name of the country governing body, which is located in Paris, is La République française (The French Republic). \nI hope that helps! Let me know if you need any further information.'
|
|
530
600
|
)
|
|
531
601
|
)
|
|
532
602
|
],
|
|
@@ -539,7 +609,87 @@ class AsyncInferenceClient:
|
|
|
539
609
|
ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content=' capital', role='assistant'), index=0, finish_reason=None)], created=1710498504)
|
|
540
610
|
(...)
|
|
541
611
|
ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content=' may', role='assistant'), index=0, finish_reason=None)], created=1710498504)
|
|
542
|
-
|
|
612
|
+
|
|
613
|
+
# Chat example with tools
|
|
614
|
+
>>> client = AsyncInferenceClient("meta-llama/Meta-Llama-3-70B-Instruct")
|
|
615
|
+
>>> messages = [
|
|
616
|
+
... {
|
|
617
|
+
... "role": "system",
|
|
618
|
+
... "content": "Don't make assumptions about what values to plug into functions. Ask async for clarification if a user request is ambiguous.",
|
|
619
|
+
... },
|
|
620
|
+
... {
|
|
621
|
+
... "role": "user",
|
|
622
|
+
... "content": "What's the weather like the next 3 days in San Francisco, CA?",
|
|
623
|
+
... },
|
|
624
|
+
... ]
|
|
625
|
+
>>> tools = [
|
|
626
|
+
... {
|
|
627
|
+
... "type": "function",
|
|
628
|
+
... "function": {
|
|
629
|
+
... "name": "get_current_weather",
|
|
630
|
+
... "description": "Get the current weather",
|
|
631
|
+
... "parameters": {
|
|
632
|
+
... "type": "object",
|
|
633
|
+
... "properties": {
|
|
634
|
+
... "location": {
|
|
635
|
+
... "type": "string",
|
|
636
|
+
... "description": "The city and state, e.g. San Francisco, CA",
|
|
637
|
+
... },
|
|
638
|
+
... "format": {
|
|
639
|
+
... "type": "string",
|
|
640
|
+
... "enum": ["celsius", "fahrenheit"],
|
|
641
|
+
... "description": "The temperature unit to use. Infer this from the users location.",
|
|
642
|
+
... },
|
|
643
|
+
... },
|
|
644
|
+
... "required": ["location", "format"],
|
|
645
|
+
... },
|
|
646
|
+
... },
|
|
647
|
+
... },
|
|
648
|
+
... {
|
|
649
|
+
... "type": "function",
|
|
650
|
+
... "function": {
|
|
651
|
+
... "name": "get_n_day_weather_forecast",
|
|
652
|
+
... "description": "Get an N-day weather forecast",
|
|
653
|
+
... "parameters": {
|
|
654
|
+
... "type": "object",
|
|
655
|
+
... "properties": {
|
|
656
|
+
... "location": {
|
|
657
|
+
... "type": "string",
|
|
658
|
+
... "description": "The city and state, e.g. San Francisco, CA",
|
|
659
|
+
... },
|
|
660
|
+
... "format": {
|
|
661
|
+
... "type": "string",
|
|
662
|
+
... "enum": ["celsius", "fahrenheit"],
|
|
663
|
+
... "description": "The temperature unit to use. Infer this from the users location.",
|
|
664
|
+
... },
|
|
665
|
+
... "num_days": {
|
|
666
|
+
... "type": "integer",
|
|
667
|
+
... "description": "The number of days to forecast",
|
|
668
|
+
... },
|
|
669
|
+
... },
|
|
670
|
+
... "required": ["location", "format", "num_days"],
|
|
671
|
+
... },
|
|
672
|
+
... },
|
|
673
|
+
... },
|
|
674
|
+
... ]
|
|
675
|
+
|
|
676
|
+
>>> response = await client.chat_completion(
|
|
677
|
+
... model="meta-llama/Meta-Llama-3-70B-Instruct",
|
|
678
|
+
... messages=messages,
|
|
679
|
+
... tools=tools,
|
|
680
|
+
... tool_choice="auto",
|
|
681
|
+
... max_tokens=500,
|
|
682
|
+
... )
|
|
683
|
+
>>> response.choices[0].message.tool_calls[0].function
|
|
684
|
+
ChatCompletionOutputFunctionDefinition(
|
|
685
|
+
arguments={
|
|
686
|
+
'location': 'San Francisco, CA',
|
|
687
|
+
'format': 'fahrenheit',
|
|
688
|
+
'num_days': 3
|
|
689
|
+
},
|
|
690
|
+
name='get_n_day_weather_forecast',
|
|
691
|
+
description=None
|
|
692
|
+
)
|
|
543
693
|
```
|
|
544
694
|
"""
|
|
545
695
|
# determine model
|
|
@@ -558,30 +708,44 @@ class AsyncInferenceClient:
|
|
|
558
708
|
json=dict(
|
|
559
709
|
model="tgi", # random string
|
|
560
710
|
messages=messages,
|
|
711
|
+
frequency_penalty=frequency_penalty,
|
|
712
|
+
logit_bias=logit_bias,
|
|
713
|
+
logprobs=logprobs,
|
|
561
714
|
max_tokens=max_tokens,
|
|
715
|
+
n=n,
|
|
716
|
+
presence_penalty=presence_penalty,
|
|
562
717
|
seed=seed,
|
|
563
718
|
stop=stop,
|
|
564
719
|
temperature=temperature,
|
|
720
|
+
tool_choice=tool_choice,
|
|
721
|
+
tool_prompt=tool_prompt,
|
|
722
|
+
tools=tools,
|
|
723
|
+
top_logprobs=top_logprobs,
|
|
565
724
|
top_p=top_p,
|
|
566
725
|
stream=stream,
|
|
567
726
|
),
|
|
568
727
|
stream=stream,
|
|
569
728
|
)
|
|
570
|
-
except _import_aiohttp().ClientResponseError:
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
729
|
+
except _import_aiohttp().ClientResponseError as e:
|
|
730
|
+
if e.status in (400, 404, 500):
|
|
731
|
+
# Let's consider the server is not a chat completion server.
|
|
732
|
+
# Then we call again `chat_completion` which will render the chat template client side.
|
|
733
|
+
# (can be HTTP 500, HTTP 400, HTTP 404 depending on the server)
|
|
734
|
+
_set_as_non_chat_completion_server(model)
|
|
735
|
+
logger.warning(
|
|
736
|
+
f"Server {model_url} does not seem to support chat completion. Falling back to text generation. Error: {e}"
|
|
737
|
+
)
|
|
738
|
+
return await self.chat_completion(
|
|
739
|
+
messages=messages,
|
|
740
|
+
model=model,
|
|
741
|
+
stream=stream,
|
|
742
|
+
max_tokens=max_tokens,
|
|
743
|
+
seed=seed,
|
|
744
|
+
stop=stop,
|
|
745
|
+
temperature=temperature,
|
|
746
|
+
top_p=top_p,
|
|
747
|
+
)
|
|
748
|
+
raise
|
|
585
749
|
|
|
586
750
|
if stream:
|
|
587
751
|
return _async_stream_chat_completion_response_from_bytes(data) # type: ignore[arg-type]
|
|
@@ -589,75 +753,46 @@ class AsyncInferenceClient:
|
|
|
589
753
|
return ChatCompletionOutput.parse_obj_as_instance(data) # type: ignore[arg-type]
|
|
590
754
|
|
|
591
755
|
# At this point, we know the server is not a chat completion server.
|
|
592
|
-
#
|
|
593
|
-
# the
|
|
594
|
-
|
|
595
|
-
model_id = None
|
|
596
|
-
if model.startswith(("http://", "https://")):
|
|
597
|
-
# If URL, we need to know which model is served. This is not always possible.
|
|
598
|
-
# A workaround is to list the user Inference Endpoints and check if one of them correspond to the model URL.
|
|
599
|
-
# If not, we raise an error.
|
|
600
|
-
# TODO: fix when we have a proper API for this (at least for Inference Endpoints)
|
|
601
|
-
# TODO: what if Sagemaker URL?
|
|
602
|
-
# TODO: what if Azure URL?
|
|
603
|
-
from ..hf_api import HfApi
|
|
604
|
-
|
|
605
|
-
for endpoint in HfApi(token=self.token).list_inference_endpoints():
|
|
606
|
-
if endpoint.url == model:
|
|
607
|
-
model_id = endpoint.repository
|
|
608
|
-
break
|
|
609
|
-
else:
|
|
610
|
-
model_id = model
|
|
611
|
-
|
|
612
|
-
if model_id is None:
|
|
613
|
-
# If we don't have the model ID, we can't fetch the chat template.
|
|
614
|
-
# We raise an error.
|
|
756
|
+
# It means it's a transformers-backed server for which we can send a list of messages directly to the
|
|
757
|
+
# `text-generation` pipeline. We won't receive a detailed response but only the generated text.
|
|
758
|
+
if stream:
|
|
615
759
|
raise ValueError(
|
|
616
|
-
"
|
|
617
|
-
"
|
|
618
|
-
|
|
760
|
+
"Streaming token is not supported by the model. This is due to the model not been served by a "
|
|
761
|
+
"Text-Generation-Inference server. Please pass `stream=False` as input."
|
|
762
|
+
)
|
|
763
|
+
if tool_choice is not None or tool_prompt is not None or tools is not None:
|
|
764
|
+
warnings.warn(
|
|
765
|
+
"Tools are not supported by the model. This is due to the model not been served by a "
|
|
766
|
+
"Text-Generation-Inference server. The provided tool parameters will be ignored."
|
|
619
767
|
)
|
|
620
|
-
|
|
621
|
-
# fetch chat template + tokens
|
|
622
|
-
prompt = render_chat_prompt(model_id=model_id, token=self.token, messages=messages)
|
|
623
768
|
|
|
624
769
|
# generate response
|
|
625
|
-
stop_sequences = [stop] if isinstance(stop, str) else stop
|
|
626
770
|
text_generation_output = await self.text_generation(
|
|
627
|
-
prompt=
|
|
628
|
-
details=True,
|
|
629
|
-
stream=stream,
|
|
771
|
+
prompt=messages, # type: ignore # Not correct type but works implicitly
|
|
630
772
|
model=model,
|
|
773
|
+
stream=False,
|
|
774
|
+
details=False,
|
|
631
775
|
max_new_tokens=max_tokens,
|
|
632
776
|
seed=seed,
|
|
633
|
-
stop_sequences=
|
|
777
|
+
stop_sequences=stop,
|
|
634
778
|
temperature=temperature,
|
|
635
779
|
top_p=top_p,
|
|
636
780
|
)
|
|
637
781
|
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
if stream:
|
|
641
|
-
return _async_stream_chat_completion_response_from_text_generation(text_generation_output) # type: ignore [arg-type]
|
|
642
|
-
|
|
643
|
-
if isinstance(text_generation_output, TextGenerationOutput):
|
|
644
|
-
# General use case => format ChatCompletionOutput from text generation details
|
|
645
|
-
content: str = text_generation_output.generated_text
|
|
646
|
-
finish_reason: str = text_generation_output.details.finish_reason # type: ignore[union-attr]
|
|
647
|
-
else:
|
|
648
|
-
# Corner case: if server doesn't support details (e.g. if not a TGI server), we only receive an output string.
|
|
649
|
-
# In such a case, `finish_reason` is set to `"unk"`.
|
|
650
|
-
content = text_generation_output # type: ignore[assignment]
|
|
651
|
-
finish_reason = "unk"
|
|
652
|
-
|
|
782
|
+
# Format as a ChatCompletionOutput with dummy values for fields we can't provide
|
|
653
783
|
return ChatCompletionOutput(
|
|
654
|
-
|
|
784
|
+
id="dummy",
|
|
785
|
+
model="dummy",
|
|
786
|
+
object="dummy",
|
|
787
|
+
system_fingerprint="dummy",
|
|
788
|
+
usage=None, # type: ignore # set to `None` as we don't want to provide false information
|
|
789
|
+
created=int(time.time()),
|
|
655
790
|
choices=[
|
|
656
|
-
|
|
657
|
-
finish_reason=
|
|
791
|
+
ChatCompletionOutputComplete(
|
|
792
|
+
finish_reason="unk", # type: ignore # set to `unk` as we don't want to provide false information
|
|
658
793
|
index=0,
|
|
659
|
-
message=
|
|
660
|
-
content=
|
|
794
|
+
message=ChatCompletionOutputMessage(
|
|
795
|
+
content=text_generation_output,
|
|
661
796
|
role="assistant",
|
|
662
797
|
),
|
|
663
798
|
)
|
|
@@ -1063,7 +1198,7 @@ class AsyncInferenceClient:
|
|
|
1063
1198
|
self, frameworks: Union[None, str, Literal["all"], List[str]] = None
|
|
1064
1199
|
) -> Dict[str, List[str]]:
|
|
1065
1200
|
"""
|
|
1066
|
-
List models
|
|
1201
|
+
List models deployed on the Serverless Inference API service.
|
|
1067
1202
|
|
|
1068
1203
|
This helper checks deployed models framework by framework. By default, it will check the 4 main frameworks that
|
|
1069
1204
|
are supported and account for 95% of the hosted models. However, if you want a complete list of models you can
|
|
@@ -1071,9 +1206,17 @@ class AsyncInferenceClient:
|
|
|
1071
1206
|
in, you can also restrict to search to this one (e.g. `frameworks="text-generation-inference"`). The more
|
|
1072
1207
|
frameworks are checked, the more time it will take.
|
|
1073
1208
|
|
|
1209
|
+
<Tip warning={true}>
|
|
1210
|
+
|
|
1211
|
+
This endpoint method does not return a live list of all models available for the Serverless Inference API service.
|
|
1212
|
+
It searches over a cached list of models that were recently available and the list may not be up to date.
|
|
1213
|
+
If you want to know the live status of a specific model, use [`~InferenceClient.get_model_status`].
|
|
1214
|
+
|
|
1215
|
+
</Tip>
|
|
1216
|
+
|
|
1074
1217
|
<Tip>
|
|
1075
1218
|
|
|
1076
|
-
This endpoint is mostly useful for discoverability. If you already know which model you want to use and want to
|
|
1219
|
+
This endpoint method is mostly useful for discoverability. If you already know which model you want to use and want to
|
|
1077
1220
|
check its availability, you can directly use [`~InferenceClient.get_model_status`].
|
|
1078
1221
|
|
|
1079
1222
|
</Tip>
|
|
@@ -1499,19 +1642,24 @@ class AsyncInferenceClient:
|
|
|
1499
1642
|
details: Literal[False] = ...,
|
|
1500
1643
|
stream: Literal[False] = ...,
|
|
1501
1644
|
model: Optional[str] = None,
|
|
1502
|
-
|
|
1503
|
-
max_new_tokens: int = 20,
|
|
1645
|
+
# Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
|
|
1504
1646
|
best_of: Optional[int] = None,
|
|
1647
|
+
decoder_input_details: Optional[bool] = None,
|
|
1648
|
+
do_sample: Optional[bool] = False, # Manual default value
|
|
1649
|
+
frequency_penalty: Optional[float] = None,
|
|
1650
|
+
grammar: Optional[TextGenerationInputGrammarType] = None,
|
|
1651
|
+
max_new_tokens: Optional[int] = None,
|
|
1505
1652
|
repetition_penalty: Optional[float] = None,
|
|
1506
|
-
return_full_text: bool = False,
|
|
1653
|
+
return_full_text: Optional[bool] = False, # Manual default value
|
|
1507
1654
|
seed: Optional[int] = None,
|
|
1508
|
-
stop_sequences: Optional[List[str]] = None,
|
|
1655
|
+
stop_sequences: Optional[List[str]] = None, # Same as `stop`
|
|
1509
1656
|
temperature: Optional[float] = None,
|
|
1510
1657
|
top_k: Optional[int] = None,
|
|
1658
|
+
top_n_tokens: Optional[int] = None,
|
|
1511
1659
|
top_p: Optional[float] = None,
|
|
1512
1660
|
truncate: Optional[int] = None,
|
|
1513
1661
|
typical_p: Optional[float] = None,
|
|
1514
|
-
watermark: bool =
|
|
1662
|
+
watermark: Optional[bool] = None,
|
|
1515
1663
|
) -> str: ...
|
|
1516
1664
|
|
|
1517
1665
|
@overload
|
|
@@ -1522,19 +1670,24 @@ class AsyncInferenceClient:
|
|
|
1522
1670
|
details: Literal[True] = ...,
|
|
1523
1671
|
stream: Literal[False] = ...,
|
|
1524
1672
|
model: Optional[str] = None,
|
|
1525
|
-
|
|
1526
|
-
max_new_tokens: int = 20,
|
|
1673
|
+
# Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
|
|
1527
1674
|
best_of: Optional[int] = None,
|
|
1675
|
+
decoder_input_details: Optional[bool] = None,
|
|
1676
|
+
do_sample: Optional[bool] = False, # Manual default value
|
|
1677
|
+
frequency_penalty: Optional[float] = None,
|
|
1678
|
+
grammar: Optional[TextGenerationInputGrammarType] = None,
|
|
1679
|
+
max_new_tokens: Optional[int] = None,
|
|
1528
1680
|
repetition_penalty: Optional[float] = None,
|
|
1529
|
-
return_full_text: bool = False,
|
|
1681
|
+
return_full_text: Optional[bool] = False, # Manual default value
|
|
1530
1682
|
seed: Optional[int] = None,
|
|
1531
|
-
stop_sequences: Optional[List[str]] = None,
|
|
1683
|
+
stop_sequences: Optional[List[str]] = None, # Same as `stop`
|
|
1532
1684
|
temperature: Optional[float] = None,
|
|
1533
1685
|
top_k: Optional[int] = None,
|
|
1686
|
+
top_n_tokens: Optional[int] = None,
|
|
1534
1687
|
top_p: Optional[float] = None,
|
|
1535
1688
|
truncate: Optional[int] = None,
|
|
1536
1689
|
typical_p: Optional[float] = None,
|
|
1537
|
-
watermark: bool =
|
|
1690
|
+
watermark: Optional[bool] = None,
|
|
1538
1691
|
) -> TextGenerationOutput: ...
|
|
1539
1692
|
|
|
1540
1693
|
@overload
|
|
@@ -1545,19 +1698,24 @@ class AsyncInferenceClient:
|
|
|
1545
1698
|
details: Literal[False] = ...,
|
|
1546
1699
|
stream: Literal[True] = ...,
|
|
1547
1700
|
model: Optional[str] = None,
|
|
1548
|
-
|
|
1549
|
-
max_new_tokens: int = 20,
|
|
1701
|
+
# Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
|
|
1550
1702
|
best_of: Optional[int] = None,
|
|
1703
|
+
decoder_input_details: Optional[bool] = None,
|
|
1704
|
+
do_sample: Optional[bool] = False, # Manual default value
|
|
1705
|
+
frequency_penalty: Optional[float] = None,
|
|
1706
|
+
grammar: Optional[TextGenerationInputGrammarType] = None,
|
|
1707
|
+
max_new_tokens: Optional[int] = None,
|
|
1551
1708
|
repetition_penalty: Optional[float] = None,
|
|
1552
|
-
return_full_text: bool = False,
|
|
1709
|
+
return_full_text: Optional[bool] = False, # Manual default value
|
|
1553
1710
|
seed: Optional[int] = None,
|
|
1554
|
-
stop_sequences: Optional[List[str]] = None,
|
|
1711
|
+
stop_sequences: Optional[List[str]] = None, # Same as `stop`
|
|
1555
1712
|
temperature: Optional[float] = None,
|
|
1556
1713
|
top_k: Optional[int] = None,
|
|
1714
|
+
top_n_tokens: Optional[int] = None,
|
|
1557
1715
|
top_p: Optional[float] = None,
|
|
1558
1716
|
truncate: Optional[int] = None,
|
|
1559
1717
|
typical_p: Optional[float] = None,
|
|
1560
|
-
watermark: bool =
|
|
1718
|
+
watermark: Optional[bool] = None,
|
|
1561
1719
|
) -> AsyncIterable[str]: ...
|
|
1562
1720
|
|
|
1563
1721
|
@overload
|
|
@@ -1568,19 +1726,24 @@ class AsyncInferenceClient:
|
|
|
1568
1726
|
details: Literal[True] = ...,
|
|
1569
1727
|
stream: Literal[True] = ...,
|
|
1570
1728
|
model: Optional[str] = None,
|
|
1571
|
-
|
|
1572
|
-
max_new_tokens: int = 20,
|
|
1729
|
+
# Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
|
|
1573
1730
|
best_of: Optional[int] = None,
|
|
1731
|
+
decoder_input_details: Optional[bool] = None,
|
|
1732
|
+
do_sample: Optional[bool] = False, # Manual default value
|
|
1733
|
+
frequency_penalty: Optional[float] = None,
|
|
1734
|
+
grammar: Optional[TextGenerationInputGrammarType] = None,
|
|
1735
|
+
max_new_tokens: Optional[int] = None,
|
|
1574
1736
|
repetition_penalty: Optional[float] = None,
|
|
1575
|
-
return_full_text: bool = False,
|
|
1737
|
+
return_full_text: Optional[bool] = False, # Manual default value
|
|
1576
1738
|
seed: Optional[int] = None,
|
|
1577
|
-
stop_sequences: Optional[List[str]] = None,
|
|
1739
|
+
stop_sequences: Optional[List[str]] = None, # Same as `stop`
|
|
1578
1740
|
temperature: Optional[float] = None,
|
|
1579
1741
|
top_k: Optional[int] = None,
|
|
1742
|
+
top_n_tokens: Optional[int] = None,
|
|
1580
1743
|
top_p: Optional[float] = None,
|
|
1581
1744
|
truncate: Optional[int] = None,
|
|
1582
1745
|
typical_p: Optional[float] = None,
|
|
1583
|
-
watermark: bool =
|
|
1746
|
+
watermark: Optional[bool] = None,
|
|
1584
1747
|
) -> AsyncIterable[TextGenerationStreamOutput]: ...
|
|
1585
1748
|
|
|
1586
1749
|
@overload
|
|
@@ -1591,19 +1754,24 @@ class AsyncInferenceClient:
|
|
|
1591
1754
|
details: Literal[True] = ...,
|
|
1592
1755
|
stream: bool = ...,
|
|
1593
1756
|
model: Optional[str] = None,
|
|
1594
|
-
|
|
1595
|
-
max_new_tokens: int = 20,
|
|
1757
|
+
# Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
|
|
1596
1758
|
best_of: Optional[int] = None,
|
|
1759
|
+
decoder_input_details: Optional[bool] = None,
|
|
1760
|
+
do_sample: Optional[bool] = False, # Manual default value
|
|
1761
|
+
frequency_penalty: Optional[float] = None,
|
|
1762
|
+
grammar: Optional[TextGenerationInputGrammarType] = None,
|
|
1763
|
+
max_new_tokens: Optional[int] = None,
|
|
1597
1764
|
repetition_penalty: Optional[float] = None,
|
|
1598
|
-
return_full_text: bool = False,
|
|
1765
|
+
return_full_text: Optional[bool] = False, # Manual default value
|
|
1599
1766
|
seed: Optional[int] = None,
|
|
1600
|
-
stop_sequences: Optional[List[str]] = None,
|
|
1767
|
+
stop_sequences: Optional[List[str]] = None, # Same as `stop`
|
|
1601
1768
|
temperature: Optional[float] = None,
|
|
1602
1769
|
top_k: Optional[int] = None,
|
|
1770
|
+
top_n_tokens: Optional[int] = None,
|
|
1603
1771
|
top_p: Optional[float] = None,
|
|
1604
1772
|
truncate: Optional[int] = None,
|
|
1605
1773
|
typical_p: Optional[float] = None,
|
|
1606
|
-
watermark: bool =
|
|
1774
|
+
watermark: Optional[bool] = None,
|
|
1607
1775
|
) -> Union[TextGenerationOutput, AsyncIterable[TextGenerationStreamOutput]]: ...
|
|
1608
1776
|
|
|
1609
1777
|
async def text_generation(
|
|
@@ -1613,20 +1781,24 @@ class AsyncInferenceClient:
|
|
|
1613
1781
|
details: bool = False,
|
|
1614
1782
|
stream: bool = False,
|
|
1615
1783
|
model: Optional[str] = None,
|
|
1616
|
-
|
|
1617
|
-
max_new_tokens: int = 20,
|
|
1784
|
+
# Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
|
|
1618
1785
|
best_of: Optional[int] = None,
|
|
1786
|
+
decoder_input_details: Optional[bool] = None,
|
|
1787
|
+
do_sample: Optional[bool] = False, # Manual default value
|
|
1788
|
+
frequency_penalty: Optional[float] = None,
|
|
1789
|
+
grammar: Optional[TextGenerationInputGrammarType] = None,
|
|
1790
|
+
max_new_tokens: Optional[int] = None,
|
|
1619
1791
|
repetition_penalty: Optional[float] = None,
|
|
1620
|
-
return_full_text: bool = False,
|
|
1792
|
+
return_full_text: Optional[bool] = False, # Manual default value
|
|
1621
1793
|
seed: Optional[int] = None,
|
|
1622
|
-
stop_sequences: Optional[List[str]] = None,
|
|
1794
|
+
stop_sequences: Optional[List[str]] = None, # Same as `stop`
|
|
1623
1795
|
temperature: Optional[float] = None,
|
|
1624
1796
|
top_k: Optional[int] = None,
|
|
1797
|
+
top_n_tokens: Optional[int] = None,
|
|
1625
1798
|
top_p: Optional[float] = None,
|
|
1626
1799
|
truncate: Optional[int] = None,
|
|
1627
1800
|
typical_p: Optional[float] = None,
|
|
1628
|
-
watermark: bool =
|
|
1629
|
-
decoder_input_details: bool = False,
|
|
1801
|
+
watermark: Optional[bool] = None,
|
|
1630
1802
|
) -> Union[str, TextGenerationOutput, AsyncIterable[str], AsyncIterable[TextGenerationStreamOutput]]:
|
|
1631
1803
|
"""
|
|
1632
1804
|
Given a prompt, generate the following text.
|
|
@@ -1654,38 +1826,46 @@ class AsyncInferenceClient:
|
|
|
1654
1826
|
model (`str`, *optional*):
|
|
1655
1827
|
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
|
1656
1828
|
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
|
|
1657
|
-
|
|
1829
|
+
best_of (`int`, *optional*):
|
|
1830
|
+
Generate best_of sequences and return the one if the highest token logprobs.
|
|
1831
|
+
decoder_input_details (`bool`, *optional*):
|
|
1832
|
+
Return the decoder input token logprobs and ids. You must set `details=True` as well for it to be taken
|
|
1833
|
+
into account. Defaults to `False`.
|
|
1834
|
+
do_sample (`bool`, *optional*):
|
|
1658
1835
|
Activate logits sampling
|
|
1659
|
-
|
|
1836
|
+
frequency_penalty (`float`, *optional*):
|
|
1837
|
+
Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in
|
|
1838
|
+
the text so far, decreasing the model's likelihood to repeat the same line verbatim.
|
|
1839
|
+
grammar ([`TextGenerationInputGrammarType`], *optional*):
|
|
1840
|
+
Grammar constraints. Can be either a JSONSchema or a regex.
|
|
1841
|
+
max_new_tokens (`int`, *optional*):
|
|
1660
1842
|
Maximum number of generated tokens
|
|
1661
|
-
|
|
1662
|
-
Generate best_of sequences and return the one if the highest token logprobs
|
|
1663
|
-
repetition_penalty (`float`):
|
|
1843
|
+
repetition_penalty (`float`, *optional*):
|
|
1664
1844
|
The parameter for repetition penalty. 1.0 means no penalty. See [this
|
|
1665
1845
|
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
|
1666
|
-
return_full_text (`bool
|
|
1846
|
+
return_full_text (`bool`, *optional*):
|
|
1667
1847
|
Whether to prepend the prompt to the generated text
|
|
1668
|
-
seed (`int
|
|
1848
|
+
seed (`int`, *optional*):
|
|
1669
1849
|
Random sampling seed
|
|
1670
|
-
stop_sequences (`List[str]
|
|
1850
|
+
stop_sequences (`List[str]`, *optional*):
|
|
1671
1851
|
Stop generating tokens if a member of `stop_sequences` is generated
|
|
1672
|
-
temperature (`float
|
|
1852
|
+
temperature (`float`, *optional*):
|
|
1673
1853
|
The value used to module the logits distribution.
|
|
1674
|
-
|
|
1854
|
+
top_n_tokens (`int`, *optional*):
|
|
1855
|
+
Return information about the `top_n_tokens` most likely tokens at each generation step, instead of
|
|
1856
|
+
just the sampled token.
|
|
1857
|
+
top_k (`int`, *optional`):
|
|
1675
1858
|
The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
|
1676
|
-
top_p (`float`):
|
|
1859
|
+
top_p (`float`, *optional`):
|
|
1677
1860
|
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
|
1678
1861
|
higher are kept for generation.
|
|
1679
|
-
truncate (`int`):
|
|
1680
|
-
Truncate inputs tokens to the given size
|
|
1681
|
-
typical_p (`float`):
|
|
1862
|
+
truncate (`int`, *optional`):
|
|
1863
|
+
Truncate inputs tokens to the given size.
|
|
1864
|
+
typical_p (`float`, *optional`):
|
|
1682
1865
|
Typical Decoding mass
|
|
1683
1866
|
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
|
|
1684
|
-
watermark (`bool`):
|
|
1867
|
+
watermark (`bool`, *optional`):
|
|
1685
1868
|
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
|
1686
|
-
decoder_input_details (`bool`):
|
|
1687
|
-
Return the decoder input token logprobs and ids. You must set `details=True` as well for it to be taken
|
|
1688
|
-
into account. Defaults to `False`.
|
|
1689
1869
|
|
|
1690
1870
|
Returns:
|
|
1691
1871
|
`Union[str, TextGenerationOutput, Iterable[str], Iterable[TextGenerationStreamOutput]]`:
|
|
@@ -1738,10 +1918,10 @@ class AsyncInferenceClient:
|
|
|
1738
1918
|
generated_tokens=12,
|
|
1739
1919
|
seed=None,
|
|
1740
1920
|
prefill=[
|
|
1741
|
-
|
|
1742
|
-
|
|
1921
|
+
TextGenerationPrefillOutputToken(id=487, text='The', logprob=None),
|
|
1922
|
+
TextGenerationPrefillOutputToken(id=53789, text=' hugging', logprob=-13.171875),
|
|
1743
1923
|
(...)
|
|
1744
|
-
|
|
1924
|
+
TextGenerationPrefillOutputToken(id=204, text=' ', logprob=-7.0390625)
|
|
1745
1925
|
],
|
|
1746
1926
|
tokens=[
|
|
1747
1927
|
TokenElement(id=1425, text='100', logprob=-1.0175781, special=False),
|
|
@@ -1775,8 +1955,35 @@ class AsyncInferenceClient:
|
|
|
1775
1955
|
logprob=-0.5703125,
|
|
1776
1956
|
special=False),
|
|
1777
1957
|
generated_text='100% open source and built to be easy to use.',
|
|
1778
|
-
details=
|
|
1958
|
+
details=TextGenerationStreamOutputStreamDetails(finish_reason='length', generated_tokens=12, seed=None)
|
|
1779
1959
|
)
|
|
1960
|
+
|
|
1961
|
+
# Case 5: generate constrained output using grammar
|
|
1962
|
+
>>> response = await client.text_generation(
|
|
1963
|
+
... prompt="I saw a puppy a cat and a raccoon during my bike ride in the park",
|
|
1964
|
+
... model="HuggingFaceH4/zephyr-orpo-141b-A35b-v0.1",
|
|
1965
|
+
... max_new_tokens=100,
|
|
1966
|
+
... repetition_penalty=1.3,
|
|
1967
|
+
... grammar={
|
|
1968
|
+
... "type": "json",
|
|
1969
|
+
... "value": {
|
|
1970
|
+
... "properties": {
|
|
1971
|
+
... "location": {"type": "string"},
|
|
1972
|
+
... "activity": {"type": "string"},
|
|
1973
|
+
... "animals_seen": {"type": "integer", "minimum": 1, "maximum": 5},
|
|
1974
|
+
... "animals": {"type": "array", "items": {"type": "string"}},
|
|
1975
|
+
... },
|
|
1976
|
+
... "required": ["location", "activity", "animals_seen", "animals"],
|
|
1977
|
+
... },
|
|
1978
|
+
... },
|
|
1979
|
+
... )
|
|
1980
|
+
>>> json.loads(response)
|
|
1981
|
+
{
|
|
1982
|
+
"activity": "bike riding",
|
|
1983
|
+
"animals": ["puppy", "cat", "raccoon"],
|
|
1984
|
+
"animals_seen": 3,
|
|
1985
|
+
"location": "park"
|
|
1986
|
+
}
|
|
1780
1987
|
```
|
|
1781
1988
|
"""
|
|
1782
1989
|
if decoder_input_details and not details:
|
|
@@ -1787,41 +1994,48 @@ class AsyncInferenceClient:
|
|
|
1787
1994
|
decoder_input_details = False
|
|
1788
1995
|
|
|
1789
1996
|
# Build payload
|
|
1997
|
+
parameters = {
|
|
1998
|
+
"best_of": best_of,
|
|
1999
|
+
"decoder_input_details": decoder_input_details,
|
|
2000
|
+
"do_sample": do_sample,
|
|
2001
|
+
"frequency_penalty": frequency_penalty,
|
|
2002
|
+
"grammar": grammar,
|
|
2003
|
+
"max_new_tokens": max_new_tokens,
|
|
2004
|
+
"repetition_penalty": repetition_penalty,
|
|
2005
|
+
"return_full_text": return_full_text,
|
|
2006
|
+
"seed": seed,
|
|
2007
|
+
"stop": stop_sequences if stop_sequences is not None else [],
|
|
2008
|
+
"temperature": temperature,
|
|
2009
|
+
"top_k": top_k,
|
|
2010
|
+
"top_n_tokens": top_n_tokens,
|
|
2011
|
+
"top_p": top_p,
|
|
2012
|
+
"truncate": truncate,
|
|
2013
|
+
"typical_p": typical_p,
|
|
2014
|
+
"watermark": watermark,
|
|
2015
|
+
}
|
|
2016
|
+
parameters = {k: v for k, v in parameters.items() if v is not None}
|
|
1790
2017
|
payload = {
|
|
1791
2018
|
"inputs": prompt,
|
|
1792
|
-
"parameters":
|
|
1793
|
-
"best_of": best_of,
|
|
1794
|
-
"decoder_input_details": decoder_input_details,
|
|
1795
|
-
"details": details,
|
|
1796
|
-
"do_sample": do_sample,
|
|
1797
|
-
"max_new_tokens": max_new_tokens,
|
|
1798
|
-
"repetition_penalty": repetition_penalty,
|
|
1799
|
-
"return_full_text": return_full_text,
|
|
1800
|
-
"seed": seed,
|
|
1801
|
-
"stop": stop_sequences if stop_sequences is not None else [],
|
|
1802
|
-
"temperature": temperature,
|
|
1803
|
-
"top_k": top_k,
|
|
1804
|
-
"top_p": top_p,
|
|
1805
|
-
"truncate": truncate,
|
|
1806
|
-
"typical_p": typical_p,
|
|
1807
|
-
"watermark": watermark,
|
|
1808
|
-
},
|
|
2019
|
+
"parameters": parameters,
|
|
1809
2020
|
"stream": stream,
|
|
1810
2021
|
}
|
|
1811
2022
|
|
|
1812
2023
|
# Remove some parameters if not a TGI server
|
|
1813
|
-
|
|
1814
|
-
|
|
2024
|
+
unsupported_kwargs = _get_unsupported_text_generation_kwargs(model)
|
|
2025
|
+
if len(unsupported_kwargs) > 0:
|
|
2026
|
+
# The server does not support some parameters
|
|
2027
|
+
# => means it is not a TGI server
|
|
2028
|
+
# => remove unsupported parameters and warn the user
|
|
1815
2029
|
|
|
1816
2030
|
ignored_parameters = []
|
|
1817
|
-
for key in
|
|
1818
|
-
if parameters
|
|
2031
|
+
for key in unsupported_kwargs:
|
|
2032
|
+
if parameters.get(key):
|
|
1819
2033
|
ignored_parameters.append(key)
|
|
1820
|
-
|
|
2034
|
+
parameters.pop(key, None)
|
|
1821
2035
|
if len(ignored_parameters) > 0:
|
|
1822
2036
|
warnings.warn(
|
|
1823
|
-
"API endpoint/model for text-generation is not served via TGI. Ignoring parameters"
|
|
1824
|
-
f" {ignored_parameters}.",
|
|
2037
|
+
"API endpoint/model for text-generation is not served via TGI. Ignoring following parameters:"
|
|
2038
|
+
f" {', '.join(ignored_parameters)}.",
|
|
1825
2039
|
UserWarning,
|
|
1826
2040
|
)
|
|
1827
2041
|
if details:
|
|
@@ -1841,28 +2055,32 @@ class AsyncInferenceClient:
|
|
|
1841
2055
|
try:
|
|
1842
2056
|
bytes_output = await self.post(json=payload, model=model, task="text-generation", stream=stream) # type: ignore
|
|
1843
2057
|
except _import_aiohttp().ClientResponseError as e:
|
|
1844
|
-
|
|
1845
|
-
if e.
|
|
1846
|
-
|
|
2058
|
+
match = MODEL_KWARGS_NOT_USED_REGEX.search(e.response_error_payload["error"])
|
|
2059
|
+
if e.status == 400 and match:
|
|
2060
|
+
unused_params = [kwarg.strip("' ") for kwarg in match.group(1).split(",")]
|
|
2061
|
+
_set_unsupported_text_generation_kwargs(model, unused_params)
|
|
1847
2062
|
return await self.text_generation( # type: ignore
|
|
1848
2063
|
prompt=prompt,
|
|
1849
2064
|
details=details,
|
|
1850
2065
|
stream=stream,
|
|
1851
2066
|
model=model,
|
|
2067
|
+
best_of=best_of,
|
|
2068
|
+
decoder_input_details=decoder_input_details,
|
|
1852
2069
|
do_sample=do_sample,
|
|
2070
|
+
frequency_penalty=frequency_penalty,
|
|
2071
|
+
grammar=grammar,
|
|
1853
2072
|
max_new_tokens=max_new_tokens,
|
|
1854
|
-
best_of=best_of,
|
|
1855
2073
|
repetition_penalty=repetition_penalty,
|
|
1856
2074
|
return_full_text=return_full_text,
|
|
1857
2075
|
seed=seed,
|
|
1858
2076
|
stop_sequences=stop_sequences,
|
|
1859
2077
|
temperature=temperature,
|
|
1860
2078
|
top_k=top_k,
|
|
2079
|
+
top_n_tokens=top_n_tokens,
|
|
1861
2080
|
top_p=top_p,
|
|
1862
2081
|
truncate=truncate,
|
|
1863
2082
|
typical_p=typical_p,
|
|
1864
2083
|
watermark=watermark,
|
|
1865
|
-
decoder_input_details=decoder_input_details,
|
|
1866
2084
|
)
|
|
1867
2085
|
raise_text_generation_error(e)
|
|
1868
2086
|
|