huggingface-hub 0.23.4__py3-none-any.whl → 0.24.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of huggingface-hub might be problematic. Click here for more details.
- huggingface_hub/__init__.py +47 -15
- huggingface_hub/_commit_api.py +38 -8
- huggingface_hub/_inference_endpoints.py +11 -4
- huggingface_hub/_local_folder.py +22 -13
- huggingface_hub/_snapshot_download.py +12 -7
- huggingface_hub/_webhooks_server.py +3 -1
- huggingface_hub/commands/huggingface_cli.py +4 -3
- huggingface_hub/commands/repo_files.py +128 -0
- huggingface_hub/constants.py +12 -0
- huggingface_hub/file_download.py +127 -91
- huggingface_hub/hf_api.py +976 -341
- huggingface_hub/hf_file_system.py +30 -3
- huggingface_hub/hub_mixin.py +17 -6
- huggingface_hub/inference/_client.py +379 -43
- huggingface_hub/inference/_common.py +0 -2
- huggingface_hub/inference/_generated/_async_client.py +396 -49
- huggingface_hub/inference/_generated/types/__init__.py +4 -1
- huggingface_hub/inference/_generated/types/chat_completion.py +41 -21
- huggingface_hub/inference/_generated/types/feature_extraction.py +23 -5
- huggingface_hub/inference/_generated/types/text_generation.py +29 -0
- huggingface_hub/lfs.py +11 -6
- huggingface_hub/repocard_data.py +3 -3
- huggingface_hub/repository.py +6 -6
- huggingface_hub/serialization/__init__.py +8 -3
- huggingface_hub/serialization/_base.py +13 -16
- huggingface_hub/serialization/_tensorflow.py +4 -3
- huggingface_hub/serialization/_torch.py +399 -22
- huggingface_hub/utils/__init__.py +0 -1
- huggingface_hub/utils/_errors.py +1 -1
- huggingface_hub/utils/_fixes.py +14 -3
- huggingface_hub/utils/_paths.py +17 -6
- huggingface_hub/utils/_subprocess.py +0 -1
- huggingface_hub/utils/_telemetry.py +9 -1
- huggingface_hub/utils/endpoint_helpers.py +2 -186
- huggingface_hub/utils/sha.py +36 -1
- huggingface_hub/utils/tqdm.py +0 -1
- {huggingface_hub-0.23.4.dist-info → huggingface_hub-0.24.0.dist-info}/METADATA +12 -9
- {huggingface_hub-0.23.4.dist-info → huggingface_hub-0.24.0.dist-info}/RECORD +42 -42
- huggingface_hub/serialization/_numpy.py +0 -68
- {huggingface_hub-0.23.4.dist-info → huggingface_hub-0.24.0.dist-info}/LICENSE +0 -0
- {huggingface_hub-0.23.4.dist-info → huggingface_hub-0.24.0.dist-info}/WHEEL +0 -0
- {huggingface_hub-0.23.4.dist-info → huggingface_hub-0.24.0.dist-info}/entry_points.txt +0 -0
- {huggingface_hub-0.23.4.dist-info → huggingface_hub-0.24.0.dist-info}/top_level.txt +0 -0
|
@@ -64,6 +64,7 @@ from huggingface_hub.inference._generated.types import (
|
|
|
64
64
|
AudioClassificationOutputElement,
|
|
65
65
|
AudioToAudioOutputElement,
|
|
66
66
|
AutomaticSpeechRecognitionOutput,
|
|
67
|
+
ChatCompletionInputGrammarType,
|
|
67
68
|
ChatCompletionInputTool,
|
|
68
69
|
ChatCompletionInputToolTypeClass,
|
|
69
70
|
ChatCompletionOutput,
|
|
@@ -89,13 +90,13 @@ from huggingface_hub.inference._generated.types import (
|
|
|
89
90
|
ZeroShotClassificationOutputElement,
|
|
90
91
|
ZeroShotImageClassificationOutputElement,
|
|
91
92
|
)
|
|
92
|
-
from huggingface_hub.inference._generated.types.chat_completion import ChatCompletionInputToolTypeEnum
|
|
93
93
|
from huggingface_hub.inference._types import (
|
|
94
94
|
ConversationalOutput, # soon to be removed
|
|
95
95
|
)
|
|
96
96
|
from huggingface_hub.utils import (
|
|
97
97
|
build_hf_headers,
|
|
98
98
|
)
|
|
99
|
+
from huggingface_hub.utils._deprecation import _deprecate_positional_args
|
|
99
100
|
|
|
100
101
|
from .._common import _async_yield_from, _import_aiohttp
|
|
101
102
|
|
|
@@ -119,12 +120,16 @@ class AsyncInferenceClient:
|
|
|
119
120
|
|
|
120
121
|
Args:
|
|
121
122
|
model (`str`, `optional`):
|
|
122
|
-
The model to run inference with. Can be a model id hosted on the Hugging Face Hub, e.g. `
|
|
123
|
+
The model to run inference with. Can be a model id hosted on the Hugging Face Hub, e.g. `meta-llama/Meta-Llama-3-8B-Instruct`
|
|
123
124
|
or a URL to a deployed Inference Endpoint. Defaults to None, in which case a recommended model is
|
|
124
125
|
automatically selected for the task.
|
|
126
|
+
Note: for better compatibility with OpenAI's client, `model` has been aliased as `base_url`. Those 2
|
|
127
|
+
arguments are mutually exclusive and have the exact same behavior.
|
|
125
128
|
token (`str` or `bool`, *optional*):
|
|
126
129
|
Hugging Face token. Will default to the locally saved token if not provided.
|
|
127
130
|
Pass `token=False` if you don't want to send your token to the server.
|
|
131
|
+
Note: for better compatibility with OpenAI's client, `token` has been aliased as `api_key`. Those 2
|
|
132
|
+
arguments are mutually exclusive and have the exact same behavior.
|
|
128
133
|
timeout (`float`, `optional`):
|
|
129
134
|
The maximum number of seconds to wait for a response from the server. Loading a new model in Inference
|
|
130
135
|
API can take up to several minutes. Defaults to None, meaning it will loop until the server is available.
|
|
@@ -133,23 +138,52 @@ class AsyncInferenceClient:
|
|
|
133
138
|
Values in this dictionary will override the default values.
|
|
134
139
|
cookies (`Dict[str, str]`, `optional`):
|
|
135
140
|
Additional cookies to send to the server.
|
|
141
|
+
base_url (`str`, `optional`):
|
|
142
|
+
Base URL to run inference. This is a duplicated argument from `model` to make [`InferenceClient`]
|
|
143
|
+
follow the same pattern as `openai.OpenAI` client. Cannot be used if `model` is set. Defaults to None.
|
|
144
|
+
api_key (`str`, `optional`):
|
|
145
|
+
Token to use for authentication. This is a duplicated argument from `token` to make [`InferenceClient`]
|
|
146
|
+
follow the same pattern as `openai.OpenAI` client. Cannot be used if `token` is set. Defaults to None.
|
|
136
147
|
"""
|
|
137
148
|
|
|
149
|
+
@_deprecate_positional_args(version="0.26")
|
|
138
150
|
def __init__(
|
|
139
151
|
self,
|
|
140
152
|
model: Optional[str] = None,
|
|
153
|
+
*,
|
|
141
154
|
token: Union[str, bool, None] = None,
|
|
142
155
|
timeout: Optional[float] = None,
|
|
143
156
|
headers: Optional[Dict[str, str]] = None,
|
|
144
157
|
cookies: Optional[Dict[str, str]] = None,
|
|
158
|
+
proxies: Optional[Any] = None,
|
|
159
|
+
# OpenAI compatibility
|
|
160
|
+
base_url: Optional[str] = None,
|
|
161
|
+
api_key: Optional[str] = None,
|
|
145
162
|
) -> None:
|
|
163
|
+
if model is not None and base_url is not None:
|
|
164
|
+
raise ValueError(
|
|
165
|
+
"Received both `model` and `base_url` arguments. Please provide only one of them."
|
|
166
|
+
" `base_url` is an alias for `model` to make the API compatible with OpenAI's client."
|
|
167
|
+
" It has the exact same behavior as `model`."
|
|
168
|
+
)
|
|
169
|
+
if token is not None and api_key is not None:
|
|
170
|
+
raise ValueError(
|
|
171
|
+
"Received both `token` and `api_key` arguments. Please provide only one of them."
|
|
172
|
+
" `api_key` is an alias for `token` to make the API compatible with OpenAI's client."
|
|
173
|
+
" It has the exact same behavior as `token`."
|
|
174
|
+
)
|
|
175
|
+
|
|
146
176
|
self.model: Optional[str] = model
|
|
147
|
-
self.token: Union[str, bool, None] = token
|
|
148
|
-
self.headers = CaseInsensitiveDict(build_hf_headers(token=token)) #
|
|
177
|
+
self.token: Union[str, bool, None] = token or api_key
|
|
178
|
+
self.headers = CaseInsensitiveDict(build_hf_headers(token=self.token)) # 'authorization' + 'user-agent'
|
|
149
179
|
if headers is not None:
|
|
150
180
|
self.headers.update(headers)
|
|
151
181
|
self.cookies = cookies
|
|
152
182
|
self.timeout = timeout
|
|
183
|
+
self.proxies = proxies
|
|
184
|
+
|
|
185
|
+
# OpenAI compatibility
|
|
186
|
+
self.base_url = base_url
|
|
153
187
|
|
|
154
188
|
def __repr__(self):
|
|
155
189
|
return f"<InferenceClient(model='{self.model if self.model else ''}', timeout={self.timeout})>"
|
|
@@ -250,7 +284,7 @@ class AsyncInferenceClient:
|
|
|
250
284
|
)
|
|
251
285
|
|
|
252
286
|
try:
|
|
253
|
-
response = await client.post(url, json=json, data=data_as_binary)
|
|
287
|
+
response = await client.post(url, json=json, data=data_as_binary, proxy=self.proxies)
|
|
254
288
|
response_error_payload = None
|
|
255
289
|
if response.status != 200:
|
|
256
290
|
try:
|
|
@@ -284,11 +318,16 @@ class AsyncInferenceClient:
|
|
|
284
318
|
) from error
|
|
285
319
|
# ...or wait 1s and retry
|
|
286
320
|
logger.info(f"Waiting for model to be loaded on the server: {error}")
|
|
321
|
+
if "X-wait-for-model" not in headers and url.startswith(INFERENCE_ENDPOINT):
|
|
322
|
+
headers["X-wait-for-model"] = "1"
|
|
287
323
|
time.sleep(1)
|
|
288
324
|
if timeout is not None:
|
|
289
325
|
timeout = max(self.timeout - (time.time() - t0), 1) # type: ignore
|
|
290
326
|
continue
|
|
291
327
|
raise error
|
|
328
|
+
except Exception:
|
|
329
|
+
await client.close()
|
|
330
|
+
raise
|
|
292
331
|
|
|
293
332
|
async def audio_classification(
|
|
294
333
|
self,
|
|
@@ -427,10 +466,11 @@ class AsyncInferenceClient:
|
|
|
427
466
|
max_tokens: Optional[int] = None,
|
|
428
467
|
n: Optional[int] = None,
|
|
429
468
|
presence_penalty: Optional[float] = None,
|
|
469
|
+
response_format: Optional[ChatCompletionInputGrammarType] = None,
|
|
430
470
|
seed: Optional[int] = None,
|
|
431
471
|
stop: Optional[List[str]] = None,
|
|
432
472
|
temperature: Optional[float] = None,
|
|
433
|
-
tool_choice: Optional[Union[ChatCompletionInputToolTypeClass,
|
|
473
|
+
tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, str]] = None,
|
|
434
474
|
tool_prompt: Optional[str] = None,
|
|
435
475
|
tools: Optional[List[ChatCompletionInputTool]] = None,
|
|
436
476
|
top_logprobs: Optional[int] = None,
|
|
@@ -450,10 +490,11 @@ class AsyncInferenceClient:
|
|
|
450
490
|
max_tokens: Optional[int] = None,
|
|
451
491
|
n: Optional[int] = None,
|
|
452
492
|
presence_penalty: Optional[float] = None,
|
|
493
|
+
response_format: Optional[ChatCompletionInputGrammarType] = None,
|
|
453
494
|
seed: Optional[int] = None,
|
|
454
495
|
stop: Optional[List[str]] = None,
|
|
455
496
|
temperature: Optional[float] = None,
|
|
456
|
-
tool_choice: Optional[Union[ChatCompletionInputToolTypeClass,
|
|
497
|
+
tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, str]] = None,
|
|
457
498
|
tool_prompt: Optional[str] = None,
|
|
458
499
|
tools: Optional[List[ChatCompletionInputTool]] = None,
|
|
459
500
|
top_logprobs: Optional[int] = None,
|
|
@@ -473,10 +514,11 @@ class AsyncInferenceClient:
|
|
|
473
514
|
max_tokens: Optional[int] = None,
|
|
474
515
|
n: Optional[int] = None,
|
|
475
516
|
presence_penalty: Optional[float] = None,
|
|
517
|
+
response_format: Optional[ChatCompletionInputGrammarType] = None,
|
|
476
518
|
seed: Optional[int] = None,
|
|
477
519
|
stop: Optional[List[str]] = None,
|
|
478
520
|
temperature: Optional[float] = None,
|
|
479
|
-
tool_choice: Optional[Union[ChatCompletionInputToolTypeClass,
|
|
521
|
+
tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, str]] = None,
|
|
480
522
|
tool_prompt: Optional[str] = None,
|
|
481
523
|
tools: Optional[List[ChatCompletionInputTool]] = None,
|
|
482
524
|
top_logprobs: Optional[int] = None,
|
|
@@ -496,10 +538,11 @@ class AsyncInferenceClient:
|
|
|
496
538
|
max_tokens: Optional[int] = None,
|
|
497
539
|
n: Optional[int] = None,
|
|
498
540
|
presence_penalty: Optional[float] = None,
|
|
541
|
+
response_format: Optional[ChatCompletionInputGrammarType] = None,
|
|
499
542
|
seed: Optional[int] = None,
|
|
500
543
|
stop: Optional[List[str]] = None,
|
|
501
544
|
temperature: Optional[float] = None,
|
|
502
|
-
tool_choice: Optional[Union[ChatCompletionInputToolTypeClass,
|
|
545
|
+
tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, str]] = None,
|
|
503
546
|
tool_prompt: Optional[str] = None,
|
|
504
547
|
tools: Optional[List[ChatCompletionInputTool]] = None,
|
|
505
548
|
top_logprobs: Optional[int] = None,
|
|
@@ -510,11 +553,10 @@ class AsyncInferenceClient:
|
|
|
510
553
|
|
|
511
554
|
<Tip>
|
|
512
555
|
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
to install it.
|
|
556
|
+
The `client.chat_completion` method is aliased as `client.chat.completions.create` for compatibility with OpenAI's client.
|
|
557
|
+
Inputs and outputs are strictly the same and using either syntax will yield the same results.
|
|
558
|
+
Check out the [Inference guide](https://huggingface.co/docs/huggingface_hub/guides/inference#openai-compatibility)
|
|
559
|
+
for more details about OpenAI's compatibility.
|
|
518
560
|
|
|
519
561
|
</Tip>
|
|
520
562
|
|
|
@@ -525,6 +567,9 @@ class AsyncInferenceClient:
|
|
|
525
567
|
The model to use for chat-completion. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
|
526
568
|
Inference Endpoint. If not provided, the default recommended model for chat-based text-generation will be used.
|
|
527
569
|
See https://huggingface.co/tasks/text-generation for more details.
|
|
570
|
+
|
|
571
|
+
If `model` is a model ID, it is passed to the server as the `model` parameter. If you want to define a
|
|
572
|
+
custom URL while setting `model` in the request payload, you must set `base_url` when initializing [`InferenceClient`].
|
|
528
573
|
frequency_penalty (`float`, *optional*):
|
|
529
574
|
Penalizes new tokens based on their existing frequency
|
|
530
575
|
in the text so far. Range: [-2.0, 2.0]. Defaults to 0.0.
|
|
@@ -544,6 +589,8 @@ class AsyncInferenceClient:
|
|
|
544
589
|
presence_penalty (`float`, *optional*):
|
|
545
590
|
Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the
|
|
546
591
|
text so far, increasing the model's likelihood to talk about new topics.
|
|
592
|
+
response_format ([`ChatCompletionInputGrammarType`], *optional*):
|
|
593
|
+
Grammar constraints. Can be either a JSONSchema or a regex.
|
|
547
594
|
seed (Optional[`int`], *optional*):
|
|
548
595
|
Seed for reproducible control flow. Defaults to None.
|
|
549
596
|
stop (Optional[`str`], *optional*):
|
|
@@ -561,7 +608,7 @@ class AsyncInferenceClient:
|
|
|
561
608
|
top_p (`float`, *optional*):
|
|
562
609
|
Fraction of the most likely next words to sample from.
|
|
563
610
|
Must be between 0 and 1. Defaults to 1.0.
|
|
564
|
-
tool_choice ([`ChatCompletionInputToolTypeClass`] or
|
|
611
|
+
tool_choice ([`ChatCompletionInputToolTypeClass`] or `str`, *optional*):
|
|
565
612
|
The tool to use for the completion. Defaults to "auto".
|
|
566
613
|
tool_prompt (`str`, *optional*):
|
|
567
614
|
A prompt to be appended before the tools.
|
|
@@ -570,7 +617,7 @@ class AsyncInferenceClient:
|
|
|
570
617
|
provide a list of functions the model may generate JSON inputs for.
|
|
571
618
|
|
|
572
619
|
Returns:
|
|
573
|
-
[`ChatCompletionOutput] or Iterable of [`ChatCompletionStreamOutput`]:
|
|
620
|
+
[`ChatCompletionOutput`] or Iterable of [`ChatCompletionStreamOutput`]:
|
|
574
621
|
Generated text returned from the server:
|
|
575
622
|
- if `stream=False`, the generated text is returned as a [`ChatCompletionOutput`] (default).
|
|
576
623
|
- if `stream=True`, the generated text is returned token by token as a sequence of [`ChatCompletionStreamOutput`].
|
|
@@ -585,10 +632,9 @@ class AsyncInferenceClient:
|
|
|
585
632
|
|
|
586
633
|
```py
|
|
587
634
|
# Must be run in an async context
|
|
588
|
-
# Chat example
|
|
589
635
|
>>> from huggingface_hub import AsyncInferenceClient
|
|
590
636
|
>>> messages = [{"role": "user", "content": "What is the capital of France?"}]
|
|
591
|
-
>>> client = AsyncInferenceClient("
|
|
637
|
+
>>> client = AsyncInferenceClient("meta-llama/Meta-Llama-3-8B-Instruct")
|
|
592
638
|
>>> await client.chat_completion(messages, max_tokens=100)
|
|
593
639
|
ChatCompletionOutput(
|
|
594
640
|
choices=[
|
|
@@ -596,26 +642,75 @@ class AsyncInferenceClient:
|
|
|
596
642
|
finish_reason='eos_token',
|
|
597
643
|
index=0,
|
|
598
644
|
message=ChatCompletionOutputMessage(
|
|
599
|
-
|
|
600
|
-
|
|
645
|
+
role='assistant',
|
|
646
|
+
content='The capital of France is Paris.',
|
|
647
|
+
name=None,
|
|
648
|
+
tool_calls=None
|
|
649
|
+
),
|
|
650
|
+
logprobs=None
|
|
601
651
|
)
|
|
602
652
|
],
|
|
603
|
-
created=
|
|
653
|
+
created=1719907176,
|
|
654
|
+
id='',
|
|
655
|
+
model='meta-llama/Meta-Llama-3-8B-Instruct',
|
|
656
|
+
object='text_completion',
|
|
657
|
+
system_fingerprint='2.0.4-sha-f426a33',
|
|
658
|
+
usage=ChatCompletionOutputUsage(
|
|
659
|
+
completion_tokens=8,
|
|
660
|
+
prompt_tokens=17,
|
|
661
|
+
total_tokens=25
|
|
662
|
+
)
|
|
604
663
|
)
|
|
664
|
+
```
|
|
605
665
|
|
|
666
|
+
Example (stream=True):
|
|
667
|
+
```py
|
|
668
|
+
# Must be run in an async context
|
|
669
|
+
>>> from huggingface_hub import AsyncInferenceClient
|
|
670
|
+
>>> messages = [{"role": "user", "content": "What is the capital of France?"}]
|
|
671
|
+
>>> client = AsyncInferenceClient("meta-llama/Meta-Llama-3-8B-Instruct")
|
|
606
672
|
>>> async for token in await client.chat_completion(messages, max_tokens=10, stream=True):
|
|
607
673
|
... print(token)
|
|
608
674
|
ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content='The', role='assistant'), index=0, finish_reason=None)], created=1710498504)
|
|
609
675
|
ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content=' capital', role='assistant'), index=0, finish_reason=None)], created=1710498504)
|
|
610
676
|
(...)
|
|
611
677
|
ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content=' may', role='assistant'), index=0, finish_reason=None)], created=1710498504)
|
|
678
|
+
```
|
|
612
679
|
|
|
613
|
-
|
|
680
|
+
Example using OpenAI's syntax:
|
|
681
|
+
```py
|
|
682
|
+
# Must be run in an async context
|
|
683
|
+
# instead of `from openai import OpenAI`
|
|
684
|
+
from huggingface_hub import AsyncInferenceClient
|
|
685
|
+
|
|
686
|
+
# instead of `client = OpenAI(...)`
|
|
687
|
+
client = AsyncInferenceClient(
|
|
688
|
+
base_url=...,
|
|
689
|
+
api_key=...,
|
|
690
|
+
)
|
|
691
|
+
|
|
692
|
+
output = await client.chat.completions.create(
|
|
693
|
+
model="meta-llama/Meta-Llama-3-8B-Instruct",
|
|
694
|
+
messages=[
|
|
695
|
+
{"role": "system", "content": "You are a helpful assistant."},
|
|
696
|
+
{"role": "user", "content": "Count to 10"},
|
|
697
|
+
],
|
|
698
|
+
stream=True,
|
|
699
|
+
max_tokens=1024,
|
|
700
|
+
)
|
|
701
|
+
|
|
702
|
+
for chunk in output:
|
|
703
|
+
print(chunk.choices[0].delta.content)
|
|
704
|
+
```
|
|
705
|
+
|
|
706
|
+
Example using tools:
|
|
707
|
+
```py
|
|
708
|
+
# Must be run in an async context
|
|
614
709
|
>>> client = AsyncInferenceClient("meta-llama/Meta-Llama-3-70B-Instruct")
|
|
615
710
|
>>> messages = [
|
|
616
711
|
... {
|
|
617
712
|
... "role": "system",
|
|
618
|
-
... "content": "Don't make assumptions about what values to plug into functions. Ask
|
|
713
|
+
... "content": "Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.",
|
|
619
714
|
... },
|
|
620
715
|
... {
|
|
621
716
|
... "role": "user",
|
|
@@ -691,9 +786,44 @@ class AsyncInferenceClient:
|
|
|
691
786
|
description=None
|
|
692
787
|
)
|
|
693
788
|
```
|
|
789
|
+
|
|
790
|
+
Example using response_format:
|
|
791
|
+
```py
|
|
792
|
+
# Must be run in an async context
|
|
793
|
+
>>> from huggingface_hub import AsyncInferenceClient
|
|
794
|
+
>>> client = AsyncInferenceClient("meta-llama/Meta-Llama-3-70B-Instruct")
|
|
795
|
+
>>> messages = [
|
|
796
|
+
... {
|
|
797
|
+
... "role": "user",
|
|
798
|
+
... "content": "I saw a puppy a cat and a raccoon during my bike ride in the park. What did I saw and when?",
|
|
799
|
+
... },
|
|
800
|
+
... ]
|
|
801
|
+
>>> response_format = {
|
|
802
|
+
... "type": "json",
|
|
803
|
+
... "value": {
|
|
804
|
+
... "properties": {
|
|
805
|
+
... "location": {"type": "string"},
|
|
806
|
+
... "activity": {"type": "string"},
|
|
807
|
+
... "animals_seen": {"type": "integer", "minimum": 1, "maximum": 5},
|
|
808
|
+
... "animals": {"type": "array", "items": {"type": "string"}},
|
|
809
|
+
... },
|
|
810
|
+
... "required": ["location", "activity", "animals_seen", "animals"],
|
|
811
|
+
... },
|
|
812
|
+
... }
|
|
813
|
+
>>> response = await client.chat_completion(
|
|
814
|
+
... messages=messages,
|
|
815
|
+
... response_format=response_format,
|
|
816
|
+
... max_tokens=500,
|
|
817
|
+
)
|
|
818
|
+
>>> response.choices[0].message.content
|
|
819
|
+
'{\n\n"activity": "bike ride",\n"animals": ["puppy", "cat", "raccoon"],\n"animals_seen": 3,\n"location": "park"}'
|
|
820
|
+
```
|
|
694
821
|
"""
|
|
695
|
-
#
|
|
696
|
-
|
|
822
|
+
# Determine model
|
|
823
|
+
# `self.xxx` takes precedence over the method argument only in `chat_completion`
|
|
824
|
+
# since `chat_completion(..., model=xxx)` is also a payload parameter for the
|
|
825
|
+
# server, we need to handle it differently
|
|
826
|
+
model = self.base_url or self.model or model or self.get_recommended_model("text-generation")
|
|
697
827
|
|
|
698
828
|
if _is_chat_completion_server(model):
|
|
699
829
|
# First, let's consider the server has a `/v1/chat/completions` endpoint.
|
|
@@ -702,11 +832,19 @@ class AsyncInferenceClient:
|
|
|
702
832
|
if not model_url.endswith("/chat/completions"):
|
|
703
833
|
model_url += "/v1/chat/completions"
|
|
704
834
|
|
|
835
|
+
# `model` is sent in the payload. Not used by the server but can be useful for debugging/routing.
|
|
836
|
+
if not model.startswith("http") and model.count("/") == 1:
|
|
837
|
+
# If it's a ID on the Hub => use it
|
|
838
|
+
model_id = model
|
|
839
|
+
else:
|
|
840
|
+
# Otherwise, we use a random string
|
|
841
|
+
model_id = "tgi"
|
|
842
|
+
|
|
705
843
|
try:
|
|
706
844
|
data = await self.post(
|
|
707
845
|
model=model_url,
|
|
708
846
|
json=dict(
|
|
709
|
-
model=
|
|
847
|
+
model=model_id,
|
|
710
848
|
messages=messages,
|
|
711
849
|
frequency_penalty=frequency_penalty,
|
|
712
850
|
logit_bias=logit_bias,
|
|
@@ -714,6 +852,7 @@ class AsyncInferenceClient:
|
|
|
714
852
|
max_tokens=max_tokens,
|
|
715
853
|
n=n,
|
|
716
854
|
presence_penalty=presence_penalty,
|
|
855
|
+
response_format=response_format,
|
|
717
856
|
seed=seed,
|
|
718
857
|
stop=stop,
|
|
719
858
|
temperature=temperature,
|
|
@@ -765,6 +904,11 @@ class AsyncInferenceClient:
|
|
|
765
904
|
"Tools are not supported by the model. This is due to the model not been served by a "
|
|
766
905
|
"Text-Generation-Inference server. The provided tool parameters will be ignored."
|
|
767
906
|
)
|
|
907
|
+
if response_format is not None:
|
|
908
|
+
warnings.warn(
|
|
909
|
+
"Response format is not supported by the model. This is due to the model not been served by a "
|
|
910
|
+
"Text-Generation-Inference server. The provided response format will be ignored."
|
|
911
|
+
)
|
|
768
912
|
|
|
769
913
|
# generate response
|
|
770
914
|
text_generation_output = await self.text_generation(
|
|
@@ -783,7 +927,6 @@ class AsyncInferenceClient:
|
|
|
783
927
|
return ChatCompletionOutput(
|
|
784
928
|
id="dummy",
|
|
785
929
|
model="dummy",
|
|
786
|
-
object="dummy",
|
|
787
930
|
system_fingerprint="dummy",
|
|
788
931
|
usage=None, # type: ignore # set to `None` as we don't want to provide false information
|
|
789
932
|
created=int(time.time()),
|
|
@@ -850,7 +993,7 @@ class AsyncInferenceClient:
|
|
|
850
993
|
>>> client = AsyncInferenceClient()
|
|
851
994
|
>>> output = await client.conversational("Hi, who are you?")
|
|
852
995
|
>>> output
|
|
853
|
-
{'generated_text': 'I am the one who knocks.', 'conversation': {'generated_responses': ['I am the one who knocks.'], 'past_user_inputs': ['Hi, who are you?']}, 'warnings': ['Setting `pad_token_id` to `eos_token_id`:50256
|
|
996
|
+
{'generated_text': 'I am the one who knocks.', 'conversation': {'generated_responses': ['I am the one who knocks.'], 'past_user_inputs': ['Hi, who are you?']}, 'warnings': ['Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.']}
|
|
854
997
|
>>> await client.conversational(
|
|
855
998
|
... "Wow, that's scary!",
|
|
856
999
|
... generated_responses=output["conversation"]["generated_responses"],
|
|
@@ -915,7 +1058,16 @@ class AsyncInferenceClient:
|
|
|
915
1058
|
response = await self.post(json=payload, model=model, task="document-question-answering")
|
|
916
1059
|
return DocumentQuestionAnsweringOutputElement.parse_obj_as_list(response)
|
|
917
1060
|
|
|
918
|
-
async def feature_extraction(
|
|
1061
|
+
async def feature_extraction(
|
|
1062
|
+
self,
|
|
1063
|
+
text: str,
|
|
1064
|
+
*,
|
|
1065
|
+
normalize: Optional[bool] = None,
|
|
1066
|
+
prompt_name: Optional[str] = None,
|
|
1067
|
+
truncate: Optional[bool] = None,
|
|
1068
|
+
truncation_direction: Optional[Literal["Left", "Right"]] = None,
|
|
1069
|
+
model: Optional[str] = None,
|
|
1070
|
+
) -> "np.ndarray":
|
|
919
1071
|
"""
|
|
920
1072
|
Generate embeddings for a given text.
|
|
921
1073
|
|
|
@@ -926,6 +1078,20 @@ class AsyncInferenceClient:
|
|
|
926
1078
|
The model to use for the conversational task. Can be a model ID hosted on the Hugging Face Hub or a URL to
|
|
927
1079
|
a deployed Inference Endpoint. If not provided, the default recommended conversational model will be used.
|
|
928
1080
|
Defaults to None.
|
|
1081
|
+
normalize (`bool`, *optional*):
|
|
1082
|
+
Whether to normalize the embeddings or not. Defaults to None.
|
|
1083
|
+
Only available on server powered by Text-Embedding-Inference.
|
|
1084
|
+
prompt_name (`str`, *optional*):
|
|
1085
|
+
The name of the prompt that should be used by for encoding. If not set, no prompt will be applied.
|
|
1086
|
+
Must be a key in the `Sentence Transformers` configuration `prompts` dictionary.
|
|
1087
|
+
For example if ``prompt_name`` is "query" and the ``prompts`` is {"query": "query: ",...},
|
|
1088
|
+
then the sentence "What is the capital of France?" will be encoded as "query: What is the capital of France?"
|
|
1089
|
+
because the prompt text will be prepended before any text to encode.
|
|
1090
|
+
truncate (`bool`, *optional*):
|
|
1091
|
+
Whether to truncate the embeddings or not. Defaults to None.
|
|
1092
|
+
Only available on server powered by Text-Embedding-Inference.
|
|
1093
|
+
truncation_direction (`Literal["Left", "Right"]`, *optional*):
|
|
1094
|
+
Which side of the input should be truncated when `truncate=True` is passed.
|
|
929
1095
|
|
|
930
1096
|
Returns:
|
|
931
1097
|
`np.ndarray`: The embedding representing the input text as a float32 numpy array.
|
|
@@ -948,7 +1114,16 @@ class AsyncInferenceClient:
|
|
|
948
1114
|
[ 0.28552425, -0.928395 , -1.2077185 , ..., 0.76810825, -2.1069427 , 0.6236161 ]], dtype=float32)
|
|
949
1115
|
```
|
|
950
1116
|
"""
|
|
951
|
-
|
|
1117
|
+
payload: Dict = {"inputs": text}
|
|
1118
|
+
if normalize is not None:
|
|
1119
|
+
payload["normalize"] = normalize
|
|
1120
|
+
if prompt_name is not None:
|
|
1121
|
+
payload["prompt_name"] = prompt_name
|
|
1122
|
+
if truncate is not None:
|
|
1123
|
+
payload["truncate"] = truncate
|
|
1124
|
+
if truncation_direction is not None:
|
|
1125
|
+
payload["truncation_direction"] = truncation_direction
|
|
1126
|
+
response = await self.post(json=payload, model=model, task="feature-extraction")
|
|
952
1127
|
np = _import_numpy()
|
|
953
1128
|
return np.array(_bytes_to_dict(response), dtype="float32")
|
|
954
1129
|
|
|
@@ -1192,7 +1367,8 @@ class AsyncInferenceClient:
|
|
|
1192
1367
|
```
|
|
1193
1368
|
"""
|
|
1194
1369
|
response = await self.post(data=image, model=model, task="image-to-text")
|
|
1195
|
-
|
|
1370
|
+
output = ImageToTextOutput.parse_obj(response)
|
|
1371
|
+
return output[0] if isinstance(output, list) else output
|
|
1196
1372
|
|
|
1197
1373
|
async def list_deployed_models(
|
|
1198
1374
|
self, frameworks: Union[None, str, Literal["all"], List[str]] = None
|
|
@@ -1643,6 +1819,7 @@ class AsyncInferenceClient:
|
|
|
1643
1819
|
stream: Literal[False] = ...,
|
|
1644
1820
|
model: Optional[str] = None,
|
|
1645
1821
|
# Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
|
|
1822
|
+
adapter_id: Optional[str] = None,
|
|
1646
1823
|
best_of: Optional[int] = None,
|
|
1647
1824
|
decoder_input_details: Optional[bool] = None,
|
|
1648
1825
|
do_sample: Optional[bool] = False, # Manual default value
|
|
@@ -1671,6 +1848,7 @@ class AsyncInferenceClient:
|
|
|
1671
1848
|
stream: Literal[False] = ...,
|
|
1672
1849
|
model: Optional[str] = None,
|
|
1673
1850
|
# Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
|
|
1851
|
+
adapter_id: Optional[str] = None,
|
|
1674
1852
|
best_of: Optional[int] = None,
|
|
1675
1853
|
decoder_input_details: Optional[bool] = None,
|
|
1676
1854
|
do_sample: Optional[bool] = False, # Manual default value
|
|
@@ -1699,6 +1877,7 @@ class AsyncInferenceClient:
|
|
|
1699
1877
|
stream: Literal[True] = ...,
|
|
1700
1878
|
model: Optional[str] = None,
|
|
1701
1879
|
# Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
|
|
1880
|
+
adapter_id: Optional[str] = None,
|
|
1702
1881
|
best_of: Optional[int] = None,
|
|
1703
1882
|
decoder_input_details: Optional[bool] = None,
|
|
1704
1883
|
do_sample: Optional[bool] = False, # Manual default value
|
|
@@ -1727,6 +1906,7 @@ class AsyncInferenceClient:
|
|
|
1727
1906
|
stream: Literal[True] = ...,
|
|
1728
1907
|
model: Optional[str] = None,
|
|
1729
1908
|
# Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
|
|
1909
|
+
adapter_id: Optional[str] = None,
|
|
1730
1910
|
best_of: Optional[int] = None,
|
|
1731
1911
|
decoder_input_details: Optional[bool] = None,
|
|
1732
1912
|
do_sample: Optional[bool] = False, # Manual default value
|
|
@@ -1755,6 +1935,7 @@ class AsyncInferenceClient:
|
|
|
1755
1935
|
stream: bool = ...,
|
|
1756
1936
|
model: Optional[str] = None,
|
|
1757
1937
|
# Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
|
|
1938
|
+
adapter_id: Optional[str] = None,
|
|
1758
1939
|
best_of: Optional[int] = None,
|
|
1759
1940
|
decoder_input_details: Optional[bool] = None,
|
|
1760
1941
|
do_sample: Optional[bool] = False, # Manual default value
|
|
@@ -1782,6 +1963,7 @@ class AsyncInferenceClient:
|
|
|
1782
1963
|
stream: bool = False,
|
|
1783
1964
|
model: Optional[str] = None,
|
|
1784
1965
|
# Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
|
|
1966
|
+
adapter_id: Optional[str] = None,
|
|
1785
1967
|
best_of: Optional[int] = None,
|
|
1786
1968
|
decoder_input_details: Optional[bool] = None,
|
|
1787
1969
|
do_sample: Optional[bool] = False, # Manual default value
|
|
@@ -1812,6 +1994,13 @@ class AsyncInferenceClient:
|
|
|
1812
1994
|
|
|
1813
1995
|
To learn more about the TGI project, please refer to https://github.com/huggingface/text-generation-inference.
|
|
1814
1996
|
|
|
1997
|
+
<Tip>
|
|
1998
|
+
|
|
1999
|
+
If you want to generate a response from chat messages, you should use the [`InferenceClient.chat_completion`] method.
|
|
2000
|
+
It accepts a list of messages instead of a single text prompt and handles the chat templating for you.
|
|
2001
|
+
|
|
2002
|
+
</Tip>
|
|
2003
|
+
|
|
1815
2004
|
Args:
|
|
1816
2005
|
prompt (`str`):
|
|
1817
2006
|
Input text.
|
|
@@ -1826,6 +2015,8 @@ class AsyncInferenceClient:
|
|
|
1826
2015
|
model (`str`, *optional*):
|
|
1827
2016
|
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
|
1828
2017
|
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
|
|
2018
|
+
adapter_id (`str`, *optional*):
|
|
2019
|
+
Lora adapter id.
|
|
1829
2020
|
best_of (`int`, *optional*):
|
|
1830
2021
|
Generate best_of sequences and return the one if the highest token logprobs.
|
|
1831
2022
|
decoder_input_details (`bool`, *optional*):
|
|
@@ -1893,7 +2084,7 @@ class AsyncInferenceClient:
|
|
|
1893
2084
|
>>> await client.text_generation("The huggingface_hub library is ", max_new_tokens=12)
|
|
1894
2085
|
'100% open source and built to be easy to use.'
|
|
1895
2086
|
|
|
1896
|
-
# Case 2: iterate over the generated tokens. Useful
|
|
2087
|
+
# Case 2: iterate over the generated tokens. Useful for large generation.
|
|
1897
2088
|
>>> async for token in await client.text_generation("The huggingface_hub library is ", max_new_tokens=12, stream=True):
|
|
1898
2089
|
... print(token)
|
|
1899
2090
|
100
|
|
@@ -1995,6 +2186,7 @@ class AsyncInferenceClient:
|
|
|
1995
2186
|
|
|
1996
2187
|
# Build payload
|
|
1997
2188
|
parameters = {
|
|
2189
|
+
"adapter_id": adapter_id,
|
|
1998
2190
|
"best_of": best_of,
|
|
1999
2191
|
"decoder_input_details": decoder_input_details,
|
|
2000
2192
|
"details": details,
|
|
@@ -2065,6 +2257,7 @@ class AsyncInferenceClient:
|
|
|
2065
2257
|
details=details,
|
|
2066
2258
|
stream=stream,
|
|
2067
2259
|
model=model,
|
|
2260
|
+
adapter_id=adapter_id,
|
|
2068
2261
|
best_of=best_of,
|
|
2069
2262
|
decoder_input_details=decoder_input_details,
|
|
2070
2263
|
do_sample=do_sample,
|
|
@@ -2089,7 +2282,12 @@ class AsyncInferenceClient:
|
|
|
2089
2282
|
if stream:
|
|
2090
2283
|
return _async_stream_text_generation_response(bytes_output, details) # type: ignore
|
|
2091
2284
|
|
|
2092
|
-
data = _bytes_to_dict(bytes_output)
|
|
2285
|
+
data = _bytes_to_dict(bytes_output) # type: ignore[arg-type]
|
|
2286
|
+
|
|
2287
|
+
# Data can be a single element (dict) or an iterable of dicts where we select the first element of.
|
|
2288
|
+
if isinstance(data, list):
|
|
2289
|
+
data = data[0]
|
|
2290
|
+
|
|
2093
2291
|
return TextGenerationOutput.parse_obj_as_instance(data) if details else data["generated_text"]
|
|
2094
2292
|
|
|
2095
2293
|
async def text_to_image(
|
|
@@ -2377,7 +2575,13 @@ class AsyncInferenceClient:
|
|
|
2377
2575
|
return VisualQuestionAnsweringOutputElement.parse_obj_as_list(response)
|
|
2378
2576
|
|
|
2379
2577
|
async def zero_shot_classification(
|
|
2380
|
-
self,
|
|
2578
|
+
self,
|
|
2579
|
+
text: str,
|
|
2580
|
+
labels: List[str],
|
|
2581
|
+
*,
|
|
2582
|
+
multi_label: bool = False,
|
|
2583
|
+
hypothesis_template: Optional[str] = None,
|
|
2584
|
+
model: Optional[str] = None,
|
|
2381
2585
|
) -> List[ZeroShotClassificationOutputElement]:
|
|
2382
2586
|
"""
|
|
2383
2587
|
Provide as input a text and a set of candidate labels to classify the input text.
|
|
@@ -2386,9 +2590,15 @@ class AsyncInferenceClient:
|
|
|
2386
2590
|
text (`str`):
|
|
2387
2591
|
The input text to classify.
|
|
2388
2592
|
labels (`List[str]`):
|
|
2389
|
-
List of string
|
|
2593
|
+
List of strings. Each string is the verbalization of a possible label for the input text.
|
|
2390
2594
|
multi_label (`bool`):
|
|
2391
|
-
Boolean
|
|
2595
|
+
Boolean. If True, the probability for each label is evaluated independently and multiple labels can have a probability close to 1 simultaneously or all probabilities can be close to 0.
|
|
2596
|
+
If False, the labels are considered mutually exclusive and the probability over all labels always sums to 1. Defaults to False.
|
|
2597
|
+
hypothesis_template (`str`, *optional*):
|
|
2598
|
+
A template sentence string with curly brackets to which the label strings are added. The label strings are added at the position of the curly brackets "{}".
|
|
2599
|
+
Zero-shot classifiers are based on NLI models, which evaluate if a hypothesis is entailed in another text or not.
|
|
2600
|
+
For example, with hypothesis_template="This text is about {}." and labels=["economics", "politics"], the system internally creates the two hypotheses "This text is about economics." and "This text is about politics.".
|
|
2601
|
+
The model then evaluates for both hypotheses if they are entailed in the provided `text` or not.
|
|
2392
2602
|
model (`str`, *optional*):
|
|
2393
2603
|
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
|
2394
2604
|
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
|
|
@@ -2402,15 +2612,15 @@ class AsyncInferenceClient:
|
|
|
2402
2612
|
`aiohttp.ClientResponseError`:
|
|
2403
2613
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
2404
2614
|
|
|
2405
|
-
Example
|
|
2615
|
+
Example with `multi_label=False`:
|
|
2406
2616
|
```py
|
|
2407
2617
|
# Must be run in an async context
|
|
2408
2618
|
>>> from huggingface_hub import AsyncInferenceClient
|
|
2409
2619
|
>>> client = AsyncInferenceClient()
|
|
2410
2620
|
>>> text = (
|
|
2411
|
-
... "A new model offers an explanation
|
|
2621
|
+
... "A new model offers an explanation for how the Galilean satellites formed around the solar system's"
|
|
2412
2622
|
... "largest world. Konstantin Batygin did not set out to solve one of the solar system's most puzzling"
|
|
2413
|
-
... " mysteries when he went
|
|
2623
|
+
... " mysteries when he went for a run up a hill in Nice, France."
|
|
2414
2624
|
... )
|
|
2415
2625
|
>>> labels = ["space & cosmos", "scientific discovery", "microbiology", "robots", "archeology"]
|
|
2416
2626
|
>>> await client.zero_shot_classification(text, labels)
|
|
@@ -2430,21 +2640,38 @@ class AsyncInferenceClient:
|
|
|
2430
2640
|
ZeroShotClassificationOutputElement(label='robots', score=0.00030448526376858354),
|
|
2431
2641
|
]
|
|
2432
2642
|
```
|
|
2643
|
+
|
|
2644
|
+
Example with `multi_label=True` and a custom `hypothesis_template`:
|
|
2645
|
+
```py
|
|
2646
|
+
# Must be run in an async context
|
|
2647
|
+
>>> from huggingface_hub import AsyncInferenceClient
|
|
2648
|
+
>>> client = AsyncInferenceClient()
|
|
2649
|
+
>>> await client.zero_shot_classification(
|
|
2650
|
+
... text="I really like our dinner and I'm very happy. I don't like the weather though.",
|
|
2651
|
+
... labels=["positive", "negative", "pessimistic", "optimistic"],
|
|
2652
|
+
... multi_label=True,
|
|
2653
|
+
... hypothesis_template="This text is {} towards the weather"
|
|
2654
|
+
... )
|
|
2655
|
+
[
|
|
2656
|
+
ZeroShotClassificationOutputElement(label='negative', score=0.9231801629066467),
|
|
2657
|
+
ZeroShotClassificationOutputElement(label='pessimistic', score=0.8760990500450134),
|
|
2658
|
+
ZeroShotClassificationOutputElement(label='optimistic', score=0.0008674879791215062),
|
|
2659
|
+
ZeroShotClassificationOutputElement(label='positive', score=0.0005250611575320363)
|
|
2660
|
+
]
|
|
2661
|
+
```
|
|
2433
2662
|
"""
|
|
2434
|
-
|
|
2435
|
-
|
|
2436
|
-
|
|
2663
|
+
|
|
2664
|
+
parameters = {"candidate_labels": labels, "multi_label": multi_label}
|
|
2665
|
+
if hypothesis_template is not None:
|
|
2666
|
+
parameters["hypothesis_template"] = hypothesis_template
|
|
2437
2667
|
|
|
2438
2668
|
response = await self.post(
|
|
2439
2669
|
json={
|
|
2440
2670
|
"inputs": text,
|
|
2441
|
-
"parameters":
|
|
2442
|
-
"candidate_labels": ",".join(labels),
|
|
2443
|
-
"multi_label": multi_label,
|
|
2444
|
-
},
|
|
2671
|
+
"parameters": parameters,
|
|
2445
2672
|
},
|
|
2446
|
-
model=model,
|
|
2447
2673
|
task="zero-shot-classification",
|
|
2674
|
+
model=model,
|
|
2448
2675
|
)
|
|
2449
2676
|
output = _bytes_to_dict(response)
|
|
2450
2677
|
return [
|
|
@@ -2501,7 +2728,7 @@ class AsyncInferenceClient:
|
|
|
2501
2728
|
return ZeroShotImageClassificationOutputElement.parse_obj_as_list(response)
|
|
2502
2729
|
|
|
2503
2730
|
def _resolve_url(self, model: Optional[str] = None, task: Optional[str] = None) -> str:
|
|
2504
|
-
model = model or self.model
|
|
2731
|
+
model = model or self.model or self.base_url
|
|
2505
2732
|
|
|
2506
2733
|
# If model is already a URL, ignore `task` and return directly
|
|
2507
2734
|
if model is not None and (model.startswith("http://") or model.startswith("https://")):
|
|
@@ -2554,6 +2781,99 @@ class AsyncInferenceClient:
|
|
|
2554
2781
|
)
|
|
2555
2782
|
return model
|
|
2556
2783
|
|
|
2784
|
+
async def get_endpoint_info(self, *, model: Optional[str] = None) -> Dict[str, Any]:
|
|
2785
|
+
"""
|
|
2786
|
+
Get information about the deployed endpoint.
|
|
2787
|
+
|
|
2788
|
+
This endpoint is only available on endpoints powered by Text-Generation-Inference (TGI) or Text-Embedding-Inference (TEI).
|
|
2789
|
+
Endpoints powered by `transformers` return an empty payload.
|
|
2790
|
+
|
|
2791
|
+
Args:
|
|
2792
|
+
model (`str`, *optional*):
|
|
2793
|
+
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
|
2794
|
+
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
|
|
2795
|
+
|
|
2796
|
+
Returns:
|
|
2797
|
+
`Dict[str, Any]`: Information about the endpoint.
|
|
2798
|
+
|
|
2799
|
+
Example:
|
|
2800
|
+
```py
|
|
2801
|
+
# Must be run in an async context
|
|
2802
|
+
>>> from huggingface_hub import AsyncInferenceClient
|
|
2803
|
+
>>> client = AsyncInferenceClient("meta-llama/Meta-Llama-3-70B-Instruct")
|
|
2804
|
+
>>> await client.get_endpoint_info()
|
|
2805
|
+
{
|
|
2806
|
+
'model_id': 'meta-llama/Meta-Llama-3-70B-Instruct',
|
|
2807
|
+
'model_sha': None,
|
|
2808
|
+
'model_dtype': 'torch.float16',
|
|
2809
|
+
'model_device_type': 'cuda',
|
|
2810
|
+
'model_pipeline_tag': None,
|
|
2811
|
+
'max_concurrent_requests': 128,
|
|
2812
|
+
'max_best_of': 2,
|
|
2813
|
+
'max_stop_sequences': 4,
|
|
2814
|
+
'max_input_length': 8191,
|
|
2815
|
+
'max_total_tokens': 8192,
|
|
2816
|
+
'waiting_served_ratio': 0.3,
|
|
2817
|
+
'max_batch_total_tokens': 1259392,
|
|
2818
|
+
'max_waiting_tokens': 20,
|
|
2819
|
+
'max_batch_size': None,
|
|
2820
|
+
'validation_workers': 32,
|
|
2821
|
+
'max_client_batch_size': 4,
|
|
2822
|
+
'version': '2.0.2',
|
|
2823
|
+
'sha': 'dccab72549635c7eb5ddb17f43f0b7cdff07c214',
|
|
2824
|
+
'docker_label': 'sha-dccab72'
|
|
2825
|
+
}
|
|
2826
|
+
```
|
|
2827
|
+
"""
|
|
2828
|
+
model = model or self.model
|
|
2829
|
+
if model is None:
|
|
2830
|
+
raise ValueError("Model id not provided.")
|
|
2831
|
+
if model.startswith(("http://", "https://")):
|
|
2832
|
+
url = model.rstrip("/") + "/info"
|
|
2833
|
+
else:
|
|
2834
|
+
url = f"{INFERENCE_ENDPOINT}/models/{model}/info"
|
|
2835
|
+
|
|
2836
|
+
async with _import_aiohttp().ClientSession(headers=self.headers) as client:
|
|
2837
|
+
response = await client.get(url)
|
|
2838
|
+
response.raise_for_status()
|
|
2839
|
+
return await response.json()
|
|
2840
|
+
|
|
2841
|
+
async def health_check(self, model: Optional[str] = None) -> bool:
|
|
2842
|
+
"""
|
|
2843
|
+
Check the health of the deployed endpoint.
|
|
2844
|
+
|
|
2845
|
+
Health check is only available with Inference Endpoints powered by Text-Generation-Inference (TGI) or Text-Embedding-Inference (TEI).
|
|
2846
|
+
For Inference API, please use [`InferenceClient.get_model_status`] instead.
|
|
2847
|
+
|
|
2848
|
+
Args:
|
|
2849
|
+
model (`str`, *optional*):
|
|
2850
|
+
URL of the Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
|
|
2851
|
+
|
|
2852
|
+
Returns:
|
|
2853
|
+
`bool`: True if everything is working fine.
|
|
2854
|
+
|
|
2855
|
+
Example:
|
|
2856
|
+
```py
|
|
2857
|
+
# Must be run in an async context
|
|
2858
|
+
>>> from huggingface_hub import AsyncInferenceClient
|
|
2859
|
+
>>> client = AsyncInferenceClient("https://jzgu0buei5.us-east-1.aws.endpoints.huggingface.cloud")
|
|
2860
|
+
>>> await client.health_check()
|
|
2861
|
+
True
|
|
2862
|
+
```
|
|
2863
|
+
"""
|
|
2864
|
+
model = model or self.model
|
|
2865
|
+
if model is None:
|
|
2866
|
+
raise ValueError("Model id not provided.")
|
|
2867
|
+
if not model.startswith(("http://", "https://")):
|
|
2868
|
+
raise ValueError(
|
|
2869
|
+
"Model must be an Inference Endpoint URL. For serverless Inference API, please use `InferenceClient.get_model_status`."
|
|
2870
|
+
)
|
|
2871
|
+
url = model.rstrip("/") + "/health"
|
|
2872
|
+
|
|
2873
|
+
async with _import_aiohttp().ClientSession(headers=self.headers) as client:
|
|
2874
|
+
response = await client.get(url)
|
|
2875
|
+
return response.status == 200
|
|
2876
|
+
|
|
2557
2877
|
async def get_model_status(self, model: Optional[str] = None) -> ModelStatus:
|
|
2558
2878
|
"""
|
|
2559
2879
|
Get the status of a model hosted on the Inference API.
|
|
@@ -2581,7 +2901,7 @@ class AsyncInferenceClient:
|
|
|
2581
2901
|
# Must be run in an async context
|
|
2582
2902
|
>>> from huggingface_hub import AsyncInferenceClient
|
|
2583
2903
|
>>> client = AsyncInferenceClient()
|
|
2584
|
-
>>> await client.get_model_status("
|
|
2904
|
+
>>> await client.get_model_status("meta-llama/Meta-Llama-3-8B-Instruct")
|
|
2585
2905
|
ModelStatus(loaded=True, state='Loaded', compute_type='gpu', framework='text-generation-inference')
|
|
2586
2906
|
```
|
|
2587
2907
|
"""
|
|
@@ -2606,3 +2926,30 @@ class AsyncInferenceClient:
|
|
|
2606
2926
|
compute_type=response_data["compute_type"],
|
|
2607
2927
|
framework=response_data["framework"],
|
|
2608
2928
|
)
|
|
2929
|
+
|
|
2930
|
+
@property
|
|
2931
|
+
def chat(self) -> "ProxyClientChat":
|
|
2932
|
+
return ProxyClientChat(self)
|
|
2933
|
+
|
|
2934
|
+
|
|
2935
|
+
class _ProxyClient:
|
|
2936
|
+
"""Proxy class to be able to call `client.chat.completion.create(...)` as OpenAI client."""
|
|
2937
|
+
|
|
2938
|
+
def __init__(self, client: AsyncInferenceClient):
|
|
2939
|
+
self._client = client
|
|
2940
|
+
|
|
2941
|
+
|
|
2942
|
+
class ProxyClientChat(_ProxyClient):
|
|
2943
|
+
"""Proxy class to be able to call `client.chat.completion.create(...)` as OpenAI client."""
|
|
2944
|
+
|
|
2945
|
+
@property
|
|
2946
|
+
def completions(self) -> "ProxyClientChatCompletions":
|
|
2947
|
+
return ProxyClientChatCompletions(self._client)
|
|
2948
|
+
|
|
2949
|
+
|
|
2950
|
+
class ProxyClientChatCompletions(_ProxyClient):
|
|
2951
|
+
"""Proxy class to be able to call `client.chat.completion.create(...)` as OpenAI client."""
|
|
2952
|
+
|
|
2953
|
+
@property
|
|
2954
|
+
def create(self):
|
|
2955
|
+
return self._client.chat_completion
|