huggingface-hub 0.29.3__py3-none-any.whl → 0.30.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of huggingface-hub might be problematic. Click here for more details.
- huggingface_hub/__init__.py +16 -1
- huggingface_hub/_commit_api.py +142 -4
- huggingface_hub/_space_api.py +15 -2
- huggingface_hub/_webhooks_server.py +2 -0
- huggingface_hub/commands/delete_cache.py +66 -20
- huggingface_hub/commands/upload.py +16 -2
- huggingface_hub/constants.py +45 -7
- huggingface_hub/errors.py +19 -0
- huggingface_hub/file_download.py +163 -35
- huggingface_hub/hf_api.py +349 -28
- huggingface_hub/hub_mixin.py +19 -4
- huggingface_hub/inference/_client.py +73 -70
- huggingface_hub/inference/_generated/_async_client.py +80 -77
- huggingface_hub/inference/_generated/types/__init__.py +1 -0
- huggingface_hub/inference/_generated/types/chat_completion.py +20 -10
- huggingface_hub/inference/_generated/types/image_to_image.py +2 -0
- huggingface_hub/inference/_providers/__init__.py +7 -1
- huggingface_hub/inference/_providers/_common.py +9 -5
- huggingface_hub/inference/_providers/black_forest_labs.py +5 -5
- huggingface_hub/inference/_providers/cohere.py +1 -1
- huggingface_hub/inference/_providers/fal_ai.py +64 -7
- huggingface_hub/inference/_providers/fireworks_ai.py +4 -1
- huggingface_hub/inference/_providers/hf_inference.py +41 -4
- huggingface_hub/inference/_providers/hyperbolic.py +3 -3
- huggingface_hub/inference/_providers/nebius.py +3 -3
- huggingface_hub/inference/_providers/novita.py +35 -5
- huggingface_hub/inference/_providers/openai.py +22 -0
- huggingface_hub/inference/_providers/replicate.py +3 -3
- huggingface_hub/inference/_providers/together.py +3 -3
- huggingface_hub/utils/__init__.py +8 -0
- huggingface_hub/utils/_http.py +4 -1
- huggingface_hub/utils/_runtime.py +11 -0
- huggingface_hub/utils/_xet.py +199 -0
- huggingface_hub/utils/tqdm.py +30 -2
- {huggingface_hub-0.29.3.dist-info → huggingface_hub-0.30.0.dist-info}/METADATA +3 -1
- {huggingface_hub-0.29.3.dist-info → huggingface_hub-0.30.0.dist-info}/RECORD +40 -38
- {huggingface_hub-0.29.3.dist-info → huggingface_hub-0.30.0.dist-info}/LICENSE +0 -0
- {huggingface_hub-0.29.3.dist-info → huggingface_hub-0.30.0.dist-info}/WHEEL +0 -0
- {huggingface_hub-0.29.3.dist-info → huggingface_hub-0.30.0.dist-info}/entry_points.txt +0 -0
- {huggingface_hub-0.29.3.dist-info → huggingface_hub-0.30.0.dist-info}/top_level.txt +0 -0
huggingface_hub/hub_mixin.py
CHANGED
|
@@ -58,7 +58,8 @@ DEFAULT_MODEL_CARD = """
|
|
|
58
58
|
---
|
|
59
59
|
|
|
60
60
|
This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
|
|
61
|
-
-
|
|
61
|
+
- Code: {{ repo_url | default("[More Information Needed]", true) }}
|
|
62
|
+
- Paper: {{ paper_url | default("[More Information Needed]", true) }}
|
|
62
63
|
- Docs: {{ docs_url | default("[More Information Needed]", true) }}
|
|
63
64
|
"""
|
|
64
65
|
|
|
@@ -67,8 +68,9 @@ This model has been pushed to the Hub using the [PytorchModelHubMixin](https://h
|
|
|
67
68
|
class MixinInfo:
|
|
68
69
|
model_card_template: str
|
|
69
70
|
model_card_data: ModelCardData
|
|
70
|
-
repo_url: Optional[str] = None
|
|
71
71
|
docs_url: Optional[str] = None
|
|
72
|
+
paper_url: Optional[str] = None
|
|
73
|
+
repo_url: Optional[str] = None
|
|
72
74
|
|
|
73
75
|
|
|
74
76
|
class ModelHubMixin:
|
|
@@ -88,6 +90,8 @@ class ModelHubMixin:
|
|
|
88
90
|
Args:
|
|
89
91
|
repo_url (`str`, *optional*):
|
|
90
92
|
URL of the library repository. Used to generate model card.
|
|
93
|
+
paper_url (`str`, *optional*):
|
|
94
|
+
URL of the library paper. Used to generate model card.
|
|
91
95
|
docs_url (`str`, *optional*):
|
|
92
96
|
URL of the library documentation. Used to generate model card.
|
|
93
97
|
model_card_template (`str`, *optional*):
|
|
@@ -110,7 +114,7 @@ class ModelHubMixin:
|
|
|
110
114
|
pipeline_tag (`str`, *optional*):
|
|
111
115
|
Tag of the pipeline. Used to generate model card. E.g. "text-classification".
|
|
112
116
|
tags (`List[str]`, *optional*):
|
|
113
|
-
Tags to be added to the model card. Used to generate model card. E.g. ["
|
|
117
|
+
Tags to be added to the model card. Used to generate model card. E.g. ["computer-vision"]
|
|
114
118
|
coders (`Dict[Type, Tuple[Callable, Callable]]`, *optional*):
|
|
115
119
|
Dictionary of custom types and their encoders/decoders. Used to encode/decode arguments that are not
|
|
116
120
|
jsonable by default. E.g dataclasses, argparse.Namespace, OmegaConf, etc.
|
|
@@ -124,8 +128,9 @@ class ModelHubMixin:
|
|
|
124
128
|
>>> class MyCustomModel(
|
|
125
129
|
... ModelHubMixin,
|
|
126
130
|
... library_name="my-library",
|
|
127
|
-
... tags=["
|
|
131
|
+
... tags=["computer-vision"],
|
|
128
132
|
... repo_url="https://github.com/huggingface/my-cool-library",
|
|
133
|
+
... paper_url="https://arxiv.org/abs/2304.12244",
|
|
129
134
|
... docs_url="https://huggingface.co/docs/my-cool-library",
|
|
130
135
|
... # ^ optional metadata to generate model card
|
|
131
136
|
... ):
|
|
@@ -194,6 +199,7 @@ class ModelHubMixin:
|
|
|
194
199
|
*,
|
|
195
200
|
# Generic info for model card
|
|
196
201
|
repo_url: Optional[str] = None,
|
|
202
|
+
paper_url: Optional[str] = None,
|
|
197
203
|
docs_url: Optional[str] = None,
|
|
198
204
|
# Model card template
|
|
199
205
|
model_card_template: str = DEFAULT_MODEL_CARD,
|
|
@@ -234,6 +240,7 @@ class ModelHubMixin:
|
|
|
234
240
|
|
|
235
241
|
# Inherit other info
|
|
236
242
|
info.docs_url = cls._hub_mixin_info.docs_url
|
|
243
|
+
info.paper_url = cls._hub_mixin_info.paper_url
|
|
237
244
|
info.repo_url = cls._hub_mixin_info.repo_url
|
|
238
245
|
cls._hub_mixin_info = info
|
|
239
246
|
|
|
@@ -242,6 +249,8 @@ class ModelHubMixin:
|
|
|
242
249
|
info.model_card_template = model_card_template
|
|
243
250
|
if repo_url is not None:
|
|
244
251
|
info.repo_url = repo_url
|
|
252
|
+
if paper_url is not None:
|
|
253
|
+
info.paper_url = paper_url
|
|
245
254
|
if docs_url is not None:
|
|
246
255
|
info.docs_url = docs_url
|
|
247
256
|
if language is not None:
|
|
@@ -334,6 +343,8 @@ class ModelHubMixin:
|
|
|
334
343
|
@classmethod
|
|
335
344
|
def _is_jsonable(cls, value: Any) -> bool:
|
|
336
345
|
"""Check if a value is JSON serializable."""
|
|
346
|
+
if is_dataclass(value):
|
|
347
|
+
return True
|
|
337
348
|
if isinstance(value, cls._hub_mixin_jsonable_custom_types):
|
|
338
349
|
return True
|
|
339
350
|
return is_jsonable(value)
|
|
@@ -341,6 +352,8 @@ class ModelHubMixin:
|
|
|
341
352
|
@classmethod
|
|
342
353
|
def _encode_arg(cls, arg: Any) -> Any:
|
|
343
354
|
"""Encode an argument into a JSON serializable format."""
|
|
355
|
+
if is_dataclass(arg):
|
|
356
|
+
return asdict(arg)
|
|
344
357
|
for type_, (encoder, _) in cls._hub_mixin_coders.items():
|
|
345
358
|
if isinstance(arg, type_):
|
|
346
359
|
if arg is None:
|
|
@@ -692,6 +705,7 @@ class ModelHubMixin:
|
|
|
692
705
|
card_data=self._hub_mixin_info.model_card_data,
|
|
693
706
|
template_str=self._hub_mixin_info.model_card_template,
|
|
694
707
|
repo_url=self._hub_mixin_info.repo_url,
|
|
708
|
+
paper_url=self._hub_mixin_info.paper_url,
|
|
695
709
|
docs_url=self._hub_mixin_info.docs_url,
|
|
696
710
|
**kwargs,
|
|
697
711
|
)
|
|
@@ -718,6 +732,7 @@ class PyTorchModelHubMixin(ModelHubMixin):
|
|
|
718
732
|
... PyTorchModelHubMixin,
|
|
719
733
|
... library_name="keras-nlp",
|
|
720
734
|
... repo_url="https://github.com/keras-team/keras-nlp",
|
|
735
|
+
... paper_url="https://arxiv.org/abs/2304.12244",
|
|
721
736
|
... docs_url="https://keras.io/keras_nlp/",
|
|
722
737
|
... # ^ optional metadata to generate model card
|
|
723
738
|
... ):
|
|
@@ -102,7 +102,8 @@ from huggingface_hub.inference._generated.types import (
|
|
|
102
102
|
)
|
|
103
103
|
from huggingface_hub.inference._providers import PROVIDER_T, HFInferenceTask, get_provider_helper
|
|
104
104
|
from huggingface_hub.utils import build_hf_headers, get_session, hf_raise_for_status
|
|
105
|
-
from huggingface_hub.utils.
|
|
105
|
+
from huggingface_hub.utils._auth import get_token
|
|
106
|
+
from huggingface_hub.utils._deprecation import _deprecate_method
|
|
106
107
|
|
|
107
108
|
|
|
108
109
|
if TYPE_CHECKING:
|
|
@@ -132,12 +133,11 @@ class InferenceClient:
|
|
|
132
133
|
path will be appended to the base URL (see the [TGI Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api)
|
|
133
134
|
documentation for details). When passing a URL as `model`, the client will not append any suffix path to it.
|
|
134
135
|
provider (`str`, *optional*):
|
|
135
|
-
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"replicate"`, "sambanova"` or `"together"`.
|
|
136
|
+
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"openai"`, `"replicate"`, "sambanova"` or `"together"`.
|
|
136
137
|
defaults to hf-inference (Hugging Face Serverless Inference API).
|
|
137
138
|
If model is a URL or `base_url` is passed, then `provider` is not used.
|
|
138
|
-
token (`str
|
|
139
|
+
token (`str`, *optional*):
|
|
139
140
|
Hugging Face token. Will default to the locally saved token if not provided.
|
|
140
|
-
Pass `token=False` if you don't want to send your token to the server.
|
|
141
141
|
Note: for better compatibility with OpenAI's client, `token` has been aliased as `api_key`. Those 2
|
|
142
142
|
arguments are mutually exclusive and have the exact same behavior.
|
|
143
143
|
timeout (`float`, `optional`):
|
|
@@ -146,6 +146,9 @@ class InferenceClient:
|
|
|
146
146
|
headers (`Dict[str, str]`, `optional`):
|
|
147
147
|
Additional headers to send to the server. By default only the authorization and user-agent headers are sent.
|
|
148
148
|
Values in this dictionary will override the default values.
|
|
149
|
+
bill_to (`str`, `optional`):
|
|
150
|
+
The billing account to use for the requests. By default the requests are billed on the user's account.
|
|
151
|
+
Requests can only be billed to an organization the user is a member of, and which has subscribed to Enterprise Hub.
|
|
149
152
|
cookies (`Dict[str, str]`, `optional`):
|
|
150
153
|
Additional cookies to send to the server.
|
|
151
154
|
proxies (`Any`, `optional`):
|
|
@@ -168,6 +171,7 @@ class InferenceClient:
|
|
|
168
171
|
headers: Optional[Dict[str, str]] = None,
|
|
169
172
|
cookies: Optional[Dict[str, str]] = None,
|
|
170
173
|
proxies: Optional[Any] = None,
|
|
174
|
+
bill_to: Optional[str] = None,
|
|
171
175
|
# OpenAI compatibility
|
|
172
176
|
base_url: Optional[str] = None,
|
|
173
177
|
api_key: Optional[str] = None,
|
|
@@ -185,10 +189,43 @@ class InferenceClient:
|
|
|
185
189
|
" `api_key` is an alias for `token` to make the API compatible with OpenAI's client."
|
|
186
190
|
" It has the exact same behavior as `token`."
|
|
187
191
|
)
|
|
192
|
+
token = token if token is not None else api_key
|
|
193
|
+
if isinstance(token, bool):
|
|
194
|
+
# Legacy behavior: previously is was possible to pass `token=False` to disable authentication. This is not
|
|
195
|
+
# supported anymore as authentication is required. Better to explicitly raise here rather than risking
|
|
196
|
+
# sending the locally saved token without the user knowing about it.
|
|
197
|
+
if token is False:
|
|
198
|
+
raise ValueError(
|
|
199
|
+
"Cannot use `token=False` to disable authentication as authentication is required to run Inference."
|
|
200
|
+
)
|
|
201
|
+
warnings.warn(
|
|
202
|
+
"Using `token=True` to automatically use the locally saved token is deprecated and will be removed in a future release. "
|
|
203
|
+
"Please use `token=None` instead (default).",
|
|
204
|
+
DeprecationWarning,
|
|
205
|
+
)
|
|
206
|
+
token = get_token()
|
|
188
207
|
|
|
189
208
|
self.model: Optional[str] = base_url or model
|
|
190
|
-
self.token: Optional[str] = token
|
|
191
|
-
|
|
209
|
+
self.token: Optional[str] = token
|
|
210
|
+
|
|
211
|
+
self.headers = {**headers} if headers is not None else {}
|
|
212
|
+
if bill_to is not None:
|
|
213
|
+
if (
|
|
214
|
+
constants.HUGGINGFACE_HEADER_X_BILL_TO in self.headers
|
|
215
|
+
and self.headers[constants.HUGGINGFACE_HEADER_X_BILL_TO] != bill_to
|
|
216
|
+
):
|
|
217
|
+
warnings.warn(
|
|
218
|
+
f"Overriding existing '{self.headers[constants.HUGGINGFACE_HEADER_X_BILL_TO]}' value in headers with '{bill_to}'.",
|
|
219
|
+
UserWarning,
|
|
220
|
+
)
|
|
221
|
+
self.headers[constants.HUGGINGFACE_HEADER_X_BILL_TO] = bill_to
|
|
222
|
+
|
|
223
|
+
if token is not None and not token.startswith("hf_"):
|
|
224
|
+
warnings.warn(
|
|
225
|
+
"You've provided an external provider's API key, so requests will be billed directly by the provider. "
|
|
226
|
+
"The `bill_to` parameter is only applicable for Hugging Face billing and will be ignored.",
|
|
227
|
+
UserWarning,
|
|
228
|
+
)
|
|
192
229
|
|
|
193
230
|
# Configure provider
|
|
194
231
|
self.provider = provider if provider is not None else "hf-inference"
|
|
@@ -300,33 +337,32 @@ class InferenceClient:
|
|
|
300
337
|
if request_parameters.task in TASKS_EXPECTING_IMAGES and "Accept" not in request_parameters.headers:
|
|
301
338
|
request_parameters.headers["Accept"] = "image/png"
|
|
302
339
|
|
|
303
|
-
|
|
304
|
-
with _open_as_binary(request_parameters.data) as data_as_binary:
|
|
305
|
-
try:
|
|
306
|
-
response = get_session().post(
|
|
307
|
-
request_parameters.url,
|
|
308
|
-
json=request_parameters.json,
|
|
309
|
-
data=data_as_binary,
|
|
310
|
-
headers=request_parameters.headers,
|
|
311
|
-
cookies=self.cookies,
|
|
312
|
-
timeout=self.timeout,
|
|
313
|
-
stream=stream,
|
|
314
|
-
proxies=self.proxies,
|
|
315
|
-
)
|
|
316
|
-
except TimeoutError as error:
|
|
317
|
-
# Convert any `TimeoutError` to a `InferenceTimeoutError`
|
|
318
|
-
raise InferenceTimeoutError(f"Inference call timed out: {request_parameters.url}") from error # type: ignore
|
|
319
|
-
|
|
340
|
+
with _open_as_binary(request_parameters.data) as data_as_binary:
|
|
320
341
|
try:
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
342
|
+
response = get_session().post(
|
|
343
|
+
request_parameters.url,
|
|
344
|
+
json=request_parameters.json,
|
|
345
|
+
data=data_as_binary,
|
|
346
|
+
headers=request_parameters.headers,
|
|
347
|
+
cookies=self.cookies,
|
|
348
|
+
timeout=self.timeout,
|
|
349
|
+
stream=stream,
|
|
350
|
+
proxies=self.proxies,
|
|
351
|
+
)
|
|
352
|
+
except TimeoutError as error:
|
|
353
|
+
# Convert any `TimeoutError` to a `InferenceTimeoutError`
|
|
354
|
+
raise InferenceTimeoutError(f"Inference call timed out: {request_parameters.url}") from error # type: ignore
|
|
355
|
+
|
|
356
|
+
try:
|
|
357
|
+
hf_raise_for_status(response)
|
|
358
|
+
return response.iter_lines() if stream else response.content
|
|
359
|
+
except HTTPError as error:
|
|
360
|
+
if error.response.status_code == 422 and request_parameters.task != "unknown":
|
|
361
|
+
msg = str(error.args[0])
|
|
362
|
+
if len(error.response.text) > 0:
|
|
363
|
+
msg += f"\n{error.response.text}\n"
|
|
364
|
+
error.args = (msg,) + error.args[1:]
|
|
365
|
+
raise
|
|
330
366
|
|
|
331
367
|
def audio_classification(
|
|
332
368
|
self,
|
|
@@ -910,7 +946,7 @@ class InferenceClient:
|
|
|
910
946
|
... messages=messages,
|
|
911
947
|
... response_format=response_format,
|
|
912
948
|
... max_tokens=500,
|
|
913
|
-
)
|
|
949
|
+
... )
|
|
914
950
|
>>> response.choices[0].message.content
|
|
915
951
|
'{\n\n"activity": "bike ride",\n"animals": ["puppy", "cat", "raccoon"],\n"animals_seen": 3,\n"location": "park"}'
|
|
916
952
|
```
|
|
@@ -1272,7 +1308,7 @@ class InferenceClient:
|
|
|
1272
1308
|
[ImageSegmentationOutputElement(score=0.989008, label='LABEL_184', mask=<PIL.PngImagePlugin.PngImageFile image mode=L size=400x300 at 0x7FDD2B129CC0>), ...]
|
|
1273
1309
|
```
|
|
1274
1310
|
"""
|
|
1275
|
-
provider_helper = get_provider_helper(self.provider, task="
|
|
1311
|
+
provider_helper = get_provider_helper(self.provider, task="image-segmentation")
|
|
1276
1312
|
request_parameters = provider_helper.prepare_request(
|
|
1277
1313
|
inputs=image,
|
|
1278
1314
|
parameters={
|
|
@@ -2602,7 +2638,7 @@ class InferenceClient:
|
|
|
2602
2638
|
api_key=self.token,
|
|
2603
2639
|
)
|
|
2604
2640
|
response = self._inner_post(request_parameters)
|
|
2605
|
-
response = provider_helper.get_response(response)
|
|
2641
|
+
response = provider_helper.get_response(response, request_parameters)
|
|
2606
2642
|
return response
|
|
2607
2643
|
|
|
2608
2644
|
def text_to_speech(
|
|
@@ -3033,22 +3069,14 @@ class InferenceClient:
|
|
|
3033
3069
|
response = self._inner_post(request_parameters)
|
|
3034
3070
|
return VisualQuestionAnsweringOutputElement.parse_obj_as_list(response)
|
|
3035
3071
|
|
|
3036
|
-
@_deprecate_arguments(
|
|
3037
|
-
version="0.30.0",
|
|
3038
|
-
deprecated_args=["labels"],
|
|
3039
|
-
custom_message="`labels`has been renamed to `candidate_labels` and will be removed in huggingface_hub>=0.30.0.",
|
|
3040
|
-
)
|
|
3041
3072
|
def zero_shot_classification(
|
|
3042
3073
|
self,
|
|
3043
3074
|
text: str,
|
|
3044
|
-
|
|
3045
|
-
candidate_labels: List[str] = None, # type: ignore
|
|
3075
|
+
candidate_labels: List[str],
|
|
3046
3076
|
*,
|
|
3047
3077
|
multi_label: Optional[bool] = False,
|
|
3048
3078
|
hypothesis_template: Optional[str] = None,
|
|
3049
3079
|
model: Optional[str] = None,
|
|
3050
|
-
# deprecated argument
|
|
3051
|
-
labels: List[str] = None, # type: ignore
|
|
3052
3080
|
) -> List[ZeroShotClassificationOutputElement]:
|
|
3053
3081
|
"""
|
|
3054
3082
|
Provide as input a text and a set of candidate labels to classify the input text.
|
|
@@ -3127,16 +3155,6 @@ class InferenceClient:
|
|
|
3127
3155
|
]
|
|
3128
3156
|
```
|
|
3129
3157
|
"""
|
|
3130
|
-
# handle deprecation
|
|
3131
|
-
if labels is not None:
|
|
3132
|
-
if candidate_labels is not None:
|
|
3133
|
-
raise ValueError(
|
|
3134
|
-
"Cannot specify both `labels` and `candidate_labels`. Use `candidate_labels` instead."
|
|
3135
|
-
)
|
|
3136
|
-
candidate_labels = labels
|
|
3137
|
-
elif candidate_labels is None:
|
|
3138
|
-
raise ValueError("Must specify `candidate_labels`")
|
|
3139
|
-
|
|
3140
3158
|
provider_helper = get_provider_helper(self.provider, task="zero-shot-classification")
|
|
3141
3159
|
request_parameters = provider_helper.prepare_request(
|
|
3142
3160
|
inputs=text,
|
|
@@ -3156,16 +3174,10 @@ class InferenceClient:
|
|
|
3156
3174
|
for label, score in zip(output["labels"], output["scores"])
|
|
3157
3175
|
]
|
|
3158
3176
|
|
|
3159
|
-
@_deprecate_arguments(
|
|
3160
|
-
version="0.30.0",
|
|
3161
|
-
deprecated_args=["labels"],
|
|
3162
|
-
custom_message="`labels`has been renamed to `candidate_labels` and will be removed in huggingface_hub>=0.30.0.",
|
|
3163
|
-
)
|
|
3164
3177
|
def zero_shot_image_classification(
|
|
3165
3178
|
self,
|
|
3166
3179
|
image: ContentT,
|
|
3167
|
-
|
|
3168
|
-
candidate_labels: List[str] = None, # type: ignore
|
|
3180
|
+
candidate_labels: List[str],
|
|
3169
3181
|
*,
|
|
3170
3182
|
model: Optional[str] = None,
|
|
3171
3183
|
hypothesis_template: Optional[str] = None,
|
|
@@ -3210,15 +3222,6 @@ class InferenceClient:
|
|
|
3210
3222
|
[ZeroShotImageClassificationOutputElement(label='dog', score=0.956),...]
|
|
3211
3223
|
```
|
|
3212
3224
|
"""
|
|
3213
|
-
# handle deprecation
|
|
3214
|
-
if labels is not None:
|
|
3215
|
-
if candidate_labels is not None:
|
|
3216
|
-
raise ValueError(
|
|
3217
|
-
"Cannot specify both `labels` and `candidate_labels`. Use `candidate_labels` instead."
|
|
3218
|
-
)
|
|
3219
|
-
candidate_labels = labels
|
|
3220
|
-
elif candidate_labels is None:
|
|
3221
|
-
raise ValueError("Must specify `candidate_labels`")
|
|
3222
3225
|
# Raise ValueError if input is less than 2 labels
|
|
3223
3226
|
if len(candidate_labels) < 2:
|
|
3224
3227
|
raise ValueError("You must specify at least 2 classes to compare.")
|
|
@@ -87,7 +87,8 @@ from huggingface_hub.inference._generated.types import (
|
|
|
87
87
|
)
|
|
88
88
|
from huggingface_hub.inference._providers import PROVIDER_T, HFInferenceTask, get_provider_helper
|
|
89
89
|
from huggingface_hub.utils import build_hf_headers, get_session, hf_raise_for_status
|
|
90
|
-
from huggingface_hub.utils.
|
|
90
|
+
from huggingface_hub.utils._auth import get_token
|
|
91
|
+
from huggingface_hub.utils._deprecation import _deprecate_method
|
|
91
92
|
|
|
92
93
|
from .._common import _async_yield_from, _import_aiohttp
|
|
93
94
|
|
|
@@ -120,12 +121,11 @@ class AsyncInferenceClient:
|
|
|
120
121
|
path will be appended to the base URL (see the [TGI Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api)
|
|
121
122
|
documentation for details). When passing a URL as `model`, the client will not append any suffix path to it.
|
|
122
123
|
provider (`str`, *optional*):
|
|
123
|
-
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"replicate"`, "sambanova"` or `"together"`.
|
|
124
|
+
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"openai"`, `"replicate"`, "sambanova"` or `"together"`.
|
|
124
125
|
defaults to hf-inference (Hugging Face Serverless Inference API).
|
|
125
126
|
If model is a URL or `base_url` is passed, then `provider` is not used.
|
|
126
|
-
token (`str
|
|
127
|
+
token (`str`, *optional*):
|
|
127
128
|
Hugging Face token. Will default to the locally saved token if not provided.
|
|
128
|
-
Pass `token=False` if you don't want to send your token to the server.
|
|
129
129
|
Note: for better compatibility with OpenAI's client, `token` has been aliased as `api_key`. Those 2
|
|
130
130
|
arguments are mutually exclusive and have the exact same behavior.
|
|
131
131
|
timeout (`float`, `optional`):
|
|
@@ -134,6 +134,9 @@ class AsyncInferenceClient:
|
|
|
134
134
|
headers (`Dict[str, str]`, `optional`):
|
|
135
135
|
Additional headers to send to the server. By default only the authorization and user-agent headers are sent.
|
|
136
136
|
Values in this dictionary will override the default values.
|
|
137
|
+
bill_to (`str`, `optional`):
|
|
138
|
+
The billing account to use for the requests. By default the requests are billed on the user's account.
|
|
139
|
+
Requests can only be billed to an organization the user is a member of, and which has subscribed to Enterprise Hub.
|
|
137
140
|
cookies (`Dict[str, str]`, `optional`):
|
|
138
141
|
Additional cookies to send to the server.
|
|
139
142
|
trust_env ('bool', 'optional'):
|
|
@@ -159,6 +162,7 @@ class AsyncInferenceClient:
|
|
|
159
162
|
cookies: Optional[Dict[str, str]] = None,
|
|
160
163
|
trust_env: bool = False,
|
|
161
164
|
proxies: Optional[Any] = None,
|
|
165
|
+
bill_to: Optional[str] = None,
|
|
162
166
|
# OpenAI compatibility
|
|
163
167
|
base_url: Optional[str] = None,
|
|
164
168
|
api_key: Optional[str] = None,
|
|
@@ -176,10 +180,43 @@ class AsyncInferenceClient:
|
|
|
176
180
|
" `api_key` is an alias for `token` to make the API compatible with OpenAI's client."
|
|
177
181
|
" It has the exact same behavior as `token`."
|
|
178
182
|
)
|
|
183
|
+
token = token if token is not None else api_key
|
|
184
|
+
if isinstance(token, bool):
|
|
185
|
+
# Legacy behavior: previously is was possible to pass `token=False` to disable authentication. This is not
|
|
186
|
+
# supported anymore as authentication is required. Better to explicitly raise here rather than risking
|
|
187
|
+
# sending the locally saved token without the user knowing about it.
|
|
188
|
+
if token is False:
|
|
189
|
+
raise ValueError(
|
|
190
|
+
"Cannot use `token=False` to disable authentication as authentication is required to run Inference."
|
|
191
|
+
)
|
|
192
|
+
warnings.warn(
|
|
193
|
+
"Using `token=True` to automatically use the locally saved token is deprecated and will be removed in a future release. "
|
|
194
|
+
"Please use `token=None` instead (default).",
|
|
195
|
+
DeprecationWarning,
|
|
196
|
+
)
|
|
197
|
+
token = get_token()
|
|
179
198
|
|
|
180
199
|
self.model: Optional[str] = base_url or model
|
|
181
|
-
self.token: Optional[str] = token
|
|
182
|
-
|
|
200
|
+
self.token: Optional[str] = token
|
|
201
|
+
|
|
202
|
+
self.headers = {**headers} if headers is not None else {}
|
|
203
|
+
if bill_to is not None:
|
|
204
|
+
if (
|
|
205
|
+
constants.HUGGINGFACE_HEADER_X_BILL_TO in self.headers
|
|
206
|
+
and self.headers[constants.HUGGINGFACE_HEADER_X_BILL_TO] != bill_to
|
|
207
|
+
):
|
|
208
|
+
warnings.warn(
|
|
209
|
+
f"Overriding existing '{self.headers[constants.HUGGINGFACE_HEADER_X_BILL_TO]}' value in headers with '{bill_to}'.",
|
|
210
|
+
UserWarning,
|
|
211
|
+
)
|
|
212
|
+
self.headers[constants.HUGGINGFACE_HEADER_X_BILL_TO] = bill_to
|
|
213
|
+
|
|
214
|
+
if token is not None and not token.startswith("hf_"):
|
|
215
|
+
warnings.warn(
|
|
216
|
+
"You've provided an external provider's API key, so requests will be billed directly by the provider. "
|
|
217
|
+
"The `bill_to` parameter is only applicable for Hugging Face billing and will be ignored.",
|
|
218
|
+
UserWarning,
|
|
219
|
+
)
|
|
183
220
|
|
|
184
221
|
# Configure provider
|
|
185
222
|
self.provider = provider if provider is not None else "hf-inference"
|
|
@@ -298,40 +335,39 @@ class AsyncInferenceClient:
|
|
|
298
335
|
if request_parameters.task in TASKS_EXPECTING_IMAGES and "Accept" not in request_parameters.headers:
|
|
299
336
|
request_parameters.headers["Accept"] = "image/png"
|
|
300
337
|
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
content = await response.read()
|
|
322
|
-
await session.close()
|
|
323
|
-
return content
|
|
324
|
-
except asyncio.TimeoutError as error:
|
|
325
|
-
await session.close()
|
|
326
|
-
# Convert any `TimeoutError` to a `InferenceTimeoutError`
|
|
327
|
-
raise InferenceTimeoutError(f"Inference call timed out: {request_parameters.url}") from error # type: ignore
|
|
328
|
-
except aiohttp.ClientResponseError as error:
|
|
329
|
-
error.response_error_payload = response_error_payload
|
|
330
|
-
await session.close()
|
|
331
|
-
raise error
|
|
332
|
-
except Exception:
|
|
338
|
+
with _open_as_binary(request_parameters.data) as data_as_binary:
|
|
339
|
+
# Do not use context manager as we don't want to close the connection immediately when returning
|
|
340
|
+
# a stream
|
|
341
|
+
session = self._get_client_session(headers=request_parameters.headers)
|
|
342
|
+
|
|
343
|
+
try:
|
|
344
|
+
response = await session.post(
|
|
345
|
+
request_parameters.url, json=request_parameters.json, data=data_as_binary, proxy=self.proxies
|
|
346
|
+
)
|
|
347
|
+
response_error_payload = None
|
|
348
|
+
if response.status != 200:
|
|
349
|
+
try:
|
|
350
|
+
response_error_payload = await response.json() # get payload before connection closed
|
|
351
|
+
except Exception:
|
|
352
|
+
pass
|
|
353
|
+
response.raise_for_status()
|
|
354
|
+
if stream:
|
|
355
|
+
return _async_yield_from(session, response)
|
|
356
|
+
else:
|
|
357
|
+
content = await response.read()
|
|
333
358
|
await session.close()
|
|
334
|
-
|
|
359
|
+
return content
|
|
360
|
+
except asyncio.TimeoutError as error:
|
|
361
|
+
await session.close()
|
|
362
|
+
# Convert any `TimeoutError` to a `InferenceTimeoutError`
|
|
363
|
+
raise InferenceTimeoutError(f"Inference call timed out: {request_parameters.url}") from error # type: ignore
|
|
364
|
+
except aiohttp.ClientResponseError as error:
|
|
365
|
+
error.response_error_payload = response_error_payload
|
|
366
|
+
await session.close()
|
|
367
|
+
raise error
|
|
368
|
+
except Exception:
|
|
369
|
+
await session.close()
|
|
370
|
+
raise
|
|
335
371
|
|
|
336
372
|
async def __aenter__(self):
|
|
337
373
|
return self
|
|
@@ -950,7 +986,7 @@ class AsyncInferenceClient:
|
|
|
950
986
|
... messages=messages,
|
|
951
987
|
... response_format=response_format,
|
|
952
988
|
... max_tokens=500,
|
|
953
|
-
)
|
|
989
|
+
... )
|
|
954
990
|
>>> response.choices[0].message.content
|
|
955
991
|
'{\n\n"activity": "bike ride",\n"animals": ["puppy", "cat", "raccoon"],\n"animals_seen": 3,\n"location": "park"}'
|
|
956
992
|
```
|
|
@@ -1317,7 +1353,7 @@ class AsyncInferenceClient:
|
|
|
1317
1353
|
[ImageSegmentationOutputElement(score=0.989008, label='LABEL_184', mask=<PIL.PngImagePlugin.PngImageFile image mode=L size=400x300 at 0x7FDD2B129CC0>), ...]
|
|
1318
1354
|
```
|
|
1319
1355
|
"""
|
|
1320
|
-
provider_helper = get_provider_helper(self.provider, task="
|
|
1356
|
+
provider_helper = get_provider_helper(self.provider, task="image-segmentation")
|
|
1321
1357
|
request_parameters = provider_helper.prepare_request(
|
|
1322
1358
|
inputs=image,
|
|
1323
1359
|
parameters={
|
|
@@ -2659,7 +2695,7 @@ class AsyncInferenceClient:
|
|
|
2659
2695
|
api_key=self.token,
|
|
2660
2696
|
)
|
|
2661
2697
|
response = await self._inner_post(request_parameters)
|
|
2662
|
-
response = provider_helper.get_response(response)
|
|
2698
|
+
response = provider_helper.get_response(response, request_parameters)
|
|
2663
2699
|
return response
|
|
2664
2700
|
|
|
2665
2701
|
async def text_to_speech(
|
|
@@ -3094,22 +3130,14 @@ class AsyncInferenceClient:
|
|
|
3094
3130
|
response = await self._inner_post(request_parameters)
|
|
3095
3131
|
return VisualQuestionAnsweringOutputElement.parse_obj_as_list(response)
|
|
3096
3132
|
|
|
3097
|
-
@_deprecate_arguments(
|
|
3098
|
-
version="0.30.0",
|
|
3099
|
-
deprecated_args=["labels"],
|
|
3100
|
-
custom_message="`labels`has been renamed to `candidate_labels` and will be removed in huggingface_hub>=0.30.0.",
|
|
3101
|
-
)
|
|
3102
3133
|
async def zero_shot_classification(
|
|
3103
3134
|
self,
|
|
3104
3135
|
text: str,
|
|
3105
|
-
|
|
3106
|
-
candidate_labels: List[str] = None, # type: ignore
|
|
3136
|
+
candidate_labels: List[str],
|
|
3107
3137
|
*,
|
|
3108
3138
|
multi_label: Optional[bool] = False,
|
|
3109
3139
|
hypothesis_template: Optional[str] = None,
|
|
3110
3140
|
model: Optional[str] = None,
|
|
3111
|
-
# deprecated argument
|
|
3112
|
-
labels: List[str] = None, # type: ignore
|
|
3113
3141
|
) -> List[ZeroShotClassificationOutputElement]:
|
|
3114
3142
|
"""
|
|
3115
3143
|
Provide as input a text and a set of candidate labels to classify the input text.
|
|
@@ -3190,16 +3218,6 @@ class AsyncInferenceClient:
|
|
|
3190
3218
|
]
|
|
3191
3219
|
```
|
|
3192
3220
|
"""
|
|
3193
|
-
# handle deprecation
|
|
3194
|
-
if labels is not None:
|
|
3195
|
-
if candidate_labels is not None:
|
|
3196
|
-
raise ValueError(
|
|
3197
|
-
"Cannot specify both `labels` and `candidate_labels`. Use `candidate_labels` instead."
|
|
3198
|
-
)
|
|
3199
|
-
candidate_labels = labels
|
|
3200
|
-
elif candidate_labels is None:
|
|
3201
|
-
raise ValueError("Must specify `candidate_labels`")
|
|
3202
|
-
|
|
3203
3221
|
provider_helper = get_provider_helper(self.provider, task="zero-shot-classification")
|
|
3204
3222
|
request_parameters = provider_helper.prepare_request(
|
|
3205
3223
|
inputs=text,
|
|
@@ -3219,16 +3237,10 @@ class AsyncInferenceClient:
|
|
|
3219
3237
|
for label, score in zip(output["labels"], output["scores"])
|
|
3220
3238
|
]
|
|
3221
3239
|
|
|
3222
|
-
@_deprecate_arguments(
|
|
3223
|
-
version="0.30.0",
|
|
3224
|
-
deprecated_args=["labels"],
|
|
3225
|
-
custom_message="`labels`has been renamed to `candidate_labels` and will be removed in huggingface_hub>=0.30.0.",
|
|
3226
|
-
)
|
|
3227
3240
|
async def zero_shot_image_classification(
|
|
3228
3241
|
self,
|
|
3229
3242
|
image: ContentT,
|
|
3230
|
-
|
|
3231
|
-
candidate_labels: List[str] = None, # type: ignore
|
|
3243
|
+
candidate_labels: List[str],
|
|
3232
3244
|
*,
|
|
3233
3245
|
model: Optional[str] = None,
|
|
3234
3246
|
hypothesis_template: Optional[str] = None,
|
|
@@ -3274,15 +3286,6 @@ class AsyncInferenceClient:
|
|
|
3274
3286
|
[ZeroShotImageClassificationOutputElement(label='dog', score=0.956),...]
|
|
3275
3287
|
```
|
|
3276
3288
|
"""
|
|
3277
|
-
# handle deprecation
|
|
3278
|
-
if labels is not None:
|
|
3279
|
-
if candidate_labels is not None:
|
|
3280
|
-
raise ValueError(
|
|
3281
|
-
"Cannot specify both `labels` and `candidate_labels`. Use `candidate_labels` instead."
|
|
3282
|
-
)
|
|
3283
|
-
candidate_labels = labels
|
|
3284
|
-
elif candidate_labels is None:
|
|
3285
|
-
raise ValueError("Must specify `candidate_labels`")
|
|
3286
3289
|
# Raise ValueError if input is less than 2 labels
|
|
3287
3290
|
if len(candidate_labels) < 2:
|
|
3288
3291
|
raise ValueError("You must specify at least 2 classes to compare.")
|
|
@@ -30,6 +30,7 @@ from .chat_completion import (
|
|
|
30
30
|
ChatCompletionInputMessageChunkType,
|
|
31
31
|
ChatCompletionInputStreamOptions,
|
|
32
32
|
ChatCompletionInputTool,
|
|
33
|
+
ChatCompletionInputToolCall,
|
|
33
34
|
ChatCompletionInputToolChoiceClass,
|
|
34
35
|
ChatCompletionInputToolChoiceEnum,
|
|
35
36
|
ChatCompletionInputURL,
|