huggingface-hub 0.20.2__py3-none-any.whl → 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of huggingface-hub might be problematic. Click here for more details.
- huggingface_hub/__init__.py +19 -1
- huggingface_hub/_commit_api.py +49 -20
- huggingface_hub/_inference_endpoints.py +10 -0
- huggingface_hub/_login.py +2 -2
- huggingface_hub/commands/download.py +1 -1
- huggingface_hub/file_download.py +57 -21
- huggingface_hub/hf_api.py +269 -54
- huggingface_hub/hf_file_system.py +131 -8
- huggingface_hub/hub_mixin.py +204 -42
- huggingface_hub/inference/_client.py +56 -9
- huggingface_hub/inference/_common.py +4 -3
- huggingface_hub/inference/_generated/_async_client.py +57 -9
- huggingface_hub/inference/_text_generation.py +5 -0
- huggingface_hub/inference/_types.py +17 -0
- huggingface_hub/lfs.py +6 -3
- huggingface_hub/repocard.py +5 -3
- huggingface_hub/repocard_data.py +11 -3
- huggingface_hub/serialization/__init__.py +19 -0
- huggingface_hub/serialization/_base.py +168 -0
- huggingface_hub/serialization/_numpy.py +67 -0
- huggingface_hub/serialization/_tensorflow.py +93 -0
- huggingface_hub/serialization/_torch.py +199 -0
- huggingface_hub/templates/datasetcard_template.md +1 -1
- huggingface_hub/templates/modelcard_template.md +1 -4
- huggingface_hub/utils/__init__.py +14 -10
- huggingface_hub/utils/_datetime.py +4 -11
- huggingface_hub/utils/_errors.py +29 -0
- huggingface_hub/utils/_hf_folder.py +4 -23
- huggingface_hub/utils/_runtime.py +21 -15
- huggingface_hub/utils/endpoint_helpers.py +27 -1
- {huggingface_hub-0.20.2.dist-info → huggingface_hub-0.21.0.dist-info}/METADATA +7 -3
- {huggingface_hub-0.20.2.dist-info → huggingface_hub-0.21.0.dist-info}/RECORD +36 -31
- {huggingface_hub-0.20.2.dist-info → huggingface_hub-0.21.0.dist-info}/LICENSE +0 -0
- {huggingface_hub-0.20.2.dist-info → huggingface_hub-0.21.0.dist-info}/WHEEL +0 -0
- {huggingface_hub-0.20.2.dist-info → huggingface_hub-0.21.0.dist-info}/entry_points.txt +0 -0
- {huggingface_hub-0.20.2.dist-info → huggingface_hub-0.21.0.dist-info}/top_level.txt +0 -0
|
@@ -10,6 +10,7 @@ from typing import Any, Dict, List, NoReturn, Optional, Tuple, Union
|
|
|
10
10
|
from urllib.parse import quote, unquote
|
|
11
11
|
|
|
12
12
|
import fsspec
|
|
13
|
+
from requests import Response
|
|
13
14
|
|
|
14
15
|
from ._commit_api import CommitOperationCopy, CommitOperationDelete
|
|
15
16
|
from .constants import DEFAULT_REVISION, ENDPOINT, REPO_TYPE_MODEL, REPO_TYPES_MAPPING, REPO_TYPES_URL_PREFIXES
|
|
@@ -216,11 +217,15 @@ class HfFileSystem(fsspec.AbstractFileSystem):
|
|
|
216
217
|
path: str,
|
|
217
218
|
mode: str = "rb",
|
|
218
219
|
revision: Optional[str] = None,
|
|
220
|
+
block_size: Optional[int] = None,
|
|
219
221
|
**kwargs,
|
|
220
222
|
) -> "HfFileSystemFile":
|
|
221
223
|
if "a" in mode:
|
|
222
224
|
raise NotImplementedError("Appending to remote files is not yet supported.")
|
|
223
|
-
|
|
225
|
+
if block_size == 0:
|
|
226
|
+
return HfFileSystemStreamFile(self, path, mode=mode, revision=revision, block_size=block_size, **kwargs)
|
|
227
|
+
else:
|
|
228
|
+
return HfFileSystemFile(self, path, mode=mode, revision=revision, block_size=block_size, **kwargs)
|
|
224
229
|
|
|
225
230
|
def _rm(self, path: str, revision: Optional[str] = None, **kwargs) -> None:
|
|
226
231
|
resolved_path = self.resolve_path(path, revision=revision)
|
|
@@ -244,9 +249,8 @@ class HfFileSystem(fsspec.AbstractFileSystem):
|
|
|
244
249
|
**kwargs,
|
|
245
250
|
) -> None:
|
|
246
251
|
resolved_path = self.resolve_path(path, revision=revision)
|
|
247
|
-
root_path = REPO_TYPES_URL_PREFIXES.get(resolved_path.repo_type, "") + resolved_path.repo_id
|
|
248
252
|
paths = self.expand_path(path, recursive=recursive, maxdepth=maxdepth, revision=revision)
|
|
249
|
-
paths_in_repo = [path
|
|
253
|
+
paths_in_repo = [self.resolve_path(path).path_in_repo for path in paths if not self.isdir(path)]
|
|
250
254
|
operations = [CommitOperationDelete(path_in_repo=path_in_repo) for path_in_repo in paths_in_repo]
|
|
251
255
|
commit_message = f"Delete {path} "
|
|
252
256
|
commit_message += "recursively " if recursive else ""
|
|
@@ -439,7 +443,7 @@ class HfFileSystem(fsspec.AbstractFileSystem):
|
|
|
439
443
|
resolved_path1.repo_type == resolved_path2.repo_type and resolved_path1.repo_id == resolved_path2.repo_id
|
|
440
444
|
)
|
|
441
445
|
|
|
442
|
-
if same_repo
|
|
446
|
+
if same_repo:
|
|
443
447
|
commit_message = f"Copy {path1} to {path2}"
|
|
444
448
|
self._api.create_commit(
|
|
445
449
|
repo_id=resolved_path1.repo_id,
|
|
@@ -573,6 +577,20 @@ class HfFileSystem(fsspec.AbstractFileSystem):
|
|
|
573
577
|
except: # noqa: E722
|
|
574
578
|
return False
|
|
575
579
|
|
|
580
|
+
def url(self, path: str) -> str:
|
|
581
|
+
"""Get the HTTP URL of the given path"""
|
|
582
|
+
resolved_path = self.resolve_path(path)
|
|
583
|
+
url = hf_hub_url(
|
|
584
|
+
resolved_path.repo_id,
|
|
585
|
+
resolved_path.path_in_repo,
|
|
586
|
+
repo_type=resolved_path.repo_type,
|
|
587
|
+
revision=resolved_path.revision,
|
|
588
|
+
endpoint=self.endpoint,
|
|
589
|
+
)
|
|
590
|
+
if self.isdir(path):
|
|
591
|
+
url = url.replace("/resolve/", "/tree/", 1)
|
|
592
|
+
return url
|
|
593
|
+
|
|
576
594
|
@property
|
|
577
595
|
def transaction(self):
|
|
578
596
|
"""A context within which files are committed together upon exit
|
|
@@ -593,9 +611,6 @@ class HfFileSystem(fsspec.AbstractFileSystem):
|
|
|
593
611
|
|
|
594
612
|
class HfFileSystemFile(fsspec.spec.AbstractBufferedFile):
|
|
595
613
|
def __init__(self, fs: HfFileSystem, path: str, revision: Optional[str] = None, **kwargs):
|
|
596
|
-
super().__init__(fs, path, **kwargs)
|
|
597
|
-
self.fs: HfFileSystem
|
|
598
|
-
|
|
599
614
|
try:
|
|
600
615
|
self.resolved_path = fs.resolve_path(path, revision=revision)
|
|
601
616
|
except FileNotFoundError as e:
|
|
@@ -603,6 +618,8 @@ class HfFileSystemFile(fsspec.spec.AbstractBufferedFile):
|
|
|
603
618
|
raise FileNotFoundError(
|
|
604
619
|
f"{e}.\nMake sure the repository and revision exist before writing data."
|
|
605
620
|
) from e
|
|
621
|
+
super().__init__(fs, self.resolved_path.unresolve(), **kwargs)
|
|
622
|
+
self.fs: HfFileSystem
|
|
606
623
|
|
|
607
624
|
def __del__(self):
|
|
608
625
|
if not hasattr(self, "resolved_path"):
|
|
@@ -622,7 +639,7 @@ class HfFileSystemFile(fsspec.spec.AbstractBufferedFile):
|
|
|
622
639
|
repo_type=self.resolved_path.repo_type,
|
|
623
640
|
endpoint=self.fs.endpoint,
|
|
624
641
|
)
|
|
625
|
-
r = http_backoff("GET", url, headers=headers)
|
|
642
|
+
r = http_backoff("GET", url, headers=headers, retry_on_status_codes=(502, 503, 504))
|
|
626
643
|
hf_raise_for_status(r)
|
|
627
644
|
return r.content
|
|
628
645
|
|
|
@@ -650,6 +667,108 @@ class HfFileSystemFile(fsspec.spec.AbstractBufferedFile):
|
|
|
650
667
|
path=self.resolved_path.unresolve(),
|
|
651
668
|
)
|
|
652
669
|
|
|
670
|
+
def url(self) -> str:
|
|
671
|
+
return self.fs.url(self.path)
|
|
672
|
+
|
|
673
|
+
|
|
674
|
+
class HfFileSystemStreamFile(fsspec.spec.AbstractBufferedFile):
|
|
675
|
+
def __init__(
|
|
676
|
+
self,
|
|
677
|
+
fs: HfFileSystem,
|
|
678
|
+
path: str,
|
|
679
|
+
mode: str = "rb",
|
|
680
|
+
revision: Optional[str] = None,
|
|
681
|
+
block_size: int = 0,
|
|
682
|
+
cache_type: str = "none",
|
|
683
|
+
**kwargs,
|
|
684
|
+
):
|
|
685
|
+
if block_size != 0:
|
|
686
|
+
raise ValueError(f"HfFileSystemStreamFile only supports block_size=0 but got {block_size}")
|
|
687
|
+
if cache_type != "none":
|
|
688
|
+
raise ValueError(f"HfFileSystemStreamFile only supports cache_type='none' but got {cache_type}")
|
|
689
|
+
if "w" in mode:
|
|
690
|
+
raise ValueError(f"HfFileSystemStreamFile only supports reading but got mode='{mode}'")
|
|
691
|
+
try:
|
|
692
|
+
self.resolved_path = fs.resolve_path(path, revision=revision)
|
|
693
|
+
except FileNotFoundError as e:
|
|
694
|
+
if "w" in kwargs.get("mode", ""):
|
|
695
|
+
raise FileNotFoundError(
|
|
696
|
+
f"{e}.\nMake sure the repository and revision exist before writing data."
|
|
697
|
+
) from e
|
|
698
|
+
# avoid an unecessary .info() call to instantiate .details
|
|
699
|
+
self.details = {"name": self.resolved_path.unresolve(), "size": None}
|
|
700
|
+
super().__init__(
|
|
701
|
+
fs, self.resolved_path.unresolve(), mode=mode, block_size=block_size, cache_type=cache_type, **kwargs
|
|
702
|
+
)
|
|
703
|
+
self.response: Optional[Response] = None
|
|
704
|
+
self.fs: HfFileSystem
|
|
705
|
+
|
|
706
|
+
def seek(self, loc: int, whence: int = 0):
|
|
707
|
+
if loc == 0 and whence == 1:
|
|
708
|
+
return
|
|
709
|
+
if loc == self.loc and whence == 0:
|
|
710
|
+
return
|
|
711
|
+
raise ValueError("Cannot seek streaming HF file")
|
|
712
|
+
|
|
713
|
+
def read(self, length: int = -1):
|
|
714
|
+
read_args = (length,) if length >= 0 else ()
|
|
715
|
+
if self.response is None or self.response.raw.isclosed():
|
|
716
|
+
url = hf_hub_url(
|
|
717
|
+
repo_id=self.resolved_path.repo_id,
|
|
718
|
+
revision=self.resolved_path.revision,
|
|
719
|
+
filename=self.resolved_path.path_in_repo,
|
|
720
|
+
repo_type=self.resolved_path.repo_type,
|
|
721
|
+
endpoint=self.fs.endpoint,
|
|
722
|
+
)
|
|
723
|
+
self.response = http_backoff(
|
|
724
|
+
"GET",
|
|
725
|
+
url,
|
|
726
|
+
headers=self.fs._api._build_hf_headers(),
|
|
727
|
+
retry_on_status_codes=(502, 503, 504),
|
|
728
|
+
stream=True,
|
|
729
|
+
)
|
|
730
|
+
hf_raise_for_status(self.response)
|
|
731
|
+
try:
|
|
732
|
+
out = self.response.raw.read(*read_args)
|
|
733
|
+
except Exception:
|
|
734
|
+
self.response.close()
|
|
735
|
+
|
|
736
|
+
# Retry by recreating the connection
|
|
737
|
+
url = hf_hub_url(
|
|
738
|
+
repo_id=self.resolved_path.repo_id,
|
|
739
|
+
revision=self.resolved_path.revision,
|
|
740
|
+
filename=self.resolved_path.path_in_repo,
|
|
741
|
+
repo_type=self.resolved_path.repo_type,
|
|
742
|
+
endpoint=self.fs.endpoint,
|
|
743
|
+
)
|
|
744
|
+
self.response = http_backoff(
|
|
745
|
+
"GET",
|
|
746
|
+
url,
|
|
747
|
+
headers={"Range": "bytes=%d-" % self.loc, **self.fs._api._build_hf_headers()},
|
|
748
|
+
retry_on_status_codes=(502, 503, 504),
|
|
749
|
+
stream=True,
|
|
750
|
+
)
|
|
751
|
+
hf_raise_for_status(self.response)
|
|
752
|
+
try:
|
|
753
|
+
out = self.response.raw.read(*read_args)
|
|
754
|
+
except Exception:
|
|
755
|
+
self.response.close()
|
|
756
|
+
raise
|
|
757
|
+
self.loc += len(out)
|
|
758
|
+
return out
|
|
759
|
+
|
|
760
|
+
def url(self) -> str:
|
|
761
|
+
return self.fs.url(self.path)
|
|
762
|
+
|
|
763
|
+
def __del__(self):
|
|
764
|
+
if not hasattr(self, "resolved_path"):
|
|
765
|
+
# Means that the constructor failed. Nothing to do.
|
|
766
|
+
return
|
|
767
|
+
return super().__del__()
|
|
768
|
+
|
|
769
|
+
def __reduce__(self):
|
|
770
|
+
return reopen, (self.fs, self.path, self.mode, self.blocksize, self.cache.name)
|
|
771
|
+
|
|
653
772
|
|
|
654
773
|
def safe_revision(revision: str) -> str:
|
|
655
774
|
return revision if SPECIAL_REFS_REVISION_REGEX.match(revision) else safe_quote(revision)
|
|
@@ -668,3 +787,7 @@ def _raise_file_not_found(path: str, err: Optional[Exception]) -> NoReturn:
|
|
|
668
787
|
elif isinstance(err, HFValidationError):
|
|
669
788
|
msg = f"{path} (invalid repository id)"
|
|
670
789
|
raise FileNotFoundError(msg) from err
|
|
790
|
+
|
|
791
|
+
|
|
792
|
+
def reopen(fs: HfFileSystem, path: str, mode: str, block_size: int, cache_type: str):
|
|
793
|
+
return fs.open(path, mode=mode, block_size=block_size, cache_type=cache_type)
|
huggingface_hub/hub_mixin.py
CHANGED
|
@@ -1,17 +1,36 @@
|
|
|
1
|
+
import inspect
|
|
1
2
|
import json
|
|
2
3
|
import os
|
|
4
|
+
from dataclasses import asdict, is_dataclass
|
|
3
5
|
from pathlib import Path
|
|
4
|
-
from typing import Dict, List, Optional, Type, TypeVar, Union
|
|
6
|
+
from typing import TYPE_CHECKING, Dict, List, Optional, Type, TypeVar, Union, get_args
|
|
5
7
|
|
|
6
|
-
from .constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME
|
|
7
|
-
from .file_download import hf_hub_download
|
|
8
|
+
from .constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME, SAFETENSORS_SINGLE_FILE
|
|
9
|
+
from .file_download import hf_hub_download
|
|
8
10
|
from .hf_api import HfApi
|
|
9
|
-
from .utils import
|
|
11
|
+
from .utils import (
|
|
12
|
+
EntryNotFoundError,
|
|
13
|
+
HfHubHTTPError,
|
|
14
|
+
SoftTemporaryDirectory,
|
|
15
|
+
is_safetensors_available,
|
|
16
|
+
is_torch_available,
|
|
17
|
+
logging,
|
|
18
|
+
validate_hf_hub_args,
|
|
19
|
+
)
|
|
20
|
+
from .utils._deprecation import _deprecate_arguments
|
|
10
21
|
|
|
11
22
|
|
|
23
|
+
if TYPE_CHECKING:
|
|
24
|
+
from _typeshed import DataclassInstance
|
|
25
|
+
|
|
12
26
|
if is_torch_available():
|
|
13
27
|
import torch # type: ignore
|
|
14
28
|
|
|
29
|
+
if is_safetensors_available():
|
|
30
|
+
from safetensors import safe_open
|
|
31
|
+
from safetensors.torch import save_file
|
|
32
|
+
|
|
33
|
+
|
|
15
34
|
logger = logging.get_logger(__name__)
|
|
16
35
|
|
|
17
36
|
# Generic variable that is either ModelHubMixin or a subclass thereof
|
|
@@ -25,16 +44,89 @@ class ModelHubMixin:
|
|
|
25
44
|
To integrate your framework, your model class must inherit from this class. Custom logic for saving/loading models
|
|
26
45
|
have to be overwritten in [`_from_pretrained`] and [`_save_pretrained`]. [`PyTorchModelHubMixin`] is a good example
|
|
27
46
|
of mixin integration with the Hub. Check out our [integration guide](../guides/integrations) for more instructions.
|
|
47
|
+
|
|
48
|
+
Example:
|
|
49
|
+
|
|
50
|
+
```python
|
|
51
|
+
>>> from dataclasses import dataclass
|
|
52
|
+
>>> from huggingface_hub import ModelHubMixin
|
|
53
|
+
|
|
54
|
+
# Define your model configuration (optional)
|
|
55
|
+
>>> @dataclass
|
|
56
|
+
... class Config:
|
|
57
|
+
... foo: int = 512
|
|
58
|
+
... bar: str = "cpu"
|
|
59
|
+
|
|
60
|
+
# Inherit from ModelHubMixin (and optionally from your framework's model class)
|
|
61
|
+
>>> class MyCustomModel(ModelHubMixin):
|
|
62
|
+
... def __init__(self, config: Config):
|
|
63
|
+
... # define how to initialize your model
|
|
64
|
+
... super().__init__()
|
|
65
|
+
... ...
|
|
66
|
+
...
|
|
67
|
+
... def _save_pretrained(self, save_directory: Path) -> None:
|
|
68
|
+
... # define how to serialize your model
|
|
69
|
+
... ...
|
|
70
|
+
...
|
|
71
|
+
... @classmethod
|
|
72
|
+
... def from_pretrained(
|
|
73
|
+
... cls: Type[T],
|
|
74
|
+
... pretrained_model_name_or_path: Union[str, Path],
|
|
75
|
+
... *,
|
|
76
|
+
... force_download: bool = False,
|
|
77
|
+
... resume_download: bool = False,
|
|
78
|
+
... proxies: Optional[Dict] = None,
|
|
79
|
+
... token: Optional[Union[str, bool]] = None,
|
|
80
|
+
... cache_dir: Optional[Union[str, Path]] = None,
|
|
81
|
+
... local_files_only: bool = False,
|
|
82
|
+
... revision: Optional[str] = None,
|
|
83
|
+
... **model_kwargs,
|
|
84
|
+
... ) -> T:
|
|
85
|
+
... # define how to deserialize your model
|
|
86
|
+
... ...
|
|
87
|
+
|
|
88
|
+
>>> model = MyCustomModel(config=Config(foo=256, bar="gpu"))
|
|
89
|
+
|
|
90
|
+
# Save model weights to local directory
|
|
91
|
+
>>> model.save_pretrained("my-awesome-model")
|
|
92
|
+
|
|
93
|
+
# Push model weights to the Hub
|
|
94
|
+
>>> model.push_to_hub("my-awesome-model")
|
|
95
|
+
|
|
96
|
+
# Download and initialize weights from the Hub
|
|
97
|
+
>>> reloaded_model = MyCustomModel.from_pretrained("username/my-awesome-model")
|
|
98
|
+
>>> reloaded_model.config
|
|
99
|
+
Config(foo=256, bar="gpu")
|
|
100
|
+
```
|
|
28
101
|
"""
|
|
29
102
|
|
|
103
|
+
config: Optional[Union[dict, "DataclassInstance"]] = None
|
|
104
|
+
# ^ optional config attribute automatically set in `from_pretrained` (if not already set by the subclass)
|
|
105
|
+
|
|
106
|
+
def __new__(cls, *args, **kwargs) -> "ModelHubMixin":
|
|
107
|
+
instance = super().__new__(cls)
|
|
108
|
+
|
|
109
|
+
# Set `config` attribute if not already set by the subclass
|
|
110
|
+
if instance.config is None:
|
|
111
|
+
if "config" in kwargs:
|
|
112
|
+
instance.config = kwargs["config"]
|
|
113
|
+
elif len(args) > 0:
|
|
114
|
+
sig = inspect.signature(cls.__init__)
|
|
115
|
+
parameters = list(sig.parameters)[1:] # remove `self`
|
|
116
|
+
for key, value in zip(parameters, args):
|
|
117
|
+
if key == "config":
|
|
118
|
+
instance.config = value
|
|
119
|
+
break
|
|
120
|
+
return instance
|
|
121
|
+
|
|
30
122
|
def save_pretrained(
|
|
31
123
|
self,
|
|
32
124
|
save_directory: Union[str, Path],
|
|
33
125
|
*,
|
|
34
|
-
config: Optional[dict] = None,
|
|
126
|
+
config: Optional[Union[dict, "DataclassInstance"]] = None,
|
|
35
127
|
repo_id: Optional[str] = None,
|
|
36
128
|
push_to_hub: bool = False,
|
|
37
|
-
**
|
|
129
|
+
**push_to_hub_kwargs,
|
|
38
130
|
) -> Optional[str]:
|
|
39
131
|
"""
|
|
40
132
|
Save weights in local directory.
|
|
@@ -42,8 +134,8 @@ class ModelHubMixin:
|
|
|
42
134
|
Args:
|
|
43
135
|
save_directory (`str` or `Path`):
|
|
44
136
|
Path to directory in which the model weights and configuration will be saved.
|
|
45
|
-
config (`dict`, *optional*):
|
|
46
|
-
Model configuration specified as a key/value dictionary.
|
|
137
|
+
config (`dict` or `DataclassInstance`, *optional*):
|
|
138
|
+
Model configuration specified as a key/value dictionary or a dataclass instance.
|
|
47
139
|
push_to_hub (`bool`, *optional*, defaults to `False`):
|
|
48
140
|
Whether or not to push your model to the Huggingface Hub after saving it.
|
|
49
141
|
repo_id (`str`, *optional*):
|
|
@@ -55,15 +147,20 @@ class ModelHubMixin:
|
|
|
55
147
|
save_directory = Path(save_directory)
|
|
56
148
|
save_directory.mkdir(parents=True, exist_ok=True)
|
|
57
149
|
|
|
58
|
-
#
|
|
150
|
+
# save model weights/files (framework-specific)
|
|
59
151
|
self._save_pretrained(save_directory)
|
|
60
152
|
|
|
61
|
-
#
|
|
62
|
-
if
|
|
63
|
-
|
|
153
|
+
# save config (if provided)
|
|
154
|
+
if config is None:
|
|
155
|
+
config = self.config
|
|
156
|
+
if config is not None:
|
|
157
|
+
if is_dataclass(config):
|
|
158
|
+
config = asdict(config) # type: ignore[arg-type]
|
|
159
|
+
(save_directory / CONFIG_NAME).write_text(json.dumps(config, indent=2))
|
|
64
160
|
|
|
161
|
+
# push to the Hub if required
|
|
65
162
|
if push_to_hub:
|
|
66
|
-
kwargs =
|
|
163
|
+
kwargs = push_to_hub_kwargs.copy() # soft-copy to avoid mutating input
|
|
67
164
|
if config is not None: # kwarg for `push_to_hub`
|
|
68
165
|
kwargs["config"] = config
|
|
69
166
|
if repo_id is None:
|
|
@@ -126,17 +223,17 @@ class ModelHubMixin:
|
|
|
126
223
|
model_kwargs (`Dict`, *optional*):
|
|
127
224
|
Additional kwargs to pass to the model during initialization.
|
|
128
225
|
"""
|
|
129
|
-
model_id = pretrained_model_name_or_path
|
|
226
|
+
model_id = str(pretrained_model_name_or_path)
|
|
130
227
|
config_file: Optional[str] = None
|
|
131
228
|
if os.path.isdir(model_id):
|
|
132
229
|
if CONFIG_NAME in os.listdir(model_id):
|
|
133
230
|
config_file = os.path.join(model_id, CONFIG_NAME)
|
|
134
231
|
else:
|
|
135
232
|
logger.warning(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
|
|
136
|
-
|
|
233
|
+
else:
|
|
137
234
|
try:
|
|
138
235
|
config_file = hf_hub_download(
|
|
139
|
-
repo_id=
|
|
236
|
+
repo_id=model_id,
|
|
140
237
|
filename=CONFIG_NAME,
|
|
141
238
|
revision=revision,
|
|
142
239
|
cache_dir=cache_dir,
|
|
@@ -146,15 +243,35 @@ class ModelHubMixin:
|
|
|
146
243
|
token=token,
|
|
147
244
|
local_files_only=local_files_only,
|
|
148
245
|
)
|
|
149
|
-
except HfHubHTTPError:
|
|
150
|
-
logger.info(f"{CONFIG_NAME} not found
|
|
246
|
+
except HfHubHTTPError as e:
|
|
247
|
+
logger.info(f"{CONFIG_NAME} not found on the HuggingFace Hub: {str(e)}")
|
|
151
248
|
|
|
249
|
+
config = None
|
|
152
250
|
if config_file is not None:
|
|
251
|
+
# Read config
|
|
153
252
|
with open(config_file, "r", encoding="utf-8") as f:
|
|
154
253
|
config = json.load(f)
|
|
155
|
-
model_kwargs.update({"config": config})
|
|
156
254
|
|
|
157
|
-
|
|
255
|
+
# Check if class expect a `config` argument
|
|
256
|
+
init_parameters = inspect.signature(cls.__init__).parameters
|
|
257
|
+
if "config" in init_parameters:
|
|
258
|
+
# Check if `config` argument is a dataclass
|
|
259
|
+
config_annotation = init_parameters["config"].annotation
|
|
260
|
+
if config_annotation is inspect.Parameter.empty:
|
|
261
|
+
pass # no annotation
|
|
262
|
+
elif is_dataclass(config_annotation):
|
|
263
|
+
config = config_annotation(**config) # expect a dataclass
|
|
264
|
+
else:
|
|
265
|
+
# if Optional/Union annotation => check if a dataclass is in the Union
|
|
266
|
+
for _sub_annotation in get_args(config_annotation):
|
|
267
|
+
if is_dataclass(_sub_annotation):
|
|
268
|
+
config = _sub_annotation(**config)
|
|
269
|
+
break
|
|
270
|
+
|
|
271
|
+
# Forward config to model initialization
|
|
272
|
+
model_kwargs["config"] = config
|
|
273
|
+
|
|
274
|
+
instance = cls._from_pretrained(
|
|
158
275
|
model_id=str(model_id),
|
|
159
276
|
revision=revision,
|
|
160
277
|
cache_dir=cache_dir,
|
|
@@ -166,6 +283,13 @@ class ModelHubMixin:
|
|
|
166
283
|
**model_kwargs,
|
|
167
284
|
)
|
|
168
285
|
|
|
286
|
+
# Implicitly set the config as instance attribute if not already set by the class
|
|
287
|
+
# This way `config` will be available when calling `save_pretrained` or `push_to_hub`.
|
|
288
|
+
if config is not None and instance.config is None:
|
|
289
|
+
instance.config = config
|
|
290
|
+
|
|
291
|
+
return instance
|
|
292
|
+
|
|
169
293
|
@classmethod
|
|
170
294
|
def _from_pretrained(
|
|
171
295
|
cls: Type[T],
|
|
@@ -215,21 +339,27 @@ class ModelHubMixin:
|
|
|
215
339
|
"""
|
|
216
340
|
raise NotImplementedError
|
|
217
341
|
|
|
342
|
+
@_deprecate_arguments(
|
|
343
|
+
version="0.23.0",
|
|
344
|
+
deprecated_args=["api_endpoint"],
|
|
345
|
+
custom_message="Use `HF_ENDPOINT` environment variable instead.",
|
|
346
|
+
)
|
|
218
347
|
@validate_hf_hub_args
|
|
219
348
|
def push_to_hub(
|
|
220
349
|
self,
|
|
221
350
|
repo_id: str,
|
|
222
351
|
*,
|
|
223
|
-
config: Optional[dict] = None,
|
|
352
|
+
config: Optional[Union[dict, "DataclassInstance"]] = None,
|
|
224
353
|
commit_message: str = "Push model using huggingface_hub.",
|
|
225
354
|
private: bool = False,
|
|
226
|
-
api_endpoint: Optional[str] = None,
|
|
227
355
|
token: Optional[str] = None,
|
|
228
356
|
branch: Optional[str] = None,
|
|
229
357
|
create_pr: Optional[bool] = None,
|
|
230
358
|
allow_patterns: Optional[Union[List[str], str]] = None,
|
|
231
359
|
ignore_patterns: Optional[Union[List[str], str]] = None,
|
|
232
360
|
delete_patterns: Optional[Union[List[str], str]] = None,
|
|
361
|
+
# TODO: remove once deprecated
|
|
362
|
+
api_endpoint: Optional[str] = None,
|
|
233
363
|
) -> str:
|
|
234
364
|
"""
|
|
235
365
|
Upload model checkpoint to the Hub.
|
|
@@ -238,12 +368,11 @@ class ModelHubMixin:
|
|
|
238
368
|
`delete_patterns` to delete existing remote files in the same commit. See [`upload_folder`] reference for more
|
|
239
369
|
details.
|
|
240
370
|
|
|
241
|
-
|
|
242
371
|
Args:
|
|
243
372
|
repo_id (`str`):
|
|
244
373
|
ID of the repository to push to (example: `"username/my-model"`).
|
|
245
|
-
config (`dict`, *optional*):
|
|
246
|
-
|
|
374
|
+
config (`dict` or `DataclassInstance`, *optional*):
|
|
375
|
+
Model configuration specified as a key/value dictionary or a dataclass instance.
|
|
247
376
|
commit_message (`str`, *optional*):
|
|
248
377
|
Message to commit while pushing.
|
|
249
378
|
private (`bool`, *optional*, defaults to `False`):
|
|
@@ -296,16 +425,22 @@ class PyTorchModelHubMixin(ModelHubMixin):
|
|
|
296
425
|
Example:
|
|
297
426
|
|
|
298
427
|
```python
|
|
428
|
+
>>> from dataclasses import dataclass
|
|
299
429
|
>>> import torch
|
|
300
430
|
>>> import torch.nn as nn
|
|
301
431
|
>>> from huggingface_hub import PyTorchModelHubMixin
|
|
302
432
|
|
|
433
|
+
>>> @dataclass
|
|
434
|
+
... class Config:
|
|
435
|
+
... hidden_size: int = 512
|
|
436
|
+
... vocab_size: int = 30000
|
|
437
|
+
... output_size: int = 4
|
|
303
438
|
|
|
304
439
|
>>> class MyModel(nn.Module, PyTorchModelHubMixin):
|
|
305
|
-
... def __init__(self):
|
|
440
|
+
... def __init__(self, config: Config):
|
|
306
441
|
... super().__init__()
|
|
307
|
-
... self.param = nn.Parameter(torch.rand(
|
|
308
|
-
... self.linear = nn.Linear(
|
|
442
|
+
... self.param = nn.Parameter(torch.rand(config.hidden_size, config.vocab_size))
|
|
443
|
+
... self.linear = nn.Linear(config.output_size, config.vocab_size)
|
|
309
444
|
|
|
310
445
|
... def forward(self, x):
|
|
311
446
|
... return self.linear(x + self.param)
|
|
@@ -325,7 +460,7 @@ class PyTorchModelHubMixin(ModelHubMixin):
|
|
|
325
460
|
def _save_pretrained(self, save_directory: Path) -> None:
|
|
326
461
|
"""Save weights from a Pytorch model to a local directory."""
|
|
327
462
|
model_to_save = self.module if hasattr(self, "module") else self # type: ignore
|
|
328
|
-
|
|
463
|
+
save_file(model_to_save.state_dict(), save_directory / SAFETENSORS_SINGLE_FILE)
|
|
329
464
|
|
|
330
465
|
@classmethod
|
|
331
466
|
def _from_pretrained(
|
|
@@ -344,25 +479,52 @@ class PyTorchModelHubMixin(ModelHubMixin):
|
|
|
344
479
|
**model_kwargs,
|
|
345
480
|
):
|
|
346
481
|
"""Load Pytorch pretrained weights and return the loaded model."""
|
|
482
|
+
model = cls(**model_kwargs)
|
|
347
483
|
if os.path.isdir(model_id):
|
|
348
484
|
print("Loading weights from local directory")
|
|
349
|
-
model_file = os.path.join(model_id,
|
|
485
|
+
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
|
|
486
|
+
return cls._load_as_safetensor(model, model_file, map_location, strict)
|
|
350
487
|
else:
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
488
|
+
try:
|
|
489
|
+
model_file = hf_hub_download(
|
|
490
|
+
repo_id=model_id,
|
|
491
|
+
filename=SAFETENSORS_SINGLE_FILE,
|
|
492
|
+
revision=revision,
|
|
493
|
+
cache_dir=cache_dir,
|
|
494
|
+
force_download=force_download,
|
|
495
|
+
proxies=proxies,
|
|
496
|
+
resume_download=resume_download,
|
|
497
|
+
token=token,
|
|
498
|
+
local_files_only=local_files_only,
|
|
499
|
+
)
|
|
500
|
+
return cls._load_as_safetensor(model, model_file, map_location, strict)
|
|
501
|
+
except EntryNotFoundError:
|
|
502
|
+
model_file = hf_hub_download(
|
|
503
|
+
repo_id=model_id,
|
|
504
|
+
filename=PYTORCH_WEIGHTS_NAME,
|
|
505
|
+
revision=revision,
|
|
506
|
+
cache_dir=cache_dir,
|
|
507
|
+
force_download=force_download,
|
|
508
|
+
proxies=proxies,
|
|
509
|
+
resume_download=resume_download,
|
|
510
|
+
token=token,
|
|
511
|
+
local_files_only=local_files_only,
|
|
512
|
+
)
|
|
513
|
+
return cls._load_as_pickle(model, model_file, map_location, strict)
|
|
363
514
|
|
|
515
|
+
@classmethod
|
|
516
|
+
def _load_as_pickle(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
|
|
364
517
|
state_dict = torch.load(model_file, map_location=torch.device(map_location))
|
|
365
518
|
model.load_state_dict(state_dict, strict=strict) # type: ignore
|
|
366
519
|
model.eval() # type: ignore
|
|
520
|
+
return model
|
|
367
521
|
|
|
522
|
+
@classmethod
|
|
523
|
+
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
|
|
524
|
+
state_dict = {}
|
|
525
|
+
with safe_open(model_file, framework="pt", device=map_location) as f: # type: ignore [attr-defined]
|
|
526
|
+
for k in f.keys():
|
|
527
|
+
state_dict[k] = f.get_tensor(k)
|
|
528
|
+
model.load_state_dict(state_dict, strict=strict) # type: ignore
|
|
529
|
+
model.eval() # type: ignore
|
|
368
530
|
return model
|