huggingface-hub 0.29.3rc0__py3-none-any.whl → 0.30.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of huggingface-hub might be problematic. Click here for more details.

Files changed (40) hide show
  1. huggingface_hub/__init__.py +16 -1
  2. huggingface_hub/_commit_api.py +142 -4
  3. huggingface_hub/_space_api.py +15 -2
  4. huggingface_hub/_webhooks_server.py +2 -0
  5. huggingface_hub/commands/delete_cache.py +66 -20
  6. huggingface_hub/commands/upload.py +16 -2
  7. huggingface_hub/constants.py +45 -7
  8. huggingface_hub/errors.py +19 -0
  9. huggingface_hub/file_download.py +163 -35
  10. huggingface_hub/hf_api.py +349 -28
  11. huggingface_hub/hub_mixin.py +19 -4
  12. huggingface_hub/inference/_client.py +73 -70
  13. huggingface_hub/inference/_generated/_async_client.py +80 -77
  14. huggingface_hub/inference/_generated/types/__init__.py +1 -0
  15. huggingface_hub/inference/_generated/types/chat_completion.py +20 -10
  16. huggingface_hub/inference/_generated/types/image_to_image.py +2 -0
  17. huggingface_hub/inference/_providers/__init__.py +7 -1
  18. huggingface_hub/inference/_providers/_common.py +9 -5
  19. huggingface_hub/inference/_providers/black_forest_labs.py +5 -5
  20. huggingface_hub/inference/_providers/cohere.py +1 -1
  21. huggingface_hub/inference/_providers/fal_ai.py +64 -7
  22. huggingface_hub/inference/_providers/fireworks_ai.py +4 -1
  23. huggingface_hub/inference/_providers/hf_inference.py +41 -4
  24. huggingface_hub/inference/_providers/hyperbolic.py +3 -3
  25. huggingface_hub/inference/_providers/nebius.py +3 -3
  26. huggingface_hub/inference/_providers/novita.py +35 -5
  27. huggingface_hub/inference/_providers/openai.py +22 -0
  28. huggingface_hub/inference/_providers/replicate.py +3 -3
  29. huggingface_hub/inference/_providers/together.py +3 -3
  30. huggingface_hub/utils/__init__.py +8 -0
  31. huggingface_hub/utils/_http.py +4 -1
  32. huggingface_hub/utils/_runtime.py +11 -0
  33. huggingface_hub/utils/_xet.py +199 -0
  34. huggingface_hub/utils/tqdm.py +30 -2
  35. {huggingface_hub-0.29.3rc0.dist-info → huggingface_hub-0.30.0.dist-info}/METADATA +3 -1
  36. {huggingface_hub-0.29.3rc0.dist-info → huggingface_hub-0.30.0.dist-info}/RECORD +40 -38
  37. {huggingface_hub-0.29.3rc0.dist-info → huggingface_hub-0.30.0.dist-info}/LICENSE +0 -0
  38. {huggingface_hub-0.29.3rc0.dist-info → huggingface_hub-0.30.0.dist-info}/WHEEL +0 -0
  39. {huggingface_hub-0.29.3rc0.dist-info → huggingface_hub-0.30.0.dist-info}/entry_points.txt +0 -0
  40. {huggingface_hub-0.29.3rc0.dist-info → huggingface_hub-0.30.0.dist-info}/top_level.txt +0 -0
@@ -58,7 +58,8 @@ DEFAULT_MODEL_CARD = """
58
58
  ---
59
59
 
60
60
  This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
61
- - Library: {{ repo_url | default("[More Information Needed]", true) }}
61
+ - Code: {{ repo_url | default("[More Information Needed]", true) }}
62
+ - Paper: {{ paper_url | default("[More Information Needed]", true) }}
62
63
  - Docs: {{ docs_url | default("[More Information Needed]", true) }}
63
64
  """
64
65
 
@@ -67,8 +68,9 @@ This model has been pushed to the Hub using the [PytorchModelHubMixin](https://h
67
68
  class MixinInfo:
68
69
  model_card_template: str
69
70
  model_card_data: ModelCardData
70
- repo_url: Optional[str] = None
71
71
  docs_url: Optional[str] = None
72
+ paper_url: Optional[str] = None
73
+ repo_url: Optional[str] = None
72
74
 
73
75
 
74
76
  class ModelHubMixin:
@@ -88,6 +90,8 @@ class ModelHubMixin:
88
90
  Args:
89
91
  repo_url (`str`, *optional*):
90
92
  URL of the library repository. Used to generate model card.
93
+ paper_url (`str`, *optional*):
94
+ URL of the library paper. Used to generate model card.
91
95
  docs_url (`str`, *optional*):
92
96
  URL of the library documentation. Used to generate model card.
93
97
  model_card_template (`str`, *optional*):
@@ -110,7 +114,7 @@ class ModelHubMixin:
110
114
  pipeline_tag (`str`, *optional*):
111
115
  Tag of the pipeline. Used to generate model card. E.g. "text-classification".
112
116
  tags (`List[str]`, *optional*):
113
- Tags to be added to the model card. Used to generate model card. E.g. ["x-custom-tag", "arxiv:2304.12244"]
117
+ Tags to be added to the model card. Used to generate model card. E.g. ["computer-vision"]
114
118
  coders (`Dict[Type, Tuple[Callable, Callable]]`, *optional*):
115
119
  Dictionary of custom types and their encoders/decoders. Used to encode/decode arguments that are not
116
120
  jsonable by default. E.g dataclasses, argparse.Namespace, OmegaConf, etc.
@@ -124,8 +128,9 @@ class ModelHubMixin:
124
128
  >>> class MyCustomModel(
125
129
  ... ModelHubMixin,
126
130
  ... library_name="my-library",
127
- ... tags=["x-custom-tag", "arxiv:2304.12244"],
131
+ ... tags=["computer-vision"],
128
132
  ... repo_url="https://github.com/huggingface/my-cool-library",
133
+ ... paper_url="https://arxiv.org/abs/2304.12244",
129
134
  ... docs_url="https://huggingface.co/docs/my-cool-library",
130
135
  ... # ^ optional metadata to generate model card
131
136
  ... ):
@@ -194,6 +199,7 @@ class ModelHubMixin:
194
199
  *,
195
200
  # Generic info for model card
196
201
  repo_url: Optional[str] = None,
202
+ paper_url: Optional[str] = None,
197
203
  docs_url: Optional[str] = None,
198
204
  # Model card template
199
205
  model_card_template: str = DEFAULT_MODEL_CARD,
@@ -234,6 +240,7 @@ class ModelHubMixin:
234
240
 
235
241
  # Inherit other info
236
242
  info.docs_url = cls._hub_mixin_info.docs_url
243
+ info.paper_url = cls._hub_mixin_info.paper_url
237
244
  info.repo_url = cls._hub_mixin_info.repo_url
238
245
  cls._hub_mixin_info = info
239
246
 
@@ -242,6 +249,8 @@ class ModelHubMixin:
242
249
  info.model_card_template = model_card_template
243
250
  if repo_url is not None:
244
251
  info.repo_url = repo_url
252
+ if paper_url is not None:
253
+ info.paper_url = paper_url
245
254
  if docs_url is not None:
246
255
  info.docs_url = docs_url
247
256
  if language is not None:
@@ -334,6 +343,8 @@ class ModelHubMixin:
334
343
  @classmethod
335
344
  def _is_jsonable(cls, value: Any) -> bool:
336
345
  """Check if a value is JSON serializable."""
346
+ if is_dataclass(value):
347
+ return True
337
348
  if isinstance(value, cls._hub_mixin_jsonable_custom_types):
338
349
  return True
339
350
  return is_jsonable(value)
@@ -341,6 +352,8 @@ class ModelHubMixin:
341
352
  @classmethod
342
353
  def _encode_arg(cls, arg: Any) -> Any:
343
354
  """Encode an argument into a JSON serializable format."""
355
+ if is_dataclass(arg):
356
+ return asdict(arg)
344
357
  for type_, (encoder, _) in cls._hub_mixin_coders.items():
345
358
  if isinstance(arg, type_):
346
359
  if arg is None:
@@ -692,6 +705,7 @@ class ModelHubMixin:
692
705
  card_data=self._hub_mixin_info.model_card_data,
693
706
  template_str=self._hub_mixin_info.model_card_template,
694
707
  repo_url=self._hub_mixin_info.repo_url,
708
+ paper_url=self._hub_mixin_info.paper_url,
695
709
  docs_url=self._hub_mixin_info.docs_url,
696
710
  **kwargs,
697
711
  )
@@ -718,6 +732,7 @@ class PyTorchModelHubMixin(ModelHubMixin):
718
732
  ... PyTorchModelHubMixin,
719
733
  ... library_name="keras-nlp",
720
734
  ... repo_url="https://github.com/keras-team/keras-nlp",
735
+ ... paper_url="https://arxiv.org/abs/2304.12244",
721
736
  ... docs_url="https://keras.io/keras_nlp/",
722
737
  ... # ^ optional metadata to generate model card
723
738
  ... ):
@@ -102,7 +102,8 @@ from huggingface_hub.inference._generated.types import (
102
102
  )
103
103
  from huggingface_hub.inference._providers import PROVIDER_T, HFInferenceTask, get_provider_helper
104
104
  from huggingface_hub.utils import build_hf_headers, get_session, hf_raise_for_status
105
- from huggingface_hub.utils._deprecation import _deprecate_arguments, _deprecate_method
105
+ from huggingface_hub.utils._auth import get_token
106
+ from huggingface_hub.utils._deprecation import _deprecate_method
106
107
 
107
108
 
108
109
  if TYPE_CHECKING:
@@ -132,12 +133,11 @@ class InferenceClient:
132
133
  path will be appended to the base URL (see the [TGI Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api)
133
134
  documentation for details). When passing a URL as `model`, the client will not append any suffix path to it.
134
135
  provider (`str`, *optional*):
135
- Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"replicate"`, "sambanova"` or `"together"`.
136
+ Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"openai"`, `"replicate"`, "sambanova"` or `"together"`.
136
137
  defaults to hf-inference (Hugging Face Serverless Inference API).
137
138
  If model is a URL or `base_url` is passed, then `provider` is not used.
138
- token (`str` or `bool`, *optional*):
139
+ token (`str`, *optional*):
139
140
  Hugging Face token. Will default to the locally saved token if not provided.
140
- Pass `token=False` if you don't want to send your token to the server.
141
141
  Note: for better compatibility with OpenAI's client, `token` has been aliased as `api_key`. Those 2
142
142
  arguments are mutually exclusive and have the exact same behavior.
143
143
  timeout (`float`, `optional`):
@@ -146,6 +146,9 @@ class InferenceClient:
146
146
  headers (`Dict[str, str]`, `optional`):
147
147
  Additional headers to send to the server. By default only the authorization and user-agent headers are sent.
148
148
  Values in this dictionary will override the default values.
149
+ bill_to (`str`, `optional`):
150
+ The billing account to use for the requests. By default the requests are billed on the user's account.
151
+ Requests can only be billed to an organization the user is a member of, and which has subscribed to Enterprise Hub.
149
152
  cookies (`Dict[str, str]`, `optional`):
150
153
  Additional cookies to send to the server.
151
154
  proxies (`Any`, `optional`):
@@ -168,6 +171,7 @@ class InferenceClient:
168
171
  headers: Optional[Dict[str, str]] = None,
169
172
  cookies: Optional[Dict[str, str]] = None,
170
173
  proxies: Optional[Any] = None,
174
+ bill_to: Optional[str] = None,
171
175
  # OpenAI compatibility
172
176
  base_url: Optional[str] = None,
173
177
  api_key: Optional[str] = None,
@@ -185,10 +189,43 @@ class InferenceClient:
185
189
  " `api_key` is an alias for `token` to make the API compatible with OpenAI's client."
186
190
  " It has the exact same behavior as `token`."
187
191
  )
192
+ token = token if token is not None else api_key
193
+ if isinstance(token, bool):
194
+ # Legacy behavior: previously is was possible to pass `token=False` to disable authentication. This is not
195
+ # supported anymore as authentication is required. Better to explicitly raise here rather than risking
196
+ # sending the locally saved token without the user knowing about it.
197
+ if token is False:
198
+ raise ValueError(
199
+ "Cannot use `token=False` to disable authentication as authentication is required to run Inference."
200
+ )
201
+ warnings.warn(
202
+ "Using `token=True` to automatically use the locally saved token is deprecated and will be removed in a future release. "
203
+ "Please use `token=None` instead (default).",
204
+ DeprecationWarning,
205
+ )
206
+ token = get_token()
188
207
 
189
208
  self.model: Optional[str] = base_url or model
190
- self.token: Optional[str] = token if token is not None else api_key
191
- self.headers = headers if headers is not None else {}
209
+ self.token: Optional[str] = token
210
+
211
+ self.headers = {**headers} if headers is not None else {}
212
+ if bill_to is not None:
213
+ if (
214
+ constants.HUGGINGFACE_HEADER_X_BILL_TO in self.headers
215
+ and self.headers[constants.HUGGINGFACE_HEADER_X_BILL_TO] != bill_to
216
+ ):
217
+ warnings.warn(
218
+ f"Overriding existing '{self.headers[constants.HUGGINGFACE_HEADER_X_BILL_TO]}' value in headers with '{bill_to}'.",
219
+ UserWarning,
220
+ )
221
+ self.headers[constants.HUGGINGFACE_HEADER_X_BILL_TO] = bill_to
222
+
223
+ if token is not None and not token.startswith("hf_"):
224
+ warnings.warn(
225
+ "You've provided an external provider's API key, so requests will be billed directly by the provider. "
226
+ "The `bill_to` parameter is only applicable for Hugging Face billing and will be ignored.",
227
+ UserWarning,
228
+ )
192
229
 
193
230
  # Configure provider
194
231
  self.provider = provider if provider is not None else "hf-inference"
@@ -300,33 +337,32 @@ class InferenceClient:
300
337
  if request_parameters.task in TASKS_EXPECTING_IMAGES and "Accept" not in request_parameters.headers:
301
338
  request_parameters.headers["Accept"] = "image/png"
302
339
 
303
- while True:
304
- with _open_as_binary(request_parameters.data) as data_as_binary:
305
- try:
306
- response = get_session().post(
307
- request_parameters.url,
308
- json=request_parameters.json,
309
- data=data_as_binary,
310
- headers=request_parameters.headers,
311
- cookies=self.cookies,
312
- timeout=self.timeout,
313
- stream=stream,
314
- proxies=self.proxies,
315
- )
316
- except TimeoutError as error:
317
- # Convert any `TimeoutError` to a `InferenceTimeoutError`
318
- raise InferenceTimeoutError(f"Inference call timed out: {request_parameters.url}") from error # type: ignore
319
-
340
+ with _open_as_binary(request_parameters.data) as data_as_binary:
320
341
  try:
321
- hf_raise_for_status(response)
322
- return response.iter_lines() if stream else response.content
323
- except HTTPError as error:
324
- if error.response.status_code == 422 and request_parameters.task != "unknown":
325
- msg = str(error.args[0])
326
- if len(error.response.text) > 0:
327
- msg += f"\n{error.response.text}\n"
328
- error.args = (msg,) + error.args[1:]
329
- raise
342
+ response = get_session().post(
343
+ request_parameters.url,
344
+ json=request_parameters.json,
345
+ data=data_as_binary,
346
+ headers=request_parameters.headers,
347
+ cookies=self.cookies,
348
+ timeout=self.timeout,
349
+ stream=stream,
350
+ proxies=self.proxies,
351
+ )
352
+ except TimeoutError as error:
353
+ # Convert any `TimeoutError` to a `InferenceTimeoutError`
354
+ raise InferenceTimeoutError(f"Inference call timed out: {request_parameters.url}") from error # type: ignore
355
+
356
+ try:
357
+ hf_raise_for_status(response)
358
+ return response.iter_lines() if stream else response.content
359
+ except HTTPError as error:
360
+ if error.response.status_code == 422 and request_parameters.task != "unknown":
361
+ msg = str(error.args[0])
362
+ if len(error.response.text) > 0:
363
+ msg += f"\n{error.response.text}\n"
364
+ error.args = (msg,) + error.args[1:]
365
+ raise
330
366
 
331
367
  def audio_classification(
332
368
  self,
@@ -910,7 +946,7 @@ class InferenceClient:
910
946
  ... messages=messages,
911
947
  ... response_format=response_format,
912
948
  ... max_tokens=500,
913
- )
949
+ ... )
914
950
  >>> response.choices[0].message.content
915
951
  '{\n\n"activity": "bike ride",\n"animals": ["puppy", "cat", "raccoon"],\n"animals_seen": 3,\n"location": "park"}'
916
952
  ```
@@ -1272,7 +1308,7 @@ class InferenceClient:
1272
1308
  [ImageSegmentationOutputElement(score=0.989008, label='LABEL_184', mask=<PIL.PngImagePlugin.PngImageFile image mode=L size=400x300 at 0x7FDD2B129CC0>), ...]
1273
1309
  ```
1274
1310
  """
1275
- provider_helper = get_provider_helper(self.provider, task="audio-classification")
1311
+ provider_helper = get_provider_helper(self.provider, task="image-segmentation")
1276
1312
  request_parameters = provider_helper.prepare_request(
1277
1313
  inputs=image,
1278
1314
  parameters={
@@ -2602,7 +2638,7 @@ class InferenceClient:
2602
2638
  api_key=self.token,
2603
2639
  )
2604
2640
  response = self._inner_post(request_parameters)
2605
- response = provider_helper.get_response(response)
2641
+ response = provider_helper.get_response(response, request_parameters)
2606
2642
  return response
2607
2643
 
2608
2644
  def text_to_speech(
@@ -3033,22 +3069,14 @@ class InferenceClient:
3033
3069
  response = self._inner_post(request_parameters)
3034
3070
  return VisualQuestionAnsweringOutputElement.parse_obj_as_list(response)
3035
3071
 
3036
- @_deprecate_arguments(
3037
- version="0.30.0",
3038
- deprecated_args=["labels"],
3039
- custom_message="`labels`has been renamed to `candidate_labels` and will be removed in huggingface_hub>=0.30.0.",
3040
- )
3041
3072
  def zero_shot_classification(
3042
3073
  self,
3043
3074
  text: str,
3044
- # temporarily keeping it optional for backward compatibility.
3045
- candidate_labels: List[str] = None, # type: ignore
3075
+ candidate_labels: List[str],
3046
3076
  *,
3047
3077
  multi_label: Optional[bool] = False,
3048
3078
  hypothesis_template: Optional[str] = None,
3049
3079
  model: Optional[str] = None,
3050
- # deprecated argument
3051
- labels: List[str] = None, # type: ignore
3052
3080
  ) -> List[ZeroShotClassificationOutputElement]:
3053
3081
  """
3054
3082
  Provide as input a text and a set of candidate labels to classify the input text.
@@ -3127,16 +3155,6 @@ class InferenceClient:
3127
3155
  ]
3128
3156
  ```
3129
3157
  """
3130
- # handle deprecation
3131
- if labels is not None:
3132
- if candidate_labels is not None:
3133
- raise ValueError(
3134
- "Cannot specify both `labels` and `candidate_labels`. Use `candidate_labels` instead."
3135
- )
3136
- candidate_labels = labels
3137
- elif candidate_labels is None:
3138
- raise ValueError("Must specify `candidate_labels`")
3139
-
3140
3158
  provider_helper = get_provider_helper(self.provider, task="zero-shot-classification")
3141
3159
  request_parameters = provider_helper.prepare_request(
3142
3160
  inputs=text,
@@ -3156,16 +3174,10 @@ class InferenceClient:
3156
3174
  for label, score in zip(output["labels"], output["scores"])
3157
3175
  ]
3158
3176
 
3159
- @_deprecate_arguments(
3160
- version="0.30.0",
3161
- deprecated_args=["labels"],
3162
- custom_message="`labels`has been renamed to `candidate_labels` and will be removed in huggingface_hub>=0.30.0.",
3163
- )
3164
3177
  def zero_shot_image_classification(
3165
3178
  self,
3166
3179
  image: ContentT,
3167
- # temporarily keeping it optional for backward compatibility.
3168
- candidate_labels: List[str] = None, # type: ignore
3180
+ candidate_labels: List[str],
3169
3181
  *,
3170
3182
  model: Optional[str] = None,
3171
3183
  hypothesis_template: Optional[str] = None,
@@ -3210,15 +3222,6 @@ class InferenceClient:
3210
3222
  [ZeroShotImageClassificationOutputElement(label='dog', score=0.956),...]
3211
3223
  ```
3212
3224
  """
3213
- # handle deprecation
3214
- if labels is not None:
3215
- if candidate_labels is not None:
3216
- raise ValueError(
3217
- "Cannot specify both `labels` and `candidate_labels`. Use `candidate_labels` instead."
3218
- )
3219
- candidate_labels = labels
3220
- elif candidate_labels is None:
3221
- raise ValueError("Must specify `candidate_labels`")
3222
3225
  # Raise ValueError if input is less than 2 labels
3223
3226
  if len(candidate_labels) < 2:
3224
3227
  raise ValueError("You must specify at least 2 classes to compare.")
@@ -87,7 +87,8 @@ from huggingface_hub.inference._generated.types import (
87
87
  )
88
88
  from huggingface_hub.inference._providers import PROVIDER_T, HFInferenceTask, get_provider_helper
89
89
  from huggingface_hub.utils import build_hf_headers, get_session, hf_raise_for_status
90
- from huggingface_hub.utils._deprecation import _deprecate_arguments, _deprecate_method
90
+ from huggingface_hub.utils._auth import get_token
91
+ from huggingface_hub.utils._deprecation import _deprecate_method
91
92
 
92
93
  from .._common import _async_yield_from, _import_aiohttp
93
94
 
@@ -120,12 +121,11 @@ class AsyncInferenceClient:
120
121
  path will be appended to the base URL (see the [TGI Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api)
121
122
  documentation for details). When passing a URL as `model`, the client will not append any suffix path to it.
122
123
  provider (`str`, *optional*):
123
- Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"replicate"`, "sambanova"` or `"together"`.
124
+ Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"openai"`, `"replicate"`, "sambanova"` or `"together"`.
124
125
  defaults to hf-inference (Hugging Face Serverless Inference API).
125
126
  If model is a URL or `base_url` is passed, then `provider` is not used.
126
- token (`str` or `bool`, *optional*):
127
+ token (`str`, *optional*):
127
128
  Hugging Face token. Will default to the locally saved token if not provided.
128
- Pass `token=False` if you don't want to send your token to the server.
129
129
  Note: for better compatibility with OpenAI's client, `token` has been aliased as `api_key`. Those 2
130
130
  arguments are mutually exclusive and have the exact same behavior.
131
131
  timeout (`float`, `optional`):
@@ -134,6 +134,9 @@ class AsyncInferenceClient:
134
134
  headers (`Dict[str, str]`, `optional`):
135
135
  Additional headers to send to the server. By default only the authorization and user-agent headers are sent.
136
136
  Values in this dictionary will override the default values.
137
+ bill_to (`str`, `optional`):
138
+ The billing account to use for the requests. By default the requests are billed on the user's account.
139
+ Requests can only be billed to an organization the user is a member of, and which has subscribed to Enterprise Hub.
137
140
  cookies (`Dict[str, str]`, `optional`):
138
141
  Additional cookies to send to the server.
139
142
  trust_env ('bool', 'optional'):
@@ -159,6 +162,7 @@ class AsyncInferenceClient:
159
162
  cookies: Optional[Dict[str, str]] = None,
160
163
  trust_env: bool = False,
161
164
  proxies: Optional[Any] = None,
165
+ bill_to: Optional[str] = None,
162
166
  # OpenAI compatibility
163
167
  base_url: Optional[str] = None,
164
168
  api_key: Optional[str] = None,
@@ -176,10 +180,43 @@ class AsyncInferenceClient:
176
180
  " `api_key` is an alias for `token` to make the API compatible with OpenAI's client."
177
181
  " It has the exact same behavior as `token`."
178
182
  )
183
+ token = token if token is not None else api_key
184
+ if isinstance(token, bool):
185
+ # Legacy behavior: previously is was possible to pass `token=False` to disable authentication. This is not
186
+ # supported anymore as authentication is required. Better to explicitly raise here rather than risking
187
+ # sending the locally saved token without the user knowing about it.
188
+ if token is False:
189
+ raise ValueError(
190
+ "Cannot use `token=False` to disable authentication as authentication is required to run Inference."
191
+ )
192
+ warnings.warn(
193
+ "Using `token=True` to automatically use the locally saved token is deprecated and will be removed in a future release. "
194
+ "Please use `token=None` instead (default).",
195
+ DeprecationWarning,
196
+ )
197
+ token = get_token()
179
198
 
180
199
  self.model: Optional[str] = base_url or model
181
- self.token: Optional[str] = token if token is not None else api_key
182
- self.headers = headers if headers is not None else {}
200
+ self.token: Optional[str] = token
201
+
202
+ self.headers = {**headers} if headers is not None else {}
203
+ if bill_to is not None:
204
+ if (
205
+ constants.HUGGINGFACE_HEADER_X_BILL_TO in self.headers
206
+ and self.headers[constants.HUGGINGFACE_HEADER_X_BILL_TO] != bill_to
207
+ ):
208
+ warnings.warn(
209
+ f"Overriding existing '{self.headers[constants.HUGGINGFACE_HEADER_X_BILL_TO]}' value in headers with '{bill_to}'.",
210
+ UserWarning,
211
+ )
212
+ self.headers[constants.HUGGINGFACE_HEADER_X_BILL_TO] = bill_to
213
+
214
+ if token is not None and not token.startswith("hf_"):
215
+ warnings.warn(
216
+ "You've provided an external provider's API key, so requests will be billed directly by the provider. "
217
+ "The `bill_to` parameter is only applicable for Hugging Face billing and will be ignored.",
218
+ UserWarning,
219
+ )
183
220
 
184
221
  # Configure provider
185
222
  self.provider = provider if provider is not None else "hf-inference"
@@ -298,40 +335,39 @@ class AsyncInferenceClient:
298
335
  if request_parameters.task in TASKS_EXPECTING_IMAGES and "Accept" not in request_parameters.headers:
299
336
  request_parameters.headers["Accept"] = "image/png"
300
337
 
301
- while True:
302
- with _open_as_binary(request_parameters.data) as data_as_binary:
303
- # Do not use context manager as we don't want to close the connection immediately when returning
304
- # a stream
305
- session = self._get_client_session(headers=request_parameters.headers)
306
-
307
- try:
308
- response = await session.post(
309
- request_parameters.url, json=request_parameters.json, data=data_as_binary, proxy=self.proxies
310
- )
311
- response_error_payload = None
312
- if response.status != 200:
313
- try:
314
- response_error_payload = await response.json() # get payload before connection closed
315
- except Exception:
316
- pass
317
- response.raise_for_status()
318
- if stream:
319
- return _async_yield_from(session, response)
320
- else:
321
- content = await response.read()
322
- await session.close()
323
- return content
324
- except asyncio.TimeoutError as error:
325
- await session.close()
326
- # Convert any `TimeoutError` to a `InferenceTimeoutError`
327
- raise InferenceTimeoutError(f"Inference call timed out: {request_parameters.url}") from error # type: ignore
328
- except aiohttp.ClientResponseError as error:
329
- error.response_error_payload = response_error_payload
330
- await session.close()
331
- raise error
332
- except Exception:
338
+ with _open_as_binary(request_parameters.data) as data_as_binary:
339
+ # Do not use context manager as we don't want to close the connection immediately when returning
340
+ # a stream
341
+ session = self._get_client_session(headers=request_parameters.headers)
342
+
343
+ try:
344
+ response = await session.post(
345
+ request_parameters.url, json=request_parameters.json, data=data_as_binary, proxy=self.proxies
346
+ )
347
+ response_error_payload = None
348
+ if response.status != 200:
349
+ try:
350
+ response_error_payload = await response.json() # get payload before connection closed
351
+ except Exception:
352
+ pass
353
+ response.raise_for_status()
354
+ if stream:
355
+ return _async_yield_from(session, response)
356
+ else:
357
+ content = await response.read()
333
358
  await session.close()
334
- raise
359
+ return content
360
+ except asyncio.TimeoutError as error:
361
+ await session.close()
362
+ # Convert any `TimeoutError` to a `InferenceTimeoutError`
363
+ raise InferenceTimeoutError(f"Inference call timed out: {request_parameters.url}") from error # type: ignore
364
+ except aiohttp.ClientResponseError as error:
365
+ error.response_error_payload = response_error_payload
366
+ await session.close()
367
+ raise error
368
+ except Exception:
369
+ await session.close()
370
+ raise
335
371
 
336
372
  async def __aenter__(self):
337
373
  return self
@@ -950,7 +986,7 @@ class AsyncInferenceClient:
950
986
  ... messages=messages,
951
987
  ... response_format=response_format,
952
988
  ... max_tokens=500,
953
- )
989
+ ... )
954
990
  >>> response.choices[0].message.content
955
991
  '{\n\n"activity": "bike ride",\n"animals": ["puppy", "cat", "raccoon"],\n"animals_seen": 3,\n"location": "park"}'
956
992
  ```
@@ -1317,7 +1353,7 @@ class AsyncInferenceClient:
1317
1353
  [ImageSegmentationOutputElement(score=0.989008, label='LABEL_184', mask=<PIL.PngImagePlugin.PngImageFile image mode=L size=400x300 at 0x7FDD2B129CC0>), ...]
1318
1354
  ```
1319
1355
  """
1320
- provider_helper = get_provider_helper(self.provider, task="audio-classification")
1356
+ provider_helper = get_provider_helper(self.provider, task="image-segmentation")
1321
1357
  request_parameters = provider_helper.prepare_request(
1322
1358
  inputs=image,
1323
1359
  parameters={
@@ -2659,7 +2695,7 @@ class AsyncInferenceClient:
2659
2695
  api_key=self.token,
2660
2696
  )
2661
2697
  response = await self._inner_post(request_parameters)
2662
- response = provider_helper.get_response(response)
2698
+ response = provider_helper.get_response(response, request_parameters)
2663
2699
  return response
2664
2700
 
2665
2701
  async def text_to_speech(
@@ -3094,22 +3130,14 @@ class AsyncInferenceClient:
3094
3130
  response = await self._inner_post(request_parameters)
3095
3131
  return VisualQuestionAnsweringOutputElement.parse_obj_as_list(response)
3096
3132
 
3097
- @_deprecate_arguments(
3098
- version="0.30.0",
3099
- deprecated_args=["labels"],
3100
- custom_message="`labels`has been renamed to `candidate_labels` and will be removed in huggingface_hub>=0.30.0.",
3101
- )
3102
3133
  async def zero_shot_classification(
3103
3134
  self,
3104
3135
  text: str,
3105
- # temporarily keeping it optional for backward compatibility.
3106
- candidate_labels: List[str] = None, # type: ignore
3136
+ candidate_labels: List[str],
3107
3137
  *,
3108
3138
  multi_label: Optional[bool] = False,
3109
3139
  hypothesis_template: Optional[str] = None,
3110
3140
  model: Optional[str] = None,
3111
- # deprecated argument
3112
- labels: List[str] = None, # type: ignore
3113
3141
  ) -> List[ZeroShotClassificationOutputElement]:
3114
3142
  """
3115
3143
  Provide as input a text and a set of candidate labels to classify the input text.
@@ -3190,16 +3218,6 @@ class AsyncInferenceClient:
3190
3218
  ]
3191
3219
  ```
3192
3220
  """
3193
- # handle deprecation
3194
- if labels is not None:
3195
- if candidate_labels is not None:
3196
- raise ValueError(
3197
- "Cannot specify both `labels` and `candidate_labels`. Use `candidate_labels` instead."
3198
- )
3199
- candidate_labels = labels
3200
- elif candidate_labels is None:
3201
- raise ValueError("Must specify `candidate_labels`")
3202
-
3203
3221
  provider_helper = get_provider_helper(self.provider, task="zero-shot-classification")
3204
3222
  request_parameters = provider_helper.prepare_request(
3205
3223
  inputs=text,
@@ -3219,16 +3237,10 @@ class AsyncInferenceClient:
3219
3237
  for label, score in zip(output["labels"], output["scores"])
3220
3238
  ]
3221
3239
 
3222
- @_deprecate_arguments(
3223
- version="0.30.0",
3224
- deprecated_args=["labels"],
3225
- custom_message="`labels`has been renamed to `candidate_labels` and will be removed in huggingface_hub>=0.30.0.",
3226
- )
3227
3240
  async def zero_shot_image_classification(
3228
3241
  self,
3229
3242
  image: ContentT,
3230
- # temporarily keeping it optional for backward compatibility.
3231
- candidate_labels: List[str] = None, # type: ignore
3243
+ candidate_labels: List[str],
3232
3244
  *,
3233
3245
  model: Optional[str] = None,
3234
3246
  hypothesis_template: Optional[str] = None,
@@ -3274,15 +3286,6 @@ class AsyncInferenceClient:
3274
3286
  [ZeroShotImageClassificationOutputElement(label='dog', score=0.956),...]
3275
3287
  ```
3276
3288
  """
3277
- # handle deprecation
3278
- if labels is not None:
3279
- if candidate_labels is not None:
3280
- raise ValueError(
3281
- "Cannot specify both `labels` and `candidate_labels`. Use `candidate_labels` instead."
3282
- )
3283
- candidate_labels = labels
3284
- elif candidate_labels is None:
3285
- raise ValueError("Must specify `candidate_labels`")
3286
3289
  # Raise ValueError if input is less than 2 labels
3287
3290
  if len(candidate_labels) < 2:
3288
3291
  raise ValueError("You must specify at least 2 classes to compare.")
@@ -30,6 +30,7 @@ from .chat_completion import (
30
30
  ChatCompletionInputMessageChunkType,
31
31
  ChatCompletionInputStreamOptions,
32
32
  ChatCompletionInputTool,
33
+ ChatCompletionInputToolCall,
33
34
  ChatCompletionInputToolChoiceClass,
34
35
  ChatCompletionInputToolChoiceEnum,
35
36
  ChatCompletionInputURL,