huggingface-hub 0.30.1__py3-none-any.whl → 0.31.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. huggingface_hub/__init__.py +1 -1
  2. huggingface_hub/_commit_api.py +23 -4
  3. huggingface_hub/_inference_endpoints.py +8 -5
  4. huggingface_hub/_snapshot_download.py +2 -1
  5. huggingface_hub/_space_api.py +0 -5
  6. huggingface_hub/_upload_large_folder.py +26 -3
  7. huggingface_hub/commands/upload.py +2 -1
  8. huggingface_hub/constants.py +1 -0
  9. huggingface_hub/file_download.py +58 -10
  10. huggingface_hub/hf_api.py +81 -15
  11. huggingface_hub/inference/_client.py +105 -150
  12. huggingface_hub/inference/_generated/_async_client.py +105 -150
  13. huggingface_hub/inference/_generated/types/automatic_speech_recognition.py +2 -3
  14. huggingface_hub/inference/_generated/types/chat_completion.py +3 -3
  15. huggingface_hub/inference/_generated/types/image_to_text.py +2 -3
  16. huggingface_hub/inference/_generated/types/text_generation.py +1 -1
  17. huggingface_hub/inference/_generated/types/text_to_audio.py +1 -2
  18. huggingface_hub/inference/_generated/types/text_to_speech.py +1 -2
  19. huggingface_hub/inference/_providers/__init__.py +55 -17
  20. huggingface_hub/inference/_providers/_common.py +34 -19
  21. huggingface_hub/inference/_providers/black_forest_labs.py +4 -1
  22. huggingface_hub/inference/_providers/fal_ai.py +36 -11
  23. huggingface_hub/inference/_providers/hf_inference.py +33 -11
  24. huggingface_hub/inference/_providers/hyperbolic.py +5 -1
  25. huggingface_hub/inference/_providers/nebius.py +15 -1
  26. huggingface_hub/inference/_providers/novita.py +14 -1
  27. huggingface_hub/inference/_providers/openai.py +3 -2
  28. huggingface_hub/inference/_providers/replicate.py +22 -3
  29. huggingface_hub/inference/_providers/sambanova.py +23 -1
  30. huggingface_hub/inference/_providers/together.py +15 -1
  31. huggingface_hub/repocard_data.py +24 -4
  32. huggingface_hub/utils/_pagination.py +2 -2
  33. huggingface_hub/utils/_runtime.py +4 -0
  34. huggingface_hub/utils/_xet.py +1 -12
  35. {huggingface_hub-0.30.1.dist-info → huggingface_hub-0.31.0rc0.dist-info}/METADATA +3 -2
  36. {huggingface_hub-0.30.1.dist-info → huggingface_hub-0.31.0rc0.dist-info}/RECORD +40 -40
  37. {huggingface_hub-0.30.1.dist-info → huggingface_hub-0.31.0rc0.dist-info}/LICENSE +0 -0
  38. {huggingface_hub-0.30.1.dist-info → huggingface_hub-0.31.0rc0.dist-info}/WHEEL +0 -0
  39. {huggingface_hub-0.30.1.dist-info → huggingface_hub-0.31.0rc0.dist-info}/entry_points.txt +0 -0
  40. {huggingface_hub-0.30.1.dist-info → huggingface_hub-0.31.0rc0.dist-info}/top_level.txt +0 -0
@@ -85,7 +85,7 @@ from huggingface_hub.inference._generated.types import (
85
85
  ZeroShotClassificationOutputElement,
86
86
  ZeroShotImageClassificationOutputElement,
87
87
  )
88
- from huggingface_hub.inference._providers import PROVIDER_T, HFInferenceTask, get_provider_helper
88
+ from huggingface_hub.inference._providers import PROVIDER_T, get_provider_helper
89
89
  from huggingface_hub.utils import build_hf_headers, get_session, hf_raise_for_status
90
90
  from huggingface_hub.utils._auth import get_token
91
91
  from huggingface_hub.utils._deprecation import _deprecate_method
@@ -122,15 +122,14 @@ class AsyncInferenceClient:
122
122
  documentation for details). When passing a URL as `model`, the client will not append any suffix path to it.
123
123
  provider (`str`, *optional*):
124
124
  Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"openai"`, `"replicate"`, "sambanova"` or `"together"`.
125
- defaults to hf-inference (Hugging Face Serverless Inference API).
125
+ Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers.
126
126
  If model is a URL or `base_url` is passed, then `provider` is not used.
127
127
  token (`str`, *optional*):
128
128
  Hugging Face token. Will default to the locally saved token if not provided.
129
129
  Note: for better compatibility with OpenAI's client, `token` has been aliased as `api_key`. Those 2
130
130
  arguments are mutually exclusive and have the exact same behavior.
131
131
  timeout (`float`, `optional`):
132
- The maximum number of seconds to wait for a response from the server. Loading a new model in Inference
133
- API can take up to several minutes. Defaults to None, meaning it will loop until the server is available.
132
+ The maximum number of seconds to wait for a response from the server. Defaults to None, meaning it will loop until the server is available.
134
133
  headers (`Dict[str, str]`, `optional`):
135
134
  Additional headers to send to the server. By default only the authorization and user-agent headers are sent.
136
135
  Values in this dictionary will override the default values.
@@ -155,7 +154,7 @@ class AsyncInferenceClient:
155
154
  self,
156
155
  model: Optional[str] = None,
157
156
  *,
158
- provider: Optional[PROVIDER_T] = None,
157
+ provider: Union[Literal["auto"], PROVIDER_T, None] = None,
159
158
  token: Optional[str] = None,
160
159
  timeout: Optional[float] = None,
161
160
  headers: Optional[Dict[str, str]] = None,
@@ -219,7 +218,7 @@ class AsyncInferenceClient:
219
218
  )
220
219
 
221
220
  # Configure provider
222
- self.provider = provider if provider is not None else "hf-inference"
221
+ self.provider = provider
223
222
 
224
223
  self.cookies = cookies
225
224
  self.timeout = timeout
@@ -232,83 +231,6 @@ class AsyncInferenceClient:
232
231
  def __repr__(self):
233
232
  return f"<InferenceClient(model='{self.model if self.model else ''}', timeout={self.timeout})>"
234
233
 
235
- @overload
236
- async def post( # type: ignore[misc]
237
- self,
238
- *,
239
- json: Optional[Union[str, Dict, List]] = None,
240
- data: Optional[ContentT] = None,
241
- model: Optional[str] = None,
242
- task: Optional[str] = None,
243
- stream: Literal[False] = ...,
244
- ) -> bytes: ...
245
-
246
- @overload
247
- async def post( # type: ignore[misc]
248
- self,
249
- *,
250
- json: Optional[Union[str, Dict, List]] = None,
251
- data: Optional[ContentT] = None,
252
- model: Optional[str] = None,
253
- task: Optional[str] = None,
254
- stream: Literal[True] = ...,
255
- ) -> AsyncIterable[bytes]: ...
256
-
257
- @overload
258
- async def post(
259
- self,
260
- *,
261
- json: Optional[Union[str, Dict, List]] = None,
262
- data: Optional[ContentT] = None,
263
- model: Optional[str] = None,
264
- task: Optional[str] = None,
265
- stream: bool = False,
266
- ) -> Union[bytes, AsyncIterable[bytes]]: ...
267
-
268
- @_deprecate_method(
269
- version="0.31.0",
270
- message=(
271
- "Making direct POST requests to the inference server is not supported anymore. "
272
- "Please use task methods instead (e.g. `InferenceClient.chat_completion`). "
273
- "If your use case is not supported, please open an issue in https://github.com/huggingface/huggingface_hub."
274
- ),
275
- )
276
- async def post(
277
- self,
278
- *,
279
- json: Optional[Union[str, Dict, List]] = None,
280
- data: Optional[ContentT] = None,
281
- model: Optional[str] = None,
282
- task: Optional[str] = None,
283
- stream: bool = False,
284
- ) -> Union[bytes, AsyncIterable[bytes]]:
285
- """
286
- Make a POST request to the inference server.
287
-
288
- This method is deprecated and will be removed in the future.
289
- Please use task methods instead (e.g. `InferenceClient.chat_completion`).
290
- """
291
- if self.provider != "hf-inference":
292
- raise ValueError(
293
- "Cannot use `post` with another provider than `hf-inference`. "
294
- "`InferenceClient.post` is deprecated and should not be used directly anymore."
295
- )
296
- provider_helper = HFInferenceTask(task or "unknown")
297
- mapped_model = provider_helper._prepare_mapped_model(model or self.model)
298
- url = provider_helper._prepare_url(self.token, mapped_model) # type: ignore[arg-type]
299
- headers = provider_helper._prepare_headers(self.headers, self.token) # type: ignore[arg-type]
300
- return await self._inner_post(
301
- request_parameters=RequestParameters(
302
- url=url,
303
- task=task or "unknown",
304
- model=model or "unknown",
305
- json=json,
306
- data=data,
307
- headers=headers,
308
- ),
309
- stream=stream,
310
- )
311
-
312
234
  @overload
313
235
  async def _inner_post( # type: ignore[misc]
314
236
  self, request_parameters: RequestParameters, *, stream: Literal[False] = ...
@@ -441,12 +363,13 @@ class AsyncInferenceClient:
441
363
  ]
442
364
  ```
443
365
  """
444
- provider_helper = get_provider_helper(self.provider, task="audio-classification")
366
+ model_id = model or self.model
367
+ provider_helper = get_provider_helper(self.provider, task="audio-classification", model=model_id)
445
368
  request_parameters = provider_helper.prepare_request(
446
369
  inputs=audio,
447
370
  parameters={"function_to_apply": function_to_apply, "top_k": top_k},
448
371
  headers=self.headers,
449
- model=model or self.model,
372
+ model=model_id,
450
373
  api_key=self.token,
451
374
  )
452
375
  response = await self._inner_post(request_parameters)
@@ -490,12 +413,13 @@ class AsyncInferenceClient:
490
413
  f.write(item.blob)
491
414
  ```
492
415
  """
493
- provider_helper = get_provider_helper(self.provider, task="audio-to-audio")
416
+ model_id = model or self.model
417
+ provider_helper = get_provider_helper(self.provider, task="audio-to-audio", model=model_id)
494
418
  request_parameters = provider_helper.prepare_request(
495
419
  inputs=audio,
496
420
  parameters={},
497
421
  headers=self.headers,
498
- model=model or self.model,
422
+ model=model_id,
499
423
  api_key=self.token,
500
424
  )
501
425
  response = await self._inner_post(request_parameters)
@@ -541,12 +465,13 @@ class AsyncInferenceClient:
541
465
  "hello world"
542
466
  ```
543
467
  """
544
- provider_helper = get_provider_helper(self.provider, task="automatic-speech-recognition")
468
+ model_id = model or self.model
469
+ provider_helper = get_provider_helper(self.provider, task="automatic-speech-recognition", model=model_id)
545
470
  request_parameters = provider_helper.prepare_request(
546
471
  inputs=audio,
547
472
  parameters={**(extra_body or {})},
548
473
  headers=self.headers,
549
- model=model or self.model,
474
+ model=model_id,
550
475
  api_key=self.token,
551
476
  )
552
477
  response = await self._inner_post(request_parameters)
@@ -991,15 +916,21 @@ class AsyncInferenceClient:
991
916
  '{\n\n"activity": "bike ride",\n"animals": ["puppy", "cat", "raccoon"],\n"animals_seen": 3,\n"location": "park"}'
992
917
  ```
993
918
  """
994
- # Get the provider helper
995
- provider_helper = get_provider_helper(self.provider, task="conversational")
996
-
997
919
  # Since `chat_completion(..., model=xxx)` is also a payload parameter for the server, we need to handle 'model' differently.
998
920
  # `self.model` takes precedence over 'model' argument for building URL.
999
921
  # `model` takes precedence for payload value.
1000
922
  model_id_or_url = self.model or model
1001
923
  payload_model = model or self.model
1002
924
 
925
+ # Get the provider helper
926
+ provider_helper = get_provider_helper(
927
+ self.provider,
928
+ task="conversational",
929
+ model=model_id_or_url
930
+ if model_id_or_url is not None and model_id_or_url.startswith(("http://", "https://"))
931
+ else payload_model,
932
+ )
933
+
1003
934
  # Prepare the payload
1004
935
  parameters = {
1005
936
  "model": payload_model,
@@ -1102,8 +1033,9 @@ class AsyncInferenceClient:
1102
1033
  [DocumentQuestionAnsweringOutputElement(answer='us-001', end=16, score=0.9999666213989258, start=16)]
1103
1034
  ```
1104
1035
  """
1036
+ model_id = model or self.model
1037
+ provider_helper = get_provider_helper(self.provider, task="document-question-answering", model=model_id)
1105
1038
  inputs: Dict[str, Any] = {"question": question, "image": _b64_encode(image)}
1106
- provider_helper = get_provider_helper(self.provider, task="document-question-answering")
1107
1039
  request_parameters = provider_helper.prepare_request(
1108
1040
  inputs=inputs,
1109
1041
  parameters={
@@ -1117,7 +1049,7 @@ class AsyncInferenceClient:
1117
1049
  "word_boxes": word_boxes,
1118
1050
  },
1119
1051
  headers=self.headers,
1120
- model=model or self.model,
1052
+ model=model_id,
1121
1053
  api_key=self.token,
1122
1054
  )
1123
1055
  response = await self._inner_post(request_parameters)
@@ -1140,8 +1072,8 @@ class AsyncInferenceClient:
1140
1072
  text (`str`):
1141
1073
  The text to embed.
1142
1074
  model (`str`, *optional*):
1143
- The model to use for the conversational task. Can be a model ID hosted on the Hugging Face Hub or a URL to
1144
- a deployed Inference Endpoint. If not provided, the default recommended conversational model will be used.
1075
+ The model to use for the feature extraction task. Can be a model ID hosted on the Hugging Face Hub or a URL to
1076
+ a deployed Inference Endpoint. If not provided, the default recommended feature extraction model will be used.
1145
1077
  Defaults to None.
1146
1078
  normalize (`bool`, *optional*):
1147
1079
  Whether to normalize the embeddings or not.
@@ -1179,7 +1111,8 @@ class AsyncInferenceClient:
1179
1111
  [ 0.28552425, -0.928395 , -1.2077185 , ..., 0.76810825, -2.1069427 , 0.6236161 ]], dtype=float32)
1180
1112
  ```
1181
1113
  """
1182
- provider_helper = get_provider_helper(self.provider, task="feature-extraction")
1114
+ model_id = model or self.model
1115
+ provider_helper = get_provider_helper(self.provider, task="feature-extraction", model=model_id)
1183
1116
  request_parameters = provider_helper.prepare_request(
1184
1117
  inputs=text,
1185
1118
  parameters={
@@ -1189,12 +1122,12 @@ class AsyncInferenceClient:
1189
1122
  "truncation_direction": truncation_direction,
1190
1123
  },
1191
1124
  headers=self.headers,
1192
- model=model or self.model,
1125
+ model=model_id,
1193
1126
  api_key=self.token,
1194
1127
  )
1195
1128
  response = await self._inner_post(request_parameters)
1196
1129
  np = _import_numpy()
1197
- return np.array(_bytes_to_dict(response), dtype="float32")
1130
+ return np.array(provider_helper.get_response(response), dtype="float32")
1198
1131
 
1199
1132
  async def fill_mask(
1200
1133
  self,
@@ -1241,12 +1174,13 @@ class AsyncInferenceClient:
1241
1174
  ]
1242
1175
  ```
1243
1176
  """
1244
- provider_helper = get_provider_helper(self.provider, task="fill-mask")
1177
+ model_id = model or self.model
1178
+ provider_helper = get_provider_helper(self.provider, task="fill-mask", model=model_id)
1245
1179
  request_parameters = provider_helper.prepare_request(
1246
1180
  inputs=text,
1247
1181
  parameters={"targets": targets, "top_k": top_k},
1248
1182
  headers=self.headers,
1249
- model=model or self.model,
1183
+ model=model_id,
1250
1184
  api_key=self.token,
1251
1185
  )
1252
1186
  response = await self._inner_post(request_parameters)
@@ -1291,12 +1225,13 @@ class AsyncInferenceClient:
1291
1225
  [ImageClassificationOutputElement(label='Blenheim spaniel', score=0.9779096841812134), ...]
1292
1226
  ```
1293
1227
  """
1294
- provider_helper = get_provider_helper(self.provider, task="image-classification")
1228
+ model_id = model or self.model
1229
+ provider_helper = get_provider_helper(self.provider, task="image-classification", model=model_id)
1295
1230
  request_parameters = provider_helper.prepare_request(
1296
1231
  inputs=image,
1297
1232
  parameters={"function_to_apply": function_to_apply, "top_k": top_k},
1298
1233
  headers=self.headers,
1299
- model=model or self.model,
1234
+ model=model_id,
1300
1235
  api_key=self.token,
1301
1236
  )
1302
1237
  response = await self._inner_post(request_parameters)
@@ -1353,7 +1288,8 @@ class AsyncInferenceClient:
1353
1288
  [ImageSegmentationOutputElement(score=0.989008, label='LABEL_184', mask=<PIL.PngImagePlugin.PngImageFile image mode=L size=400x300 at 0x7FDD2B129CC0>), ...]
1354
1289
  ```
1355
1290
  """
1356
- provider_helper = get_provider_helper(self.provider, task="image-segmentation")
1291
+ model_id = model or self.model
1292
+ provider_helper = get_provider_helper(self.provider, task="image-segmentation", model=model_id)
1357
1293
  request_parameters = provider_helper.prepare_request(
1358
1294
  inputs=image,
1359
1295
  parameters={
@@ -1363,7 +1299,7 @@ class AsyncInferenceClient:
1363
1299
  "threshold": threshold,
1364
1300
  },
1365
1301
  headers=self.headers,
1366
- model=model or self.model,
1302
+ model=model_id,
1367
1303
  api_key=self.token,
1368
1304
  )
1369
1305
  response = await self._inner_post(request_parameters)
@@ -1430,7 +1366,8 @@ class AsyncInferenceClient:
1430
1366
  >>> image.save("tiger.jpg")
1431
1367
  ```
1432
1368
  """
1433
- provider_helper = get_provider_helper(self.provider, task="image-to-image")
1369
+ model_id = model or self.model
1370
+ provider_helper = get_provider_helper(self.provider, task="image-to-image", model=model_id)
1434
1371
  request_parameters = provider_helper.prepare_request(
1435
1372
  inputs=image,
1436
1373
  parameters={
@@ -1442,7 +1379,7 @@ class AsyncInferenceClient:
1442
1379
  **kwargs,
1443
1380
  },
1444
1381
  headers=self.headers,
1445
- model=model or self.model,
1382
+ model=model_id,
1446
1383
  api_key=self.token,
1447
1384
  )
1448
1385
  response = await self._inner_post(request_parameters)
@@ -1482,12 +1419,13 @@ class AsyncInferenceClient:
1482
1419
  'a dog laying on the grass next to a flower pot '
1483
1420
  ```
1484
1421
  """
1485
- provider_helper = get_provider_helper(self.provider, task="image-to-text")
1422
+ model_id = model or self.model
1423
+ provider_helper = get_provider_helper(self.provider, task="image-to-text", model=model_id)
1486
1424
  request_parameters = provider_helper.prepare_request(
1487
1425
  inputs=image,
1488
1426
  parameters={},
1489
1427
  headers=self.headers,
1490
- model=model or self.model,
1428
+ model=model_id,
1491
1429
  api_key=self.token,
1492
1430
  )
1493
1431
  response = await self._inner_post(request_parameters)
@@ -1534,12 +1472,13 @@ class AsyncInferenceClient:
1534
1472
  [ObjectDetectionOutputElement(score=0.9486683011054993, label='person', box=ObjectDetectionBoundingBox(xmin=59, ymin=39, xmax=420, ymax=510)), ...]
1535
1473
  ```
1536
1474
  """
1537
- provider_helper = get_provider_helper(self.provider, task="object-detection")
1475
+ model_id = model or self.model
1476
+ provider_helper = get_provider_helper(self.provider, task="object-detection", model=model_id)
1538
1477
  request_parameters = provider_helper.prepare_request(
1539
1478
  inputs=image,
1540
1479
  parameters={"threshold": threshold},
1541
1480
  headers=self.headers,
1542
- model=model or self.model,
1481
+ model=model_id,
1543
1482
  api_key=self.token,
1544
1483
  )
1545
1484
  response = await self._inner_post(request_parameters)
@@ -1608,7 +1547,8 @@ class AsyncInferenceClient:
1608
1547
  QuestionAnsweringOutputElement(answer='Clara', end=16, score=0.9326565265655518, start=11)
1609
1548
  ```
1610
1549
  """
1611
- provider_helper = get_provider_helper(self.provider, task="question-answering")
1550
+ model_id = model or self.model
1551
+ provider_helper = get_provider_helper(self.provider, task="question-answering", model=model_id)
1612
1552
  request_parameters = provider_helper.prepare_request(
1613
1553
  inputs=None,
1614
1554
  parameters={
@@ -1622,7 +1562,7 @@ class AsyncInferenceClient:
1622
1562
  },
1623
1563
  extra_payload={"question": question, "context": context},
1624
1564
  headers=self.headers,
1625
- model=model or self.model,
1565
+ model=model_id,
1626
1566
  api_key=self.token,
1627
1567
  )
1628
1568
  response = await self._inner_post(request_parameters)
@@ -1642,8 +1582,8 @@ class AsyncInferenceClient:
1642
1582
  other_sentences (`List[str]`):
1643
1583
  The list of sentences to compare to.
1644
1584
  model (`str`, *optional*):
1645
- The model to use for the conversational task. Can be a model ID hosted on the Hugging Face Hub or a URL to
1646
- a deployed Inference Endpoint. If not provided, the default recommended conversational model will be used.
1585
+ The model to use for the sentence similarity task. Can be a model ID hosted on the Hugging Face Hub or a URL to
1586
+ a deployed Inference Endpoint. If not provided, the default recommended sentence similarity model will be used.
1647
1587
  Defaults to None.
1648
1588
 
1649
1589
  Returns:
@@ -1671,13 +1611,14 @@ class AsyncInferenceClient:
1671
1611
  [0.7785726189613342, 0.45876261591911316, 0.2906220555305481]
1672
1612
  ```
1673
1613
  """
1674
- provider_helper = get_provider_helper(self.provider, task="sentence-similarity")
1614
+ model_id = model or self.model
1615
+ provider_helper = get_provider_helper(self.provider, task="sentence-similarity", model=model_id)
1675
1616
  request_parameters = provider_helper.prepare_request(
1676
- inputs=None,
1617
+ inputs={"source_sentence": sentence, "sentences": other_sentences},
1677
1618
  parameters={},
1678
- extra_payload={"source_sentence": sentence, "sentences": other_sentences},
1619
+ extra_payload={},
1679
1620
  headers=self.headers,
1680
- model=model or self.model,
1621
+ model=model_id,
1681
1622
  api_key=self.token,
1682
1623
  )
1683
1624
  response = await self._inner_post(request_parameters)
@@ -1730,12 +1671,13 @@ class AsyncInferenceClient:
1730
1671
  "generate_parameters": generate_parameters,
1731
1672
  "truncation": truncation,
1732
1673
  }
1733
- provider_helper = get_provider_helper(self.provider, task="summarization")
1674
+ model_id = model or self.model
1675
+ provider_helper = get_provider_helper(self.provider, task="summarization", model=model_id)
1734
1676
  request_parameters = provider_helper.prepare_request(
1735
1677
  inputs=text,
1736
1678
  parameters=parameters,
1737
1679
  headers=self.headers,
1738
- model=model or self.model,
1680
+ model=model_id,
1739
1681
  api_key=self.token,
1740
1682
  )
1741
1683
  response = await self._inner_post(request_parameters)
@@ -1792,13 +1734,14 @@ class AsyncInferenceClient:
1792
1734
  TableQuestionAnsweringOutputElement(answer='36542', coordinates=[[0, 1]], cells=['36542'], aggregator='AVERAGE')
1793
1735
  ```
1794
1736
  """
1795
- provider_helper = get_provider_helper(self.provider, task="table-question-answering")
1737
+ model_id = model or self.model
1738
+ provider_helper = get_provider_helper(self.provider, task="table-question-answering", model=model_id)
1796
1739
  request_parameters = provider_helper.prepare_request(
1797
1740
  inputs=None,
1798
1741
  parameters={"model": model, "padding": padding, "sequential": sequential, "truncation": truncation},
1799
1742
  extra_payload={"query": query, "table": table},
1800
1743
  headers=self.headers,
1801
- model=model or self.model,
1744
+ model=model_id,
1802
1745
  api_key=self.token,
1803
1746
  )
1804
1747
  response = await self._inner_post(request_parameters)
@@ -1847,13 +1790,14 @@ class AsyncInferenceClient:
1847
1790
  ["5", "5", "5"]
1848
1791
  ```
1849
1792
  """
1850
- provider_helper = get_provider_helper(self.provider, task="tabular-classification")
1793
+ model_id = model or self.model
1794
+ provider_helper = get_provider_helper(self.provider, task="tabular-classification", model=model_id)
1851
1795
  request_parameters = provider_helper.prepare_request(
1852
1796
  inputs=None,
1853
1797
  extra_payload={"table": table},
1854
1798
  parameters={},
1855
1799
  headers=self.headers,
1856
- model=model or self.model,
1800
+ model=model_id,
1857
1801
  api_key=self.token,
1858
1802
  )
1859
1803
  response = await self._inner_post(request_parameters)
@@ -1897,13 +1841,14 @@ class AsyncInferenceClient:
1897
1841
  [110, 120, 130]
1898
1842
  ```
1899
1843
  """
1900
- provider_helper = get_provider_helper(self.provider, task="tabular-regression")
1844
+ model_id = model or self.model
1845
+ provider_helper = get_provider_helper(self.provider, task="tabular-regression", model=model_id)
1901
1846
  request_parameters = provider_helper.prepare_request(
1902
1847
  inputs=None,
1903
1848
  parameters={},
1904
1849
  extra_payload={"table": table},
1905
1850
  headers=self.headers,
1906
- model=model or self.model,
1851
+ model=model_id,
1907
1852
  api_key=self.token,
1908
1853
  )
1909
1854
  response = await self._inner_post(request_parameters)
@@ -1953,7 +1898,8 @@ class AsyncInferenceClient:
1953
1898
  ]
1954
1899
  ```
1955
1900
  """
1956
- provider_helper = get_provider_helper(self.provider, task="text-classification")
1901
+ model_id = model or self.model
1902
+ provider_helper = get_provider_helper(self.provider, task="text-classification", model=model_id)
1957
1903
  request_parameters = provider_helper.prepare_request(
1958
1904
  inputs=text,
1959
1905
  parameters={
@@ -1961,7 +1907,7 @@ class AsyncInferenceClient:
1961
1907
  "top_k": top_k,
1962
1908
  },
1963
1909
  headers=self.headers,
1964
- model=model or self.model,
1910
+ model=model_id,
1965
1911
  api_key=self.token,
1966
1912
  )
1967
1913
  response = await self._inner_post(request_parameters)
@@ -2403,13 +2349,14 @@ class AsyncInferenceClient:
2403
2349
  " Please pass `stream=False` as input."
2404
2350
  )
2405
2351
 
2406
- provider_helper = get_provider_helper(self.provider, task="text-generation")
2352
+ model_id = model or self.model
2353
+ provider_helper = get_provider_helper(self.provider, task="text-generation", model=model_id)
2407
2354
  request_parameters = provider_helper.prepare_request(
2408
2355
  inputs=prompt,
2409
2356
  parameters=parameters,
2410
2357
  extra_payload={"stream": stream},
2411
2358
  headers=self.headers,
2412
- model=model or self.model,
2359
+ model=model_id,
2413
2360
  api_key=self.token,
2414
2361
  )
2415
2362
 
@@ -2425,7 +2372,7 @@ class AsyncInferenceClient:
2425
2372
  prompt=prompt,
2426
2373
  details=details,
2427
2374
  stream=stream,
2428
- model=model or self.model,
2375
+ model=model_id,
2429
2376
  adapter_id=adapter_id,
2430
2377
  best_of=best_of,
2431
2378
  decoder_input_details=decoder_input_details,
@@ -2456,8 +2403,8 @@ class AsyncInferenceClient:
2456
2403
  # Data can be a single element (dict) or an iterable of dicts where we select the first element of.
2457
2404
  if isinstance(data, list):
2458
2405
  data = data[0]
2459
-
2460
- return TextGenerationOutput.parse_obj_as_instance(data) if details else data["generated_text"]
2406
+ response = provider_helper.get_response(data, request_parameters)
2407
+ return TextGenerationOutput.parse_obj_as_instance(response) if details else response["generated_text"]
2461
2408
 
2462
2409
  async def text_to_image(
2463
2410
  self,
@@ -2581,7 +2528,8 @@ class AsyncInferenceClient:
2581
2528
  >>> image.save("astronaut.png")
2582
2529
  ```
2583
2530
  """
2584
- provider_helper = get_provider_helper(self.provider, task="text-to-image")
2531
+ model_id = model or self.model
2532
+ provider_helper = get_provider_helper(self.provider, task="text-to-image", model=model_id)
2585
2533
  request_parameters = provider_helper.prepare_request(
2586
2534
  inputs=prompt,
2587
2535
  parameters={
@@ -2595,7 +2543,7 @@ class AsyncInferenceClient:
2595
2543
  **(extra_body or {}),
2596
2544
  },
2597
2545
  headers=self.headers,
2598
- model=model or self.model,
2546
+ model=model_id,
2599
2547
  api_key=self.token,
2600
2548
  )
2601
2549
  response = await self._inner_post(request_parameters)
@@ -2679,7 +2627,8 @@ class AsyncInferenceClient:
2679
2627
  ... file.write(video)
2680
2628
  ```
2681
2629
  """
2682
- provider_helper = get_provider_helper(self.provider, task="text-to-video")
2630
+ model_id = model or self.model
2631
+ provider_helper = get_provider_helper(self.provider, task="text-to-video", model=model_id)
2683
2632
  request_parameters = provider_helper.prepare_request(
2684
2633
  inputs=prompt,
2685
2634
  parameters={
@@ -2691,7 +2640,7 @@ class AsyncInferenceClient:
2691
2640
  **(extra_body or {}),
2692
2641
  },
2693
2642
  headers=self.headers,
2694
- model=model or self.model,
2643
+ model=model_id,
2695
2644
  api_key=self.token,
2696
2645
  )
2697
2646
  response = await self._inner_post(request_parameters)
@@ -2877,7 +2826,8 @@ class AsyncInferenceClient:
2877
2826
  ... f.write(audio)
2878
2827
  ```
2879
2828
  """
2880
- provider_helper = get_provider_helper(self.provider, task="text-to-speech")
2829
+ model_id = model or self.model
2830
+ provider_helper = get_provider_helper(self.provider, task="text-to-speech", model=model_id)
2881
2831
  request_parameters = provider_helper.prepare_request(
2882
2832
  inputs=text,
2883
2833
  parameters={
@@ -2900,7 +2850,7 @@ class AsyncInferenceClient:
2900
2850
  **(extra_body or {}),
2901
2851
  },
2902
2852
  headers=self.headers,
2903
- model=model or self.model,
2853
+ model=model_id,
2904
2854
  api_key=self.token,
2905
2855
  )
2906
2856
  response = await self._inner_post(request_parameters)
@@ -2967,7 +2917,8 @@ class AsyncInferenceClient:
2967
2917
  ]
2968
2918
  ```
2969
2919
  """
2970
- provider_helper = get_provider_helper(self.provider, task="token-classification")
2920
+ model_id = model or self.model
2921
+ provider_helper = get_provider_helper(self.provider, task="token-classification", model=model_id)
2971
2922
  request_parameters = provider_helper.prepare_request(
2972
2923
  inputs=text,
2973
2924
  parameters={
@@ -2976,7 +2927,7 @@ class AsyncInferenceClient:
2976
2927
  "stride": stride,
2977
2928
  },
2978
2929
  headers=self.headers,
2979
- model=model or self.model,
2930
+ model=model_id,
2980
2931
  api_key=self.token,
2981
2932
  )
2982
2933
  response = await self._inner_post(request_parameters)
@@ -3054,7 +3005,8 @@ class AsyncInferenceClient:
3054
3005
  if src_lang is None and tgt_lang is not None:
3055
3006
  raise ValueError("You cannot specify `tgt_lang` without specifying `src_lang`.")
3056
3007
 
3057
- provider_helper = get_provider_helper(self.provider, task="translation")
3008
+ model_id = model or self.model
3009
+ provider_helper = get_provider_helper(self.provider, task="translation", model=model_id)
3058
3010
  request_parameters = provider_helper.prepare_request(
3059
3011
  inputs=text,
3060
3012
  parameters={
@@ -3065,7 +3017,7 @@ class AsyncInferenceClient:
3065
3017
  "generate_parameters": generate_parameters,
3066
3018
  },
3067
3019
  headers=self.headers,
3068
- model=model or self.model,
3020
+ model=model_id,
3069
3021
  api_key=self.token,
3070
3022
  )
3071
3023
  response = await self._inner_post(request_parameters)
@@ -3118,12 +3070,13 @@ class AsyncInferenceClient:
3118
3070
  ]
3119
3071
  ```
3120
3072
  """
3121
- provider_helper = get_provider_helper(self.provider, task="visual-question-answering")
3073
+ model_id = model or self.model
3074
+ provider_helper = get_provider_helper(self.provider, task="visual-question-answering", model=model_id)
3122
3075
  request_parameters = provider_helper.prepare_request(
3123
3076
  inputs=image,
3124
3077
  parameters={"top_k": top_k},
3125
3078
  headers=self.headers,
3126
- model=model or self.model,
3079
+ model=model_id,
3127
3080
  api_key=self.token,
3128
3081
  extra_payload={"question": question, "image": _b64_encode(image)},
3129
3082
  )
@@ -3218,7 +3171,8 @@ class AsyncInferenceClient:
3218
3171
  ]
3219
3172
  ```
3220
3173
  """
3221
- provider_helper = get_provider_helper(self.provider, task="zero-shot-classification")
3174
+ model_id = model or self.model
3175
+ provider_helper = get_provider_helper(self.provider, task="zero-shot-classification", model=model_id)
3222
3176
  request_parameters = provider_helper.prepare_request(
3223
3177
  inputs=text,
3224
3178
  parameters={
@@ -3227,7 +3181,7 @@ class AsyncInferenceClient:
3227
3181
  "hypothesis_template": hypothesis_template,
3228
3182
  },
3229
3183
  headers=self.headers,
3230
- model=model or self.model,
3184
+ model=model_id,
3231
3185
  api_key=self.token,
3232
3186
  )
3233
3187
  response = await self._inner_post(request_parameters)
@@ -3290,7 +3244,8 @@ class AsyncInferenceClient:
3290
3244
  if len(candidate_labels) < 2:
3291
3245
  raise ValueError("You must specify at least 2 classes to compare.")
3292
3246
 
3293
- provider_helper = get_provider_helper(self.provider, task="zero-shot-image-classification")
3247
+ model_id = model or self.model
3248
+ provider_helper = get_provider_helper(self.provider, task="zero-shot-image-classification", model=model_id)
3294
3249
  request_parameters = provider_helper.prepare_request(
3295
3250
  inputs=image,
3296
3251
  parameters={
@@ -3298,7 +3253,7 @@ class AsyncInferenceClient:
3298
3253
  "hypothesis_template": hypothesis_template,
3299
3254
  },
3300
3255
  headers=self.headers,
3301
- model=model or self.model,
3256
+ model=model_id,
3302
3257
  api_key=self.token,
3303
3258
  )
3304
3259
  response = await self._inner_post(request_parameters)