huggingface-hub 0.21.4__py3-none-any.whl → 0.22.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of huggingface-hub might be problematic. Click here for more details.
- huggingface_hub/__init__.py +217 -1
- huggingface_hub/_commit_api.py +14 -15
- huggingface_hub/_inference_endpoints.py +12 -11
- huggingface_hub/_login.py +1 -0
- huggingface_hub/_multi_commits.py +1 -0
- huggingface_hub/_snapshot_download.py +9 -1
- huggingface_hub/_tensorboard_logger.py +1 -0
- huggingface_hub/_webhooks_payload.py +1 -0
- huggingface_hub/_webhooks_server.py +1 -0
- huggingface_hub/commands/_cli_utils.py +1 -0
- huggingface_hub/commands/delete_cache.py +1 -0
- huggingface_hub/commands/download.py +1 -0
- huggingface_hub/commands/env.py +1 -0
- huggingface_hub/commands/scan_cache.py +1 -0
- huggingface_hub/commands/upload.py +1 -0
- huggingface_hub/community.py +1 -0
- huggingface_hub/constants.py +3 -1
- huggingface_hub/errors.py +38 -0
- huggingface_hub/file_download.py +24 -24
- huggingface_hub/hf_api.py +47 -35
- huggingface_hub/hub_mixin.py +210 -54
- huggingface_hub/inference/_client.py +554 -239
- huggingface_hub/inference/_common.py +195 -41
- huggingface_hub/inference/_generated/_async_client.py +558 -239
- huggingface_hub/inference/_generated/types/__init__.py +115 -0
- huggingface_hub/inference/_generated/types/audio_classification.py +43 -0
- huggingface_hub/inference/_generated/types/audio_to_audio.py +31 -0
- huggingface_hub/inference/_generated/types/automatic_speech_recognition.py +116 -0
- huggingface_hub/inference/_generated/types/base.py +149 -0
- huggingface_hub/inference/_generated/types/chat_completion.py +106 -0
- huggingface_hub/inference/_generated/types/depth_estimation.py +29 -0
- huggingface_hub/inference/_generated/types/document_question_answering.py +85 -0
- huggingface_hub/inference/_generated/types/feature_extraction.py +19 -0
- huggingface_hub/inference/_generated/types/fill_mask.py +50 -0
- huggingface_hub/inference/_generated/types/image_classification.py +43 -0
- huggingface_hub/inference/_generated/types/image_segmentation.py +52 -0
- huggingface_hub/inference/_generated/types/image_to_image.py +55 -0
- huggingface_hub/inference/_generated/types/image_to_text.py +105 -0
- huggingface_hub/inference/_generated/types/object_detection.py +55 -0
- huggingface_hub/inference/_generated/types/question_answering.py +77 -0
- huggingface_hub/inference/_generated/types/sentence_similarity.py +28 -0
- huggingface_hub/inference/_generated/types/summarization.py +46 -0
- huggingface_hub/inference/_generated/types/table_question_answering.py +45 -0
- huggingface_hub/inference/_generated/types/text2text_generation.py +45 -0
- huggingface_hub/inference/_generated/types/text_classification.py +43 -0
- huggingface_hub/inference/_generated/types/text_generation.py +161 -0
- huggingface_hub/inference/_generated/types/text_to_audio.py +105 -0
- huggingface_hub/inference/_generated/types/text_to_image.py +57 -0
- huggingface_hub/inference/_generated/types/token_classification.py +53 -0
- huggingface_hub/inference/_generated/types/translation.py +46 -0
- huggingface_hub/inference/_generated/types/video_classification.py +47 -0
- huggingface_hub/inference/_generated/types/visual_question_answering.py +53 -0
- huggingface_hub/inference/_generated/types/zero_shot_classification.py +56 -0
- huggingface_hub/inference/_generated/types/zero_shot_image_classification.py +51 -0
- huggingface_hub/inference/_generated/types/zero_shot_object_detection.py +55 -0
- huggingface_hub/inference/_templating.py +105 -0
- huggingface_hub/inference/_types.py +4 -152
- huggingface_hub/keras_mixin.py +39 -17
- huggingface_hub/lfs.py +20 -8
- huggingface_hub/repocard.py +11 -3
- huggingface_hub/repocard_data.py +12 -2
- huggingface_hub/serialization/__init__.py +1 -0
- huggingface_hub/serialization/_base.py +1 -0
- huggingface_hub/serialization/_numpy.py +1 -0
- huggingface_hub/serialization/_tensorflow.py +1 -0
- huggingface_hub/serialization/_torch.py +1 -0
- huggingface_hub/utils/__init__.py +4 -1
- huggingface_hub/utils/_cache_manager.py +7 -0
- huggingface_hub/utils/_chunk_utils.py +1 -0
- huggingface_hub/utils/_datetime.py +1 -0
- huggingface_hub/utils/_errors.py +10 -1
- huggingface_hub/utils/_experimental.py +1 -0
- huggingface_hub/utils/_fixes.py +19 -3
- huggingface_hub/utils/_git_credential.py +1 -0
- huggingface_hub/utils/_headers.py +10 -3
- huggingface_hub/utils/_hf_folder.py +1 -0
- huggingface_hub/utils/_http.py +1 -0
- huggingface_hub/utils/_pagination.py +1 -0
- huggingface_hub/utils/_paths.py +1 -0
- huggingface_hub/utils/_runtime.py +22 -0
- huggingface_hub/utils/_subprocess.py +1 -0
- huggingface_hub/utils/_token.py +1 -0
- huggingface_hub/utils/_typing.py +29 -1
- huggingface_hub/utils/_validators.py +1 -0
- huggingface_hub/utils/endpoint_helpers.py +1 -0
- huggingface_hub/utils/logging.py +1 -1
- huggingface_hub/utils/sha.py +1 -0
- huggingface_hub/utils/tqdm.py +1 -0
- {huggingface_hub-0.21.4.dist-info → huggingface_hub-0.22.0rc0.dist-info}/METADATA +14 -15
- huggingface_hub-0.22.0rc0.dist-info/RECORD +113 -0
- {huggingface_hub-0.21.4.dist-info → huggingface_hub-0.22.0rc0.dist-info}/WHEEL +1 -1
- huggingface_hub/inference/_text_generation.py +0 -551
- huggingface_hub-0.21.4.dist-info/RECORD +0 -81
- {huggingface_hub-0.21.4.dist-info → huggingface_hub-0.22.0rc0.dist-info}/LICENSE +0 -0
- {huggingface_hub-0.21.4.dist-info → huggingface_hub-0.22.0rc0.dist-info}/entry_points.txt +0 -0
- {huggingface_hub-0.21.4.dist-info → huggingface_hub-0.22.0rc0.dist-info}/top_level.txt +0 -0
huggingface_hub/hub_mixin.py
CHANGED
|
@@ -1,17 +1,19 @@
|
|
|
1
1
|
import inspect
|
|
2
2
|
import json
|
|
3
3
|
import os
|
|
4
|
-
from dataclasses import asdict, is_dataclass
|
|
4
|
+
from dataclasses import asdict, dataclass, is_dataclass
|
|
5
5
|
from pathlib import Path
|
|
6
|
-
from typing import TYPE_CHECKING, Dict, List, Optional, Type, TypeVar, Union, get_args
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, TypeVar, Union, get_args
|
|
7
7
|
|
|
8
8
|
from .constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME, SAFETENSORS_SINGLE_FILE
|
|
9
9
|
from .file_download import hf_hub_download
|
|
10
10
|
from .hf_api import HfApi
|
|
11
|
+
from .repocard import ModelCard, ModelCardData
|
|
11
12
|
from .utils import (
|
|
12
13
|
EntryNotFoundError,
|
|
13
14
|
HfHubHTTPError,
|
|
14
15
|
SoftTemporaryDirectory,
|
|
16
|
+
is_jsonable,
|
|
15
17
|
is_safetensors_available,
|
|
16
18
|
is_torch_available,
|
|
17
19
|
logging,
|
|
@@ -36,6 +38,26 @@ logger = logging.get_logger(__name__)
|
|
|
36
38
|
# Generic variable that is either ModelHubMixin or a subclass thereof
|
|
37
39
|
T = TypeVar("T", bound="ModelHubMixin")
|
|
38
40
|
|
|
41
|
+
DEFAULT_MODEL_CARD = """
|
|
42
|
+
---
|
|
43
|
+
# For reference on model card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1
|
|
44
|
+
# Doc / guide: https://huggingface.co/docs/hub/model-cards
|
|
45
|
+
{{ card_data }}
|
|
46
|
+
---
|
|
47
|
+
|
|
48
|
+
This model has been pushed to the Hub using **{{ library_name }}**:
|
|
49
|
+
- Repo: {{ repo_url | default("[More Information Needed]", true) }}
|
|
50
|
+
- Docs: {{ docs_url | default("[More Information Needed]", true) }}
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@dataclass
|
|
55
|
+
class MixinInfo:
|
|
56
|
+
library_name: Optional[str] = None
|
|
57
|
+
tags: Optional[List[str]] = None
|
|
58
|
+
repo_url: Optional[str] = None
|
|
59
|
+
docs_url: Optional[str] = None
|
|
60
|
+
|
|
39
61
|
|
|
40
62
|
class ModelHubMixin:
|
|
41
63
|
"""
|
|
@@ -45,21 +67,35 @@ class ModelHubMixin:
|
|
|
45
67
|
have to be overwritten in [`_from_pretrained`] and [`_save_pretrained`]. [`PyTorchModelHubMixin`] is a good example
|
|
46
68
|
of mixin integration with the Hub. Check out our [integration guide](../guides/integrations) for more instructions.
|
|
47
69
|
|
|
70
|
+
When inheriting from [`ModelHubMixin`], you can define class-level attributes. These attributes are not passed to
|
|
71
|
+
`__init__` but to the class definition itself. This is useful to define metadata about the library integrating
|
|
72
|
+
[`ModelHubMixin`].
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
library_name (`str`, *optional*):
|
|
76
|
+
Name of the library integrating ModelHubMixin. Used to generate model card.
|
|
77
|
+
tags (`List[str]`, *optional*):
|
|
78
|
+
Tags to be added to the model card. Used to generate model card.
|
|
79
|
+
repo_url (`str`, *optional*):
|
|
80
|
+
URL of the library repository. Used to generate model card.
|
|
81
|
+
docs_url (`str`, *optional*):
|
|
82
|
+
URL of the library documentation. Used to generate model card.
|
|
83
|
+
|
|
48
84
|
Example:
|
|
49
85
|
|
|
50
86
|
```python
|
|
51
|
-
>>> from dataclasses import dataclass
|
|
52
87
|
>>> from huggingface_hub import ModelHubMixin
|
|
53
88
|
|
|
54
|
-
#
|
|
55
|
-
>>>
|
|
56
|
-
...
|
|
57
|
-
...
|
|
58
|
-
...
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
...
|
|
89
|
+
# Inherit from ModelHubMixin
|
|
90
|
+
>>> class MyCustomModel(
|
|
91
|
+
... ModelHubMixin,
|
|
92
|
+
... library_name="my-library",
|
|
93
|
+
... tags=["x-custom-tag"],
|
|
94
|
+
... repo_url="https://github.com/huggingface/my-cool-library",
|
|
95
|
+
... docs_url="https://huggingface.co/docs/my-cool-library",
|
|
96
|
+
... # ^ optional metadata to generate model card
|
|
97
|
+
... ):
|
|
98
|
+
... def __init__(self, size: int = 512, device: str = "cpu"):
|
|
63
99
|
... # define how to initialize your model
|
|
64
100
|
... super().__init__()
|
|
65
101
|
... ...
|
|
@@ -85,7 +121,7 @@ class ModelHubMixin:
|
|
|
85
121
|
... # define how to deserialize your model
|
|
86
122
|
... ...
|
|
87
123
|
|
|
88
|
-
>>> model = MyCustomModel(
|
|
124
|
+
>>> model = MyCustomModel(size=256, device="gpu")
|
|
89
125
|
|
|
90
126
|
# Save model weights to local directory
|
|
91
127
|
>>> model.save_pretrained("my-awesome-model")
|
|
@@ -95,28 +131,107 @@ class ModelHubMixin:
|
|
|
95
131
|
|
|
96
132
|
# Download and initialize weights from the Hub
|
|
97
133
|
>>> reloaded_model = MyCustomModel.from_pretrained("username/my-awesome-model")
|
|
98
|
-
>>> reloaded_model.
|
|
99
|
-
|
|
134
|
+
>>> reloaded_model._hub_mixin_config
|
|
135
|
+
{"size": 256, "device": "gpu"}
|
|
136
|
+
|
|
137
|
+
# Model card has been correctly populated
|
|
138
|
+
>>> from huggingface_hub import ModelCard
|
|
139
|
+
>>> card = ModelCard.load("username/my-awesome-model")
|
|
140
|
+
>>> card.data.tags
|
|
141
|
+
["x-custom-tag", "pytorch_model_hub_mixin", "model_hub_mixin"]
|
|
142
|
+
>>> card.data.library_name
|
|
143
|
+
"my-library"
|
|
100
144
|
```
|
|
101
145
|
"""
|
|
102
146
|
|
|
103
|
-
|
|
104
|
-
# ^ optional config attribute automatically set in `from_pretrained`
|
|
147
|
+
_hub_mixin_config: Optional[Union[dict, "DataclassInstance"]] = None
|
|
148
|
+
# ^ optional config attribute automatically set in `from_pretrained`
|
|
149
|
+
_hub_mixin_info: MixinInfo
|
|
150
|
+
# ^ information about the library integrating ModelHubMixin (used to generate model card)
|
|
151
|
+
_hub_mixin_init_parameters: Dict[str, inspect.Parameter]
|
|
152
|
+
_hub_mixin_jsonable_default_values: Dict[str, Any]
|
|
153
|
+
_hub_mixin_inject_config: bool
|
|
154
|
+
# ^ internal values to handle config
|
|
155
|
+
|
|
156
|
+
def __init_subclass__(
|
|
157
|
+
cls,
|
|
158
|
+
*,
|
|
159
|
+
library_name: Optional[str] = None,
|
|
160
|
+
tags: Optional[List[str]] = None,
|
|
161
|
+
repo_url: Optional[str] = None,
|
|
162
|
+
docs_url: Optional[str] = None,
|
|
163
|
+
) -> None:
|
|
164
|
+
"""Inspect __init__ signature only once when subclassing + handle modelcard."""
|
|
165
|
+
super().__init_subclass__()
|
|
166
|
+
|
|
167
|
+
# Will be reused when creating modelcard
|
|
168
|
+
tags = tags or []
|
|
169
|
+
tags.append("model_hub_mixin")
|
|
170
|
+
cls._hub_mixin_info = MixinInfo(
|
|
171
|
+
library_name=library_name,
|
|
172
|
+
tags=tags,
|
|
173
|
+
repo_url=repo_url,
|
|
174
|
+
docs_url=docs_url,
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
# Inspect __init__ signature to handle config
|
|
178
|
+
cls._hub_mixin_init_parameters = dict(inspect.signature(cls.__init__).parameters)
|
|
179
|
+
cls._hub_mixin_jsonable_default_values = {
|
|
180
|
+
param.name: param.default
|
|
181
|
+
for param in cls._hub_mixin_init_parameters.values()
|
|
182
|
+
if param.default is not inspect.Parameter.empty and is_jsonable(param.default)
|
|
183
|
+
}
|
|
184
|
+
cls._hub_mixin_inject_config = "config" in inspect.signature(cls._from_pretrained).parameters
|
|
105
185
|
|
|
106
186
|
def __new__(cls, *args, **kwargs) -> "ModelHubMixin":
|
|
187
|
+
"""Create a new instance of the class and handle config.
|
|
188
|
+
|
|
189
|
+
3 cases:
|
|
190
|
+
- If `self._hub_mixin_config` is already set, do nothing.
|
|
191
|
+
- If `config` is passed as a dataclass, set it as `self._hub_mixin_config`.
|
|
192
|
+
- Otherwise, build `self._hub_mixin_config` from default values and passed values.
|
|
193
|
+
"""
|
|
107
194
|
instance = super().__new__(cls)
|
|
108
195
|
|
|
109
|
-
#
|
|
110
|
-
if instance.
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
196
|
+
# If `config` is already set, return early
|
|
197
|
+
if instance._hub_mixin_config is not None:
|
|
198
|
+
return instance
|
|
199
|
+
|
|
200
|
+
# Infer passed values
|
|
201
|
+
passed_values = {
|
|
202
|
+
**{
|
|
203
|
+
key: value
|
|
204
|
+
for key, value in zip(
|
|
205
|
+
# [1:] to skip `self` parameter
|
|
206
|
+
list(cls._hub_mixin_init_parameters)[1:],
|
|
207
|
+
args,
|
|
208
|
+
)
|
|
209
|
+
},
|
|
210
|
+
**kwargs,
|
|
211
|
+
}
|
|
212
|
+
|
|
213
|
+
# If config passed as dataclass => set it and return early
|
|
214
|
+
if is_dataclass(passed_values.get("config")):
|
|
215
|
+
instance._hub_mixin_config = passed_values["config"]
|
|
216
|
+
return instance
|
|
217
|
+
|
|
218
|
+
# Otherwise, build config from default + passed values
|
|
219
|
+
init_config = {
|
|
220
|
+
# default values
|
|
221
|
+
**cls._hub_mixin_jsonable_default_values,
|
|
222
|
+
# passed values
|
|
223
|
+
**{key: value for key, value in passed_values.items() if is_jsonable(value)},
|
|
224
|
+
}
|
|
225
|
+
init_config.pop("config", {})
|
|
226
|
+
|
|
227
|
+
# Populate `init_config` with provided config
|
|
228
|
+
provided_config = passed_values.get("config")
|
|
229
|
+
if isinstance(provided_config, dict):
|
|
230
|
+
init_config.update(provided_config)
|
|
231
|
+
|
|
232
|
+
# Set `config` attribute and return
|
|
233
|
+
if init_config != {}:
|
|
234
|
+
instance._hub_mixin_config = init_config
|
|
120
235
|
return instance
|
|
121
236
|
|
|
122
237
|
def save_pretrained(
|
|
@@ -150,13 +265,21 @@ class ModelHubMixin:
|
|
|
150
265
|
# save model weights/files (framework-specific)
|
|
151
266
|
self._save_pretrained(save_directory)
|
|
152
267
|
|
|
153
|
-
# save config (if provided)
|
|
268
|
+
# save config (if provided and if not serialized yet in `_save_pretrained`)
|
|
154
269
|
if config is None:
|
|
155
|
-
config = self.
|
|
270
|
+
config = self._hub_mixin_config
|
|
156
271
|
if config is not None:
|
|
157
272
|
if is_dataclass(config):
|
|
158
273
|
config = asdict(config) # type: ignore[arg-type]
|
|
159
|
-
|
|
274
|
+
config_path = save_directory / CONFIG_NAME
|
|
275
|
+
if not config_path.exists():
|
|
276
|
+
config_str = json.dumps(config, sort_keys=True, indent=2)
|
|
277
|
+
config_path.write_text(config_str)
|
|
278
|
+
|
|
279
|
+
# save model card
|
|
280
|
+
model_card_path = save_directory / "README.md"
|
|
281
|
+
if not model_card_path.exists(): # do not overwrite if already exists
|
|
282
|
+
self.generate_model_card().save(save_directory / "README.md")
|
|
160
283
|
|
|
161
284
|
# push to the Hub if required
|
|
162
285
|
if push_to_hub:
|
|
@@ -246,32 +369,42 @@ class ModelHubMixin:
|
|
|
246
369
|
except HfHubHTTPError as e:
|
|
247
370
|
logger.info(f"{CONFIG_NAME} not found on the HuggingFace Hub: {str(e)}")
|
|
248
371
|
|
|
372
|
+
# Read config
|
|
249
373
|
config = None
|
|
250
374
|
if config_file is not None:
|
|
251
|
-
# Read config
|
|
252
375
|
with open(config_file, "r", encoding="utf-8") as f:
|
|
253
376
|
config = json.load(f)
|
|
254
377
|
|
|
255
|
-
#
|
|
256
|
-
|
|
257
|
-
|
|
378
|
+
# Populate model_kwargs from config
|
|
379
|
+
for param in cls._hub_mixin_init_parameters.values():
|
|
380
|
+
if param.name not in model_kwargs and param.name in config:
|
|
381
|
+
model_kwargs[param.name] = config[param.name]
|
|
382
|
+
|
|
383
|
+
# Check if `config` argument was passed at init
|
|
384
|
+
if "config" in cls._hub_mixin_init_parameters:
|
|
258
385
|
# Check if `config` argument is a dataclass
|
|
259
|
-
config_annotation =
|
|
386
|
+
config_annotation = cls._hub_mixin_init_parameters["config"].annotation
|
|
260
387
|
if config_annotation is inspect.Parameter.empty:
|
|
261
388
|
pass # no annotation
|
|
262
389
|
elif is_dataclass(config_annotation):
|
|
263
|
-
config = config_annotation
|
|
390
|
+
config = _load_dataclass(config_annotation, config)
|
|
264
391
|
else:
|
|
265
392
|
# if Optional/Union annotation => check if a dataclass is in the Union
|
|
266
393
|
for _sub_annotation in get_args(config_annotation):
|
|
267
394
|
if is_dataclass(_sub_annotation):
|
|
268
|
-
config = _sub_annotation
|
|
395
|
+
config = _load_dataclass(_sub_annotation, config)
|
|
269
396
|
break
|
|
270
397
|
|
|
271
398
|
# Forward config to model initialization
|
|
272
399
|
model_kwargs["config"] = config
|
|
273
|
-
|
|
274
|
-
|
|
400
|
+
|
|
401
|
+
if any(param.kind == inspect.Parameter.VAR_KEYWORD for param in cls._hub_mixin_init_parameters.values()):
|
|
402
|
+
for key, value in config.items():
|
|
403
|
+
if key not in model_kwargs:
|
|
404
|
+
model_kwargs[key] = value
|
|
405
|
+
|
|
406
|
+
# Finally, also inject if `_from_pretrained` expects it
|
|
407
|
+
if cls._hub_mixin_inject_config:
|
|
275
408
|
model_kwargs["config"] = config
|
|
276
409
|
|
|
277
410
|
instance = cls._from_pretrained(
|
|
@@ -288,8 +421,8 @@ class ModelHubMixin:
|
|
|
288
421
|
|
|
289
422
|
# Implicitly set the config as instance attribute if not already set by the class
|
|
290
423
|
# This way `config` will be available when calling `save_pretrained` or `push_to_hub`.
|
|
291
|
-
if config is not None and instance
|
|
292
|
-
instance.
|
|
424
|
+
if config is not None and (getattr(instance, "_hub_mixin_config", None) in (None, {})):
|
|
425
|
+
instance._hub_mixin_config = config
|
|
293
426
|
|
|
294
427
|
return instance
|
|
295
428
|
|
|
@@ -418,6 +551,13 @@ class ModelHubMixin:
|
|
|
418
551
|
delete_patterns=delete_patterns,
|
|
419
552
|
)
|
|
420
553
|
|
|
554
|
+
def generate_model_card(self, *args, **kwargs) -> ModelCard:
|
|
555
|
+
card = ModelCard.from_template(
|
|
556
|
+
card_data=ModelCardData(**asdict(self._hub_mixin_info)),
|
|
557
|
+
template_str=DEFAULT_MODEL_CARD,
|
|
558
|
+
)
|
|
559
|
+
return card
|
|
560
|
+
|
|
421
561
|
|
|
422
562
|
class PyTorchModelHubMixin(ModelHubMixin):
|
|
423
563
|
"""
|
|
@@ -428,26 +568,26 @@ class PyTorchModelHubMixin(ModelHubMixin):
|
|
|
428
568
|
Example:
|
|
429
569
|
|
|
430
570
|
```python
|
|
431
|
-
>>> from dataclasses import dataclass
|
|
432
571
|
>>> import torch
|
|
433
572
|
>>> import torch.nn as nn
|
|
434
573
|
>>> from huggingface_hub import PyTorchModelHubMixin
|
|
435
574
|
|
|
436
|
-
>>>
|
|
437
|
-
...
|
|
438
|
-
...
|
|
439
|
-
...
|
|
440
|
-
...
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
...
|
|
575
|
+
>>> class MyModel(
|
|
576
|
+
... nn.Module,
|
|
577
|
+
... PyTorchModelHubMixin,
|
|
578
|
+
... library_name="keras-nlp",
|
|
579
|
+
... repo_url="https://github.com/keras-team/keras-nlp",
|
|
580
|
+
... docs_url="https://keras.io/keras_nlp/",
|
|
581
|
+
... # ^ optional metadata to generate model card
|
|
582
|
+
... ):
|
|
583
|
+
... def __init__(self, hidden_size: int = 512, vocab_size: int = 30000, output_size: int = 4):
|
|
444
584
|
... super().__init__()
|
|
445
|
-
... self.param = nn.Parameter(torch.rand(
|
|
446
|
-
... self.linear = nn.Linear(
|
|
585
|
+
... self.param = nn.Parameter(torch.rand(hidden_size, vocab_size))
|
|
586
|
+
... self.linear = nn.Linear(output_size, vocab_size)
|
|
447
587
|
|
|
448
588
|
... def forward(self, x):
|
|
449
589
|
... return self.linear(x + self.param)
|
|
450
|
-
>>> model = MyModel()
|
|
590
|
+
>>> model = MyModel(hidden_size=256)
|
|
451
591
|
|
|
452
592
|
# Save model weights to local directory
|
|
453
593
|
>>> model.save_pretrained("my-awesome-model")
|
|
@@ -457,9 +597,17 @@ class PyTorchModelHubMixin(ModelHubMixin):
|
|
|
457
597
|
|
|
458
598
|
# Download and initialize weights from the Hub
|
|
459
599
|
>>> model = MyModel.from_pretrained("username/my-awesome-model")
|
|
600
|
+
>>> model.hidden_size
|
|
601
|
+
256
|
|
460
602
|
```
|
|
461
603
|
"""
|
|
462
604
|
|
|
605
|
+
def __init_subclass__(cls, *args, tags: Optional[List[str]] = None, **kwargs) -> None:
|
|
606
|
+
tags = tags or []
|
|
607
|
+
tags.append("pytorch_model_hub_mixin")
|
|
608
|
+
kwargs["tags"] = tags
|
|
609
|
+
return super().__init_subclass__(*args, **kwargs)
|
|
610
|
+
|
|
463
611
|
def _save_pretrained(self, save_directory: Path) -> None:
|
|
464
612
|
"""Save weights from a Pytorch model to a local directory."""
|
|
465
613
|
model_to_save = self.module if hasattr(self, "module") else self # type: ignore
|
|
@@ -536,3 +684,11 @@ class PyTorchModelHubMixin(ModelHubMixin):
|
|
|
536
684
|
)
|
|
537
685
|
model.to(map_location) # type: ignore [attr-defined]
|
|
538
686
|
return model
|
|
687
|
+
|
|
688
|
+
|
|
689
|
+
def _load_dataclass(datacls: Type["DataclassInstance"], data: dict) -> "DataclassInstance":
|
|
690
|
+
"""Load a dataclass instance from a dictionary.
|
|
691
|
+
|
|
692
|
+
Fields not expected by the dataclass are ignored.
|
|
693
|
+
"""
|
|
694
|
+
return datacls(**{k: v for k, v in data.items() if k in datacls.__dataclass_fields__})
|