huggingface-hub 0.21.4__py3-none-any.whl → 0.22.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of huggingface-hub might be problematic. Click here for more details.

Files changed (96) hide show
  1. huggingface_hub/__init__.py +217 -1
  2. huggingface_hub/_commit_api.py +14 -15
  3. huggingface_hub/_inference_endpoints.py +12 -11
  4. huggingface_hub/_login.py +1 -0
  5. huggingface_hub/_multi_commits.py +1 -0
  6. huggingface_hub/_snapshot_download.py +9 -1
  7. huggingface_hub/_tensorboard_logger.py +1 -0
  8. huggingface_hub/_webhooks_payload.py +1 -0
  9. huggingface_hub/_webhooks_server.py +1 -0
  10. huggingface_hub/commands/_cli_utils.py +1 -0
  11. huggingface_hub/commands/delete_cache.py +1 -0
  12. huggingface_hub/commands/download.py +1 -0
  13. huggingface_hub/commands/env.py +1 -0
  14. huggingface_hub/commands/scan_cache.py +1 -0
  15. huggingface_hub/commands/upload.py +1 -0
  16. huggingface_hub/community.py +1 -0
  17. huggingface_hub/constants.py +3 -1
  18. huggingface_hub/errors.py +38 -0
  19. huggingface_hub/file_download.py +24 -24
  20. huggingface_hub/hf_api.py +47 -35
  21. huggingface_hub/hub_mixin.py +210 -54
  22. huggingface_hub/inference/_client.py +554 -239
  23. huggingface_hub/inference/_common.py +195 -41
  24. huggingface_hub/inference/_generated/_async_client.py +558 -239
  25. huggingface_hub/inference/_generated/types/__init__.py +115 -0
  26. huggingface_hub/inference/_generated/types/audio_classification.py +43 -0
  27. huggingface_hub/inference/_generated/types/audio_to_audio.py +31 -0
  28. huggingface_hub/inference/_generated/types/automatic_speech_recognition.py +116 -0
  29. huggingface_hub/inference/_generated/types/base.py +149 -0
  30. huggingface_hub/inference/_generated/types/chat_completion.py +106 -0
  31. huggingface_hub/inference/_generated/types/depth_estimation.py +29 -0
  32. huggingface_hub/inference/_generated/types/document_question_answering.py +85 -0
  33. huggingface_hub/inference/_generated/types/feature_extraction.py +19 -0
  34. huggingface_hub/inference/_generated/types/fill_mask.py +50 -0
  35. huggingface_hub/inference/_generated/types/image_classification.py +43 -0
  36. huggingface_hub/inference/_generated/types/image_segmentation.py +52 -0
  37. huggingface_hub/inference/_generated/types/image_to_image.py +55 -0
  38. huggingface_hub/inference/_generated/types/image_to_text.py +105 -0
  39. huggingface_hub/inference/_generated/types/object_detection.py +55 -0
  40. huggingface_hub/inference/_generated/types/question_answering.py +77 -0
  41. huggingface_hub/inference/_generated/types/sentence_similarity.py +28 -0
  42. huggingface_hub/inference/_generated/types/summarization.py +46 -0
  43. huggingface_hub/inference/_generated/types/table_question_answering.py +45 -0
  44. huggingface_hub/inference/_generated/types/text2text_generation.py +45 -0
  45. huggingface_hub/inference/_generated/types/text_classification.py +43 -0
  46. huggingface_hub/inference/_generated/types/text_generation.py +161 -0
  47. huggingface_hub/inference/_generated/types/text_to_audio.py +105 -0
  48. huggingface_hub/inference/_generated/types/text_to_image.py +57 -0
  49. huggingface_hub/inference/_generated/types/token_classification.py +53 -0
  50. huggingface_hub/inference/_generated/types/translation.py +46 -0
  51. huggingface_hub/inference/_generated/types/video_classification.py +47 -0
  52. huggingface_hub/inference/_generated/types/visual_question_answering.py +53 -0
  53. huggingface_hub/inference/_generated/types/zero_shot_classification.py +56 -0
  54. huggingface_hub/inference/_generated/types/zero_shot_image_classification.py +51 -0
  55. huggingface_hub/inference/_generated/types/zero_shot_object_detection.py +55 -0
  56. huggingface_hub/inference/_templating.py +105 -0
  57. huggingface_hub/inference/_types.py +4 -152
  58. huggingface_hub/keras_mixin.py +39 -17
  59. huggingface_hub/lfs.py +20 -8
  60. huggingface_hub/repocard.py +11 -3
  61. huggingface_hub/repocard_data.py +12 -2
  62. huggingface_hub/serialization/__init__.py +1 -0
  63. huggingface_hub/serialization/_base.py +1 -0
  64. huggingface_hub/serialization/_numpy.py +1 -0
  65. huggingface_hub/serialization/_tensorflow.py +1 -0
  66. huggingface_hub/serialization/_torch.py +1 -0
  67. huggingface_hub/utils/__init__.py +4 -1
  68. huggingface_hub/utils/_cache_manager.py +7 -0
  69. huggingface_hub/utils/_chunk_utils.py +1 -0
  70. huggingface_hub/utils/_datetime.py +1 -0
  71. huggingface_hub/utils/_errors.py +10 -1
  72. huggingface_hub/utils/_experimental.py +1 -0
  73. huggingface_hub/utils/_fixes.py +19 -3
  74. huggingface_hub/utils/_git_credential.py +1 -0
  75. huggingface_hub/utils/_headers.py +10 -3
  76. huggingface_hub/utils/_hf_folder.py +1 -0
  77. huggingface_hub/utils/_http.py +1 -0
  78. huggingface_hub/utils/_pagination.py +1 -0
  79. huggingface_hub/utils/_paths.py +1 -0
  80. huggingface_hub/utils/_runtime.py +22 -0
  81. huggingface_hub/utils/_subprocess.py +1 -0
  82. huggingface_hub/utils/_token.py +1 -0
  83. huggingface_hub/utils/_typing.py +29 -1
  84. huggingface_hub/utils/_validators.py +1 -0
  85. huggingface_hub/utils/endpoint_helpers.py +1 -0
  86. huggingface_hub/utils/logging.py +1 -1
  87. huggingface_hub/utils/sha.py +1 -0
  88. huggingface_hub/utils/tqdm.py +1 -0
  89. {huggingface_hub-0.21.4.dist-info → huggingface_hub-0.22.0rc0.dist-info}/METADATA +14 -15
  90. huggingface_hub-0.22.0rc0.dist-info/RECORD +113 -0
  91. {huggingface_hub-0.21.4.dist-info → huggingface_hub-0.22.0rc0.dist-info}/WHEEL +1 -1
  92. huggingface_hub/inference/_text_generation.py +0 -551
  93. huggingface_hub-0.21.4.dist-info/RECORD +0 -81
  94. {huggingface_hub-0.21.4.dist-info → huggingface_hub-0.22.0rc0.dist-info}/LICENSE +0 -0
  95. {huggingface_hub-0.21.4.dist-info → huggingface_hub-0.22.0rc0.dist-info}/entry_points.txt +0 -0
  96. {huggingface_hub-0.21.4.dist-info → huggingface_hub-0.22.0rc0.dist-info}/top_level.txt +0 -0
@@ -13,10 +13,12 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
  """Contains utilities used by both the sync and async inference clients."""
16
+
16
17
  import base64
17
18
  import io
18
19
  import json
19
20
  import logging
21
+ import time
20
22
  from contextlib import contextmanager
21
23
  from dataclasses import dataclass
22
24
  from pathlib import Path
@@ -31,6 +33,7 @@ from typing import (
31
33
  Iterable,
32
34
  List,
33
35
  Literal,
36
+ NoReturn,
34
37
  Optional,
35
38
  Set,
36
39
  Union,
@@ -39,6 +42,15 @@ from typing import (
39
42
 
40
43
  from requests import HTTPError
41
44
 
45
+ from huggingface_hub.errors import (
46
+ GenerationError,
47
+ IncompleteGenerationError,
48
+ OverloadedError,
49
+ TextGenerationError,
50
+ UnknownError,
51
+ ValidationError,
52
+ )
53
+
42
54
  from ..constants import ENDPOINT
43
55
  from ..utils import (
44
56
  build_hf_headers,
@@ -48,7 +60,12 @@ from ..utils import (
48
60
  is_numpy_available,
49
61
  is_pillow_available,
50
62
  )
51
- from ._text_generation import TextGenerationStreamResponse, _parse_text_generation_error
63
+ from ._generated.types import (
64
+ ChatCompletionStreamOutput,
65
+ ChatCompletionStreamOutputChoice,
66
+ ChatCompletionStreamOutputDelta,
67
+ TextGenerationStreamOutput,
68
+ )
52
69
 
53
70
 
54
71
  if TYPE_CHECKING:
@@ -98,10 +115,6 @@ class ModelStatus:
98
115
  framework: str
99
116
 
100
117
 
101
- class InferenceTimeoutError(HTTPError, TimeoutError):
102
- """Error raised when a model is unavailable or the request times out."""
103
-
104
-
105
118
  ## IMPORT UTILS
106
119
 
107
120
 
@@ -163,13 +176,15 @@ def _first_or_none(items: List[Any]) -> Optional[Any]:
163
176
 
164
177
 
165
178
  @overload
166
- def _open_as_binary(content: ContentT) -> ContextManager[BinaryT]:
167
- ... # means "if input is not None, output is not None"
179
+ def _open_as_binary(
180
+ content: ContentT,
181
+ ) -> ContextManager[BinaryT]: ... # means "if input is not None, output is not None"
168
182
 
169
183
 
170
184
  @overload
171
- def _open_as_binary(content: Literal[None]) -> ContextManager[Literal[None]]:
172
- ... # means "if input is None, output is None"
185
+ def _open_as_binary(
186
+ content: Literal[None],
187
+ ) -> ContextManager[Literal[None]]: ... # means "if input is None, output is None"
173
188
 
174
189
 
175
190
  @contextmanager # type: ignore
@@ -253,48 +268,125 @@ def _bytes_to_image(content: bytes) -> "Image":
253
268
 
254
269
  def _stream_text_generation_response(
255
270
  bytes_output_as_lines: Iterable[bytes], details: bool
256
- ) -> Union[Iterable[str], Iterable[TextGenerationStreamResponse]]:
271
+ ) -> Union[Iterable[str], Iterable[TextGenerationStreamOutput]]:
272
+ """Used in `InferenceClient.text_generation`."""
257
273
  # Parse ServerSentEvents
258
274
  for byte_payload in bytes_output_as_lines:
259
- # Skip line
260
- if byte_payload == b"\n":
261
- continue
262
-
263
- payload = byte_payload.decode("utf-8")
264
-
265
- # Event data
266
- if payload.startswith("data:"):
267
- # Decode payload
268
- json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
269
- # Either an error as being returned
270
- if json_payload.get("error") is not None:
271
- raise _parse_text_generation_error(json_payload["error"], json_payload.get("error_type"))
272
- # Or parse token payload
273
- output = TextGenerationStreamResponse(**json_payload)
274
- yield output.token.text if not details else output
275
+ output = _format_text_generation_stream_output(byte_payload, details)
276
+ if output is not None:
277
+ yield output
275
278
 
276
279
 
277
280
  async def _async_stream_text_generation_response(
278
281
  bytes_output_as_lines: AsyncIterable[bytes], details: bool
279
- ) -> Union[AsyncIterable[str], AsyncIterable[TextGenerationStreamResponse]]:
282
+ ) -> Union[AsyncIterable[str], AsyncIterable[TextGenerationStreamOutput]]:
283
+ """Used in `AsyncInferenceClient.text_generation`."""
280
284
  # Parse ServerSentEvents
281
285
  async for byte_payload in bytes_output_as_lines:
282
- # Skip line
283
- if byte_payload == b"\n":
284
- continue
286
+ output = _format_text_generation_stream_output(byte_payload, details)
287
+ if output is not None:
288
+ yield output
289
+
290
+
291
+ def _format_text_generation_stream_output(
292
+ byte_payload: bytes, details: bool
293
+ ) -> Optional[Union[str, TextGenerationStreamOutput]]:
294
+ if not byte_payload.startswith(b"data:"):
295
+ return None # empty line
296
+
297
+ # Decode payload
298
+ payload = byte_payload.decode("utf-8")
299
+ json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
300
+
301
+ # Either an error as being returned
302
+ if json_payload.get("error") is not None:
303
+ raise _parse_text_generation_error(json_payload["error"], json_payload.get("error_type"))
304
+
305
+ # Or parse token payload
306
+ output = TextGenerationStreamOutput.parse_obj_as_instance(json_payload)
307
+ return output.token.text if not details else output
308
+
309
+
310
+ def _stream_chat_completion_response_from_text_generation(
311
+ text_generation_output: Iterable[TextGenerationStreamOutput],
312
+ ) -> Iterable[ChatCompletionStreamOutput]:
313
+ """Used in `InferenceClient.chat_completion`."""
314
+ created = int(time.time())
315
+ for item in text_generation_output:
316
+ yield _format_chat_completion_stream_output_from_text_generation(item, created)
317
+
318
+
319
+ async def _async_stream_chat_completion_response_from_text_generation(
320
+ text_generation_output: AsyncIterable[TextGenerationStreamOutput],
321
+ ) -> AsyncIterable[ChatCompletionStreamOutput]:
322
+ """Used in `AsyncInferenceClient.chat_completion`."""
323
+ created = int(time.time())
324
+ async for item in text_generation_output:
325
+ yield _format_chat_completion_stream_output_from_text_generation(item, created)
326
+
327
+
328
+ def _format_chat_completion_stream_output_from_text_generation(
329
+ item: TextGenerationStreamOutput, created: int
330
+ ) -> ChatCompletionStreamOutput:
331
+ if item.details is None:
332
+ # new token generated => return delta
333
+ return ChatCompletionStreamOutput(
334
+ choices=[
335
+ ChatCompletionStreamOutputChoice(
336
+ delta=ChatCompletionStreamOutputDelta(
337
+ role="assistant",
338
+ content=item.token.text,
339
+ ),
340
+ finish_reason=None,
341
+ index=0,
342
+ )
343
+ ],
344
+ created=created,
345
+ )
346
+ else:
347
+ # generation is completed => return finish reason
348
+ return ChatCompletionStreamOutput(
349
+ choices=[
350
+ ChatCompletionStreamOutputChoice(
351
+ delta=ChatCompletionStreamOutputDelta(),
352
+ finish_reason=item.details.finish_reason,
353
+ index=0,
354
+ )
355
+ ],
356
+ created=created,
357
+ )
358
+
359
+
360
+ def _stream_chat_completion_response_from_bytes(
361
+ bytes_lines: Iterable[bytes],
362
+ ) -> Iterable[ChatCompletionStreamOutput]:
363
+ """Used in `InferenceClient.chat_completion` if model is served with TGI."""
364
+ for item in bytes_lines:
365
+ output = _format_chat_completion_stream_output_from_text_generation_from_bytes(item)
366
+ if output is not None:
367
+ yield output
285
368
 
286
- payload = byte_payload.decode("utf-8")
287
369
 
288
- # Event data
289
- if payload.startswith("data:"):
290
- # Decode payload
291
- json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
292
- # Either an error as being returned
293
- if json_payload.get("error") is not None:
294
- raise _parse_text_generation_error(json_payload["error"], json_payload.get("error_type"))
295
- # Or parse token payload
296
- output = TextGenerationStreamResponse(**json_payload)
297
- yield output.token.text if not details else output
370
+ async def _async_stream_chat_completion_response_from_bytes(
371
+ bytes_lines: AsyncIterable[bytes],
372
+ ) -> AsyncIterable[ChatCompletionStreamOutput]:
373
+ """Used in `AsyncInferenceClient.chat_completion`."""
374
+ async for item in bytes_lines:
375
+ output = _format_chat_completion_stream_output_from_text_generation_from_bytes(item)
376
+ if output is not None:
377
+ yield output
378
+
379
+
380
+ def _format_chat_completion_stream_output_from_text_generation_from_bytes(
381
+ byte_payload: bytes,
382
+ ) -> Optional[ChatCompletionStreamOutput]:
383
+ if not byte_payload.startswith(b"data:"):
384
+ return None # empty line
385
+
386
+ # Decode payload
387
+ payload = byte_payload.decode("utf-8")
388
+ json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
389
+ return ChatCompletionStreamOutput.parse_obj_as_instance(json_payload)
298
390
 
299
391
 
300
392
  async def _async_yield_from(client: "ClientSession", response: "ClientResponse") -> AsyncIterable[bytes]:
@@ -314,6 +406,10 @@ async def _async_yield_from(client: "ClientSession", response: "ClientResponse")
314
406
  # default API with a warning message. We remember for each model if it's a TGI server
315
407
  # or not using `_NON_TGI_SERVERS` global variable.
316
408
  #
409
+ # In addition, TGI servers have a built-in API route for chat-completion, which is not
410
+ # available on the default API. We use this route to provide a more consistent behavior
411
+ # when available.
412
+ #
317
413
  # For more details, see https://github.com/huggingface/text-generation-inference and
318
414
  # https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task.
319
415
 
@@ -326,3 +422,61 @@ def _set_as_non_tgi(model: Optional[str]) -> None:
326
422
 
327
423
  def _is_tgi_server(model: Optional[str]) -> bool:
328
424
  return model not in _NON_TGI_SERVERS
425
+
426
+
427
+ _NON_CHAT_COMPLETION_SERVER: Set[str] = set()
428
+
429
+
430
+ def _set_as_non_chat_completion_server(model: str) -> None:
431
+ print("Set as non chat completion", model)
432
+ _NON_CHAT_COMPLETION_SERVER.add(model)
433
+
434
+
435
+ def _is_chat_completion_server(model: str) -> bool:
436
+ return model not in _NON_CHAT_COMPLETION_SERVER
437
+
438
+
439
+ # TEXT GENERATION ERRORS
440
+ # ----------------------
441
+ # Text-generation errors are parsed separately to handle as much as possible the errors returned by the text generation
442
+ # inference project (https://github.com/huggingface/text-generation-inference).
443
+ # ----------------------
444
+
445
+
446
+ def raise_text_generation_error(http_error: HTTPError) -> NoReturn:
447
+ """
448
+ Try to parse text-generation-inference error message and raise HTTPError in any case.
449
+
450
+ Args:
451
+ error (`HTTPError`):
452
+ The HTTPError that have been raised.
453
+ """
454
+ # Try to parse a Text Generation Inference error
455
+
456
+ try:
457
+ # Hacky way to retrieve payload in case of aiohttp error
458
+ payload = getattr(http_error, "response_error_payload", None) or http_error.response.json()
459
+ error = payload.get("error")
460
+ error_type = payload.get("error_type")
461
+ except Exception: # no payload
462
+ raise http_error
463
+
464
+ # If error_type => more information than `hf_raise_for_status`
465
+ if error_type is not None:
466
+ exception = _parse_text_generation_error(error, error_type)
467
+ raise exception from http_error
468
+
469
+ # Otherwise, fallback to default error
470
+ raise http_error
471
+
472
+
473
+ def _parse_text_generation_error(error: Optional[str], error_type: Optional[str]) -> TextGenerationError:
474
+ if error_type == "generation":
475
+ return GenerationError(error) # type: ignore
476
+ if error_type == "incomplete_generation":
477
+ return IncompleteGenerationError(error) # type: ignore
478
+ if error_type == "overloaded":
479
+ return OverloadedError(error) # type: ignore
480
+ if error_type == "validation":
481
+ return ValidationError(error) # type: ignore
482
+ return UnknownError(error) # type: ignore