schemez 1.1.0__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- schemez/__init__.py +21 -0
- schemez/bind_kwargs.py +193 -0
- schemez/create_type.py +340 -0
- schemez/executable.py +211 -0
- schemez/functionschema.py +772 -0
- schemez/helpers.py +72 -51
- schemez/schema.py +2 -1
- schemez/schema_generators.py +215 -0
- schemez/typedefs.py +205 -0
- {schemez-1.1.0.dist-info → schemez-1.2.0.dist-info}/METADATA +2 -1
- schemez-1.2.0.dist-info/RECORD +19 -0
- schemez-1.1.0.dist-info/RECORD +0 -13
- {schemez-1.1.0.dist-info → schemez-1.2.0.dist-info}/WHEEL +0 -0
schemez/__init__.py
CHANGED
@@ -26,19 +26,40 @@ from schemez.schemadef.schemadef import (
|
|
26
26
|
)
|
27
27
|
from schemez.pydantic_types import ModelIdentifier, ModelTemperature, MimeType
|
28
28
|
|
29
|
+
from schemez.executable import create_executable, ExecutableFunction
|
30
|
+
from schemez.functionschema import FunctionType, create_schema
|
31
|
+
from schemez.schema_generators import (
|
32
|
+
create_schemas_from_callables,
|
33
|
+
create_schemas_from_module,
|
34
|
+
create_schemas_from_class,
|
35
|
+
create_constructor_schema,
|
36
|
+
)
|
37
|
+
from schemez.typedefs import OpenAIFunctionDefinition, OpenAIFunctionTool
|
38
|
+
|
29
39
|
__version__ = version("schemez")
|
30
40
|
|
31
41
|
__all__ = [
|
42
|
+
"ExecutableFunction",
|
43
|
+
"FunctionType",
|
32
44
|
"ImportedSchemaDef",
|
33
45
|
"InlineSchemaDef",
|
34
46
|
"JSONCode",
|
35
47
|
"MimeType",
|
36
48
|
"ModelIdentifier",
|
37
49
|
"ModelTemperature",
|
50
|
+
"OpenAIFunctionDefinition",
|
51
|
+
"OpenAIFunctionTool",
|
38
52
|
"PythonCode",
|
39
53
|
"Schema",
|
40
54
|
"SchemaDef",
|
41
55
|
"SchemaField",
|
42
56
|
"TOMLCode",
|
43
57
|
"YAMLCode",
|
58
|
+
"__version__",
|
59
|
+
"create_constructor_schema",
|
60
|
+
"create_executable",
|
61
|
+
"create_schema",
|
62
|
+
"create_schemas_from_callables",
|
63
|
+
"create_schemas_from_class",
|
64
|
+
"create_schemas_from_module",
|
44
65
|
]
|
schemez/bind_kwargs.py
ADDED
@@ -0,0 +1,193 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import inspect
|
4
|
+
from typing import TYPE_CHECKING, Any
|
5
|
+
|
6
|
+
|
7
|
+
if TYPE_CHECKING:
|
8
|
+
from collections.abc import Callable
|
9
|
+
|
10
|
+
|
11
|
+
class BoundFunction[T]:
|
12
|
+
"""A function with pre-bound parameters.
|
13
|
+
|
14
|
+
This class wraps a function and binds some parameters to fixed values,
|
15
|
+
while updating the signature and docstring to reflect only the remaining
|
16
|
+
parameters that can still be provided when calling.
|
17
|
+
"""
|
18
|
+
|
19
|
+
def __init__(self, func: Callable[..., T], **bound_kwargs: Any):
|
20
|
+
"""Initialize with a function and parameters to bind.
|
21
|
+
|
22
|
+
Args:
|
23
|
+
func: The function to wrap
|
24
|
+
**bound_kwargs: Parameters to bind to fixed values
|
25
|
+
|
26
|
+
Raises:
|
27
|
+
ValueError: If any parameter name is not in the function signature
|
28
|
+
"""
|
29
|
+
self.func = func
|
30
|
+
self.bound_kwargs = bound_kwargs
|
31
|
+
self.__name__ = func.__name__
|
32
|
+
self.__module__ = func.__module__
|
33
|
+
self.__qualname__ = func.__qualname__
|
34
|
+
self.__doc__ = self._update_docstring(func.__doc__)
|
35
|
+
self.__annotations__ = self._update_annotations(
|
36
|
+
getattr(func, "__annotations__", {})
|
37
|
+
)
|
38
|
+
self.__signature__ = self._update_signature()
|
39
|
+
|
40
|
+
# Verify all bound kwargs are valid parameters
|
41
|
+
sig = inspect.signature(func)
|
42
|
+
for param in bound_kwargs:
|
43
|
+
if param not in sig.parameters:
|
44
|
+
msg = f"Parameter {param!r} not found in signature of {func.__name__}"
|
45
|
+
raise ValueError(msg)
|
46
|
+
|
47
|
+
def __call__(self, *args: Any, **kwargs: Any) -> T:
|
48
|
+
"""Call the function with the bound parameters.
|
49
|
+
|
50
|
+
Args:
|
51
|
+
*args: Positional arguments for the function
|
52
|
+
**kwargs: Keyword arguments for the function
|
53
|
+
|
54
|
+
Returns:
|
55
|
+
The return value from the wrapped function
|
56
|
+
"""
|
57
|
+
# Combine bound parameters with provided parameters
|
58
|
+
all_kwargs = {**self.bound_kwargs, **kwargs}
|
59
|
+
return self.func(*args, **all_kwargs)
|
60
|
+
|
61
|
+
def _update_signature(self) -> inspect.Signature:
|
62
|
+
"""Create a new signature excluding bound parameters.
|
63
|
+
|
64
|
+
Returns:
|
65
|
+
Updated signature without bound parameters
|
66
|
+
"""
|
67
|
+
sig = inspect.signature(self.func)
|
68
|
+
parameters = [
|
69
|
+
param
|
70
|
+
for name, param in sig.parameters.items()
|
71
|
+
if name not in self.bound_kwargs
|
72
|
+
]
|
73
|
+
return sig.replace(parameters=parameters)
|
74
|
+
|
75
|
+
def _update_annotations(self, annotations: dict[str, Any]) -> dict[str, Any]:
|
76
|
+
"""Remove bound parameters from annotations.
|
77
|
+
|
78
|
+
Args:
|
79
|
+
annotations: Original function annotations
|
80
|
+
|
81
|
+
Returns:
|
82
|
+
Updated annotations dictionary
|
83
|
+
"""
|
84
|
+
return {
|
85
|
+
name: ann
|
86
|
+
for name, ann in annotations.items()
|
87
|
+
if name not in self.bound_kwargs and name != "return"
|
88
|
+
}
|
89
|
+
|
90
|
+
def _update_docstring(self, docstring: str | None) -> str | None:
|
91
|
+
"""Update docstring to remove bound parameters.
|
92
|
+
|
93
|
+
Args:
|
94
|
+
docstring: Original function docstring
|
95
|
+
|
96
|
+
Returns:
|
97
|
+
Updated docstring with bound parameters removed
|
98
|
+
"""
|
99
|
+
if not docstring:
|
100
|
+
return docstring
|
101
|
+
|
102
|
+
lines = docstring.splitlines()
|
103
|
+
new_lines = []
|
104
|
+
|
105
|
+
# Find the Args section and modify it
|
106
|
+
in_args_section = False
|
107
|
+
skip_lines = False
|
108
|
+
current_param = None
|
109
|
+
|
110
|
+
for line in lines:
|
111
|
+
# Check if entering Args section
|
112
|
+
if "Args:" in line:
|
113
|
+
in_args_section = True
|
114
|
+
new_lines.append(line)
|
115
|
+
continue
|
116
|
+
if in_args_section and line.strip() and not line.startswith(" "):
|
117
|
+
in_args_section = False
|
118
|
+
if in_args_section and ":" in line:
|
119
|
+
# Get parameter name from the line
|
120
|
+
param_name = line.strip().split(":", 1)[0].strip()
|
121
|
+
if param_name in self.bound_kwargs:
|
122
|
+
skip_lines = True
|
123
|
+
current_param = param_name
|
124
|
+
else:
|
125
|
+
skip_lines = False
|
126
|
+
current_param = None
|
127
|
+
if (
|
128
|
+
in_args_section
|
129
|
+
and current_param
|
130
|
+
and line.strip()
|
131
|
+
and ":" in line.lstrip()
|
132
|
+
):
|
133
|
+
new_param = line.strip().split(":", 1)[0].strip()
|
134
|
+
if new_param != current_param:
|
135
|
+
skip_lines = False
|
136
|
+
current_param = None
|
137
|
+
|
138
|
+
# Add the line if not skipping
|
139
|
+
if not skip_lines:
|
140
|
+
new_lines.append(line)
|
141
|
+
|
142
|
+
return "\n".join(new_lines)
|
143
|
+
|
144
|
+
|
145
|
+
if __name__ == "__main__":
|
146
|
+
import asyncio
|
147
|
+
|
148
|
+
async def search_db(
|
149
|
+
query: str,
|
150
|
+
k: int = 5,
|
151
|
+
filters: dict[str, list[str]] | None = None,
|
152
|
+
min_score: float = 0.7,
|
153
|
+
) -> list[dict]:
|
154
|
+
"""Search the database for relevant information.
|
155
|
+
|
156
|
+
Args:
|
157
|
+
query: Search query text
|
158
|
+
k: Number of results to return
|
159
|
+
filters: Filters to apply to search
|
160
|
+
2nd line
|
161
|
+
min_score: Minimum relevance score
|
162
|
+
|
163
|
+
Returns:
|
164
|
+
List of search results
|
165
|
+
|
166
|
+
Example:
|
167
|
+
>>> await search_db("quantum computing", k=3)
|
168
|
+
"""
|
169
|
+
print(f"query={query}, k={k}, filters={filters}, min_score={min_score}")
|
170
|
+
return [{"id": 1, "score": 0.9}, {"id": 2, "score": 0.8}]
|
171
|
+
|
172
|
+
# Create a bound version
|
173
|
+
simple_search = BoundFunction(search_db, k=3, min_score=0.8)
|
174
|
+
print("Original function:")
|
175
|
+
print(f"Signature: {inspect.signature(search_db)}")
|
176
|
+
print(f"Docstring:\n{search_db.__doc__}")
|
177
|
+
|
178
|
+
# Print bound function info
|
179
|
+
print("\nBound function:")
|
180
|
+
print(f"Signature: {inspect.signature(simple_search)}")
|
181
|
+
print(f"Docstring:\n{simple_search.__doc__}")
|
182
|
+
|
183
|
+
# Run both functions to compare
|
184
|
+
async def run_test():
|
185
|
+
print("\nCalling original function:")
|
186
|
+
result1 = await search_db("quantum computing")
|
187
|
+
print(f"Result: {result1}")
|
188
|
+
|
189
|
+
print("\nCalling bound function:")
|
190
|
+
result2 = await simple_search("quantum computing")
|
191
|
+
print(f"Result: {result2}")
|
192
|
+
|
193
|
+
asyncio.run(run_test())
|
schemez/create_type.py
ADDED
@@ -0,0 +1,340 @@
|
|
1
|
+
"""Convert JSON schema to appropriate Python type with validation.
|
2
|
+
|
3
|
+
Credits to Marvin / prefect for original code.
|
4
|
+
"""
|
5
|
+
|
6
|
+
from copy import deepcopy
|
7
|
+
from dataclasses import MISSING, field, make_dataclass
|
8
|
+
from datetime import datetime
|
9
|
+
from enum import Enum
|
10
|
+
import hashlib
|
11
|
+
import json
|
12
|
+
import re
|
13
|
+
from typing import Annotated, Any, ForwardRef, Literal, Optional, Union
|
14
|
+
|
15
|
+
from pydantic import AnyUrl, EmailStr, Field, Json, StringConstraints, model_validator
|
16
|
+
|
17
|
+
|
18
|
+
__all__ = ["jsonschema_to_type", "merge_defaults"]
|
19
|
+
|
20
|
+
|
21
|
+
FORMAT_TYPES = {"date-time": datetime, "email": EmailStr, "uri": AnyUrl, "json": Json}
|
22
|
+
|
23
|
+
_classes: dict[str | tuple[str, ...], Any] = {}
|
24
|
+
|
25
|
+
|
26
|
+
def jsonschema_to_type(
|
27
|
+
schema: dict[str, Any], name: str | None = None
|
28
|
+
) -> type | ForwardRef | Enum:
|
29
|
+
# Always use the top-level schema for references
|
30
|
+
if schema.get("type") == "object":
|
31
|
+
return _create_dataclass(schema, name, schemas=schema)
|
32
|
+
if name:
|
33
|
+
msg = f"Can not apply name to non-object schema: {name}"
|
34
|
+
raise ValueError(msg)
|
35
|
+
return schema_to_type(schema, schemas=schema)
|
36
|
+
|
37
|
+
|
38
|
+
def _hash_schema(schema: dict[str, Any]) -> str:
|
39
|
+
"""Generate a deterministic hash for schema caching."""
|
40
|
+
return hashlib.sha256(json.dumps(schema, sort_keys=True).encode()).hexdigest()
|
41
|
+
|
42
|
+
|
43
|
+
def resolve_ref(ref: str, schemas: dict[str, Any]) -> dict[str, Any]:
|
44
|
+
"""Resolve JSON Schema reference to target schema."""
|
45
|
+
path = ref.replace("#/", "").split("/")
|
46
|
+
current = schemas
|
47
|
+
for part in path:
|
48
|
+
current = current.get(part, {})
|
49
|
+
return current
|
50
|
+
|
51
|
+
|
52
|
+
def create_string_type(schema: dict[str, Any]) -> type | Annotated: # type: ignore
|
53
|
+
"""Create string type with optional constraints."""
|
54
|
+
if "const" in schema:
|
55
|
+
return Literal[schema["const"]] # type: ignore
|
56
|
+
|
57
|
+
if fmt := schema.get("format"):
|
58
|
+
if fmt == "uri":
|
59
|
+
return AnyUrl
|
60
|
+
if fmt == "uri-reference":
|
61
|
+
return str
|
62
|
+
return FORMAT_TYPES.get(fmt, str)
|
63
|
+
|
64
|
+
constraints = {
|
65
|
+
k: v
|
66
|
+
for k, v in {
|
67
|
+
"min_length": schema.get("minLength"),
|
68
|
+
"max_length": schema.get("maxLength"),
|
69
|
+
"pattern": schema.get("pattern"),
|
70
|
+
}.items()
|
71
|
+
if v is not None
|
72
|
+
}
|
73
|
+
|
74
|
+
return Annotated[str, StringConstraints(**constraints)] if constraints else str
|
75
|
+
|
76
|
+
|
77
|
+
def create_numeric_type(
|
78
|
+
base: type[int | float], schema: dict[str, Any]
|
79
|
+
) -> type | Annotated: # type: ignore
|
80
|
+
"""Create numeric type with optional constraints."""
|
81
|
+
if "const" in schema:
|
82
|
+
return Literal[schema["const"]] # type: ignore
|
83
|
+
|
84
|
+
constraints = {
|
85
|
+
k: v
|
86
|
+
for k, v in {
|
87
|
+
"gt": schema.get("exclusiveMinimum"),
|
88
|
+
"ge": schema.get("minimum"),
|
89
|
+
"lt": schema.get("exclusiveMaximum"),
|
90
|
+
"le": schema.get("maximum"),
|
91
|
+
"multiple_of": schema.get("multipleOf"),
|
92
|
+
}.items()
|
93
|
+
if v is not None
|
94
|
+
}
|
95
|
+
|
96
|
+
return Annotated[base, Field(**constraints)] if constraints else base
|
97
|
+
|
98
|
+
|
99
|
+
def create_enum(name: str, values: list[Any]) -> type | Enum:
|
100
|
+
"""Create enum type from list of values."""
|
101
|
+
if all(isinstance(v, str) for v in values):
|
102
|
+
return Enum(name, {v.upper(): v for v in values})
|
103
|
+
return Literal[tuple(values)] # type: ignore
|
104
|
+
|
105
|
+
|
106
|
+
def create_array_type(
|
107
|
+
schema: dict[str, Any], schemas: dict[str, Any]
|
108
|
+
) -> type | Annotated: # type: ignore
|
109
|
+
"""Create list/set type with optional constraints."""
|
110
|
+
items = schema.get("items", {})
|
111
|
+
if isinstance(items, list):
|
112
|
+
# Handle positional item schemas
|
113
|
+
item_types = [schema_to_type(s, schemas) for s in items]
|
114
|
+
combined = Union[tuple(item_types)] # type: ignore # noqa: UP007
|
115
|
+
base = list[combined] # type: ignore
|
116
|
+
else:
|
117
|
+
# Handle single item schema
|
118
|
+
item_type = schema_to_type(items, schemas)
|
119
|
+
base = set if schema.get("uniqueItems") else list # type: ignore
|
120
|
+
base = base[item_type] # type: ignore
|
121
|
+
|
122
|
+
constraints = {
|
123
|
+
k: v
|
124
|
+
for k, v in {
|
125
|
+
"min_length": schema.get("minItems"),
|
126
|
+
"max_length": schema.get("maxItems"),
|
127
|
+
}.items()
|
128
|
+
if v is not None
|
129
|
+
}
|
130
|
+
|
131
|
+
return Annotated[base, Field(**constraints)] if constraints else base
|
132
|
+
|
133
|
+
|
134
|
+
def schema_to_type( # noqa: PLR0911
|
135
|
+
schema: dict[str, Any], schemas: dict[str, Any]
|
136
|
+
) -> type | ForwardRef | Enum:
|
137
|
+
"""Convert schema to appropriate Python type."""
|
138
|
+
if not schema:
|
139
|
+
return object
|
140
|
+
if "type" not in schema and "properties" in schema:
|
141
|
+
return _create_dataclass(schema, schema.get("title"), schemas)
|
142
|
+
|
143
|
+
# Handle references first
|
144
|
+
if "$ref" in schema:
|
145
|
+
ref = schema["$ref"]
|
146
|
+
# Handle self-reference
|
147
|
+
if ref == "#":
|
148
|
+
return ForwardRef(schema.get("title", "Root"))
|
149
|
+
return schema_to_type(resolve_ref(ref, schemas), schemas)
|
150
|
+
|
151
|
+
if "const" in schema:
|
152
|
+
return Literal[schema["const"]] # type: ignore
|
153
|
+
|
154
|
+
if "enum" in schema:
|
155
|
+
return create_enum(f"Enum_{len(_classes)}", schema["enum"])
|
156
|
+
|
157
|
+
schema_type = schema.get("type")
|
158
|
+
if not schema_type:
|
159
|
+
return Any # type: ignore
|
160
|
+
|
161
|
+
if isinstance(schema_type, list):
|
162
|
+
# Create a copy of the schema for each type, but keep all constraints
|
163
|
+
types = []
|
164
|
+
for t in schema_type:
|
165
|
+
type_schema = schema.copy()
|
166
|
+
type_schema["type"] = t
|
167
|
+
types.append(schema_to_type(type_schema, schemas))
|
168
|
+
has_null = type(None) in types
|
169
|
+
types = [t for t in types if t is not type(None)]
|
170
|
+
if has_null:
|
171
|
+
return Optional[tuple(types) if len(types) > 1 else types[0]] # type: ignore # noqa: UP045
|
172
|
+
return Union[tuple(types)] # type: ignore # noqa: UP007
|
173
|
+
|
174
|
+
type_handlers = {
|
175
|
+
"string": lambda s: create_string_type(s),
|
176
|
+
"integer": lambda s: create_numeric_type(int, s),
|
177
|
+
"number": lambda s: create_numeric_type(float, s),
|
178
|
+
"boolean": lambda _: bool,
|
179
|
+
"null": lambda _: type(None),
|
180
|
+
"array": lambda s: create_array_type(s, schemas),
|
181
|
+
"object": lambda s: _create_dataclass(s, s.get("title"), schemas),
|
182
|
+
}
|
183
|
+
|
184
|
+
return type_handlers.get(schema_type, lambda _: Any)(schema)
|
185
|
+
|
186
|
+
|
187
|
+
def sanitize_name(name: str) -> str:
|
188
|
+
"""Convert string to valid Python identifier."""
|
189
|
+
cleaned = re.sub(r"[^0-9a-zA-Z_]", "_", name)
|
190
|
+
cleaned = re.sub(r"__+", "_", cleaned)
|
191
|
+
cleaned = cleaned.lower()
|
192
|
+
if not name or not re.match(r"[a-zA-Z]", name[0]):
|
193
|
+
cleaned = f"field_{cleaned}"
|
194
|
+
return re.sub(r"__+", "_", cleaned).strip("_")
|
195
|
+
|
196
|
+
|
197
|
+
def get_default_value(
|
198
|
+
schema: dict[str, Any], prop_name: str, parent_default: dict[str, Any] | None = None
|
199
|
+
) -> Any:
|
200
|
+
"""Get default value with proper priority ordering.
|
201
|
+
|
202
|
+
1. Value from parent's default if it exists
|
203
|
+
2. Property's own default if it exists
|
204
|
+
3. None
|
205
|
+
"""
|
206
|
+
if parent_default is not None and prop_name in parent_default:
|
207
|
+
return parent_default[prop_name]
|
208
|
+
return schema.get("default")
|
209
|
+
|
210
|
+
|
211
|
+
def create_field_with_default(
|
212
|
+
field_type: type,
|
213
|
+
default_value: Any,
|
214
|
+
schema: dict[str, Any],
|
215
|
+
) -> Any:
|
216
|
+
"""Create a field with simplified default handling."""
|
217
|
+
if isinstance(default_value, dict | list) or default_value is None:
|
218
|
+
return field(default=None)
|
219
|
+
return field(default=default_value)
|
220
|
+
|
221
|
+
|
222
|
+
def _create_dataclass(
|
223
|
+
schema: dict[str, Any], name: str | None = None, schemas: dict[str, Any] | None = None
|
224
|
+
) -> type | ForwardRef:
|
225
|
+
"""Create dataclass from object schema."""
|
226
|
+
name = name or schema.get("title", "Root")
|
227
|
+
assert name
|
228
|
+
schema_hash = _hash_schema(schema)
|
229
|
+
cache_key = (schema_hash, name)
|
230
|
+
original_schema = schema.copy() # Store copy for validator
|
231
|
+
if cache_key in _classes:
|
232
|
+
existing = _classes[cache_key]
|
233
|
+
if existing is None:
|
234
|
+
return ForwardRef(name)
|
235
|
+
return existing
|
236
|
+
_classes[cache_key] = None
|
237
|
+
if "$ref" in schema:
|
238
|
+
ref = schema["$ref"]
|
239
|
+
if ref == "#":
|
240
|
+
return ForwardRef(name)
|
241
|
+
schema = resolve_ref(ref, schemas or {})
|
242
|
+
properties = schema.get("properties", {})
|
243
|
+
required = schema.get("required", [])
|
244
|
+
fields = []
|
245
|
+
for prop_name, prop_schema in properties.items():
|
246
|
+
field_name = sanitize_name(prop_name)
|
247
|
+
if prop_schema.get("$ref") == "#":
|
248
|
+
field_type: type | ForwardRef | Enum = ForwardRef(name)
|
249
|
+
else:
|
250
|
+
assert schemas
|
251
|
+
field_type = schema_to_type(prop_schema, schemas)
|
252
|
+
default_val = prop_schema.get("default", MISSING)
|
253
|
+
is_required = prop_name in required
|
254
|
+
meta = {"alias": prop_name}
|
255
|
+
if default_val is not MISSING:
|
256
|
+
if isinstance(default_val, dict | list):
|
257
|
+
field_def = field(
|
258
|
+
default_factory=lambda d=default_val: deepcopy(d), metadata=meta
|
259
|
+
)
|
260
|
+
else:
|
261
|
+
field_def = field(default=default_val, metadata=meta)
|
262
|
+
elif is_required:
|
263
|
+
field_def = field(metadata=meta)
|
264
|
+
else:
|
265
|
+
field_def = field(default=None, metadata=meta)
|
266
|
+
|
267
|
+
if (is_required and default_val is not MISSING) or is_required:
|
268
|
+
fields.append((field_name, field_type, field_def))
|
269
|
+
else:
|
270
|
+
fields.append((field_name, Optional[field_type], field_def)) # type: ignore # noqa: UP045
|
271
|
+
|
272
|
+
cls = make_dataclass(name, fields, kw_only=True)
|
273
|
+
|
274
|
+
@model_validator(mode="before")
|
275
|
+
@classmethod
|
276
|
+
def _apply_defaults(cls, data):
|
277
|
+
if isinstance(data, dict):
|
278
|
+
return merge_defaults(data, original_schema)
|
279
|
+
return data
|
280
|
+
|
281
|
+
cls._apply_defaults = _apply_defaults # type: ignore
|
282
|
+
_classes[cache_key] = cls
|
283
|
+
return cls
|
284
|
+
|
285
|
+
|
286
|
+
def merge_defaults(
|
287
|
+
data: dict[str, Any],
|
288
|
+
schema: dict[str, Any],
|
289
|
+
parent_default: dict[str, Any] | None = None,
|
290
|
+
) -> dict[str, Any]:
|
291
|
+
"""Merge defaults with provided data at all levels."""
|
292
|
+
if not data:
|
293
|
+
if parent_default:
|
294
|
+
result = dict(parent_default)
|
295
|
+
elif "default" in schema:
|
296
|
+
result = dict(schema["default"])
|
297
|
+
else:
|
298
|
+
result = {}
|
299
|
+
elif parent_default:
|
300
|
+
result = dict(parent_default)
|
301
|
+
for key, value in data.items():
|
302
|
+
if (
|
303
|
+
isinstance(value, dict)
|
304
|
+
and key in result
|
305
|
+
and isinstance(result[key], dict)
|
306
|
+
):
|
307
|
+
# recursively merge nested dicts
|
308
|
+
result[key] = merge_defaults(value, {"properties": {}}, result[key])
|
309
|
+
else:
|
310
|
+
result[key] = value
|
311
|
+
else:
|
312
|
+
result = dict(data)
|
313
|
+
|
314
|
+
# For each property in the schema
|
315
|
+
for prop_name, prop_schema in schema.get("properties", {}).items():
|
316
|
+
# If property is missing, apply defaults in priority order
|
317
|
+
if prop_name not in result:
|
318
|
+
if parent_default and prop_name in parent_default:
|
319
|
+
result[prop_name] = parent_default[prop_name]
|
320
|
+
elif "default" in prop_schema:
|
321
|
+
result[prop_name] = prop_schema["default"]
|
322
|
+
|
323
|
+
# If property exists and is an object, recursively merge
|
324
|
+
if (
|
325
|
+
prop_name in result
|
326
|
+
and isinstance(result[prop_name], dict)
|
327
|
+
and prop_schema.get("type") == "object"
|
328
|
+
):
|
329
|
+
# Get the appropriate default for this nested object
|
330
|
+
nested_default = None
|
331
|
+
if parent_default and prop_name in parent_default:
|
332
|
+
nested_default = parent_default[prop_name]
|
333
|
+
elif "default" in prop_schema:
|
334
|
+
nested_default = prop_schema["default"]
|
335
|
+
|
336
|
+
result[prop_name] = merge_defaults(
|
337
|
+
result[prop_name], prop_schema, nested_default
|
338
|
+
)
|
339
|
+
|
340
|
+
return result
|