huggingface-hub 0.31.0rc0__py3-none-any.whl → 1.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- huggingface_hub/__init__.py +145 -46
- huggingface_hub/_commit_api.py +168 -119
- huggingface_hub/_commit_scheduler.py +15 -15
- huggingface_hub/_inference_endpoints.py +15 -12
- huggingface_hub/_jobs_api.py +301 -0
- huggingface_hub/_local_folder.py +18 -3
- huggingface_hub/_login.py +31 -63
- huggingface_hub/_oauth.py +460 -0
- huggingface_hub/_snapshot_download.py +239 -80
- huggingface_hub/_space_api.py +5 -5
- huggingface_hub/_tensorboard_logger.py +15 -19
- huggingface_hub/_upload_large_folder.py +172 -76
- huggingface_hub/_webhooks_payload.py +3 -3
- huggingface_hub/_webhooks_server.py +13 -25
- huggingface_hub/{commands → cli}/__init__.py +1 -15
- huggingface_hub/cli/_cli_utils.py +173 -0
- huggingface_hub/cli/auth.py +147 -0
- huggingface_hub/cli/cache.py +841 -0
- huggingface_hub/cli/download.py +189 -0
- huggingface_hub/cli/hf.py +60 -0
- huggingface_hub/cli/inference_endpoints.py +377 -0
- huggingface_hub/cli/jobs.py +772 -0
- huggingface_hub/cli/lfs.py +175 -0
- huggingface_hub/cli/repo.py +315 -0
- huggingface_hub/cli/repo_files.py +94 -0
- huggingface_hub/{commands/env.py → cli/system.py} +10 -13
- huggingface_hub/cli/upload.py +294 -0
- huggingface_hub/cli/upload_large_folder.py +117 -0
- huggingface_hub/community.py +20 -12
- huggingface_hub/constants.py +38 -53
- huggingface_hub/dataclasses.py +609 -0
- huggingface_hub/errors.py +80 -30
- huggingface_hub/fastai_utils.py +30 -41
- huggingface_hub/file_download.py +435 -351
- huggingface_hub/hf_api.py +2050 -1124
- huggingface_hub/hf_file_system.py +269 -152
- huggingface_hub/hub_mixin.py +43 -63
- huggingface_hub/inference/_client.py +347 -434
- huggingface_hub/inference/_common.py +133 -121
- huggingface_hub/inference/_generated/_async_client.py +397 -541
- huggingface_hub/inference/_generated/types/__init__.py +5 -1
- huggingface_hub/inference/_generated/types/automatic_speech_recognition.py +3 -3
- huggingface_hub/inference/_generated/types/base.py +10 -7
- huggingface_hub/inference/_generated/types/chat_completion.py +59 -23
- huggingface_hub/inference/_generated/types/depth_estimation.py +2 -2
- huggingface_hub/inference/_generated/types/document_question_answering.py +2 -2
- huggingface_hub/inference/_generated/types/feature_extraction.py +2 -2
- huggingface_hub/inference/_generated/types/fill_mask.py +2 -2
- huggingface_hub/inference/_generated/types/image_to_image.py +6 -2
- huggingface_hub/inference/_generated/types/image_to_video.py +60 -0
- huggingface_hub/inference/_generated/types/sentence_similarity.py +3 -3
- huggingface_hub/inference/_generated/types/summarization.py +2 -2
- huggingface_hub/inference/_generated/types/table_question_answering.py +5 -5
- huggingface_hub/inference/_generated/types/text2text_generation.py +2 -2
- huggingface_hub/inference/_generated/types/text_generation.py +10 -10
- huggingface_hub/inference/_generated/types/text_to_video.py +2 -2
- huggingface_hub/inference/_generated/types/token_classification.py +2 -2
- huggingface_hub/inference/_generated/types/translation.py +2 -2
- huggingface_hub/inference/_generated/types/zero_shot_classification.py +2 -2
- huggingface_hub/inference/_generated/types/zero_shot_image_classification.py +2 -2
- huggingface_hub/inference/_generated/types/zero_shot_object_detection.py +1 -3
- huggingface_hub/inference/_mcp/__init__.py +0 -0
- huggingface_hub/inference/_mcp/_cli_hacks.py +88 -0
- huggingface_hub/inference/_mcp/agent.py +100 -0
- huggingface_hub/inference/_mcp/cli.py +247 -0
- huggingface_hub/inference/_mcp/constants.py +81 -0
- huggingface_hub/inference/_mcp/mcp_client.py +395 -0
- huggingface_hub/inference/_mcp/types.py +45 -0
- huggingface_hub/inference/_mcp/utils.py +128 -0
- huggingface_hub/inference/_providers/__init__.py +82 -7
- huggingface_hub/inference/_providers/_common.py +129 -27
- huggingface_hub/inference/_providers/black_forest_labs.py +6 -6
- huggingface_hub/inference/_providers/cerebras.py +1 -1
- huggingface_hub/inference/_providers/clarifai.py +13 -0
- huggingface_hub/inference/_providers/cohere.py +20 -3
- huggingface_hub/inference/_providers/fal_ai.py +183 -56
- huggingface_hub/inference/_providers/featherless_ai.py +38 -0
- huggingface_hub/inference/_providers/fireworks_ai.py +18 -0
- huggingface_hub/inference/_providers/groq.py +9 -0
- huggingface_hub/inference/_providers/hf_inference.py +69 -30
- huggingface_hub/inference/_providers/hyperbolic.py +4 -4
- huggingface_hub/inference/_providers/nebius.py +33 -5
- huggingface_hub/inference/_providers/novita.py +5 -5
- huggingface_hub/inference/_providers/nscale.py +44 -0
- huggingface_hub/inference/_providers/openai.py +3 -1
- huggingface_hub/inference/_providers/publicai.py +6 -0
- huggingface_hub/inference/_providers/replicate.py +31 -13
- huggingface_hub/inference/_providers/sambanova.py +18 -4
- huggingface_hub/inference/_providers/scaleway.py +28 -0
- huggingface_hub/inference/_providers/together.py +20 -5
- huggingface_hub/inference/_providers/wavespeed.py +138 -0
- huggingface_hub/inference/_providers/zai_org.py +17 -0
- huggingface_hub/lfs.py +33 -100
- huggingface_hub/repocard.py +34 -38
- huggingface_hub/repocard_data.py +57 -57
- huggingface_hub/serialization/__init__.py +0 -1
- huggingface_hub/serialization/_base.py +12 -15
- huggingface_hub/serialization/_dduf.py +8 -8
- huggingface_hub/serialization/_torch.py +69 -69
- huggingface_hub/utils/__init__.py +19 -8
- huggingface_hub/utils/_auth.py +7 -7
- huggingface_hub/utils/_cache_manager.py +92 -147
- huggingface_hub/utils/_chunk_utils.py +2 -3
- huggingface_hub/utils/_deprecation.py +1 -1
- huggingface_hub/utils/_dotenv.py +55 -0
- huggingface_hub/utils/_experimental.py +7 -5
- huggingface_hub/utils/_fixes.py +0 -10
- huggingface_hub/utils/_git_credential.py +5 -5
- huggingface_hub/utils/_headers.py +8 -30
- huggingface_hub/utils/_http.py +398 -239
- huggingface_hub/utils/_pagination.py +4 -4
- huggingface_hub/utils/_parsing.py +98 -0
- huggingface_hub/utils/_paths.py +5 -5
- huggingface_hub/utils/_runtime.py +61 -24
- huggingface_hub/utils/_safetensors.py +21 -21
- huggingface_hub/utils/_subprocess.py +9 -9
- huggingface_hub/utils/_telemetry.py +4 -4
- huggingface_hub/{commands/_cli_utils.py → utils/_terminal.py} +4 -4
- huggingface_hub/utils/_typing.py +25 -5
- huggingface_hub/utils/_validators.py +55 -74
- huggingface_hub/utils/_verification.py +167 -0
- huggingface_hub/utils/_xet.py +64 -17
- huggingface_hub/utils/_xet_progress_reporting.py +162 -0
- huggingface_hub/utils/insecure_hashlib.py +3 -5
- huggingface_hub/utils/logging.py +8 -11
- huggingface_hub/utils/tqdm.py +5 -4
- {huggingface_hub-0.31.0rc0.dist-info → huggingface_hub-1.1.3.dist-info}/METADATA +94 -85
- huggingface_hub-1.1.3.dist-info/RECORD +155 -0
- {huggingface_hub-0.31.0rc0.dist-info → huggingface_hub-1.1.3.dist-info}/WHEEL +1 -1
- huggingface_hub-1.1.3.dist-info/entry_points.txt +6 -0
- huggingface_hub/commands/delete_cache.py +0 -474
- huggingface_hub/commands/download.py +0 -200
- huggingface_hub/commands/huggingface_cli.py +0 -61
- huggingface_hub/commands/lfs.py +0 -200
- huggingface_hub/commands/repo_files.py +0 -128
- huggingface_hub/commands/scan_cache.py +0 -181
- huggingface_hub/commands/tag.py +0 -159
- huggingface_hub/commands/upload.py +0 -314
- huggingface_hub/commands/upload_large_folder.py +0 -129
- huggingface_hub/commands/user.py +0 -304
- huggingface_hub/commands/version.py +0 -37
- huggingface_hub/inference_api.py +0 -217
- huggingface_hub/keras_mixin.py +0 -500
- huggingface_hub/repository.py +0 -1477
- huggingface_hub/serialization/_tensorflow.py +0 -95
- huggingface_hub/utils/_hf_folder.py +0 -68
- huggingface_hub-0.31.0rc0.dist-info/RECORD +0 -135
- huggingface_hub-0.31.0rc0.dist-info/entry_points.txt +0 -6
- {huggingface_hub-0.31.0rc0.dist-info → huggingface_hub-1.1.3.dist-info/licenses}/LICENSE +0 -0
- {huggingface_hub-0.31.0rc0.dist-info → huggingface_hub-1.1.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,609 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
from dataclasses import _MISSING_TYPE, MISSING, Field, field, fields, make_dataclass
|
|
3
|
+
from functools import lru_cache, wraps
|
|
4
|
+
from typing import (
|
|
5
|
+
Annotated,
|
|
6
|
+
Any,
|
|
7
|
+
Callable,
|
|
8
|
+
ForwardRef,
|
|
9
|
+
Literal,
|
|
10
|
+
Optional,
|
|
11
|
+
Type,
|
|
12
|
+
TypeVar,
|
|
13
|
+
Union,
|
|
14
|
+
get_args,
|
|
15
|
+
get_origin,
|
|
16
|
+
overload,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
try:
|
|
21
|
+
# Python 3.11+
|
|
22
|
+
from typing import NotRequired, Required # type: ignore
|
|
23
|
+
except ImportError:
|
|
24
|
+
try:
|
|
25
|
+
# In case typing_extensions is installed
|
|
26
|
+
from typing_extensions import NotRequired, Required # type: ignore
|
|
27
|
+
except ImportError:
|
|
28
|
+
# Fallback: create dummy types that will never match
|
|
29
|
+
Required = type("Required", (), {}) # type: ignore
|
|
30
|
+
NotRequired = type("NotRequired", (), {}) # type: ignore
|
|
31
|
+
|
|
32
|
+
from .errors import (
|
|
33
|
+
StrictDataclassClassValidationError,
|
|
34
|
+
StrictDataclassDefinitionError,
|
|
35
|
+
StrictDataclassFieldValidationError,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
Validator_T = Callable[[Any], None]
|
|
40
|
+
T = TypeVar("T")
|
|
41
|
+
TypedDictType = TypeVar("TypedDictType", bound=dict[str, Any])
|
|
42
|
+
|
|
43
|
+
_TYPED_DICT_DEFAULT_VALUE = object() # used as default value in TypedDict fields (to distinguish from None)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
# The overload decorator helps type checkers understand the different return types
|
|
47
|
+
@overload
|
|
48
|
+
def strict(cls: Type[T]) -> Type[T]: ...
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@overload
|
|
52
|
+
def strict(*, accept_kwargs: bool = False) -> Callable[[Type[T]], Type[T]]: ...
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def strict(
|
|
56
|
+
cls: Optional[Type[T]] = None, *, accept_kwargs: bool = False
|
|
57
|
+
) -> Union[Type[T], Callable[[Type[T]], Type[T]]]:
|
|
58
|
+
"""
|
|
59
|
+
Decorator to add strict validation to a dataclass.
|
|
60
|
+
|
|
61
|
+
This decorator must be used on top of `@dataclass` to ensure IDEs and static typing tools
|
|
62
|
+
recognize the class as a dataclass.
|
|
63
|
+
|
|
64
|
+
Can be used with or without arguments:
|
|
65
|
+
- `@strict`
|
|
66
|
+
- `@strict(accept_kwargs=True)`
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
cls:
|
|
70
|
+
The class to convert to a strict dataclass.
|
|
71
|
+
accept_kwargs (`bool`, *optional*):
|
|
72
|
+
If True, allows arbitrary keyword arguments in `__init__`. Defaults to False.
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
The enhanced dataclass with strict validation on field assignment.
|
|
76
|
+
|
|
77
|
+
Example:
|
|
78
|
+
```py
|
|
79
|
+
>>> from dataclasses import dataclass
|
|
80
|
+
>>> from huggingface_hub.dataclasses import as_validated_field, strict, validated_field
|
|
81
|
+
|
|
82
|
+
>>> @as_validated_field
|
|
83
|
+
>>> def positive_int(value: int):
|
|
84
|
+
... if not value >= 0:
|
|
85
|
+
... raise ValueError(f"Value must be positive, got {value}")
|
|
86
|
+
|
|
87
|
+
>>> @strict(accept_kwargs=True)
|
|
88
|
+
... @dataclass
|
|
89
|
+
... class User:
|
|
90
|
+
... name: str
|
|
91
|
+
... age: int = positive_int(default=10)
|
|
92
|
+
|
|
93
|
+
# Initialize
|
|
94
|
+
>>> User(name="John")
|
|
95
|
+
User(name='John', age=10)
|
|
96
|
+
|
|
97
|
+
# Extra kwargs are accepted
|
|
98
|
+
>>> User(name="John", age=30, lastname="Doe")
|
|
99
|
+
User(name='John', age=30, *lastname='Doe')
|
|
100
|
+
|
|
101
|
+
# Invalid type => raises
|
|
102
|
+
>>> User(name="John", age="30")
|
|
103
|
+
huggingface_hub.errors.StrictDataclassFieldValidationError: Validation error for field 'age':
|
|
104
|
+
TypeError: Field 'age' expected int, got str (value: '30')
|
|
105
|
+
|
|
106
|
+
# Invalid value => raises
|
|
107
|
+
>>> User(name="John", age=-1)
|
|
108
|
+
huggingface_hub.errors.StrictDataclassFieldValidationError: Validation error for field 'age':
|
|
109
|
+
ValueError: Value must be positive, got -1
|
|
110
|
+
```
|
|
111
|
+
"""
|
|
112
|
+
|
|
113
|
+
def wrap(cls: Type[T]) -> Type[T]:
|
|
114
|
+
if not hasattr(cls, "__dataclass_fields__"):
|
|
115
|
+
raise StrictDataclassDefinitionError(
|
|
116
|
+
f"Class '{cls.__name__}' must be a dataclass before applying @strict."
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
# List and store validators
|
|
120
|
+
field_validators: dict[str, list[Validator_T]] = {}
|
|
121
|
+
for f in fields(cls): # type: ignore [arg-type]
|
|
122
|
+
validators = []
|
|
123
|
+
validators.append(_create_type_validator(f))
|
|
124
|
+
custom_validator = f.metadata.get("validator")
|
|
125
|
+
if custom_validator is not None:
|
|
126
|
+
if not isinstance(custom_validator, list):
|
|
127
|
+
custom_validator = [custom_validator]
|
|
128
|
+
for validator in custom_validator:
|
|
129
|
+
if not _is_validator(validator):
|
|
130
|
+
raise StrictDataclassDefinitionError(
|
|
131
|
+
f"Invalid validator for field '{f.name}': {validator}. Must be a callable taking a single argument."
|
|
132
|
+
)
|
|
133
|
+
validators.extend(custom_validator)
|
|
134
|
+
field_validators[f.name] = validators
|
|
135
|
+
cls.__validators__ = field_validators # type: ignore
|
|
136
|
+
|
|
137
|
+
# Override __setattr__ to validate fields on assignment
|
|
138
|
+
original_setattr = cls.__setattr__
|
|
139
|
+
|
|
140
|
+
def __strict_setattr__(self: Any, name: str, value: Any) -> None:
|
|
141
|
+
"""Custom __setattr__ method for strict dataclasses."""
|
|
142
|
+
# Run all validators
|
|
143
|
+
for validator in self.__validators__.get(name, []):
|
|
144
|
+
try:
|
|
145
|
+
validator(value)
|
|
146
|
+
except (ValueError, TypeError) as e:
|
|
147
|
+
raise StrictDataclassFieldValidationError(field=name, cause=e) from e
|
|
148
|
+
|
|
149
|
+
# If validation passed, set the attribute
|
|
150
|
+
original_setattr(self, name, value)
|
|
151
|
+
|
|
152
|
+
cls.__setattr__ = __strict_setattr__ # type: ignore[method-assign]
|
|
153
|
+
|
|
154
|
+
if accept_kwargs:
|
|
155
|
+
# (optional) Override __init__ to accept arbitrary keyword arguments
|
|
156
|
+
original_init = cls.__init__
|
|
157
|
+
|
|
158
|
+
@wraps(original_init)
|
|
159
|
+
def __init__(self, **kwargs: Any) -> None:
|
|
160
|
+
# Extract only the fields that are part of the dataclass
|
|
161
|
+
dataclass_fields = {f.name for f in fields(cls)} # type: ignore [arg-type]
|
|
162
|
+
standard_kwargs = {k: v for k, v in kwargs.items() if k in dataclass_fields}
|
|
163
|
+
|
|
164
|
+
# Call the original __init__ with standard fields
|
|
165
|
+
original_init(self, **standard_kwargs)
|
|
166
|
+
|
|
167
|
+
# Add any additional kwargs as attributes
|
|
168
|
+
for name, value in kwargs.items():
|
|
169
|
+
if name not in dataclass_fields:
|
|
170
|
+
self.__setattr__(name, value)
|
|
171
|
+
|
|
172
|
+
cls.__init__ = __init__ # type: ignore[method-assign]
|
|
173
|
+
|
|
174
|
+
# (optional) Override __repr__ to include additional kwargs
|
|
175
|
+
original_repr = cls.__repr__
|
|
176
|
+
|
|
177
|
+
@wraps(original_repr)
|
|
178
|
+
def __repr__(self) -> str:
|
|
179
|
+
# Call the original __repr__ to get the standard fields
|
|
180
|
+
standard_repr = original_repr(self)
|
|
181
|
+
|
|
182
|
+
# Get additional kwargs
|
|
183
|
+
additional_kwargs = [
|
|
184
|
+
# add a '*' in front of additional kwargs to let the user know they are not part of the dataclass
|
|
185
|
+
f"*{k}={v!r}"
|
|
186
|
+
for k, v in self.__dict__.items()
|
|
187
|
+
if k not in cls.__dataclass_fields__ # type: ignore [attr-defined]
|
|
188
|
+
]
|
|
189
|
+
additional_repr = ", ".join(additional_kwargs)
|
|
190
|
+
|
|
191
|
+
# Combine both representations
|
|
192
|
+
return f"{standard_repr[:-1]}, {additional_repr})" if additional_kwargs else standard_repr
|
|
193
|
+
|
|
194
|
+
cls.__repr__ = __repr__ # type: ignore [method-assign]
|
|
195
|
+
|
|
196
|
+
# List all public methods starting with `validate_` => class validators.
|
|
197
|
+
class_validators = []
|
|
198
|
+
|
|
199
|
+
for name in dir(cls):
|
|
200
|
+
if not name.startswith("validate_"):
|
|
201
|
+
continue
|
|
202
|
+
method = getattr(cls, name)
|
|
203
|
+
if not callable(method):
|
|
204
|
+
continue
|
|
205
|
+
if len(inspect.signature(method).parameters) != 1:
|
|
206
|
+
raise StrictDataclassDefinitionError(
|
|
207
|
+
f"Class '{cls.__name__}' has a class validator '{name}' that takes more than one argument."
|
|
208
|
+
" Class validators must take only 'self' as an argument. Methods starting with 'validate_'"
|
|
209
|
+
" are considered to be class validators."
|
|
210
|
+
)
|
|
211
|
+
class_validators.append(method)
|
|
212
|
+
|
|
213
|
+
cls.__class_validators__ = class_validators # type: ignore [attr-defined]
|
|
214
|
+
|
|
215
|
+
# Add `validate` method to the class, but first check if it already exists
|
|
216
|
+
def validate(self: T) -> None:
|
|
217
|
+
"""Run class validators on the instance."""
|
|
218
|
+
for validator in cls.__class_validators__: # type: ignore [attr-defined]
|
|
219
|
+
try:
|
|
220
|
+
validator(self)
|
|
221
|
+
except (ValueError, TypeError) as e:
|
|
222
|
+
raise StrictDataclassClassValidationError(validator=validator.__name__, cause=e) from e
|
|
223
|
+
|
|
224
|
+
# Hack to be able to raise if `.validate()` already exists except if it was created by this decorator on a parent class
|
|
225
|
+
# (in which case we just override it)
|
|
226
|
+
validate.__is_defined_by_strict_decorator__ = True # type: ignore [attr-defined]
|
|
227
|
+
|
|
228
|
+
if hasattr(cls, "validate"):
|
|
229
|
+
if not getattr(cls.validate, "__is_defined_by_strict_decorator__", False): # type: ignore [attr-defined]
|
|
230
|
+
raise StrictDataclassDefinitionError(
|
|
231
|
+
f"Class '{cls.__name__}' already implements a method called 'validate'."
|
|
232
|
+
" This method name is reserved when using the @strict decorator on a dataclass."
|
|
233
|
+
" If you want to keep your own method, please rename it."
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
cls.validate = validate # type: ignore
|
|
237
|
+
|
|
238
|
+
# Run class validators after initialization
|
|
239
|
+
initial_init = cls.__init__
|
|
240
|
+
|
|
241
|
+
@wraps(initial_init)
|
|
242
|
+
def init_with_validate(self, *args, **kwargs) -> None:
|
|
243
|
+
"""Run class validators after initialization."""
|
|
244
|
+
initial_init(self, *args, **kwargs) # type: ignore [call-arg]
|
|
245
|
+
cls.validate(self) # type: ignore [attr-defined]
|
|
246
|
+
|
|
247
|
+
setattr(cls, "__init__", init_with_validate)
|
|
248
|
+
|
|
249
|
+
return cls
|
|
250
|
+
|
|
251
|
+
# Return wrapped class or the decorator itself
|
|
252
|
+
return wrap(cls) if cls is not None else wrap
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def validate_typed_dict(schema: type[TypedDictType], data: dict) -> None:
|
|
256
|
+
"""
|
|
257
|
+
Validate that a dictionary conforms to the types defined in a TypedDict class.
|
|
258
|
+
|
|
259
|
+
Under the hood, the typed dict is converted to a strict dataclass and validated using the `@strict` decorator.
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
schema (`type[TypedDictType]`):
|
|
263
|
+
The TypedDict class defining the expected structure and types.
|
|
264
|
+
data (`dict`):
|
|
265
|
+
The dictionary to validate.
|
|
266
|
+
|
|
267
|
+
Raises:
|
|
268
|
+
`StrictDataclassFieldValidationError`:
|
|
269
|
+
If any field in the dictionary does not conform to the expected type.
|
|
270
|
+
|
|
271
|
+
Example:
|
|
272
|
+
```py
|
|
273
|
+
>>> from typing import Annotated, TypedDict
|
|
274
|
+
>>> from huggingface_hub.dataclasses import validate_typed_dict
|
|
275
|
+
|
|
276
|
+
>>> def positive_int(value: int):
|
|
277
|
+
... if not value >= 0:
|
|
278
|
+
... raise ValueError(f"Value must be positive, got {value}")
|
|
279
|
+
|
|
280
|
+
>>> class User(TypedDict):
|
|
281
|
+
... name: str
|
|
282
|
+
... age: Annotated[int, positive_int]
|
|
283
|
+
|
|
284
|
+
>>> # Valid data
|
|
285
|
+
>>> validate_typed_dict(User, {"name": "John", "age": 30})
|
|
286
|
+
|
|
287
|
+
>>> # Invalid type for age
|
|
288
|
+
>>> validate_typed_dict(User, {"name": "John", "age": "30"})
|
|
289
|
+
huggingface_hub.errors.StrictDataclassFieldValidationError: Validation error for field 'age':
|
|
290
|
+
TypeError: Field 'age' expected int, got str (value: '30')
|
|
291
|
+
|
|
292
|
+
>>> # Invalid value for age
|
|
293
|
+
>>> validate_typed_dict(User, {"name": "John", "age": -1})
|
|
294
|
+
huggingface_hub.errors.StrictDataclassFieldValidationError: Validation error for field 'age':
|
|
295
|
+
ValueError: Value must be positive, got -1
|
|
296
|
+
```
|
|
297
|
+
"""
|
|
298
|
+
# Convert typed dict to dataclass
|
|
299
|
+
strict_cls = _build_strict_cls_from_typed_dict(schema)
|
|
300
|
+
|
|
301
|
+
# Validate the data by instantiating the strict dataclass
|
|
302
|
+
strict_cls(**data) # will raise if validation fails
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
@lru_cache
|
|
306
|
+
def _build_strict_cls_from_typed_dict(schema: type[TypedDictType]) -> Type:
|
|
307
|
+
# Extract type hints from the TypedDict class
|
|
308
|
+
type_hints = _get_typed_dict_annotations(schema)
|
|
309
|
+
|
|
310
|
+
# If the TypedDict is not total, wrap fields as NotRequired (unless explicitly Required or NotRequired)
|
|
311
|
+
if not getattr(schema, "__total__", True):
|
|
312
|
+
for key, value in type_hints.items():
|
|
313
|
+
origin = get_origin(value)
|
|
314
|
+
|
|
315
|
+
if origin is Annotated:
|
|
316
|
+
base, *meta = get_args(value)
|
|
317
|
+
if not _is_required_or_notrequired(base):
|
|
318
|
+
base = NotRequired[base]
|
|
319
|
+
type_hints[key] = Annotated[tuple([base] + list(meta))]
|
|
320
|
+
elif not _is_required_or_notrequired(value):
|
|
321
|
+
type_hints[key] = NotRequired[value]
|
|
322
|
+
|
|
323
|
+
# Convert type hints to dataclass fields
|
|
324
|
+
fields = []
|
|
325
|
+
for key, value in type_hints.items():
|
|
326
|
+
if get_origin(value) is Annotated:
|
|
327
|
+
base, *meta = get_args(value)
|
|
328
|
+
fields.append((key, base, field(default=_TYPED_DICT_DEFAULT_VALUE, metadata={"validator": meta[0]})))
|
|
329
|
+
else:
|
|
330
|
+
fields.append((key, value, field(default=_TYPED_DICT_DEFAULT_VALUE)))
|
|
331
|
+
|
|
332
|
+
# Create a strict dataclass from the TypedDict fields
|
|
333
|
+
return strict(make_dataclass(schema.__name__, fields))
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
def _get_typed_dict_annotations(schema: type[TypedDictType]) -> dict[str, Any]:
|
|
337
|
+
"""Extract type annotations from a TypedDict class."""
|
|
338
|
+
try:
|
|
339
|
+
# Available in Python 3.14+
|
|
340
|
+
import annotationlib
|
|
341
|
+
|
|
342
|
+
return annotationlib.get_annotations(schema)
|
|
343
|
+
except ImportError:
|
|
344
|
+
return {
|
|
345
|
+
# We do not use `get_type_hints` here to avoid evaluating ForwardRefs (which might fail).
|
|
346
|
+
# ForwardRefs are not validated by @strict anyway.
|
|
347
|
+
name: value if value is not None else type(None)
|
|
348
|
+
for name, value in schema.__dict__.get("__annotations__", {}).items()
|
|
349
|
+
}
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
def validated_field(
|
|
353
|
+
validator: Union[list[Validator_T], Validator_T],
|
|
354
|
+
default: Union[Any, _MISSING_TYPE] = MISSING,
|
|
355
|
+
default_factory: Union[Callable[[], Any], _MISSING_TYPE] = MISSING,
|
|
356
|
+
init: bool = True,
|
|
357
|
+
repr: bool = True,
|
|
358
|
+
hash: Optional[bool] = None,
|
|
359
|
+
compare: bool = True,
|
|
360
|
+
metadata: Optional[dict] = None,
|
|
361
|
+
**kwargs: Any,
|
|
362
|
+
) -> Any:
|
|
363
|
+
"""
|
|
364
|
+
Create a dataclass field with a custom validator.
|
|
365
|
+
|
|
366
|
+
Useful to apply several checks to a field. If only applying one rule, check out the [`as_validated_field`] decorator.
|
|
367
|
+
|
|
368
|
+
Args:
|
|
369
|
+
validator (`Callable` or `list[Callable]`):
|
|
370
|
+
A method that takes a value as input and raises ValueError/TypeError if the value is invalid.
|
|
371
|
+
Can be a list of validators to apply multiple checks.
|
|
372
|
+
**kwargs:
|
|
373
|
+
Additional arguments to pass to `dataclasses.field()`.
|
|
374
|
+
|
|
375
|
+
Returns:
|
|
376
|
+
A field with the validator attached in metadata
|
|
377
|
+
"""
|
|
378
|
+
if not isinstance(validator, list):
|
|
379
|
+
validator = [validator]
|
|
380
|
+
if metadata is None:
|
|
381
|
+
metadata = {}
|
|
382
|
+
metadata["validator"] = validator
|
|
383
|
+
return field( # type: ignore
|
|
384
|
+
default=default, # type: ignore [arg-type]
|
|
385
|
+
default_factory=default_factory, # type: ignore [arg-type]
|
|
386
|
+
init=init,
|
|
387
|
+
repr=repr,
|
|
388
|
+
hash=hash,
|
|
389
|
+
compare=compare,
|
|
390
|
+
metadata=metadata,
|
|
391
|
+
**kwargs,
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
def as_validated_field(validator: Validator_T):
|
|
396
|
+
"""
|
|
397
|
+
Decorates a validator function as a [`validated_field`] (i.e. a dataclass field with a custom validator).
|
|
398
|
+
|
|
399
|
+
Args:
|
|
400
|
+
validator (`Callable`):
|
|
401
|
+
A method that takes a value as input and raises ValueError/TypeError if the value is invalid.
|
|
402
|
+
"""
|
|
403
|
+
|
|
404
|
+
def _inner(
|
|
405
|
+
default: Union[Any, _MISSING_TYPE] = MISSING,
|
|
406
|
+
default_factory: Union[Callable[[], Any], _MISSING_TYPE] = MISSING,
|
|
407
|
+
init: bool = True,
|
|
408
|
+
repr: bool = True,
|
|
409
|
+
hash: Optional[bool] = None,
|
|
410
|
+
compare: bool = True,
|
|
411
|
+
metadata: Optional[dict] = None,
|
|
412
|
+
**kwargs: Any,
|
|
413
|
+
):
|
|
414
|
+
return validated_field(
|
|
415
|
+
validator,
|
|
416
|
+
default=default,
|
|
417
|
+
default_factory=default_factory,
|
|
418
|
+
init=init,
|
|
419
|
+
repr=repr,
|
|
420
|
+
hash=hash,
|
|
421
|
+
compare=compare,
|
|
422
|
+
metadata=metadata,
|
|
423
|
+
**kwargs,
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
return _inner
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
def type_validator(name: str, value: Any, expected_type: Any) -> None:
|
|
430
|
+
"""Validate that 'value' matches 'expected_type'."""
|
|
431
|
+
origin = get_origin(expected_type)
|
|
432
|
+
args = get_args(expected_type)
|
|
433
|
+
|
|
434
|
+
if expected_type is Any:
|
|
435
|
+
return
|
|
436
|
+
elif validator := _BASIC_TYPE_VALIDATORS.get(origin):
|
|
437
|
+
validator(name, value, args)
|
|
438
|
+
elif isinstance(expected_type, type): # simple types
|
|
439
|
+
_validate_simple_type(name, value, expected_type)
|
|
440
|
+
elif isinstance(expected_type, ForwardRef) or isinstance(expected_type, str):
|
|
441
|
+
return
|
|
442
|
+
elif origin is Required:
|
|
443
|
+
if value is _TYPED_DICT_DEFAULT_VALUE:
|
|
444
|
+
raise TypeError(f"Field '{name}' is required but missing.")
|
|
445
|
+
type_validator(name, value, args[0])
|
|
446
|
+
elif origin is NotRequired:
|
|
447
|
+
if value is _TYPED_DICT_DEFAULT_VALUE:
|
|
448
|
+
return
|
|
449
|
+
type_validator(name, value, args[0])
|
|
450
|
+
else:
|
|
451
|
+
raise TypeError(f"Unsupported type for field '{name}': {expected_type}")
|
|
452
|
+
|
|
453
|
+
|
|
454
|
+
def _validate_union(name: str, value: Any, args: tuple[Any, ...]) -> None:
|
|
455
|
+
"""Validate that value matches one of the types in a Union."""
|
|
456
|
+
errors = []
|
|
457
|
+
for t in args:
|
|
458
|
+
try:
|
|
459
|
+
type_validator(name, value, t)
|
|
460
|
+
return # Valid if any type matches
|
|
461
|
+
except TypeError as e:
|
|
462
|
+
errors.append(str(e))
|
|
463
|
+
|
|
464
|
+
raise TypeError(
|
|
465
|
+
f"Field '{name}' with value {repr(value)} doesn't match any type in {args}. Errors: {'; '.join(errors)}"
|
|
466
|
+
)
|
|
467
|
+
|
|
468
|
+
|
|
469
|
+
def _validate_literal(name: str, value: Any, args: tuple[Any, ...]) -> None:
|
|
470
|
+
"""Validate Literal type."""
|
|
471
|
+
if value not in args:
|
|
472
|
+
raise TypeError(f"Field '{name}' expected one of {args}, got {value}")
|
|
473
|
+
|
|
474
|
+
|
|
475
|
+
def _validate_list(name: str, value: Any, args: tuple[Any, ...]) -> None:
|
|
476
|
+
"""Validate list[T] type."""
|
|
477
|
+
if not isinstance(value, list):
|
|
478
|
+
raise TypeError(f"Field '{name}' expected a list, got {type(value).__name__}")
|
|
479
|
+
|
|
480
|
+
# Validate each item in the list
|
|
481
|
+
item_type = args[0]
|
|
482
|
+
for i, item in enumerate(value):
|
|
483
|
+
try:
|
|
484
|
+
type_validator(f"{name}[{i}]", item, item_type)
|
|
485
|
+
except TypeError as e:
|
|
486
|
+
raise TypeError(f"Invalid item at index {i} in list '{name}'") from e
|
|
487
|
+
|
|
488
|
+
|
|
489
|
+
def _validate_dict(name: str, value: Any, args: tuple[Any, ...]) -> None:
|
|
490
|
+
"""Validate dict[K, V] type."""
|
|
491
|
+
if not isinstance(value, dict):
|
|
492
|
+
raise TypeError(f"Field '{name}' expected a dict, got {type(value).__name__}")
|
|
493
|
+
|
|
494
|
+
# Validate keys and values
|
|
495
|
+
key_type, value_type = args
|
|
496
|
+
for k, v in value.items():
|
|
497
|
+
try:
|
|
498
|
+
type_validator(f"{name}.key", k, key_type)
|
|
499
|
+
type_validator(f"{name}[{k!r}]", v, value_type)
|
|
500
|
+
except TypeError as e:
|
|
501
|
+
raise TypeError(f"Invalid key or value in dict '{name}'") from e
|
|
502
|
+
|
|
503
|
+
|
|
504
|
+
def _validate_tuple(name: str, value: Any, args: tuple[Any, ...]) -> None:
|
|
505
|
+
"""Validate Tuple type."""
|
|
506
|
+
if not isinstance(value, tuple):
|
|
507
|
+
raise TypeError(f"Field '{name}' expected a tuple, got {type(value).__name__}")
|
|
508
|
+
|
|
509
|
+
# Handle variable-length tuples: tuple[T, ...]
|
|
510
|
+
if len(args) == 2 and args[1] is Ellipsis:
|
|
511
|
+
for i, item in enumerate(value):
|
|
512
|
+
try:
|
|
513
|
+
type_validator(f"{name}[{i}]", item, args[0])
|
|
514
|
+
except TypeError as e:
|
|
515
|
+
raise TypeError(f"Invalid item at index {i} in tuple '{name}'") from e
|
|
516
|
+
# Handle fixed-length tuples: tuple[T1, T2, ...]
|
|
517
|
+
elif len(args) != len(value):
|
|
518
|
+
raise TypeError(f"Field '{name}' expected a tuple of length {len(args)}, got {len(value)}")
|
|
519
|
+
else:
|
|
520
|
+
for i, (item, expected) in enumerate(zip(value, args)):
|
|
521
|
+
try:
|
|
522
|
+
type_validator(f"{name}[{i}]", item, expected)
|
|
523
|
+
except TypeError as e:
|
|
524
|
+
raise TypeError(f"Invalid item at index {i} in tuple '{name}'") from e
|
|
525
|
+
|
|
526
|
+
|
|
527
|
+
def _validate_set(name: str, value: Any, args: tuple[Any, ...]) -> None:
|
|
528
|
+
"""Validate set[T] type."""
|
|
529
|
+
if not isinstance(value, set):
|
|
530
|
+
raise TypeError(f"Field '{name}' expected a set, got {type(value).__name__}")
|
|
531
|
+
|
|
532
|
+
# Validate each item in the set
|
|
533
|
+
item_type = args[0]
|
|
534
|
+
for i, item in enumerate(value):
|
|
535
|
+
try:
|
|
536
|
+
type_validator(f"{name} item", item, item_type)
|
|
537
|
+
except TypeError as e:
|
|
538
|
+
raise TypeError(f"Invalid item in set '{name}'") from e
|
|
539
|
+
|
|
540
|
+
|
|
541
|
+
def _validate_simple_type(name: str, value: Any, expected_type: type) -> None:
|
|
542
|
+
"""Validate simple type (int, str, etc.)."""
|
|
543
|
+
if not isinstance(value, expected_type):
|
|
544
|
+
raise TypeError(
|
|
545
|
+
f"Field '{name}' expected {expected_type.__name__}, got {type(value).__name__} (value: {repr(value)})"
|
|
546
|
+
)
|
|
547
|
+
|
|
548
|
+
|
|
549
|
+
def _create_type_validator(field: Field) -> Validator_T:
|
|
550
|
+
"""Create a type validator function for a field."""
|
|
551
|
+
# Hacky: we cannot use a lambda here because of reference issues
|
|
552
|
+
|
|
553
|
+
def validator(value: Any) -> None:
|
|
554
|
+
type_validator(field.name, value, field.type)
|
|
555
|
+
|
|
556
|
+
return validator
|
|
557
|
+
|
|
558
|
+
|
|
559
|
+
def _is_validator(validator: Any) -> bool:
|
|
560
|
+
"""Check if a function is a validator.
|
|
561
|
+
|
|
562
|
+
A validator is a Callable that can be called with a single positional argument.
|
|
563
|
+
The validator can have more arguments with default values.
|
|
564
|
+
|
|
565
|
+
Basically, returns True if `validator(value)` is possible.
|
|
566
|
+
"""
|
|
567
|
+
if not callable(validator):
|
|
568
|
+
return False
|
|
569
|
+
|
|
570
|
+
signature = inspect.signature(validator)
|
|
571
|
+
parameters = list(signature.parameters.values())
|
|
572
|
+
if len(parameters) == 0:
|
|
573
|
+
return False
|
|
574
|
+
if parameters[0].kind not in (
|
|
575
|
+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
|
576
|
+
inspect.Parameter.POSITIONAL_ONLY,
|
|
577
|
+
inspect.Parameter.VAR_POSITIONAL,
|
|
578
|
+
):
|
|
579
|
+
return False
|
|
580
|
+
for parameter in parameters[1:]:
|
|
581
|
+
if parameter.default == inspect.Parameter.empty:
|
|
582
|
+
return False
|
|
583
|
+
return True
|
|
584
|
+
|
|
585
|
+
|
|
586
|
+
def _is_required_or_notrequired(type_hint: Any) -> bool:
|
|
587
|
+
"""Helper to check if a type is Required/NotRequired."""
|
|
588
|
+
return type_hint in (Required, NotRequired) or (get_origin(type_hint) in (Required, NotRequired))
|
|
589
|
+
|
|
590
|
+
|
|
591
|
+
_BASIC_TYPE_VALIDATORS = {
|
|
592
|
+
Union: _validate_union,
|
|
593
|
+
Literal: _validate_literal,
|
|
594
|
+
list: _validate_list,
|
|
595
|
+
dict: _validate_dict,
|
|
596
|
+
tuple: _validate_tuple,
|
|
597
|
+
set: _validate_set,
|
|
598
|
+
}
|
|
599
|
+
|
|
600
|
+
|
|
601
|
+
__all__ = [
|
|
602
|
+
"strict",
|
|
603
|
+
"validate_typed_dict",
|
|
604
|
+
"validated_field",
|
|
605
|
+
"Validator_T",
|
|
606
|
+
"StrictDataclassClassValidationError",
|
|
607
|
+
"StrictDataclassDefinitionError",
|
|
608
|
+
"StrictDataclassFieldValidationError",
|
|
609
|
+
]
|