huggingface-hub 0.22.1__py3-none-any.whl → 0.23.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of huggingface-hub might be problematic. Click here for more details.
- huggingface_hub/__init__.py +51 -19
- huggingface_hub/_commit_api.py +10 -9
- huggingface_hub/_commit_scheduler.py +2 -2
- huggingface_hub/_inference_endpoints.py +10 -17
- huggingface_hub/_local_folder.py +229 -0
- huggingface_hub/_login.py +4 -3
- huggingface_hub/_multi_commits.py +1 -1
- huggingface_hub/_snapshot_download.py +16 -38
- huggingface_hub/_tensorboard_logger.py +16 -6
- huggingface_hub/_webhooks_payload.py +22 -1
- huggingface_hub/_webhooks_server.py +24 -20
- huggingface_hub/commands/download.py +11 -34
- huggingface_hub/commands/huggingface_cli.py +2 -0
- huggingface_hub/commands/tag.py +159 -0
- huggingface_hub/constants.py +3 -5
- huggingface_hub/errors.py +58 -0
- huggingface_hub/file_download.py +545 -376
- huggingface_hub/hf_api.py +758 -629
- huggingface_hub/hf_file_system.py +14 -5
- huggingface_hub/hub_mixin.py +127 -43
- huggingface_hub/inference/_client.py +402 -183
- huggingface_hub/inference/_common.py +19 -29
- huggingface_hub/inference/_generated/_async_client.py +402 -184
- huggingface_hub/inference/_generated/types/__init__.py +23 -6
- huggingface_hub/inference/_generated/types/chat_completion.py +197 -43
- huggingface_hub/inference/_generated/types/text_generation.py +57 -79
- huggingface_hub/inference/_templating.py +2 -4
- huggingface_hub/keras_mixin.py +0 -3
- huggingface_hub/lfs.py +16 -4
- huggingface_hub/repository.py +1 -0
- huggingface_hub/utils/__init__.py +19 -6
- huggingface_hub/utils/_fixes.py +1 -0
- huggingface_hub/utils/_headers.py +2 -4
- huggingface_hub/utils/_http.py +16 -5
- huggingface_hub/utils/_paths.py +13 -1
- huggingface_hub/utils/_runtime.py +10 -0
- huggingface_hub/utils/_safetensors.py +0 -13
- huggingface_hub/utils/_validators.py +2 -7
- huggingface_hub/utils/tqdm.py +124 -46
- {huggingface_hub-0.22.1.dist-info → huggingface_hub-0.23.0.dist-info}/METADATA +5 -1
- {huggingface_hub-0.22.1.dist-info → huggingface_hub-0.23.0.dist-info}/RECORD +45 -43
- {huggingface_hub-0.22.1.dist-info → huggingface_hub-0.23.0.dist-info}/LICENSE +0 -0
- {huggingface_hub-0.22.1.dist-info → huggingface_hub-0.23.0.dist-info}/WHEEL +0 -0
- {huggingface_hub-0.22.1.dist-info → huggingface_hub-0.23.0.dist-info}/entry_points.txt +0 -0
- {huggingface_hub-0.22.1.dist-info → huggingface_hub-0.23.0.dist-info}/top_level.txt +0 -0
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
import copy
|
|
2
1
|
import os
|
|
3
2
|
import re
|
|
4
3
|
import tempfile
|
|
@@ -19,6 +18,8 @@ from ._commit_api import CommitOperationCopy, CommitOperationDelete
|
|
|
19
18
|
from .constants import (
|
|
20
19
|
DEFAULT_REVISION,
|
|
21
20
|
ENDPOINT,
|
|
21
|
+
HF_HUB_DOWNLOAD_TIMEOUT,
|
|
22
|
+
HF_HUB_ETAG_TIMEOUT,
|
|
22
23
|
REPO_TYPE_MODEL,
|
|
23
24
|
REPO_TYPES_MAPPING,
|
|
24
25
|
REPO_TYPES_URL_PREFIXES,
|
|
@@ -123,7 +124,7 @@ class HfFileSystem(fsspec.AbstractFileSystem):
|
|
|
123
124
|
) -> Tuple[bool, Optional[Exception]]:
|
|
124
125
|
if (repo_type, repo_id, revision) not in self._repo_and_revision_exists_cache:
|
|
125
126
|
try:
|
|
126
|
-
self._api.repo_info(repo_id, revision=revision, repo_type=repo_type)
|
|
127
|
+
self._api.repo_info(repo_id, revision=revision, repo_type=repo_type, timeout=HF_HUB_ETAG_TIMEOUT)
|
|
127
128
|
except (RepositoryNotFoundError, HFValidationError) as e:
|
|
128
129
|
self._repo_and_revision_exists_cache[(repo_type, repo_id, revision)] = False, e
|
|
129
130
|
self._repo_and_revision_exists_cache[(repo_type, repo_id, None)] = False, e
|
|
@@ -397,7 +398,7 @@ class HfFileSystem(fsspec.AbstractFileSystem):
|
|
|
397
398
|
parent_path = self._parent(cache_path_info["name"])
|
|
398
399
|
self.dircache.setdefault(parent_path, []).append(cache_path_info)
|
|
399
400
|
out.append(cache_path_info)
|
|
400
|
-
return
|
|
401
|
+
return out
|
|
401
402
|
|
|
402
403
|
def glob(self, path, **kwargs):
|
|
403
404
|
# Set expand_info=False by default to get a x10 speed boost
|
|
@@ -561,7 +562,7 @@ class HfFileSystem(fsspec.AbstractFileSystem):
|
|
|
561
562
|
if not expand_info:
|
|
562
563
|
out = {k: out[k] for k in ["name", "size", "type"]}
|
|
563
564
|
assert out is not None
|
|
564
|
-
return
|
|
565
|
+
return out
|
|
565
566
|
|
|
566
567
|
def exists(self, path, **kwargs):
|
|
567
568
|
"""Is there a file at the given path"""
|
|
@@ -701,7 +702,13 @@ class HfFileSystemFile(fsspec.spec.AbstractBufferedFile):
|
|
|
701
702
|
repo_type=self.resolved_path.repo_type,
|
|
702
703
|
endpoint=self.fs.endpoint,
|
|
703
704
|
)
|
|
704
|
-
r = http_backoff(
|
|
705
|
+
r = http_backoff(
|
|
706
|
+
"GET",
|
|
707
|
+
url,
|
|
708
|
+
headers=headers,
|
|
709
|
+
retry_on_status_codes=(502, 503, 504),
|
|
710
|
+
timeout=HF_HUB_DOWNLOAD_TIMEOUT,
|
|
711
|
+
)
|
|
705
712
|
hf_raise_for_status(r)
|
|
706
713
|
return r.content
|
|
707
714
|
|
|
@@ -800,6 +807,7 @@ class HfFileSystemStreamFile(fsspec.spec.AbstractBufferedFile):
|
|
|
800
807
|
headers=self.fs._api._build_hf_headers(),
|
|
801
808
|
retry_on_status_codes=(502, 503, 504),
|
|
802
809
|
stream=True,
|
|
810
|
+
timeout=HF_HUB_DOWNLOAD_TIMEOUT,
|
|
803
811
|
)
|
|
804
812
|
hf_raise_for_status(self.response)
|
|
805
813
|
try:
|
|
@@ -821,6 +829,7 @@ class HfFileSystemStreamFile(fsspec.spec.AbstractBufferedFile):
|
|
|
821
829
|
headers={"Range": "bytes=%d-" % self.loc, **self.fs._api._build_hf_headers()},
|
|
822
830
|
retry_on_status_codes=(502, 503, 504),
|
|
823
831
|
stream=True,
|
|
832
|
+
timeout=HF_HUB_DOWNLOAD_TIMEOUT,
|
|
824
833
|
)
|
|
825
834
|
hf_raise_for_status(self.response)
|
|
826
835
|
try:
|
huggingface_hub/hub_mixin.py
CHANGED
|
@@ -3,7 +3,7 @@ import json
|
|
|
3
3
|
import os
|
|
4
4
|
from dataclasses import asdict, dataclass, is_dataclass
|
|
5
5
|
from pathlib import Path
|
|
6
|
-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, TypeVar, Union, get_args
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union, get_args
|
|
7
7
|
|
|
8
8
|
from .constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME, SAFETENSORS_SINGLE_FILE
|
|
9
9
|
from .file_download import hf_hub_download
|
|
@@ -19,7 +19,6 @@ from .utils import (
|
|
|
19
19
|
logging,
|
|
20
20
|
validate_hf_hub_args,
|
|
21
21
|
)
|
|
22
|
-
from .utils._deprecation import _deprecate_arguments
|
|
23
22
|
|
|
24
23
|
|
|
25
24
|
if TYPE_CHECKING:
|
|
@@ -37,6 +36,12 @@ logger = logging.get_logger(__name__)
|
|
|
37
36
|
|
|
38
37
|
# Generic variable that is either ModelHubMixin or a subclass thereof
|
|
39
38
|
T = TypeVar("T", bound="ModelHubMixin")
|
|
39
|
+
# Generic variable to represent an args type
|
|
40
|
+
ARGS_T = TypeVar("ARGS_T")
|
|
41
|
+
ENCODER_T = Callable[[ARGS_T], Any]
|
|
42
|
+
DECODER_T = Callable[[Any], ARGS_T]
|
|
43
|
+
CODER_T = Tuple[ENCODER_T, DECODER_T]
|
|
44
|
+
|
|
40
45
|
|
|
41
46
|
DEFAULT_MODEL_CARD = """
|
|
42
47
|
---
|
|
@@ -45,16 +50,16 @@ DEFAULT_MODEL_CARD = """
|
|
|
45
50
|
{{ card_data }}
|
|
46
51
|
---
|
|
47
52
|
|
|
48
|
-
This model has been pushed to the Hub using
|
|
49
|
-
-
|
|
53
|
+
This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
|
|
54
|
+
- Library: {{ repo_url | default("[More Information Needed]", true) }}
|
|
50
55
|
- Docs: {{ docs_url | default("[More Information Needed]", true) }}
|
|
51
56
|
"""
|
|
52
57
|
|
|
53
58
|
|
|
54
59
|
@dataclass
|
|
55
60
|
class MixinInfo:
|
|
56
|
-
|
|
57
|
-
|
|
61
|
+
model_card_template: str
|
|
62
|
+
model_card_data: ModelCardData
|
|
58
63
|
repo_url: Optional[str] = None
|
|
59
64
|
docs_url: Optional[str] = None
|
|
60
65
|
|
|
@@ -71,15 +76,37 @@ class ModelHubMixin:
|
|
|
71
76
|
`__init__` but to the class definition itself. This is useful to define metadata about the library integrating
|
|
72
77
|
[`ModelHubMixin`].
|
|
73
78
|
|
|
79
|
+
For more details on how to integrate the mixin with your library, checkout the [integration guide](../guides/integrations).
|
|
80
|
+
|
|
74
81
|
Args:
|
|
75
|
-
library_name (`str`, *optional*):
|
|
76
|
-
Name of the library integrating ModelHubMixin. Used to generate model card.
|
|
77
|
-
tags (`List[str]`, *optional*):
|
|
78
|
-
Tags to be added to the model card. Used to generate model card.
|
|
79
82
|
repo_url (`str`, *optional*):
|
|
80
83
|
URL of the library repository. Used to generate model card.
|
|
81
84
|
docs_url (`str`, *optional*):
|
|
82
85
|
URL of the library documentation. Used to generate model card.
|
|
86
|
+
model_card_template (`str`, *optional*):
|
|
87
|
+
Template of the model card. Used to generate model card. Defaults to a generic template.
|
|
88
|
+
languages (`List[str]`, *optional*):
|
|
89
|
+
Languages supported by the library. Used to generate model card.
|
|
90
|
+
library_name (`str`, *optional*):
|
|
91
|
+
Name of the library integrating ModelHubMixin. Used to generate model card.
|
|
92
|
+
license (`str`, *optional*):
|
|
93
|
+
License of the library integrating ModelHubMixin. Used to generate model card.
|
|
94
|
+
E.g: "apache-2.0"
|
|
95
|
+
license_name (`str`, *optional*):
|
|
96
|
+
Name of the library integrating ModelHubMixin. Used to generate model card.
|
|
97
|
+
Only used if `license` is set to `other`.
|
|
98
|
+
E.g: "coqui-public-model-license".
|
|
99
|
+
license_link (`str`, *optional*):
|
|
100
|
+
URL to the license of the library integrating ModelHubMixin. Used to generate model card.
|
|
101
|
+
Only used if `license` is set to `other` and `license_name` is set.
|
|
102
|
+
E.g: "https://coqui.ai/cpml".
|
|
103
|
+
pipeline_tag (`str`, *optional*):
|
|
104
|
+
Tag of the pipeline. Used to generate model card. E.g. "text-classification".
|
|
105
|
+
tags (`List[str]`, *optional*):
|
|
106
|
+
Tags to be added to the model card. Used to generate model card. E.g. ["x-custom-tag", "arxiv:2304.12244"]
|
|
107
|
+
coders (`Dict[Type, Tuple[Callable, Callable]]`, *optional*):
|
|
108
|
+
Dictionary of custom types and their encoders/decoders. Used to encode/decode arguments that are not
|
|
109
|
+
jsonable by default. E.g dataclasses, argparse.Namespace, OmegaConf, etc.
|
|
83
110
|
|
|
84
111
|
Example:
|
|
85
112
|
|
|
@@ -90,7 +117,7 @@ class ModelHubMixin:
|
|
|
90
117
|
>>> class MyCustomModel(
|
|
91
118
|
... ModelHubMixin,
|
|
92
119
|
... library_name="my-library",
|
|
93
|
-
... tags=["x-custom-tag"],
|
|
120
|
+
... tags=["x-custom-tag", "arxiv:2304.12244"],
|
|
94
121
|
... repo_url="https://github.com/huggingface/my-cool-library",
|
|
95
122
|
... docs_url="https://huggingface.co/docs/my-cool-library",
|
|
96
123
|
... # ^ optional metadata to generate model card
|
|
@@ -110,7 +137,7 @@ class ModelHubMixin:
|
|
|
110
137
|
... pretrained_model_name_or_path: Union[str, Path],
|
|
111
138
|
... *,
|
|
112
139
|
... force_download: bool = False,
|
|
113
|
-
... resume_download: bool =
|
|
140
|
+
... resume_download: Optional[bool] = None,
|
|
114
141
|
... proxies: Optional[Dict] = None,
|
|
115
142
|
... token: Optional[Union[str, bool]] = None,
|
|
116
143
|
... cache_dir: Optional[Union[str, Path]] = None,
|
|
@@ -131,8 +158,8 @@ class ModelHubMixin:
|
|
|
131
158
|
|
|
132
159
|
# Download and initialize weights from the Hub
|
|
133
160
|
>>> reloaded_model = MyCustomModel.from_pretrained("username/my-awesome-model")
|
|
134
|
-
>>> reloaded_model.
|
|
135
|
-
|
|
161
|
+
>>> reloaded_model.size
|
|
162
|
+
256
|
|
136
163
|
|
|
137
164
|
# Model card has been correctly populated
|
|
138
165
|
>>> from huggingface_hub import ModelCard
|
|
@@ -148,18 +175,36 @@ class ModelHubMixin:
|
|
|
148
175
|
# ^ optional config attribute automatically set in `from_pretrained`
|
|
149
176
|
_hub_mixin_info: MixinInfo
|
|
150
177
|
# ^ information about the library integrating ModelHubMixin (used to generate model card)
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
178
|
+
_hub_mixin_inject_config: bool # whether `_from_pretrained` expects `config` or not
|
|
179
|
+
_hub_mixin_init_parameters: Dict[str, inspect.Parameter] # __init__ parameters
|
|
180
|
+
_hub_mixin_jsonable_default_values: Dict[str, Any] # default values for __init__ parameters
|
|
181
|
+
_hub_mixin_jsonable_custom_types: Tuple[Type, ...] # custom types that can be encoded/decoded
|
|
182
|
+
_hub_mixin_coders: Dict[Type, CODER_T] # encoders/decoders for custom types
|
|
154
183
|
# ^ internal values to handle config
|
|
155
184
|
|
|
156
185
|
def __init_subclass__(
|
|
157
186
|
cls,
|
|
158
187
|
*,
|
|
159
|
-
|
|
160
|
-
tags: Optional[List[str]] = None,
|
|
188
|
+
# Generic info for model card
|
|
161
189
|
repo_url: Optional[str] = None,
|
|
162
190
|
docs_url: Optional[str] = None,
|
|
191
|
+
# Model card template
|
|
192
|
+
model_card_template: str = DEFAULT_MODEL_CARD,
|
|
193
|
+
# Model card metadata
|
|
194
|
+
languages: Optional[List[str]] = None,
|
|
195
|
+
library_name: Optional[str] = None,
|
|
196
|
+
license: Optional[str] = None,
|
|
197
|
+
license_name: Optional[str] = None,
|
|
198
|
+
license_link: Optional[str] = None,
|
|
199
|
+
pipeline_tag: Optional[str] = None,
|
|
200
|
+
tags: Optional[List[str]] = None,
|
|
201
|
+
# How to encode/decode arguments with custom type into a JSON config?
|
|
202
|
+
coders: Optional[
|
|
203
|
+
Dict[Type, CODER_T]
|
|
204
|
+
# Key is a type.
|
|
205
|
+
# Value is a tuple (encoder, decoder).
|
|
206
|
+
# Example: {MyCustomType: (lambda x: x.value, lambda data: MyCustomType(data))}
|
|
207
|
+
] = None,
|
|
163
208
|
) -> None:
|
|
164
209
|
"""Inspect __init__ signature only once when subclassing + handle modelcard."""
|
|
165
210
|
super().__init_subclass__()
|
|
@@ -168,18 +213,30 @@ class ModelHubMixin:
|
|
|
168
213
|
tags = tags or []
|
|
169
214
|
tags.append("model_hub_mixin")
|
|
170
215
|
cls._hub_mixin_info = MixinInfo(
|
|
171
|
-
|
|
172
|
-
tags=tags,
|
|
216
|
+
model_card_template=model_card_template,
|
|
173
217
|
repo_url=repo_url,
|
|
174
218
|
docs_url=docs_url,
|
|
219
|
+
model_card_data=ModelCardData(
|
|
220
|
+
languages=languages,
|
|
221
|
+
library_name=library_name,
|
|
222
|
+
license=license,
|
|
223
|
+
license_name=license_name,
|
|
224
|
+
license_link=license_link,
|
|
225
|
+
pipeline_tag=pipeline_tag,
|
|
226
|
+
tags=tags,
|
|
227
|
+
),
|
|
175
228
|
)
|
|
176
229
|
|
|
230
|
+
# Handle encoders/decoders for args
|
|
231
|
+
cls._hub_mixin_coders = coders or {}
|
|
232
|
+
cls._hub_mixin_jsonable_custom_types = tuple(cls._hub_mixin_coders.keys())
|
|
233
|
+
|
|
177
234
|
# Inspect __init__ signature to handle config
|
|
178
235
|
cls._hub_mixin_init_parameters = dict(inspect.signature(cls.__init__).parameters)
|
|
179
236
|
cls._hub_mixin_jsonable_default_values = {
|
|
180
|
-
param.name: param.default
|
|
237
|
+
param.name: cls._encode_arg(param.default)
|
|
181
238
|
for param in cls._hub_mixin_init_parameters.values()
|
|
182
|
-
if param.default is not inspect.Parameter.empty and
|
|
239
|
+
if param.default is not inspect.Parameter.empty and cls._is_jsonable(param.default)
|
|
183
240
|
}
|
|
184
241
|
cls._hub_mixin_inject_config = "config" in inspect.signature(cls._from_pretrained).parameters
|
|
185
242
|
|
|
@@ -220,7 +277,11 @@ class ModelHubMixin:
|
|
|
220
277
|
# default values
|
|
221
278
|
**cls._hub_mixin_jsonable_default_values,
|
|
222
279
|
# passed values
|
|
223
|
-
**{
|
|
280
|
+
**{
|
|
281
|
+
key: cls._encode_arg(value) # Encode custom types as jsonable value
|
|
282
|
+
for key, value in passed_values.items()
|
|
283
|
+
if instance._is_jsonable(value) # Only if jsonable or we have a custom encoder
|
|
284
|
+
},
|
|
224
285
|
}
|
|
225
286
|
init_config.pop("config", {})
|
|
226
287
|
|
|
@@ -234,6 +295,29 @@ class ModelHubMixin:
|
|
|
234
295
|
instance._hub_mixin_config = init_config
|
|
235
296
|
return instance
|
|
236
297
|
|
|
298
|
+
@classmethod
|
|
299
|
+
def _is_jsonable(cls, value: Any) -> bool:
|
|
300
|
+
"""Check if a value is JSON serializable."""
|
|
301
|
+
if isinstance(value, cls._hub_mixin_jsonable_custom_types):
|
|
302
|
+
return True
|
|
303
|
+
return is_jsonable(value)
|
|
304
|
+
|
|
305
|
+
@classmethod
|
|
306
|
+
def _encode_arg(cls, arg: Any) -> Any:
|
|
307
|
+
"""Encode an argument into a JSON serializable format."""
|
|
308
|
+
for type_, (encoder, _) in cls._hub_mixin_coders.items():
|
|
309
|
+
if isinstance(arg, type_):
|
|
310
|
+
return encoder(arg)
|
|
311
|
+
return arg
|
|
312
|
+
|
|
313
|
+
@classmethod
|
|
314
|
+
def _decode_arg(cls, expected_type: Type[ARGS_T], value: Any) -> ARGS_T:
|
|
315
|
+
"""Decode a JSON serializable value into an argument."""
|
|
316
|
+
for type_, (_, decoder) in cls._hub_mixin_coders.items():
|
|
317
|
+
if issubclass(expected_type, type_):
|
|
318
|
+
return decoder(value)
|
|
319
|
+
return value
|
|
320
|
+
|
|
237
321
|
def save_pretrained(
|
|
238
322
|
self,
|
|
239
323
|
save_directory: Union[str, Path],
|
|
@@ -258,6 +342,8 @@ class ModelHubMixin:
|
|
|
258
342
|
not provided.
|
|
259
343
|
kwargs:
|
|
260
344
|
Additional key word arguments passed along to the [`~ModelHubMixin.push_to_hub`] method.
|
|
345
|
+
Returns:
|
|
346
|
+
`str` or `None`: url of the commit on the Hub if `push_to_hub=True`, `None` otherwise.
|
|
261
347
|
"""
|
|
262
348
|
save_directory = Path(save_directory)
|
|
263
349
|
save_directory.mkdir(parents=True, exist_ok=True)
|
|
@@ -314,7 +400,7 @@ class ModelHubMixin:
|
|
|
314
400
|
pretrained_model_name_or_path: Union[str, Path],
|
|
315
401
|
*,
|
|
316
402
|
force_download: bool = False,
|
|
317
|
-
resume_download: bool =
|
|
403
|
+
resume_download: Optional[bool] = None,
|
|
318
404
|
proxies: Optional[Dict] = None,
|
|
319
405
|
token: Optional[Union[str, bool]] = None,
|
|
320
406
|
cache_dir: Optional[Union[str, Path]] = None,
|
|
@@ -336,8 +422,6 @@ class ModelHubMixin:
|
|
|
336
422
|
force_download (`bool`, *optional*, defaults to `False`):
|
|
337
423
|
Whether to force (re-)downloading the model weights and configuration files from the Hub, overriding
|
|
338
424
|
the existing cache.
|
|
339
|
-
resume_download (`bool`, *optional*, defaults to `False`):
|
|
340
|
-
Whether to delete incompletely received files. Will attempt to resume the download if such a file exists.
|
|
341
425
|
proxies (`Dict[str, str]`, *optional*):
|
|
342
426
|
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
|
343
427
|
'http://hostname': 'foo.bar:4012'}`. The proxies are used on every request.
|
|
@@ -380,6 +464,13 @@ class ModelHubMixin:
|
|
|
380
464
|
with open(config_file, "r", encoding="utf-8") as f:
|
|
381
465
|
config = json.load(f)
|
|
382
466
|
|
|
467
|
+
# Decode custom types in config
|
|
468
|
+
for key, value in config.items():
|
|
469
|
+
if key in cls._hub_mixin_init_parameters:
|
|
470
|
+
expected_type = cls._hub_mixin_init_parameters[key].annotation
|
|
471
|
+
if expected_type is not inspect.Parameter.empty:
|
|
472
|
+
config[key] = cls._decode_arg(expected_type, value)
|
|
473
|
+
|
|
383
474
|
# Populate model_kwargs from config
|
|
384
475
|
for param in cls._hub_mixin_init_parameters.values():
|
|
385
476
|
if param.name not in model_kwargs and param.name in config:
|
|
@@ -445,7 +536,7 @@ class ModelHubMixin:
|
|
|
445
536
|
cache_dir: Optional[Union[str, Path]],
|
|
446
537
|
force_download: bool,
|
|
447
538
|
proxies: Optional[Dict],
|
|
448
|
-
resume_download: bool,
|
|
539
|
+
resume_download: Optional[bool],
|
|
449
540
|
local_files_only: bool,
|
|
450
541
|
token: Optional[Union[str, bool]],
|
|
451
542
|
**model_kwargs,
|
|
@@ -468,8 +559,6 @@ class ModelHubMixin:
|
|
|
468
559
|
force_download (`bool`, *optional*, defaults to `False`):
|
|
469
560
|
Whether to force (re-)downloading the model weights and configuration files from the Hub, overriding
|
|
470
561
|
the existing cache.
|
|
471
|
-
resume_download (`bool`, *optional*, defaults to `False`):
|
|
472
|
-
Whether to delete incompletely received files. Will attempt to resume the download if such a file exists.
|
|
473
562
|
proxies (`Dict[str, str]`, *optional*):
|
|
474
563
|
A dictionary of proxy servers to use by protocol or endpoint (e.g., `{'http': 'foo.bar:3128',
|
|
475
564
|
'http://hostname': 'foo.bar:4012'}`).
|
|
@@ -485,11 +574,6 @@ class ModelHubMixin:
|
|
|
485
574
|
"""
|
|
486
575
|
raise NotImplementedError
|
|
487
576
|
|
|
488
|
-
@_deprecate_arguments(
|
|
489
|
-
version="0.23.0",
|
|
490
|
-
deprecated_args=["api_endpoint"],
|
|
491
|
-
custom_message="Use `HF_ENDPOINT` environment variable instead.",
|
|
492
|
-
)
|
|
493
577
|
@validate_hf_hub_args
|
|
494
578
|
def push_to_hub(
|
|
495
579
|
self,
|
|
@@ -504,8 +588,6 @@ class ModelHubMixin:
|
|
|
504
588
|
allow_patterns: Optional[Union[List[str], str]] = None,
|
|
505
589
|
ignore_patterns: Optional[Union[List[str], str]] = None,
|
|
506
590
|
delete_patterns: Optional[Union[List[str], str]] = None,
|
|
507
|
-
# TODO: remove once deprecated
|
|
508
|
-
api_endpoint: Optional[str] = None,
|
|
509
591
|
) -> str:
|
|
510
592
|
"""
|
|
511
593
|
Upload model checkpoint to the Hub.
|
|
@@ -523,8 +605,6 @@ class ModelHubMixin:
|
|
|
523
605
|
Message to commit while pushing.
|
|
524
606
|
private (`bool`, *optional*, defaults to `False`):
|
|
525
607
|
Whether the repository created should be private.
|
|
526
|
-
api_endpoint (`str`, *optional*):
|
|
527
|
-
The API endpoint to use when pushing the model to the hub.
|
|
528
608
|
token (`str`, *optional*):
|
|
529
609
|
The token to use as HTTP bearer authorization for remote files. By default, it will use the token
|
|
530
610
|
cached when running `huggingface-cli login`.
|
|
@@ -542,7 +622,7 @@ class ModelHubMixin:
|
|
|
542
622
|
Returns:
|
|
543
623
|
The url of the commit of your model in the given repository.
|
|
544
624
|
"""
|
|
545
|
-
api = HfApi(
|
|
625
|
+
api = HfApi(token=token)
|
|
546
626
|
repo_id = api.create_repo(repo_id=repo_id, private=private, exist_ok=True).repo_id
|
|
547
627
|
|
|
548
628
|
# Push the files to the repo in a single commit
|
|
@@ -563,8 +643,10 @@ class ModelHubMixin:
|
|
|
563
643
|
|
|
564
644
|
def generate_model_card(self, *args, **kwargs) -> ModelCard:
|
|
565
645
|
card = ModelCard.from_template(
|
|
566
|
-
card_data=
|
|
567
|
-
template_str=
|
|
646
|
+
card_data=self._hub_mixin_info.model_card_data,
|
|
647
|
+
template_str=self._hub_mixin_info.model_card_template,
|
|
648
|
+
repo_url=self._hub_mixin_info.repo_url,
|
|
649
|
+
docs_url=self._hub_mixin_info.docs_url,
|
|
568
650
|
)
|
|
569
651
|
return card
|
|
570
652
|
|
|
@@ -575,6 +657,8 @@ class PyTorchModelHubMixin(ModelHubMixin):
|
|
|
575
657
|
is set in evaluation mode by default using `model.eval()` (dropout modules are deactivated). To train the model,
|
|
576
658
|
you should first set it back in training mode with `model.train()`.
|
|
577
659
|
|
|
660
|
+
See [`ModelHubMixin`] for more details on how to use the mixin.
|
|
661
|
+
|
|
578
662
|
Example:
|
|
579
663
|
|
|
580
664
|
```python
|
|
@@ -632,7 +716,7 @@ class PyTorchModelHubMixin(ModelHubMixin):
|
|
|
632
716
|
cache_dir: Optional[Union[str, Path]],
|
|
633
717
|
force_download: bool,
|
|
634
718
|
proxies: Optional[Dict],
|
|
635
|
-
resume_download: bool,
|
|
719
|
+
resume_download: Optional[bool],
|
|
636
720
|
local_files_only: bool,
|
|
637
721
|
token: Union[str, bool, None],
|
|
638
722
|
map_location: str = "cpu",
|