huggingface-hub 0.21.4__py3-none-any.whl → 0.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of huggingface-hub might be problematic. Click here for more details.

Files changed (97) hide show
  1. huggingface_hub/__init__.py +217 -1
  2. huggingface_hub/_commit_api.py +14 -15
  3. huggingface_hub/_inference_endpoints.py +12 -11
  4. huggingface_hub/_login.py +1 -0
  5. huggingface_hub/_multi_commits.py +1 -0
  6. huggingface_hub/_snapshot_download.py +9 -1
  7. huggingface_hub/_tensorboard_logger.py +1 -0
  8. huggingface_hub/_webhooks_payload.py +1 -0
  9. huggingface_hub/_webhooks_server.py +1 -0
  10. huggingface_hub/commands/_cli_utils.py +1 -0
  11. huggingface_hub/commands/delete_cache.py +1 -0
  12. huggingface_hub/commands/download.py +1 -0
  13. huggingface_hub/commands/env.py +1 -0
  14. huggingface_hub/commands/scan_cache.py +1 -0
  15. huggingface_hub/commands/upload.py +1 -0
  16. huggingface_hub/community.py +1 -0
  17. huggingface_hub/constants.py +3 -1
  18. huggingface_hub/errors.py +38 -0
  19. huggingface_hub/file_download.py +102 -95
  20. huggingface_hub/hf_api.py +47 -35
  21. huggingface_hub/hf_file_system.py +77 -3
  22. huggingface_hub/hub_mixin.py +215 -54
  23. huggingface_hub/inference/_client.py +554 -239
  24. huggingface_hub/inference/_common.py +195 -41
  25. huggingface_hub/inference/_generated/_async_client.py +558 -239
  26. huggingface_hub/inference/_generated/types/__init__.py +115 -0
  27. huggingface_hub/inference/_generated/types/audio_classification.py +43 -0
  28. huggingface_hub/inference/_generated/types/audio_to_audio.py +31 -0
  29. huggingface_hub/inference/_generated/types/automatic_speech_recognition.py +116 -0
  30. huggingface_hub/inference/_generated/types/base.py +149 -0
  31. huggingface_hub/inference/_generated/types/chat_completion.py +106 -0
  32. huggingface_hub/inference/_generated/types/depth_estimation.py +29 -0
  33. huggingface_hub/inference/_generated/types/document_question_answering.py +85 -0
  34. huggingface_hub/inference/_generated/types/feature_extraction.py +19 -0
  35. huggingface_hub/inference/_generated/types/fill_mask.py +50 -0
  36. huggingface_hub/inference/_generated/types/image_classification.py +43 -0
  37. huggingface_hub/inference/_generated/types/image_segmentation.py +52 -0
  38. huggingface_hub/inference/_generated/types/image_to_image.py +55 -0
  39. huggingface_hub/inference/_generated/types/image_to_text.py +105 -0
  40. huggingface_hub/inference/_generated/types/object_detection.py +55 -0
  41. huggingface_hub/inference/_generated/types/question_answering.py +77 -0
  42. huggingface_hub/inference/_generated/types/sentence_similarity.py +28 -0
  43. huggingface_hub/inference/_generated/types/summarization.py +46 -0
  44. huggingface_hub/inference/_generated/types/table_question_answering.py +45 -0
  45. huggingface_hub/inference/_generated/types/text2text_generation.py +45 -0
  46. huggingface_hub/inference/_generated/types/text_classification.py +43 -0
  47. huggingface_hub/inference/_generated/types/text_generation.py +161 -0
  48. huggingface_hub/inference/_generated/types/text_to_audio.py +105 -0
  49. huggingface_hub/inference/_generated/types/text_to_image.py +57 -0
  50. huggingface_hub/inference/_generated/types/token_classification.py +53 -0
  51. huggingface_hub/inference/_generated/types/translation.py +46 -0
  52. huggingface_hub/inference/_generated/types/video_classification.py +47 -0
  53. huggingface_hub/inference/_generated/types/visual_question_answering.py +53 -0
  54. huggingface_hub/inference/_generated/types/zero_shot_classification.py +56 -0
  55. huggingface_hub/inference/_generated/types/zero_shot_image_classification.py +51 -0
  56. huggingface_hub/inference/_generated/types/zero_shot_object_detection.py +55 -0
  57. huggingface_hub/inference/_templating.py +105 -0
  58. huggingface_hub/inference/_types.py +4 -152
  59. huggingface_hub/keras_mixin.py +39 -17
  60. huggingface_hub/lfs.py +20 -8
  61. huggingface_hub/repocard.py +11 -3
  62. huggingface_hub/repocard_data.py +12 -2
  63. huggingface_hub/serialization/__init__.py +1 -0
  64. huggingface_hub/serialization/_base.py +1 -0
  65. huggingface_hub/serialization/_numpy.py +1 -0
  66. huggingface_hub/serialization/_tensorflow.py +1 -0
  67. huggingface_hub/serialization/_torch.py +1 -0
  68. huggingface_hub/utils/__init__.py +4 -1
  69. huggingface_hub/utils/_cache_manager.py +7 -0
  70. huggingface_hub/utils/_chunk_utils.py +1 -0
  71. huggingface_hub/utils/_datetime.py +1 -0
  72. huggingface_hub/utils/_errors.py +10 -1
  73. huggingface_hub/utils/_experimental.py +1 -0
  74. huggingface_hub/utils/_fixes.py +19 -3
  75. huggingface_hub/utils/_git_credential.py +1 -0
  76. huggingface_hub/utils/_headers.py +10 -3
  77. huggingface_hub/utils/_hf_folder.py +1 -0
  78. huggingface_hub/utils/_http.py +1 -0
  79. huggingface_hub/utils/_pagination.py +1 -0
  80. huggingface_hub/utils/_paths.py +1 -0
  81. huggingface_hub/utils/_runtime.py +22 -0
  82. huggingface_hub/utils/_subprocess.py +1 -0
  83. huggingface_hub/utils/_token.py +1 -0
  84. huggingface_hub/utils/_typing.py +29 -1
  85. huggingface_hub/utils/_validators.py +1 -0
  86. huggingface_hub/utils/endpoint_helpers.py +1 -0
  87. huggingface_hub/utils/logging.py +1 -1
  88. huggingface_hub/utils/sha.py +1 -0
  89. huggingface_hub/utils/tqdm.py +1 -0
  90. {huggingface_hub-0.21.4.dist-info → huggingface_hub-0.22.0.dist-info}/METADATA +14 -15
  91. huggingface_hub-0.22.0.dist-info/RECORD +113 -0
  92. {huggingface_hub-0.21.4.dist-info → huggingface_hub-0.22.0.dist-info}/WHEEL +1 -1
  93. huggingface_hub/inference/_text_generation.py +0 -551
  94. huggingface_hub-0.21.4.dist-info/RECORD +0 -81
  95. {huggingface_hub-0.21.4.dist-info → huggingface_hub-0.22.0.dist-info}/LICENSE +0 -0
  96. {huggingface_hub-0.21.4.dist-info → huggingface_hub-0.22.0.dist-info}/entry_points.txt +0 -0
  97. {huggingface_hub-0.21.4.dist-info → huggingface_hub-0.22.0.dist-info}/top_level.txt +0 -0
@@ -6,15 +6,24 @@ from collections import deque
6
6
  from dataclasses import dataclass, field
7
7
  from datetime import datetime
8
8
  from itertools import chain
9
+ from pathlib import Path
9
10
  from typing import Any, Dict, List, NoReturn, Optional, Tuple, Union
10
11
  from urllib.parse import quote, unquote
11
12
 
12
13
  import fsspec
14
+ from fsspec.callbacks import _DEFAULT_CALLBACK, NoOpCallback, TqdmCallback
15
+ from fsspec.utils import isfilelike
13
16
  from requests import Response
14
17
 
15
18
  from ._commit_api import CommitOperationCopy, CommitOperationDelete
16
- from .constants import DEFAULT_REVISION, ENDPOINT, REPO_TYPE_MODEL, REPO_TYPES_MAPPING, REPO_TYPES_URL_PREFIXES
17
- from .file_download import hf_hub_url
19
+ from .constants import (
20
+ DEFAULT_REVISION,
21
+ ENDPOINT,
22
+ REPO_TYPE_MODEL,
23
+ REPO_TYPES_MAPPING,
24
+ REPO_TYPES_URL_PREFIXES,
25
+ )
26
+ from .file_download import hf_hub_url, http_get
18
27
  from .hf_api import HfApi, LastCommitInfo, RepoFile
19
28
  from .utils import (
20
29
  EntryNotFoundError,
@@ -591,6 +600,58 @@ class HfFileSystem(fsspec.AbstractFileSystem):
591
600
  url = url.replace("/resolve/", "/tree/", 1)
592
601
  return url
593
602
 
603
+ def get_file(self, rpath, lpath, callback=_DEFAULT_CALLBACK, outfile=None, **kwargs) -> None:
604
+ """Copy single remote file to local."""
605
+ revision = kwargs.get("revision")
606
+ unhandled_kwargs = set(kwargs.keys()) - {"revision"}
607
+ if not isinstance(callback, (NoOpCallback, TqdmCallback)) or len(unhandled_kwargs) > 0:
608
+ # for now, let's not handle custom callbacks
609
+ # and let's not handle custom kwargs
610
+ return super().get_file(rpath, lpath, callback=callback, outfile=outfile, **kwargs)
611
+
612
+ # Taken from https://github.com/fsspec/filesystem_spec/blob/47b445ae4c284a82dd15e0287b1ffc410e8fc470/fsspec/spec.py#L883
613
+ if isfilelike(lpath):
614
+ outfile = lpath
615
+ elif self.isdir(rpath):
616
+ os.makedirs(lpath, exist_ok=True)
617
+ return None
618
+
619
+ if isinstance(lpath, (str, Path)): # otherwise, let's assume it's a file-like object
620
+ os.makedirs(os.path.dirname(lpath), exist_ok=True)
621
+
622
+ # Open file if not already open
623
+ close_file = False
624
+ if outfile is None:
625
+ outfile = open(lpath, "wb")
626
+ close_file = True
627
+ initial_pos = outfile.tell()
628
+
629
+ # Custom implementation of `get_file` to use `http_get`.
630
+ resolve_remote_path = self.resolve_path(rpath, revision=revision)
631
+ expected_size = self.info(rpath, revision=revision)["size"]
632
+ callback.set_size(expected_size)
633
+ try:
634
+ http_get(
635
+ url=hf_hub_url(
636
+ repo_id=resolve_remote_path.repo_id,
637
+ revision=resolve_remote_path.revision,
638
+ filename=resolve_remote_path.path_in_repo,
639
+ repo_type=resolve_remote_path.repo_type,
640
+ endpoint=self.endpoint,
641
+ ),
642
+ temp_file=outfile,
643
+ displayed_filename=rpath,
644
+ expected_size=expected_size,
645
+ resume_size=0,
646
+ headers=self._api._build_hf_headers(),
647
+ _tqdm_bar=callback.tqdm if isinstance(callback, TqdmCallback) else None,
648
+ )
649
+ outfile.seek(initial_pos)
650
+ finally:
651
+ # Close file only if we opened it ourselves
652
+ if close_file:
653
+ outfile.close()
654
+
594
655
  @property
595
656
  def transaction(self):
596
657
  """A context within which files are committed together upon exit
@@ -618,6 +679,7 @@ class HfFileSystemFile(fsspec.spec.AbstractBufferedFile):
618
679
  raise FileNotFoundError(
619
680
  f"{e}.\nMake sure the repository and revision exist before writing data."
620
681
  ) from e
682
+ raise
621
683
  super().__init__(fs, self.resolved_path.unresolve(), **kwargs)
622
684
  self.fs: HfFileSystem
623
685
 
@@ -667,6 +729,18 @@ class HfFileSystemFile(fsspec.spec.AbstractBufferedFile):
667
729
  path=self.resolved_path.unresolve(),
668
730
  )
669
731
 
732
+ def read(self, length=-1):
733
+ """Read remote file.
734
+
735
+ If `length` is not provided or is -1, the entire file is downloaded and read. On POSIX systems and if
736
+ `hf_transfer` is not enabled, the file is loaded in memory directly. Otherwise, the file is downloaded to a
737
+ temporary file and read from there.
738
+ """
739
+ if self.mode == "rb" and (length is None or length == -1) and self.loc == 0:
740
+ with self.fs.open(self.path, "rb", block_size=0) as f: # block_size=0 enables fast streaming
741
+ return f.read()
742
+ return super().read(length)
743
+
670
744
  def url(self) -> str:
671
745
  return self.fs.url(self.path)
672
746
 
@@ -695,7 +769,7 @@ class HfFileSystemStreamFile(fsspec.spec.AbstractBufferedFile):
695
769
  raise FileNotFoundError(
696
770
  f"{e}.\nMake sure the repository and revision exist before writing data."
697
771
  ) from e
698
- # avoid an unecessary .info() call to instantiate .details
772
+ # avoid an unnecessary .info() call to instantiate .details
699
773
  self.details = {"name": self.resolved_path.unresolve(), "size": None}
700
774
  super().__init__(
701
775
  fs, self.resolved_path.unresolve(), mode=mode, block_size=block_size, cache_type=cache_type, **kwargs
@@ -1,17 +1,19 @@
1
1
  import inspect
2
2
  import json
3
3
  import os
4
- from dataclasses import asdict, is_dataclass
4
+ from dataclasses import asdict, dataclass, is_dataclass
5
5
  from pathlib import Path
6
- from typing import TYPE_CHECKING, Dict, List, Optional, Type, TypeVar, Union, get_args
6
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, TypeVar, Union, get_args
7
7
 
8
8
  from .constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME, SAFETENSORS_SINGLE_FILE
9
9
  from .file_download import hf_hub_download
10
10
  from .hf_api import HfApi
11
+ from .repocard import ModelCard, ModelCardData
11
12
  from .utils import (
12
13
  EntryNotFoundError,
13
14
  HfHubHTTPError,
14
15
  SoftTemporaryDirectory,
16
+ is_jsonable,
15
17
  is_safetensors_available,
16
18
  is_torch_available,
17
19
  logging,
@@ -36,6 +38,26 @@ logger = logging.get_logger(__name__)
36
38
  # Generic variable that is either ModelHubMixin or a subclass thereof
37
39
  T = TypeVar("T", bound="ModelHubMixin")
38
40
 
41
+ DEFAULT_MODEL_CARD = """
42
+ ---
43
+ # For reference on model card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1
44
+ # Doc / guide: https://huggingface.co/docs/hub/model-cards
45
+ {{ card_data }}
46
+ ---
47
+
48
+ This model has been pushed to the Hub using **{{ library_name }}**:
49
+ - Repo: {{ repo_url | default("[More Information Needed]", true) }}
50
+ - Docs: {{ docs_url | default("[More Information Needed]", true) }}
51
+ """
52
+
53
+
54
+ @dataclass
55
+ class MixinInfo:
56
+ library_name: Optional[str] = None
57
+ tags: Optional[List[str]] = None
58
+ repo_url: Optional[str] = None
59
+ docs_url: Optional[str] = None
60
+
39
61
 
40
62
  class ModelHubMixin:
41
63
  """
@@ -45,21 +67,35 @@ class ModelHubMixin:
45
67
  have to be overwritten in [`_from_pretrained`] and [`_save_pretrained`]. [`PyTorchModelHubMixin`] is a good example
46
68
  of mixin integration with the Hub. Check out our [integration guide](../guides/integrations) for more instructions.
47
69
 
70
+ When inheriting from [`ModelHubMixin`], you can define class-level attributes. These attributes are not passed to
71
+ `__init__` but to the class definition itself. This is useful to define metadata about the library integrating
72
+ [`ModelHubMixin`].
73
+
74
+ Args:
75
+ library_name (`str`, *optional*):
76
+ Name of the library integrating ModelHubMixin. Used to generate model card.
77
+ tags (`List[str]`, *optional*):
78
+ Tags to be added to the model card. Used to generate model card.
79
+ repo_url (`str`, *optional*):
80
+ URL of the library repository. Used to generate model card.
81
+ docs_url (`str`, *optional*):
82
+ URL of the library documentation. Used to generate model card.
83
+
48
84
  Example:
49
85
 
50
86
  ```python
51
- >>> from dataclasses import dataclass
52
87
  >>> from huggingface_hub import ModelHubMixin
53
88
 
54
- # Define your model configuration (optional)
55
- >>> @dataclass
56
- ... class Config:
57
- ... foo: int = 512
58
- ... bar: str = "cpu"
59
-
60
- # Inherit from ModelHubMixin (and optionally from your framework's model class)
61
- >>> class MyCustomModel(ModelHubMixin):
62
- ... def __init__(self, config: Config):
89
+ # Inherit from ModelHubMixin
90
+ >>> class MyCustomModel(
91
+ ... ModelHubMixin,
92
+ ... library_name="my-library",
93
+ ... tags=["x-custom-tag"],
94
+ ... repo_url="https://github.com/huggingface/my-cool-library",
95
+ ... docs_url="https://huggingface.co/docs/my-cool-library",
96
+ ... # ^ optional metadata to generate model card
97
+ ... ):
98
+ ... def __init__(self, size: int = 512, device: str = "cpu"):
63
99
  ... # define how to initialize your model
64
100
  ... super().__init__()
65
101
  ... ...
@@ -85,7 +121,7 @@ class ModelHubMixin:
85
121
  ... # define how to deserialize your model
86
122
  ... ...
87
123
 
88
- >>> model = MyCustomModel(config=Config(foo=256, bar="gpu"))
124
+ >>> model = MyCustomModel(size=256, device="gpu")
89
125
 
90
126
  # Save model weights to local directory
91
127
  >>> model.save_pretrained("my-awesome-model")
@@ -95,28 +131,107 @@ class ModelHubMixin:
95
131
 
96
132
  # Download and initialize weights from the Hub
97
133
  >>> reloaded_model = MyCustomModel.from_pretrained("username/my-awesome-model")
98
- >>> reloaded_model.config
99
- Config(foo=256, bar="gpu")
134
+ >>> reloaded_model._hub_mixin_config
135
+ {"size": 256, "device": "gpu"}
136
+
137
+ # Model card has been correctly populated
138
+ >>> from huggingface_hub import ModelCard
139
+ >>> card = ModelCard.load("username/my-awesome-model")
140
+ >>> card.data.tags
141
+ ["x-custom-tag", "pytorch_model_hub_mixin", "model_hub_mixin"]
142
+ >>> card.data.library_name
143
+ "my-library"
100
144
  ```
101
145
  """
102
146
 
103
- config: Optional[Union[dict, "DataclassInstance"]] = None
104
- # ^ optional config attribute automatically set in `from_pretrained` (if not already set by the subclass)
147
+ _hub_mixin_config: Optional[Union[dict, "DataclassInstance"]] = None
148
+ # ^ optional config attribute automatically set in `from_pretrained`
149
+ _hub_mixin_info: MixinInfo
150
+ # ^ information about the library integrating ModelHubMixin (used to generate model card)
151
+ _hub_mixin_init_parameters: Dict[str, inspect.Parameter]
152
+ _hub_mixin_jsonable_default_values: Dict[str, Any]
153
+ _hub_mixin_inject_config: bool
154
+ # ^ internal values to handle config
155
+
156
+ def __init_subclass__(
157
+ cls,
158
+ *,
159
+ library_name: Optional[str] = None,
160
+ tags: Optional[List[str]] = None,
161
+ repo_url: Optional[str] = None,
162
+ docs_url: Optional[str] = None,
163
+ ) -> None:
164
+ """Inspect __init__ signature only once when subclassing + handle modelcard."""
165
+ super().__init_subclass__()
166
+
167
+ # Will be reused when creating modelcard
168
+ tags = tags or []
169
+ tags.append("model_hub_mixin")
170
+ cls._hub_mixin_info = MixinInfo(
171
+ library_name=library_name,
172
+ tags=tags,
173
+ repo_url=repo_url,
174
+ docs_url=docs_url,
175
+ )
176
+
177
+ # Inspect __init__ signature to handle config
178
+ cls._hub_mixin_init_parameters = dict(inspect.signature(cls.__init__).parameters)
179
+ cls._hub_mixin_jsonable_default_values = {
180
+ param.name: param.default
181
+ for param in cls._hub_mixin_init_parameters.values()
182
+ if param.default is not inspect.Parameter.empty and is_jsonable(param.default)
183
+ }
184
+ cls._hub_mixin_inject_config = "config" in inspect.signature(cls._from_pretrained).parameters
105
185
 
106
186
  def __new__(cls, *args, **kwargs) -> "ModelHubMixin":
187
+ """Create a new instance of the class and handle config.
188
+
189
+ 3 cases:
190
+ - If `self._hub_mixin_config` is already set, do nothing.
191
+ - If `config` is passed as a dataclass, set it as `self._hub_mixin_config`.
192
+ - Otherwise, build `self._hub_mixin_config` from default values and passed values.
193
+ """
107
194
  instance = super().__new__(cls)
108
195
 
109
- # Set `config` attribute if not already set by the subclass
110
- if instance.config is None:
111
- if "config" in kwargs:
112
- instance.config = kwargs["config"]
113
- elif len(args) > 0:
114
- sig = inspect.signature(cls.__init__)
115
- parameters = list(sig.parameters)[1:] # remove `self`
116
- for key, value in zip(parameters, args):
117
- if key == "config":
118
- instance.config = value
119
- break
196
+ # If `config` is already set, return early
197
+ if instance._hub_mixin_config is not None:
198
+ return instance
199
+
200
+ # Infer passed values
201
+ passed_values = {
202
+ **{
203
+ key: value
204
+ for key, value in zip(
205
+ # [1:] to skip `self` parameter
206
+ list(cls._hub_mixin_init_parameters)[1:],
207
+ args,
208
+ )
209
+ },
210
+ **kwargs,
211
+ }
212
+
213
+ # If config passed as dataclass => set it and return early
214
+ if is_dataclass(passed_values.get("config")):
215
+ instance._hub_mixin_config = passed_values["config"]
216
+ return instance
217
+
218
+ # Otherwise, build config from default + passed values
219
+ init_config = {
220
+ # default values
221
+ **cls._hub_mixin_jsonable_default_values,
222
+ # passed values
223
+ **{key: value for key, value in passed_values.items() if is_jsonable(value)},
224
+ }
225
+ init_config.pop("config", {})
226
+
227
+ # Populate `init_config` with provided config
228
+ provided_config = passed_values.get("config")
229
+ if isinstance(provided_config, dict):
230
+ init_config.update(provided_config)
231
+
232
+ # Set `config` attribute and return
233
+ if init_config != {}:
234
+ instance._hub_mixin_config = init_config
120
235
  return instance
121
236
 
122
237
  def save_pretrained(
@@ -147,16 +262,29 @@ class ModelHubMixin:
147
262
  save_directory = Path(save_directory)
148
263
  save_directory.mkdir(parents=True, exist_ok=True)
149
264
 
265
+ # Remove config.json if already exists. After `_save_pretrained` we don't want to overwrite config.json
266
+ # as it might have been saved by the custom `_save_pretrained` already. However we do want to overwrite
267
+ # an existing config.json if it was not saved by `_save_pretrained`.
268
+ config_path = save_directory / CONFIG_NAME
269
+ config_path.unlink(missing_ok=True)
270
+
150
271
  # save model weights/files (framework-specific)
151
272
  self._save_pretrained(save_directory)
152
273
 
153
- # save config (if provided)
274
+ # save config (if provided and if not serialized yet in `_save_pretrained`)
154
275
  if config is None:
155
- config = self.config
276
+ config = self._hub_mixin_config
156
277
  if config is not None:
157
278
  if is_dataclass(config):
158
279
  config = asdict(config) # type: ignore[arg-type]
159
- (save_directory / CONFIG_NAME).write_text(json.dumps(config, indent=2))
280
+ if not config_path.exists():
281
+ config_str = json.dumps(config, sort_keys=True, indent=2)
282
+ config_path.write_text(config_str)
283
+
284
+ # save model card
285
+ model_card_path = save_directory / "README.md"
286
+ if not model_card_path.exists(): # do not overwrite if already exists
287
+ self.generate_model_card().save(save_directory / "README.md")
160
288
 
161
289
  # push to the Hub if required
162
290
  if push_to_hub:
@@ -246,32 +374,42 @@ class ModelHubMixin:
246
374
  except HfHubHTTPError as e:
247
375
  logger.info(f"{CONFIG_NAME} not found on the HuggingFace Hub: {str(e)}")
248
376
 
377
+ # Read config
249
378
  config = None
250
379
  if config_file is not None:
251
- # Read config
252
380
  with open(config_file, "r", encoding="utf-8") as f:
253
381
  config = json.load(f)
254
382
 
255
- # Check if class expect a `config` argument
256
- init_parameters = inspect.signature(cls.__init__).parameters
257
- if "config" in init_parameters:
383
+ # Populate model_kwargs from config
384
+ for param in cls._hub_mixin_init_parameters.values():
385
+ if param.name not in model_kwargs and param.name in config:
386
+ model_kwargs[param.name] = config[param.name]
387
+
388
+ # Check if `config` argument was passed at init
389
+ if "config" in cls._hub_mixin_init_parameters:
258
390
  # Check if `config` argument is a dataclass
259
- config_annotation = init_parameters["config"].annotation
391
+ config_annotation = cls._hub_mixin_init_parameters["config"].annotation
260
392
  if config_annotation is inspect.Parameter.empty:
261
393
  pass # no annotation
262
394
  elif is_dataclass(config_annotation):
263
- config = config_annotation(**config) # expect a dataclass
395
+ config = _load_dataclass(config_annotation, config)
264
396
  else:
265
397
  # if Optional/Union annotation => check if a dataclass is in the Union
266
398
  for _sub_annotation in get_args(config_annotation):
267
399
  if is_dataclass(_sub_annotation):
268
- config = _sub_annotation(**config)
400
+ config = _load_dataclass(_sub_annotation, config)
269
401
  break
270
402
 
271
403
  # Forward config to model initialization
272
404
  model_kwargs["config"] = config
273
- elif any(param.kind == inspect.Parameter.VAR_KEYWORD for param in init_parameters.values()):
274
- # If __init__ accepts **kwargs, let's forward the config as well (as a dict)
405
+
406
+ if any(param.kind == inspect.Parameter.VAR_KEYWORD for param in cls._hub_mixin_init_parameters.values()):
407
+ for key, value in config.items():
408
+ if key not in model_kwargs:
409
+ model_kwargs[key] = value
410
+
411
+ # Finally, also inject if `_from_pretrained` expects it
412
+ if cls._hub_mixin_inject_config:
275
413
  model_kwargs["config"] = config
276
414
 
277
415
  instance = cls._from_pretrained(
@@ -288,8 +426,8 @@ class ModelHubMixin:
288
426
 
289
427
  # Implicitly set the config as instance attribute if not already set by the class
290
428
  # This way `config` will be available when calling `save_pretrained` or `push_to_hub`.
291
- if config is not None and instance.config is None:
292
- instance.config = config
429
+ if config is not None and (getattr(instance, "_hub_mixin_config", None) in (None, {})):
430
+ instance._hub_mixin_config = config
293
431
 
294
432
  return instance
295
433
 
@@ -418,6 +556,13 @@ class ModelHubMixin:
418
556
  delete_patterns=delete_patterns,
419
557
  )
420
558
 
559
+ def generate_model_card(self, *args, **kwargs) -> ModelCard:
560
+ card = ModelCard.from_template(
561
+ card_data=ModelCardData(**asdict(self._hub_mixin_info)),
562
+ template_str=DEFAULT_MODEL_CARD,
563
+ )
564
+ return card
565
+
421
566
 
422
567
  class PyTorchModelHubMixin(ModelHubMixin):
423
568
  """
@@ -428,26 +573,26 @@ class PyTorchModelHubMixin(ModelHubMixin):
428
573
  Example:
429
574
 
430
575
  ```python
431
- >>> from dataclasses import dataclass
432
576
  >>> import torch
433
577
  >>> import torch.nn as nn
434
578
  >>> from huggingface_hub import PyTorchModelHubMixin
435
579
 
436
- >>> @dataclass
437
- ... class Config:
438
- ... hidden_size: int = 512
439
- ... vocab_size: int = 30000
440
- ... output_size: int = 4
441
-
442
- >>> class MyModel(nn.Module, PyTorchModelHubMixin):
443
- ... def __init__(self, config: Config):
580
+ >>> class MyModel(
581
+ ... nn.Module,
582
+ ... PyTorchModelHubMixin,
583
+ ... library_name="keras-nlp",
584
+ ... repo_url="https://github.com/keras-team/keras-nlp",
585
+ ... docs_url="https://keras.io/keras_nlp/",
586
+ ... # ^ optional metadata to generate model card
587
+ ... ):
588
+ ... def __init__(self, hidden_size: int = 512, vocab_size: int = 30000, output_size: int = 4):
444
589
  ... super().__init__()
445
- ... self.param = nn.Parameter(torch.rand(config.hidden_size, config.vocab_size))
446
- ... self.linear = nn.Linear(config.output_size, config.vocab_size)
590
+ ... self.param = nn.Parameter(torch.rand(hidden_size, vocab_size))
591
+ ... self.linear = nn.Linear(output_size, vocab_size)
447
592
 
448
593
  ... def forward(self, x):
449
594
  ... return self.linear(x + self.param)
450
- >>> model = MyModel()
595
+ >>> model = MyModel(hidden_size=256)
451
596
 
452
597
  # Save model weights to local directory
453
598
  >>> model.save_pretrained("my-awesome-model")
@@ -457,9 +602,17 @@ class PyTorchModelHubMixin(ModelHubMixin):
457
602
 
458
603
  # Download and initialize weights from the Hub
459
604
  >>> model = MyModel.from_pretrained("username/my-awesome-model")
605
+ >>> model.hidden_size
606
+ 256
460
607
  ```
461
608
  """
462
609
 
610
+ def __init_subclass__(cls, *args, tags: Optional[List[str]] = None, **kwargs) -> None:
611
+ tags = tags or []
612
+ tags.append("pytorch_model_hub_mixin")
613
+ kwargs["tags"] = tags
614
+ return super().__init_subclass__(*args, **kwargs)
615
+
463
616
  def _save_pretrained(self, save_directory: Path) -> None:
464
617
  """Save weights from a Pytorch model to a local directory."""
465
618
  model_to_save = self.module if hasattr(self, "module") else self # type: ignore
@@ -536,3 +689,11 @@ class PyTorchModelHubMixin(ModelHubMixin):
536
689
  )
537
690
  model.to(map_location) # type: ignore [attr-defined]
538
691
  return model
692
+
693
+
694
+ def _load_dataclass(datacls: Type["DataclassInstance"], data: dict) -> "DataclassInstance":
695
+ """Load a dataclass instance from a dictionary.
696
+
697
+ Fields not expected by the dataclass are ignored.
698
+ """
699
+ return datacls(**{k: v for k, v in data.items() if k in datacls.__dataclass_fields__})