huggingface-hub 0.26.3__py3-none-any.whl → 0.27.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of huggingface-hub might be problematic. Click here for more details.
- huggingface_hub/__init__.py +49 -23
- huggingface_hub/_commit_scheduler.py +30 -4
- huggingface_hub/_local_folder.py +0 -4
- huggingface_hub/_login.py +38 -54
- huggingface_hub/_snapshot_download.py +6 -3
- huggingface_hub/_tensorboard_logger.py +2 -3
- huggingface_hub/_upload_large_folder.py +1 -1
- huggingface_hub/errors.py +19 -0
- huggingface_hub/fastai_utils.py +3 -2
- huggingface_hub/file_download.py +10 -12
- huggingface_hub/hf_api.py +102 -498
- huggingface_hub/hf_file_system.py +274 -35
- huggingface_hub/hub_mixin.py +5 -25
- huggingface_hub/inference/_client.py +185 -136
- huggingface_hub/inference/_common.py +2 -2
- huggingface_hub/inference/_generated/_async_client.py +186 -137
- huggingface_hub/inference/_generated/types/__init__.py +31 -10
- huggingface_hub/inference/_generated/types/audio_classification.py +3 -5
- huggingface_hub/inference/_generated/types/automatic_speech_recognition.py +6 -9
- huggingface_hub/inference/_generated/types/chat_completion.py +8 -5
- huggingface_hub/inference/_generated/types/depth_estimation.py +1 -1
- huggingface_hub/inference/_generated/types/document_question_answering.py +2 -6
- huggingface_hub/inference/_generated/types/feature_extraction.py +1 -1
- huggingface_hub/inference/_generated/types/fill_mask.py +2 -4
- huggingface_hub/inference/_generated/types/image_classification.py +3 -5
- huggingface_hub/inference/_generated/types/image_segmentation.py +2 -4
- huggingface_hub/inference/_generated/types/image_to_image.py +2 -4
- huggingface_hub/inference/_generated/types/image_to_text.py +6 -9
- huggingface_hub/inference/_generated/types/object_detection.py +2 -4
- huggingface_hub/inference/_generated/types/question_answering.py +2 -4
- huggingface_hub/inference/_generated/types/sentence_similarity.py +1 -1
- huggingface_hub/inference/_generated/types/summarization.py +2 -4
- huggingface_hub/inference/_generated/types/table_question_answering.py +21 -3
- huggingface_hub/inference/_generated/types/text2text_generation.py +2 -4
- huggingface_hub/inference/_generated/types/text_classification.py +4 -10
- huggingface_hub/inference/_generated/types/text_to_audio.py +7 -10
- huggingface_hub/inference/_generated/types/text_to_image.py +2 -4
- huggingface_hub/inference/_generated/types/text_to_speech.py +7 -10
- huggingface_hub/inference/_generated/types/token_classification.py +11 -12
- huggingface_hub/inference/_generated/types/translation.py +2 -4
- huggingface_hub/inference/_generated/types/video_classification.py +3 -4
- huggingface_hub/inference/_generated/types/visual_question_answering.py +2 -5
- huggingface_hub/inference/_generated/types/zero_shot_classification.py +8 -18
- huggingface_hub/inference/_generated/types/zero_shot_image_classification.py +9 -19
- huggingface_hub/inference/_generated/types/zero_shot_object_detection.py +7 -9
- huggingface_hub/keras_mixin.py +3 -2
- huggingface_hub/lfs.py +2 -5
- huggingface_hub/repocard_data.py +4 -4
- huggingface_hub/serialization/__init__.py +2 -0
- huggingface_hub/serialization/_dduf.py +387 -0
- huggingface_hub/serialization/_torch.py +407 -25
- huggingface_hub/utils/_cache_manager.py +1 -1
- huggingface_hub/utils/_headers.py +9 -25
- huggingface_hub/utils/tqdm.py +15 -0
- {huggingface_hub-0.26.3.dist-info → huggingface_hub-0.27.0rc1.dist-info}/METADATA +8 -3
- {huggingface_hub-0.26.3.dist-info → huggingface_hub-0.27.0rc1.dist-info}/RECORD +60 -60
- huggingface_hub/_multi_commits.py +0 -306
- {huggingface_hub-0.26.3.dist-info → huggingface_hub-0.27.0rc1.dist-info}/LICENSE +0 -0
- {huggingface_hub-0.26.3.dist-info → huggingface_hub-0.27.0rc1.dist-info}/WHEEL +0 -0
- {huggingface_hub-0.26.3.dist-info → huggingface_hub-0.27.0rc1.dist-info}/entry_points.txt +0 -0
- {huggingface_hub-0.26.3.dist-info → huggingface_hub-0.27.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -17,10 +17,12 @@ import importlib
|
|
|
17
17
|
import json
|
|
18
18
|
import os
|
|
19
19
|
import re
|
|
20
|
-
from collections import defaultdict
|
|
20
|
+
from collections import defaultdict, namedtuple
|
|
21
21
|
from functools import lru_cache
|
|
22
22
|
from pathlib import Path
|
|
23
|
-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
|
|
23
|
+
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, NamedTuple, Optional, Set, Tuple, Union
|
|
24
|
+
|
|
25
|
+
from packaging import version
|
|
24
26
|
|
|
25
27
|
from .. import constants, logging
|
|
26
28
|
from ._base import MAX_SHARD_SIZE, StateDictSplit, split_state_dict_into_shards_factory
|
|
@@ -31,6 +33,8 @@ logger = logging.get_logger(__file__)
|
|
|
31
33
|
if TYPE_CHECKING:
|
|
32
34
|
import torch
|
|
33
35
|
|
|
36
|
+
# SAVING
|
|
37
|
+
|
|
34
38
|
|
|
35
39
|
def save_torch_model(
|
|
36
40
|
model: "torch.nn.Module",
|
|
@@ -41,6 +45,8 @@ def save_torch_model(
|
|
|
41
45
|
max_shard_size: Union[int, str] = MAX_SHARD_SIZE,
|
|
42
46
|
metadata: Optional[Dict[str, str]] = None,
|
|
43
47
|
safe_serialization: bool = True,
|
|
48
|
+
is_main_process: bool = True,
|
|
49
|
+
shared_tensors_to_discard: Optional[List[str]] = None,
|
|
44
50
|
):
|
|
45
51
|
"""
|
|
46
52
|
Saves a given torch model to disk, handling sharding and shared tensors issues.
|
|
@@ -64,6 +70,12 @@ def save_torch_model(
|
|
|
64
70
|
|
|
65
71
|
</Tip>
|
|
66
72
|
|
|
73
|
+
<Tip warning={true}>
|
|
74
|
+
|
|
75
|
+
If your model is a `transformers.PreTrainedModel`, you should pass `model._tied_weights_keys` as `shared_tensors_to_discard` to properly handle shared tensors saving. This ensures the correct duplicate tensors are discarded during saving.
|
|
76
|
+
|
|
77
|
+
</Tip>
|
|
78
|
+
|
|
67
79
|
Args:
|
|
68
80
|
model (`torch.nn.Module`):
|
|
69
81
|
The model to save on disk.
|
|
@@ -88,6 +100,13 @@ def save_torch_model(
|
|
|
88
100
|
Whether to save as safetensors, which is the default behavior. If `False`, the shards are saved as pickle.
|
|
89
101
|
Safe serialization is recommended for security reasons. Saving as pickle is deprecated and will be removed
|
|
90
102
|
in a future version.
|
|
103
|
+
is_main_process (`bool`, *optional*):
|
|
104
|
+
Whether the process calling this is the main process or not. Useful when in distributed training like
|
|
105
|
+
TPUs and need to call this function from all processes. In this case, set `is_main_process=True` only on
|
|
106
|
+
the main process to avoid race conditions. Defaults to True.
|
|
107
|
+
shared_tensors_to_discard (`List[str]`, *optional*):
|
|
108
|
+
List of tensor names to drop when saving shared tensors. If not provided and shared tensors are
|
|
109
|
+
detected, it will drop the first name alphabetically.
|
|
91
110
|
|
|
92
111
|
Example:
|
|
93
112
|
|
|
@@ -112,6 +131,8 @@ def save_torch_model(
|
|
|
112
131
|
metadata=metadata,
|
|
113
132
|
safe_serialization=safe_serialization,
|
|
114
133
|
save_directory=save_directory,
|
|
134
|
+
is_main_process=is_main_process,
|
|
135
|
+
shared_tensors_to_discard=shared_tensors_to_discard,
|
|
115
136
|
)
|
|
116
137
|
|
|
117
138
|
|
|
@@ -124,6 +145,8 @@ def save_torch_state_dict(
|
|
|
124
145
|
max_shard_size: Union[int, str] = MAX_SHARD_SIZE,
|
|
125
146
|
metadata: Optional[Dict[str, str]] = None,
|
|
126
147
|
safe_serialization: bool = True,
|
|
148
|
+
is_main_process: bool = True,
|
|
149
|
+
shared_tensors_to_discard: Optional[List[str]] = None,
|
|
127
150
|
) -> None:
|
|
128
151
|
"""
|
|
129
152
|
Save a model state dictionary to the disk, handling sharding and shared tensors issues.
|
|
@@ -147,6 +170,12 @@ def save_torch_state_dict(
|
|
|
147
170
|
|
|
148
171
|
</Tip>
|
|
149
172
|
|
|
173
|
+
<Tip warning={true}>
|
|
174
|
+
|
|
175
|
+
If your model is a `transformers.PreTrainedModel`, you should pass `model._tied_weights_keys` as `shared_tensors_to_discard` to properly handle shared tensors saving. This ensures the correct duplicate tensors are discarded during saving.
|
|
176
|
+
|
|
177
|
+
</Tip>
|
|
178
|
+
|
|
150
179
|
Args:
|
|
151
180
|
state_dict (`Dict[str, torch.Tensor]`):
|
|
152
181
|
The state dictionary to save.
|
|
@@ -171,6 +200,13 @@ def save_torch_state_dict(
|
|
|
171
200
|
Whether to save as safetensors, which is the default behavior. If `False`, the shards are saved as pickle.
|
|
172
201
|
Safe serialization is recommended for security reasons. Saving as pickle is deprecated and will be removed
|
|
173
202
|
in a future version.
|
|
203
|
+
is_main_process (`bool`, *optional*):
|
|
204
|
+
Whether the process calling this is the main process or not. Useful when in distributed training like
|
|
205
|
+
TPUs and need to call this function from all processes. In this case, set `is_main_process=True` only on
|
|
206
|
+
the main process to avoid race conditions. Defaults to True.
|
|
207
|
+
shared_tensors_to_discard (`List[str]`, *optional*):
|
|
208
|
+
List of tensor names to drop when saving shared tensors. If not provided and shared tensors are
|
|
209
|
+
detected, it will drop the first name alphabetically.
|
|
174
210
|
|
|
175
211
|
Example:
|
|
176
212
|
|
|
@@ -192,7 +228,8 @@ def save_torch_state_dict(
|
|
|
192
228
|
else constants.PYTORCH_WEIGHTS_FILE_PATTERN
|
|
193
229
|
)
|
|
194
230
|
|
|
195
|
-
|
|
231
|
+
if metadata is None:
|
|
232
|
+
metadata = {}
|
|
196
233
|
if safe_serialization:
|
|
197
234
|
try:
|
|
198
235
|
from safetensors.torch import save_file as save_file_fn
|
|
@@ -201,7 +238,13 @@ def save_torch_state_dict(
|
|
|
201
238
|
"Please install `safetensors` to use safe serialization. "
|
|
202
239
|
"You can install it with `pip install safetensors`."
|
|
203
240
|
) from e
|
|
204
|
-
|
|
241
|
+
# Clean state dict for safetensors
|
|
242
|
+
state_dict = _clean_state_dict_for_safetensors(
|
|
243
|
+
state_dict,
|
|
244
|
+
metadata,
|
|
245
|
+
force_contiguous=force_contiguous,
|
|
246
|
+
shared_tensors_to_discard=shared_tensors_to_discard,
|
|
247
|
+
)
|
|
205
248
|
else:
|
|
206
249
|
from torch import save as save_file_fn # type: ignore[assignment]
|
|
207
250
|
|
|
@@ -210,27 +253,23 @@ def save_torch_state_dict(
|
|
|
210
253
|
"pickled models from untrusted sources. If you intend to share your model, we strongly recommend "
|
|
211
254
|
"using safe serialization by installing `safetensors` with `pip install safetensors`."
|
|
212
255
|
)
|
|
213
|
-
|
|
214
|
-
# Clean state dict for safetensors
|
|
215
|
-
if metadata is None:
|
|
216
|
-
metadata = {}
|
|
217
|
-
if safe_serialization:
|
|
218
|
-
state_dict = _clean_state_dict_for_safetensors(state_dict, metadata, force_contiguous=force_contiguous)
|
|
219
|
-
|
|
220
256
|
# Split dict
|
|
221
257
|
state_dict_split = split_torch_state_dict_into_shards(
|
|
222
258
|
state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
|
|
223
259
|
)
|
|
224
260
|
|
|
225
|
-
#
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
261
|
+
# Only main process should clean up existing files to avoid race conditions in distributed environment
|
|
262
|
+
if is_main_process:
|
|
263
|
+
existing_files_regex = re.compile(filename_pattern.format(suffix=r"(-\d{5}-of-\d{5})?") + r"(\.index\.json)?")
|
|
264
|
+
for filename in os.listdir(save_directory):
|
|
265
|
+
if existing_files_regex.match(filename):
|
|
266
|
+
try:
|
|
267
|
+
logger.debug(f"Removing existing file '{filename}' from folder.")
|
|
268
|
+
os.remove(os.path.join(save_directory, filename))
|
|
269
|
+
except Exception as e:
|
|
270
|
+
logger.warning(
|
|
271
|
+
f"Error when trying to remove existing '{filename}' from folder: {e}. Continuing..."
|
|
272
|
+
)
|
|
234
273
|
|
|
235
274
|
# Save each shard
|
|
236
275
|
per_file_metadata = {"format": "pt"}
|
|
@@ -336,6 +375,331 @@ def split_torch_state_dict_into_shards(
|
|
|
336
375
|
)
|
|
337
376
|
|
|
338
377
|
|
|
378
|
+
# LOADING
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
def load_torch_model(
|
|
382
|
+
model: "torch.nn.Module",
|
|
383
|
+
checkpoint_path: Union[str, os.PathLike],
|
|
384
|
+
*,
|
|
385
|
+
strict: bool = False,
|
|
386
|
+
safe: bool = True,
|
|
387
|
+
weights_only: bool = False,
|
|
388
|
+
map_location: Optional[Union[str, "torch.device"]] = None,
|
|
389
|
+
mmap: bool = False,
|
|
390
|
+
filename_pattern: Optional[str] = None,
|
|
391
|
+
) -> NamedTuple:
|
|
392
|
+
"""
|
|
393
|
+
Load a checkpoint into a model, handling both sharded and non-sharded checkpoints.
|
|
394
|
+
|
|
395
|
+
Args:
|
|
396
|
+
model (`torch.nn.Module`):
|
|
397
|
+
The model in which to load the checkpoint.
|
|
398
|
+
checkpoint_path (`str` or `os.PathLike`):
|
|
399
|
+
Path to either the checkpoint file or directory containing the checkpoint(s).
|
|
400
|
+
strict (`bool`, *optional*, defaults to `False`):
|
|
401
|
+
Whether to strictly enforce that the keys in the model state dict match the keys in the checkpoint.
|
|
402
|
+
safe (`bool`, *optional*, defaults to `True`):
|
|
403
|
+
If `safe` is True, the safetensors files will be loaded. If `safe` is False, the function
|
|
404
|
+
will first attempt to load safetensors files if they are available, otherwise it will fall back to loading
|
|
405
|
+
pickle files. `filename_pattern` parameter takes precedence over `safe` parameter.
|
|
406
|
+
weights_only (`bool`, *optional*, defaults to `False`):
|
|
407
|
+
If True, only loads the model weights without optimizer states and other metadata.
|
|
408
|
+
Only supported in PyTorch >= 1.13.
|
|
409
|
+
map_location (`str` or `torch.device`, *optional*):
|
|
410
|
+
A `torch.device` object, string or a dict specifying how to remap storage locations. It
|
|
411
|
+
indicates the location where all tensors should be loaded.
|
|
412
|
+
mmap (`bool`, *optional*, defaults to `False`):
|
|
413
|
+
Whether to use memory-mapped file loading. Memory mapping can improve loading performance
|
|
414
|
+
for large models in PyTorch >= 2.1.0 with zipfile-based checkpoints.
|
|
415
|
+
filename_pattern (`str`, *optional*):
|
|
416
|
+
The pattern to look for the index file. Pattern must be a string that
|
|
417
|
+
can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix`
|
|
418
|
+
Defaults to `"model{suffix}.safetensors"`.
|
|
419
|
+
Returns:
|
|
420
|
+
`NamedTuple`: A named tuple with `missing_keys` and `unexpected_keys` fields.
|
|
421
|
+
- `missing_keys` is a list of str containing the missing keys, i.e. keys that are in the model but not in the checkpoint.
|
|
422
|
+
- `unexpected_keys` is a list of str containing the unexpected keys, i.e. keys that are in the checkpoint but not in the model.
|
|
423
|
+
|
|
424
|
+
Raises:
|
|
425
|
+
[`FileNotFoundError`](https://docs.python.org/3/library/exceptions.html#FileNotFoundError)
|
|
426
|
+
If the checkpoint file or directory does not exist.
|
|
427
|
+
[`ImportError`](https://docs.python.org/3/library/exceptions.html#ImportError)
|
|
428
|
+
If safetensors or torch is not installed when trying to load a .safetensors file or a PyTorch checkpoint respectively.
|
|
429
|
+
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
|
430
|
+
If the checkpoint path is invalid or if the checkpoint format cannot be determined.
|
|
431
|
+
|
|
432
|
+
Example:
|
|
433
|
+
```python
|
|
434
|
+
>>> from huggingface_hub import load_torch_model
|
|
435
|
+
>>> model = ... # A PyTorch model
|
|
436
|
+
>>> load_torch_model(model, "path/to/checkpoint")
|
|
437
|
+
```
|
|
438
|
+
"""
|
|
439
|
+
checkpoint_path = Path(checkpoint_path)
|
|
440
|
+
|
|
441
|
+
if not checkpoint_path.exists():
|
|
442
|
+
raise ValueError(f"Checkpoint path {checkpoint_path} does not exist")
|
|
443
|
+
# 1. Check if checkpoint is a single file
|
|
444
|
+
if checkpoint_path.is_file():
|
|
445
|
+
state_dict = load_state_dict_from_file(
|
|
446
|
+
checkpoint_file=checkpoint_path,
|
|
447
|
+
map_location=map_location,
|
|
448
|
+
weights_only=weights_only,
|
|
449
|
+
)
|
|
450
|
+
return model.load_state_dict(state_dict, strict=strict)
|
|
451
|
+
|
|
452
|
+
# 2. If not, checkpoint_path is a directory
|
|
453
|
+
if filename_pattern is None:
|
|
454
|
+
filename_pattern = constants.SAFETENSORS_WEIGHTS_FILE_PATTERN
|
|
455
|
+
index_path = checkpoint_path / (filename_pattern.format(suffix="") + ".index.json")
|
|
456
|
+
# Only fallback to pickle format if safetensors index is not found and safe is False.
|
|
457
|
+
if not index_path.is_file() and not safe:
|
|
458
|
+
filename_pattern = constants.PYTORCH_WEIGHTS_FILE_PATTERN
|
|
459
|
+
|
|
460
|
+
index_path = checkpoint_path / (filename_pattern.format(suffix="") + ".index.json")
|
|
461
|
+
|
|
462
|
+
if index_path.is_file():
|
|
463
|
+
return _load_sharded_checkpoint(
|
|
464
|
+
model=model,
|
|
465
|
+
save_directory=checkpoint_path,
|
|
466
|
+
strict=strict,
|
|
467
|
+
weights_only=weights_only,
|
|
468
|
+
filename_pattern=filename_pattern,
|
|
469
|
+
)
|
|
470
|
+
|
|
471
|
+
# Look for single model file
|
|
472
|
+
model_files = list(checkpoint_path.glob("*.safetensors" if safe else "*.bin"))
|
|
473
|
+
if len(model_files) == 1:
|
|
474
|
+
state_dict = load_state_dict_from_file(
|
|
475
|
+
checkpoint_file=model_files[0],
|
|
476
|
+
map_location=map_location,
|
|
477
|
+
weights_only=weights_only,
|
|
478
|
+
mmap=mmap,
|
|
479
|
+
)
|
|
480
|
+
return model.load_state_dict(state_dict, strict=strict)
|
|
481
|
+
|
|
482
|
+
raise ValueError(
|
|
483
|
+
f"Directory '{checkpoint_path}' does not contain a valid checkpoint. "
|
|
484
|
+
"Expected either a sharded checkpoint with an index file, or a single model file."
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
|
|
488
|
+
def _load_sharded_checkpoint(
|
|
489
|
+
model: "torch.nn.Module",
|
|
490
|
+
save_directory: os.PathLike,
|
|
491
|
+
*,
|
|
492
|
+
strict: bool = False,
|
|
493
|
+
weights_only: bool = False,
|
|
494
|
+
filename_pattern: str = constants.SAFETENSORS_WEIGHTS_FILE_PATTERN,
|
|
495
|
+
) -> NamedTuple:
|
|
496
|
+
"""
|
|
497
|
+
Loads a sharded checkpoint into a model. This is the same as
|
|
498
|
+
[`torch.nn.Module.load_state_dict`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict)
|
|
499
|
+
but for a sharded checkpoint. Each shard is loaded one by one and removed from memory after being loaded into the model.
|
|
500
|
+
|
|
501
|
+
Args:
|
|
502
|
+
model (`torch.nn.Module`):
|
|
503
|
+
The model in which to load the checkpoint.
|
|
504
|
+
save_directory (`str` or `os.PathLike`):
|
|
505
|
+
A path to a folder containing the sharded checkpoint.
|
|
506
|
+
strict (`bool`, *optional*, defaults to `False`):
|
|
507
|
+
Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint.
|
|
508
|
+
weights_only (`bool`, *optional*, defaults to `False`):
|
|
509
|
+
If True, only loads the model weights without optimizer states and other metadata.
|
|
510
|
+
Only supported in PyTorch >= 1.13.
|
|
511
|
+
filename_pattern (`str`, *optional*, defaults to `"model{suffix}.safetensors"`):
|
|
512
|
+
The pattern to look for the index file. Pattern must be a string that
|
|
513
|
+
can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix`
|
|
514
|
+
Defaults to `"model{suffix}.safetensors"`.
|
|
515
|
+
|
|
516
|
+
Returns:
|
|
517
|
+
`NamedTuple`: A named tuple with `missing_keys` and `unexpected_keys` fields,
|
|
518
|
+
- `missing_keys` is a list of str containing the missing keys
|
|
519
|
+
- `unexpected_keys` is a list of str containing the unexpected keys
|
|
520
|
+
"""
|
|
521
|
+
|
|
522
|
+
# 1. Load and validate index file
|
|
523
|
+
# The index file contains mapping of parameter names to shard files
|
|
524
|
+
index_path = filename_pattern.format(suffix="") + ".index.json"
|
|
525
|
+
index_file = os.path.join(save_directory, index_path)
|
|
526
|
+
with open(index_file, "r", encoding="utf-8") as f:
|
|
527
|
+
index = json.load(f)
|
|
528
|
+
|
|
529
|
+
# 2. Validate keys if in strict mode
|
|
530
|
+
# This is done before loading any shards to fail fast
|
|
531
|
+
if strict:
|
|
532
|
+
_validate_keys_for_strict_loading(model, index["weight_map"].keys())
|
|
533
|
+
|
|
534
|
+
# 3. Load each shard using `load_state_dict`
|
|
535
|
+
# Get unique shard files (multiple parameters can be in same shard)
|
|
536
|
+
shard_files = list(set(index["weight_map"].values()))
|
|
537
|
+
for shard_file in shard_files:
|
|
538
|
+
# Load shard into memory
|
|
539
|
+
shard_path = os.path.join(save_directory, shard_file)
|
|
540
|
+
state_dict = load_state_dict_from_file(
|
|
541
|
+
shard_path,
|
|
542
|
+
map_location="cpu",
|
|
543
|
+
weights_only=weights_only,
|
|
544
|
+
)
|
|
545
|
+
# Update model with parameters from this shard
|
|
546
|
+
model.load_state_dict(state_dict, strict=strict)
|
|
547
|
+
# Explicitly remove the state dict from memory
|
|
548
|
+
del state_dict
|
|
549
|
+
|
|
550
|
+
# 4. Return compatibility info
|
|
551
|
+
loaded_keys = set(index["weight_map"].keys())
|
|
552
|
+
model_keys = set(model.state_dict().keys())
|
|
553
|
+
return _IncompatibleKeys(
|
|
554
|
+
missing_keys=list(model_keys - loaded_keys), unexpected_keys=list(loaded_keys - model_keys)
|
|
555
|
+
)
|
|
556
|
+
|
|
557
|
+
|
|
558
|
+
def load_state_dict_from_file(
|
|
559
|
+
checkpoint_file: Union[str, os.PathLike],
|
|
560
|
+
map_location: Optional[Union[str, "torch.device"]] = None,
|
|
561
|
+
weights_only: bool = False,
|
|
562
|
+
mmap: bool = False,
|
|
563
|
+
) -> Union[Dict[str, "torch.Tensor"], Any]:
|
|
564
|
+
"""
|
|
565
|
+
Loads a checkpoint file, handling both safetensors and pickle checkpoint formats.
|
|
566
|
+
|
|
567
|
+
Args:
|
|
568
|
+
checkpoint_file (`str` or `os.PathLike`):
|
|
569
|
+
Path to the checkpoint file to load. Can be either a safetensors or pickle (`.bin`) checkpoint.
|
|
570
|
+
map_location (`str` or `torch.device`, *optional*):
|
|
571
|
+
A `torch.device` object, string or a dict specifying how to remap storage locations. It
|
|
572
|
+
indicates the location where all tensors should be loaded.
|
|
573
|
+
weights_only (`bool`, *optional*, defaults to `False`):
|
|
574
|
+
If True, only loads the model weights without optimizer states and other metadata.
|
|
575
|
+
Only supported for pickle (`.bin`) checkpoints with PyTorch >= 1.13. Has no effect when
|
|
576
|
+
loading safetensors files.
|
|
577
|
+
mmap (`bool`, *optional*, defaults to `False`):
|
|
578
|
+
Whether to use memory-mapped file loading. Memory mapping can improve loading performance
|
|
579
|
+
for large models in PyTorch >= 2.1.0 with zipfile-based checkpoints. Has no effect when
|
|
580
|
+
loading safetensors files, as the `safetensors` library uses memory mapping by default.
|
|
581
|
+
|
|
582
|
+
Returns:
|
|
583
|
+
`Union[Dict[str, "torch.Tensor"], Any]`: The loaded checkpoint.
|
|
584
|
+
- For safetensors files: always returns a dictionary mapping parameter names to tensors.
|
|
585
|
+
- For pickle files: returns any Python object that was pickled (commonly a state dict, but could be
|
|
586
|
+
an entire model, optimizer state, or any other Python object).
|
|
587
|
+
|
|
588
|
+
Raises:
|
|
589
|
+
[`FileNotFoundError`](https://docs.python.org/3/library/exceptions.html#FileNotFoundError)
|
|
590
|
+
If the checkpoint file does not exist.
|
|
591
|
+
[`ImportError`](https://docs.python.org/3/library/exceptions.html#ImportError)
|
|
592
|
+
If safetensors or torch is not installed when trying to load a .safetensors file or a PyTorch checkpoint respectively.
|
|
593
|
+
[`OSError`](https://docs.python.org/3/library/exceptions.html#OSError)
|
|
594
|
+
If the checkpoint file format is invalid or if git-lfs files are not properly downloaded.
|
|
595
|
+
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
|
596
|
+
If the checkpoint file path is empty or invalid.
|
|
597
|
+
|
|
598
|
+
Example:
|
|
599
|
+
```python
|
|
600
|
+
>>> from huggingface_hub import load_state_dict_from_file
|
|
601
|
+
|
|
602
|
+
# Load a PyTorch checkpoint
|
|
603
|
+
>>> state_dict = load_state_dict_from_file("path/to/model.bin", map_location="cpu")
|
|
604
|
+
>>> model.load_state_dict(state_dict)
|
|
605
|
+
|
|
606
|
+
# Load a safetensors checkpoint
|
|
607
|
+
>>> state_dict = load_state_dict_from_file("path/to/model.safetensors")
|
|
608
|
+
>>> model.load_state_dict(state_dict)
|
|
609
|
+
```
|
|
610
|
+
"""
|
|
611
|
+
checkpoint_path = Path(checkpoint_file)
|
|
612
|
+
|
|
613
|
+
# Check if file exists and is a regular file (not a directory)
|
|
614
|
+
if not checkpoint_path.is_file():
|
|
615
|
+
raise FileNotFoundError(
|
|
616
|
+
f"No checkpoint file found at '{checkpoint_path}'. Please verify the path is correct and "
|
|
617
|
+
"the file has been properly downloaded."
|
|
618
|
+
)
|
|
619
|
+
|
|
620
|
+
# Load safetensors checkpoint
|
|
621
|
+
if checkpoint_path.suffix == ".safetensors":
|
|
622
|
+
try:
|
|
623
|
+
from safetensors import safe_open
|
|
624
|
+
from safetensors.torch import load_file
|
|
625
|
+
except ImportError as e:
|
|
626
|
+
raise ImportError(
|
|
627
|
+
"Please install `safetensors` to load safetensors checkpoint. "
|
|
628
|
+
"You can install it with `pip install safetensors`."
|
|
629
|
+
) from e
|
|
630
|
+
|
|
631
|
+
# Check format of the archive
|
|
632
|
+
with safe_open(checkpoint_file, framework="pt") as f: # type: ignore[attr-defined]
|
|
633
|
+
metadata = f.metadata()
|
|
634
|
+
# see comment: https://github.com/huggingface/transformers/blob/3d213b57fe74302e5902d68ed9478c3ad1aaa713/src/transformers/modeling_utils.py#L3966
|
|
635
|
+
if metadata is not None and metadata.get("format") not in ["pt", "mlx"]:
|
|
636
|
+
raise OSError(
|
|
637
|
+
f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
|
|
638
|
+
"you save your model with the `save_torch_model` method."
|
|
639
|
+
)
|
|
640
|
+
device = str(map_location.type) if map_location is not None and hasattr(map_location, "type") else map_location
|
|
641
|
+
# meta device is not supported with safetensors, falling back to CPU
|
|
642
|
+
if device == "meta":
|
|
643
|
+
logger.warning("Meta device is not supported with safetensors. Falling back to CPU device.")
|
|
644
|
+
device = "cpu"
|
|
645
|
+
return load_file(checkpoint_file, device=device) # type: ignore[arg-type]
|
|
646
|
+
# Otherwise, load from pickle
|
|
647
|
+
try:
|
|
648
|
+
import torch
|
|
649
|
+
from torch import load
|
|
650
|
+
except ImportError as e:
|
|
651
|
+
raise ImportError(
|
|
652
|
+
"Please install `torch` to load torch tensors. " "You can install it with `pip install torch`."
|
|
653
|
+
) from e
|
|
654
|
+
# Add additional kwargs, mmap is only supported in torch >= 2.1.0
|
|
655
|
+
additional_kwargs = {}
|
|
656
|
+
if version.parse(torch.__version__) >= version.parse("2.1.0"):
|
|
657
|
+
additional_kwargs["mmap"] = mmap
|
|
658
|
+
|
|
659
|
+
# weights_only is only supported in torch >= 1.13.0
|
|
660
|
+
if version.parse(torch.__version__) >= version.parse("1.13.0"):
|
|
661
|
+
additional_kwargs["weights_only"] = weights_only
|
|
662
|
+
|
|
663
|
+
return load(
|
|
664
|
+
checkpoint_file,
|
|
665
|
+
map_location=map_location,
|
|
666
|
+
**additional_kwargs,
|
|
667
|
+
)
|
|
668
|
+
|
|
669
|
+
|
|
670
|
+
# HELPERS
|
|
671
|
+
|
|
672
|
+
|
|
673
|
+
def _validate_keys_for_strict_loading(
|
|
674
|
+
model: "torch.nn.Module",
|
|
675
|
+
loaded_keys: Iterable[str],
|
|
676
|
+
) -> None:
|
|
677
|
+
"""
|
|
678
|
+
Validate that model keys match loaded keys when strict loading is enabled.
|
|
679
|
+
|
|
680
|
+
Args:
|
|
681
|
+
model: The PyTorch model being loaded
|
|
682
|
+
loaded_keys: The keys present in the checkpoint
|
|
683
|
+
|
|
684
|
+
Raises:
|
|
685
|
+
RuntimeError: If there are missing or unexpected keys in strict mode
|
|
686
|
+
"""
|
|
687
|
+
loaded_keys_set = set(loaded_keys)
|
|
688
|
+
model_keys = set(model.state_dict().keys())
|
|
689
|
+
missing_keys = model_keys - loaded_keys_set # Keys in model but not in checkpoint
|
|
690
|
+
unexpected_keys = loaded_keys_set - model_keys # Keys in checkpoint but not in model
|
|
691
|
+
|
|
692
|
+
if missing_keys or unexpected_keys:
|
|
693
|
+
error_message = f"Error(s) in loading state_dict for {model.__class__.__name__}"
|
|
694
|
+
if missing_keys:
|
|
695
|
+
str_missing_keys = ",".join([f'"{k}"' for k in sorted(missing_keys)])
|
|
696
|
+
error_message += f"\nMissing key(s): {str_missing_keys}."
|
|
697
|
+
if unexpected_keys:
|
|
698
|
+
str_unexpected_keys = ",".join([f'"{k}"' for k in sorted(unexpected_keys)])
|
|
699
|
+
error_message += f"\nUnexpected key(s): {str_unexpected_keys}."
|
|
700
|
+
raise RuntimeError(error_message)
|
|
701
|
+
|
|
702
|
+
|
|
339
703
|
def _get_unique_id(tensor: "torch.Tensor") -> Union[int, Tuple[Any, ...]]:
|
|
340
704
|
"""Returns a unique id for plain tensor
|
|
341
705
|
or a (potentially nested) Tuple of unique id for the flattened Tensor
|
|
@@ -359,7 +723,7 @@ def _get_unique_id(tensor: "torch.Tensor") -> Union[int, Tuple[Any, ...]]:
|
|
|
359
723
|
# use some other unique id to distinguish.
|
|
360
724
|
# this is a XLA tensor, it must be created using torch_xla's
|
|
361
725
|
# device. So the following import is safe:
|
|
362
|
-
import torch_xla
|
|
726
|
+
import torch_xla # type: ignore[import]
|
|
363
727
|
|
|
364
728
|
unique_id = torch_xla._XLAC._xla_get_tensor_id(tensor)
|
|
365
729
|
else:
|
|
@@ -423,7 +787,7 @@ def is_torch_tpu_available(check_device=True):
|
|
|
423
787
|
if check_device:
|
|
424
788
|
# We need to check if `xla_device` can be found, will raise a RuntimeError if not
|
|
425
789
|
try:
|
|
426
|
-
import torch_xla.core.xla_model as xm
|
|
790
|
+
import torch_xla.core.xla_model as xm # type: ignore[import]
|
|
427
791
|
|
|
428
792
|
_ = xm.xla_device()
|
|
429
793
|
return True
|
|
@@ -442,7 +806,7 @@ def storage_ptr(tensor: "torch.Tensor") -> Union[int, Tuple[Any, ...]]:
|
|
|
442
806
|
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
|
|
443
807
|
|
|
444
808
|
if is_traceable_wrapper_subclass(tensor):
|
|
445
|
-
return _get_unique_id(tensor)
|
|
809
|
+
return _get_unique_id(tensor) # type: ignore
|
|
446
810
|
except ImportError:
|
|
447
811
|
# for torch version less than 2.1, we can fallback to original implementation
|
|
448
812
|
pass
|
|
@@ -459,7 +823,10 @@ def storage_ptr(tensor: "torch.Tensor") -> Union[int, Tuple[Any, ...]]:
|
|
|
459
823
|
|
|
460
824
|
|
|
461
825
|
def _clean_state_dict_for_safetensors(
|
|
462
|
-
state_dict: Dict[str, "torch.Tensor"],
|
|
826
|
+
state_dict: Dict[str, "torch.Tensor"],
|
|
827
|
+
metadata: Dict[str, str],
|
|
828
|
+
force_contiguous: bool = True,
|
|
829
|
+
shared_tensors_to_discard: Optional[List[str]] = None,
|
|
463
830
|
):
|
|
464
831
|
"""Remove shared tensors from state_dict and update metadata accordingly (for reloading).
|
|
465
832
|
|
|
@@ -467,7 +834,7 @@ def _clean_state_dict_for_safetensors(
|
|
|
467
834
|
|
|
468
835
|
Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L155.
|
|
469
836
|
"""
|
|
470
|
-
to_removes = _remove_duplicate_names(state_dict)
|
|
837
|
+
to_removes = _remove_duplicate_names(state_dict, discard_names=shared_tensors_to_discard)
|
|
471
838
|
for kept_name, to_remove_group in to_removes.items():
|
|
472
839
|
for to_remove in to_remove_group:
|
|
473
840
|
if metadata is None:
|
|
@@ -631,3 +998,18 @@ def _get_dtype_size(dtype: "torch.dtype") -> int:
|
|
|
631
998
|
_float8_e5m2: 1,
|
|
632
999
|
}
|
|
633
1000
|
return _SIZE[dtype]
|
|
1001
|
+
|
|
1002
|
+
|
|
1003
|
+
class _IncompatibleKeys(namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"])):
|
|
1004
|
+
"""
|
|
1005
|
+
This is used to report missing and unexpected keys in the state dict.
|
|
1006
|
+
Taken from https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/module.py#L52.
|
|
1007
|
+
|
|
1008
|
+
"""
|
|
1009
|
+
|
|
1010
|
+
def __repr__(self) -> str:
|
|
1011
|
+
if not self.missing_keys and not self.unexpected_keys:
|
|
1012
|
+
return "<All keys matched successfully>"
|
|
1013
|
+
return super().__repr__()
|
|
1014
|
+
|
|
1015
|
+
__str__ = __repr__
|
|
@@ -742,7 +742,7 @@ def _scan_cached_repo(repo_path: Path) -> CachedRepoInfo:
|
|
|
742
742
|
|
|
743
743
|
for ref_path in refs_path.glob("**/*"):
|
|
744
744
|
# glob("**/*") iterates over all files and directories -> skip directories
|
|
745
|
-
if ref_path.is_dir():
|
|
745
|
+
if ref_path.is_dir() or ref_path.name in FILES_TO_IGNORE:
|
|
746
746
|
continue
|
|
747
747
|
|
|
748
748
|
ref_name = str(ref_path.relative_to(refs_path))
|
|
@@ -20,6 +20,7 @@ from huggingface_hub.errors import LocalTokenNotFoundError
|
|
|
20
20
|
|
|
21
21
|
from .. import constants
|
|
22
22
|
from ._auth import get_token
|
|
23
|
+
from ._deprecation import _deprecate_arguments
|
|
23
24
|
from ._runtime import (
|
|
24
25
|
get_fastai_version,
|
|
25
26
|
get_fastcore_version,
|
|
@@ -35,15 +36,20 @@ from ._runtime import (
|
|
|
35
36
|
from ._validators import validate_hf_hub_args
|
|
36
37
|
|
|
37
38
|
|
|
39
|
+
@_deprecate_arguments(
|
|
40
|
+
version="1.0",
|
|
41
|
+
deprecated_args="is_write_action",
|
|
42
|
+
custom_message="This argument is ignored and we let the server handle the permission error instead (if any).",
|
|
43
|
+
)
|
|
38
44
|
@validate_hf_hub_args
|
|
39
45
|
def build_hf_headers(
|
|
40
46
|
*,
|
|
41
47
|
token: Optional[Union[bool, str]] = None,
|
|
42
|
-
is_write_action: bool = False,
|
|
43
48
|
library_name: Optional[str] = None,
|
|
44
49
|
library_version: Optional[str] = None,
|
|
45
50
|
user_agent: Union[Dict, str, None] = None,
|
|
46
51
|
headers: Optional[Dict[str, str]] = None,
|
|
52
|
+
is_write_action: bool = False,
|
|
47
53
|
) -> Dict[str, str]:
|
|
48
54
|
"""
|
|
49
55
|
Build headers dictionary to send in a HF Hub call.
|
|
@@ -68,9 +74,6 @@ def build_hf_headers(
|
|
|
68
74
|
- if `False`, authorization header is not set
|
|
69
75
|
- if `None`, the token is read from the machine only except if
|
|
70
76
|
`HF_HUB_DISABLE_IMPLICIT_TOKEN` env variable is set.
|
|
71
|
-
is_write_action (`bool`, default to `False`):
|
|
72
|
-
Set to True if the API call requires a write access. If `True`, the token
|
|
73
|
-
will be validated (cannot be `None`, cannot start by `"api_org***"`).
|
|
74
77
|
library_name (`str`, *optional*):
|
|
75
78
|
The name of the library that is making the HTTP request. Will be added to
|
|
76
79
|
the user-agent header.
|
|
@@ -83,6 +86,8 @@ def build_hf_headers(
|
|
|
83
86
|
headers (`dict`, *optional*):
|
|
84
87
|
Additional headers to include in the request. Those headers take precedence
|
|
85
88
|
over the ones generated by this function.
|
|
89
|
+
is_write_action (`bool`):
|
|
90
|
+
Ignored and deprecated argument.
|
|
86
91
|
|
|
87
92
|
Returns:
|
|
88
93
|
A `Dict` of headers to pass in your API call.
|
|
@@ -105,9 +110,6 @@ def build_hf_headers(
|
|
|
105
110
|
>>> build_hf_headers() # token is not sent
|
|
106
111
|
{"user-agent": ...}
|
|
107
112
|
|
|
108
|
-
>>> build_hf_headers(token="api_org_***", is_write_action=True)
|
|
109
|
-
ValueError: You must use your personal account token for write-access methods.
|
|
110
|
-
|
|
111
113
|
>>> build_hf_headers(library_name="transformers", library_version="1.2.3")
|
|
112
114
|
{"authorization": ..., "user-agent": "transformers/1.2.3; hf_hub/0.10.2; python/3.10.4; tensorflow/1.55"}
|
|
113
115
|
```
|
|
@@ -122,7 +124,6 @@ def build_hf_headers(
|
|
|
122
124
|
"""
|
|
123
125
|
# Get auth token to send
|
|
124
126
|
token_to_send = get_token_to_send(token)
|
|
125
|
-
_validate_token_to_send(token_to_send, is_write_action=is_write_action)
|
|
126
127
|
|
|
127
128
|
# Combine headers
|
|
128
129
|
hf_headers = {
|
|
@@ -171,23 +172,6 @@ def get_token_to_send(token: Optional[Union[bool, str]]) -> Optional[str]:
|
|
|
171
172
|
return cached_token
|
|
172
173
|
|
|
173
174
|
|
|
174
|
-
def _validate_token_to_send(token: Optional[str], is_write_action: bool) -> None:
|
|
175
|
-
if is_write_action:
|
|
176
|
-
if token is None:
|
|
177
|
-
raise ValueError(
|
|
178
|
-
"Token is required (write-access action) but no token found. You need"
|
|
179
|
-
" to provide a token or be logged in to Hugging Face with"
|
|
180
|
-
" `huggingface-cli login` or `huggingface_hub.login`. See"
|
|
181
|
-
" https://huggingface.co/settings/tokens."
|
|
182
|
-
)
|
|
183
|
-
if token.startswith("api_org"):
|
|
184
|
-
raise ValueError(
|
|
185
|
-
"You must use your personal account token for write-access methods. To"
|
|
186
|
-
" generate a write-access token, go to"
|
|
187
|
-
" https://huggingface.co/settings/tokens"
|
|
188
|
-
)
|
|
189
|
-
|
|
190
|
-
|
|
191
175
|
def _http_user_agent(
|
|
192
176
|
*,
|
|
193
177
|
library_name: Optional[str] = None,
|