huggingface-hub 0.26.3__py3-none-any.whl → 0.27.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of huggingface-hub might be problematic. Click here for more details.

Files changed (61) hide show
  1. huggingface_hub/__init__.py +49 -23
  2. huggingface_hub/_commit_scheduler.py +30 -4
  3. huggingface_hub/_local_folder.py +0 -4
  4. huggingface_hub/_login.py +38 -54
  5. huggingface_hub/_snapshot_download.py +6 -3
  6. huggingface_hub/_tensorboard_logger.py +2 -3
  7. huggingface_hub/_upload_large_folder.py +1 -1
  8. huggingface_hub/errors.py +19 -0
  9. huggingface_hub/fastai_utils.py +3 -2
  10. huggingface_hub/file_download.py +10 -12
  11. huggingface_hub/hf_api.py +102 -498
  12. huggingface_hub/hf_file_system.py +274 -35
  13. huggingface_hub/hub_mixin.py +5 -25
  14. huggingface_hub/inference/_client.py +185 -136
  15. huggingface_hub/inference/_common.py +2 -2
  16. huggingface_hub/inference/_generated/_async_client.py +186 -137
  17. huggingface_hub/inference/_generated/types/__init__.py +31 -10
  18. huggingface_hub/inference/_generated/types/audio_classification.py +3 -5
  19. huggingface_hub/inference/_generated/types/automatic_speech_recognition.py +6 -9
  20. huggingface_hub/inference/_generated/types/chat_completion.py +8 -5
  21. huggingface_hub/inference/_generated/types/depth_estimation.py +1 -1
  22. huggingface_hub/inference/_generated/types/document_question_answering.py +2 -6
  23. huggingface_hub/inference/_generated/types/feature_extraction.py +1 -1
  24. huggingface_hub/inference/_generated/types/fill_mask.py +2 -4
  25. huggingface_hub/inference/_generated/types/image_classification.py +3 -5
  26. huggingface_hub/inference/_generated/types/image_segmentation.py +2 -4
  27. huggingface_hub/inference/_generated/types/image_to_image.py +2 -4
  28. huggingface_hub/inference/_generated/types/image_to_text.py +6 -9
  29. huggingface_hub/inference/_generated/types/object_detection.py +2 -4
  30. huggingface_hub/inference/_generated/types/question_answering.py +2 -4
  31. huggingface_hub/inference/_generated/types/sentence_similarity.py +1 -1
  32. huggingface_hub/inference/_generated/types/summarization.py +2 -4
  33. huggingface_hub/inference/_generated/types/table_question_answering.py +21 -3
  34. huggingface_hub/inference/_generated/types/text2text_generation.py +2 -4
  35. huggingface_hub/inference/_generated/types/text_classification.py +4 -10
  36. huggingface_hub/inference/_generated/types/text_to_audio.py +7 -10
  37. huggingface_hub/inference/_generated/types/text_to_image.py +2 -4
  38. huggingface_hub/inference/_generated/types/text_to_speech.py +7 -10
  39. huggingface_hub/inference/_generated/types/token_classification.py +11 -12
  40. huggingface_hub/inference/_generated/types/translation.py +2 -4
  41. huggingface_hub/inference/_generated/types/video_classification.py +3 -4
  42. huggingface_hub/inference/_generated/types/visual_question_answering.py +2 -5
  43. huggingface_hub/inference/_generated/types/zero_shot_classification.py +8 -18
  44. huggingface_hub/inference/_generated/types/zero_shot_image_classification.py +9 -19
  45. huggingface_hub/inference/_generated/types/zero_shot_object_detection.py +7 -9
  46. huggingface_hub/keras_mixin.py +3 -2
  47. huggingface_hub/lfs.py +2 -5
  48. huggingface_hub/repocard_data.py +4 -4
  49. huggingface_hub/serialization/__init__.py +2 -0
  50. huggingface_hub/serialization/_dduf.py +387 -0
  51. huggingface_hub/serialization/_torch.py +407 -25
  52. huggingface_hub/utils/_cache_manager.py +1 -1
  53. huggingface_hub/utils/_headers.py +9 -25
  54. huggingface_hub/utils/tqdm.py +15 -0
  55. {huggingface_hub-0.26.3.dist-info → huggingface_hub-0.27.0rc1.dist-info}/METADATA +8 -3
  56. {huggingface_hub-0.26.3.dist-info → huggingface_hub-0.27.0rc1.dist-info}/RECORD +60 -60
  57. huggingface_hub/_multi_commits.py +0 -306
  58. {huggingface_hub-0.26.3.dist-info → huggingface_hub-0.27.0rc1.dist-info}/LICENSE +0 -0
  59. {huggingface_hub-0.26.3.dist-info → huggingface_hub-0.27.0rc1.dist-info}/WHEEL +0 -0
  60. {huggingface_hub-0.26.3.dist-info → huggingface_hub-0.27.0rc1.dist-info}/entry_points.txt +0 -0
  61. {huggingface_hub-0.26.3.dist-info → huggingface_hub-0.27.0rc1.dist-info}/top_level.txt +0 -0
@@ -17,10 +17,12 @@ import importlib
17
17
  import json
18
18
  import os
19
19
  import re
20
- from collections import defaultdict
20
+ from collections import defaultdict, namedtuple
21
21
  from functools import lru_cache
22
22
  from pathlib import Path
23
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
23
+ from typing import TYPE_CHECKING, Any, Dict, Iterable, List, NamedTuple, Optional, Set, Tuple, Union
24
+
25
+ from packaging import version
24
26
 
25
27
  from .. import constants, logging
26
28
  from ._base import MAX_SHARD_SIZE, StateDictSplit, split_state_dict_into_shards_factory
@@ -31,6 +33,8 @@ logger = logging.get_logger(__file__)
31
33
  if TYPE_CHECKING:
32
34
  import torch
33
35
 
36
+ # SAVING
37
+
34
38
 
35
39
  def save_torch_model(
36
40
  model: "torch.nn.Module",
@@ -41,6 +45,8 @@ def save_torch_model(
41
45
  max_shard_size: Union[int, str] = MAX_SHARD_SIZE,
42
46
  metadata: Optional[Dict[str, str]] = None,
43
47
  safe_serialization: bool = True,
48
+ is_main_process: bool = True,
49
+ shared_tensors_to_discard: Optional[List[str]] = None,
44
50
  ):
45
51
  """
46
52
  Saves a given torch model to disk, handling sharding and shared tensors issues.
@@ -64,6 +70,12 @@ def save_torch_model(
64
70
 
65
71
  </Tip>
66
72
 
73
+ <Tip warning={true}>
74
+
75
+ If your model is a `transformers.PreTrainedModel`, you should pass `model._tied_weights_keys` as `shared_tensors_to_discard` to properly handle shared tensors saving. This ensures the correct duplicate tensors are discarded during saving.
76
+
77
+ </Tip>
78
+
67
79
  Args:
68
80
  model (`torch.nn.Module`):
69
81
  The model to save on disk.
@@ -88,6 +100,13 @@ def save_torch_model(
88
100
  Whether to save as safetensors, which is the default behavior. If `False`, the shards are saved as pickle.
89
101
  Safe serialization is recommended for security reasons. Saving as pickle is deprecated and will be removed
90
102
  in a future version.
103
+ is_main_process (`bool`, *optional*):
104
+ Whether the process calling this is the main process or not. Useful when in distributed training like
105
+ TPUs and need to call this function from all processes. In this case, set `is_main_process=True` only on
106
+ the main process to avoid race conditions. Defaults to True.
107
+ shared_tensors_to_discard (`List[str]`, *optional*):
108
+ List of tensor names to drop when saving shared tensors. If not provided and shared tensors are
109
+ detected, it will drop the first name alphabetically.
91
110
 
92
111
  Example:
93
112
 
@@ -112,6 +131,8 @@ def save_torch_model(
112
131
  metadata=metadata,
113
132
  safe_serialization=safe_serialization,
114
133
  save_directory=save_directory,
134
+ is_main_process=is_main_process,
135
+ shared_tensors_to_discard=shared_tensors_to_discard,
115
136
  )
116
137
 
117
138
 
@@ -124,6 +145,8 @@ def save_torch_state_dict(
124
145
  max_shard_size: Union[int, str] = MAX_SHARD_SIZE,
125
146
  metadata: Optional[Dict[str, str]] = None,
126
147
  safe_serialization: bool = True,
148
+ is_main_process: bool = True,
149
+ shared_tensors_to_discard: Optional[List[str]] = None,
127
150
  ) -> None:
128
151
  """
129
152
  Save a model state dictionary to the disk, handling sharding and shared tensors issues.
@@ -147,6 +170,12 @@ def save_torch_state_dict(
147
170
 
148
171
  </Tip>
149
172
 
173
+ <Tip warning={true}>
174
+
175
+ If your model is a `transformers.PreTrainedModel`, you should pass `model._tied_weights_keys` as `shared_tensors_to_discard` to properly handle shared tensors saving. This ensures the correct duplicate tensors are discarded during saving.
176
+
177
+ </Tip>
178
+
150
179
  Args:
151
180
  state_dict (`Dict[str, torch.Tensor]`):
152
181
  The state dictionary to save.
@@ -171,6 +200,13 @@ def save_torch_state_dict(
171
200
  Whether to save as safetensors, which is the default behavior. If `False`, the shards are saved as pickle.
172
201
  Safe serialization is recommended for security reasons. Saving as pickle is deprecated and will be removed
173
202
  in a future version.
203
+ is_main_process (`bool`, *optional*):
204
+ Whether the process calling this is the main process or not. Useful when in distributed training like
205
+ TPUs and need to call this function from all processes. In this case, set `is_main_process=True` only on
206
+ the main process to avoid race conditions. Defaults to True.
207
+ shared_tensors_to_discard (`List[str]`, *optional*):
208
+ List of tensor names to drop when saving shared tensors. If not provided and shared tensors are
209
+ detected, it will drop the first name alphabetically.
174
210
 
175
211
  Example:
176
212
 
@@ -192,7 +228,8 @@ def save_torch_state_dict(
192
228
  else constants.PYTORCH_WEIGHTS_FILE_PATTERN
193
229
  )
194
230
 
195
- # Imports correct library
231
+ if metadata is None:
232
+ metadata = {}
196
233
  if safe_serialization:
197
234
  try:
198
235
  from safetensors.torch import save_file as save_file_fn
@@ -201,7 +238,13 @@ def save_torch_state_dict(
201
238
  "Please install `safetensors` to use safe serialization. "
202
239
  "You can install it with `pip install safetensors`."
203
240
  ) from e
204
-
241
+ # Clean state dict for safetensors
242
+ state_dict = _clean_state_dict_for_safetensors(
243
+ state_dict,
244
+ metadata,
245
+ force_contiguous=force_contiguous,
246
+ shared_tensors_to_discard=shared_tensors_to_discard,
247
+ )
205
248
  else:
206
249
  from torch import save as save_file_fn # type: ignore[assignment]
207
250
 
@@ -210,27 +253,23 @@ def save_torch_state_dict(
210
253
  "pickled models from untrusted sources. If you intend to share your model, we strongly recommend "
211
254
  "using safe serialization by installing `safetensors` with `pip install safetensors`."
212
255
  )
213
-
214
- # Clean state dict for safetensors
215
- if metadata is None:
216
- metadata = {}
217
- if safe_serialization:
218
- state_dict = _clean_state_dict_for_safetensors(state_dict, metadata, force_contiguous=force_contiguous)
219
-
220
256
  # Split dict
221
257
  state_dict_split = split_torch_state_dict_into_shards(
222
258
  state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
223
259
  )
224
260
 
225
- # Clean the folder from previous save
226
- existing_files_regex = re.compile(filename_pattern.format(suffix=r"(-\d{5}-of-\d{5})?") + r"(\.index\.json)?")
227
- for filename in os.listdir(save_directory):
228
- if existing_files_regex.match(filename):
229
- try:
230
- logger.debug(f"Removing existing file '{filename}' from folder.")
231
- os.remove(os.path.join(save_directory, filename))
232
- except Exception as e:
233
- logger.warning(f"Error when trying to remove existing '{filename}' from folder: {e}. Continuing...")
261
+ # Only main process should clean up existing files to avoid race conditions in distributed environment
262
+ if is_main_process:
263
+ existing_files_regex = re.compile(filename_pattern.format(suffix=r"(-\d{5}-of-\d{5})?") + r"(\.index\.json)?")
264
+ for filename in os.listdir(save_directory):
265
+ if existing_files_regex.match(filename):
266
+ try:
267
+ logger.debug(f"Removing existing file '{filename}' from folder.")
268
+ os.remove(os.path.join(save_directory, filename))
269
+ except Exception as e:
270
+ logger.warning(
271
+ f"Error when trying to remove existing '{filename}' from folder: {e}. Continuing..."
272
+ )
234
273
 
235
274
  # Save each shard
236
275
  per_file_metadata = {"format": "pt"}
@@ -336,6 +375,331 @@ def split_torch_state_dict_into_shards(
336
375
  )
337
376
 
338
377
 
378
+ # LOADING
379
+
380
+
381
+ def load_torch_model(
382
+ model: "torch.nn.Module",
383
+ checkpoint_path: Union[str, os.PathLike],
384
+ *,
385
+ strict: bool = False,
386
+ safe: bool = True,
387
+ weights_only: bool = False,
388
+ map_location: Optional[Union[str, "torch.device"]] = None,
389
+ mmap: bool = False,
390
+ filename_pattern: Optional[str] = None,
391
+ ) -> NamedTuple:
392
+ """
393
+ Load a checkpoint into a model, handling both sharded and non-sharded checkpoints.
394
+
395
+ Args:
396
+ model (`torch.nn.Module`):
397
+ The model in which to load the checkpoint.
398
+ checkpoint_path (`str` or `os.PathLike`):
399
+ Path to either the checkpoint file or directory containing the checkpoint(s).
400
+ strict (`bool`, *optional*, defaults to `False`):
401
+ Whether to strictly enforce that the keys in the model state dict match the keys in the checkpoint.
402
+ safe (`bool`, *optional*, defaults to `True`):
403
+ If `safe` is True, the safetensors files will be loaded. If `safe` is False, the function
404
+ will first attempt to load safetensors files if they are available, otherwise it will fall back to loading
405
+ pickle files. `filename_pattern` parameter takes precedence over `safe` parameter.
406
+ weights_only (`bool`, *optional*, defaults to `False`):
407
+ If True, only loads the model weights without optimizer states and other metadata.
408
+ Only supported in PyTorch >= 1.13.
409
+ map_location (`str` or `torch.device`, *optional*):
410
+ A `torch.device` object, string or a dict specifying how to remap storage locations. It
411
+ indicates the location where all tensors should be loaded.
412
+ mmap (`bool`, *optional*, defaults to `False`):
413
+ Whether to use memory-mapped file loading. Memory mapping can improve loading performance
414
+ for large models in PyTorch >= 2.1.0 with zipfile-based checkpoints.
415
+ filename_pattern (`str`, *optional*):
416
+ The pattern to look for the index file. Pattern must be a string that
417
+ can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix`
418
+ Defaults to `"model{suffix}.safetensors"`.
419
+ Returns:
420
+ `NamedTuple`: A named tuple with `missing_keys` and `unexpected_keys` fields.
421
+ - `missing_keys` is a list of str containing the missing keys, i.e. keys that are in the model but not in the checkpoint.
422
+ - `unexpected_keys` is a list of str containing the unexpected keys, i.e. keys that are in the checkpoint but not in the model.
423
+
424
+ Raises:
425
+ [`FileNotFoundError`](https://docs.python.org/3/library/exceptions.html#FileNotFoundError)
426
+ If the checkpoint file or directory does not exist.
427
+ [`ImportError`](https://docs.python.org/3/library/exceptions.html#ImportError)
428
+ If safetensors or torch is not installed when trying to load a .safetensors file or a PyTorch checkpoint respectively.
429
+ [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
430
+ If the checkpoint path is invalid or if the checkpoint format cannot be determined.
431
+
432
+ Example:
433
+ ```python
434
+ >>> from huggingface_hub import load_torch_model
435
+ >>> model = ... # A PyTorch model
436
+ >>> load_torch_model(model, "path/to/checkpoint")
437
+ ```
438
+ """
439
+ checkpoint_path = Path(checkpoint_path)
440
+
441
+ if not checkpoint_path.exists():
442
+ raise ValueError(f"Checkpoint path {checkpoint_path} does not exist")
443
+ # 1. Check if checkpoint is a single file
444
+ if checkpoint_path.is_file():
445
+ state_dict = load_state_dict_from_file(
446
+ checkpoint_file=checkpoint_path,
447
+ map_location=map_location,
448
+ weights_only=weights_only,
449
+ )
450
+ return model.load_state_dict(state_dict, strict=strict)
451
+
452
+ # 2. If not, checkpoint_path is a directory
453
+ if filename_pattern is None:
454
+ filename_pattern = constants.SAFETENSORS_WEIGHTS_FILE_PATTERN
455
+ index_path = checkpoint_path / (filename_pattern.format(suffix="") + ".index.json")
456
+ # Only fallback to pickle format if safetensors index is not found and safe is False.
457
+ if not index_path.is_file() and not safe:
458
+ filename_pattern = constants.PYTORCH_WEIGHTS_FILE_PATTERN
459
+
460
+ index_path = checkpoint_path / (filename_pattern.format(suffix="") + ".index.json")
461
+
462
+ if index_path.is_file():
463
+ return _load_sharded_checkpoint(
464
+ model=model,
465
+ save_directory=checkpoint_path,
466
+ strict=strict,
467
+ weights_only=weights_only,
468
+ filename_pattern=filename_pattern,
469
+ )
470
+
471
+ # Look for single model file
472
+ model_files = list(checkpoint_path.glob("*.safetensors" if safe else "*.bin"))
473
+ if len(model_files) == 1:
474
+ state_dict = load_state_dict_from_file(
475
+ checkpoint_file=model_files[0],
476
+ map_location=map_location,
477
+ weights_only=weights_only,
478
+ mmap=mmap,
479
+ )
480
+ return model.load_state_dict(state_dict, strict=strict)
481
+
482
+ raise ValueError(
483
+ f"Directory '{checkpoint_path}' does not contain a valid checkpoint. "
484
+ "Expected either a sharded checkpoint with an index file, or a single model file."
485
+ )
486
+
487
+
488
+ def _load_sharded_checkpoint(
489
+ model: "torch.nn.Module",
490
+ save_directory: os.PathLike,
491
+ *,
492
+ strict: bool = False,
493
+ weights_only: bool = False,
494
+ filename_pattern: str = constants.SAFETENSORS_WEIGHTS_FILE_PATTERN,
495
+ ) -> NamedTuple:
496
+ """
497
+ Loads a sharded checkpoint into a model. This is the same as
498
+ [`torch.nn.Module.load_state_dict`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict)
499
+ but for a sharded checkpoint. Each shard is loaded one by one and removed from memory after being loaded into the model.
500
+
501
+ Args:
502
+ model (`torch.nn.Module`):
503
+ The model in which to load the checkpoint.
504
+ save_directory (`str` or `os.PathLike`):
505
+ A path to a folder containing the sharded checkpoint.
506
+ strict (`bool`, *optional*, defaults to `False`):
507
+ Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint.
508
+ weights_only (`bool`, *optional*, defaults to `False`):
509
+ If True, only loads the model weights without optimizer states and other metadata.
510
+ Only supported in PyTorch >= 1.13.
511
+ filename_pattern (`str`, *optional*, defaults to `"model{suffix}.safetensors"`):
512
+ The pattern to look for the index file. Pattern must be a string that
513
+ can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix`
514
+ Defaults to `"model{suffix}.safetensors"`.
515
+
516
+ Returns:
517
+ `NamedTuple`: A named tuple with `missing_keys` and `unexpected_keys` fields,
518
+ - `missing_keys` is a list of str containing the missing keys
519
+ - `unexpected_keys` is a list of str containing the unexpected keys
520
+ """
521
+
522
+ # 1. Load and validate index file
523
+ # The index file contains mapping of parameter names to shard files
524
+ index_path = filename_pattern.format(suffix="") + ".index.json"
525
+ index_file = os.path.join(save_directory, index_path)
526
+ with open(index_file, "r", encoding="utf-8") as f:
527
+ index = json.load(f)
528
+
529
+ # 2. Validate keys if in strict mode
530
+ # This is done before loading any shards to fail fast
531
+ if strict:
532
+ _validate_keys_for_strict_loading(model, index["weight_map"].keys())
533
+
534
+ # 3. Load each shard using `load_state_dict`
535
+ # Get unique shard files (multiple parameters can be in same shard)
536
+ shard_files = list(set(index["weight_map"].values()))
537
+ for shard_file in shard_files:
538
+ # Load shard into memory
539
+ shard_path = os.path.join(save_directory, shard_file)
540
+ state_dict = load_state_dict_from_file(
541
+ shard_path,
542
+ map_location="cpu",
543
+ weights_only=weights_only,
544
+ )
545
+ # Update model with parameters from this shard
546
+ model.load_state_dict(state_dict, strict=strict)
547
+ # Explicitly remove the state dict from memory
548
+ del state_dict
549
+
550
+ # 4. Return compatibility info
551
+ loaded_keys = set(index["weight_map"].keys())
552
+ model_keys = set(model.state_dict().keys())
553
+ return _IncompatibleKeys(
554
+ missing_keys=list(model_keys - loaded_keys), unexpected_keys=list(loaded_keys - model_keys)
555
+ )
556
+
557
+
558
+ def load_state_dict_from_file(
559
+ checkpoint_file: Union[str, os.PathLike],
560
+ map_location: Optional[Union[str, "torch.device"]] = None,
561
+ weights_only: bool = False,
562
+ mmap: bool = False,
563
+ ) -> Union[Dict[str, "torch.Tensor"], Any]:
564
+ """
565
+ Loads a checkpoint file, handling both safetensors and pickle checkpoint formats.
566
+
567
+ Args:
568
+ checkpoint_file (`str` or `os.PathLike`):
569
+ Path to the checkpoint file to load. Can be either a safetensors or pickle (`.bin`) checkpoint.
570
+ map_location (`str` or `torch.device`, *optional*):
571
+ A `torch.device` object, string or a dict specifying how to remap storage locations. It
572
+ indicates the location where all tensors should be loaded.
573
+ weights_only (`bool`, *optional*, defaults to `False`):
574
+ If True, only loads the model weights without optimizer states and other metadata.
575
+ Only supported for pickle (`.bin`) checkpoints with PyTorch >= 1.13. Has no effect when
576
+ loading safetensors files.
577
+ mmap (`bool`, *optional*, defaults to `False`):
578
+ Whether to use memory-mapped file loading. Memory mapping can improve loading performance
579
+ for large models in PyTorch >= 2.1.0 with zipfile-based checkpoints. Has no effect when
580
+ loading safetensors files, as the `safetensors` library uses memory mapping by default.
581
+
582
+ Returns:
583
+ `Union[Dict[str, "torch.Tensor"], Any]`: The loaded checkpoint.
584
+ - For safetensors files: always returns a dictionary mapping parameter names to tensors.
585
+ - For pickle files: returns any Python object that was pickled (commonly a state dict, but could be
586
+ an entire model, optimizer state, or any other Python object).
587
+
588
+ Raises:
589
+ [`FileNotFoundError`](https://docs.python.org/3/library/exceptions.html#FileNotFoundError)
590
+ If the checkpoint file does not exist.
591
+ [`ImportError`](https://docs.python.org/3/library/exceptions.html#ImportError)
592
+ If safetensors or torch is not installed when trying to load a .safetensors file or a PyTorch checkpoint respectively.
593
+ [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError)
594
+ If the checkpoint file format is invalid or if git-lfs files are not properly downloaded.
595
+ [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
596
+ If the checkpoint file path is empty or invalid.
597
+
598
+ Example:
599
+ ```python
600
+ >>> from huggingface_hub import load_state_dict_from_file
601
+
602
+ # Load a PyTorch checkpoint
603
+ >>> state_dict = load_state_dict_from_file("path/to/model.bin", map_location="cpu")
604
+ >>> model.load_state_dict(state_dict)
605
+
606
+ # Load a safetensors checkpoint
607
+ >>> state_dict = load_state_dict_from_file("path/to/model.safetensors")
608
+ >>> model.load_state_dict(state_dict)
609
+ ```
610
+ """
611
+ checkpoint_path = Path(checkpoint_file)
612
+
613
+ # Check if file exists and is a regular file (not a directory)
614
+ if not checkpoint_path.is_file():
615
+ raise FileNotFoundError(
616
+ f"No checkpoint file found at '{checkpoint_path}'. Please verify the path is correct and "
617
+ "the file has been properly downloaded."
618
+ )
619
+
620
+ # Load safetensors checkpoint
621
+ if checkpoint_path.suffix == ".safetensors":
622
+ try:
623
+ from safetensors import safe_open
624
+ from safetensors.torch import load_file
625
+ except ImportError as e:
626
+ raise ImportError(
627
+ "Please install `safetensors` to load safetensors checkpoint. "
628
+ "You can install it with `pip install safetensors`."
629
+ ) from e
630
+
631
+ # Check format of the archive
632
+ with safe_open(checkpoint_file, framework="pt") as f: # type: ignore[attr-defined]
633
+ metadata = f.metadata()
634
+ # see comment: https://github.com/huggingface/transformers/blob/3d213b57fe74302e5902d68ed9478c3ad1aaa713/src/transformers/modeling_utils.py#L3966
635
+ if metadata is not None and metadata.get("format") not in ["pt", "mlx"]:
636
+ raise OSError(
637
+ f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
638
+ "you save your model with the `save_torch_model` method."
639
+ )
640
+ device = str(map_location.type) if map_location is not None and hasattr(map_location, "type") else map_location
641
+ # meta device is not supported with safetensors, falling back to CPU
642
+ if device == "meta":
643
+ logger.warning("Meta device is not supported with safetensors. Falling back to CPU device.")
644
+ device = "cpu"
645
+ return load_file(checkpoint_file, device=device) # type: ignore[arg-type]
646
+ # Otherwise, load from pickle
647
+ try:
648
+ import torch
649
+ from torch import load
650
+ except ImportError as e:
651
+ raise ImportError(
652
+ "Please install `torch` to load torch tensors. " "You can install it with `pip install torch`."
653
+ ) from e
654
+ # Add additional kwargs, mmap is only supported in torch >= 2.1.0
655
+ additional_kwargs = {}
656
+ if version.parse(torch.__version__) >= version.parse("2.1.0"):
657
+ additional_kwargs["mmap"] = mmap
658
+
659
+ # weights_only is only supported in torch >= 1.13.0
660
+ if version.parse(torch.__version__) >= version.parse("1.13.0"):
661
+ additional_kwargs["weights_only"] = weights_only
662
+
663
+ return load(
664
+ checkpoint_file,
665
+ map_location=map_location,
666
+ **additional_kwargs,
667
+ )
668
+
669
+
670
+ # HELPERS
671
+
672
+
673
+ def _validate_keys_for_strict_loading(
674
+ model: "torch.nn.Module",
675
+ loaded_keys: Iterable[str],
676
+ ) -> None:
677
+ """
678
+ Validate that model keys match loaded keys when strict loading is enabled.
679
+
680
+ Args:
681
+ model: The PyTorch model being loaded
682
+ loaded_keys: The keys present in the checkpoint
683
+
684
+ Raises:
685
+ RuntimeError: If there are missing or unexpected keys in strict mode
686
+ """
687
+ loaded_keys_set = set(loaded_keys)
688
+ model_keys = set(model.state_dict().keys())
689
+ missing_keys = model_keys - loaded_keys_set # Keys in model but not in checkpoint
690
+ unexpected_keys = loaded_keys_set - model_keys # Keys in checkpoint but not in model
691
+
692
+ if missing_keys or unexpected_keys:
693
+ error_message = f"Error(s) in loading state_dict for {model.__class__.__name__}"
694
+ if missing_keys:
695
+ str_missing_keys = ",".join([f'"{k}"' for k in sorted(missing_keys)])
696
+ error_message += f"\nMissing key(s): {str_missing_keys}."
697
+ if unexpected_keys:
698
+ str_unexpected_keys = ",".join([f'"{k}"' for k in sorted(unexpected_keys)])
699
+ error_message += f"\nUnexpected key(s): {str_unexpected_keys}."
700
+ raise RuntimeError(error_message)
701
+
702
+
339
703
  def _get_unique_id(tensor: "torch.Tensor") -> Union[int, Tuple[Any, ...]]:
340
704
  """Returns a unique id for plain tensor
341
705
  or a (potentially nested) Tuple of unique id for the flattened Tensor
@@ -359,7 +723,7 @@ def _get_unique_id(tensor: "torch.Tensor") -> Union[int, Tuple[Any, ...]]:
359
723
  # use some other unique id to distinguish.
360
724
  # this is a XLA tensor, it must be created using torch_xla's
361
725
  # device. So the following import is safe:
362
- import torch_xla
726
+ import torch_xla # type: ignore[import]
363
727
 
364
728
  unique_id = torch_xla._XLAC._xla_get_tensor_id(tensor)
365
729
  else:
@@ -423,7 +787,7 @@ def is_torch_tpu_available(check_device=True):
423
787
  if check_device:
424
788
  # We need to check if `xla_device` can be found, will raise a RuntimeError if not
425
789
  try:
426
- import torch_xla.core.xla_model as xm
790
+ import torch_xla.core.xla_model as xm # type: ignore[import]
427
791
 
428
792
  _ = xm.xla_device()
429
793
  return True
@@ -442,7 +806,7 @@ def storage_ptr(tensor: "torch.Tensor") -> Union[int, Tuple[Any, ...]]:
442
806
  from torch.utils._python_dispatch import is_traceable_wrapper_subclass
443
807
 
444
808
  if is_traceable_wrapper_subclass(tensor):
445
- return _get_unique_id(tensor)
809
+ return _get_unique_id(tensor) # type: ignore
446
810
  except ImportError:
447
811
  # for torch version less than 2.1, we can fallback to original implementation
448
812
  pass
@@ -459,7 +823,10 @@ def storage_ptr(tensor: "torch.Tensor") -> Union[int, Tuple[Any, ...]]:
459
823
 
460
824
 
461
825
  def _clean_state_dict_for_safetensors(
462
- state_dict: Dict[str, "torch.Tensor"], metadata: Dict[str, str], force_contiguous: bool = True
826
+ state_dict: Dict[str, "torch.Tensor"],
827
+ metadata: Dict[str, str],
828
+ force_contiguous: bool = True,
829
+ shared_tensors_to_discard: Optional[List[str]] = None,
463
830
  ):
464
831
  """Remove shared tensors from state_dict and update metadata accordingly (for reloading).
465
832
 
@@ -467,7 +834,7 @@ def _clean_state_dict_for_safetensors(
467
834
 
468
835
  Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L155.
469
836
  """
470
- to_removes = _remove_duplicate_names(state_dict)
837
+ to_removes = _remove_duplicate_names(state_dict, discard_names=shared_tensors_to_discard)
471
838
  for kept_name, to_remove_group in to_removes.items():
472
839
  for to_remove in to_remove_group:
473
840
  if metadata is None:
@@ -631,3 +998,18 @@ def _get_dtype_size(dtype: "torch.dtype") -> int:
631
998
  _float8_e5m2: 1,
632
999
  }
633
1000
  return _SIZE[dtype]
1001
+
1002
+
1003
+ class _IncompatibleKeys(namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"])):
1004
+ """
1005
+ This is used to report missing and unexpected keys in the state dict.
1006
+ Taken from https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/module.py#L52.
1007
+
1008
+ """
1009
+
1010
+ def __repr__(self) -> str:
1011
+ if not self.missing_keys and not self.unexpected_keys:
1012
+ return "<All keys matched successfully>"
1013
+ return super().__repr__()
1014
+
1015
+ __str__ = __repr__
@@ -742,7 +742,7 @@ def _scan_cached_repo(repo_path: Path) -> CachedRepoInfo:
742
742
 
743
743
  for ref_path in refs_path.glob("**/*"):
744
744
  # glob("**/*") iterates over all files and directories -> skip directories
745
- if ref_path.is_dir():
745
+ if ref_path.is_dir() or ref_path.name in FILES_TO_IGNORE:
746
746
  continue
747
747
 
748
748
  ref_name = str(ref_path.relative_to(refs_path))
@@ -20,6 +20,7 @@ from huggingface_hub.errors import LocalTokenNotFoundError
20
20
 
21
21
  from .. import constants
22
22
  from ._auth import get_token
23
+ from ._deprecation import _deprecate_arguments
23
24
  from ._runtime import (
24
25
  get_fastai_version,
25
26
  get_fastcore_version,
@@ -35,15 +36,20 @@ from ._runtime import (
35
36
  from ._validators import validate_hf_hub_args
36
37
 
37
38
 
39
+ @_deprecate_arguments(
40
+ version="1.0",
41
+ deprecated_args="is_write_action",
42
+ custom_message="This argument is ignored and we let the server handle the permission error instead (if any).",
43
+ )
38
44
  @validate_hf_hub_args
39
45
  def build_hf_headers(
40
46
  *,
41
47
  token: Optional[Union[bool, str]] = None,
42
- is_write_action: bool = False,
43
48
  library_name: Optional[str] = None,
44
49
  library_version: Optional[str] = None,
45
50
  user_agent: Union[Dict, str, None] = None,
46
51
  headers: Optional[Dict[str, str]] = None,
52
+ is_write_action: bool = False,
47
53
  ) -> Dict[str, str]:
48
54
  """
49
55
  Build headers dictionary to send in a HF Hub call.
@@ -68,9 +74,6 @@ def build_hf_headers(
68
74
  - if `False`, authorization header is not set
69
75
  - if `None`, the token is read from the machine only except if
70
76
  `HF_HUB_DISABLE_IMPLICIT_TOKEN` env variable is set.
71
- is_write_action (`bool`, default to `False`):
72
- Set to True if the API call requires a write access. If `True`, the token
73
- will be validated (cannot be `None`, cannot start by `"api_org***"`).
74
77
  library_name (`str`, *optional*):
75
78
  The name of the library that is making the HTTP request. Will be added to
76
79
  the user-agent header.
@@ -83,6 +86,8 @@ def build_hf_headers(
83
86
  headers (`dict`, *optional*):
84
87
  Additional headers to include in the request. Those headers take precedence
85
88
  over the ones generated by this function.
89
+ is_write_action (`bool`):
90
+ Ignored and deprecated argument.
86
91
 
87
92
  Returns:
88
93
  A `Dict` of headers to pass in your API call.
@@ -105,9 +110,6 @@ def build_hf_headers(
105
110
  >>> build_hf_headers() # token is not sent
106
111
  {"user-agent": ...}
107
112
 
108
- >>> build_hf_headers(token="api_org_***", is_write_action=True)
109
- ValueError: You must use your personal account token for write-access methods.
110
-
111
113
  >>> build_hf_headers(library_name="transformers", library_version="1.2.3")
112
114
  {"authorization": ..., "user-agent": "transformers/1.2.3; hf_hub/0.10.2; python/3.10.4; tensorflow/1.55"}
113
115
  ```
@@ -122,7 +124,6 @@ def build_hf_headers(
122
124
  """
123
125
  # Get auth token to send
124
126
  token_to_send = get_token_to_send(token)
125
- _validate_token_to_send(token_to_send, is_write_action=is_write_action)
126
127
 
127
128
  # Combine headers
128
129
  hf_headers = {
@@ -171,23 +172,6 @@ def get_token_to_send(token: Optional[Union[bool, str]]) -> Optional[str]:
171
172
  return cached_token
172
173
 
173
174
 
174
- def _validate_token_to_send(token: Optional[str], is_write_action: bool) -> None:
175
- if is_write_action:
176
- if token is None:
177
- raise ValueError(
178
- "Token is required (write-access action) but no token found. You need"
179
- " to provide a token or be logged in to Hugging Face with"
180
- " `huggingface-cli login` or `huggingface_hub.login`. See"
181
- " https://huggingface.co/settings/tokens."
182
- )
183
- if token.startswith("api_org"):
184
- raise ValueError(
185
- "You must use your personal account token for write-access methods. To"
186
- " generate a write-access token, go to"
187
- " https://huggingface.co/settings/tokens"
188
- )
189
-
190
-
191
175
  def _http_user_agent(
192
176
  *,
193
177
  library_name: Optional[str] = None,