huggingface-hub 0.26.3__py3-none-any.whl → 0.26.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of huggingface-hub might be problematic. Click here for more details.

@@ -46,7 +46,7 @@ import sys
46
46
  from typing import TYPE_CHECKING
47
47
 
48
48
 
49
- __version__ = "0.26.3"
49
+ __version__ = "0.26.4"
50
50
 
51
51
  # Alphabetical order of definitions is ensured in tests
52
52
  # WARNING: any comment added in this dictionary definition will be lost when
@@ -41,6 +41,8 @@ def save_torch_model(
41
41
  max_shard_size: Union[int, str] = MAX_SHARD_SIZE,
42
42
  metadata: Optional[Dict[str, str]] = None,
43
43
  safe_serialization: bool = True,
44
+ is_main_process: bool = True,
45
+ shared_tensors_to_discard: Optional[List[str]] = None,
44
46
  ):
45
47
  """
46
48
  Saves a given torch model to disk, handling sharding and shared tensors issues.
@@ -64,6 +66,12 @@ def save_torch_model(
64
66
 
65
67
  </Tip>
66
68
 
69
+ <Tip warning={true}>
70
+
71
+ If your model is a `transformers.PreTrainedModel`, you should pass `model._tied_weights_keys` as `shared_tensors_to_discard` to properly handle shared tensors saving. This ensures the correct duplicate tensors are discarded during saving.
72
+
73
+ </Tip>
74
+
67
75
  Args:
68
76
  model (`torch.nn.Module`):
69
77
  The model to save on disk.
@@ -88,6 +96,13 @@ def save_torch_model(
88
96
  Whether to save as safetensors, which is the default behavior. If `False`, the shards are saved as pickle.
89
97
  Safe serialization is recommended for security reasons. Saving as pickle is deprecated and will be removed
90
98
  in a future version.
99
+ is_main_process (`bool`, *optional*):
100
+ Whether the process calling this is the main process or not. Useful when in distributed training like
101
+ TPUs and need to call this function from all processes. In this case, set `is_main_process=True` only on
102
+ the main process to avoid race conditions. Defaults to True.
103
+ shared_tensors_to_discard (`List[str]`, *optional*):
104
+ List of tensor names to drop when saving shared tensors. If not provided and shared tensors are
105
+ detected, it will drop the first name alphabetically.
91
106
 
92
107
  Example:
93
108
 
@@ -112,6 +127,8 @@ def save_torch_model(
112
127
  metadata=metadata,
113
128
  safe_serialization=safe_serialization,
114
129
  save_directory=save_directory,
130
+ is_main_process=is_main_process,
131
+ shared_tensors_to_discard=shared_tensors_to_discard,
115
132
  )
116
133
 
117
134
 
@@ -124,6 +141,8 @@ def save_torch_state_dict(
124
141
  max_shard_size: Union[int, str] = MAX_SHARD_SIZE,
125
142
  metadata: Optional[Dict[str, str]] = None,
126
143
  safe_serialization: bool = True,
144
+ is_main_process: bool = True,
145
+ shared_tensors_to_discard: Optional[List[str]] = None,
127
146
  ) -> None:
128
147
  """
129
148
  Save a model state dictionary to the disk, handling sharding and shared tensors issues.
@@ -147,6 +166,12 @@ def save_torch_state_dict(
147
166
 
148
167
  </Tip>
149
168
 
169
+ <Tip warning={true}>
170
+
171
+ If your model is a `transformers.PreTrainedModel`, you should pass `model._tied_weights_keys` as `shared_tensors_to_discard` to properly handle shared tensors saving. This ensures the correct duplicate tensors are discarded during saving.
172
+
173
+ </Tip>
174
+
150
175
  Args:
151
176
  state_dict (`Dict[str, torch.Tensor]`):
152
177
  The state dictionary to save.
@@ -171,6 +196,13 @@ def save_torch_state_dict(
171
196
  Whether to save as safetensors, which is the default behavior. If `False`, the shards are saved as pickle.
172
197
  Safe serialization is recommended for security reasons. Saving as pickle is deprecated and will be removed
173
198
  in a future version.
199
+ is_main_process (`bool`, *optional*):
200
+ Whether the process calling this is the main process or not. Useful when in distributed training like
201
+ TPUs and need to call this function from all processes. In this case, set `is_main_process=True` only on
202
+ the main process to avoid race conditions. Defaults to True.
203
+ shared_tensors_to_discard (`List[str]`, *optional*):
204
+ List of tensor names to drop when saving shared tensors. If not provided and shared tensors are
205
+ detected, it will drop the first name alphabetically.
174
206
 
175
207
  Example:
176
208
 
@@ -192,7 +224,8 @@ def save_torch_state_dict(
192
224
  else constants.PYTORCH_WEIGHTS_FILE_PATTERN
193
225
  )
194
226
 
195
- # Imports correct library
227
+ if metadata is None:
228
+ metadata = {}
196
229
  if safe_serialization:
197
230
  try:
198
231
  from safetensors.torch import save_file as save_file_fn
@@ -201,7 +234,13 @@ def save_torch_state_dict(
201
234
  "Please install `safetensors` to use safe serialization. "
202
235
  "You can install it with `pip install safetensors`."
203
236
  ) from e
204
-
237
+ # Clean state dict for safetensors
238
+ state_dict = _clean_state_dict_for_safetensors(
239
+ state_dict,
240
+ metadata,
241
+ force_contiguous=force_contiguous,
242
+ shared_tensors_to_discard=shared_tensors_to_discard,
243
+ )
205
244
  else:
206
245
  from torch import save as save_file_fn # type: ignore[assignment]
207
246
 
@@ -210,13 +249,6 @@ def save_torch_state_dict(
210
249
  "pickled models from untrusted sources. If you intend to share your model, we strongly recommend "
211
250
  "using safe serialization by installing `safetensors` with `pip install safetensors`."
212
251
  )
213
-
214
- # Clean state dict for safetensors
215
- if metadata is None:
216
- metadata = {}
217
- if safe_serialization:
218
- state_dict = _clean_state_dict_for_safetensors(state_dict, metadata, force_contiguous=force_contiguous)
219
-
220
252
  # Split dict
221
253
  state_dict_split = split_torch_state_dict_into_shards(
222
254
  state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
@@ -459,7 +491,10 @@ def storage_ptr(tensor: "torch.Tensor") -> Union[int, Tuple[Any, ...]]:
459
491
 
460
492
 
461
493
  def _clean_state_dict_for_safetensors(
462
- state_dict: Dict[str, "torch.Tensor"], metadata: Dict[str, str], force_contiguous: bool = True
494
+ state_dict: Dict[str, "torch.Tensor"],
495
+ metadata: Dict[str, str],
496
+ force_contiguous: bool = True,
497
+ shared_tensors_to_discard: Optional[List[str]] = None,
463
498
  ):
464
499
  """Remove shared tensors from state_dict and update metadata accordingly (for reloading).
465
500
 
@@ -467,7 +502,7 @@ def _clean_state_dict_for_safetensors(
467
502
 
468
503
  Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L155.
469
504
  """
470
- to_removes = _remove_duplicate_names(state_dict)
505
+ to_removes = _remove_duplicate_names(state_dict, discard_names=shared_tensors_to_discard)
471
506
  for kept_name, to_remove_group in to_removes.items():
472
507
  for to_remove in to_remove_group:
473
508
  if metadata is None:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: huggingface-hub
3
- Version: 0.26.3
3
+ Version: 0.26.4
4
4
  Summary: Client library to download and publish models, datasets and other repos on the huggingface.co hub
5
5
  Home-page: https://github.com/huggingface/huggingface_hub
6
6
  Author: Hugging Face, Inc.
@@ -1,4 +1,4 @@
1
- huggingface_hub/__init__.py,sha256=TvADZbi2Jw5oeg_Mu-yZ7xo66rN_x7RTs1F6c9gt2aI,35993
1
+ huggingface_hub/__init__.py,sha256=kXk5sqwCRAGkjc6nnpI_Mu-LGaIoinBcbXytSzhca3s,35993
2
2
  huggingface_hub/_commit_api.py,sha256=Y9eTaW4bYzxtrZsSniVtfeAuFafqx8x1ofMI5es8hvM,31057
3
3
  huggingface_hub/_commit_scheduler.py,sha256=nlJS_vnLb8i92NLrRwJX8Mg9QZ7f3kfLbLlQuEd5YjU,13647
4
4
  huggingface_hub/_inference_endpoints.py,sha256=wzjD8P68VpUDHzIDbXzFXsM2Y-aNVSAap7BXsZFuthk,16750
@@ -79,7 +79,7 @@ huggingface_hub/inference/_generated/types/zero_shot_object_detection.py,sha256=
79
79
  huggingface_hub/serialization/__init__.py,sha256=z5MLxMqz0Y2qST-3Lj0PZHUONL-SGRlc0g4Z6MdL6rw,988
80
80
  huggingface_hub/serialization/_base.py,sha256=JZneES-HgcRH9C2SQehIGRDtT7nS7emu-RRV4ZjB6xo,8124
81
81
  huggingface_hub/serialization/_tensorflow.py,sha256=zHOvEMg-JHC55Fm4roDT3LUCDO5zB9qtXZffG065RAM,3625
82
- huggingface_hub/serialization/_torch.py,sha256=i6UFAHk1MDx_RONaXYolsISVa0V3a_YH-bdQtCYnmtg,26498
82
+ huggingface_hub/serialization/_torch.py,sha256=KlCRgLarzegkbfUmb73h82p2vDvgKrWw03ltQ7klI2Q,28685
83
83
  huggingface_hub/templates/datasetcard_template.md,sha256=W-EMqR6wndbrnZorkVv56URWPG49l7MATGeI015kTvs,5503
84
84
  huggingface_hub/templates/modelcard_template.md,sha256=4AqArS3cqdtbit5Bo-DhjcnDFR-pza5hErLLTPM4Yuc,6870
85
85
  huggingface_hub/utils/__init__.py,sha256=aMEsiXGi93z-dXz1W7FFma71tAMeKw0SoKVZSQUeE_4,3525
@@ -109,9 +109,9 @@ huggingface_hub/utils/insecure_hashlib.py,sha256=OjxlvtSQHpbLp9PWSrXBDJ0wHjxCBU-
109
109
  huggingface_hub/utils/logging.py,sha256=Cp03s0uEl3kDM9XHQW9a8GAoExODQ-e7kEtgMt-_To8,4728
110
110
  huggingface_hub/utils/sha.py,sha256=OFnNGCba0sNcT2gUwaVCJnldxlltrHHe0DS_PCpV3C4,2134
111
111
  huggingface_hub/utils/tqdm.py,sha256=jQiVYwRG78HK4_54u0vTtz6Kt9IMGiHy3ixbIn3h2TU,9368
112
- huggingface_hub-0.26.3.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
113
- huggingface_hub-0.26.3.dist-info/METADATA,sha256=xeabjkgCLGRol8NXGS0ftagVPKUfTaq2iaKfd_k5P9M,13091
114
- huggingface_hub-0.26.3.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
115
- huggingface_hub-0.26.3.dist-info/entry_points.txt,sha256=Y3Z2L02rBG7va_iE6RPXolIgwOdwUFONyRN3kXMxZ0g,131
116
- huggingface_hub-0.26.3.dist-info/top_level.txt,sha256=8KzlQJAY4miUvjAssOAJodqKOw3harNzuiwGQ9qLSSk,16
117
- huggingface_hub-0.26.3.dist-info/RECORD,,
112
+ huggingface_hub-0.26.4.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
113
+ huggingface_hub-0.26.4.dist-info/METADATA,sha256=PqF3HduxB5xJmjEI0-bMVfg2z72wgsioKo-ZHyH226g,13091
114
+ huggingface_hub-0.26.4.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
115
+ huggingface_hub-0.26.4.dist-info/entry_points.txt,sha256=Y3Z2L02rBG7va_iE6RPXolIgwOdwUFONyRN3kXMxZ0g,131
116
+ huggingface_hub-0.26.4.dist-info/top_level.txt,sha256=8KzlQJAY4miUvjAssOAJodqKOw3harNzuiwGQ9qLSSk,16
117
+ huggingface_hub-0.26.4.dist-info/RECORD,,