huggingface-hub 0.26.5__py3-none-any.whl → 0.27.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of huggingface-hub might be problematic. Click here for more details.
- huggingface_hub/__init__.py +49 -23
- huggingface_hub/_commit_scheduler.py +30 -4
- huggingface_hub/_local_folder.py +0 -4
- huggingface_hub/_login.py +38 -54
- huggingface_hub/_snapshot_download.py +6 -3
- huggingface_hub/_tensorboard_logger.py +2 -3
- huggingface_hub/_upload_large_folder.py +1 -1
- huggingface_hub/errors.py +19 -0
- huggingface_hub/fastai_utils.py +3 -2
- huggingface_hub/file_download.py +10 -12
- huggingface_hub/hf_api.py +102 -498
- huggingface_hub/hf_file_system.py +274 -35
- huggingface_hub/hub_mixin.py +5 -25
- huggingface_hub/inference/_client.py +185 -136
- huggingface_hub/inference/_common.py +2 -2
- huggingface_hub/inference/_generated/_async_client.py +186 -137
- huggingface_hub/inference/_generated/types/__init__.py +31 -10
- huggingface_hub/inference/_generated/types/audio_classification.py +3 -5
- huggingface_hub/inference/_generated/types/automatic_speech_recognition.py +6 -9
- huggingface_hub/inference/_generated/types/chat_completion.py +8 -5
- huggingface_hub/inference/_generated/types/depth_estimation.py +1 -1
- huggingface_hub/inference/_generated/types/document_question_answering.py +2 -6
- huggingface_hub/inference/_generated/types/feature_extraction.py +1 -1
- huggingface_hub/inference/_generated/types/fill_mask.py +2 -4
- huggingface_hub/inference/_generated/types/image_classification.py +3 -5
- huggingface_hub/inference/_generated/types/image_segmentation.py +2 -4
- huggingface_hub/inference/_generated/types/image_to_image.py +2 -4
- huggingface_hub/inference/_generated/types/image_to_text.py +6 -9
- huggingface_hub/inference/_generated/types/object_detection.py +2 -4
- huggingface_hub/inference/_generated/types/question_answering.py +2 -4
- huggingface_hub/inference/_generated/types/sentence_similarity.py +1 -1
- huggingface_hub/inference/_generated/types/summarization.py +2 -4
- huggingface_hub/inference/_generated/types/table_question_answering.py +21 -3
- huggingface_hub/inference/_generated/types/text2text_generation.py +2 -4
- huggingface_hub/inference/_generated/types/text_classification.py +4 -10
- huggingface_hub/inference/_generated/types/text_to_audio.py +7 -10
- huggingface_hub/inference/_generated/types/text_to_image.py +2 -4
- huggingface_hub/inference/_generated/types/text_to_speech.py +7 -10
- huggingface_hub/inference/_generated/types/token_classification.py +11 -12
- huggingface_hub/inference/_generated/types/translation.py +2 -4
- huggingface_hub/inference/_generated/types/video_classification.py +3 -4
- huggingface_hub/inference/_generated/types/visual_question_answering.py +2 -5
- huggingface_hub/inference/_generated/types/zero_shot_classification.py +8 -18
- huggingface_hub/inference/_generated/types/zero_shot_image_classification.py +9 -19
- huggingface_hub/inference/_generated/types/zero_shot_object_detection.py +7 -9
- huggingface_hub/keras_mixin.py +3 -2
- huggingface_hub/lfs.py +2 -5
- huggingface_hub/repocard_data.py +4 -4
- huggingface_hub/serialization/__init__.py +2 -0
- huggingface_hub/serialization/_dduf.py +387 -0
- huggingface_hub/serialization/_torch.py +372 -14
- huggingface_hub/utils/_cache_manager.py +1 -1
- huggingface_hub/utils/_headers.py +9 -25
- huggingface_hub/utils/tqdm.py +15 -0
- {huggingface_hub-0.26.5.dist-info → huggingface_hub-0.27.0.dist-info}/METADATA +8 -3
- {huggingface_hub-0.26.5.dist-info → huggingface_hub-0.27.0.dist-info}/RECORD +60 -60
- huggingface_hub/_multi_commits.py +0 -306
- {huggingface_hub-0.26.5.dist-info → huggingface_hub-0.27.0.dist-info}/LICENSE +0 -0
- {huggingface_hub-0.26.5.dist-info → huggingface_hub-0.27.0.dist-info}/WHEEL +0 -0
- {huggingface_hub-0.26.5.dist-info → huggingface_hub-0.27.0.dist-info}/entry_points.txt +0 -0
- {huggingface_hub-0.26.5.dist-info → huggingface_hub-0.27.0.dist-info}/top_level.txt +0 -0
|
@@ -17,10 +17,12 @@ import importlib
|
|
|
17
17
|
import json
|
|
18
18
|
import os
|
|
19
19
|
import re
|
|
20
|
-
from collections import defaultdict
|
|
20
|
+
from collections import defaultdict, namedtuple
|
|
21
21
|
from functools import lru_cache
|
|
22
22
|
from pathlib import Path
|
|
23
|
-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
|
|
23
|
+
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, NamedTuple, Optional, Set, Tuple, Union
|
|
24
|
+
|
|
25
|
+
from packaging import version
|
|
24
26
|
|
|
25
27
|
from .. import constants, logging
|
|
26
28
|
from ._base import MAX_SHARD_SIZE, StateDictSplit, split_state_dict_into_shards_factory
|
|
@@ -31,6 +33,8 @@ logger = logging.get_logger(__file__)
|
|
|
31
33
|
if TYPE_CHECKING:
|
|
32
34
|
import torch
|
|
33
35
|
|
|
36
|
+
# SAVING
|
|
37
|
+
|
|
34
38
|
|
|
35
39
|
def save_torch_model(
|
|
36
40
|
model: "torch.nn.Module",
|
|
@@ -41,6 +45,7 @@ def save_torch_model(
|
|
|
41
45
|
max_shard_size: Union[int, str] = MAX_SHARD_SIZE,
|
|
42
46
|
metadata: Optional[Dict[str, str]] = None,
|
|
43
47
|
safe_serialization: bool = True,
|
|
48
|
+
is_main_process: bool = True,
|
|
44
49
|
shared_tensors_to_discard: Optional[List[str]] = None,
|
|
45
50
|
):
|
|
46
51
|
"""
|
|
@@ -95,6 +100,10 @@ def save_torch_model(
|
|
|
95
100
|
Whether to save as safetensors, which is the default behavior. If `False`, the shards are saved as pickle.
|
|
96
101
|
Safe serialization is recommended for security reasons. Saving as pickle is deprecated and will be removed
|
|
97
102
|
in a future version.
|
|
103
|
+
is_main_process (`bool`, *optional*):
|
|
104
|
+
Whether the process calling this is the main process or not. Useful when in distributed training like
|
|
105
|
+
TPUs and need to call this function from all processes. In this case, set `is_main_process=True` only on
|
|
106
|
+
the main process to avoid race conditions. Defaults to True.
|
|
98
107
|
shared_tensors_to_discard (`List[str]`, *optional*):
|
|
99
108
|
List of tensor names to drop when saving shared tensors. If not provided and shared tensors are
|
|
100
109
|
detected, it will drop the first name alphabetically.
|
|
@@ -122,6 +131,7 @@ def save_torch_model(
|
|
|
122
131
|
metadata=metadata,
|
|
123
132
|
safe_serialization=safe_serialization,
|
|
124
133
|
save_directory=save_directory,
|
|
134
|
+
is_main_process=is_main_process,
|
|
125
135
|
shared_tensors_to_discard=shared_tensors_to_discard,
|
|
126
136
|
)
|
|
127
137
|
|
|
@@ -135,6 +145,7 @@ def save_torch_state_dict(
|
|
|
135
145
|
max_shard_size: Union[int, str] = MAX_SHARD_SIZE,
|
|
136
146
|
metadata: Optional[Dict[str, str]] = None,
|
|
137
147
|
safe_serialization: bool = True,
|
|
148
|
+
is_main_process: bool = True,
|
|
138
149
|
shared_tensors_to_discard: Optional[List[str]] = None,
|
|
139
150
|
) -> None:
|
|
140
151
|
"""
|
|
@@ -189,6 +200,10 @@ def save_torch_state_dict(
|
|
|
189
200
|
Whether to save as safetensors, which is the default behavior. If `False`, the shards are saved as pickle.
|
|
190
201
|
Safe serialization is recommended for security reasons. Saving as pickle is deprecated and will be removed
|
|
191
202
|
in a future version.
|
|
203
|
+
is_main_process (`bool`, *optional*):
|
|
204
|
+
Whether the process calling this is the main process or not. Useful when in distributed training like
|
|
205
|
+
TPUs and need to call this function from all processes. In this case, set `is_main_process=True` only on
|
|
206
|
+
the main process to avoid race conditions. Defaults to True.
|
|
192
207
|
shared_tensors_to_discard (`List[str]`, *optional*):
|
|
193
208
|
List of tensor names to drop when saving shared tensors. If not provided and shared tensors are
|
|
194
209
|
detected, it will drop the first name alphabetically.
|
|
@@ -243,15 +258,18 @@ def save_torch_state_dict(
|
|
|
243
258
|
state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
|
|
244
259
|
)
|
|
245
260
|
|
|
246
|
-
#
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
261
|
+
# Only main process should clean up existing files to avoid race conditions in distributed environment
|
|
262
|
+
if is_main_process:
|
|
263
|
+
existing_files_regex = re.compile(filename_pattern.format(suffix=r"(-\d{5}-of-\d{5})?") + r"(\.index\.json)?")
|
|
264
|
+
for filename in os.listdir(save_directory):
|
|
265
|
+
if existing_files_regex.match(filename):
|
|
266
|
+
try:
|
|
267
|
+
logger.debug(f"Removing existing file '{filename}' from folder.")
|
|
268
|
+
os.remove(os.path.join(save_directory, filename))
|
|
269
|
+
except Exception as e:
|
|
270
|
+
logger.warning(
|
|
271
|
+
f"Error when trying to remove existing '{filename}' from folder: {e}. Continuing..."
|
|
272
|
+
)
|
|
255
273
|
|
|
256
274
|
# Save each shard
|
|
257
275
|
per_file_metadata = {"format": "pt"}
|
|
@@ -357,6 +375,331 @@ def split_torch_state_dict_into_shards(
|
|
|
357
375
|
)
|
|
358
376
|
|
|
359
377
|
|
|
378
|
+
# LOADING
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
def load_torch_model(
|
|
382
|
+
model: "torch.nn.Module",
|
|
383
|
+
checkpoint_path: Union[str, os.PathLike],
|
|
384
|
+
*,
|
|
385
|
+
strict: bool = False,
|
|
386
|
+
safe: bool = True,
|
|
387
|
+
weights_only: bool = False,
|
|
388
|
+
map_location: Optional[Union[str, "torch.device"]] = None,
|
|
389
|
+
mmap: bool = False,
|
|
390
|
+
filename_pattern: Optional[str] = None,
|
|
391
|
+
) -> NamedTuple:
|
|
392
|
+
"""
|
|
393
|
+
Load a checkpoint into a model, handling both sharded and non-sharded checkpoints.
|
|
394
|
+
|
|
395
|
+
Args:
|
|
396
|
+
model (`torch.nn.Module`):
|
|
397
|
+
The model in which to load the checkpoint.
|
|
398
|
+
checkpoint_path (`str` or `os.PathLike`):
|
|
399
|
+
Path to either the checkpoint file or directory containing the checkpoint(s).
|
|
400
|
+
strict (`bool`, *optional*, defaults to `False`):
|
|
401
|
+
Whether to strictly enforce that the keys in the model state dict match the keys in the checkpoint.
|
|
402
|
+
safe (`bool`, *optional*, defaults to `True`):
|
|
403
|
+
If `safe` is True, the safetensors files will be loaded. If `safe` is False, the function
|
|
404
|
+
will first attempt to load safetensors files if they are available, otherwise it will fall back to loading
|
|
405
|
+
pickle files. `filename_pattern` parameter takes precedence over `safe` parameter.
|
|
406
|
+
weights_only (`bool`, *optional*, defaults to `False`):
|
|
407
|
+
If True, only loads the model weights without optimizer states and other metadata.
|
|
408
|
+
Only supported in PyTorch >= 1.13.
|
|
409
|
+
map_location (`str` or `torch.device`, *optional*):
|
|
410
|
+
A `torch.device` object, string or a dict specifying how to remap storage locations. It
|
|
411
|
+
indicates the location where all tensors should be loaded.
|
|
412
|
+
mmap (`bool`, *optional*, defaults to `False`):
|
|
413
|
+
Whether to use memory-mapped file loading. Memory mapping can improve loading performance
|
|
414
|
+
for large models in PyTorch >= 2.1.0 with zipfile-based checkpoints.
|
|
415
|
+
filename_pattern (`str`, *optional*):
|
|
416
|
+
The pattern to look for the index file. Pattern must be a string that
|
|
417
|
+
can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix`
|
|
418
|
+
Defaults to `"model{suffix}.safetensors"`.
|
|
419
|
+
Returns:
|
|
420
|
+
`NamedTuple`: A named tuple with `missing_keys` and `unexpected_keys` fields.
|
|
421
|
+
- `missing_keys` is a list of str containing the missing keys, i.e. keys that are in the model but not in the checkpoint.
|
|
422
|
+
- `unexpected_keys` is a list of str containing the unexpected keys, i.e. keys that are in the checkpoint but not in the model.
|
|
423
|
+
|
|
424
|
+
Raises:
|
|
425
|
+
[`FileNotFoundError`](https://docs.python.org/3/library/exceptions.html#FileNotFoundError)
|
|
426
|
+
If the checkpoint file or directory does not exist.
|
|
427
|
+
[`ImportError`](https://docs.python.org/3/library/exceptions.html#ImportError)
|
|
428
|
+
If safetensors or torch is not installed when trying to load a .safetensors file or a PyTorch checkpoint respectively.
|
|
429
|
+
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
|
430
|
+
If the checkpoint path is invalid or if the checkpoint format cannot be determined.
|
|
431
|
+
|
|
432
|
+
Example:
|
|
433
|
+
```python
|
|
434
|
+
>>> from huggingface_hub import load_torch_model
|
|
435
|
+
>>> model = ... # A PyTorch model
|
|
436
|
+
>>> load_torch_model(model, "path/to/checkpoint")
|
|
437
|
+
```
|
|
438
|
+
"""
|
|
439
|
+
checkpoint_path = Path(checkpoint_path)
|
|
440
|
+
|
|
441
|
+
if not checkpoint_path.exists():
|
|
442
|
+
raise ValueError(f"Checkpoint path {checkpoint_path} does not exist")
|
|
443
|
+
# 1. Check if checkpoint is a single file
|
|
444
|
+
if checkpoint_path.is_file():
|
|
445
|
+
state_dict = load_state_dict_from_file(
|
|
446
|
+
checkpoint_file=checkpoint_path,
|
|
447
|
+
map_location=map_location,
|
|
448
|
+
weights_only=weights_only,
|
|
449
|
+
)
|
|
450
|
+
return model.load_state_dict(state_dict, strict=strict)
|
|
451
|
+
|
|
452
|
+
# 2. If not, checkpoint_path is a directory
|
|
453
|
+
if filename_pattern is None:
|
|
454
|
+
filename_pattern = constants.SAFETENSORS_WEIGHTS_FILE_PATTERN
|
|
455
|
+
index_path = checkpoint_path / (filename_pattern.format(suffix="") + ".index.json")
|
|
456
|
+
# Only fallback to pickle format if safetensors index is not found and safe is False.
|
|
457
|
+
if not index_path.is_file() and not safe:
|
|
458
|
+
filename_pattern = constants.PYTORCH_WEIGHTS_FILE_PATTERN
|
|
459
|
+
|
|
460
|
+
index_path = checkpoint_path / (filename_pattern.format(suffix="") + ".index.json")
|
|
461
|
+
|
|
462
|
+
if index_path.is_file():
|
|
463
|
+
return _load_sharded_checkpoint(
|
|
464
|
+
model=model,
|
|
465
|
+
save_directory=checkpoint_path,
|
|
466
|
+
strict=strict,
|
|
467
|
+
weights_only=weights_only,
|
|
468
|
+
filename_pattern=filename_pattern,
|
|
469
|
+
)
|
|
470
|
+
|
|
471
|
+
# Look for single model file
|
|
472
|
+
model_files = list(checkpoint_path.glob("*.safetensors" if safe else "*.bin"))
|
|
473
|
+
if len(model_files) == 1:
|
|
474
|
+
state_dict = load_state_dict_from_file(
|
|
475
|
+
checkpoint_file=model_files[0],
|
|
476
|
+
map_location=map_location,
|
|
477
|
+
weights_only=weights_only,
|
|
478
|
+
mmap=mmap,
|
|
479
|
+
)
|
|
480
|
+
return model.load_state_dict(state_dict, strict=strict)
|
|
481
|
+
|
|
482
|
+
raise ValueError(
|
|
483
|
+
f"Directory '{checkpoint_path}' does not contain a valid checkpoint. "
|
|
484
|
+
"Expected either a sharded checkpoint with an index file, or a single model file."
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
|
|
488
|
+
def _load_sharded_checkpoint(
|
|
489
|
+
model: "torch.nn.Module",
|
|
490
|
+
save_directory: os.PathLike,
|
|
491
|
+
*,
|
|
492
|
+
strict: bool = False,
|
|
493
|
+
weights_only: bool = False,
|
|
494
|
+
filename_pattern: str = constants.SAFETENSORS_WEIGHTS_FILE_PATTERN,
|
|
495
|
+
) -> NamedTuple:
|
|
496
|
+
"""
|
|
497
|
+
Loads a sharded checkpoint into a model. This is the same as
|
|
498
|
+
[`torch.nn.Module.load_state_dict`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict)
|
|
499
|
+
but for a sharded checkpoint. Each shard is loaded one by one and removed from memory after being loaded into the model.
|
|
500
|
+
|
|
501
|
+
Args:
|
|
502
|
+
model (`torch.nn.Module`):
|
|
503
|
+
The model in which to load the checkpoint.
|
|
504
|
+
save_directory (`str` or `os.PathLike`):
|
|
505
|
+
A path to a folder containing the sharded checkpoint.
|
|
506
|
+
strict (`bool`, *optional*, defaults to `False`):
|
|
507
|
+
Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint.
|
|
508
|
+
weights_only (`bool`, *optional*, defaults to `False`):
|
|
509
|
+
If True, only loads the model weights without optimizer states and other metadata.
|
|
510
|
+
Only supported in PyTorch >= 1.13.
|
|
511
|
+
filename_pattern (`str`, *optional*, defaults to `"model{suffix}.safetensors"`):
|
|
512
|
+
The pattern to look for the index file. Pattern must be a string that
|
|
513
|
+
can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix`
|
|
514
|
+
Defaults to `"model{suffix}.safetensors"`.
|
|
515
|
+
|
|
516
|
+
Returns:
|
|
517
|
+
`NamedTuple`: A named tuple with `missing_keys` and `unexpected_keys` fields,
|
|
518
|
+
- `missing_keys` is a list of str containing the missing keys
|
|
519
|
+
- `unexpected_keys` is a list of str containing the unexpected keys
|
|
520
|
+
"""
|
|
521
|
+
|
|
522
|
+
# 1. Load and validate index file
|
|
523
|
+
# The index file contains mapping of parameter names to shard files
|
|
524
|
+
index_path = filename_pattern.format(suffix="") + ".index.json"
|
|
525
|
+
index_file = os.path.join(save_directory, index_path)
|
|
526
|
+
with open(index_file, "r", encoding="utf-8") as f:
|
|
527
|
+
index = json.load(f)
|
|
528
|
+
|
|
529
|
+
# 2. Validate keys if in strict mode
|
|
530
|
+
# This is done before loading any shards to fail fast
|
|
531
|
+
if strict:
|
|
532
|
+
_validate_keys_for_strict_loading(model, index["weight_map"].keys())
|
|
533
|
+
|
|
534
|
+
# 3. Load each shard using `load_state_dict`
|
|
535
|
+
# Get unique shard files (multiple parameters can be in same shard)
|
|
536
|
+
shard_files = list(set(index["weight_map"].values()))
|
|
537
|
+
for shard_file in shard_files:
|
|
538
|
+
# Load shard into memory
|
|
539
|
+
shard_path = os.path.join(save_directory, shard_file)
|
|
540
|
+
state_dict = load_state_dict_from_file(
|
|
541
|
+
shard_path,
|
|
542
|
+
map_location="cpu",
|
|
543
|
+
weights_only=weights_only,
|
|
544
|
+
)
|
|
545
|
+
# Update model with parameters from this shard
|
|
546
|
+
model.load_state_dict(state_dict, strict=strict)
|
|
547
|
+
# Explicitly remove the state dict from memory
|
|
548
|
+
del state_dict
|
|
549
|
+
|
|
550
|
+
# 4. Return compatibility info
|
|
551
|
+
loaded_keys = set(index["weight_map"].keys())
|
|
552
|
+
model_keys = set(model.state_dict().keys())
|
|
553
|
+
return _IncompatibleKeys(
|
|
554
|
+
missing_keys=list(model_keys - loaded_keys), unexpected_keys=list(loaded_keys - model_keys)
|
|
555
|
+
)
|
|
556
|
+
|
|
557
|
+
|
|
558
|
+
def load_state_dict_from_file(
|
|
559
|
+
checkpoint_file: Union[str, os.PathLike],
|
|
560
|
+
map_location: Optional[Union[str, "torch.device"]] = None,
|
|
561
|
+
weights_only: bool = False,
|
|
562
|
+
mmap: bool = False,
|
|
563
|
+
) -> Union[Dict[str, "torch.Tensor"], Any]:
|
|
564
|
+
"""
|
|
565
|
+
Loads a checkpoint file, handling both safetensors and pickle checkpoint formats.
|
|
566
|
+
|
|
567
|
+
Args:
|
|
568
|
+
checkpoint_file (`str` or `os.PathLike`):
|
|
569
|
+
Path to the checkpoint file to load. Can be either a safetensors or pickle (`.bin`) checkpoint.
|
|
570
|
+
map_location (`str` or `torch.device`, *optional*):
|
|
571
|
+
A `torch.device` object, string or a dict specifying how to remap storage locations. It
|
|
572
|
+
indicates the location where all tensors should be loaded.
|
|
573
|
+
weights_only (`bool`, *optional*, defaults to `False`):
|
|
574
|
+
If True, only loads the model weights without optimizer states and other metadata.
|
|
575
|
+
Only supported for pickle (`.bin`) checkpoints with PyTorch >= 1.13. Has no effect when
|
|
576
|
+
loading safetensors files.
|
|
577
|
+
mmap (`bool`, *optional*, defaults to `False`):
|
|
578
|
+
Whether to use memory-mapped file loading. Memory mapping can improve loading performance
|
|
579
|
+
for large models in PyTorch >= 2.1.0 with zipfile-based checkpoints. Has no effect when
|
|
580
|
+
loading safetensors files, as the `safetensors` library uses memory mapping by default.
|
|
581
|
+
|
|
582
|
+
Returns:
|
|
583
|
+
`Union[Dict[str, "torch.Tensor"], Any]`: The loaded checkpoint.
|
|
584
|
+
- For safetensors files: always returns a dictionary mapping parameter names to tensors.
|
|
585
|
+
- For pickle files: returns any Python object that was pickled (commonly a state dict, but could be
|
|
586
|
+
an entire model, optimizer state, or any other Python object).
|
|
587
|
+
|
|
588
|
+
Raises:
|
|
589
|
+
[`FileNotFoundError`](https://docs.python.org/3/library/exceptions.html#FileNotFoundError)
|
|
590
|
+
If the checkpoint file does not exist.
|
|
591
|
+
[`ImportError`](https://docs.python.org/3/library/exceptions.html#ImportError)
|
|
592
|
+
If safetensors or torch is not installed when trying to load a .safetensors file or a PyTorch checkpoint respectively.
|
|
593
|
+
[`OSError`](https://docs.python.org/3/library/exceptions.html#OSError)
|
|
594
|
+
If the checkpoint file format is invalid or if git-lfs files are not properly downloaded.
|
|
595
|
+
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
|
596
|
+
If the checkpoint file path is empty or invalid.
|
|
597
|
+
|
|
598
|
+
Example:
|
|
599
|
+
```python
|
|
600
|
+
>>> from huggingface_hub import load_state_dict_from_file
|
|
601
|
+
|
|
602
|
+
# Load a PyTorch checkpoint
|
|
603
|
+
>>> state_dict = load_state_dict_from_file("path/to/model.bin", map_location="cpu")
|
|
604
|
+
>>> model.load_state_dict(state_dict)
|
|
605
|
+
|
|
606
|
+
# Load a safetensors checkpoint
|
|
607
|
+
>>> state_dict = load_state_dict_from_file("path/to/model.safetensors")
|
|
608
|
+
>>> model.load_state_dict(state_dict)
|
|
609
|
+
```
|
|
610
|
+
"""
|
|
611
|
+
checkpoint_path = Path(checkpoint_file)
|
|
612
|
+
|
|
613
|
+
# Check if file exists and is a regular file (not a directory)
|
|
614
|
+
if not checkpoint_path.is_file():
|
|
615
|
+
raise FileNotFoundError(
|
|
616
|
+
f"No checkpoint file found at '{checkpoint_path}'. Please verify the path is correct and "
|
|
617
|
+
"the file has been properly downloaded."
|
|
618
|
+
)
|
|
619
|
+
|
|
620
|
+
# Load safetensors checkpoint
|
|
621
|
+
if checkpoint_path.suffix == ".safetensors":
|
|
622
|
+
try:
|
|
623
|
+
from safetensors import safe_open
|
|
624
|
+
from safetensors.torch import load_file
|
|
625
|
+
except ImportError as e:
|
|
626
|
+
raise ImportError(
|
|
627
|
+
"Please install `safetensors` to load safetensors checkpoint. "
|
|
628
|
+
"You can install it with `pip install safetensors`."
|
|
629
|
+
) from e
|
|
630
|
+
|
|
631
|
+
# Check format of the archive
|
|
632
|
+
with safe_open(checkpoint_file, framework="pt") as f: # type: ignore[attr-defined]
|
|
633
|
+
metadata = f.metadata()
|
|
634
|
+
# see comment: https://github.com/huggingface/transformers/blob/3d213b57fe74302e5902d68ed9478c3ad1aaa713/src/transformers/modeling_utils.py#L3966
|
|
635
|
+
if metadata is not None and metadata.get("format") not in ["pt", "mlx"]:
|
|
636
|
+
raise OSError(
|
|
637
|
+
f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
|
|
638
|
+
"you save your model with the `save_torch_model` method."
|
|
639
|
+
)
|
|
640
|
+
device = str(map_location.type) if map_location is not None and hasattr(map_location, "type") else map_location
|
|
641
|
+
# meta device is not supported with safetensors, falling back to CPU
|
|
642
|
+
if device == "meta":
|
|
643
|
+
logger.warning("Meta device is not supported with safetensors. Falling back to CPU device.")
|
|
644
|
+
device = "cpu"
|
|
645
|
+
return load_file(checkpoint_file, device=device) # type: ignore[arg-type]
|
|
646
|
+
# Otherwise, load from pickle
|
|
647
|
+
try:
|
|
648
|
+
import torch
|
|
649
|
+
from torch import load
|
|
650
|
+
except ImportError as e:
|
|
651
|
+
raise ImportError(
|
|
652
|
+
"Please install `torch` to load torch tensors. " "You can install it with `pip install torch`."
|
|
653
|
+
) from e
|
|
654
|
+
# Add additional kwargs, mmap is only supported in torch >= 2.1.0
|
|
655
|
+
additional_kwargs = {}
|
|
656
|
+
if version.parse(torch.__version__) >= version.parse("2.1.0"):
|
|
657
|
+
additional_kwargs["mmap"] = mmap
|
|
658
|
+
|
|
659
|
+
# weights_only is only supported in torch >= 1.13.0
|
|
660
|
+
if version.parse(torch.__version__) >= version.parse("1.13.0"):
|
|
661
|
+
additional_kwargs["weights_only"] = weights_only
|
|
662
|
+
|
|
663
|
+
return load(
|
|
664
|
+
checkpoint_file,
|
|
665
|
+
map_location=map_location,
|
|
666
|
+
**additional_kwargs,
|
|
667
|
+
)
|
|
668
|
+
|
|
669
|
+
|
|
670
|
+
# HELPERS
|
|
671
|
+
|
|
672
|
+
|
|
673
|
+
def _validate_keys_for_strict_loading(
|
|
674
|
+
model: "torch.nn.Module",
|
|
675
|
+
loaded_keys: Iterable[str],
|
|
676
|
+
) -> None:
|
|
677
|
+
"""
|
|
678
|
+
Validate that model keys match loaded keys when strict loading is enabled.
|
|
679
|
+
|
|
680
|
+
Args:
|
|
681
|
+
model: The PyTorch model being loaded
|
|
682
|
+
loaded_keys: The keys present in the checkpoint
|
|
683
|
+
|
|
684
|
+
Raises:
|
|
685
|
+
RuntimeError: If there are missing or unexpected keys in strict mode
|
|
686
|
+
"""
|
|
687
|
+
loaded_keys_set = set(loaded_keys)
|
|
688
|
+
model_keys = set(model.state_dict().keys())
|
|
689
|
+
missing_keys = model_keys - loaded_keys_set # Keys in model but not in checkpoint
|
|
690
|
+
unexpected_keys = loaded_keys_set - model_keys # Keys in checkpoint but not in model
|
|
691
|
+
|
|
692
|
+
if missing_keys or unexpected_keys:
|
|
693
|
+
error_message = f"Error(s) in loading state_dict for {model.__class__.__name__}"
|
|
694
|
+
if missing_keys:
|
|
695
|
+
str_missing_keys = ",".join([f'"{k}"' for k in sorted(missing_keys)])
|
|
696
|
+
error_message += f"\nMissing key(s): {str_missing_keys}."
|
|
697
|
+
if unexpected_keys:
|
|
698
|
+
str_unexpected_keys = ",".join([f'"{k}"' for k in sorted(unexpected_keys)])
|
|
699
|
+
error_message += f"\nUnexpected key(s): {str_unexpected_keys}."
|
|
700
|
+
raise RuntimeError(error_message)
|
|
701
|
+
|
|
702
|
+
|
|
360
703
|
def _get_unique_id(tensor: "torch.Tensor") -> Union[int, Tuple[Any, ...]]:
|
|
361
704
|
"""Returns a unique id for plain tensor
|
|
362
705
|
or a (potentially nested) Tuple of unique id for the flattened Tensor
|
|
@@ -380,7 +723,7 @@ def _get_unique_id(tensor: "torch.Tensor") -> Union[int, Tuple[Any, ...]]:
|
|
|
380
723
|
# use some other unique id to distinguish.
|
|
381
724
|
# this is a XLA tensor, it must be created using torch_xla's
|
|
382
725
|
# device. So the following import is safe:
|
|
383
|
-
import torch_xla
|
|
726
|
+
import torch_xla # type: ignore[import]
|
|
384
727
|
|
|
385
728
|
unique_id = torch_xla._XLAC._xla_get_tensor_id(tensor)
|
|
386
729
|
else:
|
|
@@ -444,7 +787,7 @@ def is_torch_tpu_available(check_device=True):
|
|
|
444
787
|
if check_device:
|
|
445
788
|
# We need to check if `xla_device` can be found, will raise a RuntimeError if not
|
|
446
789
|
try:
|
|
447
|
-
import torch_xla.core.xla_model as xm
|
|
790
|
+
import torch_xla.core.xla_model as xm # type: ignore[import]
|
|
448
791
|
|
|
449
792
|
_ = xm.xla_device()
|
|
450
793
|
return True
|
|
@@ -463,7 +806,7 @@ def storage_ptr(tensor: "torch.Tensor") -> Union[int, Tuple[Any, ...]]:
|
|
|
463
806
|
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
|
|
464
807
|
|
|
465
808
|
if is_traceable_wrapper_subclass(tensor):
|
|
466
|
-
return _get_unique_id(tensor)
|
|
809
|
+
return _get_unique_id(tensor) # type: ignore
|
|
467
810
|
except ImportError:
|
|
468
811
|
# for torch version less than 2.1, we can fallback to original implementation
|
|
469
812
|
pass
|
|
@@ -655,3 +998,18 @@ def _get_dtype_size(dtype: "torch.dtype") -> int:
|
|
|
655
998
|
_float8_e5m2: 1,
|
|
656
999
|
}
|
|
657
1000
|
return _SIZE[dtype]
|
|
1001
|
+
|
|
1002
|
+
|
|
1003
|
+
class _IncompatibleKeys(namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"])):
|
|
1004
|
+
"""
|
|
1005
|
+
This is used to report missing and unexpected keys in the state dict.
|
|
1006
|
+
Taken from https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/module.py#L52.
|
|
1007
|
+
|
|
1008
|
+
"""
|
|
1009
|
+
|
|
1010
|
+
def __repr__(self) -> str:
|
|
1011
|
+
if not self.missing_keys and not self.unexpected_keys:
|
|
1012
|
+
return "<All keys matched successfully>"
|
|
1013
|
+
return super().__repr__()
|
|
1014
|
+
|
|
1015
|
+
__str__ = __repr__
|
|
@@ -742,7 +742,7 @@ def _scan_cached_repo(repo_path: Path) -> CachedRepoInfo:
|
|
|
742
742
|
|
|
743
743
|
for ref_path in refs_path.glob("**/*"):
|
|
744
744
|
# glob("**/*") iterates over all files and directories -> skip directories
|
|
745
|
-
if ref_path.is_dir():
|
|
745
|
+
if ref_path.is_dir() or ref_path.name in FILES_TO_IGNORE:
|
|
746
746
|
continue
|
|
747
747
|
|
|
748
748
|
ref_name = str(ref_path.relative_to(refs_path))
|
|
@@ -20,6 +20,7 @@ from huggingface_hub.errors import LocalTokenNotFoundError
|
|
|
20
20
|
|
|
21
21
|
from .. import constants
|
|
22
22
|
from ._auth import get_token
|
|
23
|
+
from ._deprecation import _deprecate_arguments
|
|
23
24
|
from ._runtime import (
|
|
24
25
|
get_fastai_version,
|
|
25
26
|
get_fastcore_version,
|
|
@@ -35,15 +36,20 @@ from ._runtime import (
|
|
|
35
36
|
from ._validators import validate_hf_hub_args
|
|
36
37
|
|
|
37
38
|
|
|
39
|
+
@_deprecate_arguments(
|
|
40
|
+
version="1.0",
|
|
41
|
+
deprecated_args="is_write_action",
|
|
42
|
+
custom_message="This argument is ignored and we let the server handle the permission error instead (if any).",
|
|
43
|
+
)
|
|
38
44
|
@validate_hf_hub_args
|
|
39
45
|
def build_hf_headers(
|
|
40
46
|
*,
|
|
41
47
|
token: Optional[Union[bool, str]] = None,
|
|
42
|
-
is_write_action: bool = False,
|
|
43
48
|
library_name: Optional[str] = None,
|
|
44
49
|
library_version: Optional[str] = None,
|
|
45
50
|
user_agent: Union[Dict, str, None] = None,
|
|
46
51
|
headers: Optional[Dict[str, str]] = None,
|
|
52
|
+
is_write_action: bool = False,
|
|
47
53
|
) -> Dict[str, str]:
|
|
48
54
|
"""
|
|
49
55
|
Build headers dictionary to send in a HF Hub call.
|
|
@@ -68,9 +74,6 @@ def build_hf_headers(
|
|
|
68
74
|
- if `False`, authorization header is not set
|
|
69
75
|
- if `None`, the token is read from the machine only except if
|
|
70
76
|
`HF_HUB_DISABLE_IMPLICIT_TOKEN` env variable is set.
|
|
71
|
-
is_write_action (`bool`, default to `False`):
|
|
72
|
-
Set to True if the API call requires a write access. If `True`, the token
|
|
73
|
-
will be validated (cannot be `None`, cannot start by `"api_org***"`).
|
|
74
77
|
library_name (`str`, *optional*):
|
|
75
78
|
The name of the library that is making the HTTP request. Will be added to
|
|
76
79
|
the user-agent header.
|
|
@@ -83,6 +86,8 @@ def build_hf_headers(
|
|
|
83
86
|
headers (`dict`, *optional*):
|
|
84
87
|
Additional headers to include in the request. Those headers take precedence
|
|
85
88
|
over the ones generated by this function.
|
|
89
|
+
is_write_action (`bool`):
|
|
90
|
+
Ignored and deprecated argument.
|
|
86
91
|
|
|
87
92
|
Returns:
|
|
88
93
|
A `Dict` of headers to pass in your API call.
|
|
@@ -105,9 +110,6 @@ def build_hf_headers(
|
|
|
105
110
|
>>> build_hf_headers() # token is not sent
|
|
106
111
|
{"user-agent": ...}
|
|
107
112
|
|
|
108
|
-
>>> build_hf_headers(token="api_org_***", is_write_action=True)
|
|
109
|
-
ValueError: You must use your personal account token for write-access methods.
|
|
110
|
-
|
|
111
113
|
>>> build_hf_headers(library_name="transformers", library_version="1.2.3")
|
|
112
114
|
{"authorization": ..., "user-agent": "transformers/1.2.3; hf_hub/0.10.2; python/3.10.4; tensorflow/1.55"}
|
|
113
115
|
```
|
|
@@ -122,7 +124,6 @@ def build_hf_headers(
|
|
|
122
124
|
"""
|
|
123
125
|
# Get auth token to send
|
|
124
126
|
token_to_send = get_token_to_send(token)
|
|
125
|
-
_validate_token_to_send(token_to_send, is_write_action=is_write_action)
|
|
126
127
|
|
|
127
128
|
# Combine headers
|
|
128
129
|
hf_headers = {
|
|
@@ -171,23 +172,6 @@ def get_token_to_send(token: Optional[Union[bool, str]]) -> Optional[str]:
|
|
|
171
172
|
return cached_token
|
|
172
173
|
|
|
173
174
|
|
|
174
|
-
def _validate_token_to_send(token: Optional[str], is_write_action: bool) -> None:
|
|
175
|
-
if is_write_action:
|
|
176
|
-
if token is None:
|
|
177
|
-
raise ValueError(
|
|
178
|
-
"Token is required (write-access action) but no token found. You need"
|
|
179
|
-
" to provide a token or be logged in to Hugging Face with"
|
|
180
|
-
" `huggingface-cli login` or `huggingface_hub.login`. See"
|
|
181
|
-
" https://huggingface.co/settings/tokens."
|
|
182
|
-
)
|
|
183
|
-
if token.startswith("api_org"):
|
|
184
|
-
raise ValueError(
|
|
185
|
-
"You must use your personal account token for write-access methods. To"
|
|
186
|
-
" generate a write-access token, go to"
|
|
187
|
-
" https://huggingface.co/settings/tokens"
|
|
188
|
-
)
|
|
189
|
-
|
|
190
|
-
|
|
191
175
|
def _http_user_agent(
|
|
192
176
|
*,
|
|
193
177
|
library_name: Optional[str] = None,
|
huggingface_hub/utils/tqdm.py
CHANGED
|
@@ -81,6 +81,8 @@ Group-based control:
|
|
|
81
81
|
"""
|
|
82
82
|
|
|
83
83
|
import io
|
|
84
|
+
import logging
|
|
85
|
+
import os
|
|
84
86
|
import warnings
|
|
85
87
|
from contextlib import contextmanager
|
|
86
88
|
from pathlib import Path
|
|
@@ -196,6 +198,19 @@ def are_progress_bars_disabled(name: Optional[str] = None) -> bool:
|
|
|
196
198
|
return not progress_bar_states.get("_global", True)
|
|
197
199
|
|
|
198
200
|
|
|
201
|
+
def is_tqdm_disabled(log_level: int) -> Optional[bool]:
|
|
202
|
+
"""
|
|
203
|
+
Determine if tqdm progress bars should be disabled based on logging level and environment settings.
|
|
204
|
+
|
|
205
|
+
see https://github.com/huggingface/huggingface_hub/pull/2000 and https://github.com/huggingface/huggingface_hub/pull/2698.
|
|
206
|
+
"""
|
|
207
|
+
if log_level == logging.NOTSET:
|
|
208
|
+
return True
|
|
209
|
+
if os.getenv("TQDM_POSITION") == "-1":
|
|
210
|
+
return False
|
|
211
|
+
return None
|
|
212
|
+
|
|
213
|
+
|
|
199
214
|
class tqdm(old_tqdm):
|
|
200
215
|
"""
|
|
201
216
|
Class to override `disable` argument in case progress bars are globally disabled.
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: huggingface-hub
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.27.0
|
|
4
4
|
Summary: Client library to download and publish models, datasets and other repos on the huggingface.co hub
|
|
5
5
|
Home-page: https://github.com/huggingface/huggingface_hub
|
|
6
6
|
Author: Hugging Face, Inc.
|
|
@@ -20,6 +20,7 @@ Classifier: Programming Language :: Python :: 3.9
|
|
|
20
20
|
Classifier: Programming Language :: Python :: 3.10
|
|
21
21
|
Classifier: Programming Language :: Python :: 3.11
|
|
22
22
|
Classifier: Programming Language :: Python :: 3.12
|
|
23
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
23
24
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
24
25
|
Requires-Python: >=3.8.0
|
|
25
26
|
Description-Content-Type: text/markdown
|
|
@@ -142,10 +143,14 @@ Requires-Dist: types-tqdm; extra == "typing"
|
|
|
142
143
|
Requires-Dist: types-urllib3; extra == "typing"
|
|
143
144
|
|
|
144
145
|
<p align="center">
|
|
146
|
+
<picture>
|
|
147
|
+
<source media="(prefers-color-scheme: dark)" srcset="https://huggingface.co/datasets/huggingface/documentation-images/raw/main/huggingface_hub-dark.svg">
|
|
148
|
+
<source media="(prefers-color-scheme: light)" srcset="https://huggingface.co/datasets/huggingface/documentation-images/raw/main/huggingface_hub.svg">
|
|
149
|
+
<img alt="huggingface_hub library logo" src="https://huggingface.co/datasets/huggingface/documentation-images/raw/main/huggingface_hub.svg" width="352" height="59" style="max-width: 100%;">
|
|
150
|
+
</picture>
|
|
145
151
|
<br/>
|
|
146
|
-
<img alt="huggingface_hub library logo" src="https://huggingface.co/datasets/huggingface/documentation-images/raw/main/huggingface_hub.svg" width="376" height="59" style="max-width: 100%;">
|
|
147
152
|
<br/>
|
|
148
|
-
</p>
|
|
153
|
+
</p>
|
|
149
154
|
|
|
150
155
|
<p align="center">
|
|
151
156
|
<i>The official Python client for the Huggingface Hub.</i>
|