huggingface-hub 0.20.3__py3-none-any.whl → 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of huggingface-hub might be problematic. Click here for more details.

Files changed (35) hide show
  1. huggingface_hub/__init__.py +19 -1
  2. huggingface_hub/_commit_api.py +49 -20
  3. huggingface_hub/_inference_endpoints.py +10 -0
  4. huggingface_hub/_login.py +2 -2
  5. huggingface_hub/commands/download.py +1 -1
  6. huggingface_hub/file_download.py +57 -21
  7. huggingface_hub/hf_api.py +269 -54
  8. huggingface_hub/hf_file_system.py +131 -8
  9. huggingface_hub/hub_mixin.py +204 -42
  10. huggingface_hub/inference/_client.py +56 -9
  11. huggingface_hub/inference/_common.py +4 -3
  12. huggingface_hub/inference/_generated/_async_client.py +57 -9
  13. huggingface_hub/inference/_text_generation.py +5 -0
  14. huggingface_hub/inference/_types.py +17 -0
  15. huggingface_hub/lfs.py +6 -3
  16. huggingface_hub/repocard.py +5 -3
  17. huggingface_hub/repocard_data.py +11 -3
  18. huggingface_hub/serialization/__init__.py +19 -0
  19. huggingface_hub/serialization/_base.py +168 -0
  20. huggingface_hub/serialization/_numpy.py +67 -0
  21. huggingface_hub/serialization/_tensorflow.py +93 -0
  22. huggingface_hub/serialization/_torch.py +199 -0
  23. huggingface_hub/templates/datasetcard_template.md +1 -1
  24. huggingface_hub/templates/modelcard_template.md +1 -4
  25. huggingface_hub/utils/__init__.py +14 -10
  26. huggingface_hub/utils/_datetime.py +4 -11
  27. huggingface_hub/utils/_errors.py +29 -0
  28. huggingface_hub/utils/_runtime.py +21 -15
  29. huggingface_hub/utils/endpoint_helpers.py +27 -1
  30. {huggingface_hub-0.20.3.dist-info → huggingface_hub-0.21.0.dist-info}/METADATA +7 -3
  31. {huggingface_hub-0.20.3.dist-info → huggingface_hub-0.21.0.dist-info}/RECORD +35 -30
  32. {huggingface_hub-0.20.3.dist-info → huggingface_hub-0.21.0.dist-info}/LICENSE +0 -0
  33. {huggingface_hub-0.20.3.dist-info → huggingface_hub-0.21.0.dist-info}/WHEEL +0 -0
  34. {huggingface_hub-0.20.3.dist-info → huggingface_hub-0.21.0.dist-info}/entry_points.txt +0 -0
  35. {huggingface_hub-0.20.3.dist-info → huggingface_hub-0.21.0.dist-info}/top_level.txt +0 -0
@@ -10,6 +10,7 @@ from typing import Any, Dict, List, NoReturn, Optional, Tuple, Union
10
10
  from urllib.parse import quote, unquote
11
11
 
12
12
  import fsspec
13
+ from requests import Response
13
14
 
14
15
  from ._commit_api import CommitOperationCopy, CommitOperationDelete
15
16
  from .constants import DEFAULT_REVISION, ENDPOINT, REPO_TYPE_MODEL, REPO_TYPES_MAPPING, REPO_TYPES_URL_PREFIXES
@@ -216,11 +217,15 @@ class HfFileSystem(fsspec.AbstractFileSystem):
216
217
  path: str,
217
218
  mode: str = "rb",
218
219
  revision: Optional[str] = None,
220
+ block_size: Optional[int] = None,
219
221
  **kwargs,
220
222
  ) -> "HfFileSystemFile":
221
223
  if "a" in mode:
222
224
  raise NotImplementedError("Appending to remote files is not yet supported.")
223
- return HfFileSystemFile(self, path, mode=mode, revision=revision, **kwargs)
225
+ if block_size == 0:
226
+ return HfFileSystemStreamFile(self, path, mode=mode, revision=revision, block_size=block_size, **kwargs)
227
+ else:
228
+ return HfFileSystemFile(self, path, mode=mode, revision=revision, block_size=block_size, **kwargs)
224
229
 
225
230
  def _rm(self, path: str, revision: Optional[str] = None, **kwargs) -> None:
226
231
  resolved_path = self.resolve_path(path, revision=revision)
@@ -244,9 +249,8 @@ class HfFileSystem(fsspec.AbstractFileSystem):
244
249
  **kwargs,
245
250
  ) -> None:
246
251
  resolved_path = self.resolve_path(path, revision=revision)
247
- root_path = REPO_TYPES_URL_PREFIXES.get(resolved_path.repo_type, "") + resolved_path.repo_id
248
252
  paths = self.expand_path(path, recursive=recursive, maxdepth=maxdepth, revision=revision)
249
- paths_in_repo = [path[len(root_path) + 1 :] for path in paths if not self.isdir(path)]
253
+ paths_in_repo = [self.resolve_path(path).path_in_repo for path in paths if not self.isdir(path)]
250
254
  operations = [CommitOperationDelete(path_in_repo=path_in_repo) for path_in_repo in paths_in_repo]
251
255
  commit_message = f"Delete {path} "
252
256
  commit_message += "recursively " if recursive else ""
@@ -439,7 +443,7 @@ class HfFileSystem(fsspec.AbstractFileSystem):
439
443
  resolved_path1.repo_type == resolved_path2.repo_type and resolved_path1.repo_id == resolved_path2.repo_id
440
444
  )
441
445
 
442
- if same_repo and self.info(path1, revision=resolved_path1.revision)["lfs"] is not None:
446
+ if same_repo:
443
447
  commit_message = f"Copy {path1} to {path2}"
444
448
  self._api.create_commit(
445
449
  repo_id=resolved_path1.repo_id,
@@ -573,6 +577,20 @@ class HfFileSystem(fsspec.AbstractFileSystem):
573
577
  except: # noqa: E722
574
578
  return False
575
579
 
580
+ def url(self, path: str) -> str:
581
+ """Get the HTTP URL of the given path"""
582
+ resolved_path = self.resolve_path(path)
583
+ url = hf_hub_url(
584
+ resolved_path.repo_id,
585
+ resolved_path.path_in_repo,
586
+ repo_type=resolved_path.repo_type,
587
+ revision=resolved_path.revision,
588
+ endpoint=self.endpoint,
589
+ )
590
+ if self.isdir(path):
591
+ url = url.replace("/resolve/", "/tree/", 1)
592
+ return url
593
+
576
594
  @property
577
595
  def transaction(self):
578
596
  """A context within which files are committed together upon exit
@@ -593,9 +611,6 @@ class HfFileSystem(fsspec.AbstractFileSystem):
593
611
 
594
612
  class HfFileSystemFile(fsspec.spec.AbstractBufferedFile):
595
613
  def __init__(self, fs: HfFileSystem, path: str, revision: Optional[str] = None, **kwargs):
596
- super().__init__(fs, path, **kwargs)
597
- self.fs: HfFileSystem
598
-
599
614
  try:
600
615
  self.resolved_path = fs.resolve_path(path, revision=revision)
601
616
  except FileNotFoundError as e:
@@ -603,6 +618,8 @@ class HfFileSystemFile(fsspec.spec.AbstractBufferedFile):
603
618
  raise FileNotFoundError(
604
619
  f"{e}.\nMake sure the repository and revision exist before writing data."
605
620
  ) from e
621
+ super().__init__(fs, self.resolved_path.unresolve(), **kwargs)
622
+ self.fs: HfFileSystem
606
623
 
607
624
  def __del__(self):
608
625
  if not hasattr(self, "resolved_path"):
@@ -622,7 +639,7 @@ class HfFileSystemFile(fsspec.spec.AbstractBufferedFile):
622
639
  repo_type=self.resolved_path.repo_type,
623
640
  endpoint=self.fs.endpoint,
624
641
  )
625
- r = http_backoff("GET", url, headers=headers)
642
+ r = http_backoff("GET", url, headers=headers, retry_on_status_codes=(502, 503, 504))
626
643
  hf_raise_for_status(r)
627
644
  return r.content
628
645
 
@@ -650,6 +667,108 @@ class HfFileSystemFile(fsspec.spec.AbstractBufferedFile):
650
667
  path=self.resolved_path.unresolve(),
651
668
  )
652
669
 
670
+ def url(self) -> str:
671
+ return self.fs.url(self.path)
672
+
673
+
674
+ class HfFileSystemStreamFile(fsspec.spec.AbstractBufferedFile):
675
+ def __init__(
676
+ self,
677
+ fs: HfFileSystem,
678
+ path: str,
679
+ mode: str = "rb",
680
+ revision: Optional[str] = None,
681
+ block_size: int = 0,
682
+ cache_type: str = "none",
683
+ **kwargs,
684
+ ):
685
+ if block_size != 0:
686
+ raise ValueError(f"HfFileSystemStreamFile only supports block_size=0 but got {block_size}")
687
+ if cache_type != "none":
688
+ raise ValueError(f"HfFileSystemStreamFile only supports cache_type='none' but got {cache_type}")
689
+ if "w" in mode:
690
+ raise ValueError(f"HfFileSystemStreamFile only supports reading but got mode='{mode}'")
691
+ try:
692
+ self.resolved_path = fs.resolve_path(path, revision=revision)
693
+ except FileNotFoundError as e:
694
+ if "w" in kwargs.get("mode", ""):
695
+ raise FileNotFoundError(
696
+ f"{e}.\nMake sure the repository and revision exist before writing data."
697
+ ) from e
698
+ # avoid an unecessary .info() call to instantiate .details
699
+ self.details = {"name": self.resolved_path.unresolve(), "size": None}
700
+ super().__init__(
701
+ fs, self.resolved_path.unresolve(), mode=mode, block_size=block_size, cache_type=cache_type, **kwargs
702
+ )
703
+ self.response: Optional[Response] = None
704
+ self.fs: HfFileSystem
705
+
706
+ def seek(self, loc: int, whence: int = 0):
707
+ if loc == 0 and whence == 1:
708
+ return
709
+ if loc == self.loc and whence == 0:
710
+ return
711
+ raise ValueError("Cannot seek streaming HF file")
712
+
713
+ def read(self, length: int = -1):
714
+ read_args = (length,) if length >= 0 else ()
715
+ if self.response is None or self.response.raw.isclosed():
716
+ url = hf_hub_url(
717
+ repo_id=self.resolved_path.repo_id,
718
+ revision=self.resolved_path.revision,
719
+ filename=self.resolved_path.path_in_repo,
720
+ repo_type=self.resolved_path.repo_type,
721
+ endpoint=self.fs.endpoint,
722
+ )
723
+ self.response = http_backoff(
724
+ "GET",
725
+ url,
726
+ headers=self.fs._api._build_hf_headers(),
727
+ retry_on_status_codes=(502, 503, 504),
728
+ stream=True,
729
+ )
730
+ hf_raise_for_status(self.response)
731
+ try:
732
+ out = self.response.raw.read(*read_args)
733
+ except Exception:
734
+ self.response.close()
735
+
736
+ # Retry by recreating the connection
737
+ url = hf_hub_url(
738
+ repo_id=self.resolved_path.repo_id,
739
+ revision=self.resolved_path.revision,
740
+ filename=self.resolved_path.path_in_repo,
741
+ repo_type=self.resolved_path.repo_type,
742
+ endpoint=self.fs.endpoint,
743
+ )
744
+ self.response = http_backoff(
745
+ "GET",
746
+ url,
747
+ headers={"Range": "bytes=%d-" % self.loc, **self.fs._api._build_hf_headers()},
748
+ retry_on_status_codes=(502, 503, 504),
749
+ stream=True,
750
+ )
751
+ hf_raise_for_status(self.response)
752
+ try:
753
+ out = self.response.raw.read(*read_args)
754
+ except Exception:
755
+ self.response.close()
756
+ raise
757
+ self.loc += len(out)
758
+ return out
759
+
760
+ def url(self) -> str:
761
+ return self.fs.url(self.path)
762
+
763
+ def __del__(self):
764
+ if not hasattr(self, "resolved_path"):
765
+ # Means that the constructor failed. Nothing to do.
766
+ return
767
+ return super().__del__()
768
+
769
+ def __reduce__(self):
770
+ return reopen, (self.fs, self.path, self.mode, self.blocksize, self.cache.name)
771
+
653
772
 
654
773
  def safe_revision(revision: str) -> str:
655
774
  return revision if SPECIAL_REFS_REVISION_REGEX.match(revision) else safe_quote(revision)
@@ -668,3 +787,7 @@ def _raise_file_not_found(path: str, err: Optional[Exception]) -> NoReturn:
668
787
  elif isinstance(err, HFValidationError):
669
788
  msg = f"{path} (invalid repository id)"
670
789
  raise FileNotFoundError(msg) from err
790
+
791
+
792
+ def reopen(fs: HfFileSystem, path: str, mode: str, block_size: int, cache_type: str):
793
+ return fs.open(path, mode=mode, block_size=block_size, cache_type=cache_type)
@@ -1,17 +1,36 @@
1
+ import inspect
1
2
  import json
2
3
  import os
4
+ from dataclasses import asdict, is_dataclass
3
5
  from pathlib import Path
4
- from typing import Dict, List, Optional, Type, TypeVar, Union
6
+ from typing import TYPE_CHECKING, Dict, List, Optional, Type, TypeVar, Union, get_args
5
7
 
6
- from .constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME
7
- from .file_download import hf_hub_download, is_torch_available
8
+ from .constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME, SAFETENSORS_SINGLE_FILE
9
+ from .file_download import hf_hub_download
8
10
  from .hf_api import HfApi
9
- from .utils import HfHubHTTPError, SoftTemporaryDirectory, logging, validate_hf_hub_args
11
+ from .utils import (
12
+ EntryNotFoundError,
13
+ HfHubHTTPError,
14
+ SoftTemporaryDirectory,
15
+ is_safetensors_available,
16
+ is_torch_available,
17
+ logging,
18
+ validate_hf_hub_args,
19
+ )
20
+ from .utils._deprecation import _deprecate_arguments
10
21
 
11
22
 
23
+ if TYPE_CHECKING:
24
+ from _typeshed import DataclassInstance
25
+
12
26
  if is_torch_available():
13
27
  import torch # type: ignore
14
28
 
29
+ if is_safetensors_available():
30
+ from safetensors import safe_open
31
+ from safetensors.torch import save_file
32
+
33
+
15
34
  logger = logging.get_logger(__name__)
16
35
 
17
36
  # Generic variable that is either ModelHubMixin or a subclass thereof
@@ -25,16 +44,89 @@ class ModelHubMixin:
25
44
  To integrate your framework, your model class must inherit from this class. Custom logic for saving/loading models
26
45
  have to be overwritten in [`_from_pretrained`] and [`_save_pretrained`]. [`PyTorchModelHubMixin`] is a good example
27
46
  of mixin integration with the Hub. Check out our [integration guide](../guides/integrations) for more instructions.
47
+
48
+ Example:
49
+
50
+ ```python
51
+ >>> from dataclasses import dataclass
52
+ >>> from huggingface_hub import ModelHubMixin
53
+
54
+ # Define your model configuration (optional)
55
+ >>> @dataclass
56
+ ... class Config:
57
+ ... foo: int = 512
58
+ ... bar: str = "cpu"
59
+
60
+ # Inherit from ModelHubMixin (and optionally from your framework's model class)
61
+ >>> class MyCustomModel(ModelHubMixin):
62
+ ... def __init__(self, config: Config):
63
+ ... # define how to initialize your model
64
+ ... super().__init__()
65
+ ... ...
66
+ ...
67
+ ... def _save_pretrained(self, save_directory: Path) -> None:
68
+ ... # define how to serialize your model
69
+ ... ...
70
+ ...
71
+ ... @classmethod
72
+ ... def from_pretrained(
73
+ ... cls: Type[T],
74
+ ... pretrained_model_name_or_path: Union[str, Path],
75
+ ... *,
76
+ ... force_download: bool = False,
77
+ ... resume_download: bool = False,
78
+ ... proxies: Optional[Dict] = None,
79
+ ... token: Optional[Union[str, bool]] = None,
80
+ ... cache_dir: Optional[Union[str, Path]] = None,
81
+ ... local_files_only: bool = False,
82
+ ... revision: Optional[str] = None,
83
+ ... **model_kwargs,
84
+ ... ) -> T:
85
+ ... # define how to deserialize your model
86
+ ... ...
87
+
88
+ >>> model = MyCustomModel(config=Config(foo=256, bar="gpu"))
89
+
90
+ # Save model weights to local directory
91
+ >>> model.save_pretrained("my-awesome-model")
92
+
93
+ # Push model weights to the Hub
94
+ >>> model.push_to_hub("my-awesome-model")
95
+
96
+ # Download and initialize weights from the Hub
97
+ >>> reloaded_model = MyCustomModel.from_pretrained("username/my-awesome-model")
98
+ >>> reloaded_model.config
99
+ Config(foo=256, bar="gpu")
100
+ ```
28
101
  """
29
102
 
103
+ config: Optional[Union[dict, "DataclassInstance"]] = None
104
+ # ^ optional config attribute automatically set in `from_pretrained` (if not already set by the subclass)
105
+
106
+ def __new__(cls, *args, **kwargs) -> "ModelHubMixin":
107
+ instance = super().__new__(cls)
108
+
109
+ # Set `config` attribute if not already set by the subclass
110
+ if instance.config is None:
111
+ if "config" in kwargs:
112
+ instance.config = kwargs["config"]
113
+ elif len(args) > 0:
114
+ sig = inspect.signature(cls.__init__)
115
+ parameters = list(sig.parameters)[1:] # remove `self`
116
+ for key, value in zip(parameters, args):
117
+ if key == "config":
118
+ instance.config = value
119
+ break
120
+ return instance
121
+
30
122
  def save_pretrained(
31
123
  self,
32
124
  save_directory: Union[str, Path],
33
125
  *,
34
- config: Optional[dict] = None,
126
+ config: Optional[Union[dict, "DataclassInstance"]] = None,
35
127
  repo_id: Optional[str] = None,
36
128
  push_to_hub: bool = False,
37
- **kwargs,
129
+ **push_to_hub_kwargs,
38
130
  ) -> Optional[str]:
39
131
  """
40
132
  Save weights in local directory.
@@ -42,8 +134,8 @@ class ModelHubMixin:
42
134
  Args:
43
135
  save_directory (`str` or `Path`):
44
136
  Path to directory in which the model weights and configuration will be saved.
45
- config (`dict`, *optional*):
46
- Model configuration specified as a key/value dictionary.
137
+ config (`dict` or `DataclassInstance`, *optional*):
138
+ Model configuration specified as a key/value dictionary or a dataclass instance.
47
139
  push_to_hub (`bool`, *optional*, defaults to `False`):
48
140
  Whether or not to push your model to the Huggingface Hub after saving it.
49
141
  repo_id (`str`, *optional*):
@@ -55,15 +147,20 @@ class ModelHubMixin:
55
147
  save_directory = Path(save_directory)
56
148
  save_directory.mkdir(parents=True, exist_ok=True)
57
149
 
58
- # saving model weights/files
150
+ # save model weights/files (framework-specific)
59
151
  self._save_pretrained(save_directory)
60
152
 
61
- # saving config
62
- if isinstance(config, dict):
63
- (save_directory / CONFIG_NAME).write_text(json.dumps(config))
153
+ # save config (if provided)
154
+ if config is None:
155
+ config = self.config
156
+ if config is not None:
157
+ if is_dataclass(config):
158
+ config = asdict(config) # type: ignore[arg-type]
159
+ (save_directory / CONFIG_NAME).write_text(json.dumps(config, indent=2))
64
160
 
161
+ # push to the Hub if required
65
162
  if push_to_hub:
66
- kwargs = kwargs.copy() # soft-copy to avoid mutating input
163
+ kwargs = push_to_hub_kwargs.copy() # soft-copy to avoid mutating input
67
164
  if config is not None: # kwarg for `push_to_hub`
68
165
  kwargs["config"] = config
69
166
  if repo_id is None:
@@ -126,17 +223,17 @@ class ModelHubMixin:
126
223
  model_kwargs (`Dict`, *optional*):
127
224
  Additional kwargs to pass to the model during initialization.
128
225
  """
129
- model_id = pretrained_model_name_or_path
226
+ model_id = str(pretrained_model_name_or_path)
130
227
  config_file: Optional[str] = None
131
228
  if os.path.isdir(model_id):
132
229
  if CONFIG_NAME in os.listdir(model_id):
133
230
  config_file = os.path.join(model_id, CONFIG_NAME)
134
231
  else:
135
232
  logger.warning(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
136
- elif isinstance(model_id, str):
233
+ else:
137
234
  try:
138
235
  config_file = hf_hub_download(
139
- repo_id=str(model_id),
236
+ repo_id=model_id,
140
237
  filename=CONFIG_NAME,
141
238
  revision=revision,
142
239
  cache_dir=cache_dir,
@@ -146,15 +243,35 @@ class ModelHubMixin:
146
243
  token=token,
147
244
  local_files_only=local_files_only,
148
245
  )
149
- except HfHubHTTPError:
150
- logger.info(f"{CONFIG_NAME} not found in HuggingFace Hub.")
246
+ except HfHubHTTPError as e:
247
+ logger.info(f"{CONFIG_NAME} not found on the HuggingFace Hub: {str(e)}")
151
248
 
249
+ config = None
152
250
  if config_file is not None:
251
+ # Read config
153
252
  with open(config_file, "r", encoding="utf-8") as f:
154
253
  config = json.load(f)
155
- model_kwargs.update({"config": config})
156
254
 
157
- return cls._from_pretrained(
255
+ # Check if class expect a `config` argument
256
+ init_parameters = inspect.signature(cls.__init__).parameters
257
+ if "config" in init_parameters:
258
+ # Check if `config` argument is a dataclass
259
+ config_annotation = init_parameters["config"].annotation
260
+ if config_annotation is inspect.Parameter.empty:
261
+ pass # no annotation
262
+ elif is_dataclass(config_annotation):
263
+ config = config_annotation(**config) # expect a dataclass
264
+ else:
265
+ # if Optional/Union annotation => check if a dataclass is in the Union
266
+ for _sub_annotation in get_args(config_annotation):
267
+ if is_dataclass(_sub_annotation):
268
+ config = _sub_annotation(**config)
269
+ break
270
+
271
+ # Forward config to model initialization
272
+ model_kwargs["config"] = config
273
+
274
+ instance = cls._from_pretrained(
158
275
  model_id=str(model_id),
159
276
  revision=revision,
160
277
  cache_dir=cache_dir,
@@ -166,6 +283,13 @@ class ModelHubMixin:
166
283
  **model_kwargs,
167
284
  )
168
285
 
286
+ # Implicitly set the config as instance attribute if not already set by the class
287
+ # This way `config` will be available when calling `save_pretrained` or `push_to_hub`.
288
+ if config is not None and instance.config is None:
289
+ instance.config = config
290
+
291
+ return instance
292
+
169
293
  @classmethod
170
294
  def _from_pretrained(
171
295
  cls: Type[T],
@@ -215,21 +339,27 @@ class ModelHubMixin:
215
339
  """
216
340
  raise NotImplementedError
217
341
 
342
+ @_deprecate_arguments(
343
+ version="0.23.0",
344
+ deprecated_args=["api_endpoint"],
345
+ custom_message="Use `HF_ENDPOINT` environment variable instead.",
346
+ )
218
347
  @validate_hf_hub_args
219
348
  def push_to_hub(
220
349
  self,
221
350
  repo_id: str,
222
351
  *,
223
- config: Optional[dict] = None,
352
+ config: Optional[Union[dict, "DataclassInstance"]] = None,
224
353
  commit_message: str = "Push model using huggingface_hub.",
225
354
  private: bool = False,
226
- api_endpoint: Optional[str] = None,
227
355
  token: Optional[str] = None,
228
356
  branch: Optional[str] = None,
229
357
  create_pr: Optional[bool] = None,
230
358
  allow_patterns: Optional[Union[List[str], str]] = None,
231
359
  ignore_patterns: Optional[Union[List[str], str]] = None,
232
360
  delete_patterns: Optional[Union[List[str], str]] = None,
361
+ # TODO: remove once deprecated
362
+ api_endpoint: Optional[str] = None,
233
363
  ) -> str:
234
364
  """
235
365
  Upload model checkpoint to the Hub.
@@ -238,12 +368,11 @@ class ModelHubMixin:
238
368
  `delete_patterns` to delete existing remote files in the same commit. See [`upload_folder`] reference for more
239
369
  details.
240
370
 
241
-
242
371
  Args:
243
372
  repo_id (`str`):
244
373
  ID of the repository to push to (example: `"username/my-model"`).
245
- config (`dict`, *optional*):
246
- Configuration object to be saved alongside the model weights.
374
+ config (`dict` or `DataclassInstance`, *optional*):
375
+ Model configuration specified as a key/value dictionary or a dataclass instance.
247
376
  commit_message (`str`, *optional*):
248
377
  Message to commit while pushing.
249
378
  private (`bool`, *optional*, defaults to `False`):
@@ -296,16 +425,22 @@ class PyTorchModelHubMixin(ModelHubMixin):
296
425
  Example:
297
426
 
298
427
  ```python
428
+ >>> from dataclasses import dataclass
299
429
  >>> import torch
300
430
  >>> import torch.nn as nn
301
431
  >>> from huggingface_hub import PyTorchModelHubMixin
302
432
 
433
+ >>> @dataclass
434
+ ... class Config:
435
+ ... hidden_size: int = 512
436
+ ... vocab_size: int = 30000
437
+ ... output_size: int = 4
303
438
 
304
439
  >>> class MyModel(nn.Module, PyTorchModelHubMixin):
305
- ... def __init__(self):
440
+ ... def __init__(self, config: Config):
306
441
  ... super().__init__()
307
- ... self.param = nn.Parameter(torch.rand(3, 4))
308
- ... self.linear = nn.Linear(4, 5)
442
+ ... self.param = nn.Parameter(torch.rand(config.hidden_size, config.vocab_size))
443
+ ... self.linear = nn.Linear(config.output_size, config.vocab_size)
309
444
 
310
445
  ... def forward(self, x):
311
446
  ... return self.linear(x + self.param)
@@ -325,7 +460,7 @@ class PyTorchModelHubMixin(ModelHubMixin):
325
460
  def _save_pretrained(self, save_directory: Path) -> None:
326
461
  """Save weights from a Pytorch model to a local directory."""
327
462
  model_to_save = self.module if hasattr(self, "module") else self # type: ignore
328
- torch.save(model_to_save.state_dict(), save_directory / PYTORCH_WEIGHTS_NAME)
463
+ save_file(model_to_save.state_dict(), save_directory / SAFETENSORS_SINGLE_FILE)
329
464
 
330
465
  @classmethod
331
466
  def _from_pretrained(
@@ -344,25 +479,52 @@ class PyTorchModelHubMixin(ModelHubMixin):
344
479
  **model_kwargs,
345
480
  ):
346
481
  """Load Pytorch pretrained weights and return the loaded model."""
482
+ model = cls(**model_kwargs)
347
483
  if os.path.isdir(model_id):
348
484
  print("Loading weights from local directory")
349
- model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME)
485
+ model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
486
+ return cls._load_as_safetensor(model, model_file, map_location, strict)
350
487
  else:
351
- model_file = hf_hub_download(
352
- repo_id=model_id,
353
- filename=PYTORCH_WEIGHTS_NAME,
354
- revision=revision,
355
- cache_dir=cache_dir,
356
- force_download=force_download,
357
- proxies=proxies,
358
- resume_download=resume_download,
359
- token=token,
360
- local_files_only=local_files_only,
361
- )
362
- model = cls(**model_kwargs)
488
+ try:
489
+ model_file = hf_hub_download(
490
+ repo_id=model_id,
491
+ filename=SAFETENSORS_SINGLE_FILE,
492
+ revision=revision,
493
+ cache_dir=cache_dir,
494
+ force_download=force_download,
495
+ proxies=proxies,
496
+ resume_download=resume_download,
497
+ token=token,
498
+ local_files_only=local_files_only,
499
+ )
500
+ return cls._load_as_safetensor(model, model_file, map_location, strict)
501
+ except EntryNotFoundError:
502
+ model_file = hf_hub_download(
503
+ repo_id=model_id,
504
+ filename=PYTORCH_WEIGHTS_NAME,
505
+ revision=revision,
506
+ cache_dir=cache_dir,
507
+ force_download=force_download,
508
+ proxies=proxies,
509
+ resume_download=resume_download,
510
+ token=token,
511
+ local_files_only=local_files_only,
512
+ )
513
+ return cls._load_as_pickle(model, model_file, map_location, strict)
363
514
 
515
+ @classmethod
516
+ def _load_as_pickle(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
364
517
  state_dict = torch.load(model_file, map_location=torch.device(map_location))
365
518
  model.load_state_dict(state_dict, strict=strict) # type: ignore
366
519
  model.eval() # type: ignore
520
+ return model
367
521
 
522
+ @classmethod
523
+ def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
524
+ state_dict = {}
525
+ with safe_open(model_file, framework="pt", device=map_location) as f: # type: ignore [attr-defined]
526
+ for k in f.keys():
527
+ state_dict[k] = f.get_tensor(k)
528
+ model.load_state_dict(state_dict, strict=strict) # type: ignore
529
+ model.eval() # type: ignore
368
530
  return model