huggingface-hub 1.0.0rc1__py3-none-any.whl → 1.0.0rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of huggingface-hub might be problematic. Click here for more details.
- huggingface_hub/__init__.py +4 -7
- huggingface_hub/_commit_api.py +126 -66
- huggingface_hub/_commit_scheduler.py +4 -7
- huggingface_hub/_login.py +10 -16
- huggingface_hub/_snapshot_download.py +119 -21
- huggingface_hub/_tensorboard_logger.py +2 -5
- huggingface_hub/_upload_large_folder.py +1 -2
- huggingface_hub/_webhooks_server.py +8 -20
- huggingface_hub/cli/_cli_utils.py +12 -6
- huggingface_hub/cli/download.py +32 -7
- huggingface_hub/cli/repo.py +137 -5
- huggingface_hub/dataclasses.py +122 -2
- huggingface_hub/errors.py +4 -0
- huggingface_hub/fastai_utils.py +22 -32
- huggingface_hub/file_download.py +234 -38
- huggingface_hub/hf_api.py +385 -424
- huggingface_hub/hf_file_system.py +55 -65
- huggingface_hub/inference/_client.py +27 -48
- huggingface_hub/inference/_generated/_async_client.py +27 -48
- huggingface_hub/inference/_generated/types/image_to_image.py +6 -2
- huggingface_hub/inference/_mcp/agent.py +2 -5
- huggingface_hub/inference/_mcp/mcp_client.py +6 -8
- huggingface_hub/inference/_providers/__init__.py +16 -0
- huggingface_hub/inference/_providers/_common.py +2 -0
- huggingface_hub/inference/_providers/fal_ai.py +2 -0
- huggingface_hub/inference/_providers/publicai.py +6 -0
- huggingface_hub/inference/_providers/scaleway.py +28 -0
- huggingface_hub/inference/_providers/zai_org.py +17 -0
- huggingface_hub/lfs.py +14 -8
- huggingface_hub/repocard.py +12 -16
- huggingface_hub/serialization/_base.py +3 -6
- huggingface_hub/serialization/_torch.py +16 -34
- huggingface_hub/utils/__init__.py +1 -2
- huggingface_hub/utils/_cache_manager.py +42 -72
- huggingface_hub/utils/_chunk_utils.py +2 -3
- huggingface_hub/utils/_http.py +37 -68
- huggingface_hub/utils/_validators.py +2 -2
- huggingface_hub/utils/logging.py +8 -11
- {huggingface_hub-1.0.0rc1.dist-info → huggingface_hub-1.0.0rc3.dist-info}/METADATA +2 -2
- {huggingface_hub-1.0.0rc1.dist-info → huggingface_hub-1.0.0rc3.dist-info}/RECORD +44 -56
- {huggingface_hub-1.0.0rc1.dist-info → huggingface_hub-1.0.0rc3.dist-info}/entry_points.txt +0 -1
- huggingface_hub/commands/__init__.py +0 -27
- huggingface_hub/commands/_cli_utils.py +0 -74
- huggingface_hub/commands/delete_cache.py +0 -476
- huggingface_hub/commands/download.py +0 -195
- huggingface_hub/commands/env.py +0 -39
- huggingface_hub/commands/huggingface_cli.py +0 -65
- huggingface_hub/commands/lfs.py +0 -200
- huggingface_hub/commands/repo.py +0 -151
- huggingface_hub/commands/repo_files.py +0 -132
- huggingface_hub/commands/scan_cache.py +0 -183
- huggingface_hub/commands/tag.py +0 -159
- huggingface_hub/commands/upload.py +0 -318
- huggingface_hub/commands/upload_large_folder.py +0 -131
- huggingface_hub/commands/user.py +0 -207
- huggingface_hub/commands/version.py +0 -40
- {huggingface_hub-1.0.0rc1.dist-info → huggingface_hub-1.0.0rc3.dist-info}/LICENSE +0 -0
- {huggingface_hub-1.0.0rc1.dist-info → huggingface_hub-1.0.0rc3.dist-info}/WHEEL +0 -0
- {huggingface_hub-1.0.0rc1.dist-info → huggingface_hub-1.0.0rc3.dist-info}/top_level.txt +0 -0
|
@@ -15,13 +15,23 @@
|
|
|
15
15
|
|
|
16
16
|
import os
|
|
17
17
|
from enum import Enum
|
|
18
|
-
from typing import Annotated, Optional, Union
|
|
18
|
+
from typing import TYPE_CHECKING, Annotated, Optional, Union
|
|
19
19
|
|
|
20
20
|
import click
|
|
21
21
|
import typer
|
|
22
22
|
|
|
23
23
|
from huggingface_hub import __version__
|
|
24
|
-
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
if TYPE_CHECKING:
|
|
27
|
+
from huggingface_hub.hf_api import HfApi
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def get_hf_api(token: Optional[str] = None) -> "HfApi":
|
|
31
|
+
# Import here to avoid circular import
|
|
32
|
+
from huggingface_hub.hf_api import HfApi
|
|
33
|
+
|
|
34
|
+
return HfApi(token=token, library_name="hf", library_version=__version__)
|
|
25
35
|
|
|
26
36
|
|
|
27
37
|
class ANSI:
|
|
@@ -140,7 +150,3 @@ RevisionOpt = Annotated[
|
|
|
140
150
|
help="Git revision id which can be a branch name, a tag, or a commit hash.",
|
|
141
151
|
),
|
|
142
152
|
]
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
def get_hf_api(token: Optional[str] = None) -> HfApi:
|
|
146
|
-
return HfApi(token=token, library_name="hf", library_version=__version__)
|
huggingface_hub/cli/download.py
CHANGED
|
@@ -37,16 +37,16 @@ Usage:
|
|
|
37
37
|
"""
|
|
38
38
|
|
|
39
39
|
import warnings
|
|
40
|
-
from typing import Annotated, Optional
|
|
40
|
+
from typing import Annotated, Optional, Union
|
|
41
41
|
|
|
42
42
|
import typer
|
|
43
43
|
|
|
44
44
|
from huggingface_hub import logging
|
|
45
45
|
from huggingface_hub._snapshot_download import snapshot_download
|
|
46
|
-
from huggingface_hub.file_download import hf_hub_download
|
|
47
|
-
from huggingface_hub.utils import disable_progress_bars, enable_progress_bars
|
|
46
|
+
from huggingface_hub.file_download import DryRunFileInfo, hf_hub_download
|
|
47
|
+
from huggingface_hub.utils import _format_size, disable_progress_bars, enable_progress_bars
|
|
48
48
|
|
|
49
|
-
from ._cli_utils import RepoIdArg, RepoTypeOpt, RevisionOpt, TokenOpt
|
|
49
|
+
from ._cli_utils import RepoIdArg, RepoTypeOpt, RevisionOpt, TokenOpt, tabulate
|
|
50
50
|
|
|
51
51
|
|
|
52
52
|
logger = logging.get_logger(__name__)
|
|
@@ -92,6 +92,12 @@ def download(
|
|
|
92
92
|
help="If True, the files will be downloaded even if they are already cached.",
|
|
93
93
|
),
|
|
94
94
|
] = False,
|
|
95
|
+
dry_run: Annotated[
|
|
96
|
+
bool,
|
|
97
|
+
typer.Option(
|
|
98
|
+
help="If True, perform a dry run without actually downloading the file.",
|
|
99
|
+
),
|
|
100
|
+
] = False,
|
|
95
101
|
token: TokenOpt = None,
|
|
96
102
|
quiet: Annotated[
|
|
97
103
|
bool,
|
|
@@ -108,7 +114,7 @@ def download(
|
|
|
108
114
|
) -> None:
|
|
109
115
|
"""Download files from the Hub."""
|
|
110
116
|
|
|
111
|
-
def run_download() -> str:
|
|
117
|
+
def run_download() -> Union[str, DryRunFileInfo, list[DryRunFileInfo]]:
|
|
112
118
|
filenames_list = filenames if filenames is not None else []
|
|
113
119
|
# Warn user if patterns are ignored
|
|
114
120
|
if len(filenames_list) > 0:
|
|
@@ -129,6 +135,7 @@ def download(
|
|
|
129
135
|
token=token,
|
|
130
136
|
local_dir=local_dir,
|
|
131
137
|
library_name="hf",
|
|
138
|
+
dry_run=dry_run,
|
|
132
139
|
)
|
|
133
140
|
|
|
134
141
|
# Otherwise: use `snapshot_download` to ensure all files comes from same revision
|
|
@@ -151,14 +158,32 @@ def download(
|
|
|
151
158
|
local_dir=local_dir,
|
|
152
159
|
library_name="hf",
|
|
153
160
|
max_workers=max_workers,
|
|
161
|
+
dry_run=dry_run,
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
def _print_result(result: Union[str, DryRunFileInfo, list[DryRunFileInfo]]) -> None:
|
|
165
|
+
if isinstance(result, str):
|
|
166
|
+
print(result)
|
|
167
|
+
return
|
|
168
|
+
|
|
169
|
+
# Print dry run info
|
|
170
|
+
if isinstance(result, DryRunFileInfo):
|
|
171
|
+
result = [result]
|
|
172
|
+
print(
|
|
173
|
+
f"[dry-run] Will download {len([r for r in result if r.will_download])} files (out of {len(result)}) totalling {_format_size(sum(r.file_size for r in result if r.will_download))}."
|
|
154
174
|
)
|
|
175
|
+
columns = ["File", "Bytes to download"]
|
|
176
|
+
items: list[list[Union[str, int]]] = []
|
|
177
|
+
for info in sorted(result, key=lambda x: x.filename):
|
|
178
|
+
items.append([info.filename, _format_size(info.file_size) if info.will_download else "-"])
|
|
179
|
+
print(tabulate(items, headers=columns))
|
|
155
180
|
|
|
156
181
|
if quiet:
|
|
157
182
|
disable_progress_bars()
|
|
158
183
|
with warnings.catch_warnings():
|
|
159
184
|
warnings.simplefilter("ignore")
|
|
160
|
-
|
|
185
|
+
_print_result(run_download())
|
|
161
186
|
enable_progress_bars()
|
|
162
187
|
else:
|
|
163
|
-
|
|
188
|
+
_print_result(run_download())
|
|
164
189
|
logging.set_verbosity_warning()
|
huggingface_hub/cli/repo.py
CHANGED
|
@@ -21,6 +21,7 @@ Usage:
|
|
|
21
21
|
hf repo create my-cool-model --private
|
|
22
22
|
"""
|
|
23
23
|
|
|
24
|
+
import enum
|
|
24
25
|
from typing import Annotated, Optional
|
|
25
26
|
|
|
26
27
|
import typer
|
|
@@ -44,8 +45,16 @@ from ._cli_utils import (
|
|
|
44
45
|
logger = logging.get_logger(__name__)
|
|
45
46
|
|
|
46
47
|
repo_cli = typer_factory(help="Manage repos on the Hub.")
|
|
47
|
-
|
|
48
|
-
|
|
48
|
+
tag_cli = typer_factory(help="Manage tags for a repo on the Hub.")
|
|
49
|
+
branch_cli = typer_factory(help="Manage branches for a repo on the Hub.")
|
|
50
|
+
repo_cli.add_typer(tag_cli, name="tag")
|
|
51
|
+
repo_cli.add_typer(branch_cli, name="branch")
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class GatedChoices(str, enum.Enum):
|
|
55
|
+
auto = "auto"
|
|
56
|
+
manual = "manual"
|
|
57
|
+
false = "false"
|
|
49
58
|
|
|
50
59
|
|
|
51
60
|
@repo_cli.command("create", help="Create a new repo on the Hub.")
|
|
@@ -87,7 +96,130 @@ def repo_create(
|
|
|
87
96
|
print(f"Your repo is now available at {ANSI.bold(repo_url)}")
|
|
88
97
|
|
|
89
98
|
|
|
90
|
-
@
|
|
99
|
+
@repo_cli.command("delete", help="Delete a repo from the Hub. this is an irreversible operation.")
|
|
100
|
+
def repo_delete(
|
|
101
|
+
repo_id: RepoIdArg,
|
|
102
|
+
repo_type: RepoTypeOpt = RepoType.model,
|
|
103
|
+
token: TokenOpt = None,
|
|
104
|
+
missing_ok: Annotated[
|
|
105
|
+
bool,
|
|
106
|
+
typer.Option(
|
|
107
|
+
help="If set to True, do not raise an error if repo does not exist.",
|
|
108
|
+
),
|
|
109
|
+
] = False,
|
|
110
|
+
) -> None:
|
|
111
|
+
api = get_hf_api(token=token)
|
|
112
|
+
api.delete_repo(
|
|
113
|
+
repo_id=repo_id,
|
|
114
|
+
repo_type=repo_type.value,
|
|
115
|
+
missing_ok=missing_ok,
|
|
116
|
+
)
|
|
117
|
+
print(f"Successfully deleted {ANSI.bold(repo_id)} on the Hub.")
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@repo_cli.command("move", help="Move a repository from a namespace to another namespace.")
|
|
121
|
+
def repo_move(
|
|
122
|
+
from_id: RepoIdArg,
|
|
123
|
+
to_id: RepoIdArg,
|
|
124
|
+
token: TokenOpt = None,
|
|
125
|
+
repo_type: RepoTypeOpt = RepoType.model,
|
|
126
|
+
) -> None:
|
|
127
|
+
api = get_hf_api(token=token)
|
|
128
|
+
api.move_repo(
|
|
129
|
+
from_id=from_id,
|
|
130
|
+
to_id=to_id,
|
|
131
|
+
repo_type=repo_type.value,
|
|
132
|
+
)
|
|
133
|
+
print(f"Successfully moved {ANSI.bold(from_id)} to {ANSI.bold(to_id)} on the Hub.")
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
@repo_cli.command("settings", help="Update the settings of a repository.")
|
|
137
|
+
def repo_settings(
|
|
138
|
+
repo_id: RepoIdArg,
|
|
139
|
+
gated: Annotated[
|
|
140
|
+
Optional[GatedChoices],
|
|
141
|
+
typer.Option(
|
|
142
|
+
help="The gated status for the repository.",
|
|
143
|
+
),
|
|
144
|
+
] = None,
|
|
145
|
+
private: Annotated[
|
|
146
|
+
Optional[bool],
|
|
147
|
+
typer.Option(
|
|
148
|
+
help="Whether the repository should be private.",
|
|
149
|
+
),
|
|
150
|
+
] = None,
|
|
151
|
+
xet_enabled: Annotated[
|
|
152
|
+
Optional[bool],
|
|
153
|
+
typer.Option(
|
|
154
|
+
help=" Whether the repository should be enabled for Xet Storage.",
|
|
155
|
+
),
|
|
156
|
+
] = None,
|
|
157
|
+
token: TokenOpt = None,
|
|
158
|
+
repo_type: RepoTypeOpt = RepoType.model,
|
|
159
|
+
) -> None:
|
|
160
|
+
api = get_hf_api(token=token)
|
|
161
|
+
api.update_repo_settings(
|
|
162
|
+
repo_id=repo_id,
|
|
163
|
+
gated=(gated.value if gated else None), # type: ignore [arg-type]
|
|
164
|
+
private=private,
|
|
165
|
+
xet_enabled=xet_enabled,
|
|
166
|
+
repo_type=repo_type.value,
|
|
167
|
+
)
|
|
168
|
+
print(f"Successfully updated the settings of {ANSI.bold(repo_id)} on the Hub.")
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
@branch_cli.command("create", help="Create a new branch for a repo on the Hub.")
|
|
172
|
+
def branch_create(
|
|
173
|
+
repo_id: RepoIdArg,
|
|
174
|
+
branch: Annotated[
|
|
175
|
+
str,
|
|
176
|
+
typer.Argument(
|
|
177
|
+
help="The name of the branch to create.",
|
|
178
|
+
),
|
|
179
|
+
],
|
|
180
|
+
revision: RevisionOpt = None,
|
|
181
|
+
token: TokenOpt = None,
|
|
182
|
+
repo_type: RepoTypeOpt = RepoType.model,
|
|
183
|
+
exist_ok: Annotated[
|
|
184
|
+
bool,
|
|
185
|
+
typer.Option(
|
|
186
|
+
help="If set to True, do not raise an error if branch already exists.",
|
|
187
|
+
),
|
|
188
|
+
] = False,
|
|
189
|
+
) -> None:
|
|
190
|
+
api = get_hf_api(token=token)
|
|
191
|
+
api.create_branch(
|
|
192
|
+
repo_id=repo_id,
|
|
193
|
+
branch=branch,
|
|
194
|
+
revision=revision,
|
|
195
|
+
repo_type=repo_type.value,
|
|
196
|
+
exist_ok=exist_ok,
|
|
197
|
+
)
|
|
198
|
+
print(f"Successfully created {ANSI.bold(branch)} branch on {repo_type.value} {ANSI.bold(repo_id)}")
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
@branch_cli.command("delete", help="Delete a branch from a repo on the Hub.")
|
|
202
|
+
def branch_delete(
|
|
203
|
+
repo_id: RepoIdArg,
|
|
204
|
+
branch: Annotated[
|
|
205
|
+
str,
|
|
206
|
+
typer.Argument(
|
|
207
|
+
help="The name of the branch to delete.",
|
|
208
|
+
),
|
|
209
|
+
],
|
|
210
|
+
token: TokenOpt = None,
|
|
211
|
+
repo_type: RepoTypeOpt = RepoType.model,
|
|
212
|
+
) -> None:
|
|
213
|
+
api = get_hf_api(token=token)
|
|
214
|
+
api.delete_branch(
|
|
215
|
+
repo_id=repo_id,
|
|
216
|
+
branch=branch,
|
|
217
|
+
repo_type=repo_type.value,
|
|
218
|
+
)
|
|
219
|
+
print(f"Successfully deleted {ANSI.bold(branch)} branch on {repo_type.value} {ANSI.bold(repo_id)}")
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
@tag_cli.command("create", help="Create a tag for a repo.")
|
|
91
223
|
def tag_create(
|
|
92
224
|
repo_id: RepoIdArg,
|
|
93
225
|
tag: Annotated[
|
|
@@ -127,7 +259,7 @@ def tag_create(
|
|
|
127
259
|
print(f"Tag {ANSI.bold(tag)} created on {ANSI.bold(repo_id)}")
|
|
128
260
|
|
|
129
261
|
|
|
130
|
-
@
|
|
262
|
+
@tag_cli.command("list", help="List tags for a repo.")
|
|
131
263
|
def tag_list(
|
|
132
264
|
repo_id: RepoIdArg,
|
|
133
265
|
token: TokenOpt = None,
|
|
@@ -152,7 +284,7 @@ def tag_list(
|
|
|
152
284
|
print(t.name)
|
|
153
285
|
|
|
154
286
|
|
|
155
|
-
@
|
|
287
|
+
@tag_cli.command("delete", help="Delete a tag for a repo.")
|
|
156
288
|
def tag_delete(
|
|
157
289
|
repo_id: RepoIdArg,
|
|
158
290
|
tag: Annotated[
|
huggingface_hub/dataclasses.py
CHANGED
|
@@ -1,9 +1,11 @@
|
|
|
1
1
|
import inspect
|
|
2
|
-
from dataclasses import _MISSING_TYPE, MISSING, Field, field, fields
|
|
3
|
-
from functools import wraps
|
|
2
|
+
from dataclasses import _MISSING_TYPE, MISSING, Field, field, fields, make_dataclass
|
|
3
|
+
from functools import lru_cache, wraps
|
|
4
4
|
from typing import (
|
|
5
|
+
Annotated,
|
|
5
6
|
Any,
|
|
6
7
|
Callable,
|
|
8
|
+
ForwardRef,
|
|
7
9
|
Literal,
|
|
8
10
|
Optional,
|
|
9
11
|
Type,
|
|
@@ -14,6 +16,19 @@ from typing import (
|
|
|
14
16
|
overload,
|
|
15
17
|
)
|
|
16
18
|
|
|
19
|
+
|
|
20
|
+
try:
|
|
21
|
+
# Python 3.11+
|
|
22
|
+
from typing import NotRequired, Required # type: ignore
|
|
23
|
+
except ImportError:
|
|
24
|
+
try:
|
|
25
|
+
# In case typing_extensions is installed
|
|
26
|
+
from typing_extensions import NotRequired, Required # type: ignore
|
|
27
|
+
except ImportError:
|
|
28
|
+
# Fallback: create dummy types that will never match
|
|
29
|
+
Required = type("Required", (), {}) # type: ignore
|
|
30
|
+
NotRequired = type("NotRequired", (), {}) # type: ignore
|
|
31
|
+
|
|
17
32
|
from .errors import (
|
|
18
33
|
StrictDataclassClassValidationError,
|
|
19
34
|
StrictDataclassDefinitionError,
|
|
@@ -23,6 +38,9 @@ from .errors import (
|
|
|
23
38
|
|
|
24
39
|
Validator_T = Callable[[Any], None]
|
|
25
40
|
T = TypeVar("T")
|
|
41
|
+
TypedDictType = TypeVar("TypedDictType", bound=dict[str, Any])
|
|
42
|
+
|
|
43
|
+
_TYPED_DICT_DEFAULT_VALUE = object() # used as default value in TypedDict fields (to distinguish from None)
|
|
26
44
|
|
|
27
45
|
|
|
28
46
|
# The overload decorator helps type checkers understand the different return types
|
|
@@ -234,6 +252,92 @@ def strict(
|
|
|
234
252
|
return wrap(cls) if cls is not None else wrap
|
|
235
253
|
|
|
236
254
|
|
|
255
|
+
def validate_typed_dict(schema: type[TypedDictType], data: dict) -> None:
|
|
256
|
+
"""
|
|
257
|
+
Validate that a dictionary conforms to the types defined in a TypedDict class.
|
|
258
|
+
|
|
259
|
+
Under the hood, the typed dict is converted to a strict dataclass and validated using the `@strict` decorator.
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
schema (`type[TypedDictType]`):
|
|
263
|
+
The TypedDict class defining the expected structure and types.
|
|
264
|
+
data (`dict`):
|
|
265
|
+
The dictionary to validate.
|
|
266
|
+
|
|
267
|
+
Raises:
|
|
268
|
+
`StrictDataclassFieldValidationError`:
|
|
269
|
+
If any field in the dictionary does not conform to the expected type.
|
|
270
|
+
|
|
271
|
+
Example:
|
|
272
|
+
```py
|
|
273
|
+
>>> from typing import Annotated, TypedDict
|
|
274
|
+
>>> from huggingface_hub.dataclasses import validate_typed_dict
|
|
275
|
+
|
|
276
|
+
>>> def positive_int(value: int):
|
|
277
|
+
... if not value >= 0:
|
|
278
|
+
... raise ValueError(f"Value must be positive, got {value}")
|
|
279
|
+
|
|
280
|
+
>>> class User(TypedDict):
|
|
281
|
+
... name: str
|
|
282
|
+
... age: Annotated[int, positive_int]
|
|
283
|
+
|
|
284
|
+
>>> # Valid data
|
|
285
|
+
>>> validate_typed_dict(User, {"name": "John", "age": 30})
|
|
286
|
+
|
|
287
|
+
>>> # Invalid type for age
|
|
288
|
+
>>> validate_typed_dict(User, {"name": "John", "age": "30"})
|
|
289
|
+
huggingface_hub.errors.StrictDataclassFieldValidationError: Validation error for field 'age':
|
|
290
|
+
TypeError: Field 'age' expected int, got str (value: '30')
|
|
291
|
+
|
|
292
|
+
>>> # Invalid value for age
|
|
293
|
+
>>> validate_typed_dict(User, {"name": "John", "age": -1})
|
|
294
|
+
huggingface_hub.errors.StrictDataclassFieldValidationError: Validation error for field 'age':
|
|
295
|
+
ValueError: Value must be positive, got -1
|
|
296
|
+
```
|
|
297
|
+
"""
|
|
298
|
+
# Convert typed dict to dataclass
|
|
299
|
+
strict_cls = _build_strict_cls_from_typed_dict(schema)
|
|
300
|
+
|
|
301
|
+
# Validate the data by instantiating the strict dataclass
|
|
302
|
+
strict_cls(**data) # will raise if validation fails
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
@lru_cache
|
|
306
|
+
def _build_strict_cls_from_typed_dict(schema: type[TypedDictType]) -> Type:
|
|
307
|
+
# Extract type hints from the TypedDict class
|
|
308
|
+
type_hints = {
|
|
309
|
+
# We do not use `get_type_hints` here to avoid evaluating ForwardRefs (which might fail).
|
|
310
|
+
# ForwardRefs are not validated by @strict anyway.
|
|
311
|
+
name: value if value is not None else type(None)
|
|
312
|
+
for name, value in schema.__dict__.get("__annotations__", {}).items()
|
|
313
|
+
}
|
|
314
|
+
|
|
315
|
+
# If the TypedDict is not total, wrap fields as NotRequired (unless explicitly Required or NotRequired)
|
|
316
|
+
if not getattr(schema, "__total__", True):
|
|
317
|
+
for key, value in type_hints.items():
|
|
318
|
+
origin = get_origin(value)
|
|
319
|
+
|
|
320
|
+
if origin is Annotated:
|
|
321
|
+
base, *meta = get_args(value)
|
|
322
|
+
if not _is_required_or_notrequired(base):
|
|
323
|
+
base = NotRequired[base]
|
|
324
|
+
type_hints[key] = Annotated[tuple([base] + list(meta))]
|
|
325
|
+
elif not _is_required_or_notrequired(value):
|
|
326
|
+
type_hints[key] = NotRequired[value]
|
|
327
|
+
|
|
328
|
+
# Convert type hints to dataclass fields
|
|
329
|
+
fields = []
|
|
330
|
+
for key, value in type_hints.items():
|
|
331
|
+
if get_origin(value) is Annotated:
|
|
332
|
+
base, *meta = get_args(value)
|
|
333
|
+
fields.append((key, base, field(default=_TYPED_DICT_DEFAULT_VALUE, metadata={"validator": meta[0]})))
|
|
334
|
+
else:
|
|
335
|
+
fields.append((key, value, field(default=_TYPED_DICT_DEFAULT_VALUE)))
|
|
336
|
+
|
|
337
|
+
# Create a strict dataclass from the TypedDict fields
|
|
338
|
+
return strict(make_dataclass(schema.__name__, fields))
|
|
339
|
+
|
|
340
|
+
|
|
237
341
|
def validated_field(
|
|
238
342
|
validator: Union[list[Validator_T], Validator_T],
|
|
239
343
|
default: Union[Any, _MISSING_TYPE] = MISSING,
|
|
@@ -322,6 +426,16 @@ def type_validator(name: str, value: Any, expected_type: Any) -> None:
|
|
|
322
426
|
validator(name, value, args)
|
|
323
427
|
elif isinstance(expected_type, type): # simple types
|
|
324
428
|
_validate_simple_type(name, value, expected_type)
|
|
429
|
+
elif isinstance(expected_type, ForwardRef) or isinstance(expected_type, str):
|
|
430
|
+
return
|
|
431
|
+
elif origin is Required:
|
|
432
|
+
if value is _TYPED_DICT_DEFAULT_VALUE:
|
|
433
|
+
raise TypeError(f"Field '{name}' is required but missing.")
|
|
434
|
+
_validate_simple_type(name, value, args[0])
|
|
435
|
+
elif origin is NotRequired:
|
|
436
|
+
if value is _TYPED_DICT_DEFAULT_VALUE:
|
|
437
|
+
return
|
|
438
|
+
_validate_simple_type(name, value, args[0])
|
|
325
439
|
else:
|
|
326
440
|
raise TypeError(f"Unsupported type for field '{name}': {expected_type}")
|
|
327
441
|
|
|
@@ -458,6 +572,11 @@ def _is_validator(validator: Any) -> bool:
|
|
|
458
572
|
return True
|
|
459
573
|
|
|
460
574
|
|
|
575
|
+
def _is_required_or_notrequired(type_hint: Any) -> bool:
|
|
576
|
+
"""Helper to check if a type is Required/NotRequired."""
|
|
577
|
+
return type_hint in (Required, NotRequired) or (get_origin(type_hint) in (Required, NotRequired))
|
|
578
|
+
|
|
579
|
+
|
|
461
580
|
_BASIC_TYPE_VALIDATORS = {
|
|
462
581
|
Union: _validate_union,
|
|
463
582
|
Literal: _validate_literal,
|
|
@@ -470,6 +589,7 @@ _BASIC_TYPE_VALIDATORS = {
|
|
|
470
589
|
|
|
471
590
|
__all__ = [
|
|
472
591
|
"strict",
|
|
592
|
+
"validate_typed_dict",
|
|
473
593
|
"validated_field",
|
|
474
594
|
"Validator_T",
|
|
475
595
|
"StrictDataclassClassValidationError",
|
huggingface_hub/errors.py
CHANGED
|
@@ -160,6 +160,10 @@ class HFValidationError(ValueError):
|
|
|
160
160
|
# FILE METADATA ERRORS
|
|
161
161
|
|
|
162
162
|
|
|
163
|
+
class DryRunError(OSError):
|
|
164
|
+
"""Error triggered when a dry run is requested but cannot be performed (e.g. invalid repo)."""
|
|
165
|
+
|
|
166
|
+
|
|
163
167
|
class FileMetadataError(OSError):
|
|
164
168
|
"""Error triggered when the metadata of a file on the Hub cannot be retrieved (missing ETag or commit_hash).
|
|
165
169
|
|
huggingface_hub/fastai_utils.py
CHANGED
|
@@ -34,13 +34,11 @@ def _check_fastai_fastcore_versions(
|
|
|
34
34
|
fastcore_min_version (`str`, *optional*):
|
|
35
35
|
The minimum fastcore version supported.
|
|
36
36
|
|
|
37
|
-
|
|
38
|
-
Raises the following error:
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
</Tip>
|
|
37
|
+
> [!TIP]
|
|
38
|
+
> Raises the following error:
|
|
39
|
+
>
|
|
40
|
+
> - [`ImportError`](https://docs.python.org/3/library/exceptions.html#ImportError)
|
|
41
|
+
> if the fastai or fastcore libraries are not available or are of an invalid version.
|
|
44
42
|
"""
|
|
45
43
|
|
|
46
44
|
if (get_fastcore_version() or get_fastai_version()) == "N/A":
|
|
@@ -88,15 +86,13 @@ def _check_fastai_fastcore_pyproject_versions(
|
|
|
88
86
|
fastcore_min_version (`str`, *optional*):
|
|
89
87
|
The minimum fastcore version supported.
|
|
90
88
|
|
|
91
|
-
|
|
92
|
-
Raises the following errors:
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
</Tip>
|
|
89
|
+
> [!TIP]
|
|
90
|
+
> Raises the following errors:
|
|
91
|
+
>
|
|
92
|
+
> - [`ImportError`](https://docs.python.org/3/library/exceptions.html#ImportError)
|
|
93
|
+
> if the `toml` module is not installed.
|
|
94
|
+
> - [`ImportError`](https://docs.python.org/3/library/exceptions.html#ImportError)
|
|
95
|
+
> if the `pyproject.toml` indicates a lower than minimum supported version of fastai or fastcore.
|
|
100
96
|
"""
|
|
101
97
|
|
|
102
98
|
try:
|
|
@@ -253,14 +249,11 @@ def _save_pretrained_fastai(
|
|
|
253
249
|
config (`dict`, *optional*):
|
|
254
250
|
Configuration object. Will be uploaded as a .json file. Example: 'https://huggingface.co/espejelomar/fastai-pet-breeds-classification/blob/main/config.json'.
|
|
255
251
|
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
if the config file provided is not a dictionary.
|
|
262
|
-
|
|
263
|
-
</Tip>
|
|
252
|
+
> [!TIP]
|
|
253
|
+
> Raises the following error:
|
|
254
|
+
>
|
|
255
|
+
> - [`RuntimeError`](https://docs.python.org/3/library/exceptions.html#RuntimeError)
|
|
256
|
+
> if the config file provided is not a dictionary.
|
|
264
257
|
"""
|
|
265
258
|
_check_fastai_fastcore_versions()
|
|
266
259
|
|
|
@@ -394,14 +387,11 @@ def push_to_hub_fastai(
|
|
|
394
387
|
Returns:
|
|
395
388
|
The url of the commit of your model in the given repository.
|
|
396
389
|
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
if the user is not log on to the Hugging Face Hub.
|
|
403
|
-
|
|
404
|
-
</Tip>
|
|
390
|
+
> [!TIP]
|
|
391
|
+
> Raises the following error:
|
|
392
|
+
>
|
|
393
|
+
> - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
|
394
|
+
> if the user is not log on to the Hugging Face Hub.
|
|
405
395
|
"""
|
|
406
396
|
_check_fastai_fastcore_versions()
|
|
407
397
|
api = HfApi(endpoint=api_endpoint)
|