huggingface-hub 0.29.0rc2__py3-none-any.whl → 1.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. huggingface_hub/__init__.py +160 -46
  2. huggingface_hub/_commit_api.py +277 -71
  3. huggingface_hub/_commit_scheduler.py +15 -15
  4. huggingface_hub/_inference_endpoints.py +33 -22
  5. huggingface_hub/_jobs_api.py +301 -0
  6. huggingface_hub/_local_folder.py +18 -3
  7. huggingface_hub/_login.py +31 -63
  8. huggingface_hub/_oauth.py +460 -0
  9. huggingface_hub/_snapshot_download.py +241 -81
  10. huggingface_hub/_space_api.py +18 -10
  11. huggingface_hub/_tensorboard_logger.py +15 -19
  12. huggingface_hub/_upload_large_folder.py +196 -76
  13. huggingface_hub/_webhooks_payload.py +3 -3
  14. huggingface_hub/_webhooks_server.py +15 -25
  15. huggingface_hub/{commands → cli}/__init__.py +1 -15
  16. huggingface_hub/cli/_cli_utils.py +173 -0
  17. huggingface_hub/cli/auth.py +147 -0
  18. huggingface_hub/cli/cache.py +841 -0
  19. huggingface_hub/cli/download.py +189 -0
  20. huggingface_hub/cli/hf.py +60 -0
  21. huggingface_hub/cli/inference_endpoints.py +377 -0
  22. huggingface_hub/cli/jobs.py +772 -0
  23. huggingface_hub/cli/lfs.py +175 -0
  24. huggingface_hub/cli/repo.py +315 -0
  25. huggingface_hub/cli/repo_files.py +94 -0
  26. huggingface_hub/{commands/env.py → cli/system.py} +10 -13
  27. huggingface_hub/cli/upload.py +294 -0
  28. huggingface_hub/cli/upload_large_folder.py +117 -0
  29. huggingface_hub/community.py +20 -12
  30. huggingface_hub/constants.py +83 -59
  31. huggingface_hub/dataclasses.py +609 -0
  32. huggingface_hub/errors.py +99 -30
  33. huggingface_hub/fastai_utils.py +30 -41
  34. huggingface_hub/file_download.py +606 -346
  35. huggingface_hub/hf_api.py +2445 -1132
  36. huggingface_hub/hf_file_system.py +269 -152
  37. huggingface_hub/hub_mixin.py +61 -66
  38. huggingface_hub/inference/_client.py +501 -630
  39. huggingface_hub/inference/_common.py +133 -121
  40. huggingface_hub/inference/_generated/_async_client.py +536 -722
  41. huggingface_hub/inference/_generated/types/__init__.py +6 -1
  42. huggingface_hub/inference/_generated/types/automatic_speech_recognition.py +5 -6
  43. huggingface_hub/inference/_generated/types/base.py +10 -7
  44. huggingface_hub/inference/_generated/types/chat_completion.py +77 -31
  45. huggingface_hub/inference/_generated/types/depth_estimation.py +2 -2
  46. huggingface_hub/inference/_generated/types/document_question_answering.py +2 -2
  47. huggingface_hub/inference/_generated/types/feature_extraction.py +2 -2
  48. huggingface_hub/inference/_generated/types/fill_mask.py +2 -2
  49. huggingface_hub/inference/_generated/types/image_to_image.py +8 -2
  50. huggingface_hub/inference/_generated/types/image_to_text.py +2 -3
  51. huggingface_hub/inference/_generated/types/image_to_video.py +60 -0
  52. huggingface_hub/inference/_generated/types/sentence_similarity.py +3 -3
  53. huggingface_hub/inference/_generated/types/summarization.py +2 -2
  54. huggingface_hub/inference/_generated/types/table_question_answering.py +5 -5
  55. huggingface_hub/inference/_generated/types/text2text_generation.py +2 -2
  56. huggingface_hub/inference/_generated/types/text_generation.py +11 -11
  57. huggingface_hub/inference/_generated/types/text_to_audio.py +1 -2
  58. huggingface_hub/inference/_generated/types/text_to_speech.py +1 -2
  59. huggingface_hub/inference/_generated/types/text_to_video.py +2 -2
  60. huggingface_hub/inference/_generated/types/token_classification.py +2 -2
  61. huggingface_hub/inference/_generated/types/translation.py +2 -2
  62. huggingface_hub/inference/_generated/types/zero_shot_classification.py +2 -2
  63. huggingface_hub/inference/_generated/types/zero_shot_image_classification.py +2 -2
  64. huggingface_hub/inference/_generated/types/zero_shot_object_detection.py +1 -3
  65. huggingface_hub/inference/_mcp/__init__.py +0 -0
  66. huggingface_hub/inference/_mcp/_cli_hacks.py +88 -0
  67. huggingface_hub/inference/_mcp/agent.py +100 -0
  68. huggingface_hub/inference/_mcp/cli.py +247 -0
  69. huggingface_hub/inference/_mcp/constants.py +81 -0
  70. huggingface_hub/inference/_mcp/mcp_client.py +395 -0
  71. huggingface_hub/inference/_mcp/types.py +45 -0
  72. huggingface_hub/inference/_mcp/utils.py +128 -0
  73. huggingface_hub/inference/_providers/__init__.py +149 -20
  74. huggingface_hub/inference/_providers/_common.py +160 -37
  75. huggingface_hub/inference/_providers/black_forest_labs.py +12 -9
  76. huggingface_hub/inference/_providers/cerebras.py +6 -0
  77. huggingface_hub/inference/_providers/clarifai.py +13 -0
  78. huggingface_hub/inference/_providers/cohere.py +32 -0
  79. huggingface_hub/inference/_providers/fal_ai.py +231 -22
  80. huggingface_hub/inference/_providers/featherless_ai.py +38 -0
  81. huggingface_hub/inference/_providers/fireworks_ai.py +22 -1
  82. huggingface_hub/inference/_providers/groq.py +9 -0
  83. huggingface_hub/inference/_providers/hf_inference.py +143 -33
  84. huggingface_hub/inference/_providers/hyperbolic.py +9 -5
  85. huggingface_hub/inference/_providers/nebius.py +47 -5
  86. huggingface_hub/inference/_providers/novita.py +48 -5
  87. huggingface_hub/inference/_providers/nscale.py +44 -0
  88. huggingface_hub/inference/_providers/openai.py +25 -0
  89. huggingface_hub/inference/_providers/publicai.py +6 -0
  90. huggingface_hub/inference/_providers/replicate.py +46 -9
  91. huggingface_hub/inference/_providers/sambanova.py +37 -1
  92. huggingface_hub/inference/_providers/scaleway.py +28 -0
  93. huggingface_hub/inference/_providers/together.py +34 -5
  94. huggingface_hub/inference/_providers/wavespeed.py +138 -0
  95. huggingface_hub/inference/_providers/zai_org.py +17 -0
  96. huggingface_hub/lfs.py +33 -100
  97. huggingface_hub/repocard.py +34 -38
  98. huggingface_hub/repocard_data.py +79 -59
  99. huggingface_hub/serialization/__init__.py +0 -1
  100. huggingface_hub/serialization/_base.py +12 -15
  101. huggingface_hub/serialization/_dduf.py +8 -8
  102. huggingface_hub/serialization/_torch.py +69 -69
  103. huggingface_hub/utils/__init__.py +27 -8
  104. huggingface_hub/utils/_auth.py +7 -7
  105. huggingface_hub/utils/_cache_manager.py +92 -147
  106. huggingface_hub/utils/_chunk_utils.py +2 -3
  107. huggingface_hub/utils/_deprecation.py +1 -1
  108. huggingface_hub/utils/_dotenv.py +55 -0
  109. huggingface_hub/utils/_experimental.py +7 -5
  110. huggingface_hub/utils/_fixes.py +0 -10
  111. huggingface_hub/utils/_git_credential.py +5 -5
  112. huggingface_hub/utils/_headers.py +8 -30
  113. huggingface_hub/utils/_http.py +399 -237
  114. huggingface_hub/utils/_pagination.py +6 -6
  115. huggingface_hub/utils/_parsing.py +98 -0
  116. huggingface_hub/utils/_paths.py +5 -5
  117. huggingface_hub/utils/_runtime.py +74 -22
  118. huggingface_hub/utils/_safetensors.py +21 -21
  119. huggingface_hub/utils/_subprocess.py +13 -11
  120. huggingface_hub/utils/_telemetry.py +4 -4
  121. huggingface_hub/{commands/_cli_utils.py → utils/_terminal.py} +4 -4
  122. huggingface_hub/utils/_typing.py +25 -5
  123. huggingface_hub/utils/_validators.py +55 -74
  124. huggingface_hub/utils/_verification.py +167 -0
  125. huggingface_hub/utils/_xet.py +235 -0
  126. huggingface_hub/utils/_xet_progress_reporting.py +162 -0
  127. huggingface_hub/utils/insecure_hashlib.py +3 -5
  128. huggingface_hub/utils/logging.py +8 -11
  129. huggingface_hub/utils/tqdm.py +33 -4
  130. {huggingface_hub-0.29.0rc2.dist-info → huggingface_hub-1.1.3.dist-info}/METADATA +94 -82
  131. huggingface_hub-1.1.3.dist-info/RECORD +155 -0
  132. {huggingface_hub-0.29.0rc2.dist-info → huggingface_hub-1.1.3.dist-info}/WHEEL +1 -1
  133. huggingface_hub-1.1.3.dist-info/entry_points.txt +6 -0
  134. huggingface_hub/commands/delete_cache.py +0 -428
  135. huggingface_hub/commands/download.py +0 -200
  136. huggingface_hub/commands/huggingface_cli.py +0 -61
  137. huggingface_hub/commands/lfs.py +0 -200
  138. huggingface_hub/commands/repo_files.py +0 -128
  139. huggingface_hub/commands/scan_cache.py +0 -181
  140. huggingface_hub/commands/tag.py +0 -159
  141. huggingface_hub/commands/upload.py +0 -299
  142. huggingface_hub/commands/upload_large_folder.py +0 -129
  143. huggingface_hub/commands/user.py +0 -304
  144. huggingface_hub/commands/version.py +0 -37
  145. huggingface_hub/inference_api.py +0 -217
  146. huggingface_hub/keras_mixin.py +0 -500
  147. huggingface_hub/repository.py +0 -1477
  148. huggingface_hub/serialization/_tensorflow.py +0 -95
  149. huggingface_hub/utils/_hf_folder.py +0 -68
  150. huggingface_hub-0.29.0rc2.dist-info/RECORD +0 -131
  151. huggingface_hub-0.29.0rc2.dist-info/entry_points.txt +0 -6
  152. {huggingface_hub-0.29.0rc2.dist-info → huggingface_hub-1.1.3.dist-info/licenses}/LICENSE +0 -0
  153. {huggingface_hub-0.29.0rc2.dist-info → huggingface_hub-1.1.3.dist-info}/top_level.txt +0 -0
@@ -1,21 +1,105 @@
1
1
  import os
2
2
  from pathlib import Path
3
- from typing import Dict, List, Literal, Optional, Union
3
+ from typing import Iterable, List, Literal, Optional, Union, overload
4
4
 
5
- import requests
5
+ import httpx
6
6
  from tqdm.auto import tqdm as base_tqdm
7
7
  from tqdm.contrib.concurrent import thread_map
8
8
 
9
9
  from . import constants
10
- from .errors import GatedRepoError, LocalEntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
11
- from .file_download import REGEX_COMMIT_HASH, hf_hub_download, repo_folder_name
12
- from .hf_api import DatasetInfo, HfApi, ModelInfo, SpaceInfo
13
- from .utils import OfflineModeIsEnabled, filter_repo_objects, logging, validate_hf_hub_args
10
+ from .errors import (
11
+ DryRunError,
12
+ GatedRepoError,
13
+ HfHubHTTPError,
14
+ LocalEntryNotFoundError,
15
+ RepositoryNotFoundError,
16
+ RevisionNotFoundError,
17
+ )
18
+ from .file_download import REGEX_COMMIT_HASH, DryRunFileInfo, hf_hub_download, repo_folder_name
19
+ from .hf_api import DatasetInfo, HfApi, ModelInfo, RepoFile, SpaceInfo
20
+ from .utils import OfflineModeIsEnabled, filter_repo_objects, is_tqdm_disabled, logging, validate_hf_hub_args
14
21
  from .utils import tqdm as hf_tqdm
15
22
 
16
23
 
17
24
  logger = logging.get_logger(__name__)
18
25
 
26
+ VERY_LARGE_REPO_THRESHOLD = 50000 # After this limit, we don't consider `repo_info.siblings` to be reliable enough
27
+
28
+
29
+ @overload
30
+ def snapshot_download(
31
+ repo_id: str,
32
+ *,
33
+ repo_type: Optional[str] = None,
34
+ revision: Optional[str] = None,
35
+ cache_dir: Union[str, Path, None] = None,
36
+ local_dir: Union[str, Path, None] = None,
37
+ library_name: Optional[str] = None,
38
+ library_version: Optional[str] = None,
39
+ user_agent: Optional[Union[dict, str]] = None,
40
+ etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT,
41
+ force_download: bool = False,
42
+ token: Optional[Union[bool, str]] = None,
43
+ local_files_only: bool = False,
44
+ allow_patterns: Optional[Union[list[str], str]] = None,
45
+ ignore_patterns: Optional[Union[list[str], str]] = None,
46
+ max_workers: int = 8,
47
+ tqdm_class: Optional[type[base_tqdm]] = None,
48
+ headers: Optional[dict[str, str]] = None,
49
+ endpoint: Optional[str] = None,
50
+ dry_run: Literal[False] = False,
51
+ ) -> str: ...
52
+
53
+
54
+ @overload
55
+ def snapshot_download(
56
+ repo_id: str,
57
+ *,
58
+ repo_type: Optional[str] = None,
59
+ revision: Optional[str] = None,
60
+ cache_dir: Union[str, Path, None] = None,
61
+ local_dir: Union[str, Path, None] = None,
62
+ library_name: Optional[str] = None,
63
+ library_version: Optional[str] = None,
64
+ user_agent: Optional[Union[dict, str]] = None,
65
+ etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT,
66
+ force_download: bool = False,
67
+ token: Optional[Union[bool, str]] = None,
68
+ local_files_only: bool = False,
69
+ allow_patterns: Optional[Union[list[str], str]] = None,
70
+ ignore_patterns: Optional[Union[list[str], str]] = None,
71
+ max_workers: int = 8,
72
+ tqdm_class: Optional[type[base_tqdm]] = None,
73
+ headers: Optional[dict[str, str]] = None,
74
+ endpoint: Optional[str] = None,
75
+ dry_run: Literal[True] = True,
76
+ ) -> list[DryRunFileInfo]: ...
77
+
78
+
79
+ @overload
80
+ def snapshot_download(
81
+ repo_id: str,
82
+ *,
83
+ repo_type: Optional[str] = None,
84
+ revision: Optional[str] = None,
85
+ cache_dir: Union[str, Path, None] = None,
86
+ local_dir: Union[str, Path, None] = None,
87
+ library_name: Optional[str] = None,
88
+ library_version: Optional[str] = None,
89
+ user_agent: Optional[Union[dict, str]] = None,
90
+ etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT,
91
+ force_download: bool = False,
92
+ token: Optional[Union[bool, str]] = None,
93
+ local_files_only: bool = False,
94
+ allow_patterns: Optional[Union[list[str], str]] = None,
95
+ ignore_patterns: Optional[Union[list[str], str]] = None,
96
+ max_workers: int = 8,
97
+ tqdm_class: Optional[type[base_tqdm]] = None,
98
+ headers: Optional[dict[str, str]] = None,
99
+ endpoint: Optional[str] = None,
100
+ dry_run: bool = False,
101
+ ) -> Union[str, list[DryRunFileInfo]]: ...
102
+
19
103
 
20
104
  @validate_hf_hub_args
21
105
  def snapshot_download(
@@ -27,22 +111,19 @@ def snapshot_download(
27
111
  local_dir: Union[str, Path, None] = None,
28
112
  library_name: Optional[str] = None,
29
113
  library_version: Optional[str] = None,
30
- user_agent: Optional[Union[Dict, str]] = None,
31
- proxies: Optional[Dict] = None,
114
+ user_agent: Optional[Union[dict, str]] = None,
32
115
  etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT,
33
116
  force_download: bool = False,
34
117
  token: Optional[Union[bool, str]] = None,
35
118
  local_files_only: bool = False,
36
- allow_patterns: Optional[Union[List[str], str]] = None,
37
- ignore_patterns: Optional[Union[List[str], str]] = None,
119
+ allow_patterns: Optional[Union[list[str], str]] = None,
120
+ ignore_patterns: Optional[Union[list[str], str]] = None,
38
121
  max_workers: int = 8,
39
- tqdm_class: Optional[base_tqdm] = None,
40
- headers: Optional[Dict[str, str]] = None,
122
+ tqdm_class: Optional[type[base_tqdm]] = None,
123
+ headers: Optional[dict[str, str]] = None,
41
124
  endpoint: Optional[str] = None,
42
- # Deprecated args
43
- local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto",
44
- resume_download: Optional[bool] = None,
45
- ) -> str:
125
+ dry_run: bool = False,
126
+ ) -> Union[str, list[DryRunFileInfo]]:
46
127
  """Download repo files.
47
128
 
48
129
  Download a whole snapshot of a repo's files at the specified revision. This is useful when you want all files from
@@ -77,12 +158,9 @@ def snapshot_download(
77
158
  The version of the library.
78
159
  user_agent (`str`, `dict`, *optional*):
79
160
  The user-agent info in the form of a dictionary or a string.
80
- proxies (`dict`, *optional*):
81
- Dictionary mapping protocol to the URL of the proxy passed to
82
- `requests.request`.
83
161
  etag_timeout (`float`, *optional*, defaults to `10`):
84
162
  When fetching ETag, how many seconds to wait for the server to send
85
- data before giving up which is passed to `requests.request`.
163
+ data before giving up which is passed to `httpx.request`.
86
164
  force_download (`bool`, *optional*, defaults to `False`):
87
165
  Whether the file should be downloaded even if it already exists in the local cache.
88
166
  token (`str`, `bool`, *optional*):
@@ -95,9 +173,9 @@ def snapshot_download(
95
173
  local_files_only (`bool`, *optional*, defaults to `False`):
96
174
  If `True`, avoid downloading the file and return the path to the
97
175
  local cached file if it exists.
98
- allow_patterns (`List[str]` or `str`, *optional*):
176
+ allow_patterns (`list[str]` or `str`, *optional*):
99
177
  If provided, only files matching at least one pattern are downloaded.
100
- ignore_patterns (`List[str]` or `str`, *optional*):
178
+ ignore_patterns (`list[str]` or `str`, *optional*):
101
179
  If provided, files matching any of the patterns are not downloaded.
102
180
  max_workers (`int`, *optional*):
103
181
  Number of concurrent threads to download files (1 thread = 1 file download).
@@ -108,9 +186,14 @@ def snapshot_download(
108
186
  Note that the `tqdm_class` is not passed to each individual download.
109
187
  Defaults to the custom HF progress bar that can be disabled by setting
110
188
  `HF_HUB_DISABLE_PROGRESS_BARS` environment variable.
189
+ dry_run (`bool`, *optional*, defaults to `False`):
190
+ If `True`, perform a dry run without actually downloading the files. Returns a list of
191
+ [`DryRunFileInfo`] objects containing information about what would be downloaded.
111
192
 
112
193
  Returns:
113
- `str`: folder path of the repo snapshot.
194
+ `str` or list of [`DryRunFileInfo`]:
195
+ - If `dry_run=False`: Local snapshot path.
196
+ - If `dry_run=True`: A list of [`DryRunFileInfo`] objects containing download information.
114
197
 
115
198
  Raises:
116
199
  [`~utils.RepositoryNotFoundError`]
@@ -139,28 +222,26 @@ def snapshot_download(
139
222
 
140
223
  storage_folder = os.path.join(cache_dir, repo_folder_name(repo_id=repo_id, repo_type=repo_type))
141
224
 
225
+ api = HfApi(
226
+ library_name=library_name,
227
+ library_version=library_version,
228
+ user_agent=user_agent,
229
+ endpoint=endpoint,
230
+ headers=headers,
231
+ token=token,
232
+ )
233
+
142
234
  repo_info: Union[ModelInfo, DatasetInfo, SpaceInfo, None] = None
143
235
  api_call_error: Optional[Exception] = None
144
236
  if not local_files_only:
145
237
  # try/except logic to handle different errors => taken from `hf_hub_download`
146
238
  try:
147
239
  # if we have internet connection we want to list files to download
148
- api = HfApi(
149
- library_name=library_name,
150
- library_version=library_version,
151
- user_agent=user_agent,
152
- endpoint=endpoint,
153
- headers=headers,
154
- )
155
- repo_info = api.repo_info(repo_id=repo_id, repo_type=repo_type, revision=revision, token=token)
156
- except (requests.exceptions.SSLError, requests.exceptions.ProxyError):
157
- # Actually raise for those subclasses of ConnectionError
240
+ repo_info = api.repo_info(repo_id=repo_id, repo_type=repo_type, revision=revision)
241
+ except httpx.ProxyError:
242
+ # Actually raise on proxy error
158
243
  raise
159
- except (
160
- requests.exceptions.ConnectionError,
161
- requests.exceptions.Timeout,
162
- OfflineModeIsEnabled,
163
- ) as error:
244
+ except (httpx.ConnectError, httpx.TimeoutException, OfflineModeIsEnabled) as error:
164
245
  # Internet connection is down
165
246
  # => will try to use local files only
166
247
  api_call_error = error
@@ -168,7 +249,7 @@ def snapshot_download(
168
249
  except RevisionNotFoundError:
169
250
  # The repo was found but the revision doesn't exist on the Hub (never existed or got deleted)
170
251
  raise
171
- except requests.HTTPError as error:
252
+ except HfHubHTTPError as error:
172
253
  # Multiple reasons for an http error:
173
254
  # - Repository is private and invalid/missing token sent
174
255
  # - Repository is gated and invalid/missing token sent
@@ -188,6 +269,11 @@ def snapshot_download(
188
269
  # - f the specified revision is a branch or tag, look inside "refs".
189
270
  # => if local_dir is not None, we will return the path to the local folder if it exists.
190
271
  if repo_info is None:
272
+ if dry_run:
273
+ raise DryRunError(
274
+ "Dry run cannot be performed as the repository cannot be accessed. Please check your internet connection or authentication token."
275
+ ) from api_call_error
276
+
191
277
  # Try to get which commit hash corresponds to the specified revision
192
278
  commit_hash = None
193
279
  if REGEX_COMMIT_HASH.match(revision):
@@ -200,12 +286,13 @@ def snapshot_download(
200
286
  commit_hash = f.read()
201
287
 
202
288
  # Try to locate snapshot folder for this commit hash
203
- if commit_hash is not None:
289
+ if commit_hash is not None and local_dir is None:
204
290
  snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash)
205
291
  if os.path.exists(snapshot_folder):
206
292
  # Snapshot folder exists => let's return it
207
293
  # (but we can't check if all the files are actually there)
208
294
  return snapshot_folder
295
+
209
296
  # If local_dir is not None, return it if it exists and is not empty
210
297
  if local_dir is not None:
211
298
  local_dir = Path(local_dir)
@@ -227,8 +314,10 @@ def snapshot_download(
227
314
  "outgoing traffic has been disabled. To enable repo look-ups and downloads online, set "
228
315
  "'HF_HUB_OFFLINE=0' as environment variable."
229
316
  ) from api_call_error
230
- elif isinstance(api_call_error, RepositoryNotFoundError) or isinstance(api_call_error, GatedRepoError):
231
- # Repo not found => let's raise the actual error
317
+ elif isinstance(api_call_error, (RepositoryNotFoundError, GatedRepoError)) or (
318
+ isinstance(api_call_error, HfHubHTTPError) and api_call_error.response.status_code == 401
319
+ ):
320
+ # Repo not found, gated, or specific authentication error => let's raise the actual error
232
321
  raise api_call_error
233
322
  else:
234
323
  # Otherwise: most likely a connection issue or Hub downtime => let's warn the user
@@ -241,14 +330,39 @@ def snapshot_download(
241
330
  # At this stage, internet connection is up and running
242
331
  # => let's download the files!
243
332
  assert repo_info.sha is not None, "Repo info returned from server must have a revision sha."
244
- assert repo_info.siblings is not None, "Repo info returned from server must have a siblings list."
245
- filtered_repo_files = list(
246
- filter_repo_objects(
247
- items=[f.rfilename for f in repo_info.siblings],
248
- allow_patterns=allow_patterns,
249
- ignore_patterns=ignore_patterns,
333
+
334
+ # Corner case: on very large repos, the siblings list in `repo_info` might not contain all files.
335
+ # In that case, we need to use the `list_repo_tree` method to prevent caching issues.
336
+ repo_files: Iterable[str] = [f.rfilename for f in repo_info.siblings] if repo_info.siblings is not None else []
337
+ unreliable_nb_files = (
338
+ repo_info.siblings is None
339
+ or len(repo_info.siblings) == 0
340
+ or len(repo_info.siblings) > VERY_LARGE_REPO_THRESHOLD
341
+ )
342
+ if unreliable_nb_files:
343
+ logger.info(
344
+ "Number of files in the repo is unreliable. Using `list_repo_tree` to ensure all files are listed."
250
345
  )
346
+ repo_files = (
347
+ f.rfilename
348
+ for f in api.list_repo_tree(repo_id=repo_id, recursive=True, revision=revision, repo_type=repo_type)
349
+ if isinstance(f, RepoFile)
350
+ )
351
+
352
+ filtered_repo_files: Iterable[str] = filter_repo_objects(
353
+ items=repo_files,
354
+ allow_patterns=allow_patterns,
355
+ ignore_patterns=ignore_patterns,
251
356
  )
357
+
358
+ if not unreliable_nb_files:
359
+ filtered_repo_files = list(filtered_repo_files)
360
+ tqdm_desc = f"Fetching {len(filtered_repo_files)} files"
361
+ else:
362
+ tqdm_desc = "Fetching ... files"
363
+ if dry_run:
364
+ tqdm_desc = "[dry-run] " + tqdm_desc
365
+
252
366
  commit_hash = repo_info.sha
253
367
  snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash)
254
368
  # if passed revision is not identical to commit_hash
@@ -263,44 +377,90 @@ def snapshot_download(
263
377
  except OSError as e:
264
378
  logger.warning(f"Ignored error while writing commit hash to {ref_path}: {e}.")
265
379
 
380
+ results: List[Union[str, DryRunFileInfo]] = []
381
+
382
+ # User can use its own tqdm class or the default one from `huggingface_hub.utils`
383
+ tqdm_class = tqdm_class or hf_tqdm
384
+
385
+ # Create a progress bar for the bytes downloaded
386
+ # This progress bar is shared across threads/files and gets updated each time we fetch
387
+ # metadata for a file.
388
+ bytes_progress = tqdm_class(
389
+ desc="Downloading (incomplete total...)",
390
+ disable=is_tqdm_disabled(log_level=logger.getEffectiveLevel()),
391
+ total=0,
392
+ initial=0,
393
+ unit="B",
394
+ unit_scale=True,
395
+ name="huggingface_hub.snapshot_download",
396
+ )
397
+
398
+ class _AggregatedTqdm:
399
+ """Fake tqdm object to aggregate progress into the parent `bytes_progress` bar.
400
+
401
+ In practice the `_AggregatedTqdm` object won't be displayed, it's just used to update
402
+ the `bytes_progress` bar from each thread/file download.
403
+ """
404
+
405
+ def __init__(self, *args, **kwargs):
406
+ # Adjust the total of the parent progress bar
407
+ total = kwargs.pop("total", None)
408
+ if total is not None:
409
+ bytes_progress.total += total
410
+ bytes_progress.refresh()
411
+
412
+ # Adjust initial of the parent progress bar
413
+ initial = kwargs.pop("initial", 0)
414
+ if initial:
415
+ bytes_progress.update(initial)
416
+
417
+ def __enter__(self):
418
+ return self
419
+
420
+ def __exit__(self, exc_type, exc_value, traceback):
421
+ pass
422
+
423
+ def update(self, n: Optional[Union[int, float]] = 1) -> None:
424
+ bytes_progress.update(n)
425
+
266
426
  # we pass the commit_hash to hf_hub_download
267
427
  # so no network call happens if we already
268
428
  # have the file locally.
269
- def _inner_hf_hub_download(repo_file: str):
270
- return hf_hub_download(
271
- repo_id,
272
- filename=repo_file,
273
- repo_type=repo_type,
274
- revision=commit_hash,
275
- endpoint=endpoint,
276
- cache_dir=cache_dir,
277
- local_dir=local_dir,
278
- local_dir_use_symlinks=local_dir_use_symlinks,
279
- library_name=library_name,
280
- library_version=library_version,
281
- user_agent=user_agent,
282
- proxies=proxies,
283
- etag_timeout=etag_timeout,
284
- resume_download=resume_download,
285
- force_download=force_download,
286
- token=token,
287
- headers=headers,
429
+ def _inner_hf_hub_download(repo_file: str) -> None:
430
+ results.append(
431
+ hf_hub_download( # type: ignore
432
+ repo_id,
433
+ filename=repo_file,
434
+ repo_type=repo_type,
435
+ revision=commit_hash,
436
+ endpoint=endpoint,
437
+ cache_dir=cache_dir,
438
+ local_dir=local_dir,
439
+ library_name=library_name,
440
+ library_version=library_version,
441
+ user_agent=user_agent,
442
+ etag_timeout=etag_timeout,
443
+ force_download=force_download,
444
+ token=token,
445
+ headers=headers,
446
+ tqdm_class=_AggregatedTqdm, # type: ignore
447
+ dry_run=dry_run,
448
+ )
288
449
  )
289
450
 
290
- if constants.HF_HUB_ENABLE_HF_TRANSFER:
291
- # when using hf_transfer we don't want extra parallelism
292
- # from the one hf_transfer provides
293
- for file in filtered_repo_files:
294
- _inner_hf_hub_download(file)
295
- else:
296
- thread_map(
297
- _inner_hf_hub_download,
298
- filtered_repo_files,
299
- desc=f"Fetching {len(filtered_repo_files)} files",
300
- max_workers=max_workers,
301
- # User can use its own tqdm class or the default one from `huggingface_hub.utils`
302
- tqdm_class=tqdm_class or hf_tqdm,
303
- )
451
+ thread_map(
452
+ _inner_hf_hub_download,
453
+ filtered_repo_files,
454
+ desc=tqdm_desc,
455
+ max_workers=max_workers,
456
+ tqdm_class=tqdm_class,
457
+ )
458
+
459
+ bytes_progress.set_description("Download complete")
460
+
461
+ if dry_run:
462
+ assert all(isinstance(r, DryRunFileInfo) for r in results)
463
+ return results # type: ignore
304
464
 
305
465
  if local_dir is not None:
306
466
  return str(os.path.realpath(local_dir))
@@ -15,7 +15,7 @@
15
15
  from dataclasses import dataclass
16
16
  from datetime import datetime
17
17
  from enum import Enum
18
- from typing import Dict, Optional
18
+ from typing import Optional
19
19
 
20
20
  from huggingface_hub.utils import parse_datetime
21
21
 
@@ -54,24 +54,32 @@ class SpaceHardware(str, Enum):
54
54
  assert SpaceHardware.CPU_BASIC == "cpu-basic"
55
55
  ```
56
56
 
57
- Taken from https://github.com/huggingface/moon-landing/blob/main/server/repo_types/SpaceInfo.ts#L73 (private url).
57
+ Taken from https://github.com/huggingface-internal/moon-landing/blob/main/server/repo_types/SpaceHardwareFlavor.ts (private url).
58
58
  """
59
59
 
60
+ # CPU
60
61
  CPU_BASIC = "cpu-basic"
61
62
  CPU_UPGRADE = "cpu-upgrade"
63
+ CPU_XL = "cpu-xl"
64
+
65
+ # ZeroGPU
66
+ ZERO_A10G = "zero-a10g"
67
+
68
+ # GPU
62
69
  T4_SMALL = "t4-small"
63
70
  T4_MEDIUM = "t4-medium"
64
71
  L4X1 = "l4x1"
65
72
  L4X4 = "l4x4"
66
- ZERO_A10G = "zero-a10g"
73
+ L40SX1 = "l40sx1"
74
+ L40SX4 = "l40sx4"
75
+ L40SX8 = "l40sx8"
67
76
  A10G_SMALL = "a10g-small"
68
77
  A10G_LARGE = "a10g-large"
69
78
  A10G_LARGEX2 = "a10g-largex2"
70
79
  A10G_LARGEX4 = "a10g-largex4"
71
80
  A100_LARGE = "a100-large"
72
- V5E_1X1 = "v5e-1x1"
73
- V5E_2X2 = "v5e-2x2"
74
- V5E_2X4 = "v5e-2x4"
81
+ H100 = "h100"
82
+ H100X8 = "h100x8"
75
83
 
76
84
 
77
85
  class SpaceStorage(str, Enum):
@@ -103,7 +111,7 @@ class SpaceRuntime:
103
111
  Current hardware of the space. Example: "cpu-basic". Can be `None` if Space
104
112
  is `BUILDING` for the first time.
105
113
  requested_hardware (`str` or `None`):
106
- Requested hardware. Can be different than `hardware` especially if the request
114
+ Requested hardware. Can be different from `hardware` especially if the request
107
115
  has just been made. Example: "t4-medium". Can be `None` if no hardware has
108
116
  been requested yet.
109
117
  sleep_time (`int` or `None`):
@@ -120,9 +128,9 @@ class SpaceRuntime:
120
128
  requested_hardware: Optional[SpaceHardware]
121
129
  sleep_time: Optional[int]
122
130
  storage: Optional[SpaceStorage]
123
- raw: Dict
131
+ raw: dict
124
132
 
125
- def __init__(self, data: Dict) -> None:
133
+ def __init__(self, data: dict) -> None:
126
134
  self.stage = data["stage"]
127
135
  self.hardware = data.get("hardware", {}).get("current")
128
136
  self.requested_hardware = data.get("hardware", {}).get("requested")
@@ -152,7 +160,7 @@ class SpaceVariable:
152
160
  description: Optional[str]
153
161
  updated_at: Optional[datetime]
154
162
 
155
- def __init__(self, key: str, values: Dict) -> None:
163
+ def __init__(self, key: str, values: dict) -> None:
156
164
  self.key = key
157
165
  self.value = values["value"]
158
166
  self.description = values.get("description")
@@ -14,7 +14,7 @@
14
14
  """Contains a logger to push training logs to the Hub, using Tensorboard."""
15
15
 
16
16
  from pathlib import Path
17
- from typing import TYPE_CHECKING, List, Optional, Union
17
+ from typing import Optional, Union
18
18
 
19
19
  from ._commit_scheduler import CommitScheduler
20
20
  from .errors import EntryNotFoundError
@@ -26,25 +26,24 @@ from .utils import experimental
26
26
  # or from 'torch.utils.tensorboard'. Both are compatible so let's try to load
27
27
  # from either of them.
28
28
  try:
29
- from tensorboardX import SummaryWriter
29
+ from tensorboardX import SummaryWriter as _RuntimeSummaryWriter
30
30
 
31
31
  is_summary_writer_available = True
32
-
33
32
  except ImportError:
34
33
  try:
35
- from torch.utils.tensorboard import SummaryWriter
34
+ from torch.utils.tensorboard import SummaryWriter as _RuntimeSummaryWriter
36
35
 
37
- is_summary_writer_available = False
36
+ is_summary_writer_available = True
38
37
  except ImportError:
39
38
  # Dummy class to avoid failing at import. Will raise on instance creation.
40
- SummaryWriter = object
41
- is_summary_writer_available = False
39
+ class _DummySummaryWriter:
40
+ pass
42
41
 
43
- if TYPE_CHECKING:
44
- from tensorboardX import SummaryWriter
42
+ _RuntimeSummaryWriter = _DummySummaryWriter # type: ignore[assignment]
43
+ is_summary_writer_available = False
45
44
 
46
45
 
47
- class HFSummaryWriter(SummaryWriter):
46
+ class HFSummaryWriter(_RuntimeSummaryWriter):
48
47
  """
49
48
  Wrapper around the tensorboard's `SummaryWriter` to push training logs to the Hub.
50
49
 
@@ -53,11 +52,8 @@ class HFSummaryWriter(SummaryWriter):
53
52
  issue), the main script will not be interrupted. Data is automatically pushed to the Hub every `commit_every`
54
53
  minutes (default to every 5 minutes).
55
54
 
56
- <Tip warning={true}>
57
-
58
- `HFSummaryWriter` is experimental. Its API is subject to change in the future without prior notice.
59
-
60
- </Tip>
55
+ > [!WARNING]
56
+ > `HFSummaryWriter` is experimental. Its API is subject to change in the future without prior notice.
61
57
 
62
58
  Args:
63
59
  repo_id (`str`):
@@ -78,10 +74,10 @@ class HFSummaryWriter(SummaryWriter):
78
74
  Whether to make the repo private. If `None` (default), the repo will be public unless the organization's default is private. This value is ignored if the repo already exists.
79
75
  path_in_repo (`str`, *optional*):
80
76
  The path to the folder in the repo where the logs will be pushed. Defaults to "tensorboard/".
81
- repo_allow_patterns (`List[str]` or `str`, *optional*):
77
+ repo_allow_patterns (`list[str]` or `str`, *optional*):
82
78
  A list of patterns to include in the upload. Defaults to `"*.tfevents.*"`. Check out the
83
79
  [upload guide](https://huggingface.co/docs/huggingface_hub/guides/upload#upload-a-folder) for more details.
84
- repo_ignore_patterns (`List[str]` or `str`, *optional*):
80
+ repo_ignore_patterns (`list[str]` or `str`, *optional*):
85
81
  A list of patterns to exclude in the upload. Check out the
86
82
  [upload guide](https://huggingface.co/docs/huggingface_hub/guides/upload#upload-a-folder) for more details.
87
83
  token (`str`, *optional*):
@@ -138,8 +134,8 @@ class HFSummaryWriter(SummaryWriter):
138
134
  repo_revision: Optional[str] = None,
139
135
  repo_private: Optional[bool] = None,
140
136
  path_in_repo: Optional[str] = "tensorboard",
141
- repo_allow_patterns: Optional[Union[List[str], str]] = "*.tfevents.*",
142
- repo_ignore_patterns: Optional[Union[List[str], str]] = None,
137
+ repo_allow_patterns: Optional[Union[list[str], str]] = "*.tfevents.*",
138
+ repo_ignore_patterns: Optional[Union[list[str], str]] = None,
143
139
  token: Optional[str] = None,
144
140
  **kwargs,
145
141
  ):