huggingface-hub 0.21.4__py3-none-any.whl → 0.22.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of huggingface-hub might be problematic. Click here for more details.
- huggingface_hub/__init__.py +217 -1
- huggingface_hub/_commit_api.py +14 -15
- huggingface_hub/_inference_endpoints.py +12 -11
- huggingface_hub/_login.py +1 -0
- huggingface_hub/_multi_commits.py +1 -0
- huggingface_hub/_snapshot_download.py +9 -1
- huggingface_hub/_tensorboard_logger.py +1 -0
- huggingface_hub/_webhooks_payload.py +1 -0
- huggingface_hub/_webhooks_server.py +1 -0
- huggingface_hub/commands/_cli_utils.py +1 -0
- huggingface_hub/commands/delete_cache.py +1 -0
- huggingface_hub/commands/download.py +1 -0
- huggingface_hub/commands/env.py +1 -0
- huggingface_hub/commands/scan_cache.py +1 -0
- huggingface_hub/commands/upload.py +1 -0
- huggingface_hub/community.py +1 -0
- huggingface_hub/constants.py +3 -1
- huggingface_hub/errors.py +38 -0
- huggingface_hub/file_download.py +24 -24
- huggingface_hub/hf_api.py +47 -35
- huggingface_hub/hub_mixin.py +210 -54
- huggingface_hub/inference/_client.py +554 -239
- huggingface_hub/inference/_common.py +195 -41
- huggingface_hub/inference/_generated/_async_client.py +558 -239
- huggingface_hub/inference/_generated/types/__init__.py +115 -0
- huggingface_hub/inference/_generated/types/audio_classification.py +43 -0
- huggingface_hub/inference/_generated/types/audio_to_audio.py +31 -0
- huggingface_hub/inference/_generated/types/automatic_speech_recognition.py +116 -0
- huggingface_hub/inference/_generated/types/base.py +149 -0
- huggingface_hub/inference/_generated/types/chat_completion.py +106 -0
- huggingface_hub/inference/_generated/types/depth_estimation.py +29 -0
- huggingface_hub/inference/_generated/types/document_question_answering.py +85 -0
- huggingface_hub/inference/_generated/types/feature_extraction.py +19 -0
- huggingface_hub/inference/_generated/types/fill_mask.py +50 -0
- huggingface_hub/inference/_generated/types/image_classification.py +43 -0
- huggingface_hub/inference/_generated/types/image_segmentation.py +52 -0
- huggingface_hub/inference/_generated/types/image_to_image.py +55 -0
- huggingface_hub/inference/_generated/types/image_to_text.py +105 -0
- huggingface_hub/inference/_generated/types/object_detection.py +55 -0
- huggingface_hub/inference/_generated/types/question_answering.py +77 -0
- huggingface_hub/inference/_generated/types/sentence_similarity.py +28 -0
- huggingface_hub/inference/_generated/types/summarization.py +46 -0
- huggingface_hub/inference/_generated/types/table_question_answering.py +45 -0
- huggingface_hub/inference/_generated/types/text2text_generation.py +45 -0
- huggingface_hub/inference/_generated/types/text_classification.py +43 -0
- huggingface_hub/inference/_generated/types/text_generation.py +161 -0
- huggingface_hub/inference/_generated/types/text_to_audio.py +105 -0
- huggingface_hub/inference/_generated/types/text_to_image.py +57 -0
- huggingface_hub/inference/_generated/types/token_classification.py +53 -0
- huggingface_hub/inference/_generated/types/translation.py +46 -0
- huggingface_hub/inference/_generated/types/video_classification.py +47 -0
- huggingface_hub/inference/_generated/types/visual_question_answering.py +53 -0
- huggingface_hub/inference/_generated/types/zero_shot_classification.py +56 -0
- huggingface_hub/inference/_generated/types/zero_shot_image_classification.py +51 -0
- huggingface_hub/inference/_generated/types/zero_shot_object_detection.py +55 -0
- huggingface_hub/inference/_templating.py +105 -0
- huggingface_hub/inference/_types.py +4 -152
- huggingface_hub/keras_mixin.py +39 -17
- huggingface_hub/lfs.py +20 -8
- huggingface_hub/repocard.py +11 -3
- huggingface_hub/repocard_data.py +12 -2
- huggingface_hub/serialization/__init__.py +1 -0
- huggingface_hub/serialization/_base.py +1 -0
- huggingface_hub/serialization/_numpy.py +1 -0
- huggingface_hub/serialization/_tensorflow.py +1 -0
- huggingface_hub/serialization/_torch.py +1 -0
- huggingface_hub/utils/__init__.py +4 -1
- huggingface_hub/utils/_cache_manager.py +7 -0
- huggingface_hub/utils/_chunk_utils.py +1 -0
- huggingface_hub/utils/_datetime.py +1 -0
- huggingface_hub/utils/_errors.py +10 -1
- huggingface_hub/utils/_experimental.py +1 -0
- huggingface_hub/utils/_fixes.py +19 -3
- huggingface_hub/utils/_git_credential.py +1 -0
- huggingface_hub/utils/_headers.py +10 -3
- huggingface_hub/utils/_hf_folder.py +1 -0
- huggingface_hub/utils/_http.py +1 -0
- huggingface_hub/utils/_pagination.py +1 -0
- huggingface_hub/utils/_paths.py +1 -0
- huggingface_hub/utils/_runtime.py +22 -0
- huggingface_hub/utils/_subprocess.py +1 -0
- huggingface_hub/utils/_token.py +1 -0
- huggingface_hub/utils/_typing.py +29 -1
- huggingface_hub/utils/_validators.py +1 -0
- huggingface_hub/utils/endpoint_helpers.py +1 -0
- huggingface_hub/utils/logging.py +1 -1
- huggingface_hub/utils/sha.py +1 -0
- huggingface_hub/utils/tqdm.py +1 -0
- {huggingface_hub-0.21.4.dist-info → huggingface_hub-0.22.0rc0.dist-info}/METADATA +14 -15
- huggingface_hub-0.22.0rc0.dist-info/RECORD +113 -0
- {huggingface_hub-0.21.4.dist-info → huggingface_hub-0.22.0rc0.dist-info}/WHEEL +1 -1
- huggingface_hub/inference/_text_generation.py +0 -551
- huggingface_hub-0.21.4.dist-info/RECORD +0 -81
- {huggingface_hub-0.21.4.dist-info → huggingface_hub-0.22.0rc0.dist-info}/LICENSE +0 -0
- {huggingface_hub-0.21.4.dist-info → huggingface_hub-0.22.0rc0.dist-info}/entry_points.txt +0 -0
- {huggingface_hub-0.21.4.dist-info → huggingface_hub-0.22.0rc0.dist-info}/top_level.txt +0 -0
|
@@ -13,10 +13,12 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
"""Contains utilities used by both the sync and async inference clients."""
|
|
16
|
+
|
|
16
17
|
import base64
|
|
17
18
|
import io
|
|
18
19
|
import json
|
|
19
20
|
import logging
|
|
21
|
+
import time
|
|
20
22
|
from contextlib import contextmanager
|
|
21
23
|
from dataclasses import dataclass
|
|
22
24
|
from pathlib import Path
|
|
@@ -31,6 +33,7 @@ from typing import (
|
|
|
31
33
|
Iterable,
|
|
32
34
|
List,
|
|
33
35
|
Literal,
|
|
36
|
+
NoReturn,
|
|
34
37
|
Optional,
|
|
35
38
|
Set,
|
|
36
39
|
Union,
|
|
@@ -39,6 +42,15 @@ from typing import (
|
|
|
39
42
|
|
|
40
43
|
from requests import HTTPError
|
|
41
44
|
|
|
45
|
+
from huggingface_hub.errors import (
|
|
46
|
+
GenerationError,
|
|
47
|
+
IncompleteGenerationError,
|
|
48
|
+
OverloadedError,
|
|
49
|
+
TextGenerationError,
|
|
50
|
+
UnknownError,
|
|
51
|
+
ValidationError,
|
|
52
|
+
)
|
|
53
|
+
|
|
42
54
|
from ..constants import ENDPOINT
|
|
43
55
|
from ..utils import (
|
|
44
56
|
build_hf_headers,
|
|
@@ -48,7 +60,12 @@ from ..utils import (
|
|
|
48
60
|
is_numpy_available,
|
|
49
61
|
is_pillow_available,
|
|
50
62
|
)
|
|
51
|
-
from .
|
|
63
|
+
from ._generated.types import (
|
|
64
|
+
ChatCompletionStreamOutput,
|
|
65
|
+
ChatCompletionStreamOutputChoice,
|
|
66
|
+
ChatCompletionStreamOutputDelta,
|
|
67
|
+
TextGenerationStreamOutput,
|
|
68
|
+
)
|
|
52
69
|
|
|
53
70
|
|
|
54
71
|
if TYPE_CHECKING:
|
|
@@ -98,10 +115,6 @@ class ModelStatus:
|
|
|
98
115
|
framework: str
|
|
99
116
|
|
|
100
117
|
|
|
101
|
-
class InferenceTimeoutError(HTTPError, TimeoutError):
|
|
102
|
-
"""Error raised when a model is unavailable or the request times out."""
|
|
103
|
-
|
|
104
|
-
|
|
105
118
|
## IMPORT UTILS
|
|
106
119
|
|
|
107
120
|
|
|
@@ -163,13 +176,15 @@ def _first_or_none(items: List[Any]) -> Optional[Any]:
|
|
|
163
176
|
|
|
164
177
|
|
|
165
178
|
@overload
|
|
166
|
-
def _open_as_binary(
|
|
167
|
-
|
|
179
|
+
def _open_as_binary(
|
|
180
|
+
content: ContentT,
|
|
181
|
+
) -> ContextManager[BinaryT]: ... # means "if input is not None, output is not None"
|
|
168
182
|
|
|
169
183
|
|
|
170
184
|
@overload
|
|
171
|
-
def _open_as_binary(
|
|
172
|
-
|
|
185
|
+
def _open_as_binary(
|
|
186
|
+
content: Literal[None],
|
|
187
|
+
) -> ContextManager[Literal[None]]: ... # means "if input is None, output is None"
|
|
173
188
|
|
|
174
189
|
|
|
175
190
|
@contextmanager # type: ignore
|
|
@@ -253,48 +268,125 @@ def _bytes_to_image(content: bytes) -> "Image":
|
|
|
253
268
|
|
|
254
269
|
def _stream_text_generation_response(
|
|
255
270
|
bytes_output_as_lines: Iterable[bytes], details: bool
|
|
256
|
-
) -> Union[Iterable[str], Iterable[
|
|
271
|
+
) -> Union[Iterable[str], Iterable[TextGenerationStreamOutput]]:
|
|
272
|
+
"""Used in `InferenceClient.text_generation`."""
|
|
257
273
|
# Parse ServerSentEvents
|
|
258
274
|
for byte_payload in bytes_output_as_lines:
|
|
259
|
-
|
|
260
|
-
if
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
payload = byte_payload.decode("utf-8")
|
|
264
|
-
|
|
265
|
-
# Event data
|
|
266
|
-
if payload.startswith("data:"):
|
|
267
|
-
# Decode payload
|
|
268
|
-
json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
|
|
269
|
-
# Either an error as being returned
|
|
270
|
-
if json_payload.get("error") is not None:
|
|
271
|
-
raise _parse_text_generation_error(json_payload["error"], json_payload.get("error_type"))
|
|
272
|
-
# Or parse token payload
|
|
273
|
-
output = TextGenerationStreamResponse(**json_payload)
|
|
274
|
-
yield output.token.text if not details else output
|
|
275
|
+
output = _format_text_generation_stream_output(byte_payload, details)
|
|
276
|
+
if output is not None:
|
|
277
|
+
yield output
|
|
275
278
|
|
|
276
279
|
|
|
277
280
|
async def _async_stream_text_generation_response(
|
|
278
281
|
bytes_output_as_lines: AsyncIterable[bytes], details: bool
|
|
279
|
-
) -> Union[AsyncIterable[str], AsyncIterable[
|
|
282
|
+
) -> Union[AsyncIterable[str], AsyncIterable[TextGenerationStreamOutput]]:
|
|
283
|
+
"""Used in `AsyncInferenceClient.text_generation`."""
|
|
280
284
|
# Parse ServerSentEvents
|
|
281
285
|
async for byte_payload in bytes_output_as_lines:
|
|
282
|
-
|
|
283
|
-
if
|
|
284
|
-
|
|
286
|
+
output = _format_text_generation_stream_output(byte_payload, details)
|
|
287
|
+
if output is not None:
|
|
288
|
+
yield output
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
def _format_text_generation_stream_output(
|
|
292
|
+
byte_payload: bytes, details: bool
|
|
293
|
+
) -> Optional[Union[str, TextGenerationStreamOutput]]:
|
|
294
|
+
if not byte_payload.startswith(b"data:"):
|
|
295
|
+
return None # empty line
|
|
296
|
+
|
|
297
|
+
# Decode payload
|
|
298
|
+
payload = byte_payload.decode("utf-8")
|
|
299
|
+
json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
|
|
300
|
+
|
|
301
|
+
# Either an error as being returned
|
|
302
|
+
if json_payload.get("error") is not None:
|
|
303
|
+
raise _parse_text_generation_error(json_payload["error"], json_payload.get("error_type"))
|
|
304
|
+
|
|
305
|
+
# Or parse token payload
|
|
306
|
+
output = TextGenerationStreamOutput.parse_obj_as_instance(json_payload)
|
|
307
|
+
return output.token.text if not details else output
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
def _stream_chat_completion_response_from_text_generation(
|
|
311
|
+
text_generation_output: Iterable[TextGenerationStreamOutput],
|
|
312
|
+
) -> Iterable[ChatCompletionStreamOutput]:
|
|
313
|
+
"""Used in `InferenceClient.chat_completion`."""
|
|
314
|
+
created = int(time.time())
|
|
315
|
+
for item in text_generation_output:
|
|
316
|
+
yield _format_chat_completion_stream_output_from_text_generation(item, created)
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
async def _async_stream_chat_completion_response_from_text_generation(
|
|
320
|
+
text_generation_output: AsyncIterable[TextGenerationStreamOutput],
|
|
321
|
+
) -> AsyncIterable[ChatCompletionStreamOutput]:
|
|
322
|
+
"""Used in `AsyncInferenceClient.chat_completion`."""
|
|
323
|
+
created = int(time.time())
|
|
324
|
+
async for item in text_generation_output:
|
|
325
|
+
yield _format_chat_completion_stream_output_from_text_generation(item, created)
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
def _format_chat_completion_stream_output_from_text_generation(
|
|
329
|
+
item: TextGenerationStreamOutput, created: int
|
|
330
|
+
) -> ChatCompletionStreamOutput:
|
|
331
|
+
if item.details is None:
|
|
332
|
+
# new token generated => return delta
|
|
333
|
+
return ChatCompletionStreamOutput(
|
|
334
|
+
choices=[
|
|
335
|
+
ChatCompletionStreamOutputChoice(
|
|
336
|
+
delta=ChatCompletionStreamOutputDelta(
|
|
337
|
+
role="assistant",
|
|
338
|
+
content=item.token.text,
|
|
339
|
+
),
|
|
340
|
+
finish_reason=None,
|
|
341
|
+
index=0,
|
|
342
|
+
)
|
|
343
|
+
],
|
|
344
|
+
created=created,
|
|
345
|
+
)
|
|
346
|
+
else:
|
|
347
|
+
# generation is completed => return finish reason
|
|
348
|
+
return ChatCompletionStreamOutput(
|
|
349
|
+
choices=[
|
|
350
|
+
ChatCompletionStreamOutputChoice(
|
|
351
|
+
delta=ChatCompletionStreamOutputDelta(),
|
|
352
|
+
finish_reason=item.details.finish_reason,
|
|
353
|
+
index=0,
|
|
354
|
+
)
|
|
355
|
+
],
|
|
356
|
+
created=created,
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
def _stream_chat_completion_response_from_bytes(
|
|
361
|
+
bytes_lines: Iterable[bytes],
|
|
362
|
+
) -> Iterable[ChatCompletionStreamOutput]:
|
|
363
|
+
"""Used in `InferenceClient.chat_completion` if model is served with TGI."""
|
|
364
|
+
for item in bytes_lines:
|
|
365
|
+
output = _format_chat_completion_stream_output_from_text_generation_from_bytes(item)
|
|
366
|
+
if output is not None:
|
|
367
|
+
yield output
|
|
285
368
|
|
|
286
|
-
payload = byte_payload.decode("utf-8")
|
|
287
369
|
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
370
|
+
async def _async_stream_chat_completion_response_from_bytes(
|
|
371
|
+
bytes_lines: AsyncIterable[bytes],
|
|
372
|
+
) -> AsyncIterable[ChatCompletionStreamOutput]:
|
|
373
|
+
"""Used in `AsyncInferenceClient.chat_completion`."""
|
|
374
|
+
async for item in bytes_lines:
|
|
375
|
+
output = _format_chat_completion_stream_output_from_text_generation_from_bytes(item)
|
|
376
|
+
if output is not None:
|
|
377
|
+
yield output
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
def _format_chat_completion_stream_output_from_text_generation_from_bytes(
|
|
381
|
+
byte_payload: bytes,
|
|
382
|
+
) -> Optional[ChatCompletionStreamOutput]:
|
|
383
|
+
if not byte_payload.startswith(b"data:"):
|
|
384
|
+
return None # empty line
|
|
385
|
+
|
|
386
|
+
# Decode payload
|
|
387
|
+
payload = byte_payload.decode("utf-8")
|
|
388
|
+
json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
|
|
389
|
+
return ChatCompletionStreamOutput.parse_obj_as_instance(json_payload)
|
|
298
390
|
|
|
299
391
|
|
|
300
392
|
async def _async_yield_from(client: "ClientSession", response: "ClientResponse") -> AsyncIterable[bytes]:
|
|
@@ -314,6 +406,10 @@ async def _async_yield_from(client: "ClientSession", response: "ClientResponse")
|
|
|
314
406
|
# default API with a warning message. We remember for each model if it's a TGI server
|
|
315
407
|
# or not using `_NON_TGI_SERVERS` global variable.
|
|
316
408
|
#
|
|
409
|
+
# In addition, TGI servers have a built-in API route for chat-completion, which is not
|
|
410
|
+
# available on the default API. We use this route to provide a more consistent behavior
|
|
411
|
+
# when available.
|
|
412
|
+
#
|
|
317
413
|
# For more details, see https://github.com/huggingface/text-generation-inference and
|
|
318
414
|
# https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task.
|
|
319
415
|
|
|
@@ -326,3 +422,61 @@ def _set_as_non_tgi(model: Optional[str]) -> None:
|
|
|
326
422
|
|
|
327
423
|
def _is_tgi_server(model: Optional[str]) -> bool:
|
|
328
424
|
return model not in _NON_TGI_SERVERS
|
|
425
|
+
|
|
426
|
+
|
|
427
|
+
_NON_CHAT_COMPLETION_SERVER: Set[str] = set()
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
def _set_as_non_chat_completion_server(model: str) -> None:
|
|
431
|
+
print("Set as non chat completion", model)
|
|
432
|
+
_NON_CHAT_COMPLETION_SERVER.add(model)
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
def _is_chat_completion_server(model: str) -> bool:
|
|
436
|
+
return model not in _NON_CHAT_COMPLETION_SERVER
|
|
437
|
+
|
|
438
|
+
|
|
439
|
+
# TEXT GENERATION ERRORS
|
|
440
|
+
# ----------------------
|
|
441
|
+
# Text-generation errors are parsed separately to handle as much as possible the errors returned by the text generation
|
|
442
|
+
# inference project (https://github.com/huggingface/text-generation-inference).
|
|
443
|
+
# ----------------------
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
def raise_text_generation_error(http_error: HTTPError) -> NoReturn:
|
|
447
|
+
"""
|
|
448
|
+
Try to parse text-generation-inference error message and raise HTTPError in any case.
|
|
449
|
+
|
|
450
|
+
Args:
|
|
451
|
+
error (`HTTPError`):
|
|
452
|
+
The HTTPError that have been raised.
|
|
453
|
+
"""
|
|
454
|
+
# Try to parse a Text Generation Inference error
|
|
455
|
+
|
|
456
|
+
try:
|
|
457
|
+
# Hacky way to retrieve payload in case of aiohttp error
|
|
458
|
+
payload = getattr(http_error, "response_error_payload", None) or http_error.response.json()
|
|
459
|
+
error = payload.get("error")
|
|
460
|
+
error_type = payload.get("error_type")
|
|
461
|
+
except Exception: # no payload
|
|
462
|
+
raise http_error
|
|
463
|
+
|
|
464
|
+
# If error_type => more information than `hf_raise_for_status`
|
|
465
|
+
if error_type is not None:
|
|
466
|
+
exception = _parse_text_generation_error(error, error_type)
|
|
467
|
+
raise exception from http_error
|
|
468
|
+
|
|
469
|
+
# Otherwise, fallback to default error
|
|
470
|
+
raise http_error
|
|
471
|
+
|
|
472
|
+
|
|
473
|
+
def _parse_text_generation_error(error: Optional[str], error_type: Optional[str]) -> TextGenerationError:
|
|
474
|
+
if error_type == "generation":
|
|
475
|
+
return GenerationError(error) # type: ignore
|
|
476
|
+
if error_type == "incomplete_generation":
|
|
477
|
+
return IncompleteGenerationError(error) # type: ignore
|
|
478
|
+
if error_type == "overloaded":
|
|
479
|
+
return OverloadedError(error) # type: ignore
|
|
480
|
+
if error_type == "validation":
|
|
481
|
+
return ValidationError(error) # type: ignore
|
|
482
|
+
return UnknownError(error) # type: ignore
|