huggingface-hub 0.34.4__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of huggingface-hub might be problematic. Click here for more details.
- huggingface_hub/__init__.py +46 -45
- huggingface_hub/_commit_api.py +28 -28
- huggingface_hub/_commit_scheduler.py +11 -8
- huggingface_hub/_inference_endpoints.py +8 -8
- huggingface_hub/_jobs_api.py +167 -10
- huggingface_hub/_login.py +13 -39
- huggingface_hub/_oauth.py +8 -8
- huggingface_hub/_snapshot_download.py +14 -28
- huggingface_hub/_space_api.py +4 -4
- huggingface_hub/_tensorboard_logger.py +13 -14
- huggingface_hub/_upload_large_folder.py +15 -15
- huggingface_hub/_webhooks_payload.py +3 -3
- huggingface_hub/_webhooks_server.py +2 -2
- huggingface_hub/cli/_cli_utils.py +2 -2
- huggingface_hub/cli/auth.py +5 -6
- huggingface_hub/cli/cache.py +14 -20
- huggingface_hub/cli/download.py +4 -4
- huggingface_hub/cli/jobs.py +560 -11
- huggingface_hub/cli/lfs.py +4 -4
- huggingface_hub/cli/repo.py +7 -7
- huggingface_hub/cli/repo_files.py +2 -2
- huggingface_hub/cli/upload.py +4 -4
- huggingface_hub/cli/upload_large_folder.py +3 -3
- huggingface_hub/commands/_cli_utils.py +2 -2
- huggingface_hub/commands/delete_cache.py +13 -13
- huggingface_hub/commands/download.py +4 -13
- huggingface_hub/commands/lfs.py +4 -4
- huggingface_hub/commands/repo_files.py +2 -2
- huggingface_hub/commands/scan_cache.py +1 -1
- huggingface_hub/commands/tag.py +1 -3
- huggingface_hub/commands/upload.py +4 -4
- huggingface_hub/commands/upload_large_folder.py +3 -3
- huggingface_hub/commands/user.py +5 -6
- huggingface_hub/community.py +5 -5
- huggingface_hub/constants.py +3 -41
- huggingface_hub/dataclasses.py +16 -19
- huggingface_hub/errors.py +42 -29
- huggingface_hub/fastai_utils.py +8 -9
- huggingface_hub/file_download.py +153 -252
- huggingface_hub/hf_api.py +815 -600
- huggingface_hub/hf_file_system.py +98 -62
- huggingface_hub/hub_mixin.py +37 -57
- huggingface_hub/inference/_client.py +177 -325
- huggingface_hub/inference/_common.py +110 -124
- huggingface_hub/inference/_generated/_async_client.py +226 -432
- huggingface_hub/inference/_generated/types/automatic_speech_recognition.py +3 -3
- huggingface_hub/inference/_generated/types/base.py +10 -7
- huggingface_hub/inference/_generated/types/chat_completion.py +18 -16
- huggingface_hub/inference/_generated/types/depth_estimation.py +2 -2
- huggingface_hub/inference/_generated/types/document_question_answering.py +2 -2
- huggingface_hub/inference/_generated/types/feature_extraction.py +2 -2
- huggingface_hub/inference/_generated/types/fill_mask.py +2 -2
- huggingface_hub/inference/_generated/types/sentence_similarity.py +3 -3
- huggingface_hub/inference/_generated/types/summarization.py +2 -2
- huggingface_hub/inference/_generated/types/table_question_answering.py +4 -4
- huggingface_hub/inference/_generated/types/text2text_generation.py +2 -2
- huggingface_hub/inference/_generated/types/text_generation.py +10 -10
- huggingface_hub/inference/_generated/types/text_to_video.py +2 -2
- huggingface_hub/inference/_generated/types/token_classification.py +2 -2
- huggingface_hub/inference/_generated/types/translation.py +2 -2
- huggingface_hub/inference/_generated/types/zero_shot_classification.py +2 -2
- huggingface_hub/inference/_generated/types/zero_shot_image_classification.py +2 -2
- huggingface_hub/inference/_generated/types/zero_shot_object_detection.py +1 -3
- huggingface_hub/inference/_mcp/_cli_hacks.py +3 -3
- huggingface_hub/inference/_mcp/agent.py +3 -3
- huggingface_hub/inference/_mcp/cli.py +1 -1
- huggingface_hub/inference/_mcp/constants.py +2 -3
- huggingface_hub/inference/_mcp/mcp_client.py +58 -30
- huggingface_hub/inference/_mcp/types.py +10 -7
- huggingface_hub/inference/_mcp/utils.py +11 -7
- huggingface_hub/inference/_providers/__init__.py +2 -2
- huggingface_hub/inference/_providers/_common.py +49 -25
- huggingface_hub/inference/_providers/black_forest_labs.py +6 -6
- huggingface_hub/inference/_providers/cohere.py +3 -3
- huggingface_hub/inference/_providers/fal_ai.py +25 -25
- huggingface_hub/inference/_providers/featherless_ai.py +4 -4
- huggingface_hub/inference/_providers/fireworks_ai.py +3 -3
- huggingface_hub/inference/_providers/hf_inference.py +28 -20
- huggingface_hub/inference/_providers/hyperbolic.py +4 -4
- huggingface_hub/inference/_providers/nebius.py +10 -10
- huggingface_hub/inference/_providers/novita.py +5 -5
- huggingface_hub/inference/_providers/nscale.py +4 -4
- huggingface_hub/inference/_providers/replicate.py +15 -15
- huggingface_hub/inference/_providers/sambanova.py +6 -6
- huggingface_hub/inference/_providers/together.py +7 -7
- huggingface_hub/lfs.py +20 -31
- huggingface_hub/repocard.py +18 -18
- huggingface_hub/repocard_data.py +56 -56
- huggingface_hub/serialization/__init__.py +0 -1
- huggingface_hub/serialization/_base.py +9 -9
- huggingface_hub/serialization/_dduf.py +7 -7
- huggingface_hub/serialization/_torch.py +28 -28
- huggingface_hub/utils/__init__.py +10 -4
- huggingface_hub/utils/_auth.py +5 -5
- huggingface_hub/utils/_cache_manager.py +31 -31
- huggingface_hub/utils/_deprecation.py +1 -1
- huggingface_hub/utils/_dotenv.py +3 -3
- huggingface_hub/utils/_fixes.py +0 -10
- huggingface_hub/utils/_git_credential.py +4 -4
- huggingface_hub/utils/_headers.py +7 -29
- huggingface_hub/utils/_http.py +366 -208
- huggingface_hub/utils/_pagination.py +4 -4
- huggingface_hub/utils/_paths.py +5 -5
- huggingface_hub/utils/_runtime.py +15 -13
- huggingface_hub/utils/_safetensors.py +21 -21
- huggingface_hub/utils/_subprocess.py +9 -9
- huggingface_hub/utils/_telemetry.py +3 -3
- huggingface_hub/utils/_typing.py +25 -5
- huggingface_hub/utils/_validators.py +53 -72
- huggingface_hub/utils/_xet.py +16 -16
- huggingface_hub/utils/_xet_progress_reporting.py +32 -11
- huggingface_hub/utils/insecure_hashlib.py +3 -9
- huggingface_hub/utils/tqdm.py +3 -3
- {huggingface_hub-0.34.4.dist-info → huggingface_hub-1.0.0rc0.dist-info}/METADATA +18 -29
- huggingface_hub-1.0.0rc0.dist-info/RECORD +161 -0
- huggingface_hub/inference_api.py +0 -217
- huggingface_hub/keras_mixin.py +0 -500
- huggingface_hub/repository.py +0 -1477
- huggingface_hub/serialization/_tensorflow.py +0 -95
- huggingface_hub/utils/_hf_folder.py +0 -68
- huggingface_hub-0.34.4.dist-info/RECORD +0 -166
- {huggingface_hub-0.34.4.dist-info → huggingface_hub-1.0.0rc0.dist-info}/LICENSE +0 -0
- {huggingface_hub-0.34.4.dist-info → huggingface_hub-1.0.0rc0.dist-info}/WHEEL +0 -0
- {huggingface_hub-0.34.4.dist-info → huggingface_hub-1.0.0rc0.dist-info}/entry_points.txt +0 -0
- {huggingface_hub-0.34.4.dist-info → huggingface_hub-1.0.0rc0.dist-info}/top_level.txt +0 -0
|
@@ -2,24 +2,25 @@ import os
|
|
|
2
2
|
import re
|
|
3
3
|
import tempfile
|
|
4
4
|
from collections import deque
|
|
5
|
+
from contextlib import ExitStack
|
|
5
6
|
from dataclasses import dataclass, field
|
|
6
7
|
from datetime import datetime
|
|
7
8
|
from itertools import chain
|
|
8
9
|
from pathlib import Path
|
|
9
|
-
from typing import Any,
|
|
10
|
+
from typing import Any, Iterator, NoReturn, Optional, Union
|
|
10
11
|
from urllib.parse import quote, unquote
|
|
11
12
|
|
|
12
13
|
import fsspec
|
|
14
|
+
import httpx
|
|
13
15
|
from fsspec.callbacks import _DEFAULT_CALLBACK, NoOpCallback, TqdmCallback
|
|
14
16
|
from fsspec.utils import isfilelike
|
|
15
|
-
from requests import Response
|
|
16
17
|
|
|
17
18
|
from . import constants
|
|
18
19
|
from ._commit_api import CommitOperationCopy, CommitOperationDelete
|
|
19
|
-
from .errors import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
|
20
|
+
from .errors import EntryNotFoundError, HfHubHTTPError, RepositoryNotFoundError, RevisionNotFoundError
|
|
20
21
|
from .file_download import hf_hub_url, http_get
|
|
21
22
|
from .hf_api import HfApi, LastCommitInfo, RepoFile
|
|
22
|
-
from .utils import HFValidationError, hf_raise_for_status, http_backoff
|
|
23
|
+
from .utils import HFValidationError, hf_raise_for_status, http_backoff, http_stream_backoff
|
|
23
24
|
|
|
24
25
|
|
|
25
26
|
# Regex used to match special revisions with "/" in them (see #1710)
|
|
@@ -113,13 +114,13 @@ class HfFileSystem(fsspec.AbstractFileSystem):
|
|
|
113
114
|
# Maps (repo_type, repo_id, revision) to a 2-tuple with:
|
|
114
115
|
# * the 1st element indicating whether the repositoy and the revision exist
|
|
115
116
|
# * the 2nd element being the exception raised if the repository or revision doesn't exist
|
|
116
|
-
self._repo_and_revision_exists_cache:
|
|
117
|
-
|
|
117
|
+
self._repo_and_revision_exists_cache: dict[
|
|
118
|
+
tuple[str, str, Optional[str]], tuple[bool, Optional[Exception]]
|
|
118
119
|
] = {}
|
|
119
120
|
|
|
120
121
|
def _repo_and_revision_exist(
|
|
121
122
|
self, repo_type: str, repo_id: str, revision: Optional[str]
|
|
122
|
-
) ->
|
|
123
|
+
) -> tuple[bool, Optional[Exception]]:
|
|
123
124
|
if (repo_type, repo_id, revision) not in self._repo_and_revision_exists_cache:
|
|
124
125
|
try:
|
|
125
126
|
self._api.repo_info(
|
|
@@ -338,7 +339,7 @@ class HfFileSystem(fsspec.AbstractFileSystem):
|
|
|
338
339
|
|
|
339
340
|
def ls(
|
|
340
341
|
self, path: str, detail: bool = True, refresh: bool = False, revision: Optional[str] = None, **kwargs
|
|
341
|
-
) ->
|
|
342
|
+
) -> list[Union[str, dict[str, Any]]]:
|
|
342
343
|
"""
|
|
343
344
|
List the contents of a directory.
|
|
344
345
|
|
|
@@ -362,7 +363,7 @@ class HfFileSystem(fsspec.AbstractFileSystem):
|
|
|
362
363
|
The git revision to list from.
|
|
363
364
|
|
|
364
365
|
Returns:
|
|
365
|
-
`
|
|
366
|
+
`list[Union[str, dict[str, Any]]]`: List of file paths (if detail=False) or list of file information
|
|
366
367
|
dictionaries (if detail=True).
|
|
367
368
|
"""
|
|
368
369
|
resolved_path = self.resolve_path(path, revision=revision)
|
|
@@ -483,7 +484,7 @@ class HfFileSystem(fsspec.AbstractFileSystem):
|
|
|
483
484
|
out.append(cache_path_info)
|
|
484
485
|
return out
|
|
485
486
|
|
|
486
|
-
def walk(self, path: str, *args, **kwargs) -> Iterator[
|
|
487
|
+
def walk(self, path: str, *args, **kwargs) -> Iterator[tuple[str, list[str], list[str]]]:
|
|
487
488
|
"""
|
|
488
489
|
Return all files below the given path.
|
|
489
490
|
|
|
@@ -494,12 +495,12 @@ class HfFileSystem(fsspec.AbstractFileSystem):
|
|
|
494
495
|
Root path to list files from.
|
|
495
496
|
|
|
496
497
|
Returns:
|
|
497
|
-
`Iterator[
|
|
498
|
+
`Iterator[tuple[str, list[str], list[str]]]`: An iterator of (path, list of directory names, list of file names) tuples.
|
|
498
499
|
"""
|
|
499
500
|
path = self.resolve_path(path, revision=kwargs.get("revision")).unresolve()
|
|
500
501
|
yield from super().walk(path, *args, **kwargs)
|
|
501
502
|
|
|
502
|
-
def glob(self, path: str, **kwargs) ->
|
|
503
|
+
def glob(self, path: str, **kwargs) -> list[str]:
|
|
503
504
|
"""
|
|
504
505
|
Find files by glob-matching.
|
|
505
506
|
|
|
@@ -510,7 +511,7 @@ class HfFileSystem(fsspec.AbstractFileSystem):
|
|
|
510
511
|
Path pattern to match.
|
|
511
512
|
|
|
512
513
|
Returns:
|
|
513
|
-
`
|
|
514
|
+
`list[str]`: List of paths matching the pattern.
|
|
514
515
|
"""
|
|
515
516
|
path = self.resolve_path(path, revision=kwargs.get("revision")).unresolve()
|
|
516
517
|
return super().glob(path, **kwargs)
|
|
@@ -524,7 +525,7 @@ class HfFileSystem(fsspec.AbstractFileSystem):
|
|
|
524
525
|
refresh: bool = False,
|
|
525
526
|
revision: Optional[str] = None,
|
|
526
527
|
**kwargs,
|
|
527
|
-
) -> Union[
|
|
528
|
+
) -> Union[list[str], dict[str, dict[str, Any]]]:
|
|
528
529
|
"""
|
|
529
530
|
List all files below path.
|
|
530
531
|
|
|
@@ -545,7 +546,7 @@ class HfFileSystem(fsspec.AbstractFileSystem):
|
|
|
545
546
|
The git revision to list from.
|
|
546
547
|
|
|
547
548
|
Returns:
|
|
548
|
-
`Union[
|
|
549
|
+
`Union[list[str], dict[str, dict[str, Any]]]`: List of paths or dict of file information.
|
|
549
550
|
"""
|
|
550
551
|
if maxdepth:
|
|
551
552
|
return super().find(
|
|
@@ -650,7 +651,7 @@ class HfFileSystem(fsspec.AbstractFileSystem):
|
|
|
650
651
|
info = self.info(path, **{**kwargs, "expand_info": True})
|
|
651
652
|
return info["last_commit"]["date"]
|
|
652
653
|
|
|
653
|
-
def info(self, path: str, refresh: bool = False, revision: Optional[str] = None, **kwargs) ->
|
|
654
|
+
def info(self, path: str, refresh: bool = False, revision: Optional[str] = None, **kwargs) -> dict[str, Any]:
|
|
654
655
|
"""
|
|
655
656
|
Get information about a file or directory.
|
|
656
657
|
|
|
@@ -671,7 +672,7 @@ class HfFileSystem(fsspec.AbstractFileSystem):
|
|
|
671
672
|
The git revision to get info from.
|
|
672
673
|
|
|
673
674
|
Returns:
|
|
674
|
-
`
|
|
675
|
+
`dict[str, Any]`: Dictionary containing file information (type, size, commit info, etc.).
|
|
675
676
|
|
|
676
677
|
"""
|
|
677
678
|
resolved_path = self.resolve_path(path, revision=revision)
|
|
@@ -896,7 +897,7 @@ class HfFileSystem(fsspec.AbstractFileSystem):
|
|
|
896
897
|
repo_type=resolve_remote_path.repo_type,
|
|
897
898
|
endpoint=self.endpoint,
|
|
898
899
|
),
|
|
899
|
-
temp_file=outfile,
|
|
900
|
+
temp_file=outfile, # type: ignore[arg-type]
|
|
900
901
|
displayed_filename=rpath,
|
|
901
902
|
expected_size=expected_size,
|
|
902
903
|
resume_size=0,
|
|
@@ -1039,8 +1040,9 @@ class HfFileSystemStreamFile(fsspec.spec.AbstractBufferedFile):
|
|
|
1039
1040
|
super().__init__(
|
|
1040
1041
|
fs, self.resolved_path.unresolve(), mode=mode, block_size=block_size, cache_type=cache_type, **kwargs
|
|
1041
1042
|
)
|
|
1042
|
-
self.response: Optional[Response] = None
|
|
1043
|
+
self.response: Optional[httpx.Response] = None
|
|
1043
1044
|
self.fs: HfFileSystem
|
|
1045
|
+
self._exit_stack = ExitStack()
|
|
1044
1046
|
|
|
1045
1047
|
def seek(self, loc: int, whence: int = 0):
|
|
1046
1048
|
if loc == 0 and whence == 1:
|
|
@@ -1050,53 +1052,32 @@ class HfFileSystemStreamFile(fsspec.spec.AbstractBufferedFile):
|
|
|
1050
1052
|
raise ValueError("Cannot seek streaming HF file")
|
|
1051
1053
|
|
|
1052
1054
|
def read(self, length: int = -1):
|
|
1053
|
-
|
|
1055
|
+
"""Read the remote file.
|
|
1056
|
+
|
|
1057
|
+
If the file is already open, we reuse the connection.
|
|
1058
|
+
Otherwise, open a new connection and read from it.
|
|
1059
|
+
|
|
1060
|
+
If reading the stream fails, we retry with a new connection.
|
|
1061
|
+
"""
|
|
1054
1062
|
if self.response is None:
|
|
1055
|
-
|
|
1056
|
-
repo_id=self.resolved_path.repo_id,
|
|
1057
|
-
revision=self.resolved_path.revision,
|
|
1058
|
-
filename=self.resolved_path.path_in_repo,
|
|
1059
|
-
repo_type=self.resolved_path.repo_type,
|
|
1060
|
-
endpoint=self.fs.endpoint,
|
|
1061
|
-
)
|
|
1062
|
-
self.response = http_backoff(
|
|
1063
|
-
"GET",
|
|
1064
|
-
url,
|
|
1065
|
-
headers=self.fs._api._build_hf_headers(),
|
|
1066
|
-
retry_on_status_codes=(500, 502, 503, 504),
|
|
1067
|
-
stream=True,
|
|
1068
|
-
timeout=constants.HF_HUB_DOWNLOAD_TIMEOUT,
|
|
1069
|
-
)
|
|
1070
|
-
hf_raise_for_status(self.response)
|
|
1071
|
-
try:
|
|
1072
|
-
out = self.response.raw.read(*read_args)
|
|
1073
|
-
except Exception:
|
|
1074
|
-
self.response.close()
|
|
1063
|
+
self._open_connection()
|
|
1075
1064
|
|
|
1076
|
-
|
|
1077
|
-
|
|
1078
|
-
repo_id=self.resolved_path.repo_id,
|
|
1079
|
-
revision=self.resolved_path.revision,
|
|
1080
|
-
filename=self.resolved_path.path_in_repo,
|
|
1081
|
-
repo_type=self.resolved_path.repo_type,
|
|
1082
|
-
endpoint=self.fs.endpoint,
|
|
1083
|
-
)
|
|
1084
|
-
self.response = http_backoff(
|
|
1085
|
-
"GET",
|
|
1086
|
-
url,
|
|
1087
|
-
headers={"Range": "bytes=%d-" % self.loc, **self.fs._api._build_hf_headers()},
|
|
1088
|
-
retry_on_status_codes=(500, 502, 503, 504),
|
|
1089
|
-
stream=True,
|
|
1090
|
-
timeout=constants.HF_HUB_DOWNLOAD_TIMEOUT,
|
|
1091
|
-
)
|
|
1092
|
-
hf_raise_for_status(self.response)
|
|
1065
|
+
retried_once = False
|
|
1066
|
+
while True:
|
|
1093
1067
|
try:
|
|
1094
|
-
|
|
1068
|
+
if self.response is None:
|
|
1069
|
+
return b"" # Already read the entire file
|
|
1070
|
+
out = _partial_read(self.response, length)
|
|
1071
|
+
self.loc += len(out)
|
|
1072
|
+
return out
|
|
1095
1073
|
except Exception:
|
|
1096
|
-
self.response
|
|
1097
|
-
|
|
1098
|
-
|
|
1099
|
-
|
|
1074
|
+
if self.response is not None:
|
|
1075
|
+
self.response.close()
|
|
1076
|
+
if retried_once: # Already retried once, give up
|
|
1077
|
+
raise
|
|
1078
|
+
# First failure, retry with range header
|
|
1079
|
+
self._open_connection()
|
|
1080
|
+
retried_once = True
|
|
1100
1081
|
|
|
1101
1082
|
def url(self) -> str:
|
|
1102
1083
|
return self.fs.url(self.path)
|
|
@@ -1105,11 +1086,43 @@ class HfFileSystemStreamFile(fsspec.spec.AbstractBufferedFile):
|
|
|
1105
1086
|
if not hasattr(self, "resolved_path"):
|
|
1106
1087
|
# Means that the constructor failed. Nothing to do.
|
|
1107
1088
|
return
|
|
1089
|
+
self._exit_stack.close()
|
|
1108
1090
|
return super().__del__()
|
|
1109
1091
|
|
|
1110
1092
|
def __reduce__(self):
|
|
1111
1093
|
return reopen, (self.fs, self.path, self.mode, self.blocksize, self.cache.name)
|
|
1112
1094
|
|
|
1095
|
+
def _open_connection(self):
|
|
1096
|
+
"""Open a connection to the remote file."""
|
|
1097
|
+
url = hf_hub_url(
|
|
1098
|
+
repo_id=self.resolved_path.repo_id,
|
|
1099
|
+
revision=self.resolved_path.revision,
|
|
1100
|
+
filename=self.resolved_path.path_in_repo,
|
|
1101
|
+
repo_type=self.resolved_path.repo_type,
|
|
1102
|
+
endpoint=self.fs.endpoint,
|
|
1103
|
+
)
|
|
1104
|
+
headers = self.fs._api._build_hf_headers()
|
|
1105
|
+
if self.loc > 0:
|
|
1106
|
+
headers["Range"] = f"bytes={self.loc}-"
|
|
1107
|
+
self.response = self._exit_stack.enter_context(
|
|
1108
|
+
http_stream_backoff(
|
|
1109
|
+
"GET",
|
|
1110
|
+
url,
|
|
1111
|
+
headers=headers,
|
|
1112
|
+
retry_on_status_codes=(500, 502, 503, 504),
|
|
1113
|
+
timeout=constants.HF_HUB_DOWNLOAD_TIMEOUT,
|
|
1114
|
+
)
|
|
1115
|
+
)
|
|
1116
|
+
|
|
1117
|
+
try:
|
|
1118
|
+
hf_raise_for_status(self.response)
|
|
1119
|
+
except HfHubHTTPError as e:
|
|
1120
|
+
if e.response.status_code == 416:
|
|
1121
|
+
# Range not satisfiable => means that we have already read the entire file
|
|
1122
|
+
self.response = None
|
|
1123
|
+
return
|
|
1124
|
+
raise
|
|
1125
|
+
|
|
1113
1126
|
|
|
1114
1127
|
def safe_revision(revision: str) -> str:
|
|
1115
1128
|
return revision if SPECIAL_REFS_REVISION_REGEX.match(revision) else safe_quote(revision)
|
|
@@ -1132,3 +1145,26 @@ def _raise_file_not_found(path: str, err: Optional[Exception]) -> NoReturn:
|
|
|
1132
1145
|
|
|
1133
1146
|
def reopen(fs: HfFileSystem, path: str, mode: str, block_size: int, cache_type: str):
|
|
1134
1147
|
return fs.open(path, mode=mode, block_size=block_size, cache_type=cache_type)
|
|
1148
|
+
|
|
1149
|
+
|
|
1150
|
+
def _partial_read(response: httpx.Response, length: int = -1) -> bytes:
|
|
1151
|
+
"""
|
|
1152
|
+
Read up to `length` bytes from a streamed response.
|
|
1153
|
+
If length == -1, read until EOF.
|
|
1154
|
+
"""
|
|
1155
|
+
buf = bytearray()
|
|
1156
|
+
if length < -1:
|
|
1157
|
+
raise ValueError("length must be -1 or >= 0")
|
|
1158
|
+
if length == 0:
|
|
1159
|
+
return b""
|
|
1160
|
+
if length == -1:
|
|
1161
|
+
for chunk in response.iter_bytes():
|
|
1162
|
+
buf.extend(chunk)
|
|
1163
|
+
return bytes(buf)
|
|
1164
|
+
|
|
1165
|
+
for chunk in response.iter_bytes(chunk_size=length):
|
|
1166
|
+
buf.extend(chunk)
|
|
1167
|
+
if len(buf) >= length:
|
|
1168
|
+
return bytes(buf[:length])
|
|
1169
|
+
|
|
1170
|
+
return bytes(buf) # may be < length if response ended
|
huggingface_hub/hub_mixin.py
CHANGED
|
@@ -3,7 +3,7 @@ import json
|
|
|
3
3
|
import os
|
|
4
4
|
from dataclasses import Field, asdict, dataclass, is_dataclass
|
|
5
5
|
from pathlib import Path
|
|
6
|
-
from typing import Any, Callable, ClassVar,
|
|
6
|
+
from typing import Any, Callable, ClassVar, Optional, Protocol, Type, TypeVar, Union
|
|
7
7
|
|
|
8
8
|
import packaging.version
|
|
9
9
|
|
|
@@ -38,7 +38,7 @@ logger = logging.get_logger(__name__)
|
|
|
38
38
|
|
|
39
39
|
# Type alias for dataclass instances, copied from https://github.com/python/typeshed/blob/9f28171658b9ca6c32a7cb93fbb99fc92b17858b/stdlib/_typeshed/__init__.pyi#L349
|
|
40
40
|
class DataclassInstance(Protocol):
|
|
41
|
-
__dataclass_fields__: ClassVar[
|
|
41
|
+
__dataclass_fields__: ClassVar[dict[str, Field]]
|
|
42
42
|
|
|
43
43
|
|
|
44
44
|
# Generic variable that is either ModelHubMixin or a subclass thereof
|
|
@@ -47,7 +47,7 @@ T = TypeVar("T", bound="ModelHubMixin")
|
|
|
47
47
|
ARGS_T = TypeVar("ARGS_T")
|
|
48
48
|
ENCODER_T = Callable[[ARGS_T], Any]
|
|
49
49
|
DECODER_T = Callable[[Any], ARGS_T]
|
|
50
|
-
CODER_T =
|
|
50
|
+
CODER_T = tuple[ENCODER_T, DECODER_T]
|
|
51
51
|
|
|
52
52
|
|
|
53
53
|
DEFAULT_MODEL_CARD = """
|
|
@@ -96,7 +96,7 @@ class ModelHubMixin:
|
|
|
96
96
|
URL of the library documentation. Used to generate model card.
|
|
97
97
|
model_card_template (`str`, *optional*):
|
|
98
98
|
Template of the model card. Used to generate model card. Defaults to a generic template.
|
|
99
|
-
language (`str` or `
|
|
99
|
+
language (`str` or `list[str]`, *optional*):
|
|
100
100
|
Language supported by the library. Used to generate model card.
|
|
101
101
|
library_name (`str`, *optional*):
|
|
102
102
|
Name of the library integrating ModelHubMixin. Used to generate model card.
|
|
@@ -113,9 +113,9 @@ class ModelHubMixin:
|
|
|
113
113
|
E.g: "https://coqui.ai/cpml".
|
|
114
114
|
pipeline_tag (`str`, *optional*):
|
|
115
115
|
Tag of the pipeline. Used to generate model card. E.g. "text-classification".
|
|
116
|
-
tags (`
|
|
116
|
+
tags (`list[str]`, *optional*):
|
|
117
117
|
Tags to be added to the model card. Used to generate model card. E.g. ["computer-vision"]
|
|
118
|
-
coders (`
|
|
118
|
+
coders (`dict[Type, tuple[Callable, Callable]]`, *optional*):
|
|
119
119
|
Dictionary of custom types and their encoders/decoders. Used to encode/decode arguments that are not
|
|
120
120
|
jsonable by default. E.g dataclasses, argparse.Namespace, OmegaConf, etc.
|
|
121
121
|
|
|
@@ -145,12 +145,10 @@ class ModelHubMixin:
|
|
|
145
145
|
...
|
|
146
146
|
... @classmethod
|
|
147
147
|
... def from_pretrained(
|
|
148
|
-
... cls:
|
|
148
|
+
... cls: type[T],
|
|
149
149
|
... pretrained_model_name_or_path: Union[str, Path],
|
|
150
150
|
... *,
|
|
151
151
|
... force_download: bool = False,
|
|
152
|
-
... resume_download: Optional[bool] = None,
|
|
153
|
-
... proxies: Optional[Dict] = None,
|
|
154
152
|
... token: Optional[Union[str, bool]] = None,
|
|
155
153
|
... cache_dir: Optional[Union[str, Path]] = None,
|
|
156
154
|
... local_files_only: bool = False,
|
|
@@ -188,10 +186,10 @@ class ModelHubMixin:
|
|
|
188
186
|
_hub_mixin_info: MixinInfo
|
|
189
187
|
# ^ information about the library integrating ModelHubMixin (used to generate model card)
|
|
190
188
|
_hub_mixin_inject_config: bool # whether `_from_pretrained` expects `config` or not
|
|
191
|
-
_hub_mixin_init_parameters:
|
|
192
|
-
_hub_mixin_jsonable_default_values:
|
|
193
|
-
_hub_mixin_jsonable_custom_types:
|
|
194
|
-
_hub_mixin_coders:
|
|
189
|
+
_hub_mixin_init_parameters: dict[str, inspect.Parameter] # __init__ parameters
|
|
190
|
+
_hub_mixin_jsonable_default_values: dict[str, Any] # default values for __init__ parameters
|
|
191
|
+
_hub_mixin_jsonable_custom_types: tuple[Type, ...] # custom types that can be encoded/decoded
|
|
192
|
+
_hub_mixin_coders: dict[Type, CODER_T] # encoders/decoders for custom types
|
|
195
193
|
# ^ internal values to handle config
|
|
196
194
|
|
|
197
195
|
def __init_subclass__(
|
|
@@ -204,16 +202,16 @@ class ModelHubMixin:
|
|
|
204
202
|
# Model card template
|
|
205
203
|
model_card_template: str = DEFAULT_MODEL_CARD,
|
|
206
204
|
# Model card metadata
|
|
207
|
-
language: Optional[
|
|
205
|
+
language: Optional[list[str]] = None,
|
|
208
206
|
library_name: Optional[str] = None,
|
|
209
207
|
license: Optional[str] = None,
|
|
210
208
|
license_name: Optional[str] = None,
|
|
211
209
|
license_link: Optional[str] = None,
|
|
212
210
|
pipeline_tag: Optional[str] = None,
|
|
213
|
-
tags: Optional[
|
|
211
|
+
tags: Optional[list[str]] = None,
|
|
214
212
|
# How to encode/decode arguments with custom type into a JSON config?
|
|
215
213
|
coders: Optional[
|
|
216
|
-
|
|
214
|
+
dict[Type, CODER_T]
|
|
217
215
|
# Key is a type.
|
|
218
216
|
# Value is a tuple (encoder, decoder).
|
|
219
217
|
# Example: {MyCustomType: (lambda x: x.value, lambda data: MyCustomType(data))}
|
|
@@ -266,12 +264,14 @@ class ModelHubMixin:
|
|
|
266
264
|
if pipeline_tag is not None:
|
|
267
265
|
info.model_card_data.pipeline_tag = pipeline_tag
|
|
268
266
|
if tags is not None:
|
|
267
|
+
normalized_tags = list(tags)
|
|
269
268
|
if info.model_card_data.tags is not None:
|
|
270
|
-
info.model_card_data.tags.extend(
|
|
269
|
+
info.model_card_data.tags.extend(normalized_tags)
|
|
271
270
|
else:
|
|
272
|
-
info.model_card_data.tags =
|
|
271
|
+
info.model_card_data.tags = normalized_tags
|
|
273
272
|
|
|
274
|
-
info.model_card_data.tags
|
|
273
|
+
if info.model_card_data.tags is not None:
|
|
274
|
+
info.model_card_data.tags = sorted(set(info.model_card_data.tags))
|
|
275
275
|
|
|
276
276
|
# Handle encoders/decoders for args
|
|
277
277
|
cls._hub_mixin_coders = coders or {}
|
|
@@ -286,7 +286,7 @@ class ModelHubMixin:
|
|
|
286
286
|
}
|
|
287
287
|
cls._hub_mixin_inject_config = "config" in inspect.signature(cls._from_pretrained).parameters
|
|
288
288
|
|
|
289
|
-
def __new__(cls:
|
|
289
|
+
def __new__(cls: type[T], *args, **kwargs) -> T:
|
|
290
290
|
"""Create a new instance of the class and handle config.
|
|
291
291
|
|
|
292
292
|
3 cases:
|
|
@@ -362,7 +362,7 @@ class ModelHubMixin:
|
|
|
362
362
|
return arg
|
|
363
363
|
|
|
364
364
|
@classmethod
|
|
365
|
-
def _decode_arg(cls, expected_type:
|
|
365
|
+
def _decode_arg(cls, expected_type: type[ARGS_T], value: Any) -> Optional[ARGS_T]:
|
|
366
366
|
"""Decode a JSON serializable value into an argument."""
|
|
367
367
|
if is_simple_optional_type(expected_type):
|
|
368
368
|
if value is None:
|
|
@@ -385,7 +385,7 @@ class ModelHubMixin:
|
|
|
385
385
|
config: Optional[Union[dict, DataclassInstance]] = None,
|
|
386
386
|
repo_id: Optional[str] = None,
|
|
387
387
|
push_to_hub: bool = False,
|
|
388
|
-
model_card_kwargs: Optional[
|
|
388
|
+
model_card_kwargs: Optional[dict[str, Any]] = None,
|
|
389
389
|
**push_to_hub_kwargs,
|
|
390
390
|
) -> Optional[str]:
|
|
391
391
|
"""
|
|
@@ -401,7 +401,7 @@ class ModelHubMixin:
|
|
|
401
401
|
repo_id (`str`, *optional*):
|
|
402
402
|
ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if
|
|
403
403
|
not provided.
|
|
404
|
-
model_card_kwargs (`
|
|
404
|
+
model_card_kwargs (`dict[str, Any]`, *optional*):
|
|
405
405
|
Additional arguments passed to the model card template to customize the model card.
|
|
406
406
|
push_to_hub_kwargs:
|
|
407
407
|
Additional key word arguments passed along to the [`~ModelHubMixin.push_to_hub`] method.
|
|
@@ -460,12 +460,10 @@ class ModelHubMixin:
|
|
|
460
460
|
@classmethod
|
|
461
461
|
@validate_hf_hub_args
|
|
462
462
|
def from_pretrained(
|
|
463
|
-
cls:
|
|
463
|
+
cls: type[T],
|
|
464
464
|
pretrained_model_name_or_path: Union[str, Path],
|
|
465
465
|
*,
|
|
466
466
|
force_download: bool = False,
|
|
467
|
-
resume_download: Optional[bool] = None,
|
|
468
|
-
proxies: Optional[Dict] = None,
|
|
469
467
|
token: Optional[Union[str, bool]] = None,
|
|
470
468
|
cache_dir: Optional[Union[str, Path]] = None,
|
|
471
469
|
local_files_only: bool = False,
|
|
@@ -486,9 +484,6 @@ class ModelHubMixin:
|
|
|
486
484
|
force_download (`bool`, *optional*, defaults to `False`):
|
|
487
485
|
Whether to force (re-)downloading the model weights and configuration files from the Hub, overriding
|
|
488
486
|
the existing cache.
|
|
489
|
-
proxies (`Dict[str, str]`, *optional*):
|
|
490
|
-
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
|
491
|
-
'http://hostname': 'foo.bar:4012'}`. The proxies are used on every request.
|
|
492
487
|
token (`str` or `bool`, *optional*):
|
|
493
488
|
The token to use as HTTP bearer authorization for remote files. By default, it will use the token
|
|
494
489
|
cached when running `hf auth login`.
|
|
@@ -496,7 +491,7 @@ class ModelHubMixin:
|
|
|
496
491
|
Path to the folder where cached files are stored.
|
|
497
492
|
local_files_only (`bool`, *optional*, defaults to `False`):
|
|
498
493
|
If `True`, avoid downloading the file and return the path to the local cached file if it exists.
|
|
499
|
-
model_kwargs (`
|
|
494
|
+
model_kwargs (`dict`, *optional*):
|
|
500
495
|
Additional kwargs to pass to the model during initialization.
|
|
501
496
|
"""
|
|
502
497
|
model_id = str(pretrained_model_name_or_path)
|
|
@@ -514,8 +509,6 @@ class ModelHubMixin:
|
|
|
514
509
|
revision=revision,
|
|
515
510
|
cache_dir=cache_dir,
|
|
516
511
|
force_download=force_download,
|
|
517
|
-
proxies=proxies,
|
|
518
|
-
resume_download=resume_download,
|
|
519
512
|
token=token,
|
|
520
513
|
local_files_only=local_files_only,
|
|
521
514
|
)
|
|
@@ -555,7 +548,7 @@ class ModelHubMixin:
|
|
|
555
548
|
if key not in model_kwargs and key in config:
|
|
556
549
|
model_kwargs[key] = config[key]
|
|
557
550
|
elif any(param.kind == inspect.Parameter.VAR_KEYWORD for param in cls._hub_mixin_init_parameters.values()):
|
|
558
|
-
for key, value in config.items():
|
|
551
|
+
for key, value in config.items(): # type: ignore[union-attr]
|
|
559
552
|
if key not in model_kwargs:
|
|
560
553
|
model_kwargs[key] = value
|
|
561
554
|
|
|
@@ -568,8 +561,6 @@ class ModelHubMixin:
|
|
|
568
561
|
revision=revision,
|
|
569
562
|
cache_dir=cache_dir,
|
|
570
563
|
force_download=force_download,
|
|
571
|
-
proxies=proxies,
|
|
572
|
-
resume_download=resume_download,
|
|
573
564
|
local_files_only=local_files_only,
|
|
574
565
|
token=token,
|
|
575
566
|
**model_kwargs,
|
|
@@ -584,14 +575,12 @@ class ModelHubMixin:
|
|
|
584
575
|
|
|
585
576
|
@classmethod
|
|
586
577
|
def _from_pretrained(
|
|
587
|
-
cls:
|
|
578
|
+
cls: type[T],
|
|
588
579
|
*,
|
|
589
580
|
model_id: str,
|
|
590
581
|
revision: Optional[str],
|
|
591
582
|
cache_dir: Optional[Union[str, Path]],
|
|
592
583
|
force_download: bool,
|
|
593
|
-
proxies: Optional[Dict],
|
|
594
|
-
resume_download: Optional[bool],
|
|
595
584
|
local_files_only: bool,
|
|
596
585
|
token: Optional[Union[str, bool]],
|
|
597
586
|
**model_kwargs,
|
|
@@ -614,9 +603,6 @@ class ModelHubMixin:
|
|
|
614
603
|
force_download (`bool`, *optional*, defaults to `False`):
|
|
615
604
|
Whether to force (re-)downloading the model weights and configuration files from the Hub, overriding
|
|
616
605
|
the existing cache.
|
|
617
|
-
proxies (`Dict[str, str]`, *optional*):
|
|
618
|
-
A dictionary of proxy servers to use by protocol or endpoint (e.g., `{'http': 'foo.bar:3128',
|
|
619
|
-
'http://hostname': 'foo.bar:4012'}`).
|
|
620
606
|
token (`str` or `bool`, *optional*):
|
|
621
607
|
The token to use as HTTP bearer authorization for remote files. By default, it will use the token
|
|
622
608
|
cached when running `hf auth login`.
|
|
@@ -640,10 +626,10 @@ class ModelHubMixin:
|
|
|
640
626
|
token: Optional[str] = None,
|
|
641
627
|
branch: Optional[str] = None,
|
|
642
628
|
create_pr: Optional[bool] = None,
|
|
643
|
-
allow_patterns: Optional[Union[
|
|
644
|
-
ignore_patterns: Optional[Union[
|
|
645
|
-
delete_patterns: Optional[Union[
|
|
646
|
-
model_card_kwargs: Optional[
|
|
629
|
+
allow_patterns: Optional[Union[list[str], str]] = None,
|
|
630
|
+
ignore_patterns: Optional[Union[list[str], str]] = None,
|
|
631
|
+
delete_patterns: Optional[Union[list[str], str]] = None,
|
|
632
|
+
model_card_kwargs: Optional[dict[str, Any]] = None,
|
|
647
633
|
) -> str:
|
|
648
634
|
"""
|
|
649
635
|
Upload model checkpoint to the Hub.
|
|
@@ -669,13 +655,13 @@ class ModelHubMixin:
|
|
|
669
655
|
The git branch on which to push the model. This defaults to `"main"`.
|
|
670
656
|
create_pr (`boolean`, *optional*):
|
|
671
657
|
Whether or not to create a Pull Request from `branch` with that commit. Defaults to `False`.
|
|
672
|
-
allow_patterns (`
|
|
658
|
+
allow_patterns (`list[str]` or `str`, *optional*):
|
|
673
659
|
If provided, only files matching at least one pattern are pushed.
|
|
674
|
-
ignore_patterns (`
|
|
660
|
+
ignore_patterns (`list[str]` or `str`, *optional*):
|
|
675
661
|
If provided, files matching any of the patterns are not pushed.
|
|
676
|
-
delete_patterns (`
|
|
662
|
+
delete_patterns (`list[str]` or `str`, *optional*):
|
|
677
663
|
If provided, remote files matching any of the patterns will be deleted from the repo.
|
|
678
|
-
model_card_kwargs (`
|
|
664
|
+
model_card_kwargs (`dict[str, Any]`, *optional*):
|
|
679
665
|
Additional arguments passed to the model card template to customize the model card.
|
|
680
666
|
|
|
681
667
|
Returns:
|
|
@@ -758,7 +744,7 @@ class PyTorchModelHubMixin(ModelHubMixin):
|
|
|
758
744
|
```
|
|
759
745
|
"""
|
|
760
746
|
|
|
761
|
-
def __init_subclass__(cls, *args, tags: Optional[
|
|
747
|
+
def __init_subclass__(cls, *args, tags: Optional[list[str]] = None, **kwargs) -> None:
|
|
762
748
|
tags = tags or []
|
|
763
749
|
tags.append("pytorch_model_hub_mixin")
|
|
764
750
|
kwargs["tags"] = tags
|
|
@@ -777,8 +763,6 @@ class PyTorchModelHubMixin(ModelHubMixin):
|
|
|
777
763
|
revision: Optional[str],
|
|
778
764
|
cache_dir: Optional[Union[str, Path]],
|
|
779
765
|
force_download: bool,
|
|
780
|
-
proxies: Optional[Dict],
|
|
781
|
-
resume_download: Optional[bool],
|
|
782
766
|
local_files_only: bool,
|
|
783
767
|
token: Union[str, bool, None],
|
|
784
768
|
map_location: str = "cpu",
|
|
@@ -799,8 +783,6 @@ class PyTorchModelHubMixin(ModelHubMixin):
|
|
|
799
783
|
revision=revision,
|
|
800
784
|
cache_dir=cache_dir,
|
|
801
785
|
force_download=force_download,
|
|
802
|
-
proxies=proxies,
|
|
803
|
-
resume_download=resume_download,
|
|
804
786
|
token=token,
|
|
805
787
|
local_files_only=local_files_only,
|
|
806
788
|
)
|
|
@@ -812,8 +794,6 @@ class PyTorchModelHubMixin(ModelHubMixin):
|
|
|
812
794
|
revision=revision,
|
|
813
795
|
cache_dir=cache_dir,
|
|
814
796
|
force_download=force_download,
|
|
815
|
-
proxies=proxies,
|
|
816
|
-
resume_download=resume_download,
|
|
817
797
|
token=token,
|
|
818
798
|
local_files_only=local_files_only,
|
|
819
799
|
)
|
|
@@ -843,7 +823,7 @@ class PyTorchModelHubMixin(ModelHubMixin):
|
|
|
843
823
|
return model
|
|
844
824
|
|
|
845
825
|
|
|
846
|
-
def _load_dataclass(datacls:
|
|
826
|
+
def _load_dataclass(datacls: type[DataclassInstance], data: dict) -> DataclassInstance:
|
|
847
827
|
"""Load a dataclass instance from a dictionary.
|
|
848
828
|
|
|
849
829
|
Fields not expected by the dataclass are ignored.
|