huggingface-hub 0.23.5__py3-none-any.whl → 0.24.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of huggingface-hub might be problematic. Click here for more details.
- huggingface_hub/__init__.py +47 -15
- huggingface_hub/_commit_api.py +38 -8
- huggingface_hub/_inference_endpoints.py +11 -4
- huggingface_hub/_local_folder.py +22 -13
- huggingface_hub/_snapshot_download.py +12 -7
- huggingface_hub/_webhooks_server.py +3 -1
- huggingface_hub/commands/huggingface_cli.py +4 -3
- huggingface_hub/commands/repo_files.py +128 -0
- huggingface_hub/constants.py +12 -0
- huggingface_hub/file_download.py +127 -91
- huggingface_hub/hf_api.py +976 -341
- huggingface_hub/hf_file_system.py +30 -3
- huggingface_hub/inference/_client.py +408 -147
- huggingface_hub/inference/_common.py +25 -63
- huggingface_hub/inference/_generated/_async_client.py +425 -153
- huggingface_hub/inference/_generated/types/__init__.py +4 -1
- huggingface_hub/inference/_generated/types/chat_completion.py +41 -21
- huggingface_hub/inference/_generated/types/feature_extraction.py +23 -5
- huggingface_hub/inference/_generated/types/text_generation.py +29 -0
- huggingface_hub/lfs.py +11 -6
- huggingface_hub/repocard_data.py +3 -3
- huggingface_hub/repository.py +6 -6
- huggingface_hub/serialization/__init__.py +8 -3
- huggingface_hub/serialization/_base.py +13 -16
- huggingface_hub/serialization/_tensorflow.py +4 -3
- huggingface_hub/serialization/_torch.py +399 -22
- huggingface_hub/utils/__init__.py +0 -1
- huggingface_hub/utils/_errors.py +1 -1
- huggingface_hub/utils/_fixes.py +14 -3
- huggingface_hub/utils/_paths.py +17 -6
- huggingface_hub/utils/_subprocess.py +0 -1
- huggingface_hub/utils/_telemetry.py +9 -1
- huggingface_hub/utils/endpoint_helpers.py +2 -186
- huggingface_hub/utils/sha.py +36 -1
- huggingface_hub/utils/tqdm.py +0 -1
- {huggingface_hub-0.23.5.dist-info → huggingface_hub-0.24.1.dist-info}/METADATA +12 -9
- {huggingface_hub-0.23.5.dist-info → huggingface_hub-0.24.1.dist-info}/RECORD +41 -41
- huggingface_hub/serialization/_numpy.py +0 -68
- {huggingface_hub-0.23.5.dist-info → huggingface_hub-0.24.1.dist-info}/LICENSE +0 -0
- {huggingface_hub-0.23.5.dist-info → huggingface_hub-0.24.1.dist-info}/WHEEL +0 -0
- {huggingface_hub-0.23.5.dist-info → huggingface_hub-0.24.1.dist-info}/entry_points.txt +0 -0
- {huggingface_hub-0.23.5.dist-info → huggingface_hub-0.24.1.dist-info}/top_level.txt +0 -0
|
@@ -44,7 +44,7 @@ from huggingface_hub.inference._common import (
|
|
|
44
44
|
TASKS_EXPECTING_IMAGES,
|
|
45
45
|
ContentT,
|
|
46
46
|
ModelStatus,
|
|
47
|
-
|
|
47
|
+
_async_stream_chat_completion_response,
|
|
48
48
|
_async_stream_text_generation_response,
|
|
49
49
|
_b64_encode,
|
|
50
50
|
_b64_to_image,
|
|
@@ -54,9 +54,7 @@ from huggingface_hub.inference._common import (
|
|
|
54
54
|
_fetch_recommended_models,
|
|
55
55
|
_get_unsupported_text_generation_kwargs,
|
|
56
56
|
_import_numpy,
|
|
57
|
-
_is_chat_completion_server,
|
|
58
57
|
_open_as_binary,
|
|
59
|
-
_set_as_non_chat_completion_server,
|
|
60
58
|
_set_unsupported_text_generation_kwargs,
|
|
61
59
|
raise_text_generation_error,
|
|
62
60
|
)
|
|
@@ -64,11 +62,10 @@ from huggingface_hub.inference._generated.types import (
|
|
|
64
62
|
AudioClassificationOutputElement,
|
|
65
63
|
AudioToAudioOutputElement,
|
|
66
64
|
AutomaticSpeechRecognitionOutput,
|
|
65
|
+
ChatCompletionInputGrammarType,
|
|
67
66
|
ChatCompletionInputTool,
|
|
68
67
|
ChatCompletionInputToolTypeClass,
|
|
69
68
|
ChatCompletionOutput,
|
|
70
|
-
ChatCompletionOutputComplete,
|
|
71
|
-
ChatCompletionOutputMessage,
|
|
72
69
|
ChatCompletionStreamOutput,
|
|
73
70
|
DocumentQuestionAnsweringOutputElement,
|
|
74
71
|
FillMaskOutputElement,
|
|
@@ -89,13 +86,13 @@ from huggingface_hub.inference._generated.types import (
|
|
|
89
86
|
ZeroShotClassificationOutputElement,
|
|
90
87
|
ZeroShotImageClassificationOutputElement,
|
|
91
88
|
)
|
|
92
|
-
from huggingface_hub.inference._generated.types.chat_completion import ChatCompletionInputToolTypeEnum
|
|
93
89
|
from huggingface_hub.inference._types import (
|
|
94
90
|
ConversationalOutput, # soon to be removed
|
|
95
91
|
)
|
|
96
92
|
from huggingface_hub.utils import (
|
|
97
93
|
build_hf_headers,
|
|
98
94
|
)
|
|
95
|
+
from huggingface_hub.utils._deprecation import _deprecate_positional_args
|
|
99
96
|
|
|
100
97
|
from .._common import _async_yield_from, _import_aiohttp
|
|
101
98
|
|
|
@@ -119,12 +116,16 @@ class AsyncInferenceClient:
|
|
|
119
116
|
|
|
120
117
|
Args:
|
|
121
118
|
model (`str`, `optional`):
|
|
122
|
-
The model to run inference with. Can be a model id hosted on the Hugging Face Hub, e.g. `
|
|
119
|
+
The model to run inference with. Can be a model id hosted on the Hugging Face Hub, e.g. `meta-llama/Meta-Llama-3-8B-Instruct`
|
|
123
120
|
or a URL to a deployed Inference Endpoint. Defaults to None, in which case a recommended model is
|
|
124
121
|
automatically selected for the task.
|
|
122
|
+
Note: for better compatibility with OpenAI's client, `model` has been aliased as `base_url`. Those 2
|
|
123
|
+
arguments are mutually exclusive and have the exact same behavior.
|
|
125
124
|
token (`str` or `bool`, *optional*):
|
|
126
125
|
Hugging Face token. Will default to the locally saved token if not provided.
|
|
127
126
|
Pass `token=False` if you don't want to send your token to the server.
|
|
127
|
+
Note: for better compatibility with OpenAI's client, `token` has been aliased as `api_key`. Those 2
|
|
128
|
+
arguments are mutually exclusive and have the exact same behavior.
|
|
128
129
|
timeout (`float`, `optional`):
|
|
129
130
|
The maximum number of seconds to wait for a response from the server. Loading a new model in Inference
|
|
130
131
|
API can take up to several minutes. Defaults to None, meaning it will loop until the server is available.
|
|
@@ -133,23 +134,52 @@ class AsyncInferenceClient:
|
|
|
133
134
|
Values in this dictionary will override the default values.
|
|
134
135
|
cookies (`Dict[str, str]`, `optional`):
|
|
135
136
|
Additional cookies to send to the server.
|
|
137
|
+
base_url (`str`, `optional`):
|
|
138
|
+
Base URL to run inference. This is a duplicated argument from `model` to make [`InferenceClient`]
|
|
139
|
+
follow the same pattern as `openai.OpenAI` client. Cannot be used if `model` is set. Defaults to None.
|
|
140
|
+
api_key (`str`, `optional`):
|
|
141
|
+
Token to use for authentication. This is a duplicated argument from `token` to make [`InferenceClient`]
|
|
142
|
+
follow the same pattern as `openai.OpenAI` client. Cannot be used if `token` is set. Defaults to None.
|
|
136
143
|
"""
|
|
137
144
|
|
|
145
|
+
@_deprecate_positional_args(version="0.26")
|
|
138
146
|
def __init__(
|
|
139
147
|
self,
|
|
140
148
|
model: Optional[str] = None,
|
|
149
|
+
*,
|
|
141
150
|
token: Union[str, bool, None] = None,
|
|
142
151
|
timeout: Optional[float] = None,
|
|
143
152
|
headers: Optional[Dict[str, str]] = None,
|
|
144
153
|
cookies: Optional[Dict[str, str]] = None,
|
|
154
|
+
proxies: Optional[Any] = None,
|
|
155
|
+
# OpenAI compatibility
|
|
156
|
+
base_url: Optional[str] = None,
|
|
157
|
+
api_key: Optional[str] = None,
|
|
145
158
|
) -> None:
|
|
159
|
+
if model is not None and base_url is not None:
|
|
160
|
+
raise ValueError(
|
|
161
|
+
"Received both `model` and `base_url` arguments. Please provide only one of them."
|
|
162
|
+
" `base_url` is an alias for `model` to make the API compatible with OpenAI's client."
|
|
163
|
+
" It has the exact same behavior as `model`."
|
|
164
|
+
)
|
|
165
|
+
if token is not None and api_key is not None:
|
|
166
|
+
raise ValueError(
|
|
167
|
+
"Received both `token` and `api_key` arguments. Please provide only one of them."
|
|
168
|
+
" `api_key` is an alias for `token` to make the API compatible with OpenAI's client."
|
|
169
|
+
" It has the exact same behavior as `token`."
|
|
170
|
+
)
|
|
171
|
+
|
|
146
172
|
self.model: Optional[str] = model
|
|
147
|
-
self.token: Union[str, bool, None] = token
|
|
148
|
-
self.headers = CaseInsensitiveDict(build_hf_headers(token=token)) #
|
|
173
|
+
self.token: Union[str, bool, None] = token if token is not None else api_key
|
|
174
|
+
self.headers = CaseInsensitiveDict(build_hf_headers(token=self.token)) # 'authorization' + 'user-agent'
|
|
149
175
|
if headers is not None:
|
|
150
176
|
self.headers.update(headers)
|
|
151
177
|
self.cookies = cookies
|
|
152
178
|
self.timeout = timeout
|
|
179
|
+
self.proxies = proxies
|
|
180
|
+
|
|
181
|
+
# OpenAI compatibility
|
|
182
|
+
self.base_url = base_url
|
|
153
183
|
|
|
154
184
|
def __repr__(self):
|
|
155
185
|
return f"<InferenceClient(model='{self.model if self.model else ''}', timeout={self.timeout})>"
|
|
@@ -250,7 +280,7 @@ class AsyncInferenceClient:
|
|
|
250
280
|
)
|
|
251
281
|
|
|
252
282
|
try:
|
|
253
|
-
response = await client.post(url, json=json, data=data_as_binary)
|
|
283
|
+
response = await client.post(url, json=json, data=data_as_binary, proxy=self.proxies)
|
|
254
284
|
response_error_payload = None
|
|
255
285
|
if response.status != 200:
|
|
256
286
|
try:
|
|
@@ -284,11 +314,16 @@ class AsyncInferenceClient:
|
|
|
284
314
|
) from error
|
|
285
315
|
# ...or wait 1s and retry
|
|
286
316
|
logger.info(f"Waiting for model to be loaded on the server: {error}")
|
|
317
|
+
if "X-wait-for-model" not in headers and url.startswith(INFERENCE_ENDPOINT):
|
|
318
|
+
headers["X-wait-for-model"] = "1"
|
|
287
319
|
time.sleep(1)
|
|
288
320
|
if timeout is not None:
|
|
289
321
|
timeout = max(self.timeout - (time.time() - t0), 1) # type: ignore
|
|
290
322
|
continue
|
|
291
323
|
raise error
|
|
324
|
+
except Exception:
|
|
325
|
+
await client.close()
|
|
326
|
+
raise
|
|
292
327
|
|
|
293
328
|
async def audio_classification(
|
|
294
329
|
self,
|
|
@@ -427,10 +462,11 @@ class AsyncInferenceClient:
|
|
|
427
462
|
max_tokens: Optional[int] = None,
|
|
428
463
|
n: Optional[int] = None,
|
|
429
464
|
presence_penalty: Optional[float] = None,
|
|
465
|
+
response_format: Optional[ChatCompletionInputGrammarType] = None,
|
|
430
466
|
seed: Optional[int] = None,
|
|
431
467
|
stop: Optional[List[str]] = None,
|
|
432
468
|
temperature: Optional[float] = None,
|
|
433
|
-
tool_choice: Optional[Union[ChatCompletionInputToolTypeClass,
|
|
469
|
+
tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, str]] = None,
|
|
434
470
|
tool_prompt: Optional[str] = None,
|
|
435
471
|
tools: Optional[List[ChatCompletionInputTool]] = None,
|
|
436
472
|
top_logprobs: Optional[int] = None,
|
|
@@ -450,10 +486,11 @@ class AsyncInferenceClient:
|
|
|
450
486
|
max_tokens: Optional[int] = None,
|
|
451
487
|
n: Optional[int] = None,
|
|
452
488
|
presence_penalty: Optional[float] = None,
|
|
489
|
+
response_format: Optional[ChatCompletionInputGrammarType] = None,
|
|
453
490
|
seed: Optional[int] = None,
|
|
454
491
|
stop: Optional[List[str]] = None,
|
|
455
492
|
temperature: Optional[float] = None,
|
|
456
|
-
tool_choice: Optional[Union[ChatCompletionInputToolTypeClass,
|
|
493
|
+
tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, str]] = None,
|
|
457
494
|
tool_prompt: Optional[str] = None,
|
|
458
495
|
tools: Optional[List[ChatCompletionInputTool]] = None,
|
|
459
496
|
top_logprobs: Optional[int] = None,
|
|
@@ -473,10 +510,11 @@ class AsyncInferenceClient:
|
|
|
473
510
|
max_tokens: Optional[int] = None,
|
|
474
511
|
n: Optional[int] = None,
|
|
475
512
|
presence_penalty: Optional[float] = None,
|
|
513
|
+
response_format: Optional[ChatCompletionInputGrammarType] = None,
|
|
476
514
|
seed: Optional[int] = None,
|
|
477
515
|
stop: Optional[List[str]] = None,
|
|
478
516
|
temperature: Optional[float] = None,
|
|
479
|
-
tool_choice: Optional[Union[ChatCompletionInputToolTypeClass,
|
|
517
|
+
tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, str]] = None,
|
|
480
518
|
tool_prompt: Optional[str] = None,
|
|
481
519
|
tools: Optional[List[ChatCompletionInputTool]] = None,
|
|
482
520
|
top_logprobs: Optional[int] = None,
|
|
@@ -496,10 +534,11 @@ class AsyncInferenceClient:
|
|
|
496
534
|
max_tokens: Optional[int] = None,
|
|
497
535
|
n: Optional[int] = None,
|
|
498
536
|
presence_penalty: Optional[float] = None,
|
|
537
|
+
response_format: Optional[ChatCompletionInputGrammarType] = None,
|
|
499
538
|
seed: Optional[int] = None,
|
|
500
539
|
stop: Optional[List[str]] = None,
|
|
501
540
|
temperature: Optional[float] = None,
|
|
502
|
-
tool_choice: Optional[Union[ChatCompletionInputToolTypeClass,
|
|
541
|
+
tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, str]] = None,
|
|
503
542
|
tool_prompt: Optional[str] = None,
|
|
504
543
|
tools: Optional[List[ChatCompletionInputTool]] = None,
|
|
505
544
|
top_logprobs: Optional[int] = None,
|
|
@@ -510,11 +549,10 @@ class AsyncInferenceClient:
|
|
|
510
549
|
|
|
511
550
|
<Tip>
|
|
512
551
|
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
to install it.
|
|
552
|
+
The `client.chat_completion` method is aliased as `client.chat.completions.create` for compatibility with OpenAI's client.
|
|
553
|
+
Inputs and outputs are strictly the same and using either syntax will yield the same results.
|
|
554
|
+
Check out the [Inference guide](https://huggingface.co/docs/huggingface_hub/guides/inference#openai-compatibility)
|
|
555
|
+
for more details about OpenAI's compatibility.
|
|
518
556
|
|
|
519
557
|
</Tip>
|
|
520
558
|
|
|
@@ -525,6 +563,9 @@ class AsyncInferenceClient:
|
|
|
525
563
|
The model to use for chat-completion. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
|
526
564
|
Inference Endpoint. If not provided, the default recommended model for chat-based text-generation will be used.
|
|
527
565
|
See https://huggingface.co/tasks/text-generation for more details.
|
|
566
|
+
|
|
567
|
+
If `model` is a model ID, it is passed to the server as the `model` parameter. If you want to define a
|
|
568
|
+
custom URL while setting `model` in the request payload, you must set `base_url` when initializing [`InferenceClient`].
|
|
528
569
|
frequency_penalty (`float`, *optional*):
|
|
529
570
|
Penalizes new tokens based on their existing frequency
|
|
530
571
|
in the text so far. Range: [-2.0, 2.0]. Defaults to 0.0.
|
|
@@ -544,6 +585,8 @@ class AsyncInferenceClient:
|
|
|
544
585
|
presence_penalty (`float`, *optional*):
|
|
545
586
|
Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the
|
|
546
587
|
text so far, increasing the model's likelihood to talk about new topics.
|
|
588
|
+
response_format ([`ChatCompletionInputGrammarType`], *optional*):
|
|
589
|
+
Grammar constraints. Can be either a JSONSchema or a regex.
|
|
547
590
|
seed (Optional[`int`], *optional*):
|
|
548
591
|
Seed for reproducible control flow. Defaults to None.
|
|
549
592
|
stop (Optional[`str`], *optional*):
|
|
@@ -561,7 +604,7 @@ class AsyncInferenceClient:
|
|
|
561
604
|
top_p (`float`, *optional*):
|
|
562
605
|
Fraction of the most likely next words to sample from.
|
|
563
606
|
Must be between 0 and 1. Defaults to 1.0.
|
|
564
|
-
tool_choice ([`ChatCompletionInputToolTypeClass`] or
|
|
607
|
+
tool_choice ([`ChatCompletionInputToolTypeClass`] or `str`, *optional*):
|
|
565
608
|
The tool to use for the completion. Defaults to "auto".
|
|
566
609
|
tool_prompt (`str`, *optional*):
|
|
567
610
|
A prompt to be appended before the tools.
|
|
@@ -570,7 +613,7 @@ class AsyncInferenceClient:
|
|
|
570
613
|
provide a list of functions the model may generate JSON inputs for.
|
|
571
614
|
|
|
572
615
|
Returns:
|
|
573
|
-
[`ChatCompletionOutput] or Iterable of [`ChatCompletionStreamOutput`]:
|
|
616
|
+
[`ChatCompletionOutput`] or Iterable of [`ChatCompletionStreamOutput`]:
|
|
574
617
|
Generated text returned from the server:
|
|
575
618
|
- if `stream=False`, the generated text is returned as a [`ChatCompletionOutput`] (default).
|
|
576
619
|
- if `stream=True`, the generated text is returned token by token as a sequence of [`ChatCompletionStreamOutput`].
|
|
@@ -585,10 +628,9 @@ class AsyncInferenceClient:
|
|
|
585
628
|
|
|
586
629
|
```py
|
|
587
630
|
# Must be run in an async context
|
|
588
|
-
# Chat example
|
|
589
631
|
>>> from huggingface_hub import AsyncInferenceClient
|
|
590
632
|
>>> messages = [{"role": "user", "content": "What is the capital of France?"}]
|
|
591
|
-
>>> client = AsyncInferenceClient("
|
|
633
|
+
>>> client = AsyncInferenceClient("meta-llama/Meta-Llama-3-8B-Instruct")
|
|
592
634
|
>>> await client.chat_completion(messages, max_tokens=100)
|
|
593
635
|
ChatCompletionOutput(
|
|
594
636
|
choices=[
|
|
@@ -596,26 +638,75 @@ class AsyncInferenceClient:
|
|
|
596
638
|
finish_reason='eos_token',
|
|
597
639
|
index=0,
|
|
598
640
|
message=ChatCompletionOutputMessage(
|
|
599
|
-
|
|
600
|
-
|
|
641
|
+
role='assistant',
|
|
642
|
+
content='The capital of France is Paris.',
|
|
643
|
+
name=None,
|
|
644
|
+
tool_calls=None
|
|
645
|
+
),
|
|
646
|
+
logprobs=None
|
|
601
647
|
)
|
|
602
648
|
],
|
|
603
|
-
created=
|
|
649
|
+
created=1719907176,
|
|
650
|
+
id='',
|
|
651
|
+
model='meta-llama/Meta-Llama-3-8B-Instruct',
|
|
652
|
+
object='text_completion',
|
|
653
|
+
system_fingerprint='2.0.4-sha-f426a33',
|
|
654
|
+
usage=ChatCompletionOutputUsage(
|
|
655
|
+
completion_tokens=8,
|
|
656
|
+
prompt_tokens=17,
|
|
657
|
+
total_tokens=25
|
|
658
|
+
)
|
|
604
659
|
)
|
|
660
|
+
```
|
|
605
661
|
|
|
662
|
+
Example (stream=True):
|
|
663
|
+
```py
|
|
664
|
+
# Must be run in an async context
|
|
665
|
+
>>> from huggingface_hub import AsyncInferenceClient
|
|
666
|
+
>>> messages = [{"role": "user", "content": "What is the capital of France?"}]
|
|
667
|
+
>>> client = AsyncInferenceClient("meta-llama/Meta-Llama-3-8B-Instruct")
|
|
606
668
|
>>> async for token in await client.chat_completion(messages, max_tokens=10, stream=True):
|
|
607
669
|
... print(token)
|
|
608
670
|
ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content='The', role='assistant'), index=0, finish_reason=None)], created=1710498504)
|
|
609
671
|
ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content=' capital', role='assistant'), index=0, finish_reason=None)], created=1710498504)
|
|
610
672
|
(...)
|
|
611
673
|
ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content=' may', role='assistant'), index=0, finish_reason=None)], created=1710498504)
|
|
674
|
+
```
|
|
675
|
+
|
|
676
|
+
Example using OpenAI's syntax:
|
|
677
|
+
```py
|
|
678
|
+
# Must be run in an async context
|
|
679
|
+
# instead of `from openai import OpenAI`
|
|
680
|
+
from huggingface_hub import AsyncInferenceClient
|
|
612
681
|
|
|
613
|
-
#
|
|
682
|
+
# instead of `client = OpenAI(...)`
|
|
683
|
+
client = AsyncInferenceClient(
|
|
684
|
+
base_url=...,
|
|
685
|
+
api_key=...,
|
|
686
|
+
)
|
|
687
|
+
|
|
688
|
+
output = await client.chat.completions.create(
|
|
689
|
+
model="meta-llama/Meta-Llama-3-8B-Instruct",
|
|
690
|
+
messages=[
|
|
691
|
+
{"role": "system", "content": "You are a helpful assistant."},
|
|
692
|
+
{"role": "user", "content": "Count to 10"},
|
|
693
|
+
],
|
|
694
|
+
stream=True,
|
|
695
|
+
max_tokens=1024,
|
|
696
|
+
)
|
|
697
|
+
|
|
698
|
+
for chunk in output:
|
|
699
|
+
print(chunk.choices[0].delta.content)
|
|
700
|
+
```
|
|
701
|
+
|
|
702
|
+
Example using tools:
|
|
703
|
+
```py
|
|
704
|
+
# Must be run in an async context
|
|
614
705
|
>>> client = AsyncInferenceClient("meta-llama/Meta-Llama-3-70B-Instruct")
|
|
615
706
|
>>> messages = [
|
|
616
707
|
... {
|
|
617
708
|
... "role": "system",
|
|
618
|
-
... "content": "Don't make assumptions about what values to plug into functions. Ask
|
|
709
|
+
... "content": "Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.",
|
|
619
710
|
... },
|
|
620
711
|
... {
|
|
621
712
|
... "role": "user",
|
|
@@ -691,113 +782,90 @@ class AsyncInferenceClient:
|
|
|
691
782
|
description=None
|
|
692
783
|
)
|
|
693
784
|
```
|
|
694
|
-
"""
|
|
695
|
-
# determine model
|
|
696
|
-
model = model or self.model or self.get_recommended_model("text-generation")
|
|
697
|
-
|
|
698
|
-
if _is_chat_completion_server(model):
|
|
699
|
-
# First, let's consider the server has a `/v1/chat/completions` endpoint.
|
|
700
|
-
# If that's the case, we don't have to render the chat template client-side.
|
|
701
|
-
model_url = self._resolve_url(model)
|
|
702
|
-
if not model_url.endswith("/chat/completions"):
|
|
703
|
-
model_url += "/v1/chat/completions"
|
|
704
|
-
|
|
705
|
-
try:
|
|
706
|
-
data = await self.post(
|
|
707
|
-
model=model_url,
|
|
708
|
-
json=dict(
|
|
709
|
-
model="tgi", # random string
|
|
710
|
-
messages=messages,
|
|
711
|
-
frequency_penalty=frequency_penalty,
|
|
712
|
-
logit_bias=logit_bias,
|
|
713
|
-
logprobs=logprobs,
|
|
714
|
-
max_tokens=max_tokens,
|
|
715
|
-
n=n,
|
|
716
|
-
presence_penalty=presence_penalty,
|
|
717
|
-
seed=seed,
|
|
718
|
-
stop=stop,
|
|
719
|
-
temperature=temperature,
|
|
720
|
-
tool_choice=tool_choice,
|
|
721
|
-
tool_prompt=tool_prompt,
|
|
722
|
-
tools=tools,
|
|
723
|
-
top_logprobs=top_logprobs,
|
|
724
|
-
top_p=top_p,
|
|
725
|
-
stream=stream,
|
|
726
|
-
),
|
|
727
|
-
stream=stream,
|
|
728
|
-
)
|
|
729
|
-
except _import_aiohttp().ClientResponseError as e:
|
|
730
|
-
if e.status in (400, 404, 500):
|
|
731
|
-
# Let's consider the server is not a chat completion server.
|
|
732
|
-
# Then we call again `chat_completion` which will render the chat template client side.
|
|
733
|
-
# (can be HTTP 500, HTTP 400, HTTP 404 depending on the server)
|
|
734
|
-
_set_as_non_chat_completion_server(model)
|
|
735
|
-
logger.warning(
|
|
736
|
-
f"Server {model_url} does not seem to support chat completion. Falling back to text generation. Error: {e}"
|
|
737
|
-
)
|
|
738
|
-
return await self.chat_completion(
|
|
739
|
-
messages=messages,
|
|
740
|
-
model=model,
|
|
741
|
-
stream=stream,
|
|
742
|
-
max_tokens=max_tokens,
|
|
743
|
-
seed=seed,
|
|
744
|
-
stop=stop,
|
|
745
|
-
temperature=temperature,
|
|
746
|
-
top_p=top_p,
|
|
747
|
-
)
|
|
748
|
-
raise
|
|
749
785
|
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
786
|
+
Example using response_format:
|
|
787
|
+
```py
|
|
788
|
+
# Must be run in an async context
|
|
789
|
+
>>> from huggingface_hub import AsyncInferenceClient
|
|
790
|
+
>>> client = AsyncInferenceClient("meta-llama/Meta-Llama-3-70B-Instruct")
|
|
791
|
+
>>> messages = [
|
|
792
|
+
... {
|
|
793
|
+
... "role": "user",
|
|
794
|
+
... "content": "I saw a puppy a cat and a raccoon during my bike ride in the park. What did I saw and when?",
|
|
795
|
+
... },
|
|
796
|
+
... ]
|
|
797
|
+
>>> response_format = {
|
|
798
|
+
... "type": "json",
|
|
799
|
+
... "value": {
|
|
800
|
+
... "properties": {
|
|
801
|
+
... "location": {"type": "string"},
|
|
802
|
+
... "activity": {"type": "string"},
|
|
803
|
+
... "animals_seen": {"type": "integer", "minimum": 1, "maximum": 5},
|
|
804
|
+
... "animals": {"type": "array", "items": {"type": "string"}},
|
|
805
|
+
... },
|
|
806
|
+
... "required": ["location", "activity", "animals_seen", "animals"],
|
|
807
|
+
... },
|
|
808
|
+
... }
|
|
809
|
+
>>> response = await client.chat_completion(
|
|
810
|
+
... messages=messages,
|
|
811
|
+
... response_format=response_format,
|
|
812
|
+
... max_tokens=500,
|
|
813
|
+
)
|
|
814
|
+
>>> response.choices[0].message.content
|
|
815
|
+
'{\n\n"activity": "bike ride",\n"animals": ["puppy", "cat", "raccoon"],\n"animals_seen": 3,\n"location": "park"}'
|
|
816
|
+
```
|
|
817
|
+
"""
|
|
818
|
+
# Determine model
|
|
819
|
+
# `self.xxx` takes precedence over the method argument only in `chat_completion`
|
|
820
|
+
# since `chat_completion(..., model=xxx)` is also a payload parameter for the
|
|
821
|
+
# server, we need to handle it differently
|
|
822
|
+
model = self.base_url or self.model or model or self.get_recommended_model("text-generation")
|
|
823
|
+
is_url = model.startswith(("http://", "https://"))
|
|
824
|
+
|
|
825
|
+
# First, resolve the model chat completions URL
|
|
826
|
+
if model == self.base_url:
|
|
827
|
+
# base_url passed => add server route
|
|
828
|
+
model_url = model + "/v1/chat/completions"
|
|
829
|
+
elif is_url:
|
|
830
|
+
# model is a URL => use it directly
|
|
831
|
+
model_url = model
|
|
832
|
+
else:
|
|
833
|
+
# model is a model ID => resolve it + add server route
|
|
834
|
+
model_url = self._resolve_url(model) + "/v1/chat/completions"
|
|
835
|
+
|
|
836
|
+
# `model` is sent in the payload. Not used by the server but can be useful for debugging/routing.
|
|
837
|
+
# If it's a ID on the Hub => use it. Otherwise, we use a random string.
|
|
838
|
+
model_id = model if not is_url and model.count("/") == 1 else "tgi"
|
|
839
|
+
|
|
840
|
+
data = await self.post(
|
|
841
|
+
model=model_url,
|
|
842
|
+
json=dict(
|
|
843
|
+
model=model_id,
|
|
844
|
+
messages=messages,
|
|
845
|
+
frequency_penalty=frequency_penalty,
|
|
846
|
+
logit_bias=logit_bias,
|
|
847
|
+
logprobs=logprobs,
|
|
848
|
+
max_tokens=max_tokens,
|
|
849
|
+
n=n,
|
|
850
|
+
presence_penalty=presence_penalty,
|
|
851
|
+
response_format=response_format,
|
|
852
|
+
seed=seed,
|
|
853
|
+
stop=stop,
|
|
854
|
+
temperature=temperature,
|
|
855
|
+
tool_choice=tool_choice,
|
|
856
|
+
tool_prompt=tool_prompt,
|
|
857
|
+
tools=tools,
|
|
858
|
+
top_logprobs=top_logprobs,
|
|
859
|
+
top_p=top_p,
|
|
860
|
+
stream=stream,
|
|
861
|
+
),
|
|
862
|
+
stream=stream,
|
|
863
|
+
)
|
|
754
864
|
|
|
755
|
-
# At this point, we know the server is not a chat completion server.
|
|
756
|
-
# It means it's a transformers-backed server for which we can send a list of messages directly to the
|
|
757
|
-
# `text-generation` pipeline. We won't receive a detailed response but only the generated text.
|
|
758
865
|
if stream:
|
|
759
|
-
|
|
760
|
-
"Streaming token is not supported by the model. This is due to the model not been served by a "
|
|
761
|
-
"Text-Generation-Inference server. Please pass `stream=False` as input."
|
|
762
|
-
)
|
|
763
|
-
if tool_choice is not None or tool_prompt is not None or tools is not None:
|
|
764
|
-
warnings.warn(
|
|
765
|
-
"Tools are not supported by the model. This is due to the model not been served by a "
|
|
766
|
-
"Text-Generation-Inference server. The provided tool parameters will be ignored."
|
|
767
|
-
)
|
|
768
|
-
|
|
769
|
-
# generate response
|
|
770
|
-
text_generation_output = await self.text_generation(
|
|
771
|
-
prompt=messages, # type: ignore # Not correct type but works implicitly
|
|
772
|
-
model=model,
|
|
773
|
-
stream=False,
|
|
774
|
-
details=False,
|
|
775
|
-
max_new_tokens=max_tokens,
|
|
776
|
-
seed=seed,
|
|
777
|
-
stop_sequences=stop,
|
|
778
|
-
temperature=temperature,
|
|
779
|
-
top_p=top_p,
|
|
780
|
-
)
|
|
866
|
+
return _async_stream_chat_completion_response(data) # type: ignore[arg-type]
|
|
781
867
|
|
|
782
|
-
|
|
783
|
-
return ChatCompletionOutput(
|
|
784
|
-
id="dummy",
|
|
785
|
-
model="dummy",
|
|
786
|
-
object="dummy",
|
|
787
|
-
system_fingerprint="dummy",
|
|
788
|
-
usage=None, # type: ignore # set to `None` as we don't want to provide false information
|
|
789
|
-
created=int(time.time()),
|
|
790
|
-
choices=[
|
|
791
|
-
ChatCompletionOutputComplete(
|
|
792
|
-
finish_reason="unk", # type: ignore # set to `unk` as we don't want to provide false information
|
|
793
|
-
index=0,
|
|
794
|
-
message=ChatCompletionOutputMessage(
|
|
795
|
-
content=text_generation_output,
|
|
796
|
-
role="assistant",
|
|
797
|
-
),
|
|
798
|
-
)
|
|
799
|
-
],
|
|
800
|
-
)
|
|
868
|
+
return ChatCompletionOutput.parse_obj_as_instance(data) # type: ignore[arg-type]
|
|
801
869
|
|
|
802
870
|
async def conversational(
|
|
803
871
|
self,
|
|
@@ -850,7 +918,7 @@ class AsyncInferenceClient:
|
|
|
850
918
|
>>> client = AsyncInferenceClient()
|
|
851
919
|
>>> output = await client.conversational("Hi, who are you?")
|
|
852
920
|
>>> output
|
|
853
|
-
{'generated_text': 'I am the one who knocks.', 'conversation': {'generated_responses': ['I am the one who knocks.'], 'past_user_inputs': ['Hi, who are you?']}, 'warnings': ['Setting `pad_token_id` to `eos_token_id`:50256
|
|
921
|
+
{'generated_text': 'I am the one who knocks.', 'conversation': {'generated_responses': ['I am the one who knocks.'], 'past_user_inputs': ['Hi, who are you?']}, 'warnings': ['Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.']}
|
|
854
922
|
>>> await client.conversational(
|
|
855
923
|
... "Wow, that's scary!",
|
|
856
924
|
... generated_responses=output["conversation"]["generated_responses"],
|
|
@@ -915,7 +983,16 @@ class AsyncInferenceClient:
|
|
|
915
983
|
response = await self.post(json=payload, model=model, task="document-question-answering")
|
|
916
984
|
return DocumentQuestionAnsweringOutputElement.parse_obj_as_list(response)
|
|
917
985
|
|
|
918
|
-
async def feature_extraction(
|
|
986
|
+
async def feature_extraction(
|
|
987
|
+
self,
|
|
988
|
+
text: str,
|
|
989
|
+
*,
|
|
990
|
+
normalize: Optional[bool] = None,
|
|
991
|
+
prompt_name: Optional[str] = None,
|
|
992
|
+
truncate: Optional[bool] = None,
|
|
993
|
+
truncation_direction: Optional[Literal["Left", "Right"]] = None,
|
|
994
|
+
model: Optional[str] = None,
|
|
995
|
+
) -> "np.ndarray":
|
|
919
996
|
"""
|
|
920
997
|
Generate embeddings for a given text.
|
|
921
998
|
|
|
@@ -926,6 +1003,20 @@ class AsyncInferenceClient:
|
|
|
926
1003
|
The model to use for the conversational task. Can be a model ID hosted on the Hugging Face Hub or a URL to
|
|
927
1004
|
a deployed Inference Endpoint. If not provided, the default recommended conversational model will be used.
|
|
928
1005
|
Defaults to None.
|
|
1006
|
+
normalize (`bool`, *optional*):
|
|
1007
|
+
Whether to normalize the embeddings or not. Defaults to None.
|
|
1008
|
+
Only available on server powered by Text-Embedding-Inference.
|
|
1009
|
+
prompt_name (`str`, *optional*):
|
|
1010
|
+
The name of the prompt that should be used by for encoding. If not set, no prompt will be applied.
|
|
1011
|
+
Must be a key in the `Sentence Transformers` configuration `prompts` dictionary.
|
|
1012
|
+
For example if ``prompt_name`` is "query" and the ``prompts`` is {"query": "query: ",...},
|
|
1013
|
+
then the sentence "What is the capital of France?" will be encoded as "query: What is the capital of France?"
|
|
1014
|
+
because the prompt text will be prepended before any text to encode.
|
|
1015
|
+
truncate (`bool`, *optional*):
|
|
1016
|
+
Whether to truncate the embeddings or not. Defaults to None.
|
|
1017
|
+
Only available on server powered by Text-Embedding-Inference.
|
|
1018
|
+
truncation_direction (`Literal["Left", "Right"]`, *optional*):
|
|
1019
|
+
Which side of the input should be truncated when `truncate=True` is passed.
|
|
929
1020
|
|
|
930
1021
|
Returns:
|
|
931
1022
|
`np.ndarray`: The embedding representing the input text as a float32 numpy array.
|
|
@@ -948,7 +1039,16 @@ class AsyncInferenceClient:
|
|
|
948
1039
|
[ 0.28552425, -0.928395 , -1.2077185 , ..., 0.76810825, -2.1069427 , 0.6236161 ]], dtype=float32)
|
|
949
1040
|
```
|
|
950
1041
|
"""
|
|
951
|
-
|
|
1042
|
+
payload: Dict = {"inputs": text}
|
|
1043
|
+
if normalize is not None:
|
|
1044
|
+
payload["normalize"] = normalize
|
|
1045
|
+
if prompt_name is not None:
|
|
1046
|
+
payload["prompt_name"] = prompt_name
|
|
1047
|
+
if truncate is not None:
|
|
1048
|
+
payload["truncate"] = truncate
|
|
1049
|
+
if truncation_direction is not None:
|
|
1050
|
+
payload["truncation_direction"] = truncation_direction
|
|
1051
|
+
response = await self.post(json=payload, model=model, task="feature-extraction")
|
|
952
1052
|
np = _import_numpy()
|
|
953
1053
|
return np.array(_bytes_to_dict(response), dtype="float32")
|
|
954
1054
|
|
|
@@ -1192,7 +1292,8 @@ class AsyncInferenceClient:
|
|
|
1192
1292
|
```
|
|
1193
1293
|
"""
|
|
1194
1294
|
response = await self.post(data=image, model=model, task="image-to-text")
|
|
1195
|
-
|
|
1295
|
+
output = ImageToTextOutput.parse_obj(response)
|
|
1296
|
+
return output[0] if isinstance(output, list) else output
|
|
1196
1297
|
|
|
1197
1298
|
async def list_deployed_models(
|
|
1198
1299
|
self, frameworks: Union[None, str, Literal["all"], List[str]] = None
|
|
@@ -1643,6 +1744,7 @@ class AsyncInferenceClient:
|
|
|
1643
1744
|
stream: Literal[False] = ...,
|
|
1644
1745
|
model: Optional[str] = None,
|
|
1645
1746
|
# Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
|
|
1747
|
+
adapter_id: Optional[str] = None,
|
|
1646
1748
|
best_of: Optional[int] = None,
|
|
1647
1749
|
decoder_input_details: Optional[bool] = None,
|
|
1648
1750
|
do_sample: Optional[bool] = False, # Manual default value
|
|
@@ -1671,6 +1773,7 @@ class AsyncInferenceClient:
|
|
|
1671
1773
|
stream: Literal[False] = ...,
|
|
1672
1774
|
model: Optional[str] = None,
|
|
1673
1775
|
# Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
|
|
1776
|
+
adapter_id: Optional[str] = None,
|
|
1674
1777
|
best_of: Optional[int] = None,
|
|
1675
1778
|
decoder_input_details: Optional[bool] = None,
|
|
1676
1779
|
do_sample: Optional[bool] = False, # Manual default value
|
|
@@ -1699,6 +1802,7 @@ class AsyncInferenceClient:
|
|
|
1699
1802
|
stream: Literal[True] = ...,
|
|
1700
1803
|
model: Optional[str] = None,
|
|
1701
1804
|
# Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
|
|
1805
|
+
adapter_id: Optional[str] = None,
|
|
1702
1806
|
best_of: Optional[int] = None,
|
|
1703
1807
|
decoder_input_details: Optional[bool] = None,
|
|
1704
1808
|
do_sample: Optional[bool] = False, # Manual default value
|
|
@@ -1727,6 +1831,7 @@ class AsyncInferenceClient:
|
|
|
1727
1831
|
stream: Literal[True] = ...,
|
|
1728
1832
|
model: Optional[str] = None,
|
|
1729
1833
|
# Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
|
|
1834
|
+
adapter_id: Optional[str] = None,
|
|
1730
1835
|
best_of: Optional[int] = None,
|
|
1731
1836
|
decoder_input_details: Optional[bool] = None,
|
|
1732
1837
|
do_sample: Optional[bool] = False, # Manual default value
|
|
@@ -1755,6 +1860,7 @@ class AsyncInferenceClient:
|
|
|
1755
1860
|
stream: bool = ...,
|
|
1756
1861
|
model: Optional[str] = None,
|
|
1757
1862
|
# Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
|
|
1863
|
+
adapter_id: Optional[str] = None,
|
|
1758
1864
|
best_of: Optional[int] = None,
|
|
1759
1865
|
decoder_input_details: Optional[bool] = None,
|
|
1760
1866
|
do_sample: Optional[bool] = False, # Manual default value
|
|
@@ -1782,6 +1888,7 @@ class AsyncInferenceClient:
|
|
|
1782
1888
|
stream: bool = False,
|
|
1783
1889
|
model: Optional[str] = None,
|
|
1784
1890
|
# Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
|
|
1891
|
+
adapter_id: Optional[str] = None,
|
|
1785
1892
|
best_of: Optional[int] = None,
|
|
1786
1893
|
decoder_input_details: Optional[bool] = None,
|
|
1787
1894
|
do_sample: Optional[bool] = False, # Manual default value
|
|
@@ -1812,6 +1919,13 @@ class AsyncInferenceClient:
|
|
|
1812
1919
|
|
|
1813
1920
|
To learn more about the TGI project, please refer to https://github.com/huggingface/text-generation-inference.
|
|
1814
1921
|
|
|
1922
|
+
<Tip>
|
|
1923
|
+
|
|
1924
|
+
If you want to generate a response from chat messages, you should use the [`InferenceClient.chat_completion`] method.
|
|
1925
|
+
It accepts a list of messages instead of a single text prompt and handles the chat templating for you.
|
|
1926
|
+
|
|
1927
|
+
</Tip>
|
|
1928
|
+
|
|
1815
1929
|
Args:
|
|
1816
1930
|
prompt (`str`):
|
|
1817
1931
|
Input text.
|
|
@@ -1826,6 +1940,8 @@ class AsyncInferenceClient:
|
|
|
1826
1940
|
model (`str`, *optional*):
|
|
1827
1941
|
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
|
1828
1942
|
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
|
|
1943
|
+
adapter_id (`str`, *optional*):
|
|
1944
|
+
Lora adapter id.
|
|
1829
1945
|
best_of (`int`, *optional*):
|
|
1830
1946
|
Generate best_of sequences and return the one if the highest token logprobs.
|
|
1831
1947
|
decoder_input_details (`bool`, *optional*):
|
|
@@ -1893,7 +2009,7 @@ class AsyncInferenceClient:
|
|
|
1893
2009
|
>>> await client.text_generation("The huggingface_hub library is ", max_new_tokens=12)
|
|
1894
2010
|
'100% open source and built to be easy to use.'
|
|
1895
2011
|
|
|
1896
|
-
# Case 2: iterate over the generated tokens. Useful
|
|
2012
|
+
# Case 2: iterate over the generated tokens. Useful for large generation.
|
|
1897
2013
|
>>> async for token in await client.text_generation("The huggingface_hub library is ", max_new_tokens=12, stream=True):
|
|
1898
2014
|
... print(token)
|
|
1899
2015
|
100
|
|
@@ -1995,6 +2111,7 @@ class AsyncInferenceClient:
|
|
|
1995
2111
|
|
|
1996
2112
|
# Build payload
|
|
1997
2113
|
parameters = {
|
|
2114
|
+
"adapter_id": adapter_id,
|
|
1998
2115
|
"best_of": best_of,
|
|
1999
2116
|
"decoder_input_details": decoder_input_details,
|
|
2000
2117
|
"details": details,
|
|
@@ -2065,6 +2182,7 @@ class AsyncInferenceClient:
|
|
|
2065
2182
|
details=details,
|
|
2066
2183
|
stream=stream,
|
|
2067
2184
|
model=model,
|
|
2185
|
+
adapter_id=adapter_id,
|
|
2068
2186
|
best_of=best_of,
|
|
2069
2187
|
decoder_input_details=decoder_input_details,
|
|
2070
2188
|
do_sample=do_sample,
|
|
@@ -2089,7 +2207,12 @@ class AsyncInferenceClient:
|
|
|
2089
2207
|
if stream:
|
|
2090
2208
|
return _async_stream_text_generation_response(bytes_output, details) # type: ignore
|
|
2091
2209
|
|
|
2092
|
-
data = _bytes_to_dict(bytes_output)
|
|
2210
|
+
data = _bytes_to_dict(bytes_output) # type: ignore[arg-type]
|
|
2211
|
+
|
|
2212
|
+
# Data can be a single element (dict) or an iterable of dicts where we select the first element of.
|
|
2213
|
+
if isinstance(data, list):
|
|
2214
|
+
data = data[0]
|
|
2215
|
+
|
|
2093
2216
|
return TextGenerationOutput.parse_obj_as_instance(data) if details else data["generated_text"]
|
|
2094
2217
|
|
|
2095
2218
|
async def text_to_image(
|
|
@@ -2377,7 +2500,13 @@ class AsyncInferenceClient:
|
|
|
2377
2500
|
return VisualQuestionAnsweringOutputElement.parse_obj_as_list(response)
|
|
2378
2501
|
|
|
2379
2502
|
async def zero_shot_classification(
|
|
2380
|
-
self,
|
|
2503
|
+
self,
|
|
2504
|
+
text: str,
|
|
2505
|
+
labels: List[str],
|
|
2506
|
+
*,
|
|
2507
|
+
multi_label: bool = False,
|
|
2508
|
+
hypothesis_template: Optional[str] = None,
|
|
2509
|
+
model: Optional[str] = None,
|
|
2381
2510
|
) -> List[ZeroShotClassificationOutputElement]:
|
|
2382
2511
|
"""
|
|
2383
2512
|
Provide as input a text and a set of candidate labels to classify the input text.
|
|
@@ -2386,9 +2515,15 @@ class AsyncInferenceClient:
|
|
|
2386
2515
|
text (`str`):
|
|
2387
2516
|
The input text to classify.
|
|
2388
2517
|
labels (`List[str]`):
|
|
2389
|
-
List of string
|
|
2518
|
+
List of strings. Each string is the verbalization of a possible label for the input text.
|
|
2390
2519
|
multi_label (`bool`):
|
|
2391
|
-
Boolean
|
|
2520
|
+
Boolean. If True, the probability for each label is evaluated independently and multiple labels can have a probability close to 1 simultaneously or all probabilities can be close to 0.
|
|
2521
|
+
If False, the labels are considered mutually exclusive and the probability over all labels always sums to 1. Defaults to False.
|
|
2522
|
+
hypothesis_template (`str`, *optional*):
|
|
2523
|
+
A template sentence string with curly brackets to which the label strings are added. The label strings are added at the position of the curly brackets "{}".
|
|
2524
|
+
Zero-shot classifiers are based on NLI models, which evaluate if a hypothesis is entailed in another text or not.
|
|
2525
|
+
For example, with hypothesis_template="This text is about {}." and labels=["economics", "politics"], the system internally creates the two hypotheses "This text is about economics." and "This text is about politics.".
|
|
2526
|
+
The model then evaluates for both hypotheses if they are entailed in the provided `text` or not.
|
|
2392
2527
|
model (`str`, *optional*):
|
|
2393
2528
|
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
|
2394
2529
|
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
|
|
@@ -2402,15 +2537,15 @@ class AsyncInferenceClient:
|
|
|
2402
2537
|
`aiohttp.ClientResponseError`:
|
|
2403
2538
|
If the request fails with an HTTP error status code other than HTTP 503.
|
|
2404
2539
|
|
|
2405
|
-
Example
|
|
2540
|
+
Example with `multi_label=False`:
|
|
2406
2541
|
```py
|
|
2407
2542
|
# Must be run in an async context
|
|
2408
2543
|
>>> from huggingface_hub import AsyncInferenceClient
|
|
2409
2544
|
>>> client = AsyncInferenceClient()
|
|
2410
2545
|
>>> text = (
|
|
2411
|
-
... "A new model offers an explanation
|
|
2546
|
+
... "A new model offers an explanation for how the Galilean satellites formed around the solar system's"
|
|
2412
2547
|
... "largest world. Konstantin Batygin did not set out to solve one of the solar system's most puzzling"
|
|
2413
|
-
... " mysteries when he went
|
|
2548
|
+
... " mysteries when he went for a run up a hill in Nice, France."
|
|
2414
2549
|
... )
|
|
2415
2550
|
>>> labels = ["space & cosmos", "scientific discovery", "microbiology", "robots", "archeology"]
|
|
2416
2551
|
>>> await client.zero_shot_classification(text, labels)
|
|
@@ -2430,21 +2565,38 @@ class AsyncInferenceClient:
|
|
|
2430
2565
|
ZeroShotClassificationOutputElement(label='robots', score=0.00030448526376858354),
|
|
2431
2566
|
]
|
|
2432
2567
|
```
|
|
2568
|
+
|
|
2569
|
+
Example with `multi_label=True` and a custom `hypothesis_template`:
|
|
2570
|
+
```py
|
|
2571
|
+
# Must be run in an async context
|
|
2572
|
+
>>> from huggingface_hub import AsyncInferenceClient
|
|
2573
|
+
>>> client = AsyncInferenceClient()
|
|
2574
|
+
>>> await client.zero_shot_classification(
|
|
2575
|
+
... text="I really like our dinner and I'm very happy. I don't like the weather though.",
|
|
2576
|
+
... labels=["positive", "negative", "pessimistic", "optimistic"],
|
|
2577
|
+
... multi_label=True,
|
|
2578
|
+
... hypothesis_template="This text is {} towards the weather"
|
|
2579
|
+
... )
|
|
2580
|
+
[
|
|
2581
|
+
ZeroShotClassificationOutputElement(label='negative', score=0.9231801629066467),
|
|
2582
|
+
ZeroShotClassificationOutputElement(label='pessimistic', score=0.8760990500450134),
|
|
2583
|
+
ZeroShotClassificationOutputElement(label='optimistic', score=0.0008674879791215062),
|
|
2584
|
+
ZeroShotClassificationOutputElement(label='positive', score=0.0005250611575320363)
|
|
2585
|
+
]
|
|
2586
|
+
```
|
|
2433
2587
|
"""
|
|
2434
|
-
|
|
2435
|
-
|
|
2436
|
-
|
|
2588
|
+
|
|
2589
|
+
parameters = {"candidate_labels": labels, "multi_label": multi_label}
|
|
2590
|
+
if hypothesis_template is not None:
|
|
2591
|
+
parameters["hypothesis_template"] = hypothesis_template
|
|
2437
2592
|
|
|
2438
2593
|
response = await self.post(
|
|
2439
2594
|
json={
|
|
2440
2595
|
"inputs": text,
|
|
2441
|
-
"parameters":
|
|
2442
|
-
"candidate_labels": ",".join(labels),
|
|
2443
|
-
"multi_label": multi_label,
|
|
2444
|
-
},
|
|
2596
|
+
"parameters": parameters,
|
|
2445
2597
|
},
|
|
2446
|
-
model=model,
|
|
2447
2598
|
task="zero-shot-classification",
|
|
2599
|
+
model=model,
|
|
2448
2600
|
)
|
|
2449
2601
|
output = _bytes_to_dict(response)
|
|
2450
2602
|
return [
|
|
@@ -2501,7 +2653,7 @@ class AsyncInferenceClient:
|
|
|
2501
2653
|
return ZeroShotImageClassificationOutputElement.parse_obj_as_list(response)
|
|
2502
2654
|
|
|
2503
2655
|
def _resolve_url(self, model: Optional[str] = None, task: Optional[str] = None) -> str:
|
|
2504
|
-
model = model or self.model
|
|
2656
|
+
model = model or self.model or self.base_url
|
|
2505
2657
|
|
|
2506
2658
|
# If model is already a URL, ignore `task` and return directly
|
|
2507
2659
|
if model is not None and (model.startswith("http://") or model.startswith("https://")):
|
|
@@ -2554,6 +2706,99 @@ class AsyncInferenceClient:
|
|
|
2554
2706
|
)
|
|
2555
2707
|
return model
|
|
2556
2708
|
|
|
2709
|
+
async def get_endpoint_info(self, *, model: Optional[str] = None) -> Dict[str, Any]:
|
|
2710
|
+
"""
|
|
2711
|
+
Get information about the deployed endpoint.
|
|
2712
|
+
|
|
2713
|
+
This endpoint is only available on endpoints powered by Text-Generation-Inference (TGI) or Text-Embedding-Inference (TEI).
|
|
2714
|
+
Endpoints powered by `transformers` return an empty payload.
|
|
2715
|
+
|
|
2716
|
+
Args:
|
|
2717
|
+
model (`str`, *optional*):
|
|
2718
|
+
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
|
2719
|
+
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
|
|
2720
|
+
|
|
2721
|
+
Returns:
|
|
2722
|
+
`Dict[str, Any]`: Information about the endpoint.
|
|
2723
|
+
|
|
2724
|
+
Example:
|
|
2725
|
+
```py
|
|
2726
|
+
# Must be run in an async context
|
|
2727
|
+
>>> from huggingface_hub import AsyncInferenceClient
|
|
2728
|
+
>>> client = AsyncInferenceClient("meta-llama/Meta-Llama-3-70B-Instruct")
|
|
2729
|
+
>>> await client.get_endpoint_info()
|
|
2730
|
+
{
|
|
2731
|
+
'model_id': 'meta-llama/Meta-Llama-3-70B-Instruct',
|
|
2732
|
+
'model_sha': None,
|
|
2733
|
+
'model_dtype': 'torch.float16',
|
|
2734
|
+
'model_device_type': 'cuda',
|
|
2735
|
+
'model_pipeline_tag': None,
|
|
2736
|
+
'max_concurrent_requests': 128,
|
|
2737
|
+
'max_best_of': 2,
|
|
2738
|
+
'max_stop_sequences': 4,
|
|
2739
|
+
'max_input_length': 8191,
|
|
2740
|
+
'max_total_tokens': 8192,
|
|
2741
|
+
'waiting_served_ratio': 0.3,
|
|
2742
|
+
'max_batch_total_tokens': 1259392,
|
|
2743
|
+
'max_waiting_tokens': 20,
|
|
2744
|
+
'max_batch_size': None,
|
|
2745
|
+
'validation_workers': 32,
|
|
2746
|
+
'max_client_batch_size': 4,
|
|
2747
|
+
'version': '2.0.2',
|
|
2748
|
+
'sha': 'dccab72549635c7eb5ddb17f43f0b7cdff07c214',
|
|
2749
|
+
'docker_label': 'sha-dccab72'
|
|
2750
|
+
}
|
|
2751
|
+
```
|
|
2752
|
+
"""
|
|
2753
|
+
model = model or self.model
|
|
2754
|
+
if model is None:
|
|
2755
|
+
raise ValueError("Model id not provided.")
|
|
2756
|
+
if model.startswith(("http://", "https://")):
|
|
2757
|
+
url = model.rstrip("/") + "/info"
|
|
2758
|
+
else:
|
|
2759
|
+
url = f"{INFERENCE_ENDPOINT}/models/{model}/info"
|
|
2760
|
+
|
|
2761
|
+
async with _import_aiohttp().ClientSession(headers=self.headers) as client:
|
|
2762
|
+
response = await client.get(url)
|
|
2763
|
+
response.raise_for_status()
|
|
2764
|
+
return await response.json()
|
|
2765
|
+
|
|
2766
|
+
async def health_check(self, model: Optional[str] = None) -> bool:
|
|
2767
|
+
"""
|
|
2768
|
+
Check the health of the deployed endpoint.
|
|
2769
|
+
|
|
2770
|
+
Health check is only available with Inference Endpoints powered by Text-Generation-Inference (TGI) or Text-Embedding-Inference (TEI).
|
|
2771
|
+
For Inference API, please use [`InferenceClient.get_model_status`] instead.
|
|
2772
|
+
|
|
2773
|
+
Args:
|
|
2774
|
+
model (`str`, *optional*):
|
|
2775
|
+
URL of the Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
|
|
2776
|
+
|
|
2777
|
+
Returns:
|
|
2778
|
+
`bool`: True if everything is working fine.
|
|
2779
|
+
|
|
2780
|
+
Example:
|
|
2781
|
+
```py
|
|
2782
|
+
# Must be run in an async context
|
|
2783
|
+
>>> from huggingface_hub import AsyncInferenceClient
|
|
2784
|
+
>>> client = AsyncInferenceClient("https://jzgu0buei5.us-east-1.aws.endpoints.huggingface.cloud")
|
|
2785
|
+
>>> await client.health_check()
|
|
2786
|
+
True
|
|
2787
|
+
```
|
|
2788
|
+
"""
|
|
2789
|
+
model = model or self.model
|
|
2790
|
+
if model is None:
|
|
2791
|
+
raise ValueError("Model id not provided.")
|
|
2792
|
+
if not model.startswith(("http://", "https://")):
|
|
2793
|
+
raise ValueError(
|
|
2794
|
+
"Model must be an Inference Endpoint URL. For serverless Inference API, please use `InferenceClient.get_model_status`."
|
|
2795
|
+
)
|
|
2796
|
+
url = model.rstrip("/") + "/health"
|
|
2797
|
+
|
|
2798
|
+
async with _import_aiohttp().ClientSession(headers=self.headers) as client:
|
|
2799
|
+
response = await client.get(url)
|
|
2800
|
+
return response.status == 200
|
|
2801
|
+
|
|
2557
2802
|
async def get_model_status(self, model: Optional[str] = None) -> ModelStatus:
|
|
2558
2803
|
"""
|
|
2559
2804
|
Get the status of a model hosted on the Inference API.
|
|
@@ -2581,7 +2826,7 @@ class AsyncInferenceClient:
|
|
|
2581
2826
|
# Must be run in an async context
|
|
2582
2827
|
>>> from huggingface_hub import AsyncInferenceClient
|
|
2583
2828
|
>>> client = AsyncInferenceClient()
|
|
2584
|
-
>>> await client.get_model_status("
|
|
2829
|
+
>>> await client.get_model_status("meta-llama/Meta-Llama-3-8B-Instruct")
|
|
2585
2830
|
ModelStatus(loaded=True, state='Loaded', compute_type='gpu', framework='text-generation-inference')
|
|
2586
2831
|
```
|
|
2587
2832
|
"""
|
|
@@ -2606,3 +2851,30 @@ class AsyncInferenceClient:
|
|
|
2606
2851
|
compute_type=response_data["compute_type"],
|
|
2607
2852
|
framework=response_data["framework"],
|
|
2608
2853
|
)
|
|
2854
|
+
|
|
2855
|
+
@property
|
|
2856
|
+
def chat(self) -> "ProxyClientChat":
|
|
2857
|
+
return ProxyClientChat(self)
|
|
2858
|
+
|
|
2859
|
+
|
|
2860
|
+
class _ProxyClient:
|
|
2861
|
+
"""Proxy class to be able to call `client.chat.completion.create(...)` as OpenAI client."""
|
|
2862
|
+
|
|
2863
|
+
def __init__(self, client: AsyncInferenceClient):
|
|
2864
|
+
self._client = client
|
|
2865
|
+
|
|
2866
|
+
|
|
2867
|
+
class ProxyClientChat(_ProxyClient):
|
|
2868
|
+
"""Proxy class to be able to call `client.chat.completion.create(...)` as OpenAI client."""
|
|
2869
|
+
|
|
2870
|
+
@property
|
|
2871
|
+
def completions(self) -> "ProxyClientChatCompletions":
|
|
2872
|
+
return ProxyClientChatCompletions(self._client)
|
|
2873
|
+
|
|
2874
|
+
|
|
2875
|
+
class ProxyClientChatCompletions(_ProxyClient):
|
|
2876
|
+
"""Proxy class to be able to call `client.chat.completion.create(...)` as OpenAI client."""
|
|
2877
|
+
|
|
2878
|
+
@property
|
|
2879
|
+
def create(self):
|
|
2880
|
+
return self._client.chat_completion
|