huggingface-hub 0.21.2__py3-none-any.whl → 0.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of huggingface-hub might be problematic. Click here for more details.
- huggingface_hub/__init__.py +217 -1
- huggingface_hub/_commit_api.py +14 -15
- huggingface_hub/_inference_endpoints.py +12 -11
- huggingface_hub/_login.py +1 -0
- huggingface_hub/_multi_commits.py +1 -0
- huggingface_hub/_snapshot_download.py +9 -1
- huggingface_hub/_tensorboard_logger.py +1 -0
- huggingface_hub/_webhooks_payload.py +1 -0
- huggingface_hub/_webhooks_server.py +1 -0
- huggingface_hub/commands/_cli_utils.py +1 -0
- huggingface_hub/commands/delete_cache.py +1 -0
- huggingface_hub/commands/download.py +1 -0
- huggingface_hub/commands/env.py +1 -0
- huggingface_hub/commands/scan_cache.py +1 -0
- huggingface_hub/commands/upload.py +1 -0
- huggingface_hub/community.py +1 -0
- huggingface_hub/constants.py +3 -1
- huggingface_hub/errors.py +38 -0
- huggingface_hub/file_download.py +102 -95
- huggingface_hub/hf_api.py +47 -35
- huggingface_hub/hf_file_system.py +77 -3
- huggingface_hub/hub_mixin.py +230 -61
- huggingface_hub/inference/_client.py +554 -239
- huggingface_hub/inference/_common.py +195 -41
- huggingface_hub/inference/_generated/_async_client.py +558 -239
- huggingface_hub/inference/_generated/types/__init__.py +115 -0
- huggingface_hub/inference/_generated/types/audio_classification.py +43 -0
- huggingface_hub/inference/_generated/types/audio_to_audio.py +31 -0
- huggingface_hub/inference/_generated/types/automatic_speech_recognition.py +116 -0
- huggingface_hub/inference/_generated/types/base.py +149 -0
- huggingface_hub/inference/_generated/types/chat_completion.py +106 -0
- huggingface_hub/inference/_generated/types/depth_estimation.py +29 -0
- huggingface_hub/inference/_generated/types/document_question_answering.py +85 -0
- huggingface_hub/inference/_generated/types/feature_extraction.py +19 -0
- huggingface_hub/inference/_generated/types/fill_mask.py +50 -0
- huggingface_hub/inference/_generated/types/image_classification.py +43 -0
- huggingface_hub/inference/_generated/types/image_segmentation.py +52 -0
- huggingface_hub/inference/_generated/types/image_to_image.py +55 -0
- huggingface_hub/inference/_generated/types/image_to_text.py +105 -0
- huggingface_hub/inference/_generated/types/object_detection.py +55 -0
- huggingface_hub/inference/_generated/types/question_answering.py +77 -0
- huggingface_hub/inference/_generated/types/sentence_similarity.py +28 -0
- huggingface_hub/inference/_generated/types/summarization.py +46 -0
- huggingface_hub/inference/_generated/types/table_question_answering.py +45 -0
- huggingface_hub/inference/_generated/types/text2text_generation.py +45 -0
- huggingface_hub/inference/_generated/types/text_classification.py +43 -0
- huggingface_hub/inference/_generated/types/text_generation.py +161 -0
- huggingface_hub/inference/_generated/types/text_to_audio.py +105 -0
- huggingface_hub/inference/_generated/types/text_to_image.py +57 -0
- huggingface_hub/inference/_generated/types/token_classification.py +53 -0
- huggingface_hub/inference/_generated/types/translation.py +46 -0
- huggingface_hub/inference/_generated/types/video_classification.py +47 -0
- huggingface_hub/inference/_generated/types/visual_question_answering.py +53 -0
- huggingface_hub/inference/_generated/types/zero_shot_classification.py +56 -0
- huggingface_hub/inference/_generated/types/zero_shot_image_classification.py +51 -0
- huggingface_hub/inference/_generated/types/zero_shot_object_detection.py +55 -0
- huggingface_hub/inference/_templating.py +105 -0
- huggingface_hub/inference/_types.py +4 -152
- huggingface_hub/keras_mixin.py +39 -17
- huggingface_hub/lfs.py +20 -8
- huggingface_hub/repocard.py +11 -3
- huggingface_hub/repocard_data.py +12 -2
- huggingface_hub/serialization/__init__.py +1 -0
- huggingface_hub/serialization/_base.py +1 -0
- huggingface_hub/serialization/_numpy.py +1 -0
- huggingface_hub/serialization/_tensorflow.py +1 -0
- huggingface_hub/serialization/_torch.py +1 -0
- huggingface_hub/utils/__init__.py +4 -1
- huggingface_hub/utils/_cache_manager.py +7 -0
- huggingface_hub/utils/_chunk_utils.py +1 -0
- huggingface_hub/utils/_datetime.py +1 -0
- huggingface_hub/utils/_errors.py +10 -1
- huggingface_hub/utils/_experimental.py +1 -0
- huggingface_hub/utils/_fixes.py +19 -3
- huggingface_hub/utils/_git_credential.py +1 -0
- huggingface_hub/utils/_headers.py +10 -3
- huggingface_hub/utils/_hf_folder.py +1 -0
- huggingface_hub/utils/_http.py +1 -0
- huggingface_hub/utils/_pagination.py +1 -0
- huggingface_hub/utils/_paths.py +1 -0
- huggingface_hub/utils/_runtime.py +22 -0
- huggingface_hub/utils/_subprocess.py +1 -0
- huggingface_hub/utils/_token.py +1 -0
- huggingface_hub/utils/_typing.py +29 -1
- huggingface_hub/utils/_validators.py +1 -0
- huggingface_hub/utils/endpoint_helpers.py +1 -0
- huggingface_hub/utils/logging.py +1 -1
- huggingface_hub/utils/sha.py +1 -0
- huggingface_hub/utils/tqdm.py +1 -0
- {huggingface_hub-0.21.2.dist-info → huggingface_hub-0.22.0.dist-info}/METADATA +14 -15
- huggingface_hub-0.22.0.dist-info/RECORD +113 -0
- {huggingface_hub-0.21.2.dist-info → huggingface_hub-0.22.0.dist-info}/WHEEL +1 -1
- huggingface_hub/inference/_text_generation.py +0 -551
- huggingface_hub-0.21.2.dist-info/RECORD +0 -81
- {huggingface_hub-0.21.2.dist-info → huggingface_hub-0.22.0.dist-info}/LICENSE +0 -0
- {huggingface_hub-0.21.2.dist-info → huggingface_hub-0.22.0.dist-info}/entry_points.txt +0 -0
- {huggingface_hub-0.21.2.dist-info → huggingface_hub-0.22.0.dist-info}/top_level.txt +0 -0
|
@@ -6,15 +6,24 @@ from collections import deque
|
|
|
6
6
|
from dataclasses import dataclass, field
|
|
7
7
|
from datetime import datetime
|
|
8
8
|
from itertools import chain
|
|
9
|
+
from pathlib import Path
|
|
9
10
|
from typing import Any, Dict, List, NoReturn, Optional, Tuple, Union
|
|
10
11
|
from urllib.parse import quote, unquote
|
|
11
12
|
|
|
12
13
|
import fsspec
|
|
14
|
+
from fsspec.callbacks import _DEFAULT_CALLBACK, NoOpCallback, TqdmCallback
|
|
15
|
+
from fsspec.utils import isfilelike
|
|
13
16
|
from requests import Response
|
|
14
17
|
|
|
15
18
|
from ._commit_api import CommitOperationCopy, CommitOperationDelete
|
|
16
|
-
from .constants import
|
|
17
|
-
|
|
19
|
+
from .constants import (
|
|
20
|
+
DEFAULT_REVISION,
|
|
21
|
+
ENDPOINT,
|
|
22
|
+
REPO_TYPE_MODEL,
|
|
23
|
+
REPO_TYPES_MAPPING,
|
|
24
|
+
REPO_TYPES_URL_PREFIXES,
|
|
25
|
+
)
|
|
26
|
+
from .file_download import hf_hub_url, http_get
|
|
18
27
|
from .hf_api import HfApi, LastCommitInfo, RepoFile
|
|
19
28
|
from .utils import (
|
|
20
29
|
EntryNotFoundError,
|
|
@@ -591,6 +600,58 @@ class HfFileSystem(fsspec.AbstractFileSystem):
|
|
|
591
600
|
url = url.replace("/resolve/", "/tree/", 1)
|
|
592
601
|
return url
|
|
593
602
|
|
|
603
|
+
def get_file(self, rpath, lpath, callback=_DEFAULT_CALLBACK, outfile=None, **kwargs) -> None:
|
|
604
|
+
"""Copy single remote file to local."""
|
|
605
|
+
revision = kwargs.get("revision")
|
|
606
|
+
unhandled_kwargs = set(kwargs.keys()) - {"revision"}
|
|
607
|
+
if not isinstance(callback, (NoOpCallback, TqdmCallback)) or len(unhandled_kwargs) > 0:
|
|
608
|
+
# for now, let's not handle custom callbacks
|
|
609
|
+
# and let's not handle custom kwargs
|
|
610
|
+
return super().get_file(rpath, lpath, callback=callback, outfile=outfile, **kwargs)
|
|
611
|
+
|
|
612
|
+
# Taken from https://github.com/fsspec/filesystem_spec/blob/47b445ae4c284a82dd15e0287b1ffc410e8fc470/fsspec/spec.py#L883
|
|
613
|
+
if isfilelike(lpath):
|
|
614
|
+
outfile = lpath
|
|
615
|
+
elif self.isdir(rpath):
|
|
616
|
+
os.makedirs(lpath, exist_ok=True)
|
|
617
|
+
return None
|
|
618
|
+
|
|
619
|
+
if isinstance(lpath, (str, Path)): # otherwise, let's assume it's a file-like object
|
|
620
|
+
os.makedirs(os.path.dirname(lpath), exist_ok=True)
|
|
621
|
+
|
|
622
|
+
# Open file if not already open
|
|
623
|
+
close_file = False
|
|
624
|
+
if outfile is None:
|
|
625
|
+
outfile = open(lpath, "wb")
|
|
626
|
+
close_file = True
|
|
627
|
+
initial_pos = outfile.tell()
|
|
628
|
+
|
|
629
|
+
# Custom implementation of `get_file` to use `http_get`.
|
|
630
|
+
resolve_remote_path = self.resolve_path(rpath, revision=revision)
|
|
631
|
+
expected_size = self.info(rpath, revision=revision)["size"]
|
|
632
|
+
callback.set_size(expected_size)
|
|
633
|
+
try:
|
|
634
|
+
http_get(
|
|
635
|
+
url=hf_hub_url(
|
|
636
|
+
repo_id=resolve_remote_path.repo_id,
|
|
637
|
+
revision=resolve_remote_path.revision,
|
|
638
|
+
filename=resolve_remote_path.path_in_repo,
|
|
639
|
+
repo_type=resolve_remote_path.repo_type,
|
|
640
|
+
endpoint=self.endpoint,
|
|
641
|
+
),
|
|
642
|
+
temp_file=outfile,
|
|
643
|
+
displayed_filename=rpath,
|
|
644
|
+
expected_size=expected_size,
|
|
645
|
+
resume_size=0,
|
|
646
|
+
headers=self._api._build_hf_headers(),
|
|
647
|
+
_tqdm_bar=callback.tqdm if isinstance(callback, TqdmCallback) else None,
|
|
648
|
+
)
|
|
649
|
+
outfile.seek(initial_pos)
|
|
650
|
+
finally:
|
|
651
|
+
# Close file only if we opened it ourselves
|
|
652
|
+
if close_file:
|
|
653
|
+
outfile.close()
|
|
654
|
+
|
|
594
655
|
@property
|
|
595
656
|
def transaction(self):
|
|
596
657
|
"""A context within which files are committed together upon exit
|
|
@@ -618,6 +679,7 @@ class HfFileSystemFile(fsspec.spec.AbstractBufferedFile):
|
|
|
618
679
|
raise FileNotFoundError(
|
|
619
680
|
f"{e}.\nMake sure the repository and revision exist before writing data."
|
|
620
681
|
) from e
|
|
682
|
+
raise
|
|
621
683
|
super().__init__(fs, self.resolved_path.unresolve(), **kwargs)
|
|
622
684
|
self.fs: HfFileSystem
|
|
623
685
|
|
|
@@ -667,6 +729,18 @@ class HfFileSystemFile(fsspec.spec.AbstractBufferedFile):
|
|
|
667
729
|
path=self.resolved_path.unresolve(),
|
|
668
730
|
)
|
|
669
731
|
|
|
732
|
+
def read(self, length=-1):
|
|
733
|
+
"""Read remote file.
|
|
734
|
+
|
|
735
|
+
If `length` is not provided or is -1, the entire file is downloaded and read. On POSIX systems and if
|
|
736
|
+
`hf_transfer` is not enabled, the file is loaded in memory directly. Otherwise, the file is downloaded to a
|
|
737
|
+
temporary file and read from there.
|
|
738
|
+
"""
|
|
739
|
+
if self.mode == "rb" and (length is None or length == -1) and self.loc == 0:
|
|
740
|
+
with self.fs.open(self.path, "rb", block_size=0) as f: # block_size=0 enables fast streaming
|
|
741
|
+
return f.read()
|
|
742
|
+
return super().read(length)
|
|
743
|
+
|
|
670
744
|
def url(self) -> str:
|
|
671
745
|
return self.fs.url(self.path)
|
|
672
746
|
|
|
@@ -695,7 +769,7 @@ class HfFileSystemStreamFile(fsspec.spec.AbstractBufferedFile):
|
|
|
695
769
|
raise FileNotFoundError(
|
|
696
770
|
f"{e}.\nMake sure the repository and revision exist before writing data."
|
|
697
771
|
) from e
|
|
698
|
-
# avoid an
|
|
772
|
+
# avoid an unnecessary .info() call to instantiate .details
|
|
699
773
|
self.details = {"name": self.resolved_path.unresolve(), "size": None}
|
|
700
774
|
super().__init__(
|
|
701
775
|
fs, self.resolved_path.unresolve(), mode=mode, block_size=block_size, cache_type=cache_type, **kwargs
|
huggingface_hub/hub_mixin.py
CHANGED
|
@@ -1,17 +1,19 @@
|
|
|
1
1
|
import inspect
|
|
2
2
|
import json
|
|
3
3
|
import os
|
|
4
|
-
from dataclasses import asdict, is_dataclass
|
|
4
|
+
from dataclasses import asdict, dataclass, is_dataclass
|
|
5
5
|
from pathlib import Path
|
|
6
|
-
from typing import TYPE_CHECKING, Dict, List, Optional, Type, TypeVar, Union, get_args
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, TypeVar, Union, get_args
|
|
7
7
|
|
|
8
8
|
from .constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME, SAFETENSORS_SINGLE_FILE
|
|
9
9
|
from .file_download import hf_hub_download
|
|
10
10
|
from .hf_api import HfApi
|
|
11
|
+
from .repocard import ModelCard, ModelCardData
|
|
11
12
|
from .utils import (
|
|
12
13
|
EntryNotFoundError,
|
|
13
14
|
HfHubHTTPError,
|
|
14
15
|
SoftTemporaryDirectory,
|
|
16
|
+
is_jsonable,
|
|
15
17
|
is_safetensors_available,
|
|
16
18
|
is_torch_available,
|
|
17
19
|
logging,
|
|
@@ -27,8 +29,8 @@ if is_torch_available():
|
|
|
27
29
|
import torch # type: ignore
|
|
28
30
|
|
|
29
31
|
if is_safetensors_available():
|
|
30
|
-
from safetensors import
|
|
31
|
-
from safetensors.torch import
|
|
32
|
+
from safetensors.torch import load_model as load_model_as_safetensor
|
|
33
|
+
from safetensors.torch import save_model as save_model_as_safetensor
|
|
32
34
|
|
|
33
35
|
|
|
34
36
|
logger = logging.get_logger(__name__)
|
|
@@ -36,6 +38,26 @@ logger = logging.get_logger(__name__)
|
|
|
36
38
|
# Generic variable that is either ModelHubMixin or a subclass thereof
|
|
37
39
|
T = TypeVar("T", bound="ModelHubMixin")
|
|
38
40
|
|
|
41
|
+
DEFAULT_MODEL_CARD = """
|
|
42
|
+
---
|
|
43
|
+
# For reference on model card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1
|
|
44
|
+
# Doc / guide: https://huggingface.co/docs/hub/model-cards
|
|
45
|
+
{{ card_data }}
|
|
46
|
+
---
|
|
47
|
+
|
|
48
|
+
This model has been pushed to the Hub using **{{ library_name }}**:
|
|
49
|
+
- Repo: {{ repo_url | default("[More Information Needed]", true) }}
|
|
50
|
+
- Docs: {{ docs_url | default("[More Information Needed]", true) }}
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@dataclass
|
|
55
|
+
class MixinInfo:
|
|
56
|
+
library_name: Optional[str] = None
|
|
57
|
+
tags: Optional[List[str]] = None
|
|
58
|
+
repo_url: Optional[str] = None
|
|
59
|
+
docs_url: Optional[str] = None
|
|
60
|
+
|
|
39
61
|
|
|
40
62
|
class ModelHubMixin:
|
|
41
63
|
"""
|
|
@@ -45,21 +67,35 @@ class ModelHubMixin:
|
|
|
45
67
|
have to be overwritten in [`_from_pretrained`] and [`_save_pretrained`]. [`PyTorchModelHubMixin`] is a good example
|
|
46
68
|
of mixin integration with the Hub. Check out our [integration guide](../guides/integrations) for more instructions.
|
|
47
69
|
|
|
70
|
+
When inheriting from [`ModelHubMixin`], you can define class-level attributes. These attributes are not passed to
|
|
71
|
+
`__init__` but to the class definition itself. This is useful to define metadata about the library integrating
|
|
72
|
+
[`ModelHubMixin`].
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
library_name (`str`, *optional*):
|
|
76
|
+
Name of the library integrating ModelHubMixin. Used to generate model card.
|
|
77
|
+
tags (`List[str]`, *optional*):
|
|
78
|
+
Tags to be added to the model card. Used to generate model card.
|
|
79
|
+
repo_url (`str`, *optional*):
|
|
80
|
+
URL of the library repository. Used to generate model card.
|
|
81
|
+
docs_url (`str`, *optional*):
|
|
82
|
+
URL of the library documentation. Used to generate model card.
|
|
83
|
+
|
|
48
84
|
Example:
|
|
49
85
|
|
|
50
86
|
```python
|
|
51
|
-
>>> from dataclasses import dataclass
|
|
52
87
|
>>> from huggingface_hub import ModelHubMixin
|
|
53
88
|
|
|
54
|
-
#
|
|
55
|
-
>>>
|
|
56
|
-
...
|
|
57
|
-
...
|
|
58
|
-
...
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
...
|
|
89
|
+
# Inherit from ModelHubMixin
|
|
90
|
+
>>> class MyCustomModel(
|
|
91
|
+
... ModelHubMixin,
|
|
92
|
+
... library_name="my-library",
|
|
93
|
+
... tags=["x-custom-tag"],
|
|
94
|
+
... repo_url="https://github.com/huggingface/my-cool-library",
|
|
95
|
+
... docs_url="https://huggingface.co/docs/my-cool-library",
|
|
96
|
+
... # ^ optional metadata to generate model card
|
|
97
|
+
... ):
|
|
98
|
+
... def __init__(self, size: int = 512, device: str = "cpu"):
|
|
63
99
|
... # define how to initialize your model
|
|
64
100
|
... super().__init__()
|
|
65
101
|
... ...
|
|
@@ -85,7 +121,7 @@ class ModelHubMixin:
|
|
|
85
121
|
... # define how to deserialize your model
|
|
86
122
|
... ...
|
|
87
123
|
|
|
88
|
-
>>> model = MyCustomModel(
|
|
124
|
+
>>> model = MyCustomModel(size=256, device="gpu")
|
|
89
125
|
|
|
90
126
|
# Save model weights to local directory
|
|
91
127
|
>>> model.save_pretrained("my-awesome-model")
|
|
@@ -95,28 +131,107 @@ class ModelHubMixin:
|
|
|
95
131
|
|
|
96
132
|
# Download and initialize weights from the Hub
|
|
97
133
|
>>> reloaded_model = MyCustomModel.from_pretrained("username/my-awesome-model")
|
|
98
|
-
>>> reloaded_model.
|
|
99
|
-
|
|
134
|
+
>>> reloaded_model._hub_mixin_config
|
|
135
|
+
{"size": 256, "device": "gpu"}
|
|
136
|
+
|
|
137
|
+
# Model card has been correctly populated
|
|
138
|
+
>>> from huggingface_hub import ModelCard
|
|
139
|
+
>>> card = ModelCard.load("username/my-awesome-model")
|
|
140
|
+
>>> card.data.tags
|
|
141
|
+
["x-custom-tag", "pytorch_model_hub_mixin", "model_hub_mixin"]
|
|
142
|
+
>>> card.data.library_name
|
|
143
|
+
"my-library"
|
|
100
144
|
```
|
|
101
145
|
"""
|
|
102
146
|
|
|
103
|
-
|
|
104
|
-
# ^ optional config attribute automatically set in `from_pretrained`
|
|
147
|
+
_hub_mixin_config: Optional[Union[dict, "DataclassInstance"]] = None
|
|
148
|
+
# ^ optional config attribute automatically set in `from_pretrained`
|
|
149
|
+
_hub_mixin_info: MixinInfo
|
|
150
|
+
# ^ information about the library integrating ModelHubMixin (used to generate model card)
|
|
151
|
+
_hub_mixin_init_parameters: Dict[str, inspect.Parameter]
|
|
152
|
+
_hub_mixin_jsonable_default_values: Dict[str, Any]
|
|
153
|
+
_hub_mixin_inject_config: bool
|
|
154
|
+
# ^ internal values to handle config
|
|
155
|
+
|
|
156
|
+
def __init_subclass__(
|
|
157
|
+
cls,
|
|
158
|
+
*,
|
|
159
|
+
library_name: Optional[str] = None,
|
|
160
|
+
tags: Optional[List[str]] = None,
|
|
161
|
+
repo_url: Optional[str] = None,
|
|
162
|
+
docs_url: Optional[str] = None,
|
|
163
|
+
) -> None:
|
|
164
|
+
"""Inspect __init__ signature only once when subclassing + handle modelcard."""
|
|
165
|
+
super().__init_subclass__()
|
|
166
|
+
|
|
167
|
+
# Will be reused when creating modelcard
|
|
168
|
+
tags = tags or []
|
|
169
|
+
tags.append("model_hub_mixin")
|
|
170
|
+
cls._hub_mixin_info = MixinInfo(
|
|
171
|
+
library_name=library_name,
|
|
172
|
+
tags=tags,
|
|
173
|
+
repo_url=repo_url,
|
|
174
|
+
docs_url=docs_url,
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
# Inspect __init__ signature to handle config
|
|
178
|
+
cls._hub_mixin_init_parameters = dict(inspect.signature(cls.__init__).parameters)
|
|
179
|
+
cls._hub_mixin_jsonable_default_values = {
|
|
180
|
+
param.name: param.default
|
|
181
|
+
for param in cls._hub_mixin_init_parameters.values()
|
|
182
|
+
if param.default is not inspect.Parameter.empty and is_jsonable(param.default)
|
|
183
|
+
}
|
|
184
|
+
cls._hub_mixin_inject_config = "config" in inspect.signature(cls._from_pretrained).parameters
|
|
105
185
|
|
|
106
186
|
def __new__(cls, *args, **kwargs) -> "ModelHubMixin":
|
|
187
|
+
"""Create a new instance of the class and handle config.
|
|
188
|
+
|
|
189
|
+
3 cases:
|
|
190
|
+
- If `self._hub_mixin_config` is already set, do nothing.
|
|
191
|
+
- If `config` is passed as a dataclass, set it as `self._hub_mixin_config`.
|
|
192
|
+
- Otherwise, build `self._hub_mixin_config` from default values and passed values.
|
|
193
|
+
"""
|
|
107
194
|
instance = super().__new__(cls)
|
|
108
195
|
|
|
109
|
-
#
|
|
110
|
-
if instance.
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
196
|
+
# If `config` is already set, return early
|
|
197
|
+
if instance._hub_mixin_config is not None:
|
|
198
|
+
return instance
|
|
199
|
+
|
|
200
|
+
# Infer passed values
|
|
201
|
+
passed_values = {
|
|
202
|
+
**{
|
|
203
|
+
key: value
|
|
204
|
+
for key, value in zip(
|
|
205
|
+
# [1:] to skip `self` parameter
|
|
206
|
+
list(cls._hub_mixin_init_parameters)[1:],
|
|
207
|
+
args,
|
|
208
|
+
)
|
|
209
|
+
},
|
|
210
|
+
**kwargs,
|
|
211
|
+
}
|
|
212
|
+
|
|
213
|
+
# If config passed as dataclass => set it and return early
|
|
214
|
+
if is_dataclass(passed_values.get("config")):
|
|
215
|
+
instance._hub_mixin_config = passed_values["config"]
|
|
216
|
+
return instance
|
|
217
|
+
|
|
218
|
+
# Otherwise, build config from default + passed values
|
|
219
|
+
init_config = {
|
|
220
|
+
# default values
|
|
221
|
+
**cls._hub_mixin_jsonable_default_values,
|
|
222
|
+
# passed values
|
|
223
|
+
**{key: value for key, value in passed_values.items() if is_jsonable(value)},
|
|
224
|
+
}
|
|
225
|
+
init_config.pop("config", {})
|
|
226
|
+
|
|
227
|
+
# Populate `init_config` with provided config
|
|
228
|
+
provided_config = passed_values.get("config")
|
|
229
|
+
if isinstance(provided_config, dict):
|
|
230
|
+
init_config.update(provided_config)
|
|
231
|
+
|
|
232
|
+
# Set `config` attribute and return
|
|
233
|
+
if init_config != {}:
|
|
234
|
+
instance._hub_mixin_config = init_config
|
|
120
235
|
return instance
|
|
121
236
|
|
|
122
237
|
def save_pretrained(
|
|
@@ -147,16 +262,29 @@ class ModelHubMixin:
|
|
|
147
262
|
save_directory = Path(save_directory)
|
|
148
263
|
save_directory.mkdir(parents=True, exist_ok=True)
|
|
149
264
|
|
|
265
|
+
# Remove config.json if already exists. After `_save_pretrained` we don't want to overwrite config.json
|
|
266
|
+
# as it might have been saved by the custom `_save_pretrained` already. However we do want to overwrite
|
|
267
|
+
# an existing config.json if it was not saved by `_save_pretrained`.
|
|
268
|
+
config_path = save_directory / CONFIG_NAME
|
|
269
|
+
config_path.unlink(missing_ok=True)
|
|
270
|
+
|
|
150
271
|
# save model weights/files (framework-specific)
|
|
151
272
|
self._save_pretrained(save_directory)
|
|
152
273
|
|
|
153
|
-
# save config (if provided)
|
|
274
|
+
# save config (if provided and if not serialized yet in `_save_pretrained`)
|
|
154
275
|
if config is None:
|
|
155
|
-
config = self.
|
|
276
|
+
config = self._hub_mixin_config
|
|
156
277
|
if config is not None:
|
|
157
278
|
if is_dataclass(config):
|
|
158
279
|
config = asdict(config) # type: ignore[arg-type]
|
|
159
|
-
|
|
280
|
+
if not config_path.exists():
|
|
281
|
+
config_str = json.dumps(config, sort_keys=True, indent=2)
|
|
282
|
+
config_path.write_text(config_str)
|
|
283
|
+
|
|
284
|
+
# save model card
|
|
285
|
+
model_card_path = save_directory / "README.md"
|
|
286
|
+
if not model_card_path.exists(): # do not overwrite if already exists
|
|
287
|
+
self.generate_model_card().save(save_directory / "README.md")
|
|
160
288
|
|
|
161
289
|
# push to the Hub if required
|
|
162
290
|
if push_to_hub:
|
|
@@ -246,31 +374,44 @@ class ModelHubMixin:
|
|
|
246
374
|
except HfHubHTTPError as e:
|
|
247
375
|
logger.info(f"{CONFIG_NAME} not found on the HuggingFace Hub: {str(e)}")
|
|
248
376
|
|
|
377
|
+
# Read config
|
|
249
378
|
config = None
|
|
250
379
|
if config_file is not None:
|
|
251
|
-
# Read config
|
|
252
380
|
with open(config_file, "r", encoding="utf-8") as f:
|
|
253
381
|
config = json.load(f)
|
|
254
382
|
|
|
255
|
-
#
|
|
256
|
-
|
|
257
|
-
|
|
383
|
+
# Populate model_kwargs from config
|
|
384
|
+
for param in cls._hub_mixin_init_parameters.values():
|
|
385
|
+
if param.name not in model_kwargs and param.name in config:
|
|
386
|
+
model_kwargs[param.name] = config[param.name]
|
|
387
|
+
|
|
388
|
+
# Check if `config` argument was passed at init
|
|
389
|
+
if "config" in cls._hub_mixin_init_parameters:
|
|
258
390
|
# Check if `config` argument is a dataclass
|
|
259
|
-
config_annotation =
|
|
391
|
+
config_annotation = cls._hub_mixin_init_parameters["config"].annotation
|
|
260
392
|
if config_annotation is inspect.Parameter.empty:
|
|
261
393
|
pass # no annotation
|
|
262
394
|
elif is_dataclass(config_annotation):
|
|
263
|
-
config = config_annotation
|
|
395
|
+
config = _load_dataclass(config_annotation, config)
|
|
264
396
|
else:
|
|
265
397
|
# if Optional/Union annotation => check if a dataclass is in the Union
|
|
266
398
|
for _sub_annotation in get_args(config_annotation):
|
|
267
399
|
if is_dataclass(_sub_annotation):
|
|
268
|
-
config = _sub_annotation
|
|
400
|
+
config = _load_dataclass(_sub_annotation, config)
|
|
269
401
|
break
|
|
270
402
|
|
|
271
403
|
# Forward config to model initialization
|
|
272
404
|
model_kwargs["config"] = config
|
|
273
405
|
|
|
406
|
+
if any(param.kind == inspect.Parameter.VAR_KEYWORD for param in cls._hub_mixin_init_parameters.values()):
|
|
407
|
+
for key, value in config.items():
|
|
408
|
+
if key not in model_kwargs:
|
|
409
|
+
model_kwargs[key] = value
|
|
410
|
+
|
|
411
|
+
# Finally, also inject if `_from_pretrained` expects it
|
|
412
|
+
if cls._hub_mixin_inject_config:
|
|
413
|
+
model_kwargs["config"] = config
|
|
414
|
+
|
|
274
415
|
instance = cls._from_pretrained(
|
|
275
416
|
model_id=str(model_id),
|
|
276
417
|
revision=revision,
|
|
@@ -285,8 +426,8 @@ class ModelHubMixin:
|
|
|
285
426
|
|
|
286
427
|
# Implicitly set the config as instance attribute if not already set by the class
|
|
287
428
|
# This way `config` will be available when calling `save_pretrained` or `push_to_hub`.
|
|
288
|
-
if config is not None and instance
|
|
289
|
-
instance.
|
|
429
|
+
if config is not None and (getattr(instance, "_hub_mixin_config", None) in (None, {})):
|
|
430
|
+
instance._hub_mixin_config = config
|
|
290
431
|
|
|
291
432
|
return instance
|
|
292
433
|
|
|
@@ -415,6 +556,13 @@ class ModelHubMixin:
|
|
|
415
556
|
delete_patterns=delete_patterns,
|
|
416
557
|
)
|
|
417
558
|
|
|
559
|
+
def generate_model_card(self, *args, **kwargs) -> ModelCard:
|
|
560
|
+
card = ModelCard.from_template(
|
|
561
|
+
card_data=ModelCardData(**asdict(self._hub_mixin_info)),
|
|
562
|
+
template_str=DEFAULT_MODEL_CARD,
|
|
563
|
+
)
|
|
564
|
+
return card
|
|
565
|
+
|
|
418
566
|
|
|
419
567
|
class PyTorchModelHubMixin(ModelHubMixin):
|
|
420
568
|
"""
|
|
@@ -425,26 +573,26 @@ class PyTorchModelHubMixin(ModelHubMixin):
|
|
|
425
573
|
Example:
|
|
426
574
|
|
|
427
575
|
```python
|
|
428
|
-
>>> from dataclasses import dataclass
|
|
429
576
|
>>> import torch
|
|
430
577
|
>>> import torch.nn as nn
|
|
431
578
|
>>> from huggingface_hub import PyTorchModelHubMixin
|
|
432
579
|
|
|
433
|
-
>>>
|
|
434
|
-
...
|
|
435
|
-
...
|
|
436
|
-
...
|
|
437
|
-
...
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
...
|
|
580
|
+
>>> class MyModel(
|
|
581
|
+
... nn.Module,
|
|
582
|
+
... PyTorchModelHubMixin,
|
|
583
|
+
... library_name="keras-nlp",
|
|
584
|
+
... repo_url="https://github.com/keras-team/keras-nlp",
|
|
585
|
+
... docs_url="https://keras.io/keras_nlp/",
|
|
586
|
+
... # ^ optional metadata to generate model card
|
|
587
|
+
... ):
|
|
588
|
+
... def __init__(self, hidden_size: int = 512, vocab_size: int = 30000, output_size: int = 4):
|
|
441
589
|
... super().__init__()
|
|
442
|
-
... self.param = nn.Parameter(torch.rand(
|
|
443
|
-
... self.linear = nn.Linear(
|
|
590
|
+
... self.param = nn.Parameter(torch.rand(hidden_size, vocab_size))
|
|
591
|
+
... self.linear = nn.Linear(output_size, vocab_size)
|
|
444
592
|
|
|
445
593
|
... def forward(self, x):
|
|
446
594
|
... return self.linear(x + self.param)
|
|
447
|
-
>>> model = MyModel()
|
|
595
|
+
>>> model = MyModel(hidden_size=256)
|
|
448
596
|
|
|
449
597
|
# Save model weights to local directory
|
|
450
598
|
>>> model.save_pretrained("my-awesome-model")
|
|
@@ -454,13 +602,21 @@ class PyTorchModelHubMixin(ModelHubMixin):
|
|
|
454
602
|
|
|
455
603
|
# Download and initialize weights from the Hub
|
|
456
604
|
>>> model = MyModel.from_pretrained("username/my-awesome-model")
|
|
605
|
+
>>> model.hidden_size
|
|
606
|
+
256
|
|
457
607
|
```
|
|
458
608
|
"""
|
|
459
609
|
|
|
610
|
+
def __init_subclass__(cls, *args, tags: Optional[List[str]] = None, **kwargs) -> None:
|
|
611
|
+
tags = tags or []
|
|
612
|
+
tags.append("pytorch_model_hub_mixin")
|
|
613
|
+
kwargs["tags"] = tags
|
|
614
|
+
return super().__init_subclass__(*args, **kwargs)
|
|
615
|
+
|
|
460
616
|
def _save_pretrained(self, save_directory: Path) -> None:
|
|
461
617
|
"""Save weights from a Pytorch model to a local directory."""
|
|
462
618
|
model_to_save = self.module if hasattr(self, "module") else self # type: ignore
|
|
463
|
-
|
|
619
|
+
save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE))
|
|
464
620
|
|
|
465
621
|
@classmethod
|
|
466
622
|
def _from_pretrained(
|
|
@@ -521,10 +677,23 @@ class PyTorchModelHubMixin(ModelHubMixin):
|
|
|
521
677
|
|
|
522
678
|
@classmethod
|
|
523
679
|
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
680
|
+
load_model_as_safetensor(model, model_file, strict=strict) # type: ignore [arg-type]
|
|
681
|
+
if map_location != "cpu":
|
|
682
|
+
# TODO: remove this once https://github.com/huggingface/safetensors/pull/449 is merged.
|
|
683
|
+
logger.warning(
|
|
684
|
+
"Loading model weights on other devices than 'cpu' is not supported natively."
|
|
685
|
+
" This means that the model is loaded on 'cpu' first and then copied to the device."
|
|
686
|
+
" This leads to a slower loading time."
|
|
687
|
+
" Support for loading directly on other devices is planned to be added in future releases."
|
|
688
|
+
" See https://github.com/huggingface/huggingface_hub/pull/2086 for more details."
|
|
689
|
+
)
|
|
690
|
+
model.to(map_location) # type: ignore [attr-defined]
|
|
530
691
|
return model
|
|
692
|
+
|
|
693
|
+
|
|
694
|
+
def _load_dataclass(datacls: Type["DataclassInstance"], data: dict) -> "DataclassInstance":
|
|
695
|
+
"""Load a dataclass instance from a dictionary.
|
|
696
|
+
|
|
697
|
+
Fields not expected by the dataclass are ignored.
|
|
698
|
+
"""
|
|
699
|
+
return datacls(**{k: v for k, v in data.items() if k in datacls.__dataclass_fields__})
|