huggingface-hub 0.36.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of huggingface-hub might be problematic. Click here for more details.
- huggingface_hub/__init__.py +33 -45
- huggingface_hub/_commit_api.py +39 -43
- huggingface_hub/_commit_scheduler.py +11 -8
- huggingface_hub/_inference_endpoints.py +8 -8
- huggingface_hub/_jobs_api.py +20 -20
- huggingface_hub/_login.py +17 -43
- huggingface_hub/_oauth.py +8 -8
- huggingface_hub/_snapshot_download.py +135 -50
- huggingface_hub/_space_api.py +4 -4
- huggingface_hub/_tensorboard_logger.py +5 -5
- huggingface_hub/_upload_large_folder.py +18 -32
- huggingface_hub/_webhooks_payload.py +3 -3
- huggingface_hub/_webhooks_server.py +2 -2
- huggingface_hub/cli/__init__.py +0 -14
- huggingface_hub/cli/_cli_utils.py +143 -39
- huggingface_hub/cli/auth.py +105 -171
- huggingface_hub/cli/cache.py +594 -361
- huggingface_hub/cli/download.py +120 -112
- huggingface_hub/cli/hf.py +38 -41
- huggingface_hub/cli/jobs.py +689 -1017
- huggingface_hub/cli/lfs.py +120 -143
- huggingface_hub/cli/repo.py +282 -216
- huggingface_hub/cli/repo_files.py +50 -84
- huggingface_hub/cli/system.py +6 -25
- huggingface_hub/cli/upload.py +198 -220
- huggingface_hub/cli/upload_large_folder.py +91 -106
- huggingface_hub/community.py +5 -5
- huggingface_hub/constants.py +17 -52
- huggingface_hub/dataclasses.py +135 -21
- huggingface_hub/errors.py +47 -30
- huggingface_hub/fastai_utils.py +8 -9
- huggingface_hub/file_download.py +351 -303
- huggingface_hub/hf_api.py +398 -570
- huggingface_hub/hf_file_system.py +101 -66
- huggingface_hub/hub_mixin.py +32 -54
- huggingface_hub/inference/_client.py +177 -162
- huggingface_hub/inference/_common.py +38 -54
- huggingface_hub/inference/_generated/_async_client.py +218 -258
- huggingface_hub/inference/_generated/types/automatic_speech_recognition.py +3 -3
- huggingface_hub/inference/_generated/types/base.py +10 -7
- huggingface_hub/inference/_generated/types/chat_completion.py +16 -16
- huggingface_hub/inference/_generated/types/depth_estimation.py +2 -2
- huggingface_hub/inference/_generated/types/document_question_answering.py +2 -2
- huggingface_hub/inference/_generated/types/feature_extraction.py +2 -2
- huggingface_hub/inference/_generated/types/fill_mask.py +2 -2
- huggingface_hub/inference/_generated/types/sentence_similarity.py +3 -3
- huggingface_hub/inference/_generated/types/summarization.py +2 -2
- huggingface_hub/inference/_generated/types/table_question_answering.py +4 -4
- huggingface_hub/inference/_generated/types/text2text_generation.py +2 -2
- huggingface_hub/inference/_generated/types/text_generation.py +10 -10
- huggingface_hub/inference/_generated/types/text_to_video.py +2 -2
- huggingface_hub/inference/_generated/types/token_classification.py +2 -2
- huggingface_hub/inference/_generated/types/translation.py +2 -2
- huggingface_hub/inference/_generated/types/zero_shot_classification.py +2 -2
- huggingface_hub/inference/_generated/types/zero_shot_image_classification.py +2 -2
- huggingface_hub/inference/_generated/types/zero_shot_object_detection.py +1 -3
- huggingface_hub/inference/_mcp/agent.py +3 -3
- huggingface_hub/inference/_mcp/constants.py +1 -2
- huggingface_hub/inference/_mcp/mcp_client.py +33 -22
- huggingface_hub/inference/_mcp/types.py +10 -10
- huggingface_hub/inference/_mcp/utils.py +4 -4
- huggingface_hub/inference/_providers/__init__.py +12 -4
- huggingface_hub/inference/_providers/_common.py +62 -24
- huggingface_hub/inference/_providers/black_forest_labs.py +6 -6
- huggingface_hub/inference/_providers/cohere.py +3 -3
- huggingface_hub/inference/_providers/fal_ai.py +25 -25
- huggingface_hub/inference/_providers/featherless_ai.py +4 -4
- huggingface_hub/inference/_providers/fireworks_ai.py +3 -3
- huggingface_hub/inference/_providers/hf_inference.py +13 -13
- huggingface_hub/inference/_providers/hyperbolic.py +4 -4
- huggingface_hub/inference/_providers/nebius.py +10 -10
- huggingface_hub/inference/_providers/novita.py +5 -5
- huggingface_hub/inference/_providers/nscale.py +4 -4
- huggingface_hub/inference/_providers/replicate.py +15 -15
- huggingface_hub/inference/_providers/sambanova.py +6 -6
- huggingface_hub/inference/_providers/together.py +7 -7
- huggingface_hub/lfs.py +21 -94
- huggingface_hub/repocard.py +15 -16
- huggingface_hub/repocard_data.py +57 -57
- huggingface_hub/serialization/__init__.py +0 -1
- huggingface_hub/serialization/_base.py +9 -9
- huggingface_hub/serialization/_dduf.py +7 -7
- huggingface_hub/serialization/_torch.py +28 -28
- huggingface_hub/utils/__init__.py +11 -6
- huggingface_hub/utils/_auth.py +5 -5
- huggingface_hub/utils/_cache_manager.py +49 -74
- huggingface_hub/utils/_deprecation.py +1 -1
- huggingface_hub/utils/_dotenv.py +3 -3
- huggingface_hub/utils/_fixes.py +0 -10
- huggingface_hub/utils/_git_credential.py +3 -3
- huggingface_hub/utils/_headers.py +7 -29
- huggingface_hub/utils/_http.py +371 -208
- huggingface_hub/utils/_pagination.py +4 -4
- huggingface_hub/utils/_parsing.py +98 -0
- huggingface_hub/utils/_paths.py +5 -5
- huggingface_hub/utils/_runtime.py +59 -23
- huggingface_hub/utils/_safetensors.py +21 -21
- huggingface_hub/utils/_subprocess.py +9 -9
- huggingface_hub/utils/_telemetry.py +3 -3
- huggingface_hub/{commands/_cli_utils.py → utils/_terminal.py} +4 -9
- huggingface_hub/utils/_typing.py +3 -3
- huggingface_hub/utils/_validators.py +53 -72
- huggingface_hub/utils/_xet.py +16 -16
- huggingface_hub/utils/_xet_progress_reporting.py +1 -1
- huggingface_hub/utils/insecure_hashlib.py +3 -9
- huggingface_hub/utils/tqdm.py +3 -3
- {huggingface_hub-0.36.0.dist-info → huggingface_hub-1.0.0.dist-info}/METADATA +16 -35
- huggingface_hub-1.0.0.dist-info/RECORD +152 -0
- {huggingface_hub-0.36.0.dist-info → huggingface_hub-1.0.0.dist-info}/entry_points.txt +0 -1
- huggingface_hub/commands/__init__.py +0 -27
- huggingface_hub/commands/delete_cache.py +0 -476
- huggingface_hub/commands/download.py +0 -204
- huggingface_hub/commands/env.py +0 -39
- huggingface_hub/commands/huggingface_cli.py +0 -65
- huggingface_hub/commands/lfs.py +0 -200
- huggingface_hub/commands/repo.py +0 -151
- huggingface_hub/commands/repo_files.py +0 -132
- huggingface_hub/commands/scan_cache.py +0 -183
- huggingface_hub/commands/tag.py +0 -161
- huggingface_hub/commands/upload.py +0 -318
- huggingface_hub/commands/upload_large_folder.py +0 -131
- huggingface_hub/commands/user.py +0 -208
- huggingface_hub/commands/version.py +0 -40
- huggingface_hub/inference_api.py +0 -217
- huggingface_hub/keras_mixin.py +0 -497
- huggingface_hub/repository.py +0 -1471
- huggingface_hub/serialization/_tensorflow.py +0 -92
- huggingface_hub/utils/_hf_folder.py +0 -68
- huggingface_hub-0.36.0.dist-info/RECORD +0 -170
- {huggingface_hub-0.36.0.dist-info → huggingface_hub-1.0.0.dist-info}/LICENSE +0 -0
- {huggingface_hub-0.36.0.dist-info → huggingface_hub-1.0.0.dist-info}/WHEEL +0 -0
- {huggingface_hub-0.36.0.dist-info → huggingface_hub-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -2,24 +2,25 @@ import os
|
|
|
2
2
|
import re
|
|
3
3
|
import tempfile
|
|
4
4
|
from collections import deque
|
|
5
|
+
from contextlib import ExitStack
|
|
5
6
|
from dataclasses import dataclass, field
|
|
6
7
|
from datetime import datetime
|
|
7
8
|
from itertools import chain
|
|
8
9
|
from pathlib import Path
|
|
9
|
-
from typing import Any,
|
|
10
|
+
from typing import Any, Iterator, NoReturn, Optional, Union
|
|
10
11
|
from urllib.parse import quote, unquote
|
|
11
12
|
|
|
12
13
|
import fsspec
|
|
14
|
+
import httpx
|
|
13
15
|
from fsspec.callbacks import _DEFAULT_CALLBACK, NoOpCallback, TqdmCallback
|
|
14
16
|
from fsspec.utils import isfilelike
|
|
15
|
-
from requests import Response
|
|
16
17
|
|
|
17
18
|
from . import constants
|
|
18
19
|
from ._commit_api import CommitOperationCopy, CommitOperationDelete
|
|
19
|
-
from .errors import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
|
20
|
+
from .errors import EntryNotFoundError, HfHubHTTPError, RepositoryNotFoundError, RevisionNotFoundError
|
|
20
21
|
from .file_download import hf_hub_url, http_get
|
|
21
22
|
from .hf_api import HfApi, LastCommitInfo, RepoFile
|
|
22
|
-
from .utils import HFValidationError, hf_raise_for_status, http_backoff
|
|
23
|
+
from .utils import HFValidationError, hf_raise_for_status, http_backoff, http_stream_backoff
|
|
23
24
|
|
|
24
25
|
|
|
25
26
|
# Regex used to match special revisions with "/" in them (see #1710)
|
|
@@ -112,15 +113,15 @@ class HfFileSystem(fsspec.AbstractFileSystem):
|
|
|
112
113
|
# Maps (repo_type, repo_id, revision) to a 2-tuple with:
|
|
113
114
|
# * the 1st element indicating whether the repositoy and the revision exist
|
|
114
115
|
# * the 2nd element being the exception raised if the repository or revision doesn't exist
|
|
115
|
-
self._repo_and_revision_exists_cache:
|
|
116
|
-
|
|
116
|
+
self._repo_and_revision_exists_cache: dict[
|
|
117
|
+
tuple[str, str, Optional[str]], tuple[bool, Optional[Exception]]
|
|
117
118
|
] = {}
|
|
118
119
|
# Maps parent directory path to path infos
|
|
119
|
-
self.dircache:
|
|
120
|
+
self.dircache: dict[str, list[dict[str, Any]]] = {}
|
|
120
121
|
|
|
121
122
|
def _repo_and_revision_exist(
|
|
122
123
|
self, repo_type: str, repo_id: str, revision: Optional[str]
|
|
123
|
-
) ->
|
|
124
|
+
) -> tuple[bool, Optional[Exception]]:
|
|
124
125
|
if (repo_type, repo_id, revision) not in self._repo_and_revision_exists_cache:
|
|
125
126
|
try:
|
|
126
127
|
self._api.repo_info(
|
|
@@ -339,7 +340,7 @@ class HfFileSystem(fsspec.AbstractFileSystem):
|
|
|
339
340
|
|
|
340
341
|
def ls(
|
|
341
342
|
self, path: str, detail: bool = True, refresh: bool = False, revision: Optional[str] = None, **kwargs
|
|
342
|
-
) ->
|
|
343
|
+
) -> list[Union[str, dict[str, Any]]]:
|
|
343
344
|
"""
|
|
344
345
|
List the contents of a directory.
|
|
345
346
|
|
|
@@ -360,7 +361,7 @@ class HfFileSystem(fsspec.AbstractFileSystem):
|
|
|
360
361
|
The git revision to list from.
|
|
361
362
|
|
|
362
363
|
Returns:
|
|
363
|
-
`
|
|
364
|
+
`list[Union[str, dict[str, Any]]]`: List of file paths (if detail=False) or list of file information
|
|
364
365
|
dictionaries (if detail=True).
|
|
365
366
|
"""
|
|
366
367
|
resolved_path = self.resolve_path(path, revision=revision)
|
|
@@ -495,7 +496,7 @@ class HfFileSystem(fsspec.AbstractFileSystem):
|
|
|
495
496
|
out.append(cache_path_info)
|
|
496
497
|
return out
|
|
497
498
|
|
|
498
|
-
def walk(self, path: str, *args, **kwargs) -> Iterator[
|
|
499
|
+
def walk(self, path: str, *args, **kwargs) -> Iterator[tuple[str, list[str], list[str]]]:
|
|
499
500
|
"""
|
|
500
501
|
Return all files below the given path.
|
|
501
502
|
|
|
@@ -506,12 +507,12 @@ class HfFileSystem(fsspec.AbstractFileSystem):
|
|
|
506
507
|
Root path to list files from.
|
|
507
508
|
|
|
508
509
|
Returns:
|
|
509
|
-
`Iterator[
|
|
510
|
+
`Iterator[tuple[str, list[str], list[str]]]`: An iterator of (path, list of directory names, list of file names) tuples.
|
|
510
511
|
"""
|
|
511
512
|
path = self.resolve_path(path, revision=kwargs.get("revision")).unresolve()
|
|
512
513
|
yield from super().walk(path, *args, **kwargs)
|
|
513
514
|
|
|
514
|
-
def glob(self, path: str, **kwargs) ->
|
|
515
|
+
def glob(self, path: str, **kwargs) -> list[str]:
|
|
515
516
|
"""
|
|
516
517
|
Find files by glob-matching.
|
|
517
518
|
|
|
@@ -522,7 +523,7 @@ class HfFileSystem(fsspec.AbstractFileSystem):
|
|
|
522
523
|
Path pattern to match.
|
|
523
524
|
|
|
524
525
|
Returns:
|
|
525
|
-
`
|
|
526
|
+
`list[str]`: List of paths matching the pattern.
|
|
526
527
|
"""
|
|
527
528
|
path = self.resolve_path(path, revision=kwargs.get("revision")).unresolve()
|
|
528
529
|
return super().glob(path, **kwargs)
|
|
@@ -536,7 +537,7 @@ class HfFileSystem(fsspec.AbstractFileSystem):
|
|
|
536
537
|
refresh: bool = False,
|
|
537
538
|
revision: Optional[str] = None,
|
|
538
539
|
**kwargs,
|
|
539
|
-
) -> Union[
|
|
540
|
+
) -> Union[list[str], dict[str, dict[str, Any]]]:
|
|
540
541
|
"""
|
|
541
542
|
List all files below path.
|
|
542
543
|
|
|
@@ -557,7 +558,7 @@ class HfFileSystem(fsspec.AbstractFileSystem):
|
|
|
557
558
|
The git revision to list from.
|
|
558
559
|
|
|
559
560
|
Returns:
|
|
560
|
-
`Union[
|
|
561
|
+
`Union[list[str], dict[str, dict[str, Any]]]`: List of paths or dict of file information.
|
|
561
562
|
"""
|
|
562
563
|
if maxdepth is not None and maxdepth < 1:
|
|
563
564
|
raise ValueError("maxdepth must be at least 1")
|
|
@@ -659,10 +660,10 @@ class HfFileSystem(fsspec.AbstractFileSystem):
|
|
|
659
660
|
Returns:
|
|
660
661
|
`datetime`: Last commit date of the file.
|
|
661
662
|
"""
|
|
662
|
-
info = self.info(path, **{**kwargs, "expand_info": True})
|
|
663
|
+
info = self.info(path, **{**kwargs, "expand_info": True}) # type: ignore
|
|
663
664
|
return info["last_commit"]["date"]
|
|
664
665
|
|
|
665
|
-
def info(self, path: str, refresh: bool = False, revision: Optional[str] = None, **kwargs) ->
|
|
666
|
+
def info(self, path: str, refresh: bool = False, revision: Optional[str] = None, **kwargs) -> dict[str, Any]:
|
|
666
667
|
"""
|
|
667
668
|
Get information about a file or directory.
|
|
668
669
|
|
|
@@ -680,7 +681,7 @@ class HfFileSystem(fsspec.AbstractFileSystem):
|
|
|
680
681
|
The git revision to get info from.
|
|
681
682
|
|
|
682
683
|
Returns:
|
|
683
|
-
`
|
|
684
|
+
`dict[str, Any]`: Dictionary containing file information (type, size, commit info, etc.).
|
|
684
685
|
|
|
685
686
|
"""
|
|
686
687
|
resolved_path = self.resolve_path(path, revision=revision)
|
|
@@ -1004,9 +1005,8 @@ class HfFileSystemFile(fsspec.spec.AbstractBufferedFile):
|
|
|
1004
1005
|
def read(self, length=-1):
|
|
1005
1006
|
"""Read remote file.
|
|
1006
1007
|
|
|
1007
|
-
If `length` is not provided or is -1, the entire file is downloaded and read. On POSIX systems
|
|
1008
|
-
|
|
1009
|
-
temporary file and read from there.
|
|
1008
|
+
If `length` is not provided or is -1, the entire file is downloaded and read. On POSIX systems the file is
|
|
1009
|
+
loaded in memory directly. Otherwise, the file is downloaded to a temporary file and read from there.
|
|
1010
1010
|
"""
|
|
1011
1011
|
if self.mode == "rb" and (length is None or length == -1) and self.loc == 0:
|
|
1012
1012
|
with self.fs.open(self.path, "rb", block_size=0) as f: # block_size=0 enables fast streaming
|
|
@@ -1048,8 +1048,9 @@ class HfFileSystemStreamFile(fsspec.spec.AbstractBufferedFile):
|
|
|
1048
1048
|
super().__init__(
|
|
1049
1049
|
fs, self.resolved_path.unresolve(), mode=mode, block_size=block_size, cache_type=cache_type, **kwargs
|
|
1050
1050
|
)
|
|
1051
|
-
self.response: Optional[Response] = None
|
|
1051
|
+
self.response: Optional[httpx.Response] = None
|
|
1052
1052
|
self.fs: HfFileSystem
|
|
1053
|
+
self._exit_stack = ExitStack()
|
|
1053
1054
|
|
|
1054
1055
|
def seek(self, loc: int, whence: int = 0):
|
|
1055
1056
|
if loc == 0 and whence == 1:
|
|
@@ -1059,53 +1060,32 @@ class HfFileSystemStreamFile(fsspec.spec.AbstractBufferedFile):
|
|
|
1059
1060
|
raise ValueError("Cannot seek streaming HF file")
|
|
1060
1061
|
|
|
1061
1062
|
def read(self, length: int = -1):
|
|
1062
|
-
|
|
1063
|
+
"""Read the remote file.
|
|
1064
|
+
|
|
1065
|
+
If the file is already open, we reuse the connection.
|
|
1066
|
+
Otherwise, open a new connection and read from it.
|
|
1067
|
+
|
|
1068
|
+
If reading the stream fails, we retry with a new connection.
|
|
1069
|
+
"""
|
|
1063
1070
|
if self.response is None:
|
|
1064
|
-
|
|
1065
|
-
repo_id=self.resolved_path.repo_id,
|
|
1066
|
-
revision=self.resolved_path.revision,
|
|
1067
|
-
filename=self.resolved_path.path_in_repo,
|
|
1068
|
-
repo_type=self.resolved_path.repo_type,
|
|
1069
|
-
endpoint=self.fs.endpoint,
|
|
1070
|
-
)
|
|
1071
|
-
self.response = http_backoff(
|
|
1072
|
-
"GET",
|
|
1073
|
-
url,
|
|
1074
|
-
headers=self.fs._api._build_hf_headers(),
|
|
1075
|
-
stream=True,
|
|
1076
|
-
timeout=constants.HF_HUB_DOWNLOAD_TIMEOUT,
|
|
1077
|
-
)
|
|
1078
|
-
hf_raise_for_status(self.response)
|
|
1079
|
-
try:
|
|
1080
|
-
self.response.raw.decode_content = True
|
|
1081
|
-
out = self.response.raw.read(*read_args)
|
|
1082
|
-
except Exception:
|
|
1083
|
-
self.response.close()
|
|
1071
|
+
self._open_connection()
|
|
1084
1072
|
|
|
1085
|
-
|
|
1086
|
-
|
|
1087
|
-
repo_id=self.resolved_path.repo_id,
|
|
1088
|
-
revision=self.resolved_path.revision,
|
|
1089
|
-
filename=self.resolved_path.path_in_repo,
|
|
1090
|
-
repo_type=self.resolved_path.repo_type,
|
|
1091
|
-
endpoint=self.fs.endpoint,
|
|
1092
|
-
)
|
|
1093
|
-
self.response = http_backoff(
|
|
1094
|
-
"GET",
|
|
1095
|
-
url,
|
|
1096
|
-
headers={"Range": "bytes=%d-" % self.loc, **self.fs._api._build_hf_headers()},
|
|
1097
|
-
stream=True,
|
|
1098
|
-
timeout=constants.HF_HUB_DOWNLOAD_TIMEOUT,
|
|
1099
|
-
)
|
|
1100
|
-
hf_raise_for_status(self.response)
|
|
1073
|
+
retried_once = False
|
|
1074
|
+
while True:
|
|
1101
1075
|
try:
|
|
1102
|
-
self.response
|
|
1103
|
-
|
|
1076
|
+
if self.response is None:
|
|
1077
|
+
return b"" # Already read the entire file
|
|
1078
|
+
out = _partial_read(self.response, length)
|
|
1079
|
+
self.loc += len(out)
|
|
1080
|
+
return out
|
|
1104
1081
|
except Exception:
|
|
1105
|
-
self.response
|
|
1106
|
-
|
|
1107
|
-
|
|
1108
|
-
|
|
1082
|
+
if self.response is not None:
|
|
1083
|
+
self.response.close()
|
|
1084
|
+
if retried_once: # Already retried once, give up
|
|
1085
|
+
raise
|
|
1086
|
+
# First failure, retry with range header
|
|
1087
|
+
self._open_connection()
|
|
1088
|
+
retried_once = True
|
|
1109
1089
|
|
|
1110
1090
|
def url(self) -> str:
|
|
1111
1091
|
return self.fs.url(self.path)
|
|
@@ -1114,11 +1094,43 @@ class HfFileSystemStreamFile(fsspec.spec.AbstractBufferedFile):
|
|
|
1114
1094
|
if not hasattr(self, "resolved_path"):
|
|
1115
1095
|
# Means that the constructor failed. Nothing to do.
|
|
1116
1096
|
return
|
|
1097
|
+
self._exit_stack.close()
|
|
1117
1098
|
return super().__del__()
|
|
1118
1099
|
|
|
1119
1100
|
def __reduce__(self):
|
|
1120
1101
|
return reopen, (self.fs, self.path, self.mode, self.blocksize, self.cache.name)
|
|
1121
1102
|
|
|
1103
|
+
def _open_connection(self):
|
|
1104
|
+
"""Open a connection to the remote file."""
|
|
1105
|
+
url = hf_hub_url(
|
|
1106
|
+
repo_id=self.resolved_path.repo_id,
|
|
1107
|
+
revision=self.resolved_path.revision,
|
|
1108
|
+
filename=self.resolved_path.path_in_repo,
|
|
1109
|
+
repo_type=self.resolved_path.repo_type,
|
|
1110
|
+
endpoint=self.fs.endpoint,
|
|
1111
|
+
)
|
|
1112
|
+
headers = self.fs._api._build_hf_headers()
|
|
1113
|
+
if self.loc > 0:
|
|
1114
|
+
headers["Range"] = f"bytes={self.loc}-"
|
|
1115
|
+
self.response = self._exit_stack.enter_context(
|
|
1116
|
+
http_stream_backoff(
|
|
1117
|
+
"GET",
|
|
1118
|
+
url,
|
|
1119
|
+
headers=headers,
|
|
1120
|
+
retry_on_status_codes=(500, 502, 503, 504),
|
|
1121
|
+
timeout=constants.HF_HUB_DOWNLOAD_TIMEOUT,
|
|
1122
|
+
)
|
|
1123
|
+
)
|
|
1124
|
+
|
|
1125
|
+
try:
|
|
1126
|
+
hf_raise_for_status(self.response)
|
|
1127
|
+
except HfHubHTTPError as e:
|
|
1128
|
+
if e.response.status_code == 416:
|
|
1129
|
+
# Range not satisfiable => means that we have already read the entire file
|
|
1130
|
+
self.response = None
|
|
1131
|
+
return
|
|
1132
|
+
raise
|
|
1133
|
+
|
|
1122
1134
|
|
|
1123
1135
|
def safe_revision(revision: str) -> str:
|
|
1124
1136
|
return revision if SPECIAL_REFS_REVISION_REGEX.match(revision) else safe_quote(revision)
|
|
@@ -1143,6 +1155,29 @@ def reopen(fs: HfFileSystem, path: str, mode: str, block_size: int, cache_type:
|
|
|
1143
1155
|
return fs.open(path, mode=mode, block_size=block_size, cache_type=cache_type)
|
|
1144
1156
|
|
|
1145
1157
|
|
|
1158
|
+
def _partial_read(response: httpx.Response, length: int = -1) -> bytes:
|
|
1159
|
+
"""
|
|
1160
|
+
Read up to `length` bytes from a streamed response.
|
|
1161
|
+
If length == -1, read until EOF.
|
|
1162
|
+
"""
|
|
1163
|
+
buf = bytearray()
|
|
1164
|
+
if length < -1:
|
|
1165
|
+
raise ValueError("length must be -1 or >= 0")
|
|
1166
|
+
if length == 0:
|
|
1167
|
+
return b""
|
|
1168
|
+
if length == -1:
|
|
1169
|
+
for chunk in response.iter_bytes():
|
|
1170
|
+
buf.extend(chunk)
|
|
1171
|
+
return bytes(buf)
|
|
1172
|
+
|
|
1173
|
+
for chunk in response.iter_bytes(chunk_size=length):
|
|
1174
|
+
buf.extend(chunk)
|
|
1175
|
+
if len(buf) >= length:
|
|
1176
|
+
return bytes(buf[:length])
|
|
1177
|
+
|
|
1178
|
+
return bytes(buf) # may be < length if response ended
|
|
1179
|
+
|
|
1180
|
+
|
|
1146
1181
|
def make_instance(cls, args, kwargs, instance_cache_attributes_dict):
|
|
1147
1182
|
fs = cls(*args, **kwargs)
|
|
1148
1183
|
for attr, cached_value in instance_cache_attributes_dict.items():
|
huggingface_hub/hub_mixin.py
CHANGED
|
@@ -3,7 +3,7 @@ import json
|
|
|
3
3
|
import os
|
|
4
4
|
from dataclasses import Field, asdict, dataclass, is_dataclass
|
|
5
5
|
from pathlib import Path
|
|
6
|
-
from typing import Any, Callable, ClassVar,
|
|
6
|
+
from typing import Any, Callable, ClassVar, Optional, Protocol, Type, TypeVar, Union
|
|
7
7
|
|
|
8
8
|
import packaging.version
|
|
9
9
|
|
|
@@ -38,7 +38,7 @@ logger = logging.get_logger(__name__)
|
|
|
38
38
|
|
|
39
39
|
# Type alias for dataclass instances, copied from https://github.com/python/typeshed/blob/9f28171658b9ca6c32a7cb93fbb99fc92b17858b/stdlib/_typeshed/__init__.pyi#L349
|
|
40
40
|
class DataclassInstance(Protocol):
|
|
41
|
-
__dataclass_fields__: ClassVar[
|
|
41
|
+
__dataclass_fields__: ClassVar[dict[str, Field]]
|
|
42
42
|
|
|
43
43
|
|
|
44
44
|
# Generic variable that is either ModelHubMixin or a subclass thereof
|
|
@@ -47,7 +47,7 @@ T = TypeVar("T", bound="ModelHubMixin")
|
|
|
47
47
|
ARGS_T = TypeVar("ARGS_T")
|
|
48
48
|
ENCODER_T = Callable[[ARGS_T], Any]
|
|
49
49
|
DECODER_T = Callable[[Any], ARGS_T]
|
|
50
|
-
CODER_T =
|
|
50
|
+
CODER_T = tuple[ENCODER_T, DECODER_T]
|
|
51
51
|
|
|
52
52
|
|
|
53
53
|
DEFAULT_MODEL_CARD = """
|
|
@@ -96,7 +96,7 @@ class ModelHubMixin:
|
|
|
96
96
|
URL of the library documentation. Used to generate model card.
|
|
97
97
|
model_card_template (`str`, *optional*):
|
|
98
98
|
Template of the model card. Used to generate model card. Defaults to a generic template.
|
|
99
|
-
language (`str` or `
|
|
99
|
+
language (`str` or `list[str]`, *optional*):
|
|
100
100
|
Language supported by the library. Used to generate model card.
|
|
101
101
|
library_name (`str`, *optional*):
|
|
102
102
|
Name of the library integrating ModelHubMixin. Used to generate model card.
|
|
@@ -113,9 +113,9 @@ class ModelHubMixin:
|
|
|
113
113
|
E.g: "https://coqui.ai/cpml".
|
|
114
114
|
pipeline_tag (`str`, *optional*):
|
|
115
115
|
Tag of the pipeline. Used to generate model card. E.g. "text-classification".
|
|
116
|
-
tags (`
|
|
116
|
+
tags (`list[str]`, *optional*):
|
|
117
117
|
Tags to be added to the model card. Used to generate model card. E.g. ["computer-vision"]
|
|
118
|
-
coders (`
|
|
118
|
+
coders (`dict[Type, tuple[Callable, Callable]]`, *optional*):
|
|
119
119
|
Dictionary of custom types and their encoders/decoders. Used to encode/decode arguments that are not
|
|
120
120
|
jsonable by default. E.g dataclasses, argparse.Namespace, OmegaConf, etc.
|
|
121
121
|
|
|
@@ -145,12 +145,10 @@ class ModelHubMixin:
|
|
|
145
145
|
...
|
|
146
146
|
... @classmethod
|
|
147
147
|
... def from_pretrained(
|
|
148
|
-
... cls:
|
|
148
|
+
... cls: type[T],
|
|
149
149
|
... pretrained_model_name_or_path: Union[str, Path],
|
|
150
150
|
... *,
|
|
151
151
|
... force_download: bool = False,
|
|
152
|
-
... resume_download: Optional[bool] = None,
|
|
153
|
-
... proxies: Optional[Dict] = None,
|
|
154
152
|
... token: Optional[Union[str, bool]] = None,
|
|
155
153
|
... cache_dir: Optional[Union[str, Path]] = None,
|
|
156
154
|
... local_files_only: bool = False,
|
|
@@ -188,10 +186,10 @@ class ModelHubMixin:
|
|
|
188
186
|
_hub_mixin_info: MixinInfo
|
|
189
187
|
# ^ information about the library integrating ModelHubMixin (used to generate model card)
|
|
190
188
|
_hub_mixin_inject_config: bool # whether `_from_pretrained` expects `config` or not
|
|
191
|
-
_hub_mixin_init_parameters:
|
|
192
|
-
_hub_mixin_jsonable_default_values:
|
|
193
|
-
_hub_mixin_jsonable_custom_types:
|
|
194
|
-
_hub_mixin_coders:
|
|
189
|
+
_hub_mixin_init_parameters: dict[str, inspect.Parameter] # __init__ parameters
|
|
190
|
+
_hub_mixin_jsonable_default_values: dict[str, Any] # default values for __init__ parameters
|
|
191
|
+
_hub_mixin_jsonable_custom_types: tuple[Type, ...] # custom types that can be encoded/decoded
|
|
192
|
+
_hub_mixin_coders: dict[Type, CODER_T] # encoders/decoders for custom types
|
|
195
193
|
# ^ internal values to handle config
|
|
196
194
|
|
|
197
195
|
def __init_subclass__(
|
|
@@ -204,16 +202,16 @@ class ModelHubMixin:
|
|
|
204
202
|
# Model card template
|
|
205
203
|
model_card_template: str = DEFAULT_MODEL_CARD,
|
|
206
204
|
# Model card metadata
|
|
207
|
-
language: Optional[
|
|
205
|
+
language: Optional[list[str]] = None,
|
|
208
206
|
library_name: Optional[str] = None,
|
|
209
207
|
license: Optional[str] = None,
|
|
210
208
|
license_name: Optional[str] = None,
|
|
211
209
|
license_link: Optional[str] = None,
|
|
212
210
|
pipeline_tag: Optional[str] = None,
|
|
213
|
-
tags: Optional[
|
|
211
|
+
tags: Optional[list[str]] = None,
|
|
214
212
|
# How to encode/decode arguments with custom type into a JSON config?
|
|
215
213
|
coders: Optional[
|
|
216
|
-
|
|
214
|
+
dict[Type, CODER_T]
|
|
217
215
|
# Key is a type.
|
|
218
216
|
# Value is a tuple (encoder, decoder).
|
|
219
217
|
# Example: {MyCustomType: (lambda x: x.value, lambda data: MyCustomType(data))}
|
|
@@ -288,7 +286,7 @@ class ModelHubMixin:
|
|
|
288
286
|
}
|
|
289
287
|
cls._hub_mixin_inject_config = "config" in inspect.signature(cls._from_pretrained).parameters
|
|
290
288
|
|
|
291
|
-
def __new__(cls:
|
|
289
|
+
def __new__(cls: type[T], *args, **kwargs) -> T:
|
|
292
290
|
"""Create a new instance of the class and handle config.
|
|
293
291
|
|
|
294
292
|
3 cases:
|
|
@@ -364,7 +362,7 @@ class ModelHubMixin:
|
|
|
364
362
|
return arg
|
|
365
363
|
|
|
366
364
|
@classmethod
|
|
367
|
-
def _decode_arg(cls, expected_type:
|
|
365
|
+
def _decode_arg(cls, expected_type: type[ARGS_T], value: Any) -> Optional[ARGS_T]:
|
|
368
366
|
"""Decode a JSON serializable value into an argument."""
|
|
369
367
|
if is_simple_optional_type(expected_type):
|
|
370
368
|
if value is None:
|
|
@@ -387,7 +385,7 @@ class ModelHubMixin:
|
|
|
387
385
|
config: Optional[Union[dict, DataclassInstance]] = None,
|
|
388
386
|
repo_id: Optional[str] = None,
|
|
389
387
|
push_to_hub: bool = False,
|
|
390
|
-
model_card_kwargs: Optional[
|
|
388
|
+
model_card_kwargs: Optional[dict[str, Any]] = None,
|
|
391
389
|
**push_to_hub_kwargs,
|
|
392
390
|
) -> Optional[str]:
|
|
393
391
|
"""
|
|
@@ -403,7 +401,7 @@ class ModelHubMixin:
|
|
|
403
401
|
repo_id (`str`, *optional*):
|
|
404
402
|
ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if
|
|
405
403
|
not provided.
|
|
406
|
-
model_card_kwargs (`
|
|
404
|
+
model_card_kwargs (`dict[str, Any]`, *optional*):
|
|
407
405
|
Additional arguments passed to the model card template to customize the model card.
|
|
408
406
|
push_to_hub_kwargs:
|
|
409
407
|
Additional key word arguments passed along to the [`~ModelHubMixin.push_to_hub`] method.
|
|
@@ -462,12 +460,10 @@ class ModelHubMixin:
|
|
|
462
460
|
@classmethod
|
|
463
461
|
@validate_hf_hub_args
|
|
464
462
|
def from_pretrained(
|
|
465
|
-
cls:
|
|
463
|
+
cls: type[T],
|
|
466
464
|
pretrained_model_name_or_path: Union[str, Path],
|
|
467
465
|
*,
|
|
468
466
|
force_download: bool = False,
|
|
469
|
-
resume_download: Optional[bool] = None,
|
|
470
|
-
proxies: Optional[Dict] = None,
|
|
471
467
|
token: Optional[Union[str, bool]] = None,
|
|
472
468
|
cache_dir: Optional[Union[str, Path]] = None,
|
|
473
469
|
local_files_only: bool = False,
|
|
@@ -488,9 +484,6 @@ class ModelHubMixin:
|
|
|
488
484
|
force_download (`bool`, *optional*, defaults to `False`):
|
|
489
485
|
Whether to force (re-)downloading the model weights and configuration files from the Hub, overriding
|
|
490
486
|
the existing cache.
|
|
491
|
-
proxies (`Dict[str, str]`, *optional*):
|
|
492
|
-
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
|
493
|
-
'http://hostname': 'foo.bar:4012'}`. The proxies are used on every request.
|
|
494
487
|
token (`str` or `bool`, *optional*):
|
|
495
488
|
The token to use as HTTP bearer authorization for remote files. By default, it will use the token
|
|
496
489
|
cached when running `hf auth login`.
|
|
@@ -498,7 +491,7 @@ class ModelHubMixin:
|
|
|
498
491
|
Path to the folder where cached files are stored.
|
|
499
492
|
local_files_only (`bool`, *optional*, defaults to `False`):
|
|
500
493
|
If `True`, avoid downloading the file and return the path to the local cached file if it exists.
|
|
501
|
-
model_kwargs (`
|
|
494
|
+
model_kwargs (`dict`, *optional*):
|
|
502
495
|
Additional kwargs to pass to the model during initialization.
|
|
503
496
|
"""
|
|
504
497
|
model_id = str(pretrained_model_name_or_path)
|
|
@@ -516,8 +509,6 @@ class ModelHubMixin:
|
|
|
516
509
|
revision=revision,
|
|
517
510
|
cache_dir=cache_dir,
|
|
518
511
|
force_download=force_download,
|
|
519
|
-
proxies=proxies,
|
|
520
|
-
resume_download=resume_download,
|
|
521
512
|
token=token,
|
|
522
513
|
local_files_only=local_files_only,
|
|
523
514
|
)
|
|
@@ -557,7 +548,7 @@ class ModelHubMixin:
|
|
|
557
548
|
if key not in model_kwargs and key in config:
|
|
558
549
|
model_kwargs[key] = config[key]
|
|
559
550
|
elif any(param.kind == inspect.Parameter.VAR_KEYWORD for param in cls._hub_mixin_init_parameters.values()):
|
|
560
|
-
for key, value in config.items():
|
|
551
|
+
for key, value in config.items(): # type: ignore[union-attr]
|
|
561
552
|
if key not in model_kwargs:
|
|
562
553
|
model_kwargs[key] = value
|
|
563
554
|
|
|
@@ -570,8 +561,6 @@ class ModelHubMixin:
|
|
|
570
561
|
revision=revision,
|
|
571
562
|
cache_dir=cache_dir,
|
|
572
563
|
force_download=force_download,
|
|
573
|
-
proxies=proxies,
|
|
574
|
-
resume_download=resume_download,
|
|
575
564
|
local_files_only=local_files_only,
|
|
576
565
|
token=token,
|
|
577
566
|
**model_kwargs,
|
|
@@ -586,14 +575,12 @@ class ModelHubMixin:
|
|
|
586
575
|
|
|
587
576
|
@classmethod
|
|
588
577
|
def _from_pretrained(
|
|
589
|
-
cls:
|
|
578
|
+
cls: type[T],
|
|
590
579
|
*,
|
|
591
580
|
model_id: str,
|
|
592
581
|
revision: Optional[str],
|
|
593
582
|
cache_dir: Optional[Union[str, Path]],
|
|
594
583
|
force_download: bool,
|
|
595
|
-
proxies: Optional[Dict],
|
|
596
|
-
resume_download: Optional[bool],
|
|
597
584
|
local_files_only: bool,
|
|
598
585
|
token: Optional[Union[str, bool]],
|
|
599
586
|
**model_kwargs,
|
|
@@ -616,9 +603,6 @@ class ModelHubMixin:
|
|
|
616
603
|
force_download (`bool`, *optional*, defaults to `False`):
|
|
617
604
|
Whether to force (re-)downloading the model weights and configuration files from the Hub, overriding
|
|
618
605
|
the existing cache.
|
|
619
|
-
proxies (`Dict[str, str]`, *optional*):
|
|
620
|
-
A dictionary of proxy servers to use by protocol or endpoint (e.g., `{'http': 'foo.bar:3128',
|
|
621
|
-
'http://hostname': 'foo.bar:4012'}`).
|
|
622
606
|
token (`str` or `bool`, *optional*):
|
|
623
607
|
The token to use as HTTP bearer authorization for remote files. By default, it will use the token
|
|
624
608
|
cached when running `hf auth login`.
|
|
@@ -642,10 +626,10 @@ class ModelHubMixin:
|
|
|
642
626
|
token: Optional[str] = None,
|
|
643
627
|
branch: Optional[str] = None,
|
|
644
628
|
create_pr: Optional[bool] = None,
|
|
645
|
-
allow_patterns: Optional[Union[
|
|
646
|
-
ignore_patterns: Optional[Union[
|
|
647
|
-
delete_patterns: Optional[Union[
|
|
648
|
-
model_card_kwargs: Optional[
|
|
629
|
+
allow_patterns: Optional[Union[list[str], str]] = None,
|
|
630
|
+
ignore_patterns: Optional[Union[list[str], str]] = None,
|
|
631
|
+
delete_patterns: Optional[Union[list[str], str]] = None,
|
|
632
|
+
model_card_kwargs: Optional[dict[str, Any]] = None,
|
|
649
633
|
) -> str:
|
|
650
634
|
"""
|
|
651
635
|
Upload model checkpoint to the Hub.
|
|
@@ -671,13 +655,13 @@ class ModelHubMixin:
|
|
|
671
655
|
The git branch on which to push the model. This defaults to `"main"`.
|
|
672
656
|
create_pr (`boolean`, *optional*):
|
|
673
657
|
Whether or not to create a Pull Request from `branch` with that commit. Defaults to `False`.
|
|
674
|
-
allow_patterns (`
|
|
658
|
+
allow_patterns (`list[str]` or `str`, *optional*):
|
|
675
659
|
If provided, only files matching at least one pattern are pushed.
|
|
676
|
-
ignore_patterns (`
|
|
660
|
+
ignore_patterns (`list[str]` or `str`, *optional*):
|
|
677
661
|
If provided, files matching any of the patterns are not pushed.
|
|
678
|
-
delete_patterns (`
|
|
662
|
+
delete_patterns (`list[str]` or `str`, *optional*):
|
|
679
663
|
If provided, remote files matching any of the patterns will be deleted from the repo.
|
|
680
|
-
model_card_kwargs (`
|
|
664
|
+
model_card_kwargs (`dict[str, Any]`, *optional*):
|
|
681
665
|
Additional arguments passed to the model card template to customize the model card.
|
|
682
666
|
|
|
683
667
|
Returns:
|
|
@@ -760,7 +744,7 @@ class PyTorchModelHubMixin(ModelHubMixin):
|
|
|
760
744
|
```
|
|
761
745
|
"""
|
|
762
746
|
|
|
763
|
-
def __init_subclass__(cls, *args, tags: Optional[
|
|
747
|
+
def __init_subclass__(cls, *args, tags: Optional[list[str]] = None, **kwargs) -> None:
|
|
764
748
|
tags = tags or []
|
|
765
749
|
tags.append("pytorch_model_hub_mixin")
|
|
766
750
|
kwargs["tags"] = tags
|
|
@@ -779,8 +763,6 @@ class PyTorchModelHubMixin(ModelHubMixin):
|
|
|
779
763
|
revision: Optional[str],
|
|
780
764
|
cache_dir: Optional[Union[str, Path]],
|
|
781
765
|
force_download: bool,
|
|
782
|
-
proxies: Optional[Dict],
|
|
783
|
-
resume_download: Optional[bool],
|
|
784
766
|
local_files_only: bool,
|
|
785
767
|
token: Union[str, bool, None],
|
|
786
768
|
map_location: str = "cpu",
|
|
@@ -801,8 +783,6 @@ class PyTorchModelHubMixin(ModelHubMixin):
|
|
|
801
783
|
revision=revision,
|
|
802
784
|
cache_dir=cache_dir,
|
|
803
785
|
force_download=force_download,
|
|
804
|
-
proxies=proxies,
|
|
805
|
-
resume_download=resume_download,
|
|
806
786
|
token=token,
|
|
807
787
|
local_files_only=local_files_only,
|
|
808
788
|
)
|
|
@@ -814,8 +794,6 @@ class PyTorchModelHubMixin(ModelHubMixin):
|
|
|
814
794
|
revision=revision,
|
|
815
795
|
cache_dir=cache_dir,
|
|
816
796
|
force_download=force_download,
|
|
817
|
-
proxies=proxies,
|
|
818
|
-
resume_download=resume_download,
|
|
819
797
|
token=token,
|
|
820
798
|
local_files_only=local_files_only,
|
|
821
799
|
)
|
|
@@ -845,7 +823,7 @@ class PyTorchModelHubMixin(ModelHubMixin):
|
|
|
845
823
|
return model
|
|
846
824
|
|
|
847
825
|
|
|
848
|
-
def _load_dataclass(datacls:
|
|
826
|
+
def _load_dataclass(datacls: type[DataclassInstance], data: dict) -> DataclassInstance:
|
|
849
827
|
"""Load a dataclass instance from a dictionary.
|
|
850
828
|
|
|
851
829
|
Fields not expected by the dataclass are ignored.
|