cadwyn 3.15.9__py3-none-any.whl → 4.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cadwyn might be problematic. Click here for more details.

@@ -0,0 +1,946 @@
1
+ import copy
2
+ import dataclasses
3
+ import functools
4
+ import inspect
5
+ import types
6
+ import typing
7
+ from collections.abc import Callable, Sequence
8
+ from datetime import date
9
+ from enum import Enum
10
+ from functools import cache
11
+ from typing import (
12
+ TYPE_CHECKING,
13
+ Annotated,
14
+ Any,
15
+ Generic,
16
+ TypeAlias,
17
+ TypeVar,
18
+ _BaseGenericAlias, # pyright: ignore[reportAttributeAccessIssue]
19
+ cast,
20
+ final,
21
+ get_args,
22
+ get_origin,
23
+ )
24
+
25
+ import fastapi.params
26
+ import fastapi.security.base
27
+ import fastapi.utils
28
+ import pydantic
29
+ import pydantic._internal._decorators
30
+ from fastapi import Response
31
+ from fastapi.routing import APIRoute
32
+ from issubclass import issubclass
33
+ from pydantic import BaseModel, Field, RootModel
34
+ from pydantic._internal import _decorators
35
+ from pydantic._internal._decorators import (
36
+ FieldSerializerDecoratorInfo,
37
+ FieldValidatorDecoratorInfo,
38
+ ModelSerializerDecoratorInfo,
39
+ ModelValidatorDecoratorInfo,
40
+ RootValidatorDecoratorInfo,
41
+ ValidatorDecoratorInfo,
42
+ )
43
+ from pydantic.fields import ComputedFieldInfo, FieldInfo
44
+ from typing_extensions import Doc, Self, _AnnotatedAlias, assert_never
45
+
46
+ from cadwyn._utils import Sentinel, UnionType, fully_unwrap_decorator
47
+ from cadwyn.exceptions import InvalidGenerationInstructionError
48
+ from cadwyn.structure.common import VersionDate
49
+ from cadwyn.structure.data import ResponseInfo
50
+ from cadwyn.structure.enums import AlterEnumSubInstruction, EnumDidntHaveMembersInstruction, EnumHadMembersInstruction
51
+ from cadwyn.structure.schemas import (
52
+ AlterSchemaSubInstruction,
53
+ FieldDidntExistInstruction,
54
+ FieldDidntHaveInstruction,
55
+ FieldExistedAsInstruction,
56
+ FieldHadInstruction,
57
+ SchemaHadInstruction,
58
+ ValidatorDidntExistInstruction,
59
+ ValidatorExistedInstruction,
60
+ _get_model_decorators,
61
+ )
62
+ from cadwyn.structure.versions import _CADWYN_REQUEST_PARAM_NAME, _CADWYN_RESPONSE_PARAM_NAME, VersionBundle
63
+
64
+ if TYPE_CHECKING:
65
+ from cadwyn.structure.versions import HeadVersion, Version, VersionBundle
66
+
67
+ _Call = TypeVar("_Call", bound=Callable[..., Any])
68
+
69
+ _FieldName: TypeAlias = str
70
+ _T_ANY_MODEL = TypeVar("_T_ANY_MODEL", bound=BaseModel | Enum)
71
+ _T_ENUM = TypeVar("_T_ENUM", bound=Enum)
72
+
73
+ _T_PYDANTIC_MODEL = TypeVar("_T_PYDANTIC_MODEL", bound=BaseModel)
74
+ PYDANTIC_DECORATOR_TYPE_TO_DECORATOR_MAP = {
75
+ ValidatorDecoratorInfo: pydantic.validator, # pyright: ignore[reportDeprecated]
76
+ FieldValidatorDecoratorInfo: pydantic.field_validator,
77
+ FieldSerializerDecoratorInfo: pydantic.field_serializer,
78
+ RootValidatorDecoratorInfo: pydantic.root_validator, # pyright: ignore[reportDeprecated]
79
+ ModelValidatorDecoratorInfo: pydantic.model_validator,
80
+ ModelSerializerDecoratorInfo: pydantic.model_serializer,
81
+ ComputedFieldInfo: pydantic.computed_field,
82
+ }
83
+
84
+
85
+ VALIDATOR_CONFIG_KEY = "__validators__"
86
+ _all_field_arg_names = sorted(
87
+ [
88
+ name
89
+ for name, param in inspect.signature(Field).parameters.items()
90
+ if param.kind in {inspect._ParameterKind.KEYWORD_ONLY, inspect._ParameterKind.POSITIONAL_OR_KEYWORD}
91
+ ],
92
+ )
93
+ EXTRA_FIELD_NAME = "json_schema_extra"
94
+
95
+
96
+ _empty_field_info = Field()
97
+ dict_of_empty_field_info = {k: getattr(_empty_field_info, k) for k in FieldInfo.__slots__}
98
+
99
+
100
+ @dataclasses.dataclass(slots=True)
101
+ class PydanticFieldWrapper:
102
+ """We DO NOT maintain field.metadata at all"""
103
+
104
+ init_model_field: dataclasses.InitVar[FieldInfo]
105
+
106
+ annotation: Any
107
+ passed_field_attributes: dict[str, Any] = dataclasses.field(init=False)
108
+
109
+ def __post_init__(self, init_model_field: FieldInfo):
110
+ self.passed_field_attributes = _extract_passed_field_attributes(init_model_field)
111
+
112
+ def update_attribute(self, *, name: str, value: Any):
113
+ self.passed_field_attributes[name] = value
114
+
115
+ def delete_attribute(self, *, name: str) -> None:
116
+ self.passed_field_attributes.pop(name)
117
+
118
+ def generate_field_copy(self, generator: "_SchemaGenerator") -> pydantic.fields.FieldInfo:
119
+ return pydantic.Field(
120
+ **generator.annotation_transformer.change_version_of_annotation(self.passed_field_attributes)
121
+ )
122
+
123
+
124
+ def _extract_passed_field_attributes(field_info: FieldInfo):
125
+ attributes = {
126
+ attr_name: field_info._attributes_set[attr_name]
127
+ for attr_name in _all_field_arg_names
128
+ if attr_name in field_info._attributes_set
129
+ }
130
+ # PydanticV2 always adds frozen to _attributes_set but we don't want it if it wasn't explicitly set
131
+ if attributes.get("frozen", ...) is None:
132
+ attributes.pop("frozen")
133
+ return attributes
134
+
135
+
136
+ @dataclasses.dataclass(slots=True)
137
+ class _ModelBundle:
138
+ enums: dict[type[Enum], "_EnumWrapper"]
139
+ schemas: dict[type[BaseModel], "_PydanticRuntimeModelWrapper"]
140
+
141
+
142
+ @dataclasses.dataclass(slots=True, kw_only=True)
143
+ class _RuntimeSchemaGenContext:
144
+ version_bundle: "VersionBundle"
145
+ current_version: "Version | HeadVersion"
146
+ models: _ModelBundle
147
+ latest_version: "Version" = dataclasses.field(init=False)
148
+
149
+ def __post_init__(self):
150
+ self.latest_version = max(self.version_bundle.versions, key=lambda v: v.value)
151
+
152
+
153
+ def migrate_response_body(
154
+ versions: "VersionBundle",
155
+ latest_response_model: type[pydantic.BaseModel],
156
+ *,
157
+ latest_body: Any,
158
+ version: VersionDate | str,
159
+ ):
160
+ """Convert the data to a specific version by applying all version changes from latest until that version
161
+ in reverse order and wrapping the result in the correct version of latest_response_model.
162
+ """
163
+ if isinstance(version, str):
164
+ version = date.fromisoformat(version)
165
+ response = ResponseInfo(Response(status_code=200), body=latest_body)
166
+ migrated_response = versions._migrate_response(
167
+ response,
168
+ current_version=version,
169
+ head_response_model=latest_response_model,
170
+ path="\0\0\0",
171
+ method="GET",
172
+ )
173
+
174
+ version = versions._get_closest_lesser_version(version)
175
+
176
+ versioned_response_model: type[pydantic.BaseModel] = _generate_versioned_models(versions)[str(version)][
177
+ latest_response_model
178
+ ]
179
+ return versioned_response_model.model_validate(migrated_response.body)
180
+
181
+
182
+ def _unwrap_model(model: type[_T_ANY_MODEL]) -> type[_T_ANY_MODEL]:
183
+ while hasattr(model, "__cadwyn_original_model__"):
184
+ model = model.__cadwyn_original_model__ # pyright: ignore[reportAttributeAccessIssue]
185
+ return model
186
+
187
+
188
+ @dataclasses.dataclass(slots=True, kw_only=True)
189
+ class _ValidatorWrapper:
190
+ kwargs: dict[str, Any]
191
+ func: Callable
192
+ decorator: Callable
193
+ is_deleted: bool = False
194
+
195
+
196
+ @dataclasses.dataclass(slots=True, kw_only=True)
197
+ class _PerFieldValidatorWrapper(_ValidatorWrapper):
198
+ fields: list[str]
199
+
200
+
201
+ def _wrap_validator(func: Callable, is_pydantic_v1_style_validator: Any, decorator_info: _decorators.DecoratorInfo):
202
+ # This is only for pydantic v1 style validators
203
+ func = fully_unwrap_decorator(func, is_pydantic_v1_style_validator)
204
+ if inspect.ismethod(func):
205
+ func = func.__func__
206
+ kwargs = dataclasses.asdict(decorator_info)
207
+ decorator_fields = kwargs.pop("fields", None)
208
+ actual_decorator = PYDANTIC_DECORATOR_TYPE_TO_DECORATOR_MAP[type(decorator_info)]
209
+ if is_pydantic_v1_style_validator:
210
+ # There's an inconsistency in their interfaces so we gotta resort to this
211
+ mode = kwargs.pop("mode", "after")
212
+ kwargs["pre"] = mode != "after"
213
+ if decorator_fields is not None:
214
+ return _PerFieldValidatorWrapper(
215
+ func=func, fields=list(decorator_fields), decorator=actual_decorator, kwargs=kwargs
216
+ )
217
+ else:
218
+ return _ValidatorWrapper(func=func, decorator=actual_decorator, kwargs=kwargs)
219
+
220
+
221
+ def _is_dunder(attr_name: str):
222
+ return attr_name.startswith("__") and attr_name.endswith("__")
223
+
224
+
225
+ def _wrap_pydantic_model(model: type[_T_PYDANTIC_MODEL]) -> "_PydanticRuntimeModelWrapper[_T_PYDANTIC_MODEL]":
226
+ decorators = _get_model_decorators(model)
227
+ validators = {}
228
+ for decorator_wrapper in decorators:
229
+ if decorator_wrapper.cls_var_name not in model.__dict__:
230
+ continue
231
+
232
+ wrapped_validator = _wrap_validator(decorator_wrapper.func, decorator_wrapper.shim, decorator_wrapper.info)
233
+ validators[decorator_wrapper.cls_var_name] = wrapped_validator
234
+ fields = {
235
+ field_name: PydanticFieldWrapper(model.model_fields[field_name], model.__annotations__[field_name])
236
+ for field_name in model.__annotations__
237
+ }
238
+
239
+ main_attributes = fields | validators
240
+ other_attributes = {
241
+ attr_name: attr_val
242
+ for attr_name, attr_val in model.__dict__.items()
243
+ if attr_name not in main_attributes
244
+ and not (_is_dunder(attr_name) or attr_name in {"_abc_impl", "model_fields", "model_computed_fields"})
245
+ }
246
+ other_attributes |= {
247
+ "model_config": model.model_config,
248
+ "__module__": model.__module__,
249
+ "__qualname__": model.__qualname__,
250
+ }
251
+ return _PydanticRuntimeModelWrapper(
252
+ model,
253
+ name=model.__name__,
254
+ doc=model.__doc__,
255
+ fields=fields,
256
+ other_attributes=other_attributes,
257
+ validators=validators,
258
+ annotations=model.__annotations__.copy(),
259
+ )
260
+
261
+
262
+ @final
263
+ @dataclasses.dataclass(slots=True)
264
+ class _PydanticRuntimeModelWrapper(Generic[_T_PYDANTIC_MODEL]):
265
+ cls: type[_T_PYDANTIC_MODEL]
266
+ name: str
267
+ doc: str | None
268
+ fields: Annotated[
269
+ dict["_FieldName", PydanticFieldWrapper],
270
+ Doc(
271
+ "Fields that belong to this model, not to its parents. I.e. The ones that were either defined or overriden "
272
+ ),
273
+ ]
274
+ validators: dict[str, _PerFieldValidatorWrapper | _ValidatorWrapper]
275
+ other_attributes: dict[str, Any]
276
+ annotations: dict[str, Any]
277
+ _parents: list[Self] | None = dataclasses.field(init=False, default=None)
278
+
279
+ def __post_init__(self):
280
+ # This isn't actually supposed to run, it's just a precaution
281
+ while hasattr(self.cls, "__cadwyn_original_model__"): # pragma: no cover
282
+ self.cls = self.cls.__cadwyn_original_model__ # pyright: ignore[reportAttributeAccessIssue]
283
+
284
+ for k, annotation in self.annotations.items():
285
+ if get_origin(annotation) == Annotated:
286
+ sub_annotations = get_args(annotation)
287
+ # Annotated cannot be copied and is cached based on "==" and "hash", while annotated_types.Interval are
288
+ # frozen and so are consistently hashed
289
+ self.annotations[k] = _AnnotatedAlias(
290
+ copy.deepcopy(sub_annotations[0]), tuple(copy.deepcopy(sub_ann) for sub_ann in sub_annotations[1:])
291
+ )
292
+
293
+ def __deepcopy__(self, memo: dict[int, Any]):
294
+ result = _PydanticRuntimeModelWrapper(
295
+ self.cls,
296
+ name=self.name,
297
+ doc=self.doc,
298
+ fields=copy.deepcopy(self.fields),
299
+ validators=copy.deepcopy(self.validators),
300
+ other_attributes=copy.deepcopy(self.other_attributes),
301
+ annotations=copy.deepcopy(self.annotations),
302
+ )
303
+ memo[id(self)] = result
304
+ return result
305
+
306
+ def _get_parents(self, schemas: "dict[type, Self]"):
307
+ if self._parents is not None:
308
+ return self._parents
309
+ parents = []
310
+ for base in self.cls.mro()[1:]:
311
+ if base in schemas:
312
+ parents.append(schemas[base])
313
+ elif issubclass(base, BaseModel):
314
+ parents.append(_wrap_pydantic_model(base))
315
+ self._parents = parents
316
+ return parents
317
+
318
+ def _get_defined_fields_through_mro(self, schemas: "dict[type, Self]") -> dict[str, PydanticFieldWrapper]:
319
+ fields = {}
320
+
321
+ for parent in reversed(self._get_parents(schemas)):
322
+ fields |= parent.fields
323
+
324
+ return fields | self.fields
325
+
326
+ def _get_defined_annotations_through_mro(self, schemas: "dict[type, Self]") -> dict[str, Any]:
327
+ annotations = {}
328
+
329
+ for parent in reversed(self._get_parents(schemas)):
330
+ annotations |= parent.annotations
331
+
332
+ return annotations | self.annotations
333
+
334
+ def generate_model_copy(self, generator: "_SchemaGenerator") -> type[_T_PYDANTIC_MODEL]:
335
+ per_field_validators = {
336
+ name: validator.decorator(*validator.fields, **validator.kwargs)(validator.func)
337
+ for name, validator in self.validators.items()
338
+ if not validator.is_deleted and type(validator) == _PerFieldValidatorWrapper
339
+ }
340
+ root_validators = {
341
+ name: validator.decorator(**validator.kwargs)(validator.func)
342
+ for name, validator in self.validators.items()
343
+ if not validator.is_deleted and type(validator) == _ValidatorWrapper
344
+ }
345
+ fields = {name: field.generate_field_copy(generator) for name, field in self.fields.items()}
346
+ model_copy = type(self.cls)(
347
+ self.name,
348
+ tuple(generator[base] for base in self.cls.__bases__),
349
+ self.other_attributes
350
+ | per_field_validators
351
+ | root_validators
352
+ | fields
353
+ | {
354
+ "__annotations__": generator.annotation_transformer.change_version_of_annotation(self.annotations),
355
+ "__doc__": self.doc,
356
+ "__qualname__": self.cls.__qualname__.removesuffix(self.cls.__name__) + self.name,
357
+ },
358
+ )
359
+
360
+ model_copy.__cadwyn_original_model__ = self.cls
361
+ return model_copy
362
+
363
+
364
+ class _CallableWrapper:
365
+ """__eq__ and __hash__ are needed to make sure that dependency overrides work correctly.
366
+ They are based on putting dependencies (functions) as keys for the dictionary so if we want to be able to
367
+ override the wrapper, we need to make sure that it is equivalent to the original in __hash__ and __eq__
368
+ """
369
+
370
+ def __init__(self, original_callable: Callable) -> None:
371
+ super().__init__()
372
+ self._original_callable = original_callable
373
+ functools.update_wrapper(self, original_callable)
374
+
375
+ @property
376
+ def __globals__(self):
377
+ """FastAPI uses __globals__ to resolve forward references in type hints
378
+ It's supposed to be an attribute on the function but we use it as property to prevent python
379
+ from trying to pickle globals when we deepcopy this wrapper
380
+ """
381
+ #
382
+ return self._original_callable.__globals__
383
+
384
+ def __call__(self, *args: Any, **kwargs: Any):
385
+ return self._original_callable(*args, **kwargs)
386
+
387
+ def __hash__(self):
388
+ return hash(self._original_callable)
389
+
390
+ def __eq__(self, value: object) -> bool:
391
+ return self._original_callable == value # pyright: ignore[reportUnnecessaryComparison]
392
+
393
+
394
+ class _AsyncCallableWrapper(_CallableWrapper):
395
+ async def __call__(self, *args: Any, **kwargs: Any):
396
+ return await self._original_callable(*args, **kwargs)
397
+
398
+
399
+ @final
400
+ class _AnnotationTransformer:
401
+ def __init__(self, generator: "_SchemaGenerator") -> None:
402
+ # This cache is not here for speeding things up. It's for preventing the creation of copies of the same object
403
+ # because such copies could produce weird behaviors at runtime, especially if you/fastapi do any comparisons.
404
+ # It's defined here and not on the method because of this: https://youtu.be/sVjtp6tGo0g
405
+ self.generator = generator
406
+ self.change_versions_of_a_non_container_annotation = functools.cache(
407
+ self._change_version_of_a_non_container_annotation
408
+ )
409
+
410
+ def change_version_of_annotation(self, annotation: Any) -> Any:
411
+ """Recursively go through all annotations and change them to the
412
+ annotations corresponding to the version passed.
413
+
414
+ So if we had a annotation "UserResponse" from "head" version, and we passed version of "2022-11-16", it would
415
+ replace "UserResponse" with the the same class but from the "2022-11-16" version.
416
+
417
+ """
418
+ if isinstance(annotation, dict):
419
+ return {
420
+ self.change_version_of_annotation(key): self.change_version_of_annotation(value)
421
+ for key, value in annotation.items()
422
+ }
423
+
424
+ elif isinstance(annotation, list | tuple):
425
+ return type(annotation)(self.change_version_of_annotation(v) for v in annotation)
426
+ else:
427
+ return self.change_versions_of_a_non_container_annotation(annotation)
428
+
429
+ def migrate_router_to_version(self, router: fastapi.routing.APIRouter):
430
+ for route in router.routes:
431
+ if not isinstance(route, fastapi.routing.APIRoute):
432
+ continue
433
+ self.migrate_route_to_version(route)
434
+
435
+ def migrate_route_to_version(self, route: fastapi.routing.APIRoute, *, ignore_response_model: bool = False):
436
+ if route.response_model is not None and not ignore_response_model:
437
+ route.response_model = self.change_version_of_annotation(route.response_model)
438
+ route.response_field = fastapi.utils.create_response_field(
439
+ name="Response_" + route.unique_id,
440
+ type_=route.response_model,
441
+ mode="serialization",
442
+ )
443
+ route.secure_cloned_response_field = fastapi.utils.create_cloned_field(route.response_field)
444
+ route.dependencies = self.change_version_of_annotation(route.dependencies)
445
+ route.endpoint = self.change_version_of_annotation(route.endpoint)
446
+ for callback in route.callbacks or []:
447
+ if not isinstance(callback, fastapi.routing.APIRoute):
448
+ continue
449
+ self.migrate_route_to_version(callback, ignore_response_model=ignore_response_model)
450
+ self._remake_endpoint_dependencies(route)
451
+
452
+ def _change_version_of_a_non_container_annotation(self, annotation: Any) -> Any:
453
+ if isinstance(annotation, _BaseGenericAlias | types.GenericAlias):
454
+ return get_origin(annotation)[tuple(self.change_version_of_annotation(arg) for arg in get_args(annotation))]
455
+ elif isinstance(annotation, fastapi.params.Depends):
456
+ return fastapi.params.Depends(
457
+ self.change_version_of_annotation(annotation.dependency),
458
+ use_cache=annotation.use_cache,
459
+ )
460
+ elif isinstance(annotation, UnionType):
461
+ getitem = typing.Union.__getitem__ # pyright: ignore[reportAttributeAccessIssue]
462
+ return getitem(
463
+ tuple(self.change_version_of_annotation(a) for a in get_args(annotation)),
464
+ )
465
+ elif annotation is Any or isinstance(annotation, typing.NewType):
466
+ return annotation
467
+ elif isinstance(annotation, type):
468
+ return self._change_version_of_type(annotation)
469
+ elif callable(annotation):
470
+ if type(annotation).__module__.startswith(
471
+ ("fastapi.", "pydantic.", "pydantic_core.", "starlette.")
472
+ ) or isinstance(annotation, fastapi.params.Security | fastapi.security.base.SecurityBase):
473
+ return annotation
474
+
475
+ def modifier(annotation: Any):
476
+ return self.change_version_of_annotation(annotation)
477
+
478
+ return self._modify_callable_annotations(
479
+ annotation,
480
+ modifier,
481
+ modifier,
482
+ annotation_modifying_wrapper_factory=self._copy_function_through_class_based_wrapper,
483
+ )
484
+ else:
485
+ return annotation
486
+
487
+ def _change_version_of_type(self, annotation: type):
488
+ if issubclass(annotation, BaseModel | Enum):
489
+ return self.generator[annotation]
490
+ else:
491
+ return annotation
492
+
493
+ @classmethod
494
+ def _remake_endpoint_dependencies(cls, route: fastapi.routing.APIRoute):
495
+ # Unlike get_dependant, APIRoute is the public API of FastAPI and it's (almost) guaranteed to be stable.
496
+
497
+ route_copy = fastapi.routing.APIRoute(route.path, route.endpoint, dependencies=route.dependencies)
498
+ route.dependant = route_copy.dependant
499
+ route.body_field = route_copy.body_field
500
+ _add_request_and_response_params(route)
501
+
502
+ @classmethod
503
+ def _modify_callable_annotations( # pragma: no branch # because of lambdas
504
+ cls,
505
+ call: _Call,
506
+ modify_annotations: Callable[[dict[str, Any]], dict[str, Any]] = lambda a: a,
507
+ modify_defaults: Callable[[tuple[Any, ...]], tuple[Any, ...]] = lambda a: a,
508
+ *,
509
+ annotation_modifying_wrapper_factory: Callable[[_Call], _Call],
510
+ ) -> _Call:
511
+ annotation_modifying_wrapper = annotation_modifying_wrapper_factory(call)
512
+ old_params = inspect.signature(call).parameters
513
+ callable_annotations = annotation_modifying_wrapper.__annotations__
514
+ annotation_modifying_wrapper.__annotations__ = modify_annotations(callable_annotations)
515
+ annotation_modifying_wrapper.__defaults__ = modify_defaults(
516
+ tuple(p.default for p in old_params.values() if p.default is not inspect.Signature.empty),
517
+ )
518
+ annotation_modifying_wrapper.__signature__ = cls._generate_signature(
519
+ annotation_modifying_wrapper,
520
+ old_params,
521
+ )
522
+
523
+ return annotation_modifying_wrapper
524
+
525
+ @staticmethod
526
+ def _generate_signature(
527
+ new_callable: Callable,
528
+ old_params: types.MappingProxyType[str, inspect.Parameter],
529
+ ):
530
+ parameters = []
531
+ default_counter = 0
532
+ for param in old_params.values():
533
+ if param.default is not inspect.Signature.empty:
534
+ assert new_callable.__defaults__ is not None, ( # noqa: S101
535
+ "Defaults cannot be None here. If it is, you have found a bug in Cadwyn. "
536
+ "Please, report it in our issue tracker."
537
+ )
538
+ default = new_callable.__defaults__[default_counter]
539
+ default_counter += 1
540
+ else:
541
+ default = inspect.Signature.empty
542
+ parameters.append(
543
+ inspect.Parameter(
544
+ param.name,
545
+ param.kind,
546
+ default=default,
547
+ annotation=new_callable.__annotations__.get(
548
+ param.name,
549
+ inspect.Signature.empty,
550
+ ),
551
+ ),
552
+ )
553
+ return inspect.Signature(
554
+ parameters=parameters,
555
+ return_annotation=new_callable.__annotations__.get(
556
+ "return",
557
+ inspect.Signature.empty,
558
+ ),
559
+ )
560
+
561
+ @classmethod
562
+ def _copy_function_through_class_based_wrapper(cls, call: Any):
563
+ """Separate from copy_endpoint because endpoints MUST be functions in FastAPI, they cannot be cls instances"""
564
+ call = cls._unwrap_callable(call)
565
+
566
+ if inspect.iscoroutinefunction(call):
567
+ return _AsyncCallableWrapper(call)
568
+ else:
569
+ return _CallableWrapper(call)
570
+
571
+ @staticmethod
572
+ def _unwrap_callable(call: Any) -> Any:
573
+ while hasattr(call, "_original_callable"):
574
+ call = call._original_callable
575
+ if not isinstance(call, types.FunctionType | types.MethodType):
576
+ # This means that the callable is actually an instance of a regular class
577
+ call = call.__call__
578
+
579
+ return call
580
+
581
+
582
+ def _add_request_and_response_params(route: APIRoute):
583
+ if not route.dependant.request_param_name:
584
+ route.dependant.request_param_name = _CADWYN_REQUEST_PARAM_NAME
585
+ if not route.dependant.response_param_name:
586
+ route.dependant.response_param_name = _CADWYN_RESPONSE_PARAM_NAME
587
+
588
+
589
+ @final
590
+ class _SchemaGenerator:
591
+ __slots__ = "annotation_transformer", "model_bundle", "concrete_models"
592
+
593
+ def __init__(self, model_bundle: _ModelBundle) -> None:
594
+ self.annotation_transformer = _AnnotationTransformer(self)
595
+ self.model_bundle = model_bundle
596
+ self.concrete_models = {}
597
+ self.concrete_models = {
598
+ k: wrapper.generate_model_copy(self)
599
+ for k, wrapper in (self.model_bundle.schemas | self.model_bundle.enums).items()
600
+ }
601
+
602
+ def __getitem__(self, model: type, /) -> Any:
603
+ if not isinstance(model, type) or not issubclass(model, BaseModel | Enum) or model in (BaseModel, RootModel):
604
+ return model
605
+ model = _unwrap_model(model)
606
+
607
+ if model in self.concrete_models:
608
+ return self.concrete_models[model]
609
+ else:
610
+ wrapper = self._get_wrapper_for_model(model)
611
+
612
+ wrapper = self._get_wrapper_for_model(model)
613
+ model_copy = wrapper.generate_model_copy(self)
614
+ self.concrete_models[model] = model_copy
615
+ return model_copy
616
+
617
+ def _get_wrapper_for_model(
618
+ self, model: type[BaseModel | Enum]
619
+ ) -> "_PydanticRuntimeModelWrapper[BaseModel] | _EnumWrapper[Enum]":
620
+ model = _unwrap_model(model)
621
+
622
+ if model in self.model_bundle.schemas:
623
+ return self.model_bundle.schemas[model]
624
+ elif model in self.model_bundle.enums:
625
+ return self.model_bundle.enums[model]
626
+
627
+ if issubclass(model, BaseModel):
628
+ wrapper = _wrap_pydantic_model(model)
629
+ self.model_bundle.schemas[model] = wrapper
630
+ elif issubclass(model, Enum):
631
+ wrapper = _EnumWrapper(model)
632
+ self.model_bundle.enums[model] = wrapper
633
+ else:
634
+ assert_never(model)
635
+ return wrapper
636
+
637
+
638
+ @cache
639
+ def _generate_versioned_models(versions: "VersionBundle") -> "dict[str, _SchemaGenerator]":
640
+ models = _create_model_bundle(versions)
641
+
642
+ version_to_context_map = {}
643
+ context = _RuntimeSchemaGenContext(current_version=versions.head_version, models=models, version_bundle=versions)
644
+ _migrate_classes(context)
645
+
646
+ for version in versions.versions:
647
+ context = _RuntimeSchemaGenContext(current_version=version, models=models, version_bundle=versions)
648
+ version_to_context_map[str(version.value)] = _SchemaGenerator(copy.deepcopy(models))
649
+ # note that the last migration will not contain any version changes so we don't need to save the results
650
+ _migrate_classes(context)
651
+
652
+ return version_to_context_map
653
+
654
+
655
+ def _create_model_bundle(versions: "VersionBundle"):
656
+ return _ModelBundle(
657
+ enums={enum: _EnumWrapper(enum) for enum in versions.versioned_enums.values()},
658
+ schemas={schema: _wrap_pydantic_model(schema) for schema in versions.versioned_schemas.values()},
659
+ )
660
+
661
+
662
+ def _migrate_classes(context: _RuntimeSchemaGenContext) -> None:
663
+ for version_change in context.current_version.version_changes:
664
+ _apply_alter_schema_instructions(
665
+ context.models.schemas,
666
+ version_change.alter_schema_instructions,
667
+ version_change.__name__,
668
+ )
669
+ _apply_alter_enum_instructions(
670
+ context.models.enums,
671
+ version_change.alter_enum_instructions,
672
+ version_change.__name__,
673
+ )
674
+
675
+
676
+ def _apply_alter_schema_instructions(
677
+ modified_schemas: dict[type, _PydanticRuntimeModelWrapper],
678
+ alter_schema_instructions: Sequence[AlterSchemaSubInstruction | SchemaHadInstruction],
679
+ version_change_name: str,
680
+ ) -> None:
681
+ for alter_schema_instruction in alter_schema_instructions:
682
+ schema_info = modified_schemas[alter_schema_instruction.schema]
683
+ if isinstance(alter_schema_instruction, FieldExistedAsInstruction):
684
+ _add_field_to_model(schema_info, modified_schemas, alter_schema_instruction, version_change_name)
685
+ elif isinstance(alter_schema_instruction, FieldHadInstruction | FieldDidntHaveInstruction):
686
+ _change_field_in_model(
687
+ schema_info,
688
+ modified_schemas,
689
+ alter_schema_instruction,
690
+ version_change_name,
691
+ )
692
+ elif isinstance(alter_schema_instruction, FieldDidntExistInstruction):
693
+ _delete_field_from_model(schema_info, alter_schema_instruction.name, version_change_name)
694
+ elif isinstance(alter_schema_instruction, ValidatorExistedInstruction):
695
+ validator_name = alter_schema_instruction.validator.__name__
696
+ raw_validator = cast(
697
+ pydantic._internal._decorators.PydanticDescriptorProxy, alter_schema_instruction.validator
698
+ )
699
+ schema_info.validators[validator_name] = _wrap_validator(
700
+ raw_validator.wrapped,
701
+ is_pydantic_v1_style_validator=raw_validator.shim,
702
+ decorator_info=raw_validator.decorator_info,
703
+ )
704
+ elif isinstance(alter_schema_instruction, ValidatorDidntExistInstruction):
705
+ if alter_schema_instruction.name not in schema_info.validators:
706
+ raise InvalidGenerationInstructionError(
707
+ f'You tried to delete a validator "{alter_schema_instruction.name}" from "{schema_info.name}" '
708
+ f'in "{version_change_name}" but it doesn\'t have such a validator.',
709
+ )
710
+ if schema_info.validators[alter_schema_instruction.name].is_deleted:
711
+ raise InvalidGenerationInstructionError(
712
+ f'You tried to delete a validator "{alter_schema_instruction.name}" from "{schema_info.name}" '
713
+ f'in "{version_change_name}" but it is already deleted.',
714
+ )
715
+ schema_info.validators[alter_schema_instruction.name].is_deleted = True
716
+ elif isinstance(alter_schema_instruction, SchemaHadInstruction):
717
+ _change_model(schema_info, alter_schema_instruction, version_change_name)
718
+ else:
719
+ assert_never(alter_schema_instruction)
720
+
721
+
722
+ def _apply_alter_enum_instructions(
723
+ enums: "dict[type, _EnumWrapper]",
724
+ alter_enum_instructions: Sequence[AlterEnumSubInstruction],
725
+ version_change_name: str,
726
+ ):
727
+ for alter_enum_instruction in alter_enum_instructions:
728
+ enum = enums[alter_enum_instruction.enum]
729
+ if isinstance(alter_enum_instruction, EnumDidntHaveMembersInstruction):
730
+ for member in alter_enum_instruction.members:
731
+ if member not in enum.members:
732
+ raise InvalidGenerationInstructionError(
733
+ f'You tried to delete a member "{member}" from "{enum.cls.__name__}" '
734
+ f'in "{version_change_name}" but it doesn\'t have such a member.',
735
+ )
736
+ enum.members.pop(member)
737
+ elif isinstance(alter_enum_instruction, EnumHadMembersInstruction):
738
+ for member, member_value in alter_enum_instruction.members.items():
739
+ if member in enum.members and enum.members[member] == member_value:
740
+ raise InvalidGenerationInstructionError(
741
+ f'You tried to add a member "{member}" to "{enum.cls.__name__}" '
742
+ f'in "{version_change_name}" but there is already a member with that name and value.',
743
+ )
744
+ enum.members[member] = member_value
745
+ else:
746
+ assert_never(alter_enum_instruction)
747
+
748
+
749
+ def _change_model(
750
+ model: _PydanticRuntimeModelWrapper,
751
+ alter_schema_instruction: SchemaHadInstruction,
752
+ version_change_name: str,
753
+ ):
754
+ if alter_schema_instruction.name == model.name:
755
+ raise InvalidGenerationInstructionError(
756
+ f'You tried to change the name of "{model.name}" in "{version_change_name}" '
757
+ "but it already has the name you tried to assign.",
758
+ )
759
+
760
+ model.name = alter_schema_instruction.name
761
+
762
+
763
+ def _add_field_to_model(
764
+ model: _PydanticRuntimeModelWrapper,
765
+ schemas: "dict[type, _PydanticRuntimeModelWrapper]",
766
+ alter_schema_instruction: FieldExistedAsInstruction,
767
+ version_change_name: str,
768
+ ):
769
+ defined_fields = model._get_defined_fields_through_mro(schemas)
770
+ if alter_schema_instruction.name in defined_fields:
771
+ raise InvalidGenerationInstructionError(
772
+ f'You tried to add a field "{alter_schema_instruction.name}" to "{model.name}" '
773
+ f'in "{version_change_name}" but there is already a field with that name.',
774
+ )
775
+
776
+ field = PydanticFieldWrapper(alter_schema_instruction.field, alter_schema_instruction.field.annotation)
777
+ model.fields[alter_schema_instruction.name] = field
778
+ model.annotations[alter_schema_instruction.name] = alter_schema_instruction.field.annotation
779
+
780
+
781
+ def _change_field_in_model(
782
+ model: _PydanticRuntimeModelWrapper,
783
+ schemas: "dict[type, _PydanticRuntimeModelWrapper]",
784
+ alter_schema_instruction: FieldHadInstruction | FieldDidntHaveInstruction,
785
+ version_change_name: str,
786
+ ):
787
+ defined_annotations = model._get_defined_annotations_through_mro(schemas)
788
+ defined_fields = model._get_defined_fields_through_mro(schemas)
789
+ if alter_schema_instruction.name not in defined_fields:
790
+ raise InvalidGenerationInstructionError(
791
+ f'You tried to change the field "{alter_schema_instruction.name}" from '
792
+ f'"{model.name}" in "{version_change_name}" but it doesn\'t have such a field.',
793
+ )
794
+
795
+ field = defined_fields[alter_schema_instruction.name]
796
+ model.fields[alter_schema_instruction.name] = field
797
+ model.annotations[alter_schema_instruction.name] = defined_annotations[alter_schema_instruction.name]
798
+
799
+ if isinstance(alter_schema_instruction, FieldHadInstruction):
800
+ # TODO: This naming sucks
801
+ _change_field(
802
+ model,
803
+ alter_schema_instruction,
804
+ version_change_name,
805
+ defined_annotations,
806
+ field,
807
+ model.annotations[alter_schema_instruction.name],
808
+ )
809
+ else:
810
+ _delete_field_attributes(
811
+ model,
812
+ alter_schema_instruction,
813
+ version_change_name,
814
+ field,
815
+ model.annotations[alter_schema_instruction.name],
816
+ )
817
+
818
+
819
+ def _change_field(
820
+ model: _PydanticRuntimeModelWrapper,
821
+ alter_schema_instruction: FieldHadInstruction,
822
+ version_change_name: str,
823
+ defined_annotations: dict[str, Any],
824
+ field: PydanticFieldWrapper,
825
+ annotation: Any | None,
826
+ ):
827
+ if alter_schema_instruction.type is not Sentinel:
828
+ if field.annotation == alter_schema_instruction.type:
829
+ raise InvalidGenerationInstructionError(
830
+ f'You tried to change the type of field "{alter_schema_instruction.name}" to '
831
+ f'"{alter_schema_instruction.type}" from "{model.name}" in "{version_change_name}" '
832
+ f'but it already has type "{field.annotation}"',
833
+ )
834
+ field.annotation = alter_schema_instruction.type
835
+ model.annotations[alter_schema_instruction.name] = alter_schema_instruction.type
836
+
837
+ if alter_schema_instruction.new_name is not Sentinel:
838
+ if alter_schema_instruction.new_name == alter_schema_instruction.name:
839
+ raise InvalidGenerationInstructionError(
840
+ f'You tried to change the name of field "{alter_schema_instruction.name}" '
841
+ f'from "{model.name}" in "{version_change_name}" '
842
+ "but it already has that name.",
843
+ )
844
+ model.fields[alter_schema_instruction.new_name] = model.fields.pop(alter_schema_instruction.name)
845
+ model.annotations[alter_schema_instruction.new_name] = model.annotations.pop(
846
+ alter_schema_instruction.name,
847
+ defined_annotations[alter_schema_instruction.name],
848
+ )
849
+
850
+ for attr_name in alter_schema_instruction.field_changes.__dataclass_fields__:
851
+ attr_value = getattr(alter_schema_instruction.field_changes, attr_name)
852
+ if attr_value is not Sentinel:
853
+ if field.passed_field_attributes.get(attr_name, Sentinel) == attr_value:
854
+ raise InvalidGenerationInstructionError(
855
+ f'You tried to change the attribute "{attr_name}" of field '
856
+ f'"{alter_schema_instruction.name}" '
857
+ f'from "{model.name}" to {attr_value!r} in "{version_change_name}" '
858
+ "but it already has that value.",
859
+ )
860
+ field.update_attribute(name=attr_name, value=attr_value)
861
+
862
+
863
+ def _delete_field_attributes(
864
+ model: _PydanticRuntimeModelWrapper,
865
+ alter_schema_instruction: FieldDidntHaveInstruction,
866
+ version_change_name: str,
867
+ field: PydanticFieldWrapper,
868
+ annotation: Any,
869
+ ) -> None:
870
+ for attr_name in alter_schema_instruction.attributes:
871
+ if attr_name in field.passed_field_attributes:
872
+ field.delete_attribute(name=attr_name)
873
+ elif get_origin(annotation) == Annotated and any( # pragma: no branch
874
+ hasattr(sub_ann, attr_name) for sub_ann in get_args(annotation)
875
+ ):
876
+ for sub_ann in get_args(annotation):
877
+ if hasattr(sub_ann, attr_name):
878
+ object.__setattr__(sub_ann, attr_name, None)
879
+ else:
880
+ raise InvalidGenerationInstructionError(
881
+ f'You tried to delete the attribute "{attr_name}" of field "{alter_schema_instruction.name}" '
882
+ f'from "{model.name}" in "{version_change_name}" '
883
+ "but it already doesn't have that attribute.",
884
+ )
885
+
886
+
887
+ def _delete_field_from_model(model: _PydanticRuntimeModelWrapper, field_name: str, version_change_name: str):
888
+ if field_name not in model.fields:
889
+ raise InvalidGenerationInstructionError(
890
+ f'You tried to delete a field "{field_name}" from "{model.name}" '
891
+ f'in "{version_change_name}" but it doesn\'t have such a field.',
892
+ )
893
+ model.fields.pop(field_name)
894
+ model.annotations.pop(field_name)
895
+ for validator_name, validator in model.validators.copy().items():
896
+ if isinstance(validator, _PerFieldValidatorWrapper) and field_name in validator.fields:
897
+ validator.fields.remove(field_name)
898
+ # TODO: This behavior doesn't feel natural
899
+ if not validator.fields:
900
+ model.validators[validator_name].is_deleted = True
901
+
902
+
903
+ class _DummyEnum(Enum):
904
+ pass
905
+
906
+
907
+ @final
908
+ class _EnumWrapper(Generic[_T_ENUM]):
909
+ __slots__ = "cls", "members"
910
+
911
+ def __init__(self, cls: type[_T_ENUM]):
912
+ self.cls = _unwrap_model(cls)
913
+ self.members = {member.name: member.value for member in cls}
914
+
915
+ def __deepcopy__(self, memo: Any):
916
+ result = _EnumWrapper(self.cls)
917
+ result.members = self.members.copy()
918
+ memo[id(self)] = result
919
+ return result
920
+
921
+ def generate_model_copy(self, generator: "_SchemaGenerator") -> type[_T_ENUM]:
922
+ enum_dict = Enum.__prepare__(self.cls.__name__, self.cls.__bases__)
923
+
924
+ raw_member_map = {k: v.value if isinstance(v, Enum) else v for k, v in self.members.items()}
925
+ initialization_namespace = self._get_initialization_namespace_for_enum(self.cls) | raw_member_map
926
+ for attr_name, attr in initialization_namespace.items():
927
+ enum_dict[attr_name] = attr
928
+ model_copy = cast(type[_T_ENUM], type(self.cls.__name__, self.cls.__bases__, enum_dict))
929
+ model_copy.__cadwyn_original_model__ = self.cls # pyright: ignore[reportAttributeAccessIssue]
930
+ return model_copy
931
+
932
+ @staticmethod
933
+ def _get_initialization_namespace_for_enum(enum_cls: type[Enum]):
934
+ mro_without_the_class_itself = enum_cls.mro()[1:]
935
+
936
+ mro_dict = {}
937
+ for cls in reversed(mro_without_the_class_itself):
938
+ mro_dict.update(cls.__dict__)
939
+
940
+ return {
941
+ k: v
942
+ for k, v in enum_cls.__dict__.items()
943
+ if k not in enum_cls._member_names_
944
+ and k not in _DummyEnum.__dict__
945
+ and (k not in mro_dict or mro_dict[k] is not v)
946
+ }