cadwyn 3.15.10__py3-none-any.whl → 4.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cadwyn might be problematic. Click here for more details.
- cadwyn/__init__.py +26 -8
- cadwyn/__main__.py +46 -90
- cadwyn/_asts.py +9 -133
- cadwyn/_importer.py +31 -0
- cadwyn/_render.py +152 -0
- cadwyn/_utils.py +7 -107
- cadwyn/applications.py +5 -34
- cadwyn/exceptions.py +11 -3
- cadwyn/middleware.py +4 -4
- cadwyn/route_generation.py +22 -450
- cadwyn/routing.py +2 -5
- cadwyn/schema_generation.py +946 -0
- cadwyn/structure/__init__.py +0 -2
- cadwyn/structure/schemas.py +50 -49
- cadwyn/structure/versions.py +24 -137
- {cadwyn-3.15.10.dist-info → cadwyn-4.1.0.dist-info}/METADATA +4 -5
- cadwyn-4.1.0.dist-info/RECORD +27 -0
- {cadwyn-3.15.10.dist-info → cadwyn-4.1.0.dist-info}/WHEEL +1 -1
- cadwyn/_compat.py +0 -151
- cadwyn/_package_utils.py +0 -45
- cadwyn/codegen/README.md +0 -10
- cadwyn/codegen/__init__.py +0 -10
- cadwyn/codegen/_common.py +0 -168
- cadwyn/codegen/_main.py +0 -279
- cadwyn/codegen/_plugins/__init__.py +0 -0
- cadwyn/codegen/_plugins/class_migrations.py +0 -423
- cadwyn/codegen/_plugins/class_rebuilding.py +0 -109
- cadwyn/codegen/_plugins/class_renaming.py +0 -49
- cadwyn/codegen/_plugins/import_auto_adding.py +0 -64
- cadwyn/codegen/_plugins/module_migrations.py +0 -15
- cadwyn/main.py +0 -11
- cadwyn/structure/modules.py +0 -39
- cadwyn-3.15.10.dist-info/RECORD +0 -38
- {cadwyn-3.15.10.dist-info → cadwyn-4.1.0.dist-info}/LICENSE +0 -0
- {cadwyn-3.15.10.dist-info → cadwyn-4.1.0.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,946 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import dataclasses
|
|
3
|
+
import functools
|
|
4
|
+
import inspect
|
|
5
|
+
import types
|
|
6
|
+
import typing
|
|
7
|
+
from collections.abc import Callable, Sequence
|
|
8
|
+
from datetime import date
|
|
9
|
+
from enum import Enum
|
|
10
|
+
from functools import cache
|
|
11
|
+
from typing import (
|
|
12
|
+
TYPE_CHECKING,
|
|
13
|
+
Annotated,
|
|
14
|
+
Any,
|
|
15
|
+
Generic,
|
|
16
|
+
TypeAlias,
|
|
17
|
+
TypeVar,
|
|
18
|
+
_BaseGenericAlias, # pyright: ignore[reportAttributeAccessIssue]
|
|
19
|
+
cast,
|
|
20
|
+
final,
|
|
21
|
+
get_args,
|
|
22
|
+
get_origin,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
import fastapi.params
|
|
26
|
+
import fastapi.security.base
|
|
27
|
+
import fastapi.utils
|
|
28
|
+
import pydantic
|
|
29
|
+
import pydantic._internal._decorators
|
|
30
|
+
from fastapi import Response
|
|
31
|
+
from fastapi.routing import APIRoute
|
|
32
|
+
from issubclass import issubclass
|
|
33
|
+
from pydantic import BaseModel, Field, RootModel
|
|
34
|
+
from pydantic._internal import _decorators
|
|
35
|
+
from pydantic._internal._decorators import (
|
|
36
|
+
FieldSerializerDecoratorInfo,
|
|
37
|
+
FieldValidatorDecoratorInfo,
|
|
38
|
+
ModelSerializerDecoratorInfo,
|
|
39
|
+
ModelValidatorDecoratorInfo,
|
|
40
|
+
RootValidatorDecoratorInfo,
|
|
41
|
+
ValidatorDecoratorInfo,
|
|
42
|
+
)
|
|
43
|
+
from pydantic.fields import ComputedFieldInfo, FieldInfo
|
|
44
|
+
from typing_extensions import Doc, Self, _AnnotatedAlias, assert_never
|
|
45
|
+
|
|
46
|
+
from cadwyn._utils import Sentinel, UnionType, fully_unwrap_decorator
|
|
47
|
+
from cadwyn.exceptions import InvalidGenerationInstructionError
|
|
48
|
+
from cadwyn.structure.common import VersionDate
|
|
49
|
+
from cadwyn.structure.data import ResponseInfo
|
|
50
|
+
from cadwyn.structure.enums import AlterEnumSubInstruction, EnumDidntHaveMembersInstruction, EnumHadMembersInstruction
|
|
51
|
+
from cadwyn.structure.schemas import (
|
|
52
|
+
AlterSchemaSubInstruction,
|
|
53
|
+
FieldDidntExistInstruction,
|
|
54
|
+
FieldDidntHaveInstruction,
|
|
55
|
+
FieldExistedAsInstruction,
|
|
56
|
+
FieldHadInstruction,
|
|
57
|
+
SchemaHadInstruction,
|
|
58
|
+
ValidatorDidntExistInstruction,
|
|
59
|
+
ValidatorExistedInstruction,
|
|
60
|
+
_get_model_decorators,
|
|
61
|
+
)
|
|
62
|
+
from cadwyn.structure.versions import _CADWYN_REQUEST_PARAM_NAME, _CADWYN_RESPONSE_PARAM_NAME, VersionBundle
|
|
63
|
+
|
|
64
|
+
if TYPE_CHECKING:
|
|
65
|
+
from cadwyn.structure.versions import HeadVersion, Version, VersionBundle
|
|
66
|
+
|
|
67
|
+
_Call = TypeVar("_Call", bound=Callable[..., Any])
|
|
68
|
+
|
|
69
|
+
_FieldName: TypeAlias = str
|
|
70
|
+
_T_ANY_MODEL = TypeVar("_T_ANY_MODEL", bound=BaseModel | Enum)
|
|
71
|
+
_T_ENUM = TypeVar("_T_ENUM", bound=Enum)
|
|
72
|
+
|
|
73
|
+
_T_PYDANTIC_MODEL = TypeVar("_T_PYDANTIC_MODEL", bound=BaseModel)
|
|
74
|
+
PYDANTIC_DECORATOR_TYPE_TO_DECORATOR_MAP = {
|
|
75
|
+
ValidatorDecoratorInfo: pydantic.validator, # pyright: ignore[reportDeprecated]
|
|
76
|
+
FieldValidatorDecoratorInfo: pydantic.field_validator,
|
|
77
|
+
FieldSerializerDecoratorInfo: pydantic.field_serializer,
|
|
78
|
+
RootValidatorDecoratorInfo: pydantic.root_validator, # pyright: ignore[reportDeprecated]
|
|
79
|
+
ModelValidatorDecoratorInfo: pydantic.model_validator,
|
|
80
|
+
ModelSerializerDecoratorInfo: pydantic.model_serializer,
|
|
81
|
+
ComputedFieldInfo: pydantic.computed_field,
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
VALIDATOR_CONFIG_KEY = "__validators__"
|
|
86
|
+
_all_field_arg_names = sorted(
|
|
87
|
+
[
|
|
88
|
+
name
|
|
89
|
+
for name, param in inspect.signature(Field).parameters.items()
|
|
90
|
+
if param.kind in {inspect._ParameterKind.KEYWORD_ONLY, inspect._ParameterKind.POSITIONAL_OR_KEYWORD}
|
|
91
|
+
],
|
|
92
|
+
)
|
|
93
|
+
EXTRA_FIELD_NAME = "json_schema_extra"
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
_empty_field_info = Field()
|
|
97
|
+
dict_of_empty_field_info = {k: getattr(_empty_field_info, k) for k in FieldInfo.__slots__}
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
@dataclasses.dataclass(slots=True)
|
|
101
|
+
class PydanticFieldWrapper:
|
|
102
|
+
"""We DO NOT maintain field.metadata at all"""
|
|
103
|
+
|
|
104
|
+
init_model_field: dataclasses.InitVar[FieldInfo]
|
|
105
|
+
|
|
106
|
+
annotation: Any
|
|
107
|
+
passed_field_attributes: dict[str, Any] = dataclasses.field(init=False)
|
|
108
|
+
|
|
109
|
+
def __post_init__(self, init_model_field: FieldInfo):
|
|
110
|
+
self.passed_field_attributes = _extract_passed_field_attributes(init_model_field)
|
|
111
|
+
|
|
112
|
+
def update_attribute(self, *, name: str, value: Any):
|
|
113
|
+
self.passed_field_attributes[name] = value
|
|
114
|
+
|
|
115
|
+
def delete_attribute(self, *, name: str) -> None:
|
|
116
|
+
self.passed_field_attributes.pop(name)
|
|
117
|
+
|
|
118
|
+
def generate_field_copy(self, generator: "SchemaGenerator") -> pydantic.fields.FieldInfo:
|
|
119
|
+
return pydantic.Field(
|
|
120
|
+
**generator.annotation_transformer.change_version_of_annotation(self.passed_field_attributes)
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def _extract_passed_field_attributes(field_info: FieldInfo):
|
|
125
|
+
attributes = {
|
|
126
|
+
attr_name: field_info._attributes_set[attr_name]
|
|
127
|
+
for attr_name in _all_field_arg_names
|
|
128
|
+
if attr_name in field_info._attributes_set
|
|
129
|
+
}
|
|
130
|
+
# PydanticV2 always adds frozen to _attributes_set but we don't want it if it wasn't explicitly set
|
|
131
|
+
if attributes.get("frozen", ...) is None:
|
|
132
|
+
attributes.pop("frozen")
|
|
133
|
+
return attributes
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
@dataclasses.dataclass(slots=True)
|
|
137
|
+
class _ModelBundle:
|
|
138
|
+
enums: dict[type[Enum], "_EnumWrapper"]
|
|
139
|
+
schemas: dict[type[BaseModel], "_PydanticRuntimeModelWrapper"]
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
@dataclasses.dataclass(slots=True, kw_only=True)
|
|
143
|
+
class _RuntimeSchemaGenContext:
|
|
144
|
+
version_bundle: "VersionBundle"
|
|
145
|
+
current_version: "Version | HeadVersion"
|
|
146
|
+
models: _ModelBundle
|
|
147
|
+
latest_version: "Version" = dataclasses.field(init=False)
|
|
148
|
+
|
|
149
|
+
def __post_init__(self):
|
|
150
|
+
self.latest_version = max(self.version_bundle.versions, key=lambda v: v.value)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def migrate_response_body(
|
|
154
|
+
versions: "VersionBundle",
|
|
155
|
+
latest_response_model: type[pydantic.BaseModel],
|
|
156
|
+
*,
|
|
157
|
+
latest_body: Any,
|
|
158
|
+
version: VersionDate | str,
|
|
159
|
+
):
|
|
160
|
+
"""Convert the data to a specific version by applying all version changes from latest until that version
|
|
161
|
+
in reverse order and wrapping the result in the correct version of latest_response_model.
|
|
162
|
+
"""
|
|
163
|
+
if isinstance(version, str):
|
|
164
|
+
version = date.fromisoformat(version)
|
|
165
|
+
response = ResponseInfo(Response(status_code=200), body=latest_body)
|
|
166
|
+
migrated_response = versions._migrate_response(
|
|
167
|
+
response,
|
|
168
|
+
current_version=version,
|
|
169
|
+
head_response_model=latest_response_model,
|
|
170
|
+
path="\0\0\0",
|
|
171
|
+
method="GET",
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
version = versions._get_closest_lesser_version(version)
|
|
175
|
+
|
|
176
|
+
versioned_response_model: type[pydantic.BaseModel] = generate_versioned_models(versions)[str(version)][
|
|
177
|
+
latest_response_model
|
|
178
|
+
]
|
|
179
|
+
return versioned_response_model.model_validate(migrated_response.body)
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def _unwrap_model(model: type[_T_ANY_MODEL]) -> type[_T_ANY_MODEL]:
|
|
183
|
+
while hasattr(model, "__cadwyn_original_model__"):
|
|
184
|
+
model = model.__cadwyn_original_model__ # pyright: ignore[reportAttributeAccessIssue]
|
|
185
|
+
return model
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
@dataclasses.dataclass(slots=True, kw_only=True)
|
|
189
|
+
class _ValidatorWrapper:
|
|
190
|
+
kwargs: dict[str, Any]
|
|
191
|
+
func: Callable
|
|
192
|
+
decorator: Callable
|
|
193
|
+
is_deleted: bool = False
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
@dataclasses.dataclass(slots=True, kw_only=True)
|
|
197
|
+
class _PerFieldValidatorWrapper(_ValidatorWrapper):
|
|
198
|
+
fields: list[str]
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def _wrap_validator(func: Callable, is_pydantic_v1_style_validator: Any, decorator_info: _decorators.DecoratorInfo):
|
|
202
|
+
# This is only for pydantic v1 style validators
|
|
203
|
+
func = fully_unwrap_decorator(func, is_pydantic_v1_style_validator)
|
|
204
|
+
if inspect.ismethod(func):
|
|
205
|
+
func = func.__func__
|
|
206
|
+
kwargs = dataclasses.asdict(decorator_info)
|
|
207
|
+
decorator_fields = kwargs.pop("fields", None)
|
|
208
|
+
actual_decorator = PYDANTIC_DECORATOR_TYPE_TO_DECORATOR_MAP[type(decorator_info)]
|
|
209
|
+
if is_pydantic_v1_style_validator:
|
|
210
|
+
# There's an inconsistency in their interfaces so we gotta resort to this
|
|
211
|
+
mode = kwargs.pop("mode", "after")
|
|
212
|
+
kwargs["pre"] = mode != "after"
|
|
213
|
+
if decorator_fields is not None:
|
|
214
|
+
return _PerFieldValidatorWrapper(
|
|
215
|
+
func=func, fields=list(decorator_fields), decorator=actual_decorator, kwargs=kwargs
|
|
216
|
+
)
|
|
217
|
+
else:
|
|
218
|
+
return _ValidatorWrapper(func=func, decorator=actual_decorator, kwargs=kwargs)
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def _is_dunder(attr_name: str):
|
|
222
|
+
return attr_name.startswith("__") and attr_name.endswith("__")
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def _wrap_pydantic_model(model: type[_T_PYDANTIC_MODEL]) -> "_PydanticRuntimeModelWrapper[_T_PYDANTIC_MODEL]":
|
|
226
|
+
decorators = _get_model_decorators(model)
|
|
227
|
+
validators = {}
|
|
228
|
+
for decorator_wrapper in decorators:
|
|
229
|
+
if decorator_wrapper.cls_var_name not in model.__dict__:
|
|
230
|
+
continue
|
|
231
|
+
|
|
232
|
+
wrapped_validator = _wrap_validator(decorator_wrapper.func, decorator_wrapper.shim, decorator_wrapper.info)
|
|
233
|
+
validators[decorator_wrapper.cls_var_name] = wrapped_validator
|
|
234
|
+
fields = {
|
|
235
|
+
field_name: PydanticFieldWrapper(model.model_fields[field_name], model.__annotations__[field_name])
|
|
236
|
+
for field_name in model.__annotations__
|
|
237
|
+
}
|
|
238
|
+
|
|
239
|
+
main_attributes = fields | validators
|
|
240
|
+
other_attributes = {
|
|
241
|
+
attr_name: attr_val
|
|
242
|
+
for attr_name, attr_val in model.__dict__.items()
|
|
243
|
+
if attr_name not in main_attributes
|
|
244
|
+
and not (_is_dunder(attr_name) or attr_name in {"_abc_impl", "model_fields", "model_computed_fields"})
|
|
245
|
+
}
|
|
246
|
+
other_attributes |= {
|
|
247
|
+
"model_config": model.model_config,
|
|
248
|
+
"__module__": model.__module__,
|
|
249
|
+
"__qualname__": model.__qualname__,
|
|
250
|
+
}
|
|
251
|
+
return _PydanticRuntimeModelWrapper(
|
|
252
|
+
model,
|
|
253
|
+
name=model.__name__,
|
|
254
|
+
doc=model.__doc__,
|
|
255
|
+
fields=fields,
|
|
256
|
+
other_attributes=other_attributes,
|
|
257
|
+
validators=validators,
|
|
258
|
+
annotations=model.__annotations__.copy(),
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
@final
|
|
263
|
+
@dataclasses.dataclass(slots=True)
|
|
264
|
+
class _PydanticRuntimeModelWrapper(Generic[_T_PYDANTIC_MODEL]):
|
|
265
|
+
cls: type[_T_PYDANTIC_MODEL]
|
|
266
|
+
name: str
|
|
267
|
+
doc: str | None
|
|
268
|
+
fields: Annotated[
|
|
269
|
+
dict["_FieldName", PydanticFieldWrapper],
|
|
270
|
+
Doc(
|
|
271
|
+
"Fields that belong to this model, not to its parents. I.e. The ones that were either defined or overriden "
|
|
272
|
+
),
|
|
273
|
+
]
|
|
274
|
+
validators: dict[str, _PerFieldValidatorWrapper | _ValidatorWrapper]
|
|
275
|
+
other_attributes: dict[str, Any]
|
|
276
|
+
annotations: dict[str, Any]
|
|
277
|
+
_parents: list[Self] | None = dataclasses.field(init=False, default=None)
|
|
278
|
+
|
|
279
|
+
def __post_init__(self):
|
|
280
|
+
# This isn't actually supposed to run, it's just a precaution
|
|
281
|
+
while hasattr(self.cls, "__cadwyn_original_model__"): # pragma: no cover
|
|
282
|
+
self.cls = self.cls.__cadwyn_original_model__ # pyright: ignore[reportAttributeAccessIssue]
|
|
283
|
+
|
|
284
|
+
for k, annotation in self.annotations.items():
|
|
285
|
+
if get_origin(annotation) == Annotated:
|
|
286
|
+
sub_annotations = get_args(annotation)
|
|
287
|
+
# Annotated cannot be copied and is cached based on "==" and "hash", while annotated_types.Interval are
|
|
288
|
+
# frozen and so are consistently hashed
|
|
289
|
+
self.annotations[k] = _AnnotatedAlias(
|
|
290
|
+
copy.deepcopy(sub_annotations[0]), tuple(copy.deepcopy(sub_ann) for sub_ann in sub_annotations[1:])
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
def __deepcopy__(self, memo: dict[int, Any]):
|
|
294
|
+
result = _PydanticRuntimeModelWrapper(
|
|
295
|
+
self.cls,
|
|
296
|
+
name=self.name,
|
|
297
|
+
doc=self.doc,
|
|
298
|
+
fields=copy.deepcopy(self.fields),
|
|
299
|
+
validators=copy.deepcopy(self.validators),
|
|
300
|
+
other_attributes=copy.deepcopy(self.other_attributes),
|
|
301
|
+
annotations=copy.deepcopy(self.annotations),
|
|
302
|
+
)
|
|
303
|
+
memo[id(self)] = result
|
|
304
|
+
return result
|
|
305
|
+
|
|
306
|
+
def _get_parents(self, schemas: "dict[type, Self]"):
|
|
307
|
+
if self._parents is not None:
|
|
308
|
+
return self._parents
|
|
309
|
+
parents = []
|
|
310
|
+
for base in self.cls.mro()[1:]:
|
|
311
|
+
if base in schemas:
|
|
312
|
+
parents.append(schemas[base])
|
|
313
|
+
elif issubclass(base, BaseModel):
|
|
314
|
+
parents.append(_wrap_pydantic_model(base))
|
|
315
|
+
self._parents = parents
|
|
316
|
+
return parents
|
|
317
|
+
|
|
318
|
+
def _get_defined_fields_through_mro(self, schemas: "dict[type, Self]") -> dict[str, PydanticFieldWrapper]:
|
|
319
|
+
fields = {}
|
|
320
|
+
|
|
321
|
+
for parent in reversed(self._get_parents(schemas)):
|
|
322
|
+
fields |= parent.fields
|
|
323
|
+
|
|
324
|
+
return fields | self.fields
|
|
325
|
+
|
|
326
|
+
def _get_defined_annotations_through_mro(self, schemas: "dict[type, Self]") -> dict[str, Any]:
|
|
327
|
+
annotations = {}
|
|
328
|
+
|
|
329
|
+
for parent in reversed(self._get_parents(schemas)):
|
|
330
|
+
annotations |= parent.annotations
|
|
331
|
+
|
|
332
|
+
return annotations | self.annotations
|
|
333
|
+
|
|
334
|
+
def generate_model_copy(self, generator: "SchemaGenerator") -> type[_T_PYDANTIC_MODEL]:
|
|
335
|
+
per_field_validators = {
|
|
336
|
+
name: validator.decorator(*validator.fields, **validator.kwargs)(validator.func)
|
|
337
|
+
for name, validator in self.validators.items()
|
|
338
|
+
if not validator.is_deleted and type(validator) == _PerFieldValidatorWrapper
|
|
339
|
+
}
|
|
340
|
+
root_validators = {
|
|
341
|
+
name: validator.decorator(**validator.kwargs)(validator.func)
|
|
342
|
+
for name, validator in self.validators.items()
|
|
343
|
+
if not validator.is_deleted and type(validator) == _ValidatorWrapper
|
|
344
|
+
}
|
|
345
|
+
fields = {name: field.generate_field_copy(generator) for name, field in self.fields.items()}
|
|
346
|
+
model_copy = type(self.cls)(
|
|
347
|
+
self.name,
|
|
348
|
+
tuple(generator[base] for base in self.cls.__bases__),
|
|
349
|
+
self.other_attributes
|
|
350
|
+
| per_field_validators
|
|
351
|
+
| root_validators
|
|
352
|
+
| fields
|
|
353
|
+
| {
|
|
354
|
+
"__annotations__": generator.annotation_transformer.change_version_of_annotation(self.annotations),
|
|
355
|
+
"__doc__": self.doc,
|
|
356
|
+
"__qualname__": self.cls.__qualname__.removesuffix(self.cls.__name__) + self.name,
|
|
357
|
+
},
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
model_copy.__cadwyn_original_model__ = self.cls
|
|
361
|
+
return model_copy
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
class _CallableWrapper:
|
|
365
|
+
"""__eq__ and __hash__ are needed to make sure that dependency overrides work correctly.
|
|
366
|
+
They are based on putting dependencies (functions) as keys for the dictionary so if we want to be able to
|
|
367
|
+
override the wrapper, we need to make sure that it is equivalent to the original in __hash__ and __eq__
|
|
368
|
+
"""
|
|
369
|
+
|
|
370
|
+
def __init__(self, original_callable: Callable) -> None:
|
|
371
|
+
super().__init__()
|
|
372
|
+
self._original_callable = original_callable
|
|
373
|
+
functools.update_wrapper(self, original_callable)
|
|
374
|
+
|
|
375
|
+
@property
|
|
376
|
+
def __globals__(self):
|
|
377
|
+
"""FastAPI uses __globals__ to resolve forward references in type hints
|
|
378
|
+
It's supposed to be an attribute on the function but we use it as property to prevent python
|
|
379
|
+
from trying to pickle globals when we deepcopy this wrapper
|
|
380
|
+
"""
|
|
381
|
+
#
|
|
382
|
+
return self._original_callable.__globals__
|
|
383
|
+
|
|
384
|
+
def __call__(self, *args: Any, **kwargs: Any):
|
|
385
|
+
return self._original_callable(*args, **kwargs)
|
|
386
|
+
|
|
387
|
+
def __hash__(self):
|
|
388
|
+
return hash(self._original_callable)
|
|
389
|
+
|
|
390
|
+
def __eq__(self, value: object) -> bool:
|
|
391
|
+
return self._original_callable == value # pyright: ignore[reportUnnecessaryComparison]
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
class _AsyncCallableWrapper(_CallableWrapper):
|
|
395
|
+
async def __call__(self, *args: Any, **kwargs: Any):
|
|
396
|
+
return await self._original_callable(*args, **kwargs)
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
@final
|
|
400
|
+
class _AnnotationTransformer:
|
|
401
|
+
def __init__(self, generator: "SchemaGenerator") -> None:
|
|
402
|
+
# This cache is not here for speeding things up. It's for preventing the creation of copies of the same object
|
|
403
|
+
# because such copies could produce weird behaviors at runtime, especially if you/fastapi do any comparisons.
|
|
404
|
+
# It's defined here and not on the method because of this: https://youtu.be/sVjtp6tGo0g
|
|
405
|
+
self.generator = generator
|
|
406
|
+
self.change_versions_of_a_non_container_annotation = functools.cache(
|
|
407
|
+
self._change_version_of_a_non_container_annotation
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
def change_version_of_annotation(self, annotation: Any) -> Any:
|
|
411
|
+
"""Recursively go through all annotations and change them to the
|
|
412
|
+
annotations corresponding to the version passed.
|
|
413
|
+
|
|
414
|
+
So if we had a annotation "UserResponse" from "head" version, and we passed version of "2022-11-16", it would
|
|
415
|
+
replace "UserResponse" with the the same class but from the "2022-11-16" version.
|
|
416
|
+
|
|
417
|
+
"""
|
|
418
|
+
if isinstance(annotation, dict):
|
|
419
|
+
return {
|
|
420
|
+
self.change_version_of_annotation(key): self.change_version_of_annotation(value)
|
|
421
|
+
for key, value in annotation.items()
|
|
422
|
+
}
|
|
423
|
+
|
|
424
|
+
elif isinstance(annotation, list | tuple):
|
|
425
|
+
return type(annotation)(self.change_version_of_annotation(v) for v in annotation)
|
|
426
|
+
else:
|
|
427
|
+
return self.change_versions_of_a_non_container_annotation(annotation)
|
|
428
|
+
|
|
429
|
+
def migrate_router_to_version(self, router: fastapi.routing.APIRouter):
|
|
430
|
+
for route in router.routes:
|
|
431
|
+
if not isinstance(route, fastapi.routing.APIRoute):
|
|
432
|
+
continue
|
|
433
|
+
self.migrate_route_to_version(route)
|
|
434
|
+
|
|
435
|
+
def migrate_route_to_version(self, route: fastapi.routing.APIRoute, *, ignore_response_model: bool = False):
|
|
436
|
+
if route.response_model is not None and not ignore_response_model:
|
|
437
|
+
route.response_model = self.change_version_of_annotation(route.response_model)
|
|
438
|
+
route.response_field = fastapi.utils.create_response_field(
|
|
439
|
+
name="Response_" + route.unique_id,
|
|
440
|
+
type_=route.response_model,
|
|
441
|
+
mode="serialization",
|
|
442
|
+
)
|
|
443
|
+
route.secure_cloned_response_field = fastapi.utils.create_cloned_field(route.response_field)
|
|
444
|
+
route.dependencies = self.change_version_of_annotation(route.dependencies)
|
|
445
|
+
route.endpoint = self.change_version_of_annotation(route.endpoint)
|
|
446
|
+
for callback in route.callbacks or []:
|
|
447
|
+
if not isinstance(callback, fastapi.routing.APIRoute):
|
|
448
|
+
continue
|
|
449
|
+
self.migrate_route_to_version(callback, ignore_response_model=ignore_response_model)
|
|
450
|
+
self._remake_endpoint_dependencies(route)
|
|
451
|
+
|
|
452
|
+
def _change_version_of_a_non_container_annotation(self, annotation: Any) -> Any:
|
|
453
|
+
if isinstance(annotation, _BaseGenericAlias | types.GenericAlias):
|
|
454
|
+
return get_origin(annotation)[tuple(self.change_version_of_annotation(arg) for arg in get_args(annotation))]
|
|
455
|
+
elif isinstance(annotation, fastapi.params.Depends):
|
|
456
|
+
return fastapi.params.Depends(
|
|
457
|
+
self.change_version_of_annotation(annotation.dependency),
|
|
458
|
+
use_cache=annotation.use_cache,
|
|
459
|
+
)
|
|
460
|
+
elif isinstance(annotation, UnionType):
|
|
461
|
+
getitem = typing.Union.__getitem__ # pyright: ignore[reportAttributeAccessIssue]
|
|
462
|
+
return getitem(
|
|
463
|
+
tuple(self.change_version_of_annotation(a) for a in get_args(annotation)),
|
|
464
|
+
)
|
|
465
|
+
elif annotation is Any or isinstance(annotation, typing.NewType):
|
|
466
|
+
return annotation
|
|
467
|
+
elif isinstance(annotation, type):
|
|
468
|
+
return self._change_version_of_type(annotation)
|
|
469
|
+
elif callable(annotation):
|
|
470
|
+
if type(annotation).__module__.startswith(
|
|
471
|
+
("fastapi.", "pydantic.", "pydantic_core.", "starlette.")
|
|
472
|
+
) or isinstance(annotation, fastapi.params.Security | fastapi.security.base.SecurityBase):
|
|
473
|
+
return annotation
|
|
474
|
+
|
|
475
|
+
def modifier(annotation: Any):
|
|
476
|
+
return self.change_version_of_annotation(annotation)
|
|
477
|
+
|
|
478
|
+
return self._modify_callable_annotations(
|
|
479
|
+
annotation,
|
|
480
|
+
modifier,
|
|
481
|
+
modifier,
|
|
482
|
+
annotation_modifying_wrapper_factory=self._copy_function_through_class_based_wrapper,
|
|
483
|
+
)
|
|
484
|
+
else:
|
|
485
|
+
return annotation
|
|
486
|
+
|
|
487
|
+
def _change_version_of_type(self, annotation: type):
|
|
488
|
+
if issubclass(annotation, BaseModel | Enum):
|
|
489
|
+
return self.generator[annotation]
|
|
490
|
+
else:
|
|
491
|
+
return annotation
|
|
492
|
+
|
|
493
|
+
@classmethod
|
|
494
|
+
def _remake_endpoint_dependencies(cls, route: fastapi.routing.APIRoute):
|
|
495
|
+
# Unlike get_dependant, APIRoute is the public API of FastAPI and it's (almost) guaranteed to be stable.
|
|
496
|
+
|
|
497
|
+
route_copy = fastapi.routing.APIRoute(route.path, route.endpoint, dependencies=route.dependencies)
|
|
498
|
+
route.dependant = route_copy.dependant
|
|
499
|
+
route.body_field = route_copy.body_field
|
|
500
|
+
_add_request_and_response_params(route)
|
|
501
|
+
|
|
502
|
+
@classmethod
|
|
503
|
+
def _modify_callable_annotations( # pragma: no branch # because of lambdas
|
|
504
|
+
cls,
|
|
505
|
+
call: _Call,
|
|
506
|
+
modify_annotations: Callable[[dict[str, Any]], dict[str, Any]] = lambda a: a,
|
|
507
|
+
modify_defaults: Callable[[tuple[Any, ...]], tuple[Any, ...]] = lambda a: a,
|
|
508
|
+
*,
|
|
509
|
+
annotation_modifying_wrapper_factory: Callable[[_Call], _Call],
|
|
510
|
+
) -> _Call:
|
|
511
|
+
annotation_modifying_wrapper = annotation_modifying_wrapper_factory(call)
|
|
512
|
+
old_params = inspect.signature(call).parameters
|
|
513
|
+
callable_annotations = annotation_modifying_wrapper.__annotations__
|
|
514
|
+
annotation_modifying_wrapper.__annotations__ = modify_annotations(callable_annotations)
|
|
515
|
+
annotation_modifying_wrapper.__defaults__ = modify_defaults(
|
|
516
|
+
tuple(p.default for p in old_params.values() if p.default is not inspect.Signature.empty),
|
|
517
|
+
)
|
|
518
|
+
annotation_modifying_wrapper.__signature__ = cls._generate_signature(
|
|
519
|
+
annotation_modifying_wrapper,
|
|
520
|
+
old_params,
|
|
521
|
+
)
|
|
522
|
+
|
|
523
|
+
return annotation_modifying_wrapper
|
|
524
|
+
|
|
525
|
+
@staticmethod
|
|
526
|
+
def _generate_signature(
|
|
527
|
+
new_callable: Callable,
|
|
528
|
+
old_params: types.MappingProxyType[str, inspect.Parameter],
|
|
529
|
+
):
|
|
530
|
+
parameters = []
|
|
531
|
+
default_counter = 0
|
|
532
|
+
for param in old_params.values():
|
|
533
|
+
if param.default is not inspect.Signature.empty:
|
|
534
|
+
assert new_callable.__defaults__ is not None, ( # noqa: S101
|
|
535
|
+
"Defaults cannot be None here. If it is, you have found a bug in Cadwyn. "
|
|
536
|
+
"Please, report it in our issue tracker."
|
|
537
|
+
)
|
|
538
|
+
default = new_callable.__defaults__[default_counter]
|
|
539
|
+
default_counter += 1
|
|
540
|
+
else:
|
|
541
|
+
default = inspect.Signature.empty
|
|
542
|
+
parameters.append(
|
|
543
|
+
inspect.Parameter(
|
|
544
|
+
param.name,
|
|
545
|
+
param.kind,
|
|
546
|
+
default=default,
|
|
547
|
+
annotation=new_callable.__annotations__.get(
|
|
548
|
+
param.name,
|
|
549
|
+
inspect.Signature.empty,
|
|
550
|
+
),
|
|
551
|
+
),
|
|
552
|
+
)
|
|
553
|
+
return inspect.Signature(
|
|
554
|
+
parameters=parameters,
|
|
555
|
+
return_annotation=new_callable.__annotations__.get(
|
|
556
|
+
"return",
|
|
557
|
+
inspect.Signature.empty,
|
|
558
|
+
),
|
|
559
|
+
)
|
|
560
|
+
|
|
561
|
+
@classmethod
|
|
562
|
+
def _copy_function_through_class_based_wrapper(cls, call: Any):
|
|
563
|
+
"""Separate from copy_endpoint because endpoints MUST be functions in FastAPI, they cannot be cls instances"""
|
|
564
|
+
call = cls._unwrap_callable(call)
|
|
565
|
+
|
|
566
|
+
if inspect.iscoroutinefunction(call):
|
|
567
|
+
return _AsyncCallableWrapper(call)
|
|
568
|
+
else:
|
|
569
|
+
return _CallableWrapper(call)
|
|
570
|
+
|
|
571
|
+
@staticmethod
|
|
572
|
+
def _unwrap_callable(call: Any) -> Any:
|
|
573
|
+
while hasattr(call, "_original_callable"):
|
|
574
|
+
call = call._original_callable
|
|
575
|
+
if not isinstance(call, types.FunctionType | types.MethodType):
|
|
576
|
+
# This means that the callable is actually an instance of a regular class
|
|
577
|
+
call = call.__call__
|
|
578
|
+
|
|
579
|
+
return call
|
|
580
|
+
|
|
581
|
+
|
|
582
|
+
def _add_request_and_response_params(route: APIRoute):
|
|
583
|
+
if not route.dependant.request_param_name:
|
|
584
|
+
route.dependant.request_param_name = _CADWYN_REQUEST_PARAM_NAME
|
|
585
|
+
if not route.dependant.response_param_name:
|
|
586
|
+
route.dependant.response_param_name = _CADWYN_RESPONSE_PARAM_NAME
|
|
587
|
+
|
|
588
|
+
|
|
589
|
+
@final
|
|
590
|
+
class SchemaGenerator:
|
|
591
|
+
__slots__ = "annotation_transformer", "model_bundle", "concrete_models"
|
|
592
|
+
|
|
593
|
+
def __init__(self, model_bundle: _ModelBundle) -> None:
|
|
594
|
+
self.annotation_transformer = _AnnotationTransformer(self)
|
|
595
|
+
self.model_bundle = model_bundle
|
|
596
|
+
self.concrete_models = {}
|
|
597
|
+
self.concrete_models = {
|
|
598
|
+
k: wrapper.generate_model_copy(self)
|
|
599
|
+
for k, wrapper in (self.model_bundle.schemas | self.model_bundle.enums).items()
|
|
600
|
+
}
|
|
601
|
+
|
|
602
|
+
def __getitem__(self, model: type, /) -> Any:
|
|
603
|
+
if not isinstance(model, type) or not issubclass(model, BaseModel | Enum) or model in (BaseModel, RootModel):
|
|
604
|
+
return model
|
|
605
|
+
model = _unwrap_model(model)
|
|
606
|
+
|
|
607
|
+
if model in self.concrete_models:
|
|
608
|
+
return self.concrete_models[model]
|
|
609
|
+
else:
|
|
610
|
+
wrapper = self._get_wrapper_for_model(model)
|
|
611
|
+
|
|
612
|
+
wrapper = self._get_wrapper_for_model(model)
|
|
613
|
+
model_copy = wrapper.generate_model_copy(self)
|
|
614
|
+
self.concrete_models[model] = model_copy
|
|
615
|
+
return model_copy
|
|
616
|
+
|
|
617
|
+
def _get_wrapper_for_model(
|
|
618
|
+
self, model: type[BaseModel | Enum]
|
|
619
|
+
) -> "_PydanticRuntimeModelWrapper[BaseModel] | _EnumWrapper[Enum]":
|
|
620
|
+
model = _unwrap_model(model)
|
|
621
|
+
|
|
622
|
+
if model in self.model_bundle.schemas:
|
|
623
|
+
return self.model_bundle.schemas[model]
|
|
624
|
+
elif model in self.model_bundle.enums:
|
|
625
|
+
return self.model_bundle.enums[model]
|
|
626
|
+
|
|
627
|
+
if issubclass(model, BaseModel):
|
|
628
|
+
wrapper = _wrap_pydantic_model(model)
|
|
629
|
+
self.model_bundle.schemas[model] = wrapper
|
|
630
|
+
elif issubclass(model, Enum):
|
|
631
|
+
wrapper = _EnumWrapper(model)
|
|
632
|
+
self.model_bundle.enums[model] = wrapper
|
|
633
|
+
else:
|
|
634
|
+
assert_never(model)
|
|
635
|
+
return wrapper
|
|
636
|
+
|
|
637
|
+
|
|
638
|
+
@cache
|
|
639
|
+
def generate_versioned_models(versions: "VersionBundle") -> "dict[str, SchemaGenerator]":
|
|
640
|
+
models = _create_model_bundle(versions)
|
|
641
|
+
|
|
642
|
+
version_to_context_map = {}
|
|
643
|
+
context = _RuntimeSchemaGenContext(current_version=versions.head_version, models=models, version_bundle=versions)
|
|
644
|
+
_migrate_classes(context)
|
|
645
|
+
|
|
646
|
+
for version in versions.versions:
|
|
647
|
+
context = _RuntimeSchemaGenContext(current_version=version, models=models, version_bundle=versions)
|
|
648
|
+
version_to_context_map[str(version.value)] = SchemaGenerator(copy.deepcopy(models))
|
|
649
|
+
# note that the last migration will not contain any version changes so we don't need to save the results
|
|
650
|
+
_migrate_classes(context)
|
|
651
|
+
|
|
652
|
+
return version_to_context_map
|
|
653
|
+
|
|
654
|
+
|
|
655
|
+
def _create_model_bundle(versions: "VersionBundle"):
|
|
656
|
+
return _ModelBundle(
|
|
657
|
+
enums={enum: _EnumWrapper(enum) for enum in versions.versioned_enums.values()},
|
|
658
|
+
schemas={schema: _wrap_pydantic_model(schema) for schema in versions.versioned_schemas.values()},
|
|
659
|
+
)
|
|
660
|
+
|
|
661
|
+
|
|
662
|
+
def _migrate_classes(context: _RuntimeSchemaGenContext) -> None:
|
|
663
|
+
for version_change in context.current_version.version_changes:
|
|
664
|
+
_apply_alter_schema_instructions(
|
|
665
|
+
context.models.schemas,
|
|
666
|
+
version_change.alter_schema_instructions,
|
|
667
|
+
version_change.__name__,
|
|
668
|
+
)
|
|
669
|
+
_apply_alter_enum_instructions(
|
|
670
|
+
context.models.enums,
|
|
671
|
+
version_change.alter_enum_instructions,
|
|
672
|
+
version_change.__name__,
|
|
673
|
+
)
|
|
674
|
+
|
|
675
|
+
|
|
676
|
+
def _apply_alter_schema_instructions(
|
|
677
|
+
modified_schemas: dict[type, _PydanticRuntimeModelWrapper],
|
|
678
|
+
alter_schema_instructions: Sequence[AlterSchemaSubInstruction | SchemaHadInstruction],
|
|
679
|
+
version_change_name: str,
|
|
680
|
+
) -> None:
|
|
681
|
+
for alter_schema_instruction in alter_schema_instructions:
|
|
682
|
+
schema_info = modified_schemas[alter_schema_instruction.schema]
|
|
683
|
+
if isinstance(alter_schema_instruction, FieldExistedAsInstruction):
|
|
684
|
+
_add_field_to_model(schema_info, modified_schemas, alter_schema_instruction, version_change_name)
|
|
685
|
+
elif isinstance(alter_schema_instruction, FieldHadInstruction | FieldDidntHaveInstruction):
|
|
686
|
+
_change_field_in_model(
|
|
687
|
+
schema_info,
|
|
688
|
+
modified_schemas,
|
|
689
|
+
alter_schema_instruction,
|
|
690
|
+
version_change_name,
|
|
691
|
+
)
|
|
692
|
+
elif isinstance(alter_schema_instruction, FieldDidntExistInstruction):
|
|
693
|
+
_delete_field_from_model(schema_info, alter_schema_instruction.name, version_change_name)
|
|
694
|
+
elif isinstance(alter_schema_instruction, ValidatorExistedInstruction):
|
|
695
|
+
validator_name = alter_schema_instruction.validator.__name__
|
|
696
|
+
raw_validator = cast(
|
|
697
|
+
pydantic._internal._decorators.PydanticDescriptorProxy, alter_schema_instruction.validator
|
|
698
|
+
)
|
|
699
|
+
schema_info.validators[validator_name] = _wrap_validator(
|
|
700
|
+
raw_validator.wrapped,
|
|
701
|
+
is_pydantic_v1_style_validator=raw_validator.shim,
|
|
702
|
+
decorator_info=raw_validator.decorator_info,
|
|
703
|
+
)
|
|
704
|
+
elif isinstance(alter_schema_instruction, ValidatorDidntExistInstruction):
|
|
705
|
+
if alter_schema_instruction.name not in schema_info.validators:
|
|
706
|
+
raise InvalidGenerationInstructionError(
|
|
707
|
+
f'You tried to delete a validator "{alter_schema_instruction.name}" from "{schema_info.name}" '
|
|
708
|
+
f'in "{version_change_name}" but it doesn\'t have such a validator.',
|
|
709
|
+
)
|
|
710
|
+
if schema_info.validators[alter_schema_instruction.name].is_deleted:
|
|
711
|
+
raise InvalidGenerationInstructionError(
|
|
712
|
+
f'You tried to delete a validator "{alter_schema_instruction.name}" from "{schema_info.name}" '
|
|
713
|
+
f'in "{version_change_name}" but it is already deleted.',
|
|
714
|
+
)
|
|
715
|
+
schema_info.validators[alter_schema_instruction.name].is_deleted = True
|
|
716
|
+
elif isinstance(alter_schema_instruction, SchemaHadInstruction):
|
|
717
|
+
_change_model(schema_info, alter_schema_instruction, version_change_name)
|
|
718
|
+
else:
|
|
719
|
+
assert_never(alter_schema_instruction)
|
|
720
|
+
|
|
721
|
+
|
|
722
|
+
def _apply_alter_enum_instructions(
|
|
723
|
+
enums: "dict[type, _EnumWrapper]",
|
|
724
|
+
alter_enum_instructions: Sequence[AlterEnumSubInstruction],
|
|
725
|
+
version_change_name: str,
|
|
726
|
+
):
|
|
727
|
+
for alter_enum_instruction in alter_enum_instructions:
|
|
728
|
+
enum = enums[alter_enum_instruction.enum]
|
|
729
|
+
if isinstance(alter_enum_instruction, EnumDidntHaveMembersInstruction):
|
|
730
|
+
for member in alter_enum_instruction.members:
|
|
731
|
+
if member not in enum.members:
|
|
732
|
+
raise InvalidGenerationInstructionError(
|
|
733
|
+
f'You tried to delete a member "{member}" from "{enum.cls.__name__}" '
|
|
734
|
+
f'in "{version_change_name}" but it doesn\'t have such a member.',
|
|
735
|
+
)
|
|
736
|
+
enum.members.pop(member)
|
|
737
|
+
elif isinstance(alter_enum_instruction, EnumHadMembersInstruction):
|
|
738
|
+
for member, member_value in alter_enum_instruction.members.items():
|
|
739
|
+
if member in enum.members and enum.members[member] == member_value:
|
|
740
|
+
raise InvalidGenerationInstructionError(
|
|
741
|
+
f'You tried to add a member "{member}" to "{enum.cls.__name__}" '
|
|
742
|
+
f'in "{version_change_name}" but there is already a member with that name and value.',
|
|
743
|
+
)
|
|
744
|
+
enum.members[member] = member_value
|
|
745
|
+
else:
|
|
746
|
+
assert_never(alter_enum_instruction)
|
|
747
|
+
|
|
748
|
+
|
|
749
|
+
def _change_model(
|
|
750
|
+
model: _PydanticRuntimeModelWrapper,
|
|
751
|
+
alter_schema_instruction: SchemaHadInstruction,
|
|
752
|
+
version_change_name: str,
|
|
753
|
+
):
|
|
754
|
+
if alter_schema_instruction.name == model.name:
|
|
755
|
+
raise InvalidGenerationInstructionError(
|
|
756
|
+
f'You tried to change the name of "{model.name}" in "{version_change_name}" '
|
|
757
|
+
"but it already has the name you tried to assign.",
|
|
758
|
+
)
|
|
759
|
+
|
|
760
|
+
model.name = alter_schema_instruction.name
|
|
761
|
+
|
|
762
|
+
|
|
763
|
+
def _add_field_to_model(
|
|
764
|
+
model: _PydanticRuntimeModelWrapper,
|
|
765
|
+
schemas: "dict[type, _PydanticRuntimeModelWrapper]",
|
|
766
|
+
alter_schema_instruction: FieldExistedAsInstruction,
|
|
767
|
+
version_change_name: str,
|
|
768
|
+
):
|
|
769
|
+
defined_fields = model._get_defined_fields_through_mro(schemas)
|
|
770
|
+
if alter_schema_instruction.name in defined_fields:
|
|
771
|
+
raise InvalidGenerationInstructionError(
|
|
772
|
+
f'You tried to add a field "{alter_schema_instruction.name}" to "{model.name}" '
|
|
773
|
+
f'in "{version_change_name}" but there is already a field with that name.',
|
|
774
|
+
)
|
|
775
|
+
|
|
776
|
+
field = PydanticFieldWrapper(alter_schema_instruction.field, alter_schema_instruction.field.annotation)
|
|
777
|
+
model.fields[alter_schema_instruction.name] = field
|
|
778
|
+
model.annotations[alter_schema_instruction.name] = alter_schema_instruction.field.annotation
|
|
779
|
+
|
|
780
|
+
|
|
781
|
+
def _change_field_in_model(
|
|
782
|
+
model: _PydanticRuntimeModelWrapper,
|
|
783
|
+
schemas: "dict[type, _PydanticRuntimeModelWrapper]",
|
|
784
|
+
alter_schema_instruction: FieldHadInstruction | FieldDidntHaveInstruction,
|
|
785
|
+
version_change_name: str,
|
|
786
|
+
):
|
|
787
|
+
defined_annotations = model._get_defined_annotations_through_mro(schemas)
|
|
788
|
+
defined_fields = model._get_defined_fields_through_mro(schemas)
|
|
789
|
+
if alter_schema_instruction.name not in defined_fields:
|
|
790
|
+
raise InvalidGenerationInstructionError(
|
|
791
|
+
f'You tried to change the field "{alter_schema_instruction.name}" from '
|
|
792
|
+
f'"{model.name}" in "{version_change_name}" but it doesn\'t have such a field.',
|
|
793
|
+
)
|
|
794
|
+
|
|
795
|
+
field = defined_fields[alter_schema_instruction.name]
|
|
796
|
+
model.fields[alter_schema_instruction.name] = field
|
|
797
|
+
model.annotations[alter_schema_instruction.name] = defined_annotations[alter_schema_instruction.name]
|
|
798
|
+
|
|
799
|
+
if isinstance(alter_schema_instruction, FieldHadInstruction):
|
|
800
|
+
# TODO: This naming sucks
|
|
801
|
+
_change_field(
|
|
802
|
+
model,
|
|
803
|
+
alter_schema_instruction,
|
|
804
|
+
version_change_name,
|
|
805
|
+
defined_annotations,
|
|
806
|
+
field,
|
|
807
|
+
model.annotations[alter_schema_instruction.name],
|
|
808
|
+
)
|
|
809
|
+
else:
|
|
810
|
+
_delete_field_attributes(
|
|
811
|
+
model,
|
|
812
|
+
alter_schema_instruction,
|
|
813
|
+
version_change_name,
|
|
814
|
+
field,
|
|
815
|
+
model.annotations[alter_schema_instruction.name],
|
|
816
|
+
)
|
|
817
|
+
|
|
818
|
+
|
|
819
|
+
def _change_field(
|
|
820
|
+
model: _PydanticRuntimeModelWrapper,
|
|
821
|
+
alter_schema_instruction: FieldHadInstruction,
|
|
822
|
+
version_change_name: str,
|
|
823
|
+
defined_annotations: dict[str, Any],
|
|
824
|
+
field: PydanticFieldWrapper,
|
|
825
|
+
annotation: Any | None,
|
|
826
|
+
):
|
|
827
|
+
if alter_schema_instruction.type is not Sentinel:
|
|
828
|
+
if field.annotation == alter_schema_instruction.type:
|
|
829
|
+
raise InvalidGenerationInstructionError(
|
|
830
|
+
f'You tried to change the type of field "{alter_schema_instruction.name}" to '
|
|
831
|
+
f'"{alter_schema_instruction.type}" from "{model.name}" in "{version_change_name}" '
|
|
832
|
+
f'but it already has type "{field.annotation}"',
|
|
833
|
+
)
|
|
834
|
+
field.annotation = alter_schema_instruction.type
|
|
835
|
+
model.annotations[alter_schema_instruction.name] = alter_schema_instruction.type
|
|
836
|
+
|
|
837
|
+
if alter_schema_instruction.new_name is not Sentinel:
|
|
838
|
+
if alter_schema_instruction.new_name == alter_schema_instruction.name:
|
|
839
|
+
raise InvalidGenerationInstructionError(
|
|
840
|
+
f'You tried to change the name of field "{alter_schema_instruction.name}" '
|
|
841
|
+
f'from "{model.name}" in "{version_change_name}" '
|
|
842
|
+
"but it already has that name.",
|
|
843
|
+
)
|
|
844
|
+
model.fields[alter_schema_instruction.new_name] = model.fields.pop(alter_schema_instruction.name)
|
|
845
|
+
model.annotations[alter_schema_instruction.new_name] = model.annotations.pop(
|
|
846
|
+
alter_schema_instruction.name,
|
|
847
|
+
defined_annotations[alter_schema_instruction.name],
|
|
848
|
+
)
|
|
849
|
+
|
|
850
|
+
for attr_name in alter_schema_instruction.field_changes.__dataclass_fields__:
|
|
851
|
+
attr_value = getattr(alter_schema_instruction.field_changes, attr_name)
|
|
852
|
+
if attr_value is not Sentinel:
|
|
853
|
+
if field.passed_field_attributes.get(attr_name, Sentinel) == attr_value:
|
|
854
|
+
raise InvalidGenerationInstructionError(
|
|
855
|
+
f'You tried to change the attribute "{attr_name}" of field '
|
|
856
|
+
f'"{alter_schema_instruction.name}" '
|
|
857
|
+
f'from "{model.name}" to {attr_value!r} in "{version_change_name}" '
|
|
858
|
+
"but it already has that value.",
|
|
859
|
+
)
|
|
860
|
+
field.update_attribute(name=attr_name, value=attr_value)
|
|
861
|
+
|
|
862
|
+
|
|
863
|
+
def _delete_field_attributes(
|
|
864
|
+
model: _PydanticRuntimeModelWrapper,
|
|
865
|
+
alter_schema_instruction: FieldDidntHaveInstruction,
|
|
866
|
+
version_change_name: str,
|
|
867
|
+
field: PydanticFieldWrapper,
|
|
868
|
+
annotation: Any,
|
|
869
|
+
) -> None:
|
|
870
|
+
for attr_name in alter_schema_instruction.attributes:
|
|
871
|
+
if attr_name in field.passed_field_attributes:
|
|
872
|
+
field.delete_attribute(name=attr_name)
|
|
873
|
+
elif get_origin(annotation) == Annotated and any( # pragma: no branch
|
|
874
|
+
hasattr(sub_ann, attr_name) for sub_ann in get_args(annotation)
|
|
875
|
+
):
|
|
876
|
+
for sub_ann in get_args(annotation):
|
|
877
|
+
if hasattr(sub_ann, attr_name):
|
|
878
|
+
object.__setattr__(sub_ann, attr_name, None)
|
|
879
|
+
else:
|
|
880
|
+
raise InvalidGenerationInstructionError(
|
|
881
|
+
f'You tried to delete the attribute "{attr_name}" of field "{alter_schema_instruction.name}" '
|
|
882
|
+
f'from "{model.name}" in "{version_change_name}" '
|
|
883
|
+
"but it already doesn't have that attribute.",
|
|
884
|
+
)
|
|
885
|
+
|
|
886
|
+
|
|
887
|
+
def _delete_field_from_model(model: _PydanticRuntimeModelWrapper, field_name: str, version_change_name: str):
|
|
888
|
+
if field_name not in model.fields:
|
|
889
|
+
raise InvalidGenerationInstructionError(
|
|
890
|
+
f'You tried to delete a field "{field_name}" from "{model.name}" '
|
|
891
|
+
f'in "{version_change_name}" but it doesn\'t have such a field.',
|
|
892
|
+
)
|
|
893
|
+
model.fields.pop(field_name)
|
|
894
|
+
model.annotations.pop(field_name)
|
|
895
|
+
for validator_name, validator in model.validators.copy().items():
|
|
896
|
+
if isinstance(validator, _PerFieldValidatorWrapper) and field_name in validator.fields:
|
|
897
|
+
validator.fields.remove(field_name)
|
|
898
|
+
# TODO: This behavior doesn't feel natural
|
|
899
|
+
if not validator.fields:
|
|
900
|
+
model.validators[validator_name].is_deleted = True
|
|
901
|
+
|
|
902
|
+
|
|
903
|
+
class _DummyEnum(Enum):
|
|
904
|
+
pass
|
|
905
|
+
|
|
906
|
+
|
|
907
|
+
@final
|
|
908
|
+
class _EnumWrapper(Generic[_T_ENUM]):
|
|
909
|
+
__slots__ = "cls", "members"
|
|
910
|
+
|
|
911
|
+
def __init__(self, cls: type[_T_ENUM]):
|
|
912
|
+
self.cls = _unwrap_model(cls)
|
|
913
|
+
self.members = {member.name: member.value for member in cls}
|
|
914
|
+
|
|
915
|
+
def __deepcopy__(self, memo: Any):
|
|
916
|
+
result = _EnumWrapper(self.cls)
|
|
917
|
+
result.members = self.members.copy()
|
|
918
|
+
memo[id(self)] = result
|
|
919
|
+
return result
|
|
920
|
+
|
|
921
|
+
def generate_model_copy(self, generator: "SchemaGenerator") -> type[_T_ENUM]:
|
|
922
|
+
enum_dict = Enum.__prepare__(self.cls.__name__, self.cls.__bases__)
|
|
923
|
+
|
|
924
|
+
raw_member_map = {k: v.value if isinstance(v, Enum) else v for k, v in self.members.items()}
|
|
925
|
+
initialization_namespace = self._get_initialization_namespace_for_enum(self.cls) | raw_member_map
|
|
926
|
+
for attr_name, attr in initialization_namespace.items():
|
|
927
|
+
enum_dict[attr_name] = attr
|
|
928
|
+
model_copy = cast(type[_T_ENUM], type(self.cls.__name__, self.cls.__bases__, enum_dict))
|
|
929
|
+
model_copy.__cadwyn_original_model__ = self.cls # pyright: ignore[reportAttributeAccessIssue]
|
|
930
|
+
return model_copy
|
|
931
|
+
|
|
932
|
+
@staticmethod
|
|
933
|
+
def _get_initialization_namespace_for_enum(enum_cls: type[Enum]):
|
|
934
|
+
mro_without_the_class_itself = enum_cls.mro()[1:]
|
|
935
|
+
|
|
936
|
+
mro_dict = {}
|
|
937
|
+
for cls in reversed(mro_without_the_class_itself):
|
|
938
|
+
mro_dict.update(cls.__dict__)
|
|
939
|
+
|
|
940
|
+
return {
|
|
941
|
+
k: v
|
|
942
|
+
for k, v in enum_cls.__dict__.items()
|
|
943
|
+
if k not in enum_cls._member_names_
|
|
944
|
+
and k not in _DummyEnum.__dict__
|
|
945
|
+
and (k not in mro_dict or mro_dict[k] is not v)
|
|
946
|
+
}
|