huggingface-hub 1.0.0rc2__py3-none-any.whl → 1.0.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of huggingface-hub might be problematic. Click here for more details.

Files changed (44) hide show
  1. huggingface_hub/__init__.py +4 -7
  2. huggingface_hub/_login.py +2 -2
  3. huggingface_hub/_snapshot_download.py +119 -21
  4. huggingface_hub/_upload_large_folder.py +1 -2
  5. huggingface_hub/cli/_cli_utils.py +12 -6
  6. huggingface_hub/cli/download.py +32 -7
  7. huggingface_hub/dataclasses.py +132 -3
  8. huggingface_hub/errors.py +4 -0
  9. huggingface_hub/file_download.py +216 -17
  10. huggingface_hub/hf_api.py +127 -14
  11. huggingface_hub/hf_file_system.py +38 -21
  12. huggingface_hub/inference/_client.py +3 -2
  13. huggingface_hub/inference/_generated/_async_client.py +3 -2
  14. huggingface_hub/inference/_generated/types/image_to_image.py +6 -2
  15. huggingface_hub/inference/_mcp/mcp_client.py +4 -3
  16. huggingface_hub/inference/_providers/__init__.py +5 -0
  17. huggingface_hub/inference/_providers/_common.py +1 -0
  18. huggingface_hub/inference/_providers/fal_ai.py +2 -0
  19. huggingface_hub/inference/_providers/zai_org.py +17 -0
  20. huggingface_hub/utils/__init__.py +1 -2
  21. huggingface_hub/utils/_cache_manager.py +1 -1
  22. huggingface_hub/utils/_http.py +10 -38
  23. huggingface_hub/utils/_validators.py +2 -2
  24. {huggingface_hub-1.0.0rc2.dist-info → huggingface_hub-1.0.0rc3.dist-info}/METADATA +1 -1
  25. {huggingface_hub-1.0.0rc2.dist-info → huggingface_hub-1.0.0rc3.dist-info}/RECORD +29 -43
  26. {huggingface_hub-1.0.0rc2.dist-info → huggingface_hub-1.0.0rc3.dist-info}/entry_points.txt +0 -1
  27. huggingface_hub/commands/__init__.py +0 -27
  28. huggingface_hub/commands/_cli_utils.py +0 -74
  29. huggingface_hub/commands/delete_cache.py +0 -476
  30. huggingface_hub/commands/download.py +0 -195
  31. huggingface_hub/commands/env.py +0 -39
  32. huggingface_hub/commands/huggingface_cli.py +0 -65
  33. huggingface_hub/commands/lfs.py +0 -200
  34. huggingface_hub/commands/repo.py +0 -151
  35. huggingface_hub/commands/repo_files.py +0 -132
  36. huggingface_hub/commands/scan_cache.py +0 -183
  37. huggingface_hub/commands/tag.py +0 -159
  38. huggingface_hub/commands/upload.py +0 -318
  39. huggingface_hub/commands/upload_large_folder.py +0 -131
  40. huggingface_hub/commands/user.py +0 -207
  41. huggingface_hub/commands/version.py +0 -40
  42. {huggingface_hub-1.0.0rc2.dist-info → huggingface_hub-1.0.0rc3.dist-info}/LICENSE +0 -0
  43. {huggingface_hub-1.0.0rc2.dist-info → huggingface_hub-1.0.0rc3.dist-info}/WHEEL +0 -0
  44. {huggingface_hub-1.0.0rc2.dist-info → huggingface_hub-1.0.0rc3.dist-info}/top_level.txt +0 -0
@@ -46,7 +46,7 @@ import sys
46
46
  from typing import TYPE_CHECKING
47
47
 
48
48
 
49
- __version__ = "1.0.0.rc2"
49
+ __version__ = "1.0.0.rc3"
50
50
 
51
51
  # Alphabetical order of definitions is ensured in tests
52
52
  # WARNING: any comment added in this dictionary definition will be lost when
@@ -138,6 +138,7 @@ _SUBMOD_ATTRS = {
138
138
  "push_to_hub_fastai",
139
139
  ],
140
140
  "file_download": [
141
+ "DryRunFileInfo",
141
142
  "HfFileMetadata",
142
143
  "_CACHED_NO_EXIST",
143
144
  "get_hf_file_metadata",
@@ -513,8 +514,6 @@ _SUBMOD_ATTRS = {
513
514
  "CorruptedCacheException",
514
515
  "DeleteCacheStrategy",
515
516
  "HFCacheInfo",
516
- "HfHubAsyncTransport",
517
- "HfHubTransport",
518
517
  "cached_assets_path",
519
518
  "close_session",
520
519
  "dump_environment_info",
@@ -625,6 +624,7 @@ __all__ = [
625
624
  "DocumentQuestionAnsweringInputData",
626
625
  "DocumentQuestionAnsweringOutputElement",
627
626
  "DocumentQuestionAnsweringParameters",
627
+ "DryRunFileInfo",
628
628
  "EvalResult",
629
629
  "FLAX_WEIGHTS_NAME",
630
630
  "FeatureExtractionInput",
@@ -645,8 +645,6 @@ __all__ = [
645
645
  "HfFileSystemFile",
646
646
  "HfFileSystemResolvedPath",
647
647
  "HfFileSystemStreamFile",
648
- "HfHubAsyncTransport",
649
- "HfHubTransport",
650
648
  "ImageClassificationInput",
651
649
  "ImageClassificationOutputElement",
652
650
  "ImageClassificationOutputTransform",
@@ -1147,6 +1145,7 @@ if TYPE_CHECKING: # pragma: no cover
1147
1145
  )
1148
1146
  from .file_download import (
1149
1147
  _CACHED_NO_EXIST, # noqa: F401
1148
+ DryRunFileInfo, # noqa: F401
1150
1149
  HfFileMetadata, # noqa: F401
1151
1150
  get_hf_file_metadata, # noqa: F401
1152
1151
  hf_hub_download, # noqa: F401
@@ -1515,8 +1514,6 @@ if TYPE_CHECKING: # pragma: no cover
1515
1514
  CorruptedCacheException, # noqa: F401
1516
1515
  DeleteCacheStrategy, # noqa: F401
1517
1516
  HFCacheInfo, # noqa: F401
1518
- HfHubAsyncTransport, # noqa: F401
1519
- HfHubTransport, # noqa: F401
1520
1517
  cached_assets_path, # noqa: F401
1521
1518
  close_session, # noqa: F401
1522
1519
  dump_environment_info, # noqa: F401
huggingface_hub/_login.py CHANGED
@@ -20,7 +20,7 @@ from pathlib import Path
20
20
  from typing import Optional
21
21
 
22
22
  from . import constants
23
- from .commands._cli_utils import ANSI
23
+ from .cli._cli_utils import ANSI
24
24
  from .utils import (
25
25
  capture_output,
26
26
  get_token,
@@ -244,7 +244,7 @@ def interpreter_login(*, skip_if_logged_in: bool = False) -> None:
244
244
  logger.info("User is already logged in.")
245
245
  return
246
246
 
247
- from .commands.delete_cache import _ask_for_confirmation_no_tui
247
+ from .cli.cache import _ask_for_confirmation_no_tui
248
248
 
249
249
  print(_HF_LOGO_ASCII)
250
250
  if get_token() is not None:
@@ -1,6 +1,6 @@
1
1
  import os
2
2
  from pathlib import Path
3
- from typing import Iterable, Optional, Union
3
+ from typing import Iterable, List, Literal, Optional, Union, overload
4
4
 
5
5
  import httpx
6
6
  from tqdm.auto import tqdm as base_tqdm
@@ -8,13 +8,14 @@ from tqdm.contrib.concurrent import thread_map
8
8
 
9
9
  from . import constants
10
10
  from .errors import (
11
+ DryRunError,
11
12
  GatedRepoError,
12
13
  HfHubHTTPError,
13
14
  LocalEntryNotFoundError,
14
15
  RepositoryNotFoundError,
15
16
  RevisionNotFoundError,
16
17
  )
17
- from .file_download import REGEX_COMMIT_HASH, hf_hub_download, repo_folder_name
18
+ from .file_download import REGEX_COMMIT_HASH, DryRunFileInfo, hf_hub_download, repo_folder_name
18
19
  from .hf_api import DatasetInfo, HfApi, ModelInfo, RepoFile, SpaceInfo
19
20
  from .utils import OfflineModeIsEnabled, filter_repo_objects, logging, validate_hf_hub_args
20
21
  from .utils import tqdm as hf_tqdm
@@ -25,6 +26,81 @@ logger = logging.get_logger(__name__)
25
26
  VERY_LARGE_REPO_THRESHOLD = 50000 # After this limit, we don't consider `repo_info.siblings` to be reliable enough
26
27
 
27
28
 
29
+ @overload
30
+ def snapshot_download(
31
+ repo_id: str,
32
+ *,
33
+ repo_type: Optional[str] = None,
34
+ revision: Optional[str] = None,
35
+ cache_dir: Union[str, Path, None] = None,
36
+ local_dir: Union[str, Path, None] = None,
37
+ library_name: Optional[str] = None,
38
+ library_version: Optional[str] = None,
39
+ user_agent: Optional[Union[dict, str]] = None,
40
+ etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT,
41
+ force_download: bool = False,
42
+ token: Optional[Union[bool, str]] = None,
43
+ local_files_only: bool = False,
44
+ allow_patterns: Optional[Union[list[str], str]] = None,
45
+ ignore_patterns: Optional[Union[list[str], str]] = None,
46
+ max_workers: int = 8,
47
+ tqdm_class: Optional[type[base_tqdm]] = None,
48
+ headers: Optional[dict[str, str]] = None,
49
+ endpoint: Optional[str] = None,
50
+ dry_run: Literal[False] = False,
51
+ ) -> str: ...
52
+
53
+
54
+ @overload
55
+ def snapshot_download(
56
+ repo_id: str,
57
+ *,
58
+ repo_type: Optional[str] = None,
59
+ revision: Optional[str] = None,
60
+ cache_dir: Union[str, Path, None] = None,
61
+ local_dir: Union[str, Path, None] = None,
62
+ library_name: Optional[str] = None,
63
+ library_version: Optional[str] = None,
64
+ user_agent: Optional[Union[dict, str]] = None,
65
+ etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT,
66
+ force_download: bool = False,
67
+ token: Optional[Union[bool, str]] = None,
68
+ local_files_only: bool = False,
69
+ allow_patterns: Optional[Union[list[str], str]] = None,
70
+ ignore_patterns: Optional[Union[list[str], str]] = None,
71
+ max_workers: int = 8,
72
+ tqdm_class: Optional[type[base_tqdm]] = None,
73
+ headers: Optional[dict[str, str]] = None,
74
+ endpoint: Optional[str] = None,
75
+ dry_run: Literal[True] = True,
76
+ ) -> list[DryRunFileInfo]: ...
77
+
78
+
79
+ @overload
80
+ def snapshot_download(
81
+ repo_id: str,
82
+ *,
83
+ repo_type: Optional[str] = None,
84
+ revision: Optional[str] = None,
85
+ cache_dir: Union[str, Path, None] = None,
86
+ local_dir: Union[str, Path, None] = None,
87
+ library_name: Optional[str] = None,
88
+ library_version: Optional[str] = None,
89
+ user_agent: Optional[Union[dict, str]] = None,
90
+ etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT,
91
+ force_download: bool = False,
92
+ token: Optional[Union[bool, str]] = None,
93
+ local_files_only: bool = False,
94
+ allow_patterns: Optional[Union[list[str], str]] = None,
95
+ ignore_patterns: Optional[Union[list[str], str]] = None,
96
+ max_workers: int = 8,
97
+ tqdm_class: Optional[type[base_tqdm]] = None,
98
+ headers: Optional[dict[str, str]] = None,
99
+ endpoint: Optional[str] = None,
100
+ dry_run: bool = False,
101
+ ) -> Union[str, list[DryRunFileInfo]]: ...
102
+
103
+
28
104
  @validate_hf_hub_args
29
105
  def snapshot_download(
30
106
  repo_id: str,
@@ -46,7 +122,8 @@ def snapshot_download(
46
122
  tqdm_class: Optional[type[base_tqdm]] = None,
47
123
  headers: Optional[dict[str, str]] = None,
48
124
  endpoint: Optional[str] = None,
49
- ) -> str:
125
+ dry_run: bool = False,
126
+ ) -> Union[str, list[DryRunFileInfo]]:
50
127
  """Download repo files.
51
128
 
52
129
  Download a whole snapshot of a repo's files at the specified revision. This is useful when you want all files from
@@ -109,9 +186,14 @@ def snapshot_download(
109
186
  Note that the `tqdm_class` is not passed to each individual download.
110
187
  Defaults to the custom HF progress bar that can be disabled by setting
111
188
  `HF_HUB_DISABLE_PROGRESS_BARS` environment variable.
189
+ dry_run (`bool`, *optional*, defaults to `False`):
190
+ If `True`, perform a dry run without actually downloading the files. Returns a list of
191
+ [`DryRunFileInfo`] objects containing information about what would be downloaded.
112
192
 
113
193
  Returns:
114
- `str`: folder path of the repo snapshot.
194
+ `str` or list of [`DryRunFileInfo`]:
195
+ - If `dry_run=False`: Local snapshot path.
196
+ - If `dry_run=True`: A list of [`DryRunFileInfo`] objects containing download information.
115
197
 
116
198
  Raises:
117
199
  [`~utils.RepositoryNotFoundError`]
@@ -187,6 +269,11 @@ def snapshot_download(
187
269
  # - f the specified revision is a branch or tag, look inside "refs".
188
270
  # => if local_dir is not None, we will return the path to the local folder if it exists.
189
271
  if repo_info is None:
272
+ if dry_run:
273
+ raise DryRunError(
274
+ "Dry run cannot be performed as the repository cannot be accessed. Please check your internet connection or authentication token."
275
+ ) from api_call_error
276
+
190
277
  # Try to get which commit hash corresponds to the specified revision
191
278
  commit_hash = None
192
279
  if REGEX_COMMIT_HASH.match(revision):
@@ -273,6 +360,8 @@ def snapshot_download(
273
360
  tqdm_desc = f"Fetching {len(filtered_repo_files)} files"
274
361
  else:
275
362
  tqdm_desc = "Fetching ... files"
363
+ if dry_run:
364
+ tqdm_desc = "[dry-run] " + tqdm_desc
276
365
 
277
366
  commit_hash = repo_info.sha
278
367
  snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash)
@@ -288,28 +377,33 @@ def snapshot_download(
288
377
  except OSError as e:
289
378
  logger.warning(f"Ignored error while writing commit hash to {ref_path}: {e}.")
290
379
 
380
+ results: List[Union[str, DryRunFileInfo]] = []
381
+
291
382
  # we pass the commit_hash to hf_hub_download
292
383
  # so no network call happens if we already
293
384
  # have the file locally.
294
- def _inner_hf_hub_download(repo_file: str):
295
- return hf_hub_download(
296
- repo_id,
297
- filename=repo_file,
298
- repo_type=repo_type,
299
- revision=commit_hash,
300
- endpoint=endpoint,
301
- cache_dir=cache_dir,
302
- local_dir=local_dir,
303
- library_name=library_name,
304
- library_version=library_version,
305
- user_agent=user_agent,
306
- etag_timeout=etag_timeout,
307
- force_download=force_download,
308
- token=token,
309
- headers=headers,
385
+ def _inner_hf_hub_download(repo_file: str) -> None:
386
+ results.append(
387
+ hf_hub_download( # type: ignore[no-matching-overload] # ty not happy, don't know why :/
388
+ repo_id,
389
+ filename=repo_file,
390
+ repo_type=repo_type,
391
+ revision=commit_hash,
392
+ endpoint=endpoint,
393
+ cache_dir=cache_dir,
394
+ local_dir=local_dir,
395
+ library_name=library_name,
396
+ library_version=library_version,
397
+ user_agent=user_agent,
398
+ etag_timeout=etag_timeout,
399
+ force_download=force_download,
400
+ token=token,
401
+ headers=headers,
402
+ dry_run=dry_run,
403
+ )
310
404
  )
311
405
 
312
- if constants.HF_HUB_ENABLE_HF_TRANSFER:
406
+ if constants.HF_HUB_ENABLE_HF_TRANSFER and not dry_run:
313
407
  # when using hf_transfer we don't want extra parallelism
314
408
  # from the one hf_transfer provides
315
409
  for file in filtered_repo_files:
@@ -324,6 +418,10 @@ def snapshot_download(
324
418
  tqdm_class=tqdm_class or hf_tqdm,
325
419
  )
326
420
 
421
+ if dry_run:
422
+ assert all(isinstance(r, DryRunFileInfo) for r in results)
423
+ return results # type: ignore
424
+
327
425
  if local_dir is not None:
328
426
  return str(os.path.realpath(local_dir))
329
427
  return snapshot_folder
@@ -31,8 +31,7 @@ from . import constants
31
31
  from ._commit_api import CommitOperationAdd, UploadInfo, _fetch_upload_modes
32
32
  from ._local_folder import LocalUploadFileMetadata, LocalUploadFilePaths, get_local_upload_paths, read_upload_metadata
33
33
  from .constants import DEFAULT_REVISION, REPO_TYPES
34
- from .utils import DEFAULT_IGNORE_PATTERNS, filter_repo_objects, tqdm
35
- from .utils._cache_manager import _format_size
34
+ from .utils import DEFAULT_IGNORE_PATTERNS, _format_size, filter_repo_objects, tqdm
36
35
  from .utils._runtime import is_xet_available
37
36
  from .utils.sha import sha_fileobj
38
37
 
@@ -15,13 +15,23 @@
15
15
 
16
16
  import os
17
17
  from enum import Enum
18
- from typing import Annotated, Optional, Union
18
+ from typing import TYPE_CHECKING, Annotated, Optional, Union
19
19
 
20
20
  import click
21
21
  import typer
22
22
 
23
23
  from huggingface_hub import __version__
24
- from huggingface_hub.hf_api import HfApi
24
+
25
+
26
+ if TYPE_CHECKING:
27
+ from huggingface_hub.hf_api import HfApi
28
+
29
+
30
+ def get_hf_api(token: Optional[str] = None) -> "HfApi":
31
+ # Import here to avoid circular import
32
+ from huggingface_hub.hf_api import HfApi
33
+
34
+ return HfApi(token=token, library_name="hf", library_version=__version__)
25
35
 
26
36
 
27
37
  class ANSI:
@@ -140,7 +150,3 @@ RevisionOpt = Annotated[
140
150
  help="Git revision id which can be a branch name, a tag, or a commit hash.",
141
151
  ),
142
152
  ]
143
-
144
-
145
- def get_hf_api(token: Optional[str] = None) -> HfApi:
146
- return HfApi(token=token, library_name="hf", library_version=__version__)
@@ -37,16 +37,16 @@ Usage:
37
37
  """
38
38
 
39
39
  import warnings
40
- from typing import Annotated, Optional
40
+ from typing import Annotated, Optional, Union
41
41
 
42
42
  import typer
43
43
 
44
44
  from huggingface_hub import logging
45
45
  from huggingface_hub._snapshot_download import snapshot_download
46
- from huggingface_hub.file_download import hf_hub_download
47
- from huggingface_hub.utils import disable_progress_bars, enable_progress_bars
46
+ from huggingface_hub.file_download import DryRunFileInfo, hf_hub_download
47
+ from huggingface_hub.utils import _format_size, disable_progress_bars, enable_progress_bars
48
48
 
49
- from ._cli_utils import RepoIdArg, RepoTypeOpt, RevisionOpt, TokenOpt
49
+ from ._cli_utils import RepoIdArg, RepoTypeOpt, RevisionOpt, TokenOpt, tabulate
50
50
 
51
51
 
52
52
  logger = logging.get_logger(__name__)
@@ -92,6 +92,12 @@ def download(
92
92
  help="If True, the files will be downloaded even if they are already cached.",
93
93
  ),
94
94
  ] = False,
95
+ dry_run: Annotated[
96
+ bool,
97
+ typer.Option(
98
+ help="If True, perform a dry run without actually downloading the file.",
99
+ ),
100
+ ] = False,
95
101
  token: TokenOpt = None,
96
102
  quiet: Annotated[
97
103
  bool,
@@ -108,7 +114,7 @@ def download(
108
114
  ) -> None:
109
115
  """Download files from the Hub."""
110
116
 
111
- def run_download() -> str:
117
+ def run_download() -> Union[str, DryRunFileInfo, list[DryRunFileInfo]]:
112
118
  filenames_list = filenames if filenames is not None else []
113
119
  # Warn user if patterns are ignored
114
120
  if len(filenames_list) > 0:
@@ -129,6 +135,7 @@ def download(
129
135
  token=token,
130
136
  local_dir=local_dir,
131
137
  library_name="hf",
138
+ dry_run=dry_run,
132
139
  )
133
140
 
134
141
  # Otherwise: use `snapshot_download` to ensure all files comes from same revision
@@ -151,14 +158,32 @@ def download(
151
158
  local_dir=local_dir,
152
159
  library_name="hf",
153
160
  max_workers=max_workers,
161
+ dry_run=dry_run,
162
+ )
163
+
164
+ def _print_result(result: Union[str, DryRunFileInfo, list[DryRunFileInfo]]) -> None:
165
+ if isinstance(result, str):
166
+ print(result)
167
+ return
168
+
169
+ # Print dry run info
170
+ if isinstance(result, DryRunFileInfo):
171
+ result = [result]
172
+ print(
173
+ f"[dry-run] Will download {len([r for r in result if r.will_download])} files (out of {len(result)}) totalling {_format_size(sum(r.file_size for r in result if r.will_download))}."
154
174
  )
175
+ columns = ["File", "Bytes to download"]
176
+ items: list[list[Union[str, int]]] = []
177
+ for info in sorted(result, key=lambda x: x.filename):
178
+ items.append([info.filename, _format_size(info.file_size) if info.will_download else "-"])
179
+ print(tabulate(items, headers=columns))
155
180
 
156
181
  if quiet:
157
182
  disable_progress_bars()
158
183
  with warnings.catch_warnings():
159
184
  warnings.simplefilter("ignore")
160
- print(run_download())
185
+ _print_result(run_download())
161
186
  enable_progress_bars()
162
187
  else:
163
- print(run_download())
188
+ _print_result(run_download())
164
189
  logging.set_verbosity_warning()
@@ -1,7 +1,33 @@
1
1
  import inspect
2
- from dataclasses import _MISSING_TYPE, MISSING, Field, field, fields
3
- from functools import wraps
4
- from typing import Any, Callable, ForwardRef, Literal, Optional, Type, TypeVar, Union, get_args, get_origin, overload
2
+ from dataclasses import _MISSING_TYPE, MISSING, Field, field, fields, make_dataclass
3
+ from functools import lru_cache, wraps
4
+ from typing import (
5
+ Annotated,
6
+ Any,
7
+ Callable,
8
+ ForwardRef,
9
+ Literal,
10
+ Optional,
11
+ Type,
12
+ TypeVar,
13
+ Union,
14
+ get_args,
15
+ get_origin,
16
+ overload,
17
+ )
18
+
19
+
20
+ try:
21
+ # Python 3.11+
22
+ from typing import NotRequired, Required # type: ignore
23
+ except ImportError:
24
+ try:
25
+ # In case typing_extensions is installed
26
+ from typing_extensions import NotRequired, Required # type: ignore
27
+ except ImportError:
28
+ # Fallback: create dummy types that will never match
29
+ Required = type("Required", (), {}) # type: ignore
30
+ NotRequired = type("NotRequired", (), {}) # type: ignore
5
31
 
6
32
  from .errors import (
7
33
  StrictDataclassClassValidationError,
@@ -12,6 +38,9 @@ from .errors import (
12
38
 
13
39
  Validator_T = Callable[[Any], None]
14
40
  T = TypeVar("T")
41
+ TypedDictType = TypeVar("TypedDictType", bound=dict[str, Any])
42
+
43
+ _TYPED_DICT_DEFAULT_VALUE = object() # used as default value in TypedDict fields (to distinguish from None)
15
44
 
16
45
 
17
46
  # The overload decorator helps type checkers understand the different return types
@@ -223,6 +252,92 @@ def strict(
223
252
  return wrap(cls) if cls is not None else wrap
224
253
 
225
254
 
255
+ def validate_typed_dict(schema: type[TypedDictType], data: dict) -> None:
256
+ """
257
+ Validate that a dictionary conforms to the types defined in a TypedDict class.
258
+
259
+ Under the hood, the typed dict is converted to a strict dataclass and validated using the `@strict` decorator.
260
+
261
+ Args:
262
+ schema (`type[TypedDictType]`):
263
+ The TypedDict class defining the expected structure and types.
264
+ data (`dict`):
265
+ The dictionary to validate.
266
+
267
+ Raises:
268
+ `StrictDataclassFieldValidationError`:
269
+ If any field in the dictionary does not conform to the expected type.
270
+
271
+ Example:
272
+ ```py
273
+ >>> from typing import Annotated, TypedDict
274
+ >>> from huggingface_hub.dataclasses import validate_typed_dict
275
+
276
+ >>> def positive_int(value: int):
277
+ ... if not value >= 0:
278
+ ... raise ValueError(f"Value must be positive, got {value}")
279
+
280
+ >>> class User(TypedDict):
281
+ ... name: str
282
+ ... age: Annotated[int, positive_int]
283
+
284
+ >>> # Valid data
285
+ >>> validate_typed_dict(User, {"name": "John", "age": 30})
286
+
287
+ >>> # Invalid type for age
288
+ >>> validate_typed_dict(User, {"name": "John", "age": "30"})
289
+ huggingface_hub.errors.StrictDataclassFieldValidationError: Validation error for field 'age':
290
+ TypeError: Field 'age' expected int, got str (value: '30')
291
+
292
+ >>> # Invalid value for age
293
+ >>> validate_typed_dict(User, {"name": "John", "age": -1})
294
+ huggingface_hub.errors.StrictDataclassFieldValidationError: Validation error for field 'age':
295
+ ValueError: Value must be positive, got -1
296
+ ```
297
+ """
298
+ # Convert typed dict to dataclass
299
+ strict_cls = _build_strict_cls_from_typed_dict(schema)
300
+
301
+ # Validate the data by instantiating the strict dataclass
302
+ strict_cls(**data) # will raise if validation fails
303
+
304
+
305
+ @lru_cache
306
+ def _build_strict_cls_from_typed_dict(schema: type[TypedDictType]) -> Type:
307
+ # Extract type hints from the TypedDict class
308
+ type_hints = {
309
+ # We do not use `get_type_hints` here to avoid evaluating ForwardRefs (which might fail).
310
+ # ForwardRefs are not validated by @strict anyway.
311
+ name: value if value is not None else type(None)
312
+ for name, value in schema.__dict__.get("__annotations__", {}).items()
313
+ }
314
+
315
+ # If the TypedDict is not total, wrap fields as NotRequired (unless explicitly Required or NotRequired)
316
+ if not getattr(schema, "__total__", True):
317
+ for key, value in type_hints.items():
318
+ origin = get_origin(value)
319
+
320
+ if origin is Annotated:
321
+ base, *meta = get_args(value)
322
+ if not _is_required_or_notrequired(base):
323
+ base = NotRequired[base]
324
+ type_hints[key] = Annotated[tuple([base] + list(meta))]
325
+ elif not _is_required_or_notrequired(value):
326
+ type_hints[key] = NotRequired[value]
327
+
328
+ # Convert type hints to dataclass fields
329
+ fields = []
330
+ for key, value in type_hints.items():
331
+ if get_origin(value) is Annotated:
332
+ base, *meta = get_args(value)
333
+ fields.append((key, base, field(default=_TYPED_DICT_DEFAULT_VALUE, metadata={"validator": meta[0]})))
334
+ else:
335
+ fields.append((key, value, field(default=_TYPED_DICT_DEFAULT_VALUE)))
336
+
337
+ # Create a strict dataclass from the TypedDict fields
338
+ return strict(make_dataclass(schema.__name__, fields))
339
+
340
+
226
341
  def validated_field(
227
342
  validator: Union[list[Validator_T], Validator_T],
228
343
  default: Union[Any, _MISSING_TYPE] = MISSING,
@@ -313,6 +428,14 @@ def type_validator(name: str, value: Any, expected_type: Any) -> None:
313
428
  _validate_simple_type(name, value, expected_type)
314
429
  elif isinstance(expected_type, ForwardRef) or isinstance(expected_type, str):
315
430
  return
431
+ elif origin is Required:
432
+ if value is _TYPED_DICT_DEFAULT_VALUE:
433
+ raise TypeError(f"Field '{name}' is required but missing.")
434
+ _validate_simple_type(name, value, args[0])
435
+ elif origin is NotRequired:
436
+ if value is _TYPED_DICT_DEFAULT_VALUE:
437
+ return
438
+ _validate_simple_type(name, value, args[0])
316
439
  else:
317
440
  raise TypeError(f"Unsupported type for field '{name}': {expected_type}")
318
441
 
@@ -449,6 +572,11 @@ def _is_validator(validator: Any) -> bool:
449
572
  return True
450
573
 
451
574
 
575
+ def _is_required_or_notrequired(type_hint: Any) -> bool:
576
+ """Helper to check if a type is Required/NotRequired."""
577
+ return type_hint in (Required, NotRequired) or (get_origin(type_hint) in (Required, NotRequired))
578
+
579
+
452
580
  _BASIC_TYPE_VALIDATORS = {
453
581
  Union: _validate_union,
454
582
  Literal: _validate_literal,
@@ -461,6 +589,7 @@ _BASIC_TYPE_VALIDATORS = {
461
589
 
462
590
  __all__ = [
463
591
  "strict",
592
+ "validate_typed_dict",
464
593
  "validated_field",
465
594
  "Validator_T",
466
595
  "StrictDataclassClassValidationError",
huggingface_hub/errors.py CHANGED
@@ -160,6 +160,10 @@ class HFValidationError(ValueError):
160
160
  # FILE METADATA ERRORS
161
161
 
162
162
 
163
+ class DryRunError(OSError):
164
+ """Error triggered when a dry run is requested but cannot be performed (e.g. invalid repo)."""
165
+
166
+
163
167
  class FileMetadataError(OSError):
164
168
  """Error triggered when the metadata of a file on the Hub cannot be retrieved (missing ETag or commit_hash).
165
169