huggingface-hub 0.29.3rc0__py3-none-any.whl → 0.30.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of huggingface-hub might be problematic. Click here for more details.
- huggingface_hub/__init__.py +16 -1
- huggingface_hub/_commit_api.py +142 -4
- huggingface_hub/_space_api.py +15 -2
- huggingface_hub/_webhooks_server.py +2 -0
- huggingface_hub/commands/delete_cache.py +66 -20
- huggingface_hub/commands/upload.py +16 -2
- huggingface_hub/constants.py +44 -7
- huggingface_hub/errors.py +19 -0
- huggingface_hub/file_download.py +163 -35
- huggingface_hub/hf_api.py +349 -28
- huggingface_hub/hub_mixin.py +19 -4
- huggingface_hub/inference/_client.py +50 -69
- huggingface_hub/inference/_generated/_async_client.py +57 -76
- huggingface_hub/inference/_generated/types/__init__.py +1 -0
- huggingface_hub/inference/_generated/types/chat_completion.py +20 -10
- huggingface_hub/inference/_generated/types/image_to_image.py +2 -0
- huggingface_hub/inference/_providers/__init__.py +7 -1
- huggingface_hub/inference/_providers/_common.py +9 -5
- huggingface_hub/inference/_providers/black_forest_labs.py +5 -5
- huggingface_hub/inference/_providers/cohere.py +1 -1
- huggingface_hub/inference/_providers/fal_ai.py +64 -7
- huggingface_hub/inference/_providers/fireworks_ai.py +4 -1
- huggingface_hub/inference/_providers/hf_inference.py +41 -4
- huggingface_hub/inference/_providers/hyperbolic.py +3 -3
- huggingface_hub/inference/_providers/nebius.py +3 -3
- huggingface_hub/inference/_providers/novita.py +35 -5
- huggingface_hub/inference/_providers/openai.py +22 -0
- huggingface_hub/inference/_providers/replicate.py +3 -3
- huggingface_hub/inference/_providers/together.py +3 -3
- huggingface_hub/utils/__init__.py +8 -0
- huggingface_hub/utils/_http.py +4 -1
- huggingface_hub/utils/_runtime.py +11 -0
- huggingface_hub/utils/_xet.py +199 -0
- huggingface_hub/utils/tqdm.py +30 -2
- {huggingface_hub-0.29.3rc0.dist-info → huggingface_hub-0.30.0rc1.dist-info}/METADATA +3 -1
- {huggingface_hub-0.29.3rc0.dist-info → huggingface_hub-0.30.0rc1.dist-info}/RECORD +40 -38
- {huggingface_hub-0.29.3rc0.dist-info → huggingface_hub-0.30.0rc1.dist-info}/LICENSE +0 -0
- {huggingface_hub-0.29.3rc0.dist-info → huggingface_hub-0.30.0rc1.dist-info}/WHEEL +0 -0
- {huggingface_hub-0.29.3rc0.dist-info → huggingface_hub-0.30.0rc1.dist-info}/entry_points.txt +0 -0
- {huggingface_hub-0.29.3rc0.dist-info → huggingface_hub-0.30.0rc1.dist-info}/top_level.txt +0 -0
huggingface_hub/hub_mixin.py
CHANGED
|
@@ -58,7 +58,8 @@ DEFAULT_MODEL_CARD = """
|
|
|
58
58
|
---
|
|
59
59
|
|
|
60
60
|
This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
|
|
61
|
-
-
|
|
61
|
+
- Code: {{ repo_url | default("[More Information Needed]", true) }}
|
|
62
|
+
- Paper: {{ paper_url | default("[More Information Needed]", true) }}
|
|
62
63
|
- Docs: {{ docs_url | default("[More Information Needed]", true) }}
|
|
63
64
|
"""
|
|
64
65
|
|
|
@@ -67,8 +68,9 @@ This model has been pushed to the Hub using the [PytorchModelHubMixin](https://h
|
|
|
67
68
|
class MixinInfo:
|
|
68
69
|
model_card_template: str
|
|
69
70
|
model_card_data: ModelCardData
|
|
70
|
-
repo_url: Optional[str] = None
|
|
71
71
|
docs_url: Optional[str] = None
|
|
72
|
+
paper_url: Optional[str] = None
|
|
73
|
+
repo_url: Optional[str] = None
|
|
72
74
|
|
|
73
75
|
|
|
74
76
|
class ModelHubMixin:
|
|
@@ -88,6 +90,8 @@ class ModelHubMixin:
|
|
|
88
90
|
Args:
|
|
89
91
|
repo_url (`str`, *optional*):
|
|
90
92
|
URL of the library repository. Used to generate model card.
|
|
93
|
+
paper_url (`str`, *optional*):
|
|
94
|
+
URL of the library paper. Used to generate model card.
|
|
91
95
|
docs_url (`str`, *optional*):
|
|
92
96
|
URL of the library documentation. Used to generate model card.
|
|
93
97
|
model_card_template (`str`, *optional*):
|
|
@@ -110,7 +114,7 @@ class ModelHubMixin:
|
|
|
110
114
|
pipeline_tag (`str`, *optional*):
|
|
111
115
|
Tag of the pipeline. Used to generate model card. E.g. "text-classification".
|
|
112
116
|
tags (`List[str]`, *optional*):
|
|
113
|
-
Tags to be added to the model card. Used to generate model card. E.g. ["
|
|
117
|
+
Tags to be added to the model card. Used to generate model card. E.g. ["computer-vision"]
|
|
114
118
|
coders (`Dict[Type, Tuple[Callable, Callable]]`, *optional*):
|
|
115
119
|
Dictionary of custom types and their encoders/decoders. Used to encode/decode arguments that are not
|
|
116
120
|
jsonable by default. E.g dataclasses, argparse.Namespace, OmegaConf, etc.
|
|
@@ -124,8 +128,9 @@ class ModelHubMixin:
|
|
|
124
128
|
>>> class MyCustomModel(
|
|
125
129
|
... ModelHubMixin,
|
|
126
130
|
... library_name="my-library",
|
|
127
|
-
... tags=["
|
|
131
|
+
... tags=["computer-vision"],
|
|
128
132
|
... repo_url="https://github.com/huggingface/my-cool-library",
|
|
133
|
+
... paper_url="https://arxiv.org/abs/2304.12244",
|
|
129
134
|
... docs_url="https://huggingface.co/docs/my-cool-library",
|
|
130
135
|
... # ^ optional metadata to generate model card
|
|
131
136
|
... ):
|
|
@@ -194,6 +199,7 @@ class ModelHubMixin:
|
|
|
194
199
|
*,
|
|
195
200
|
# Generic info for model card
|
|
196
201
|
repo_url: Optional[str] = None,
|
|
202
|
+
paper_url: Optional[str] = None,
|
|
197
203
|
docs_url: Optional[str] = None,
|
|
198
204
|
# Model card template
|
|
199
205
|
model_card_template: str = DEFAULT_MODEL_CARD,
|
|
@@ -234,6 +240,7 @@ class ModelHubMixin:
|
|
|
234
240
|
|
|
235
241
|
# Inherit other info
|
|
236
242
|
info.docs_url = cls._hub_mixin_info.docs_url
|
|
243
|
+
info.paper_url = cls._hub_mixin_info.paper_url
|
|
237
244
|
info.repo_url = cls._hub_mixin_info.repo_url
|
|
238
245
|
cls._hub_mixin_info = info
|
|
239
246
|
|
|
@@ -242,6 +249,8 @@ class ModelHubMixin:
|
|
|
242
249
|
info.model_card_template = model_card_template
|
|
243
250
|
if repo_url is not None:
|
|
244
251
|
info.repo_url = repo_url
|
|
252
|
+
if paper_url is not None:
|
|
253
|
+
info.paper_url = paper_url
|
|
245
254
|
if docs_url is not None:
|
|
246
255
|
info.docs_url = docs_url
|
|
247
256
|
if language is not None:
|
|
@@ -334,6 +343,8 @@ class ModelHubMixin:
|
|
|
334
343
|
@classmethod
|
|
335
344
|
def _is_jsonable(cls, value: Any) -> bool:
|
|
336
345
|
"""Check if a value is JSON serializable."""
|
|
346
|
+
if is_dataclass(value):
|
|
347
|
+
return True
|
|
337
348
|
if isinstance(value, cls._hub_mixin_jsonable_custom_types):
|
|
338
349
|
return True
|
|
339
350
|
return is_jsonable(value)
|
|
@@ -341,6 +352,8 @@ class ModelHubMixin:
|
|
|
341
352
|
@classmethod
|
|
342
353
|
def _encode_arg(cls, arg: Any) -> Any:
|
|
343
354
|
"""Encode an argument into a JSON serializable format."""
|
|
355
|
+
if is_dataclass(arg):
|
|
356
|
+
return asdict(arg)
|
|
344
357
|
for type_, (encoder, _) in cls._hub_mixin_coders.items():
|
|
345
358
|
if isinstance(arg, type_):
|
|
346
359
|
if arg is None:
|
|
@@ -692,6 +705,7 @@ class ModelHubMixin:
|
|
|
692
705
|
card_data=self._hub_mixin_info.model_card_data,
|
|
693
706
|
template_str=self._hub_mixin_info.model_card_template,
|
|
694
707
|
repo_url=self._hub_mixin_info.repo_url,
|
|
708
|
+
paper_url=self._hub_mixin_info.paper_url,
|
|
695
709
|
docs_url=self._hub_mixin_info.docs_url,
|
|
696
710
|
**kwargs,
|
|
697
711
|
)
|
|
@@ -718,6 +732,7 @@ class PyTorchModelHubMixin(ModelHubMixin):
|
|
|
718
732
|
... PyTorchModelHubMixin,
|
|
719
733
|
... library_name="keras-nlp",
|
|
720
734
|
... repo_url="https://github.com/keras-team/keras-nlp",
|
|
735
|
+
... paper_url="https://arxiv.org/abs/2304.12244",
|
|
721
736
|
... docs_url="https://keras.io/keras_nlp/",
|
|
722
737
|
... # ^ optional metadata to generate model card
|
|
723
738
|
... ):
|
|
@@ -102,7 +102,8 @@ from huggingface_hub.inference._generated.types import (
|
|
|
102
102
|
)
|
|
103
103
|
from huggingface_hub.inference._providers import PROVIDER_T, HFInferenceTask, get_provider_helper
|
|
104
104
|
from huggingface_hub.utils import build_hf_headers, get_session, hf_raise_for_status
|
|
105
|
-
from huggingface_hub.utils.
|
|
105
|
+
from huggingface_hub.utils._auth import get_token
|
|
106
|
+
from huggingface_hub.utils._deprecation import _deprecate_method
|
|
106
107
|
|
|
107
108
|
|
|
108
109
|
if TYPE_CHECKING:
|
|
@@ -132,12 +133,11 @@ class InferenceClient:
|
|
|
132
133
|
path will be appended to the base URL (see the [TGI Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api)
|
|
133
134
|
documentation for details). When passing a URL as `model`, the client will not append any suffix path to it.
|
|
134
135
|
provider (`str`, *optional*):
|
|
135
|
-
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"replicate"`, "sambanova"` or `"together"`.
|
|
136
|
+
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"openai"`, `"replicate"`, "sambanova"` or `"together"`.
|
|
136
137
|
defaults to hf-inference (Hugging Face Serverless Inference API).
|
|
137
138
|
If model is a URL or `base_url` is passed, then `provider` is not used.
|
|
138
|
-
token (`str
|
|
139
|
+
token (`str`, *optional*):
|
|
139
140
|
Hugging Face token. Will default to the locally saved token if not provided.
|
|
140
|
-
Pass `token=False` if you don't want to send your token to the server.
|
|
141
141
|
Note: for better compatibility with OpenAI's client, `token` has been aliased as `api_key`. Those 2
|
|
142
142
|
arguments are mutually exclusive and have the exact same behavior.
|
|
143
143
|
timeout (`float`, `optional`):
|
|
@@ -185,9 +185,24 @@ class InferenceClient:
|
|
|
185
185
|
" `api_key` is an alias for `token` to make the API compatible with OpenAI's client."
|
|
186
186
|
" It has the exact same behavior as `token`."
|
|
187
187
|
)
|
|
188
|
+
token = token if token is not None else api_key
|
|
189
|
+
if isinstance(token, bool):
|
|
190
|
+
# Legacy behavior: previously is was possible to pass `token=False` to disable authentication. This is not
|
|
191
|
+
# supported anymore as authentication is required. Better to explicitly raise here rather than risking
|
|
192
|
+
# sending the locally saved token without the user knowing about it.
|
|
193
|
+
if token is False:
|
|
194
|
+
raise ValueError(
|
|
195
|
+
"Cannot use `token=False` to disable authentication as authentication is required to run Inference."
|
|
196
|
+
)
|
|
197
|
+
warnings.warn(
|
|
198
|
+
"Using `token=True` to automatically use the locally saved token is deprecated and will be removed in a future release. "
|
|
199
|
+
"Please use `token=None` instead (default).",
|
|
200
|
+
DeprecationWarning,
|
|
201
|
+
)
|
|
202
|
+
token = get_token()
|
|
188
203
|
|
|
189
204
|
self.model: Optional[str] = base_url or model
|
|
190
|
-
self.token: Optional[str] = token
|
|
205
|
+
self.token: Optional[str] = token
|
|
191
206
|
self.headers = headers if headers is not None else {}
|
|
192
207
|
|
|
193
208
|
# Configure provider
|
|
@@ -300,33 +315,32 @@ class InferenceClient:
|
|
|
300
315
|
if request_parameters.task in TASKS_EXPECTING_IMAGES and "Accept" not in request_parameters.headers:
|
|
301
316
|
request_parameters.headers["Accept"] = "image/png"
|
|
302
317
|
|
|
303
|
-
|
|
304
|
-
with _open_as_binary(request_parameters.data) as data_as_binary:
|
|
305
|
-
try:
|
|
306
|
-
response = get_session().post(
|
|
307
|
-
request_parameters.url,
|
|
308
|
-
json=request_parameters.json,
|
|
309
|
-
data=data_as_binary,
|
|
310
|
-
headers=request_parameters.headers,
|
|
311
|
-
cookies=self.cookies,
|
|
312
|
-
timeout=self.timeout,
|
|
313
|
-
stream=stream,
|
|
314
|
-
proxies=self.proxies,
|
|
315
|
-
)
|
|
316
|
-
except TimeoutError as error:
|
|
317
|
-
# Convert any `TimeoutError` to a `InferenceTimeoutError`
|
|
318
|
-
raise InferenceTimeoutError(f"Inference call timed out: {request_parameters.url}") from error # type: ignore
|
|
319
|
-
|
|
318
|
+
with _open_as_binary(request_parameters.data) as data_as_binary:
|
|
320
319
|
try:
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
320
|
+
response = get_session().post(
|
|
321
|
+
request_parameters.url,
|
|
322
|
+
json=request_parameters.json,
|
|
323
|
+
data=data_as_binary,
|
|
324
|
+
headers=request_parameters.headers,
|
|
325
|
+
cookies=self.cookies,
|
|
326
|
+
timeout=self.timeout,
|
|
327
|
+
stream=stream,
|
|
328
|
+
proxies=self.proxies,
|
|
329
|
+
)
|
|
330
|
+
except TimeoutError as error:
|
|
331
|
+
# Convert any `TimeoutError` to a `InferenceTimeoutError`
|
|
332
|
+
raise InferenceTimeoutError(f"Inference call timed out: {request_parameters.url}") from error # type: ignore
|
|
333
|
+
|
|
334
|
+
try:
|
|
335
|
+
hf_raise_for_status(response)
|
|
336
|
+
return response.iter_lines() if stream else response.content
|
|
337
|
+
except HTTPError as error:
|
|
338
|
+
if error.response.status_code == 422 and request_parameters.task != "unknown":
|
|
339
|
+
msg = str(error.args[0])
|
|
340
|
+
if len(error.response.text) > 0:
|
|
341
|
+
msg += f"\n{error.response.text}\n"
|
|
342
|
+
error.args = (msg,) + error.args[1:]
|
|
343
|
+
raise
|
|
330
344
|
|
|
331
345
|
def audio_classification(
|
|
332
346
|
self,
|
|
@@ -910,7 +924,7 @@ class InferenceClient:
|
|
|
910
924
|
... messages=messages,
|
|
911
925
|
... response_format=response_format,
|
|
912
926
|
... max_tokens=500,
|
|
913
|
-
)
|
|
927
|
+
... )
|
|
914
928
|
>>> response.choices[0].message.content
|
|
915
929
|
'{\n\n"activity": "bike ride",\n"animals": ["puppy", "cat", "raccoon"],\n"animals_seen": 3,\n"location": "park"}'
|
|
916
930
|
```
|
|
@@ -1272,7 +1286,7 @@ class InferenceClient:
|
|
|
1272
1286
|
[ImageSegmentationOutputElement(score=0.989008, label='LABEL_184', mask=<PIL.PngImagePlugin.PngImageFile image mode=L size=400x300 at 0x7FDD2B129CC0>), ...]
|
|
1273
1287
|
```
|
|
1274
1288
|
"""
|
|
1275
|
-
provider_helper = get_provider_helper(self.provider, task="
|
|
1289
|
+
provider_helper = get_provider_helper(self.provider, task="image-segmentation")
|
|
1276
1290
|
request_parameters = provider_helper.prepare_request(
|
|
1277
1291
|
inputs=image,
|
|
1278
1292
|
parameters={
|
|
@@ -2602,7 +2616,7 @@ class InferenceClient:
|
|
|
2602
2616
|
api_key=self.token,
|
|
2603
2617
|
)
|
|
2604
2618
|
response = self._inner_post(request_parameters)
|
|
2605
|
-
response = provider_helper.get_response(response)
|
|
2619
|
+
response = provider_helper.get_response(response, request_parameters)
|
|
2606
2620
|
return response
|
|
2607
2621
|
|
|
2608
2622
|
def text_to_speech(
|
|
@@ -3033,22 +3047,14 @@ class InferenceClient:
|
|
|
3033
3047
|
response = self._inner_post(request_parameters)
|
|
3034
3048
|
return VisualQuestionAnsweringOutputElement.parse_obj_as_list(response)
|
|
3035
3049
|
|
|
3036
|
-
@_deprecate_arguments(
|
|
3037
|
-
version="0.30.0",
|
|
3038
|
-
deprecated_args=["labels"],
|
|
3039
|
-
custom_message="`labels`has been renamed to `candidate_labels` and will be removed in huggingface_hub>=0.30.0.",
|
|
3040
|
-
)
|
|
3041
3050
|
def zero_shot_classification(
|
|
3042
3051
|
self,
|
|
3043
3052
|
text: str,
|
|
3044
|
-
|
|
3045
|
-
candidate_labels: List[str] = None, # type: ignore
|
|
3053
|
+
candidate_labels: List[str],
|
|
3046
3054
|
*,
|
|
3047
3055
|
multi_label: Optional[bool] = False,
|
|
3048
3056
|
hypothesis_template: Optional[str] = None,
|
|
3049
3057
|
model: Optional[str] = None,
|
|
3050
|
-
# deprecated argument
|
|
3051
|
-
labels: List[str] = None, # type: ignore
|
|
3052
3058
|
) -> List[ZeroShotClassificationOutputElement]:
|
|
3053
3059
|
"""
|
|
3054
3060
|
Provide as input a text and a set of candidate labels to classify the input text.
|
|
@@ -3127,16 +3133,6 @@ class InferenceClient:
|
|
|
3127
3133
|
]
|
|
3128
3134
|
```
|
|
3129
3135
|
"""
|
|
3130
|
-
# handle deprecation
|
|
3131
|
-
if labels is not None:
|
|
3132
|
-
if candidate_labels is not None:
|
|
3133
|
-
raise ValueError(
|
|
3134
|
-
"Cannot specify both `labels` and `candidate_labels`. Use `candidate_labels` instead."
|
|
3135
|
-
)
|
|
3136
|
-
candidate_labels = labels
|
|
3137
|
-
elif candidate_labels is None:
|
|
3138
|
-
raise ValueError("Must specify `candidate_labels`")
|
|
3139
|
-
|
|
3140
3136
|
provider_helper = get_provider_helper(self.provider, task="zero-shot-classification")
|
|
3141
3137
|
request_parameters = provider_helper.prepare_request(
|
|
3142
3138
|
inputs=text,
|
|
@@ -3156,16 +3152,10 @@ class InferenceClient:
|
|
|
3156
3152
|
for label, score in zip(output["labels"], output["scores"])
|
|
3157
3153
|
]
|
|
3158
3154
|
|
|
3159
|
-
@_deprecate_arguments(
|
|
3160
|
-
version="0.30.0",
|
|
3161
|
-
deprecated_args=["labels"],
|
|
3162
|
-
custom_message="`labels`has been renamed to `candidate_labels` and will be removed in huggingface_hub>=0.30.0.",
|
|
3163
|
-
)
|
|
3164
3155
|
def zero_shot_image_classification(
|
|
3165
3156
|
self,
|
|
3166
3157
|
image: ContentT,
|
|
3167
|
-
|
|
3168
|
-
candidate_labels: List[str] = None, # type: ignore
|
|
3158
|
+
candidate_labels: List[str],
|
|
3169
3159
|
*,
|
|
3170
3160
|
model: Optional[str] = None,
|
|
3171
3161
|
hypothesis_template: Optional[str] = None,
|
|
@@ -3210,15 +3200,6 @@ class InferenceClient:
|
|
|
3210
3200
|
[ZeroShotImageClassificationOutputElement(label='dog', score=0.956),...]
|
|
3211
3201
|
```
|
|
3212
3202
|
"""
|
|
3213
|
-
# handle deprecation
|
|
3214
|
-
if labels is not None:
|
|
3215
|
-
if candidate_labels is not None:
|
|
3216
|
-
raise ValueError(
|
|
3217
|
-
"Cannot specify both `labels` and `candidate_labels`. Use `candidate_labels` instead."
|
|
3218
|
-
)
|
|
3219
|
-
candidate_labels = labels
|
|
3220
|
-
elif candidate_labels is None:
|
|
3221
|
-
raise ValueError("Must specify `candidate_labels`")
|
|
3222
3203
|
# Raise ValueError if input is less than 2 labels
|
|
3223
3204
|
if len(candidate_labels) < 2:
|
|
3224
3205
|
raise ValueError("You must specify at least 2 classes to compare.")
|
|
@@ -87,7 +87,8 @@ from huggingface_hub.inference._generated.types import (
|
|
|
87
87
|
)
|
|
88
88
|
from huggingface_hub.inference._providers import PROVIDER_T, HFInferenceTask, get_provider_helper
|
|
89
89
|
from huggingface_hub.utils import build_hf_headers, get_session, hf_raise_for_status
|
|
90
|
-
from huggingface_hub.utils.
|
|
90
|
+
from huggingface_hub.utils._auth import get_token
|
|
91
|
+
from huggingface_hub.utils._deprecation import _deprecate_method
|
|
91
92
|
|
|
92
93
|
from .._common import _async_yield_from, _import_aiohttp
|
|
93
94
|
|
|
@@ -120,12 +121,11 @@ class AsyncInferenceClient:
|
|
|
120
121
|
path will be appended to the base URL (see the [TGI Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api)
|
|
121
122
|
documentation for details). When passing a URL as `model`, the client will not append any suffix path to it.
|
|
122
123
|
provider (`str`, *optional*):
|
|
123
|
-
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"replicate"`, "sambanova"` or `"together"`.
|
|
124
|
+
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"openai"`, `"replicate"`, "sambanova"` or `"together"`.
|
|
124
125
|
defaults to hf-inference (Hugging Face Serverless Inference API).
|
|
125
126
|
If model is a URL or `base_url` is passed, then `provider` is not used.
|
|
126
|
-
token (`str
|
|
127
|
+
token (`str`, *optional*):
|
|
127
128
|
Hugging Face token. Will default to the locally saved token if not provided.
|
|
128
|
-
Pass `token=False` if you don't want to send your token to the server.
|
|
129
129
|
Note: for better compatibility with OpenAI's client, `token` has been aliased as `api_key`. Those 2
|
|
130
130
|
arguments are mutually exclusive and have the exact same behavior.
|
|
131
131
|
timeout (`float`, `optional`):
|
|
@@ -176,9 +176,24 @@ class AsyncInferenceClient:
|
|
|
176
176
|
" `api_key` is an alias for `token` to make the API compatible with OpenAI's client."
|
|
177
177
|
" It has the exact same behavior as `token`."
|
|
178
178
|
)
|
|
179
|
+
token = token if token is not None else api_key
|
|
180
|
+
if isinstance(token, bool):
|
|
181
|
+
# Legacy behavior: previously is was possible to pass `token=False` to disable authentication. This is not
|
|
182
|
+
# supported anymore as authentication is required. Better to explicitly raise here rather than risking
|
|
183
|
+
# sending the locally saved token without the user knowing about it.
|
|
184
|
+
if token is False:
|
|
185
|
+
raise ValueError(
|
|
186
|
+
"Cannot use `token=False` to disable authentication as authentication is required to run Inference."
|
|
187
|
+
)
|
|
188
|
+
warnings.warn(
|
|
189
|
+
"Using `token=True` to automatically use the locally saved token is deprecated and will be removed in a future release. "
|
|
190
|
+
"Please use `token=None` instead (default).",
|
|
191
|
+
DeprecationWarning,
|
|
192
|
+
)
|
|
193
|
+
token = get_token()
|
|
179
194
|
|
|
180
195
|
self.model: Optional[str] = base_url or model
|
|
181
|
-
self.token: Optional[str] = token
|
|
196
|
+
self.token: Optional[str] = token
|
|
182
197
|
self.headers = headers if headers is not None else {}
|
|
183
198
|
|
|
184
199
|
# Configure provider
|
|
@@ -298,40 +313,39 @@ class AsyncInferenceClient:
|
|
|
298
313
|
if request_parameters.task in TASKS_EXPECTING_IMAGES and "Accept" not in request_parameters.headers:
|
|
299
314
|
request_parameters.headers["Accept"] = "image/png"
|
|
300
315
|
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
content = await response.read()
|
|
322
|
-
await session.close()
|
|
323
|
-
return content
|
|
324
|
-
except asyncio.TimeoutError as error:
|
|
325
|
-
await session.close()
|
|
326
|
-
# Convert any `TimeoutError` to a `InferenceTimeoutError`
|
|
327
|
-
raise InferenceTimeoutError(f"Inference call timed out: {request_parameters.url}") from error # type: ignore
|
|
328
|
-
except aiohttp.ClientResponseError as error:
|
|
329
|
-
error.response_error_payload = response_error_payload
|
|
330
|
-
await session.close()
|
|
331
|
-
raise error
|
|
332
|
-
except Exception:
|
|
316
|
+
with _open_as_binary(request_parameters.data) as data_as_binary:
|
|
317
|
+
# Do not use context manager as we don't want to close the connection immediately when returning
|
|
318
|
+
# a stream
|
|
319
|
+
session = self._get_client_session(headers=request_parameters.headers)
|
|
320
|
+
|
|
321
|
+
try:
|
|
322
|
+
response = await session.post(
|
|
323
|
+
request_parameters.url, json=request_parameters.json, data=data_as_binary, proxy=self.proxies
|
|
324
|
+
)
|
|
325
|
+
response_error_payload = None
|
|
326
|
+
if response.status != 200:
|
|
327
|
+
try:
|
|
328
|
+
response_error_payload = await response.json() # get payload before connection closed
|
|
329
|
+
except Exception:
|
|
330
|
+
pass
|
|
331
|
+
response.raise_for_status()
|
|
332
|
+
if stream:
|
|
333
|
+
return _async_yield_from(session, response)
|
|
334
|
+
else:
|
|
335
|
+
content = await response.read()
|
|
333
336
|
await session.close()
|
|
334
|
-
|
|
337
|
+
return content
|
|
338
|
+
except asyncio.TimeoutError as error:
|
|
339
|
+
await session.close()
|
|
340
|
+
# Convert any `TimeoutError` to a `InferenceTimeoutError`
|
|
341
|
+
raise InferenceTimeoutError(f"Inference call timed out: {request_parameters.url}") from error # type: ignore
|
|
342
|
+
except aiohttp.ClientResponseError as error:
|
|
343
|
+
error.response_error_payload = response_error_payload
|
|
344
|
+
await session.close()
|
|
345
|
+
raise error
|
|
346
|
+
except Exception:
|
|
347
|
+
await session.close()
|
|
348
|
+
raise
|
|
335
349
|
|
|
336
350
|
async def __aenter__(self):
|
|
337
351
|
return self
|
|
@@ -950,7 +964,7 @@ class AsyncInferenceClient:
|
|
|
950
964
|
... messages=messages,
|
|
951
965
|
... response_format=response_format,
|
|
952
966
|
... max_tokens=500,
|
|
953
|
-
)
|
|
967
|
+
... )
|
|
954
968
|
>>> response.choices[0].message.content
|
|
955
969
|
'{\n\n"activity": "bike ride",\n"animals": ["puppy", "cat", "raccoon"],\n"animals_seen": 3,\n"location": "park"}'
|
|
956
970
|
```
|
|
@@ -1317,7 +1331,7 @@ class AsyncInferenceClient:
|
|
|
1317
1331
|
[ImageSegmentationOutputElement(score=0.989008, label='LABEL_184', mask=<PIL.PngImagePlugin.PngImageFile image mode=L size=400x300 at 0x7FDD2B129CC0>), ...]
|
|
1318
1332
|
```
|
|
1319
1333
|
"""
|
|
1320
|
-
provider_helper = get_provider_helper(self.provider, task="
|
|
1334
|
+
provider_helper = get_provider_helper(self.provider, task="image-segmentation")
|
|
1321
1335
|
request_parameters = provider_helper.prepare_request(
|
|
1322
1336
|
inputs=image,
|
|
1323
1337
|
parameters={
|
|
@@ -2659,7 +2673,7 @@ class AsyncInferenceClient:
|
|
|
2659
2673
|
api_key=self.token,
|
|
2660
2674
|
)
|
|
2661
2675
|
response = await self._inner_post(request_parameters)
|
|
2662
|
-
response = provider_helper.get_response(response)
|
|
2676
|
+
response = provider_helper.get_response(response, request_parameters)
|
|
2663
2677
|
return response
|
|
2664
2678
|
|
|
2665
2679
|
async def text_to_speech(
|
|
@@ -3094,22 +3108,14 @@ class AsyncInferenceClient:
|
|
|
3094
3108
|
response = await self._inner_post(request_parameters)
|
|
3095
3109
|
return VisualQuestionAnsweringOutputElement.parse_obj_as_list(response)
|
|
3096
3110
|
|
|
3097
|
-
@_deprecate_arguments(
|
|
3098
|
-
version="0.30.0",
|
|
3099
|
-
deprecated_args=["labels"],
|
|
3100
|
-
custom_message="`labels`has been renamed to `candidate_labels` and will be removed in huggingface_hub>=0.30.0.",
|
|
3101
|
-
)
|
|
3102
3111
|
async def zero_shot_classification(
|
|
3103
3112
|
self,
|
|
3104
3113
|
text: str,
|
|
3105
|
-
|
|
3106
|
-
candidate_labels: List[str] = None, # type: ignore
|
|
3114
|
+
candidate_labels: List[str],
|
|
3107
3115
|
*,
|
|
3108
3116
|
multi_label: Optional[bool] = False,
|
|
3109
3117
|
hypothesis_template: Optional[str] = None,
|
|
3110
3118
|
model: Optional[str] = None,
|
|
3111
|
-
# deprecated argument
|
|
3112
|
-
labels: List[str] = None, # type: ignore
|
|
3113
3119
|
) -> List[ZeroShotClassificationOutputElement]:
|
|
3114
3120
|
"""
|
|
3115
3121
|
Provide as input a text and a set of candidate labels to classify the input text.
|
|
@@ -3190,16 +3196,6 @@ class AsyncInferenceClient:
|
|
|
3190
3196
|
]
|
|
3191
3197
|
```
|
|
3192
3198
|
"""
|
|
3193
|
-
# handle deprecation
|
|
3194
|
-
if labels is not None:
|
|
3195
|
-
if candidate_labels is not None:
|
|
3196
|
-
raise ValueError(
|
|
3197
|
-
"Cannot specify both `labels` and `candidate_labels`. Use `candidate_labels` instead."
|
|
3198
|
-
)
|
|
3199
|
-
candidate_labels = labels
|
|
3200
|
-
elif candidate_labels is None:
|
|
3201
|
-
raise ValueError("Must specify `candidate_labels`")
|
|
3202
|
-
|
|
3203
3199
|
provider_helper = get_provider_helper(self.provider, task="zero-shot-classification")
|
|
3204
3200
|
request_parameters = provider_helper.prepare_request(
|
|
3205
3201
|
inputs=text,
|
|
@@ -3219,16 +3215,10 @@ class AsyncInferenceClient:
|
|
|
3219
3215
|
for label, score in zip(output["labels"], output["scores"])
|
|
3220
3216
|
]
|
|
3221
3217
|
|
|
3222
|
-
@_deprecate_arguments(
|
|
3223
|
-
version="0.30.0",
|
|
3224
|
-
deprecated_args=["labels"],
|
|
3225
|
-
custom_message="`labels`has been renamed to `candidate_labels` and will be removed in huggingface_hub>=0.30.0.",
|
|
3226
|
-
)
|
|
3227
3218
|
async def zero_shot_image_classification(
|
|
3228
3219
|
self,
|
|
3229
3220
|
image: ContentT,
|
|
3230
|
-
|
|
3231
|
-
candidate_labels: List[str] = None, # type: ignore
|
|
3221
|
+
candidate_labels: List[str],
|
|
3232
3222
|
*,
|
|
3233
3223
|
model: Optional[str] = None,
|
|
3234
3224
|
hypothesis_template: Optional[str] = None,
|
|
@@ -3274,15 +3264,6 @@ class AsyncInferenceClient:
|
|
|
3274
3264
|
[ZeroShotImageClassificationOutputElement(label='dog', score=0.956),...]
|
|
3275
3265
|
```
|
|
3276
3266
|
"""
|
|
3277
|
-
# handle deprecation
|
|
3278
|
-
if labels is not None:
|
|
3279
|
-
if candidate_labels is not None:
|
|
3280
|
-
raise ValueError(
|
|
3281
|
-
"Cannot specify both `labels` and `candidate_labels`. Use `candidate_labels` instead."
|
|
3282
|
-
)
|
|
3283
|
-
candidate_labels = labels
|
|
3284
|
-
elif candidate_labels is None:
|
|
3285
|
-
raise ValueError("Must specify `candidate_labels`")
|
|
3286
3267
|
# Raise ValueError if input is less than 2 labels
|
|
3287
3268
|
if len(candidate_labels) < 2:
|
|
3288
3269
|
raise ValueError("You must specify at least 2 classes to compare.")
|
|
@@ -30,6 +30,7 @@ from .chat_completion import (
|
|
|
30
30
|
ChatCompletionInputMessageChunkType,
|
|
31
31
|
ChatCompletionInputStreamOptions,
|
|
32
32
|
ChatCompletionInputTool,
|
|
33
|
+
ChatCompletionInputToolCall,
|
|
33
34
|
ChatCompletionInputToolChoiceClass,
|
|
34
35
|
ChatCompletionInputToolChoiceEnum,
|
|
35
36
|
ChatCompletionInputURL,
|
|
@@ -23,11 +23,26 @@ class ChatCompletionInputMessageChunk(BaseInferenceType):
|
|
|
23
23
|
text: Optional[str] = None
|
|
24
24
|
|
|
25
25
|
|
|
26
|
+
@dataclass_with_extra
|
|
27
|
+
class ChatCompletionInputFunctionDefinition(BaseInferenceType):
|
|
28
|
+
arguments: Any
|
|
29
|
+
name: str
|
|
30
|
+
description: Optional[str] = None
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass_with_extra
|
|
34
|
+
class ChatCompletionInputToolCall(BaseInferenceType):
|
|
35
|
+
function: ChatCompletionInputFunctionDefinition
|
|
36
|
+
id: str
|
|
37
|
+
type: str
|
|
38
|
+
|
|
39
|
+
|
|
26
40
|
@dataclass_with_extra
|
|
27
41
|
class ChatCompletionInputMessage(BaseInferenceType):
|
|
28
|
-
content: Union[List[ChatCompletionInputMessageChunk], str]
|
|
29
42
|
role: str
|
|
43
|
+
content: Optional[Union[List[ChatCompletionInputMessageChunk], str]] = None
|
|
30
44
|
name: Optional[str] = None
|
|
45
|
+
tool_calls: Optional[List[ChatCompletionInputToolCall]] = None
|
|
31
46
|
|
|
32
47
|
|
|
33
48
|
ChatCompletionInputGrammarTypeType = Literal["json", "regex"]
|
|
@@ -45,7 +60,7 @@ class ChatCompletionInputGrammarType(BaseInferenceType):
|
|
|
45
60
|
|
|
46
61
|
@dataclass_with_extra
|
|
47
62
|
class ChatCompletionInputStreamOptions(BaseInferenceType):
|
|
48
|
-
include_usage: bool
|
|
63
|
+
include_usage: Optional[bool] = None
|
|
49
64
|
"""If set, an additional chunk will be streamed before the data: [DONE] message. The usage
|
|
50
65
|
field on this chunk shows the token usage statistics for the entire request, and the
|
|
51
66
|
choices field will always be an empty array. All other chunks will also include a usage
|
|
@@ -66,13 +81,6 @@ class ChatCompletionInputToolChoiceClass(BaseInferenceType):
|
|
|
66
81
|
ChatCompletionInputToolChoiceEnum = Literal["auto", "none", "required"]
|
|
67
82
|
|
|
68
83
|
|
|
69
|
-
@dataclass_with_extra
|
|
70
|
-
class ChatCompletionInputFunctionDefinition(BaseInferenceType):
|
|
71
|
-
arguments: Any
|
|
72
|
-
name: str
|
|
73
|
-
description: Optional[str] = None
|
|
74
|
-
|
|
75
|
-
|
|
76
84
|
@dataclass_with_extra
|
|
77
85
|
class ChatCompletionInputTool(BaseInferenceType):
|
|
78
86
|
function: ChatCompletionInputFunctionDefinition
|
|
@@ -197,6 +205,7 @@ class ChatCompletionOutputToolCall(BaseInferenceType):
|
|
|
197
205
|
class ChatCompletionOutputMessage(BaseInferenceType):
|
|
198
206
|
role: str
|
|
199
207
|
content: Optional[str] = None
|
|
208
|
+
tool_call_id: Optional[str] = None
|
|
200
209
|
tool_calls: Optional[List[ChatCompletionOutputToolCall]] = None
|
|
201
210
|
|
|
202
211
|
|
|
@@ -249,7 +258,8 @@ class ChatCompletionStreamOutputDeltaToolCall(BaseInferenceType):
|
|
|
249
258
|
class ChatCompletionStreamOutputDelta(BaseInferenceType):
|
|
250
259
|
role: str
|
|
251
260
|
content: Optional[str] = None
|
|
252
|
-
|
|
261
|
+
tool_call_id: Optional[str] = None
|
|
262
|
+
tool_calls: Optional[List[ChatCompletionStreamOutputDeltaToolCall]] = None
|
|
253
263
|
|
|
254
264
|
|
|
255
265
|
@dataclass_with_extra
|
|
@@ -30,6 +30,8 @@ class ImageToImageParameters(BaseInferenceType):
|
|
|
30
30
|
"""For diffusion models. The number of denoising steps. More denoising steps usually lead to
|
|
31
31
|
a higher quality image at the expense of slower inference.
|
|
32
32
|
"""
|
|
33
|
+
prompt: Optional[str] = None
|
|
34
|
+
"""The text prompt to guide the image generation."""
|
|
33
35
|
target_size: Optional[ImageToImageTargetSize] = None
|
|
34
36
|
"""The size in pixel of the output image."""
|
|
35
37
|
|
|
@@ -14,7 +14,8 @@ from .fireworks_ai import FireworksAIConversationalTask
|
|
|
14
14
|
from .hf_inference import HFInferenceBinaryInputTask, HFInferenceConversational, HFInferenceTask
|
|
15
15
|
from .hyperbolic import HyperbolicTextGenerationTask, HyperbolicTextToImageTask
|
|
16
16
|
from .nebius import NebiusConversationalTask, NebiusTextGenerationTask, NebiusTextToImageTask
|
|
17
|
-
from .novita import NovitaConversationalTask, NovitaTextGenerationTask
|
|
17
|
+
from .novita import NovitaConversationalTask, NovitaTextGenerationTask, NovitaTextToVideoTask
|
|
18
|
+
from .openai import OpenAIConversationalTask
|
|
18
19
|
from .replicate import ReplicateTask, ReplicateTextToSpeechTask
|
|
19
20
|
from .sambanova import SambanovaConversationalTask
|
|
20
21
|
from .together import TogetherConversationalTask, TogetherTextGenerationTask, TogetherTextToImageTask
|
|
@@ -30,6 +31,7 @@ PROVIDER_T = Literal[
|
|
|
30
31
|
"hyperbolic",
|
|
31
32
|
"nebius",
|
|
32
33
|
"novita",
|
|
34
|
+
"openai",
|
|
33
35
|
"replicate",
|
|
34
36
|
"sambanova",
|
|
35
37
|
"together",
|
|
@@ -95,6 +97,10 @@ PROVIDERS: Dict[PROVIDER_T, Dict[str, TaskProviderHelper]] = {
|
|
|
95
97
|
"novita": {
|
|
96
98
|
"text-generation": NovitaTextGenerationTask(),
|
|
97
99
|
"conversational": NovitaConversationalTask(),
|
|
100
|
+
"text-to-video": NovitaTextToVideoTask(),
|
|
101
|
+
},
|
|
102
|
+
"openai": {
|
|
103
|
+
"conversational": OpenAIConversationalTask(),
|
|
98
104
|
},
|
|
99
105
|
"replicate": {
|
|
100
106
|
"text-to-image": ReplicateTask("text-to-image"),
|