cadwyn 5.4.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cadwyn/__init__.py +44 -0
- cadwyn/__main__.py +78 -0
- cadwyn/_asts.py +155 -0
- cadwyn/_importer.py +31 -0
- cadwyn/_internal/__init__.py +0 -0
- cadwyn/_internal/context_vars.py +9 -0
- cadwyn/_render.py +155 -0
- cadwyn/_utils.py +79 -0
- cadwyn/applications.py +484 -0
- cadwyn/changelogs.py +503 -0
- cadwyn/dependencies.py +5 -0
- cadwyn/exceptions.py +78 -0
- cadwyn/middleware.py +131 -0
- cadwyn/py.typed +0 -0
- cadwyn/route_generation.py +536 -0
- cadwyn/routing.py +159 -0
- cadwyn/schema_generation.py +1162 -0
- cadwyn/static/__init__.py +0 -0
- cadwyn/static/docs.html +136 -0
- cadwyn/structure/__init__.py +31 -0
- cadwyn/structure/common.py +18 -0
- cadwyn/structure/data.py +249 -0
- cadwyn/structure/endpoints.py +170 -0
- cadwyn/structure/enums.py +42 -0
- cadwyn/structure/schemas.py +338 -0
- cadwyn/structure/versions.py +756 -0
- cadwyn-5.4.6.dist-info/METADATA +90 -0
- cadwyn-5.4.6.dist-info/RECORD +31 -0
- cadwyn-5.4.6.dist-info/WHEEL +4 -0
- cadwyn-5.4.6.dist-info/entry_points.txt +2 -0
- cadwyn-5.4.6.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,1162 @@
|
|
|
1
|
+
import ast
|
|
2
|
+
import copy
|
|
3
|
+
import dataclasses
|
|
4
|
+
import functools
|
|
5
|
+
import inspect
|
|
6
|
+
import sys
|
|
7
|
+
import textwrap
|
|
8
|
+
import types
|
|
9
|
+
import typing
|
|
10
|
+
from collections.abc import Callable, Sequence
|
|
11
|
+
from datetime import date
|
|
12
|
+
from enum import Enum
|
|
13
|
+
from functools import cache
|
|
14
|
+
from typing import (
|
|
15
|
+
TYPE_CHECKING,
|
|
16
|
+
Annotated,
|
|
17
|
+
ClassVar,
|
|
18
|
+
Generic,
|
|
19
|
+
Union,
|
|
20
|
+
_BaseGenericAlias, # pyright: ignore[reportAttributeAccessIssue]
|
|
21
|
+
cast,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
import fastapi.params
|
|
25
|
+
import fastapi.security.base
|
|
26
|
+
import fastapi.utils
|
|
27
|
+
import pydantic
|
|
28
|
+
import pydantic._internal._decorators
|
|
29
|
+
from fastapi import Response
|
|
30
|
+
from fastapi.dependencies.utils import is_async_gen_callable, is_coroutine_callable, is_gen_callable
|
|
31
|
+
from fastapi.routing import APIRoute
|
|
32
|
+
from pydantic import BaseModel, Field, RootModel
|
|
33
|
+
from pydantic._internal import _decorators
|
|
34
|
+
from pydantic._internal._decorators import (
|
|
35
|
+
FieldSerializerDecoratorInfo,
|
|
36
|
+
FieldValidatorDecoratorInfo,
|
|
37
|
+
ModelSerializerDecoratorInfo,
|
|
38
|
+
ModelValidatorDecoratorInfo,
|
|
39
|
+
RootValidatorDecoratorInfo,
|
|
40
|
+
ValidatorDecoratorInfo,
|
|
41
|
+
)
|
|
42
|
+
from pydantic._internal._known_annotated_metadata import collect_known_metadata
|
|
43
|
+
from pydantic._internal._typing_extra import try_eval_type as pydantic_try_eval_type
|
|
44
|
+
from pydantic.fields import ComputedFieldInfo, FieldInfo
|
|
45
|
+
from pydantic_core import PydanticUndefined
|
|
46
|
+
from typing_extensions import (
|
|
47
|
+
Any,
|
|
48
|
+
Doc,
|
|
49
|
+
Self,
|
|
50
|
+
TypeAlias,
|
|
51
|
+
TypeAliasType,
|
|
52
|
+
TypeVar,
|
|
53
|
+
_AnnotatedAlias,
|
|
54
|
+
assert_never,
|
|
55
|
+
final,
|
|
56
|
+
get_args,
|
|
57
|
+
get_origin,
|
|
58
|
+
overload,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
from cadwyn._utils import (
|
|
62
|
+
DATACLASS_KW_ONLY,
|
|
63
|
+
DATACLASS_SLOTS,
|
|
64
|
+
Sentinel,
|
|
65
|
+
UnionType,
|
|
66
|
+
fully_unwrap_decorator,
|
|
67
|
+
get_name_of_function_wrapped_in_pydantic_validator,
|
|
68
|
+
lenient_issubclass,
|
|
69
|
+
)
|
|
70
|
+
from cadwyn.exceptions import CadwynError, InvalidGenerationInstructionError
|
|
71
|
+
from cadwyn.structure.common import VersionType
|
|
72
|
+
from cadwyn.structure.data import ResponseInfo
|
|
73
|
+
from cadwyn.structure.enums import AlterEnumSubInstruction, EnumDidntHaveMembersInstruction, EnumHadMembersInstruction
|
|
74
|
+
from cadwyn.structure.schemas import (
|
|
75
|
+
AlterSchemaSubInstruction,
|
|
76
|
+
FieldDidntExistInstruction,
|
|
77
|
+
FieldDidntHaveInstruction,
|
|
78
|
+
FieldExistedAsInstruction,
|
|
79
|
+
FieldHadInstruction,
|
|
80
|
+
SchemaHadInstruction,
|
|
81
|
+
ValidatorDidntExistInstruction,
|
|
82
|
+
ValidatorExistedInstruction,
|
|
83
|
+
_get_model_decorators,
|
|
84
|
+
)
|
|
85
|
+
from cadwyn.structure.versions import _CADWYN_REQUEST_PARAM_NAME, _CADWYN_RESPONSE_PARAM_NAME, VersionBundle
|
|
86
|
+
|
|
87
|
+
if TYPE_CHECKING:
|
|
88
|
+
from cadwyn.structure.versions import HeadVersion, Version, VersionBundle
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
_Call = TypeVar("_Call", bound=Callable[..., Any])
|
|
92
|
+
|
|
93
|
+
_FieldName: TypeAlias = str
|
|
94
|
+
_T_ANY_MODEL = TypeVar("_T_ANY_MODEL", bound=Union[BaseModel, Enum])
|
|
95
|
+
_T_ENUM = TypeVar("_T_ENUM", bound=Enum)
|
|
96
|
+
|
|
97
|
+
_T_PYDANTIC_MODEL = TypeVar("_T_PYDANTIC_MODEL", bound=BaseModel)
|
|
98
|
+
PYDANTIC_DECORATOR_TYPE_TO_DECORATOR_MAP = {
|
|
99
|
+
ValidatorDecoratorInfo: pydantic.validator, # pyright: ignore[reportDeprecated]
|
|
100
|
+
FieldValidatorDecoratorInfo: pydantic.field_validator,
|
|
101
|
+
FieldSerializerDecoratorInfo: pydantic.field_serializer,
|
|
102
|
+
RootValidatorDecoratorInfo: pydantic.root_validator, # pyright: ignore[reportDeprecated]
|
|
103
|
+
ModelValidatorDecoratorInfo: pydantic.model_validator,
|
|
104
|
+
ModelSerializerDecoratorInfo: pydantic.model_serializer,
|
|
105
|
+
ComputedFieldInfo: pydantic.computed_field,
|
|
106
|
+
}
|
|
107
|
+
_PYDANTIC_ALL_EXPORTED_NAMES = set(pydantic.__all__)
|
|
108
|
+
_DEFAULT_PYDANTIC_CLASSES = (BaseModel, RootModel)
|
|
109
|
+
|
|
110
|
+
try:
|
|
111
|
+
from pydantic_settings import BaseSettings
|
|
112
|
+
|
|
113
|
+
_DEFAULT_PYDANTIC_CLASSES = (*_DEFAULT_PYDANTIC_CLASSES, BaseSettings)
|
|
114
|
+
except ImportError: # pragma: no cover
|
|
115
|
+
pass
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
VALIDATOR_CONFIG_KEY = "__validators__"
|
|
119
|
+
_all_field_arg_names = sorted(
|
|
120
|
+
[
|
|
121
|
+
name
|
|
122
|
+
for name, param in inspect.signature(Field).parameters.items()
|
|
123
|
+
if param.kind in {inspect._ParameterKind.KEYWORD_ONLY, inspect._ParameterKind.POSITIONAL_OR_KEYWORD}
|
|
124
|
+
],
|
|
125
|
+
)
|
|
126
|
+
EXTRA_FIELD_NAME = "json_schema_extra"
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
_empty_field_info = Field()
|
|
130
|
+
dict_of_empty_field_info = {k: getattr(_empty_field_info, k) for k in FieldInfo.__slots__}
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
@dataclasses.dataclass(**DATACLASS_SLOTS)
|
|
134
|
+
class PydanticFieldWrapper:
|
|
135
|
+
"""We DO NOT maintain field.metadata at all"""
|
|
136
|
+
|
|
137
|
+
init_model_field: dataclasses.InitVar[FieldInfo]
|
|
138
|
+
|
|
139
|
+
annotation: Any
|
|
140
|
+
name_from_newer_version: str
|
|
141
|
+
|
|
142
|
+
passed_field_attributes: dict[str, Any] = dataclasses.field(init=False)
|
|
143
|
+
|
|
144
|
+
def __post_init__(self, init_model_field: FieldInfo):
|
|
145
|
+
self.passed_field_attributes = _extract_passed_field_attributes(init_model_field)
|
|
146
|
+
|
|
147
|
+
def update_attribute(self, *, name: str, value: Any):
|
|
148
|
+
self.passed_field_attributes[name] = value
|
|
149
|
+
|
|
150
|
+
def delete_attribute(self, *, name: str) -> None:
|
|
151
|
+
self.passed_field_attributes.pop(name)
|
|
152
|
+
|
|
153
|
+
def generate_field_copy(self, generator: "SchemaGenerator") -> pydantic.fields.FieldInfo:
|
|
154
|
+
return pydantic.Field(
|
|
155
|
+
**generator.annotation_transformer.change_version_of_annotation(self.passed_field_attributes)
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def _extract_passed_field_attributes(field_info: FieldInfo) -> dict[str, object]:
|
|
160
|
+
return {
|
|
161
|
+
k: v
|
|
162
|
+
for k, v in (field_info._attributes_set | collect_known_metadata(field_info.metadata)[0]).items()
|
|
163
|
+
if k in _all_field_arg_names and not (k == "frozen" and v is None)
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
@dataclasses.dataclass(**DATACLASS_SLOTS)
|
|
168
|
+
class _ModelBundle:
|
|
169
|
+
enums: dict[type[Enum], "_EnumWrapper"]
|
|
170
|
+
schemas: dict[type[BaseModel], "_PydanticModelWrapper"]
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
@dataclasses.dataclass(**DATACLASS_SLOTS, **DATACLASS_KW_ONLY)
|
|
174
|
+
class _RuntimeSchemaGenContext:
|
|
175
|
+
version_bundle: "VersionBundle"
|
|
176
|
+
current_version: "Union[Version, HeadVersion]"
|
|
177
|
+
models: _ModelBundle
|
|
178
|
+
latest_version: "Version" = dataclasses.field(init=False)
|
|
179
|
+
|
|
180
|
+
def __post_init__(self):
|
|
181
|
+
self.latest_version = max(self.version_bundle.versions, key=lambda v: v.value)
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def migrate_response_body(
|
|
185
|
+
versions: "VersionBundle",
|
|
186
|
+
latest_response_model: type[pydantic.BaseModel],
|
|
187
|
+
*,
|
|
188
|
+
latest_body: Any,
|
|
189
|
+
version: Union[VersionType, date],
|
|
190
|
+
) -> Any:
|
|
191
|
+
"""Convert the data to a specific version
|
|
192
|
+
|
|
193
|
+
Apply all version changes from latest until the passed version in reverse order
|
|
194
|
+
and wrap the result in the correct version of latest_response_model
|
|
195
|
+
"""
|
|
196
|
+
if isinstance(version, date):
|
|
197
|
+
version = version.isoformat()
|
|
198
|
+
version = versions._get_closest_lesser_version(version)
|
|
199
|
+
if version not in versions._version_values_set:
|
|
200
|
+
raise CadwynError(f"Version {version} not found in version bundle")
|
|
201
|
+
response = ResponseInfo(Response(status_code=200), body=latest_body)
|
|
202
|
+
migrated_response = versions._migrate_response(
|
|
203
|
+
response,
|
|
204
|
+
current_version=version,
|
|
205
|
+
head_response_model=latest_response_model,
|
|
206
|
+
head_route=None,
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
versioned_response_model: type[pydantic.BaseModel] = generate_versioned_models(versions)[str(version)][
|
|
210
|
+
latest_response_model
|
|
211
|
+
]
|
|
212
|
+
return versioned_response_model.model_validate(migrated_response.body)
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def _unwrap_model(model: type[_T_ANY_MODEL]) -> type[_T_ANY_MODEL]:
|
|
216
|
+
while hasattr(model, "__cadwyn_original_model__"):
|
|
217
|
+
model = model.__cadwyn_original_model__ # pyright: ignore[reportAttributeAccessIssue]
|
|
218
|
+
return model
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
@dataclasses.dataclass(**DATACLASS_SLOTS, **DATACLASS_KW_ONLY)
|
|
222
|
+
class _ValidatorWrapper:
|
|
223
|
+
kwargs: dict[str, Any]
|
|
224
|
+
func: Callable
|
|
225
|
+
decorator: Callable
|
|
226
|
+
is_deleted: bool = False
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
@dataclasses.dataclass(**DATACLASS_SLOTS, **DATACLASS_KW_ONLY)
|
|
230
|
+
class _PerFieldValidatorWrapper(_ValidatorWrapper):
|
|
231
|
+
fields: list[str] = dataclasses.field(default_factory=list)
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def _wrap_validator(func: Callable, is_pydantic_v1_style_validator: Any, decorator_info: _decorators.DecoratorInfo):
|
|
235
|
+
# This is only for pydantic v1 style validators
|
|
236
|
+
func = fully_unwrap_decorator(func, is_pydantic_v1_style_validator)
|
|
237
|
+
if inspect.ismethod(func):
|
|
238
|
+
func = func.__func__
|
|
239
|
+
kwargs = dataclasses.asdict(decorator_info)
|
|
240
|
+
decorator_fields = kwargs.pop("fields", None)
|
|
241
|
+
|
|
242
|
+
# wrapped_property is not accepted by computed_field()
|
|
243
|
+
if isinstance(decorator_info, ComputedFieldInfo):
|
|
244
|
+
kwargs.pop("wrapped_property", None)
|
|
245
|
+
|
|
246
|
+
actual_decorator = PYDANTIC_DECORATOR_TYPE_TO_DECORATOR_MAP[type(decorator_info)]
|
|
247
|
+
if is_pydantic_v1_style_validator:
|
|
248
|
+
# There's an inconsistency in their interfaces so we gotta resort to this
|
|
249
|
+
mode = kwargs.pop("mode", "after")
|
|
250
|
+
kwargs["pre"] = mode != "after"
|
|
251
|
+
if (
|
|
252
|
+
isinstance(decorator_info, RootValidatorDecoratorInfo) and decorator_info.mode == "after"
|
|
253
|
+
): # pragma: no cover # TODO
|
|
254
|
+
kwargs["skip_on_failure"] = True
|
|
255
|
+
if decorator_fields is not None:
|
|
256
|
+
return _PerFieldValidatorWrapper(
|
|
257
|
+
func=func, fields=list(decorator_fields), decorator=actual_decorator, kwargs=kwargs
|
|
258
|
+
)
|
|
259
|
+
else:
|
|
260
|
+
return _ValidatorWrapper(func=func, decorator=actual_decorator, kwargs=kwargs)
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
def _is_dunder(attr_name: str):
|
|
264
|
+
return attr_name.startswith("__") and attr_name.endswith("__")
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def _wrap_pydantic_model(model: type[_T_PYDANTIC_MODEL]) -> "_PydanticModelWrapper[_T_PYDANTIC_MODEL]":
|
|
268
|
+
# In case we have a forwardref within one of the fields
|
|
269
|
+
# For example, when "from __future__ import annotations" is used in the file with the schema
|
|
270
|
+
if model is not BaseModel:
|
|
271
|
+
model.model_rebuild(raise_errors=False)
|
|
272
|
+
model = cast("type[_T_PYDANTIC_MODEL]", model)
|
|
273
|
+
|
|
274
|
+
decorators = _get_model_decorators(model)
|
|
275
|
+
validators = {}
|
|
276
|
+
for decorator_wrapper in decorators:
|
|
277
|
+
if decorator_wrapper.cls_var_name not in model.__dict__:
|
|
278
|
+
continue
|
|
279
|
+
|
|
280
|
+
wrapped_validator = _wrap_validator(decorator_wrapper.func, decorator_wrapper.shim, decorator_wrapper.info)
|
|
281
|
+
validators[decorator_wrapper.cls_var_name] = wrapped_validator
|
|
282
|
+
|
|
283
|
+
def _rebuild_annotated(name: str):
|
|
284
|
+
if field_info := model.model_fields.get(name):
|
|
285
|
+
if not field_info.metadata:
|
|
286
|
+
return field_info.annotation
|
|
287
|
+
|
|
288
|
+
if sys.version_info >= (3, 13):
|
|
289
|
+
return Annotated.__getitem__((field_info.annotation, *field_info.metadata)) # pyright: ignore[reportAttributeAccessIssue]
|
|
290
|
+
else:
|
|
291
|
+
return Annotated.__class_getitem__((field_info.annotation, *field_info.metadata)) # pyright: ignore[reportAttributeAccessIssue]
|
|
292
|
+
return model.__annotations__[name] # pragma: no cover
|
|
293
|
+
|
|
294
|
+
annotations = {
|
|
295
|
+
name: value if not isinstance(value, str) else _rebuild_annotated(name)
|
|
296
|
+
for name, value in model.__annotations__.items()
|
|
297
|
+
}
|
|
298
|
+
|
|
299
|
+
if sys.version_info >= (3, 10):
|
|
300
|
+
defined_fields = model.__annotations__
|
|
301
|
+
else:
|
|
302
|
+
# Before 3.9, pydantic fills model_fields with all fields -- even the ones that were inherited.
|
|
303
|
+
# So we need to get the list of fields from the AST.
|
|
304
|
+
try:
|
|
305
|
+
defined_fields, _ = _get_field_and_validator_names_from_model(model)
|
|
306
|
+
except OSError: # pragma: no cover
|
|
307
|
+
defined_fields = model.model_fields
|
|
308
|
+
annotations = {
|
|
309
|
+
name: value
|
|
310
|
+
for name, value in annotations.items()
|
|
311
|
+
# We need to filter out fields that were inherited
|
|
312
|
+
if name not in model.model_fields or name in defined_fields
|
|
313
|
+
}
|
|
314
|
+
fields = {
|
|
315
|
+
field_name: PydanticFieldWrapper(
|
|
316
|
+
model.model_fields[field_name],
|
|
317
|
+
annotations[field_name],
|
|
318
|
+
field_name,
|
|
319
|
+
)
|
|
320
|
+
for field_name in model.__annotations__
|
|
321
|
+
if field_name in defined_fields and field_name in model.model_fields
|
|
322
|
+
}
|
|
323
|
+
|
|
324
|
+
main_attributes = fields | validators
|
|
325
|
+
other_attributes = {
|
|
326
|
+
attr_name: attr_val
|
|
327
|
+
for attr_name, attr_val in model.__dict__.items()
|
|
328
|
+
if attr_name not in main_attributes
|
|
329
|
+
and not (_is_dunder(attr_name) or attr_name in {"_abc_impl", "model_fields", "model_computed_fields"})
|
|
330
|
+
}
|
|
331
|
+
|
|
332
|
+
other_attributes |= {
|
|
333
|
+
"model_config": model.model_config,
|
|
334
|
+
"__module__": model.__module__,
|
|
335
|
+
"__qualname__": model.__qualname__,
|
|
336
|
+
}
|
|
337
|
+
return _PydanticModelWrapper(
|
|
338
|
+
model,
|
|
339
|
+
name=model.__name__,
|
|
340
|
+
doc=model.__doc__,
|
|
341
|
+
fields=fields,
|
|
342
|
+
other_attributes=other_attributes,
|
|
343
|
+
validators=validators,
|
|
344
|
+
annotations=annotations,
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
@cache
|
|
349
|
+
def _get_field_and_validator_names_from_model(cls: type) -> tuple[set[_FieldName], set[str]]:
|
|
350
|
+
fields = cls.model_fields
|
|
351
|
+
source = inspect.getsource(cls)
|
|
352
|
+
cls_ast = cast("ast.ClassDef", ast.parse(textwrap.dedent(source)).body[0])
|
|
353
|
+
validator_names = (
|
|
354
|
+
_get_validator_info_or_none(node)
|
|
355
|
+
for node in cls_ast.body
|
|
356
|
+
if isinstance(node, ast.FunctionDef) and node.decorator_list
|
|
357
|
+
)
|
|
358
|
+
validator_names = {name for name in validator_names if name is not None}
|
|
359
|
+
|
|
360
|
+
return (
|
|
361
|
+
{
|
|
362
|
+
node.target.id
|
|
363
|
+
for node in cls_ast.body
|
|
364
|
+
if isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name) and node.target.id in fields
|
|
365
|
+
},
|
|
366
|
+
validator_names,
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
def _get_validator_info_or_none(method: ast.FunctionDef) -> Union[str, None]:
|
|
371
|
+
for decorator in method.decorator_list:
|
|
372
|
+
# The cases we handle here:
|
|
373
|
+
# * `Name(id="root_validator")`
|
|
374
|
+
# * `Call(func=Name(id="validator"), args=[Constant(value="foo")])`
|
|
375
|
+
# * `Attribute(value=Name(id="pydantic"), attr="root_validator")`
|
|
376
|
+
# * `Call(func=Attribute(value=Name(id="pydantic"), attr="root_validator"), args=[])`
|
|
377
|
+
|
|
378
|
+
if (isinstance(decorator, ast.Call) and ast.unparse(decorator.func).endswith("validator")) or (
|
|
379
|
+
isinstance(decorator, (ast.Name, ast.Attribute)) and ast.unparse(decorator).endswith("validator")
|
|
380
|
+
):
|
|
381
|
+
return method.name
|
|
382
|
+
return None
|
|
383
|
+
|
|
384
|
+
|
|
385
|
+
@final
|
|
386
|
+
@dataclasses.dataclass(**DATACLASS_SLOTS)
|
|
387
|
+
class _PydanticModelWrapper(Generic[_T_PYDANTIC_MODEL]):
|
|
388
|
+
cls: type[_T_PYDANTIC_MODEL] = dataclasses.field(repr=False)
|
|
389
|
+
name: str
|
|
390
|
+
doc: Union[str, None] = dataclasses.field(repr=False)
|
|
391
|
+
fields: Annotated[
|
|
392
|
+
dict["_FieldName", PydanticFieldWrapper],
|
|
393
|
+
Doc(
|
|
394
|
+
"Fields that belong to this model, not to its parents. "
|
|
395
|
+
"I.e. The ones that were either defined or overridden "
|
|
396
|
+
),
|
|
397
|
+
] = dataclasses.field(repr=False)
|
|
398
|
+
validators: dict[str, Union[_PerFieldValidatorWrapper, _ValidatorWrapper]] = dataclasses.field(repr=False)
|
|
399
|
+
other_attributes: dict[str, Any] = dataclasses.field(repr=False)
|
|
400
|
+
annotations: dict[str, Any] = dataclasses.field(repr=False)
|
|
401
|
+
_parents: Union[list[Self], None] = dataclasses.field(init=False, default=None, repr=False)
|
|
402
|
+
|
|
403
|
+
def __post_init__(self):
|
|
404
|
+
# This isn't actually supposed to run, it's just a precaution
|
|
405
|
+
while hasattr(self.cls, "__cadwyn_original_model__"): # pragma: no cover
|
|
406
|
+
self.cls = self.cls.__cadwyn_original_model__ # pyright: ignore[reportAttributeAccessIssue]
|
|
407
|
+
|
|
408
|
+
for k, annotation in self.annotations.items():
|
|
409
|
+
if get_origin(annotation) == Annotated:
|
|
410
|
+
sub_annotations = get_args(annotation)
|
|
411
|
+
# Annotated cannot be copied and is cached based on "==" and "hash", while annotated_types.Interval are
|
|
412
|
+
# frozen and so are consistently hashed
|
|
413
|
+
self.annotations[k] = _AnnotatedAlias(
|
|
414
|
+
copy.deepcopy(sub_annotations[0]), tuple(copy.deepcopy(sub_ann) for sub_ann in sub_annotations[1:])
|
|
415
|
+
)
|
|
416
|
+
|
|
417
|
+
def __deepcopy__(self, memo: dict[int, Any]):
|
|
418
|
+
result = _PydanticModelWrapper(
|
|
419
|
+
self.cls,
|
|
420
|
+
name=self.name,
|
|
421
|
+
doc=self.doc,
|
|
422
|
+
fields=copy.deepcopy(self.fields),
|
|
423
|
+
validators=copy.deepcopy(self.validators),
|
|
424
|
+
other_attributes=copy.deepcopy(self.other_attributes),
|
|
425
|
+
annotations=copy.deepcopy(self.annotations),
|
|
426
|
+
)
|
|
427
|
+
memo[id(self)] = result
|
|
428
|
+
return result
|
|
429
|
+
|
|
430
|
+
def __hash__(self) -> int:
|
|
431
|
+
return hash(id(self))
|
|
432
|
+
|
|
433
|
+
def _get_parents(self, schemas: "dict[type, Self]"):
|
|
434
|
+
if self._parents is not None:
|
|
435
|
+
return self._parents
|
|
436
|
+
parents = []
|
|
437
|
+
for base in self.cls.mro()[1:]:
|
|
438
|
+
if base in schemas:
|
|
439
|
+
parents.append(schemas[base])
|
|
440
|
+
elif lenient_issubclass(base, BaseModel):
|
|
441
|
+
parents.append(_wrap_pydantic_model(base))
|
|
442
|
+
self._parents = parents
|
|
443
|
+
return parents
|
|
444
|
+
|
|
445
|
+
def _get_defined_fields_through_mro(self, schemas: "dict[type, Self]") -> dict[str, PydanticFieldWrapper]:
|
|
446
|
+
fields = {}
|
|
447
|
+
|
|
448
|
+
for parent in reversed(self._get_parents(schemas)):
|
|
449
|
+
fields |= parent.fields
|
|
450
|
+
|
|
451
|
+
return fields | self.fields
|
|
452
|
+
|
|
453
|
+
def _get_defined_annotations_through_mro(self, schemas: "dict[type, Self]") -> dict[str, Any]:
|
|
454
|
+
annotations = {}
|
|
455
|
+
|
|
456
|
+
for parent in reversed(self._get_parents(schemas)):
|
|
457
|
+
annotations |= parent.annotations
|
|
458
|
+
|
|
459
|
+
return annotations | self.annotations
|
|
460
|
+
|
|
461
|
+
def generate_model_copy(self, generator: "SchemaGenerator") -> type[_T_PYDANTIC_MODEL]:
|
|
462
|
+
per_field_validators = {
|
|
463
|
+
name: validator.decorator(*validator.fields, **validator.kwargs)(validator.func)
|
|
464
|
+
for name, validator in self.validators.items()
|
|
465
|
+
if not validator.is_deleted and type(validator) == _PerFieldValidatorWrapper # noqa: E721
|
|
466
|
+
}
|
|
467
|
+
root_validators = {
|
|
468
|
+
name: validator.decorator(**validator.kwargs)(validator.func)
|
|
469
|
+
for name, validator in self.validators.items()
|
|
470
|
+
if not validator.is_deleted and type(validator) == _ValidatorWrapper # noqa: E721
|
|
471
|
+
}
|
|
472
|
+
fields = {name: field.generate_field_copy(generator) for name, field in self.fields.items()}
|
|
473
|
+
model_copy = type(self.cls)(
|
|
474
|
+
self.name,
|
|
475
|
+
tuple(generator[cast("type[BaseModel]", base)] for base in self.cls.__bases__),
|
|
476
|
+
self.other_attributes
|
|
477
|
+
| per_field_validators
|
|
478
|
+
| root_validators
|
|
479
|
+
| fields
|
|
480
|
+
| {
|
|
481
|
+
"__annotations__": generator.annotation_transformer.change_version_of_annotation(self.annotations),
|
|
482
|
+
"__doc__": self.doc,
|
|
483
|
+
"__qualname__": self.cls.__qualname__.removesuffix(self.cls.__name__) + self.name,
|
|
484
|
+
},
|
|
485
|
+
)
|
|
486
|
+
model_copy.__cadwyn_original_model__ = self.cls
|
|
487
|
+
return model_copy
|
|
488
|
+
|
|
489
|
+
|
|
490
|
+
def is_regular_function(call: Callable):
|
|
491
|
+
return isinstance(call, (types.FunctionType, types.MethodType))
|
|
492
|
+
|
|
493
|
+
|
|
494
|
+
class _CallableWrapper:
|
|
495
|
+
# __eq__ and __hash__ are needed to make sure that dependency overrides work correctly.
|
|
496
|
+
# They are based on putting dependencies (functions) as keys for the dictionary so if we want to be able to
|
|
497
|
+
# override the wrapper, we need to make sure that it is equivalent to the original in __hash__ and __eq__
|
|
498
|
+
|
|
499
|
+
def __init__(self, original_callable: Callable) -> None:
|
|
500
|
+
super().__init__()
|
|
501
|
+
self._original_callable = original_callable
|
|
502
|
+
if not is_regular_function(original_callable):
|
|
503
|
+
original_callable = original_callable.__call__
|
|
504
|
+
|
|
505
|
+
functools.update_wrapper(self, original_callable)
|
|
506
|
+
|
|
507
|
+
@property
|
|
508
|
+
def __globals__(self):
|
|
509
|
+
# FastAPI uses __globals__ to resolve forward references in type hints
|
|
510
|
+
# It's supposed to be an attribute on the function but we use it as property to prevent python
|
|
511
|
+
# from trying to pickle globals when we deepcopy this wrapper
|
|
512
|
+
return self._original_callable.__globals__
|
|
513
|
+
|
|
514
|
+
def __call__(self, *args: Any, **kwargs: Any):
|
|
515
|
+
return self._original_callable(*args, **kwargs)
|
|
516
|
+
|
|
517
|
+
def __hash__(self):
|
|
518
|
+
return hash(self._original_callable)
|
|
519
|
+
|
|
520
|
+
def __eq__(self, value: object) -> bool:
|
|
521
|
+
return self._original_callable == value
|
|
522
|
+
|
|
523
|
+
|
|
524
|
+
class _AsyncCallableWrapper(_CallableWrapper):
|
|
525
|
+
async def __call__(self, *args: Any, **kwargs: Any):
|
|
526
|
+
return await self._original_callable(*args, **kwargs)
|
|
527
|
+
|
|
528
|
+
|
|
529
|
+
class _GeneratorCallableWrapper(_CallableWrapper):
|
|
530
|
+
def __call__(self, *args: Any, **kwargs: Any):
|
|
531
|
+
yield from self._original_callable(*args, **kwargs)
|
|
532
|
+
|
|
533
|
+
|
|
534
|
+
class _AsyncGeneratorCallableWrapper(_CallableWrapper):
|
|
535
|
+
async def __call__(self, *args: Any, **kwargs: Any):
|
|
536
|
+
async for value in self._original_callable(*args, **kwargs):
|
|
537
|
+
yield value
|
|
538
|
+
|
|
539
|
+
|
|
540
|
+
@final
|
|
541
|
+
class _AnnotationTransformer:
|
|
542
|
+
def __init__(self, generator: "SchemaGenerator") -> None:
|
|
543
|
+
self.generator = generator
|
|
544
|
+
|
|
545
|
+
def change_version_of_annotation(self, annotation: Any) -> Any:
|
|
546
|
+
"""Recursively go through all annotations and change them to annotations corresponding to the version passed.
|
|
547
|
+
|
|
548
|
+
So if we had a annotation "UserResponse" from "head" version, and we passed version of "2022-11-16", it would
|
|
549
|
+
replace "UserResponse" with the the same class but from the "2022-11-16" version.
|
|
550
|
+
|
|
551
|
+
"""
|
|
552
|
+
if isinstance(annotation, dict):
|
|
553
|
+
return {
|
|
554
|
+
self.change_version_of_annotation(key): self.change_version_of_annotation(value)
|
|
555
|
+
for key, value in annotation.items()
|
|
556
|
+
}
|
|
557
|
+
|
|
558
|
+
elif isinstance(annotation, (list, tuple)):
|
|
559
|
+
return type(annotation)(self.change_version_of_annotation(v) for v in annotation)
|
|
560
|
+
else:
|
|
561
|
+
return self._change_version_of_a_non_container_annotation(annotation)
|
|
562
|
+
|
|
563
|
+
def migrate_router_to_version(self, router: fastapi.routing.APIRouter):
|
|
564
|
+
for route in router.routes:
|
|
565
|
+
if not isinstance(route, fastapi.routing.APIRoute):
|
|
566
|
+
continue
|
|
567
|
+
self.migrate_route_to_version(route)
|
|
568
|
+
|
|
569
|
+
def migrate_route_to_version(self, route: fastapi.routing.APIRoute, *, ignore_response_model: bool = False):
|
|
570
|
+
if route.response_model is not None and not ignore_response_model:
|
|
571
|
+
route.response_model = self.change_version_of_annotation(route.response_model)
|
|
572
|
+
route.response_field = fastapi.utils.create_model_field(
|
|
573
|
+
name="Response_" + route.unique_id,
|
|
574
|
+
type_=route.response_model,
|
|
575
|
+
mode="serialization",
|
|
576
|
+
)
|
|
577
|
+
route.secure_cloned_response_field = fastapi.utils.create_cloned_field(route.response_field)
|
|
578
|
+
route.dependencies = self.change_version_of_annotation(route.dependencies)
|
|
579
|
+
route.endpoint = self.change_version_of_annotation(route.endpoint)
|
|
580
|
+
for callback in route.callbacks or []:
|
|
581
|
+
if not isinstance(callback, fastapi.routing.APIRoute):
|
|
582
|
+
continue
|
|
583
|
+
self.migrate_route_to_version(callback, ignore_response_model=ignore_response_model)
|
|
584
|
+
self._remake_endpoint_dependencies(route)
|
|
585
|
+
|
|
586
|
+
def _change_version_of_a_non_container_annotation(self, annotation: Any) -> Any:
|
|
587
|
+
from typing_inspection.typing_objects import is_any, is_newtype, is_typealiastype
|
|
588
|
+
|
|
589
|
+
if isinstance(annotation, (types.GenericAlias, _BaseGenericAlias)):
|
|
590
|
+
origin = get_origin(annotation)
|
|
591
|
+
args = get_args(annotation)
|
|
592
|
+
# Classvar does not support generic tuple arguments
|
|
593
|
+
if origin is ClassVar:
|
|
594
|
+
return ClassVar[self.change_version_of_annotation(args[0])]
|
|
595
|
+
return origin[tuple(self.change_version_of_annotation(arg) for arg in get_args(annotation))]
|
|
596
|
+
elif is_typealiastype(annotation):
|
|
597
|
+
if (
|
|
598
|
+
annotation.__module__ is not None and (annotation.__module__.startswith("pydantic."))
|
|
599
|
+
) or annotation.__name__ in _PYDANTIC_ALL_EXPORTED_NAMES:
|
|
600
|
+
return annotation
|
|
601
|
+
else:
|
|
602
|
+
return TypeAliasType( # pyright: ignore[reportGeneralTypeIssues]
|
|
603
|
+
name=annotation.__name__,
|
|
604
|
+
value=self.change_version_of_annotation(annotation.__value__),
|
|
605
|
+
type_params=self.change_version_of_annotation(annotation.__type_params__),
|
|
606
|
+
)
|
|
607
|
+
elif isinstance(annotation, fastapi.params.Security):
|
|
608
|
+
return fastapi.params.Security(
|
|
609
|
+
self.change_version_of_annotation(annotation.dependency),
|
|
610
|
+
scopes=annotation.scopes,
|
|
611
|
+
use_cache=annotation.use_cache,
|
|
612
|
+
)
|
|
613
|
+
elif isinstance(annotation, fastapi.params.Depends):
|
|
614
|
+
return fastapi.params.Depends(
|
|
615
|
+
self.change_version_of_annotation(annotation.dependency),
|
|
616
|
+
use_cache=annotation.use_cache,
|
|
617
|
+
)
|
|
618
|
+
elif isinstance(annotation, UnionType): # pragma: no cover
|
|
619
|
+
getitem = typing.Union.__getitem__ # pyright: ignore[reportAttributeAccessIssue]
|
|
620
|
+
return getitem(
|
|
621
|
+
tuple(self.change_version_of_annotation(a) for a in get_args(annotation)),
|
|
622
|
+
)
|
|
623
|
+
elif is_any(annotation) or is_newtype(annotation):
|
|
624
|
+
return annotation
|
|
625
|
+
elif isinstance(annotation, type):
|
|
626
|
+
return self._change_version_of_type(annotation)
|
|
627
|
+
elif callable(annotation):
|
|
628
|
+
if type(annotation).__module__.startswith(
|
|
629
|
+
("fastapi.", "pydantic.", "pydantic_core.", "starlette.")
|
|
630
|
+
) or isinstance(annotation, fastapi.security.base.SecurityBase):
|
|
631
|
+
return annotation
|
|
632
|
+
|
|
633
|
+
def modifier(annotation: Any):
|
|
634
|
+
return self.change_version_of_annotation(annotation)
|
|
635
|
+
|
|
636
|
+
return self._modify_callable_annotations(
|
|
637
|
+
annotation,
|
|
638
|
+
modifier,
|
|
639
|
+
modifier,
|
|
640
|
+
annotation_modifying_wrapper_factory=self._copy_function_through_class_based_wrapper,
|
|
641
|
+
)
|
|
642
|
+
else:
|
|
643
|
+
return annotation
|
|
644
|
+
|
|
645
|
+
def _change_version_of_type(self, annotation: type):
|
|
646
|
+
if lenient_issubclass(annotation, (BaseModel, Enum)):
|
|
647
|
+
return self.generator[annotation]
|
|
648
|
+
else:
|
|
649
|
+
return annotation
|
|
650
|
+
|
|
651
|
+
@classmethod
|
|
652
|
+
def _remake_endpoint_dependencies(cls, route: fastapi.routing.APIRoute):
|
|
653
|
+
# Unlike get_dependant, APIRoute is the public API of FastAPI and it's (almost) guaranteed to be stable.
|
|
654
|
+
|
|
655
|
+
route_copy = fastapi.routing.APIRoute(route.path, route.endpoint, dependencies=route.dependencies)
|
|
656
|
+
route.dependant = route_copy.dependant
|
|
657
|
+
route.body_field = route_copy.body_field
|
|
658
|
+
_add_request_and_response_params(route)
|
|
659
|
+
|
|
660
|
+
@classmethod
|
|
661
|
+
def _modify_callable_annotations( # pragma: no branch # because of lambdas
|
|
662
|
+
cls,
|
|
663
|
+
call: _Call,
|
|
664
|
+
modify_annotations: Callable[[dict[str, Any]], dict[str, Any]] = lambda a: a,
|
|
665
|
+
modify_defaults: Callable[[tuple[Any, ...]], tuple[Any, ...]] = lambda a: a,
|
|
666
|
+
*,
|
|
667
|
+
annotation_modifying_wrapper_factory: Callable[[_Call], _Call],
|
|
668
|
+
) -> _Call:
|
|
669
|
+
annotation_modifying_wrapper = annotation_modifying_wrapper_factory(call)
|
|
670
|
+
old_params = inspect.signature(call).parameters
|
|
671
|
+
callable_annotations = annotation_modifying_wrapper.__annotations__
|
|
672
|
+
callable_annotations = {
|
|
673
|
+
k: v if type(v) is not str else _try_eval_type(v, call.__globals__) for k, v in callable_annotations.items()
|
|
674
|
+
}
|
|
675
|
+
annotation_modifying_wrapper.__annotations__ = modify_annotations(callable_annotations)
|
|
676
|
+
annotation_modifying_wrapper.__defaults__ = modify_defaults(
|
|
677
|
+
tuple(p.default for p in old_params.values() if p.default is not inspect.Signature.empty),
|
|
678
|
+
)
|
|
679
|
+
annotation_modifying_wrapper.__signature__ = cls._generate_signature(
|
|
680
|
+
annotation_modifying_wrapper,
|
|
681
|
+
old_params,
|
|
682
|
+
)
|
|
683
|
+
|
|
684
|
+
return annotation_modifying_wrapper
|
|
685
|
+
|
|
686
|
+
@staticmethod
|
|
687
|
+
def _generate_signature(
|
|
688
|
+
new_callable: Callable,
|
|
689
|
+
old_params: types.MappingProxyType[str, inspect.Parameter],
|
|
690
|
+
):
|
|
691
|
+
parameters = []
|
|
692
|
+
default_counter = 0
|
|
693
|
+
for param in old_params.values():
|
|
694
|
+
if param.default is not inspect.Signature.empty:
|
|
695
|
+
assert new_callable.__defaults__ is not None, ( # noqa: S101
|
|
696
|
+
"Defaults cannot be None here. If it is, you have found a bug in Cadwyn. "
|
|
697
|
+
"Please, report it in our issue tracker."
|
|
698
|
+
)
|
|
699
|
+
default = new_callable.__defaults__[default_counter]
|
|
700
|
+
default_counter += 1
|
|
701
|
+
else:
|
|
702
|
+
default = inspect.Signature.empty
|
|
703
|
+
parameters.append(
|
|
704
|
+
inspect.Parameter(
|
|
705
|
+
param.name,
|
|
706
|
+
param.kind,
|
|
707
|
+
default=default,
|
|
708
|
+
annotation=new_callable.__annotations__.get(
|
|
709
|
+
param.name,
|
|
710
|
+
inspect.Signature.empty,
|
|
711
|
+
),
|
|
712
|
+
),
|
|
713
|
+
)
|
|
714
|
+
return inspect.Signature(
|
|
715
|
+
parameters=parameters,
|
|
716
|
+
return_annotation=new_callable.__annotations__.get(
|
|
717
|
+
"return",
|
|
718
|
+
inspect.Signature.empty,
|
|
719
|
+
),
|
|
720
|
+
)
|
|
721
|
+
|
|
722
|
+
@classmethod
|
|
723
|
+
def _copy_function_through_class_based_wrapper(cls, call: Any):
|
|
724
|
+
"""Separate from copy_endpoint because endpoints MUST be functions in FastAPI, they cannot be cls instances"""
|
|
725
|
+
call = cls._unwrap_callable(call)
|
|
726
|
+
if not is_regular_function(call):
|
|
727
|
+
# This means that the callable is actually an instance of a regular class
|
|
728
|
+
actual_call = call.__call__
|
|
729
|
+
else:
|
|
730
|
+
actual_call = call
|
|
731
|
+
if is_async_gen_callable(actual_call):
|
|
732
|
+
return _AsyncGeneratorCallableWrapper(call)
|
|
733
|
+
elif is_coroutine_callable(actual_call):
|
|
734
|
+
return _AsyncCallableWrapper(call)
|
|
735
|
+
elif is_gen_callable(actual_call):
|
|
736
|
+
return _GeneratorCallableWrapper(call)
|
|
737
|
+
else:
|
|
738
|
+
return _CallableWrapper(call)
|
|
739
|
+
|
|
740
|
+
@staticmethod
|
|
741
|
+
def _unwrap_callable(call: Any) -> Any:
|
|
742
|
+
while hasattr(call, "_original_callable"):
|
|
743
|
+
call = call._original_callable
|
|
744
|
+
|
|
745
|
+
return call
|
|
746
|
+
|
|
747
|
+
|
|
748
|
+
def _add_request_and_response_params(route: APIRoute):
|
|
749
|
+
if not route.dependant.request_param_name:
|
|
750
|
+
route.dependant.request_param_name = _CADWYN_REQUEST_PARAM_NAME
|
|
751
|
+
if not route.dependant.response_param_name:
|
|
752
|
+
route.dependant.response_param_name = _CADWYN_RESPONSE_PARAM_NAME
|
|
753
|
+
|
|
754
|
+
|
|
755
|
+
@final
|
|
756
|
+
class SchemaGenerator:
|
|
757
|
+
__slots__ = "annotation_transformer", "concrete_models", "model_bundle"
|
|
758
|
+
|
|
759
|
+
def __init__(self, model_bundle: _ModelBundle) -> None:
|
|
760
|
+
self.annotation_transformer = _AnnotationTransformer(self)
|
|
761
|
+
self.model_bundle = model_bundle
|
|
762
|
+
self.concrete_models = {}
|
|
763
|
+
self.concrete_models = {
|
|
764
|
+
k: wrapper.generate_model_copy(self)
|
|
765
|
+
for k, wrapper in (self.model_bundle.schemas | self.model_bundle.enums).items()
|
|
766
|
+
}
|
|
767
|
+
|
|
768
|
+
def __getitem__(self, model: type[_T_ANY_MODEL], /) -> type[_T_ANY_MODEL]:
|
|
769
|
+
if (
|
|
770
|
+
not isinstance(model, type)
|
|
771
|
+
or not lenient_issubclass(model, (BaseModel, Enum))
|
|
772
|
+
or model in _DEFAULT_PYDANTIC_CLASSES
|
|
773
|
+
):
|
|
774
|
+
return model
|
|
775
|
+
model = _unwrap_model(model)
|
|
776
|
+
|
|
777
|
+
if model in self.concrete_models:
|
|
778
|
+
return self.concrete_models[model]
|
|
779
|
+
|
|
780
|
+
wrapper = self._get_wrapper_for_model(model)
|
|
781
|
+
model_copy = wrapper.generate_model_copy(self)
|
|
782
|
+
self.concrete_models[model] = model_copy
|
|
783
|
+
return cast("type[_T_ANY_MODEL]", model_copy)
|
|
784
|
+
|
|
785
|
+
@overload
|
|
786
|
+
def _get_wrapper_for_model(self, model: type[BaseModel]) -> "_PydanticModelWrapper[BaseModel]": ...
|
|
787
|
+
@overload
|
|
788
|
+
def _get_wrapper_for_model(self, model: type[Enum]) -> "_EnumWrapper[Enum]": ...
|
|
789
|
+
|
|
790
|
+
def _get_wrapper_for_model(
|
|
791
|
+
self, model: type[Union[BaseModel, Enum]]
|
|
792
|
+
) -> "Union[_PydanticModelWrapper[BaseModel], _EnumWrapper[Enum]]":
|
|
793
|
+
model = _unwrap_model(model)
|
|
794
|
+
|
|
795
|
+
if model in self.model_bundle.schemas:
|
|
796
|
+
return self.model_bundle.schemas[model]
|
|
797
|
+
elif model in self.model_bundle.enums:
|
|
798
|
+
return self.model_bundle.enums[model]
|
|
799
|
+
|
|
800
|
+
if lenient_issubclass(model, BaseModel):
|
|
801
|
+
# TODO: My god, what if one of its fields is in our concrete schemas and we don't use it? :O
|
|
802
|
+
# TODO: Add an argument with our concrete schemas for _wrap_pydantic_model
|
|
803
|
+
wrapper = _wrap_pydantic_model(model)
|
|
804
|
+
self.model_bundle.schemas[model] = wrapper
|
|
805
|
+
elif lenient_issubclass(model, Enum):
|
|
806
|
+
wrapper = _EnumWrapper(model)
|
|
807
|
+
self.model_bundle.enums[model] = wrapper
|
|
808
|
+
else:
|
|
809
|
+
assert_never(model)
|
|
810
|
+
return wrapper
|
|
811
|
+
|
|
812
|
+
|
|
813
|
+
@cache
|
|
814
|
+
def generate_versioned_models(versions: "VersionBundle") -> "dict[str, SchemaGenerator]":
|
|
815
|
+
models = _create_model_bundle(versions)
|
|
816
|
+
|
|
817
|
+
version_to_context_map = {}
|
|
818
|
+
context = _RuntimeSchemaGenContext(current_version=versions.head_version, models=models, version_bundle=versions)
|
|
819
|
+
_migrate_classes(context)
|
|
820
|
+
|
|
821
|
+
for version in versions.versions:
|
|
822
|
+
context = _RuntimeSchemaGenContext(current_version=version, models=models, version_bundle=versions)
|
|
823
|
+
version_to_context_map[str(version.value)] = SchemaGenerator(copy.deepcopy(models))
|
|
824
|
+
# note that the last migration will not contain any version changes so we don't need to save the results
|
|
825
|
+
_migrate_classes(context)
|
|
826
|
+
|
|
827
|
+
return version_to_context_map
|
|
828
|
+
|
|
829
|
+
|
|
830
|
+
def _create_model_bundle(versions: "VersionBundle"):
|
|
831
|
+
return _ModelBundle(
|
|
832
|
+
enums={enum: _EnumWrapper(enum) for enum in versions.versioned_enums.values()},
|
|
833
|
+
schemas={schema: _wrap_pydantic_model(schema) for schema in versions.versioned_schemas.values()},
|
|
834
|
+
)
|
|
835
|
+
|
|
836
|
+
|
|
837
|
+
def _migrate_classes(context: _RuntimeSchemaGenContext) -> None:
|
|
838
|
+
for version_change in context.current_version.changes:
|
|
839
|
+
_apply_alter_schema_instructions(
|
|
840
|
+
context.models.schemas,
|
|
841
|
+
version_change.alter_schema_instructions,
|
|
842
|
+
version_change.__name__,
|
|
843
|
+
)
|
|
844
|
+
_apply_alter_enum_instructions(
|
|
845
|
+
context.models.enums,
|
|
846
|
+
version_change.alter_enum_instructions,
|
|
847
|
+
version_change.__name__,
|
|
848
|
+
)
|
|
849
|
+
|
|
850
|
+
|
|
851
|
+
def _apply_alter_schema_instructions(
|
|
852
|
+
modified_schemas: dict[type, _PydanticModelWrapper],
|
|
853
|
+
alter_schema_instructions: Sequence[Union[AlterSchemaSubInstruction, SchemaHadInstruction]],
|
|
854
|
+
version_change_name: str,
|
|
855
|
+
) -> None:
|
|
856
|
+
for alter_schema_instruction in alter_schema_instructions:
|
|
857
|
+
schema_info = modified_schemas[alter_schema_instruction.schema]
|
|
858
|
+
if isinstance(alter_schema_instruction, FieldExistedAsInstruction):
|
|
859
|
+
_add_field_to_model(schema_info, modified_schemas, alter_schema_instruction, version_change_name)
|
|
860
|
+
elif isinstance(alter_schema_instruction, (FieldHadInstruction, FieldDidntHaveInstruction)):
|
|
861
|
+
_change_field_in_model(
|
|
862
|
+
schema_info,
|
|
863
|
+
modified_schemas,
|
|
864
|
+
alter_schema_instruction,
|
|
865
|
+
version_change_name,
|
|
866
|
+
)
|
|
867
|
+
elif isinstance(alter_schema_instruction, FieldDidntExistInstruction):
|
|
868
|
+
_delete_field_from_model(schema_info, alter_schema_instruction.name, version_change_name)
|
|
869
|
+
elif isinstance(alter_schema_instruction, ValidatorExistedInstruction):
|
|
870
|
+
validator_name = get_name_of_function_wrapped_in_pydantic_validator(alter_schema_instruction.validator)
|
|
871
|
+
raw_validator = cast(
|
|
872
|
+
"pydantic._internal._decorators.PydanticDescriptorProxy", alter_schema_instruction.validator
|
|
873
|
+
)
|
|
874
|
+
schema_info.validators[validator_name] = _wrap_validator(
|
|
875
|
+
raw_validator.wrapped,
|
|
876
|
+
is_pydantic_v1_style_validator=raw_validator.shim,
|
|
877
|
+
decorator_info=raw_validator.decorator_info,
|
|
878
|
+
)
|
|
879
|
+
elif isinstance(alter_schema_instruction, ValidatorDidntExistInstruction):
|
|
880
|
+
if alter_schema_instruction.name not in schema_info.validators:
|
|
881
|
+
raise InvalidGenerationInstructionError(
|
|
882
|
+
f'You tried to delete a validator "{alter_schema_instruction.name}" from "{schema_info.name}" '
|
|
883
|
+
f'in "{version_change_name}" but it doesn\'t have such a validator.',
|
|
884
|
+
)
|
|
885
|
+
if schema_info.validators[alter_schema_instruction.name].is_deleted:
|
|
886
|
+
raise InvalidGenerationInstructionError(
|
|
887
|
+
f'You tried to delete a validator "{alter_schema_instruction.name}" from "{schema_info.name}" '
|
|
888
|
+
f'in "{version_change_name}" but it is already deleted.',
|
|
889
|
+
)
|
|
890
|
+
schema_info.validators[alter_schema_instruction.name].is_deleted = True
|
|
891
|
+
elif isinstance(alter_schema_instruction, SchemaHadInstruction):
|
|
892
|
+
_change_model(schema_info, alter_schema_instruction, version_change_name)
|
|
893
|
+
else:
|
|
894
|
+
assert_never(alter_schema_instruction)
|
|
895
|
+
|
|
896
|
+
|
|
897
|
+
def _apply_alter_enum_instructions(
|
|
898
|
+
enums: "dict[type, _EnumWrapper]",
|
|
899
|
+
alter_enum_instructions: Sequence[AlterEnumSubInstruction],
|
|
900
|
+
version_change_name: str,
|
|
901
|
+
):
|
|
902
|
+
for alter_enum_instruction in alter_enum_instructions:
|
|
903
|
+
enum = enums[alter_enum_instruction.enum]
|
|
904
|
+
if isinstance(alter_enum_instruction, EnumDidntHaveMembersInstruction):
|
|
905
|
+
for member in alter_enum_instruction.members:
|
|
906
|
+
if member not in enum.members:
|
|
907
|
+
raise InvalidGenerationInstructionError(
|
|
908
|
+
f'You tried to delete a member "{member}" from "{enum.cls.__name__}" '
|
|
909
|
+
f'in "{version_change_name}" but it doesn\'t have such a member.',
|
|
910
|
+
)
|
|
911
|
+
enum.members.pop(member)
|
|
912
|
+
elif isinstance(alter_enum_instruction, EnumHadMembersInstruction):
|
|
913
|
+
for member, member_value in alter_enum_instruction.members.items():
|
|
914
|
+
if member in enum.members and enum.members[member] == member_value:
|
|
915
|
+
raise InvalidGenerationInstructionError(
|
|
916
|
+
f'You tried to add a member "{member}" to "{enum.cls.__name__}" '
|
|
917
|
+
f'in "{version_change_name}" but there is already a member with that name and value.',
|
|
918
|
+
)
|
|
919
|
+
enum.members[member] = member_value
|
|
920
|
+
else:
|
|
921
|
+
assert_never(alter_enum_instruction)
|
|
922
|
+
|
|
923
|
+
|
|
924
|
+
def _change_model(
|
|
925
|
+
model: _PydanticModelWrapper,
|
|
926
|
+
alter_schema_instruction: SchemaHadInstruction,
|
|
927
|
+
version_change_name: str,
|
|
928
|
+
):
|
|
929
|
+
if alter_schema_instruction.name == model.name:
|
|
930
|
+
raise InvalidGenerationInstructionError(
|
|
931
|
+
f'You tried to change the name of "{model.name}" in "{version_change_name}" '
|
|
932
|
+
"but it already has the name you tried to assign.",
|
|
933
|
+
)
|
|
934
|
+
|
|
935
|
+
model.name = alter_schema_instruction.name
|
|
936
|
+
|
|
937
|
+
|
|
938
|
+
def _add_field_to_model(
|
|
939
|
+
model: _PydanticModelWrapper,
|
|
940
|
+
schemas: "dict[type, _PydanticModelWrapper]",
|
|
941
|
+
alter_schema_instruction: FieldExistedAsInstruction,
|
|
942
|
+
version_change_name: str,
|
|
943
|
+
):
|
|
944
|
+
defined_fields = model._get_defined_fields_through_mro(schemas)
|
|
945
|
+
if alter_schema_instruction.name in defined_fields:
|
|
946
|
+
raise InvalidGenerationInstructionError(
|
|
947
|
+
f'You tried to add a field "{alter_schema_instruction.name}" to "{model.name}" '
|
|
948
|
+
f'in "{version_change_name}" but there is already a field with that name.',
|
|
949
|
+
)
|
|
950
|
+
|
|
951
|
+
# Special handling for ClassVar fields
|
|
952
|
+
if get_origin(alter_schema_instruction.field.annotation) is ClassVar:
|
|
953
|
+
# ClassVar fields should not be in model.fields, only in annotations and other_attributes
|
|
954
|
+
model.annotations[alter_schema_instruction.name] = alter_schema_instruction.field.annotation
|
|
955
|
+
# Set the actual ClassVar value in other_attributes
|
|
956
|
+
if alter_schema_instruction.field.default is not PydanticUndefined:
|
|
957
|
+
model.other_attributes[alter_schema_instruction.name] = alter_schema_instruction.field.default
|
|
958
|
+
else:
|
|
959
|
+
# Regular field handling
|
|
960
|
+
field = PydanticFieldWrapper(
|
|
961
|
+
alter_schema_instruction.field, alter_schema_instruction.field.annotation, alter_schema_instruction.name
|
|
962
|
+
)
|
|
963
|
+
model.fields[alter_schema_instruction.name] = field
|
|
964
|
+
model.annotations[alter_schema_instruction.name] = alter_schema_instruction.field.annotation
|
|
965
|
+
|
|
966
|
+
|
|
967
|
+
def _change_field_in_model(
|
|
968
|
+
model: _PydanticModelWrapper,
|
|
969
|
+
schemas: "dict[type, _PydanticModelWrapper]",
|
|
970
|
+
alter_schema_instruction: Union[FieldHadInstruction, FieldDidntHaveInstruction],
|
|
971
|
+
version_change_name: str,
|
|
972
|
+
):
|
|
973
|
+
defined_annotations = model._get_defined_annotations_through_mro(schemas)
|
|
974
|
+
defined_fields = model._get_defined_fields_through_mro(schemas)
|
|
975
|
+
if alter_schema_instruction.name not in defined_fields:
|
|
976
|
+
raise InvalidGenerationInstructionError(
|
|
977
|
+
f'You tried to change the field "{alter_schema_instruction.name}" from '
|
|
978
|
+
f'"{model.name}" in "{version_change_name}" but it doesn\'t have such a field.',
|
|
979
|
+
)
|
|
980
|
+
|
|
981
|
+
field = defined_fields[alter_schema_instruction.name]
|
|
982
|
+
model.fields[alter_schema_instruction.name] = field
|
|
983
|
+
model.annotations[alter_schema_instruction.name] = defined_annotations[alter_schema_instruction.name]
|
|
984
|
+
|
|
985
|
+
if isinstance(alter_schema_instruction, FieldHadInstruction):
|
|
986
|
+
# TODO: This naming sucks
|
|
987
|
+
_change_field(
|
|
988
|
+
model,
|
|
989
|
+
alter_schema_instruction,
|
|
990
|
+
version_change_name,
|
|
991
|
+
defined_annotations,
|
|
992
|
+
field,
|
|
993
|
+
model.annotations[alter_schema_instruction.name],
|
|
994
|
+
)
|
|
995
|
+
else:
|
|
996
|
+
_delete_field_attributes(
|
|
997
|
+
model,
|
|
998
|
+
alter_schema_instruction,
|
|
999
|
+
version_change_name,
|
|
1000
|
+
field,
|
|
1001
|
+
model.annotations[alter_schema_instruction.name],
|
|
1002
|
+
)
|
|
1003
|
+
|
|
1004
|
+
|
|
1005
|
+
def _change_field(
|
|
1006
|
+
model: _PydanticModelWrapper,
|
|
1007
|
+
alter_schema_instruction: FieldHadInstruction,
|
|
1008
|
+
version_change_name: str,
|
|
1009
|
+
defined_annotations: dict[str, Any],
|
|
1010
|
+
field: PydanticFieldWrapper,
|
|
1011
|
+
annotation: Union[Any, None],
|
|
1012
|
+
):
|
|
1013
|
+
if alter_schema_instruction.type is not Sentinel:
|
|
1014
|
+
if field.annotation == alter_schema_instruction.type:
|
|
1015
|
+
raise InvalidGenerationInstructionError(
|
|
1016
|
+
f'You tried to change the type of field "{alter_schema_instruction.name}" to '
|
|
1017
|
+
f'"{alter_schema_instruction.type}" from "{model.name}" in "{version_change_name}" '
|
|
1018
|
+
f'but it already has type "{field.annotation}"',
|
|
1019
|
+
)
|
|
1020
|
+
field.annotation = alter_schema_instruction.type
|
|
1021
|
+
model.annotations[alter_schema_instruction.name] = alter_schema_instruction.type
|
|
1022
|
+
|
|
1023
|
+
if alter_schema_instruction.new_name is not Sentinel:
|
|
1024
|
+
if alter_schema_instruction.new_name == alter_schema_instruction.name:
|
|
1025
|
+
raise InvalidGenerationInstructionError(
|
|
1026
|
+
f'You tried to change the name of field "{alter_schema_instruction.name}" '
|
|
1027
|
+
f'from "{model.name}" in "{version_change_name}" '
|
|
1028
|
+
"but it already has that name.",
|
|
1029
|
+
)
|
|
1030
|
+
model.fields[alter_schema_instruction.new_name] = model.fields.pop(alter_schema_instruction.name)
|
|
1031
|
+
model.annotations[alter_schema_instruction.new_name] = model.annotations.pop(
|
|
1032
|
+
alter_schema_instruction.name,
|
|
1033
|
+
defined_annotations[alter_schema_instruction.name],
|
|
1034
|
+
)
|
|
1035
|
+
|
|
1036
|
+
for attr_name in alter_schema_instruction.field_changes.__dataclass_fields__:
|
|
1037
|
+
attr_value = getattr(alter_schema_instruction.field_changes, attr_name)
|
|
1038
|
+
if attr_value is not Sentinel:
|
|
1039
|
+
if field.passed_field_attributes.get(attr_name, Sentinel) == attr_value:
|
|
1040
|
+
raise InvalidGenerationInstructionError(
|
|
1041
|
+
f'You tried to change the attribute "{attr_name}" of field '
|
|
1042
|
+
f'"{alter_schema_instruction.name}" '
|
|
1043
|
+
f'from "{model.name}" to {attr_value!r} in "{version_change_name}" '
|
|
1044
|
+
"but it already has that value.",
|
|
1045
|
+
)
|
|
1046
|
+
field.update_attribute(name=attr_name, value=attr_value)
|
|
1047
|
+
|
|
1048
|
+
|
|
1049
|
+
def _delete_field_attributes(
|
|
1050
|
+
model: _PydanticModelWrapper,
|
|
1051
|
+
alter_schema_instruction: FieldDidntHaveInstruction,
|
|
1052
|
+
version_change_name: str,
|
|
1053
|
+
field: PydanticFieldWrapper,
|
|
1054
|
+
annotation: Any,
|
|
1055
|
+
) -> None:
|
|
1056
|
+
for attr_name in alter_schema_instruction.attributes:
|
|
1057
|
+
deleted = False
|
|
1058
|
+
|
|
1059
|
+
if attr_name in field.passed_field_attributes:
|
|
1060
|
+
field.delete_attribute(name=attr_name)
|
|
1061
|
+
deleted = True
|
|
1062
|
+
if get_origin(annotation) == Annotated and any( # pragma: no branch
|
|
1063
|
+
hasattr(sub_ann, attr_name) for sub_ann in get_args(annotation)
|
|
1064
|
+
):
|
|
1065
|
+
for sub_ann in get_args(annotation):
|
|
1066
|
+
if hasattr(sub_ann, attr_name):
|
|
1067
|
+
object.__setattr__(sub_ann, attr_name, None)
|
|
1068
|
+
deleted = True
|
|
1069
|
+
if not deleted:
|
|
1070
|
+
raise InvalidGenerationInstructionError(
|
|
1071
|
+
f'You tried to delete the attribute "{attr_name}" of field "{alter_schema_instruction.name}" '
|
|
1072
|
+
f'from "{model.name}" in "{version_change_name}" '
|
|
1073
|
+
"but it already doesn't have that attribute.",
|
|
1074
|
+
)
|
|
1075
|
+
|
|
1076
|
+
|
|
1077
|
+
def _delete_field_from_model(model: _PydanticModelWrapper, field_name: str, version_change_name: str):
|
|
1078
|
+
if field_name in model.fields:
|
|
1079
|
+
model.fields.pop(field_name)
|
|
1080
|
+
model.annotations.pop(field_name)
|
|
1081
|
+
for validator_name, validator in model.validators.copy().items():
|
|
1082
|
+
if isinstance(validator, _PerFieldValidatorWrapper) and field_name in validator.fields:
|
|
1083
|
+
validator.fields.remove(field_name)
|
|
1084
|
+
# TODO: This behavior doesn't feel natural
|
|
1085
|
+
if not validator.fields:
|
|
1086
|
+
model.validators[validator_name].is_deleted = True
|
|
1087
|
+
|
|
1088
|
+
elif (
|
|
1089
|
+
field_name in model.validators
|
|
1090
|
+
and isinstance(model.validators[field_name], _ValidatorWrapper)
|
|
1091
|
+
and hasattr(model.validators[field_name], "decorator")
|
|
1092
|
+
and model.validators[field_name].decorator == pydantic.computed_field
|
|
1093
|
+
):
|
|
1094
|
+
validator = model.validators[field_name]
|
|
1095
|
+
model.validators[field_name].is_deleted = True
|
|
1096
|
+
model.annotations.pop(field_name, None)
|
|
1097
|
+
elif field_name in model.annotations and get_origin(model.annotations[field_name]) is ClassVar:
|
|
1098
|
+
# Handle ClassVar fields - they exist in annotations but not in model.fields
|
|
1099
|
+
model.annotations.pop(field_name)
|
|
1100
|
+
# Also remove the attribute from other_attributes if it exists there
|
|
1101
|
+
model.other_attributes.pop(field_name, None)
|
|
1102
|
+
else:
|
|
1103
|
+
raise InvalidGenerationInstructionError(
|
|
1104
|
+
f'You tried to delete a field "{field_name}" from "{model.name}" '
|
|
1105
|
+
f'in "{version_change_name}" but it doesn\'t have such a field.',
|
|
1106
|
+
)
|
|
1107
|
+
|
|
1108
|
+
|
|
1109
|
+
class _DummyEnum(Enum):
|
|
1110
|
+
pass
|
|
1111
|
+
|
|
1112
|
+
|
|
1113
|
+
@final
|
|
1114
|
+
class _EnumWrapper(Generic[_T_ENUM]):
|
|
1115
|
+
__slots__ = "cls", "members", "name"
|
|
1116
|
+
|
|
1117
|
+
def __init__(self, cls: type[_T_ENUM]):
|
|
1118
|
+
self.cls = _unwrap_model(cls)
|
|
1119
|
+
self.name = cls.__name__
|
|
1120
|
+
self.members = {member.name: member.value for member in cls}
|
|
1121
|
+
|
|
1122
|
+
def __deepcopy__(self, memo: Any):
|
|
1123
|
+
result = _EnumWrapper(self.cls)
|
|
1124
|
+
result.members = self.members.copy()
|
|
1125
|
+
memo[id(self)] = result
|
|
1126
|
+
return result
|
|
1127
|
+
|
|
1128
|
+
def generate_model_copy(self, generator: "SchemaGenerator") -> type[_T_ENUM]:
|
|
1129
|
+
enum_dict = Enum.__prepare__(self.name, self.cls.__bases__)
|
|
1130
|
+
|
|
1131
|
+
raw_member_map = {k: v.value if isinstance(v, Enum) else v for k, v in self.members.items()}
|
|
1132
|
+
initialization_namespace = self._get_initialization_namespace_for_enum(self.cls) | raw_member_map
|
|
1133
|
+
for attr_name, attr in initialization_namespace.items():
|
|
1134
|
+
enum_dict[attr_name] = attr
|
|
1135
|
+
enum_dict["__doc__"] = self.cls.__doc__
|
|
1136
|
+
model_copy = cast("type[_T_ENUM]", type(self.name, self.cls.__bases__, enum_dict))
|
|
1137
|
+
model_copy.__cadwyn_original_model__ = self.cls # pyright: ignore[reportAttributeAccessIssue]
|
|
1138
|
+
return model_copy
|
|
1139
|
+
|
|
1140
|
+
@staticmethod
|
|
1141
|
+
def _get_initialization_namespace_for_enum(enum_cls: type[Enum]):
|
|
1142
|
+
mro_without_the_class_itself = enum_cls.mro()[1:]
|
|
1143
|
+
|
|
1144
|
+
mro_dict = {}
|
|
1145
|
+
for cls in reversed(mro_without_the_class_itself):
|
|
1146
|
+
mro_dict.update(cls.__dict__)
|
|
1147
|
+
|
|
1148
|
+
return {
|
|
1149
|
+
k: v
|
|
1150
|
+
for k, v in enum_cls.__dict__.items()
|
|
1151
|
+
if k not in enum_cls._member_names_
|
|
1152
|
+
and k not in _DummyEnum.__dict__
|
|
1153
|
+
and (k not in mro_dict or mro_dict[k] is not v)
|
|
1154
|
+
}
|
|
1155
|
+
|
|
1156
|
+
|
|
1157
|
+
def _try_eval_type(value: Any, globals: dict[str, Any]) -> Any:
|
|
1158
|
+
new_value, success = pydantic_try_eval_type(value, globals)
|
|
1159
|
+
if success:
|
|
1160
|
+
return new_value
|
|
1161
|
+
else: # pragma: no cover # Can't imagine when this would happen
|
|
1162
|
+
return value
|