huggingface-hub 0.12.0rc0__py3-none-any.whl → 0.13.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. huggingface_hub/__init__.py +166 -126
  2. huggingface_hub/_commit_api.py +25 -51
  3. huggingface_hub/_login.py +4 -13
  4. huggingface_hub/_snapshot_download.py +45 -23
  5. huggingface_hub/_space_api.py +7 -0
  6. huggingface_hub/commands/delete_cache.py +13 -39
  7. huggingface_hub/commands/env.py +1 -3
  8. huggingface_hub/commands/huggingface_cli.py +1 -3
  9. huggingface_hub/commands/lfs.py +4 -8
  10. huggingface_hub/commands/scan_cache.py +5 -16
  11. huggingface_hub/commands/user.py +27 -45
  12. huggingface_hub/community.py +4 -4
  13. huggingface_hub/constants.py +22 -19
  14. huggingface_hub/fastai_utils.py +14 -23
  15. huggingface_hub/file_download.py +210 -121
  16. huggingface_hub/hf_api.py +500 -255
  17. huggingface_hub/hub_mixin.py +181 -176
  18. huggingface_hub/inference_api.py +4 -10
  19. huggingface_hub/keras_mixin.py +39 -71
  20. huggingface_hub/lfs.py +8 -24
  21. huggingface_hub/repocard.py +33 -48
  22. huggingface_hub/repocard_data.py +141 -30
  23. huggingface_hub/repository.py +41 -112
  24. huggingface_hub/templates/modelcard_template.md +39 -34
  25. huggingface_hub/utils/__init__.py +1 -0
  26. huggingface_hub/utils/_cache_assets.py +1 -4
  27. huggingface_hub/utils/_cache_manager.py +17 -39
  28. huggingface_hub/utils/_deprecation.py +8 -12
  29. huggingface_hub/utils/_errors.py +10 -57
  30. huggingface_hub/utils/_fixes.py +2 -6
  31. huggingface_hub/utils/_git_credential.py +5 -16
  32. huggingface_hub/utils/_headers.py +22 -11
  33. huggingface_hub/utils/_http.py +1 -4
  34. huggingface_hub/utils/_paths.py +5 -12
  35. huggingface_hub/utils/_runtime.py +2 -1
  36. huggingface_hub/utils/_telemetry.py +120 -0
  37. huggingface_hub/utils/_validators.py +5 -13
  38. huggingface_hub/utils/endpoint_helpers.py +1 -3
  39. huggingface_hub/utils/logging.py +10 -8
  40. {huggingface_hub-0.12.0rc0.dist-info → huggingface_hub-0.13.0rc0.dist-info}/METADATA +7 -14
  41. huggingface_hub-0.13.0rc0.dist-info/RECORD +56 -0
  42. huggingface_hub/py.typed +0 -0
  43. huggingface_hub-0.12.0rc0.dist-info/RECORD +0 -56
  44. {huggingface_hub-0.12.0rc0.dist-info → huggingface_hub-0.13.0rc0.dist-info}/LICENSE +0 -0
  45. {huggingface_hub-0.12.0rc0.dist-info → huggingface_hub-0.13.0rc0.dist-info}/WHEEL +0 -0
  46. {huggingface_hub-0.12.0rc0.dist-info → huggingface_hub-0.13.0rc0.dist-info}/entry_points.txt +0 -0
  47. {huggingface_hub-0.12.0rc0.dist-info → huggingface_hub-0.13.0rc0.dist-info}/top_level.txt +0 -0
@@ -19,9 +19,10 @@ from urllib.parse import quote, urlparse
19
19
 
20
20
  import requests
21
21
  from filelock import FileLock
22
- from huggingface_hub import constants
23
22
  from requests.exceptions import ConnectTimeout, ProxyError
24
23
 
24
+ from huggingface_hub import constants
25
+
25
26
  from . import __version__ # noqa: F401 # for backward compatibility
26
27
  from .constants import (
27
28
  DEFAULT_REVISION,
@@ -36,38 +37,41 @@ from .constants import (
36
37
  REPO_TYPES,
37
38
  REPO_TYPES_URL_PREFIXES,
38
39
  )
39
- from .utils import get_fastai_version # noqa: F401 # for backward compatibility
40
- from .utils import get_fastcore_version # noqa: F401 # for backward compatibility
41
- from .utils import get_graphviz_version # noqa: F401 # for backward compatibility
42
- from .utils import get_jinja_version # noqa: F401 # for backward compatibility
43
- from .utils import get_pydot_version # noqa: F401 # for backward compatibility
44
- from .utils import get_tf_version # noqa: F401 # for backward compatibility
45
- from .utils import get_torch_version # noqa: F401 # for backward compatibility
46
- from .utils import is_fastai_available # noqa: F401 # for backward compatibility
47
- from .utils import is_fastcore_available # noqa: F401 # for backward compatibility
48
- from .utils import is_graphviz_available # noqa: F401 # for backward compatibility
49
- from .utils import is_jinja_available # noqa: F401 # for backward compatibility
50
- from .utils import is_pydot_available # noqa: F401 # for backward compatibility
51
- from .utils import is_tf_available # noqa: F401 # for backward compatibility
52
- from .utils import is_torch_available # noqa: F401 # for backward compatibility
53
40
  from .utils import (
54
41
  EntryNotFoundError,
55
42
  LocalEntryNotFoundError,
56
43
  SoftTemporaryDirectory,
57
44
  build_hf_headers,
45
+ get_fastai_version, # noqa: F401 # for backward compatibility
46
+ get_fastcore_version, # noqa: F401 # for backward compatibility
47
+ get_graphviz_version, # noqa: F401 # for backward compatibility
48
+ get_jinja_version, # noqa: F401 # for backward compatibility
49
+ get_pydot_version, # noqa: F401 # for backward compatibility
50
+ get_tf_version, # noqa: F401 # for backward compatibility
51
+ get_torch_version, # noqa: F401 # for backward compatibility
58
52
  hf_raise_for_status,
59
53
  http_backoff,
54
+ is_fastai_available, # noqa: F401 # for backward compatibility
55
+ is_fastcore_available, # noqa: F401 # for backward compatibility
56
+ is_graphviz_available, # noqa: F401 # for backward compatibility
57
+ is_jinja_available, # noqa: F401 # for backward compatibility
58
+ is_pydot_available, # noqa: F401 # for backward compatibility
59
+ is_tf_available, # noqa: F401 # for backward compatibility
60
+ is_torch_available, # noqa: F401 # for backward compatibility
60
61
  logging,
61
62
  tqdm,
62
63
  validate_hf_hub_args,
63
64
  )
64
65
  from .utils._headers import _http_user_agent
65
66
  from .utils._runtime import _PY_VERSION # noqa: F401 # for backward compatibility
66
- from .utils._typing import HTTP_METHOD_T
67
+ from .utils._typing import HTTP_METHOD_T, Literal
67
68
 
68
69
 
69
70
  logger = logging.get_logger(__name__)
70
71
 
72
+ # Regex to get filename from a "Content-Disposition" header for CDN-served files
73
+ HEADER_FILENAME_PATTERN = re.compile(r'filename="(?P<filename>.*?)";')
74
+
71
75
 
72
76
  _are_symlinks_supported_in_dir: Dict[str, bool] = {}
73
77
 
@@ -185,8 +189,8 @@ def hf_hub_url(
185
189
  subfolder (`str`, *optional*):
186
190
  An optional value corresponding to a folder inside the repo.
187
191
  repo_type (`str`, *optional*):
188
- Set to `"dataset"` or `"space"` if uploading to a dataset or space,
189
- `None` or `"model"` if uploading to a model. Default is `None`.
192
+ Set to `"dataset"` or `"space"` if downloading from a dataset or space,
193
+ `None` or `"model"` if downloading from a model. Default is `None`.
190
194
  revision (`str`, *optional*):
191
195
  An optional Git revision id which can be a branch name, a tag, or a
192
196
  commit hash.
@@ -347,9 +351,7 @@ def _raise_if_offline_mode_is_enabled(msg: Optional[str] = None):
347
351
  HF_HUB_OFFLINE is True."""
348
352
  if constants.HF_HUB_OFFLINE:
349
353
  raise OfflineModeIsEnabled(
350
- "Offline mode is enabled."
351
- if msg is None
352
- else "Offline mode is enabled. " + str(msg)
354
+ "Offline mode is enabled." if msg is None else "Offline mode is enabled. " + str(msg)
353
355
  )
354
356
 
355
357
 
@@ -515,16 +517,22 @@ def http_get(
515
517
 
516
518
  displayed_name = url
517
519
  content_disposition = r.headers.get("Content-Disposition")
518
- if content_disposition is not None and "filename=" in content_disposition:
519
- # Means file is on CDN
520
- displayed_name = content_disposition.split("filename=")[-1]
520
+ if content_disposition is not None:
521
+ match = HEADER_FILENAME_PATTERN.search(content_disposition)
522
+ if match is not None:
523
+ # Means file is on CDN
524
+ displayed_name = match.groupdict()["filename"]
525
+
526
+ # Truncate filename if too long to display
527
+ if len(displayed_name) > 22:
528
+ displayed_name = f"(…){displayed_name[-20:]}"
521
529
 
522
530
  progress = tqdm(
523
531
  unit="B",
524
532
  unit_scale=True,
525
533
  total=total,
526
534
  initial=resume_size,
527
- desc=f"Downloading (…){displayed_name[-20:]}",
535
+ desc=f"Downloading {displayed_name}",
528
536
  disable=bool(logger.getEffectiveLevel() == logging.NOTSET),
529
537
  )
530
538
  for chunk in r.iter_content(chunk_size=10 * 1024 * 1024):
@@ -627,8 +635,10 @@ def cached_download(
627
635
  """
628
636
  if not legacy_cache_layout:
629
637
  warnings.warn(
630
- "`cached_download` is the legacy way to download files from the HF hub,"
631
- " please consider upgrading to `hf_hub_download`",
638
+ (
639
+ "`cached_download` is the legacy way to download files from the HF hub,"
640
+ " please consider upgrading to `hf_hub_download`"
641
+ ),
632
642
  FutureWarning,
633
643
  )
634
644
 
@@ -666,8 +676,7 @@ def cached_download(
666
676
  # If we don't have any of those, raise an error.
667
677
  if etag is None:
668
678
  raise OSError(
669
- "Distant resource does not have an ETag, we won't be able to"
670
- " reliably ensure reproducibility."
679
+ "Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility."
671
680
  )
672
681
  # In case of a redirect, save an extra redirect on the request.get call,
673
682
  # and ensure we download the exact atomic version even if it changed
@@ -675,6 +684,7 @@ def cached_download(
675
684
  # Useful for lfs blobs that are stored on a CDN.
676
685
  if 300 <= r.status_code <= 399:
677
686
  url_to_download = r.headers["Location"]
687
+ headers.pop("authorization", None)
678
688
  except (requests.exceptions.SSLError, requests.exceptions.ProxyError):
679
689
  # Actually raise for those subclasses of ConnectionError
680
690
  raise
@@ -687,9 +697,7 @@ def cached_download(
687
697
  # etag is None
688
698
  pass
689
699
 
690
- filename = (
691
- force_filename if force_filename is not None else url_to_filename(url, etag)
692
- )
700
+ filename = force_filename if force_filename is not None else url_to_filename(url, etag)
693
701
 
694
702
  # get cache path to put the file
695
703
  cache_path = os.path.join(cache_dir, filename)
@@ -702,16 +710,10 @@ def cached_download(
702
710
  else:
703
711
  matching_files = [
704
712
  file
705
- for file in fnmatch.filter(
706
- os.listdir(cache_dir), filename.split(".")[0] + ".*"
707
- )
713
+ for file in fnmatch.filter(os.listdir(cache_dir), filename.split(".")[0] + ".*")
708
714
  if not file.endswith(".json") and not file.endswith(".lock")
709
715
  ]
710
- if (
711
- len(matching_files) > 0
712
- and not force_download
713
- and force_filename is None
714
- ):
716
+ if len(matching_files) > 0 and not force_download and force_filename is None:
715
717
  return os.path.join(cache_dir, matching_files[-1])
716
718
  else:
717
719
  # If files cannot be found and local_files_only=True,
@@ -843,11 +845,19 @@ def _create_relative_symlink(src: str, dst: str, new_blob: bool = False) -> None
843
845
  except OSError:
844
846
  pass
845
847
 
846
- cache_dir = os.path.dirname(os.path.commonpath([src, dst]))
847
- if are_symlinks_supported(cache_dir=cache_dir):
848
- relative_src = os.path.relpath(src, start=os.path.dirname(dst))
848
+ try:
849
+ _support_symlinks = are_symlinks_supported(
850
+ os.path.dirname(os.path.commonpath([os.path.realpath(src), os.path.realpath(dst)]))
851
+ )
852
+ except PermissionError:
853
+ # Permission error means src and dst are not in the same volume (e.g. destination path has been provided
854
+ # by the user via `local_dir`. Let's test symlink support there)
855
+ _support_symlinks = are_symlinks_supported(os.path.dirname(dst))
856
+
857
+ if _support_symlinks:
858
+ logger.info(f"Creating pointer from {src} to {dst}")
849
859
  try:
850
- os.symlink(relative_src, dst)
860
+ os.symlink(src, dst)
851
861
  except FileExistsError:
852
862
  if os.path.islink(dst) and os.path.realpath(dst) == os.path.realpath(src):
853
863
  # `dst` already exists and is a symlink to the `src` blob. It is most
@@ -860,14 +870,14 @@ def _create_relative_symlink(src: str, dst: str, new_blob: bool = False) -> None
860
870
  # blob file. Raise exception.
861
871
  raise
862
872
  elif new_blob:
873
+ logger.info(f"Symlink not supported. Moving file from {src} to {dst}")
863
874
  os.replace(src, dst)
864
875
  else:
876
+ logger.info(f"Symlink not supported. Copying file from {src} to {dst}")
865
877
  shutil.copyfile(src, dst)
866
878
 
867
879
 
868
- def _cache_commit_hash_for_specific_revision(
869
- storage_folder: str, revision: str, commit_hash: str
870
- ) -> None:
880
+ def _cache_commit_hash_for_specific_revision(storage_folder: str, revision: str, commit_hash: str) -> None:
871
881
  """Cache reference between a revision (tag, branch or truncated commit hash) and the corresponding commit hash.
872
882
 
873
883
  Does nothing if `revision` is already a proper `commit_hash` or reference is already cached.
@@ -905,6 +915,8 @@ def hf_hub_download(
905
915
  library_name: Optional[str] = None,
906
916
  library_version: Optional[str] = None,
907
917
  cache_dir: Union[str, Path, None] = None,
918
+ local_dir: Union[str, Path, None] = None,
919
+ local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto",
908
920
  user_agent: Union[Dict, str, None] = None,
909
921
  force_download: bool = False,
910
922
  force_filename: Optional[str] = None,
@@ -927,6 +939,21 @@ def hf_hub_download(
927
939
  that have been resolved at that particular commit. Each filename is a symlink to the blob
928
940
  at that particular commit.
929
941
 
942
+ If `local_dir` is provided, the file structure from the repo will be replicated in this location. You can configure
943
+ how you want to move those files:
944
+ - If `local_dir_use_symlinks="auto"` (default), files are downloaded and stored in the cache directory as blob
945
+ files. Small files (<5MB) are duplicated in `local_dir` while a symlink is created for bigger files. The goal
946
+ is to be able to manually edit and save small files without corrupting the cache while saving disk space for
947
+ binary files. The 5MB threshold can be configured with the `HF_HUB_LOCAL_DIR_AUTO_SYMLINK_THRESHOLD`
948
+ environment variable.
949
+ - If `local_dir_use_symlinks=True`, files are downloaded, stored in the cache directory and symlinked in `local_dir`.
950
+ This is optimal in term of disk usage but files must not be manually edited.
951
+ - If `local_dir_use_symlinks=False` and the blob files exist in the cache directory, they are duplicated in the
952
+ local dir. This means disk usage is not optimized.
953
+ - Finally, if `local_dir_use_symlinks=False` and the blob files do not exist in the cache directory, then the
954
+ files are downloaded and directly placed under `local_dir`. This means if you need to download them again later,
955
+ they will be re-downloaded entirely.
956
+
930
957
  ```
931
958
  [ 96] .
932
959
  └── [ 160] models--julien-c--EsperBERTo-small
@@ -953,8 +980,8 @@ def hf_hub_download(
953
980
  subfolder (`str`, *optional*):
954
981
  An optional value corresponding to a folder inside the model repo.
955
982
  repo_type (`str`, *optional*):
956
- Set to `"dataset"` or `"space"` if uploading to a dataset or space,
957
- `None` or `"model"` if uploading to a model. Default is `None`.
983
+ Set to `"dataset"` or `"space"` if downloading from a dataset or space,
984
+ `None` or `"model"` if downloading from a model. Default is `None`.
958
985
  revision (`str`, *optional*):
959
986
  An optional Git revision id which can be a branch name, a tag, or a
960
987
  commit hash.
@@ -964,6 +991,14 @@ def hf_hub_download(
964
991
  The version of the library.
965
992
  cache_dir (`str`, `Path`, *optional*):
966
993
  Path to the folder where cached files are stored.
994
+ local_dir (`str` or `Path`, *optional*):
995
+ If provided, the downloaded file will be placed under this directory, either as a symlink (default) or
996
+ a regular file (see description for more details).
997
+ local_dir_use_symlinks (`"auto"` or `bool`, defaults to `"auto"`):
998
+ To be used with `local_dir`. If set to "auto", the cache directory will be used and the file will be either
999
+ duplicated or symlinked to the local directory depending on its size. It set to `True`, a symlink will be
1000
+ created, no matter the file size. If set to `False`, the file will either be duplicated from cache (if
1001
+ already exists) or downloaded from the Hub and not cached. See description for more details.
967
1002
  user_agent (`dict`, `str`, *optional*):
968
1003
  The user-agent info in the form of a dictionary or a string.
969
1004
  force_download (`bool`, *optional*, defaults to `False`):
@@ -1018,8 +1053,10 @@ def hf_hub_download(
1018
1053
  """
1019
1054
  if force_filename is not None:
1020
1055
  warnings.warn(
1021
- "The `force_filename` parameter is deprecated as a new caching system, "
1022
- "which keeps the filenames as they are on the Hub, is now in place.",
1056
+ (
1057
+ "The `force_filename` parameter is deprecated as a new caching system, "
1058
+ "which keeps the filenames as they are on the Hub, is now in place."
1059
+ ),
1023
1060
  FutureWarning,
1024
1061
  )
1025
1062
  legacy_cache_layout = True
@@ -1055,6 +1092,8 @@ def hf_hub_download(
1055
1092
  revision = DEFAULT_REVISION
1056
1093
  if isinstance(cache_dir, Path):
1057
1094
  cache_dir = str(cache_dir)
1095
+ if isinstance(local_dir, Path):
1096
+ local_dir = str(local_dir)
1058
1097
 
1059
1098
  if subfolder == "":
1060
1099
  subfolder = None
@@ -1065,14 +1104,9 @@ def hf_hub_download(
1065
1104
  if repo_type is None:
1066
1105
  repo_type = "model"
1067
1106
  if repo_type not in REPO_TYPES:
1068
- raise ValueError(
1069
- f"Invalid repo type: {repo_type}. Accepted repo types are:"
1070
- f" {str(REPO_TYPES)}"
1071
- )
1107
+ raise ValueError(f"Invalid repo type: {repo_type}. Accepted repo types are: {str(REPO_TYPES)}")
1072
1108
 
1073
- storage_folder = os.path.join(
1074
- cache_dir, repo_folder_name(repo_id=repo_id, repo_type=repo_type)
1075
- )
1109
+ storage_folder = os.path.join(cache_dir, repo_folder_name(repo_id=repo_id, repo_type=repo_type))
1076
1110
  os.makedirs(storage_folder, exist_ok=True)
1077
1111
 
1078
1112
  # cross platform transcription of filename, to be used as a local file path.
@@ -1081,10 +1115,10 @@ def hf_hub_download(
1081
1115
  # if user provides a commit_hash and they already have the file on disk,
1082
1116
  # shortcut everything.
1083
1117
  if REGEX_COMMIT_HASH.match(revision):
1084
- pointer_path = os.path.join(
1085
- storage_folder, "snapshots", revision, relative_filename
1086
- )
1118
+ pointer_path = os.path.join(storage_folder, "snapshots", revision, relative_filename)
1087
1119
  if os.path.exists(pointer_path):
1120
+ if local_dir is not None:
1121
+ return _to_local_dir(pointer_path, local_dir, relative_filename, use_symlinks=local_dir_use_symlinks)
1088
1122
  return pointer_path
1089
1123
 
1090
1124
  url = hf_hub_url(repo_id, filename, repo_type=repo_type, revision=revision)
@@ -1110,30 +1144,18 @@ def hf_hub_download(
1110
1144
  )
1111
1145
  except EntryNotFoundError as http_error:
1112
1146
  # Cache the non-existence of the file and raise
1113
- commit_hash = http_error.response.headers.get(
1114
- HUGGINGFACE_HEADER_X_REPO_COMMIT
1115
- )
1147
+ commit_hash = http_error.response.headers.get(HUGGINGFACE_HEADER_X_REPO_COMMIT)
1116
1148
  if commit_hash is not None and not legacy_cache_layout:
1117
- no_exist_file_path = (
1118
- Path(storage_folder)
1119
- / ".no_exist"
1120
- / commit_hash
1121
- / relative_filename
1122
- )
1149
+ no_exist_file_path = Path(storage_folder) / ".no_exist" / commit_hash / relative_filename
1123
1150
  no_exist_file_path.parent.mkdir(parents=True, exist_ok=True)
1124
1151
  no_exist_file_path.touch()
1125
- _cache_commit_hash_for_specific_revision(
1126
- storage_folder, revision, commit_hash
1127
- )
1152
+ _cache_commit_hash_for_specific_revision(storage_folder, revision, commit_hash)
1128
1153
  raise
1129
1154
 
1130
1155
  # Commit hash must exist
1131
1156
  commit_hash = metadata.commit_hash
1132
1157
  if commit_hash is None:
1133
- raise OSError(
1134
- "Distant resource does not seem to be on huggingface.co (missing"
1135
- " commit header)."
1136
- )
1158
+ raise OSError("Distant resource does not seem to be on huggingface.co (missing commit header).")
1137
1159
 
1138
1160
  # Etag must exist
1139
1161
  etag = metadata.etag
@@ -1142,8 +1164,7 @@ def hf_hub_download(
1142
1164
  # If we don't have any of those, raise an error.
1143
1165
  if etag is None:
1144
1166
  raise OSError(
1145
- "Distant resource does not have an ETag, we won't be able to"
1146
- " reliably ensure reproducibility."
1167
+ "Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility."
1147
1168
  )
1148
1169
 
1149
1170
  # In case of a redirect, save an extra redirect on the request.get call,
@@ -1174,24 +1195,30 @@ def hf_hub_download(
1174
1195
  # In those cases, we cannot force download.
1175
1196
  if force_download:
1176
1197
  raise ValueError(
1177
- "We have no connection or you passed local_files_only, so"
1178
- " force_download is not an accepted option."
1198
+ "We have no connection or you passed local_files_only, so force_download is not an accepted option."
1179
1199
  )
1200
+
1201
+ # Try to get "commit_hash" from "revision"
1202
+ commit_hash = None
1180
1203
  if REGEX_COMMIT_HASH.match(revision):
1181
1204
  commit_hash = revision
1182
1205
  else:
1183
1206
  ref_path = os.path.join(storage_folder, "refs", revision)
1184
- with open(ref_path) as f:
1185
- commit_hash = f.read()
1186
-
1187
- pointer_path = os.path.join(
1188
- storage_folder, "snapshots", commit_hash, relative_filename
1189
- )
1190
- if os.path.exists(pointer_path):
1191
- return pointer_path
1207
+ if os.path.isfile(ref_path):
1208
+ with open(ref_path) as f:
1209
+ commit_hash = f.read()
1210
+
1211
+ # Return pointer file if exists
1212
+ if commit_hash is not None:
1213
+ pointer_path = os.path.join(storage_folder, "snapshots", commit_hash, relative_filename)
1214
+ if os.path.exists(pointer_path):
1215
+ if local_dir is not None:
1216
+ return _to_local_dir(
1217
+ pointer_path, local_dir, relative_filename, use_symlinks=local_dir_use_symlinks
1218
+ )
1219
+ return pointer_path
1192
1220
 
1193
- # If we couldn't find an appropriate file on disk,
1194
- # raise an error.
1221
+ # If we couldn't find an appropriate file on disk, raise an error.
1195
1222
  # If files cannot be found and local_files_only=True,
1196
1223
  # the models might've been found if local_files_only=False
1197
1224
  # Notify the user about that
@@ -1212,9 +1239,7 @@ def hf_hub_download(
1212
1239
  assert etag is not None, "etag must have been retrieved from server"
1213
1240
  assert commit_hash is not None, "commit_hash must have been retrieved from server"
1214
1241
  blob_path = os.path.join(storage_folder, "blobs", etag)
1215
- pointer_path = os.path.join(
1216
- storage_folder, "snapshots", commit_hash, relative_filename
1217
- )
1242
+ pointer_path = os.path.join(storage_folder, "snapshots", commit_hash, relative_filename)
1218
1243
 
1219
1244
  os.makedirs(os.path.dirname(blob_path), exist_ok=True)
1220
1245
  os.makedirs(os.path.dirname(pointer_path), exist_ok=True)
@@ -1224,13 +1249,17 @@ def hf_hub_download(
1224
1249
  _cache_commit_hash_for_specific_revision(storage_folder, revision, commit_hash)
1225
1250
 
1226
1251
  if os.path.exists(pointer_path) and not force_download:
1252
+ if local_dir is not None:
1253
+ return _to_local_dir(pointer_path, local_dir, relative_filename, use_symlinks=local_dir_use_symlinks)
1227
1254
  return pointer_path
1228
1255
 
1229
1256
  if os.path.exists(blob_path) and not force_download:
1230
1257
  # we have the blob already, but not the pointer
1231
- logger.info("creating pointer to %s from %s", blob_path, pointer_path)
1232
- _create_relative_symlink(blob_path, pointer_path, new_blob=False)
1233
- return pointer_path
1258
+ if local_dir is not None: # to local dir
1259
+ return _to_local_dir(blob_path, local_dir, relative_filename, use_symlinks=local_dir_use_symlinks)
1260
+ else: # or in snapshot cache
1261
+ _create_relative_symlink(blob_path, pointer_path, new_blob=False)
1262
+ return pointer_path
1234
1263
 
1235
1264
  # Prevent parallel downloads of the same file with a lock.
1236
1265
  lock_path = blob_path + ".lock"
@@ -1281,11 +1310,31 @@ def hf_hub_download(
1281
1310
  headers=headers,
1282
1311
  )
1283
1312
 
1284
- logger.info("storing %s in cache at %s", url, blob_path)
1285
- _chmod_and_replace(temp_file.name, blob_path)
1286
-
1287
- logger.info("creating pointer to %s from %s", blob_path, pointer_path)
1288
- _create_relative_symlink(blob_path, pointer_path, new_blob=True)
1313
+ if local_dir is None:
1314
+ logger.info(f"Storing {url} in cache at {blob_path}")
1315
+ _chmod_and_replace(temp_file.name, blob_path)
1316
+ _create_relative_symlink(blob_path, pointer_path, new_blob=True)
1317
+ else:
1318
+ local_dir_filepath = os.path.join(local_dir, relative_filename)
1319
+ os.makedirs(os.path.dirname(local_dir_filepath), exist_ok=True)
1320
+
1321
+ # If "auto" (default) copy-paste small files to ease manual editing but symlink big files to save disk
1322
+ # In both cases, blob file is cached.
1323
+ is_big_file = os.stat(temp_file.name).st_size > constants.HF_HUB_LOCAL_DIR_AUTO_SYMLINK_THRESHOLD
1324
+ if local_dir_use_symlinks is True or (local_dir_use_symlinks == "auto" and is_big_file):
1325
+ logger.info(f"Storing {url} in cache at {blob_path}")
1326
+ _chmod_and_replace(temp_file.name, blob_path)
1327
+ logger.info("Create symlink to local dir")
1328
+ _create_relative_symlink(blob_path, local_dir_filepath, new_blob=False)
1329
+ elif local_dir_use_symlinks == "auto" and not is_big_file:
1330
+ logger.info(f"Storing {url} in cache at {blob_path}")
1331
+ _chmod_and_replace(temp_file.name, blob_path)
1332
+ logger.info("Duplicate in local dir (small file and use_symlink set to 'auto')")
1333
+ shutil.copyfile(blob_path, local_dir_filepath)
1334
+ else:
1335
+ logger.info(f"Storing {url} in local_dir at {local_dir_filepath} (not cached).")
1336
+ _chmod_and_replace(temp_file.name, local_dir_filepath)
1337
+ pointer_path = local_dir_filepath # for return value
1289
1338
 
1290
1339
  try:
1291
1340
  os.remove(lock_path)
@@ -1327,16 +1376,30 @@ def try_to_load_from_cache(
1327
1376
  - The exact path to the cached file if it's found in the cache
1328
1377
  - A special value `_CACHED_NO_EXIST` if the file does not exist at the given commit hash and this fact was
1329
1378
  cached.
1379
+
1380
+ Example:
1381
+
1382
+ ```python
1383
+ from huggingface_hub import try_to_load_from_cache, _CACHED_NO_EXIST
1384
+
1385
+ filepath = try_to_load_from_cache()
1386
+ if isinstance(filepath, str):
1387
+ # file exists and is cached
1388
+ ...
1389
+ elif filepath is _CACHED_NO_EXIST:
1390
+ # non-existence of file is cached
1391
+ ...
1392
+ else:
1393
+ # file is not cached
1394
+ ...
1395
+ ```
1330
1396
  """
1331
1397
  if revision is None:
1332
1398
  revision = "main"
1333
1399
  if repo_type is None:
1334
1400
  repo_type = "model"
1335
1401
  if repo_type not in REPO_TYPES:
1336
- raise ValueError(
1337
- f"Invalid repo type: {repo_type}. Accepted repo types are:"
1338
- f" {str(REPO_TYPES)}"
1339
- )
1402
+ raise ValueError(f"Invalid repo type: {repo_type}. Accepted repo types are: {str(REPO_TYPES)}")
1340
1403
  if cache_dir is None:
1341
1404
  cache_dir = HUGGINGFACE_HUB_CACHE
1342
1405
 
@@ -1345,25 +1408,32 @@ def try_to_load_from_cache(
1345
1408
  if not os.path.isdir(repo_cache):
1346
1409
  # No cache for this model
1347
1410
  return None
1348
- for subfolder in ["refs", "snapshots"]:
1349
- if not os.path.isdir(os.path.join(repo_cache, subfolder)):
1350
- return None
1351
1411
 
1352
- # Resolve refs (for instance to convert main to the associated commit sha)
1353
- cached_refs = os.listdir(os.path.join(repo_cache, "refs"))
1354
- if revision in cached_refs:
1355
- with open(os.path.join(repo_cache, "refs", revision)) as f:
1356
- revision = f.read()
1412
+ refs_dir = os.path.join(repo_cache, "refs")
1413
+ snapshots_dir = os.path.join(repo_cache, "snapshots")
1414
+ no_exist_dir = os.path.join(repo_cache, ".no_exist")
1357
1415
 
1358
- if os.path.isfile(os.path.join(repo_cache, ".no_exist", revision, filename)):
1416
+ # Resolve refs (for instance to convert main to the associated commit sha)
1417
+ if os.path.isdir(refs_dir):
1418
+ revision_file = os.path.join(refs_dir, revision)
1419
+ if os.path.isfile(revision_file):
1420
+ with open(revision_file) as f:
1421
+ revision = f.read()
1422
+
1423
+ # Check if file is cached as "no_exist"
1424
+ if os.path.isfile(os.path.join(no_exist_dir, revision, filename)):
1359
1425
  return _CACHED_NO_EXIST
1360
1426
 
1361
- cached_shas = os.listdir(os.path.join(repo_cache, "snapshots"))
1427
+ # Check if revision folder exists
1428
+ if not os.path.exists(snapshots_dir):
1429
+ return None
1430
+ cached_shas = os.listdir(snapshots_dir)
1362
1431
  if revision not in cached_shas:
1363
1432
  # No cache for this revision and we won't try to return a random revision
1364
1433
  return None
1365
1434
 
1366
- cached_file = os.path.join(repo_cache, "snapshots", revision, filename)
1435
+ # Check if file exists in cache
1436
+ cached_file = os.path.join(snapshots_dir, revision, filename)
1367
1437
  return cached_file if os.path.isfile(cached_file) else None
1368
1438
 
1369
1439
 
@@ -1422,10 +1492,7 @@ def get_hf_file_metadata(
1422
1492
  # Do not use directly `url`, as `_request_wrapper` might have followed relative
1423
1493
  # redirects.
1424
1494
  location=r.headers.get("Location") or r.request.url, # type: ignore
1425
- size=_int_or_none(
1426
- r.headers.get(HUGGINGFACE_HEADER_X_LINKED_SIZE)
1427
- or r.headers.get("Content-Length")
1428
- ),
1495
+ size=_int_or_none(r.headers.get(HUGGINGFACE_HEADER_X_LINKED_SIZE) or r.headers.get("Content-Length")),
1429
1496
  )
1430
1497
 
1431
1498
 
@@ -1459,3 +1526,25 @@ def _chmod_and_replace(src: str, dst: str) -> None:
1459
1526
  tmp_file.unlink()
1460
1527
 
1461
1528
  os.replace(src, dst)
1529
+
1530
+
1531
+ def _to_local_dir(
1532
+ path: str, local_dir: str, relative_filename: str, use_symlinks: Union[bool, Literal["auto"]]
1533
+ ) -> str:
1534
+ """Place a file in a local dir (different than cache_dir).
1535
+
1536
+ Either symlink to blob file in cache or duplicate file depending on `use_symlinks` and file size.
1537
+ """
1538
+ local_dir_filepath = os.path.join(local_dir, relative_filename)
1539
+ os.makedirs(os.path.dirname(local_dir_filepath), exist_ok=True)
1540
+ real_blob_path = os.path.realpath(path)
1541
+
1542
+ # If "auto" (default) copy-paste small files to ease manual editing but symlink big files to save disk
1543
+ if use_symlinks == "auto":
1544
+ use_symlinks = os.stat(real_blob_path).st_size > constants.HF_HUB_LOCAL_DIR_AUTO_SYMLINK_THRESHOLD
1545
+
1546
+ if use_symlinks:
1547
+ _create_relative_symlink(real_blob_path, local_dir_filepath, new_blob=False)
1548
+ else:
1549
+ shutil.copyfile(real_blob_path, local_dir_filepath)
1550
+ return local_dir_filepath