wellapi 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,923 @@
1
+ import inspect
2
+ import re
3
+ from collections import deque
4
+ from collections.abc import Mapping, Sequence
5
+ from copy import copy, deepcopy
6
+ from dataclasses import dataclass, is_dataclass
7
+ from functools import lru_cache
8
+ from types import UnionType
9
+
10
+ # ruff: noqa: UP035
11
+ from typing import (
12
+ Annotated,
13
+ Any,
14
+ Callable,
15
+ Deque,
16
+ ForwardRef,
17
+ FrozenSet,
18
+ List,
19
+ Literal,
20
+ Set,
21
+ Tuple,
22
+ Union,
23
+ cast,
24
+ )
25
+
26
+ from pydantic import (
27
+ BaseModel,
28
+ PydanticSchemaGenerationError,
29
+ ValidationError,
30
+ create_model,
31
+ )
32
+ from pydantic._internal._typing_extra import try_eval_type
33
+ from pydantic._internal._utils import lenient_issubclass
34
+ from pydantic.fields import FieldInfo
35
+ from pydantic_core import PydanticUndefined
36
+ from typing_extensions import get_args, get_origin
37
+
38
+ from wellapi import params
39
+ from wellapi.datastructures import Headers, ImmutableMultiDict, QueryParams
40
+ from wellapi.dependencies.models import (
41
+ Dependant,
42
+ ModelField,
43
+ SecurityRequirement,
44
+ _regenerate_error_with_loc,
45
+ )
46
+ from wellapi.exceptions import WellAPIError
47
+ from wellapi.models import RequestAPIGateway, RequestSQS, ResponseAPIGateway
48
+ from wellapi.security import OAuth2, SecurityBase
49
+
50
+
51
+ def get_dependant(
52
+ *,
53
+ path: str,
54
+ call: Callable[..., Any],
55
+ type_: Literal["endpoint", "queue", "job"],
56
+ name: str | None = None,
57
+ security_scopes: list[str] | None = None,
58
+ use_cache: bool = True,
59
+ ) -> Dependant:
60
+ path_param_names = get_path_param_names(path)
61
+ endpoint_signature = get_typed_signature(call)
62
+ signature_params = endpoint_signature.parameters
63
+ dependant = Dependant(
64
+ call=call,
65
+ name=name,
66
+ path=path,
67
+ security_scopes=security_scopes,
68
+ use_cache=use_cache,
69
+ )
70
+ for param_name, param in signature_params.items():
71
+ is_path_param = param_name in path_param_names
72
+ param_details = analyze_param(
73
+ param_name=param_name,
74
+ annotation=param.annotation,
75
+ value=param.default,
76
+ is_path_param=is_path_param,
77
+ type_=type_,
78
+ )
79
+ if param_details.depends is not None:
80
+ sub_dependant = get_param_sub_dependant(
81
+ param_name=param_name,
82
+ depends=param_details.depends,
83
+ path=path,
84
+ security_scopes=security_scopes,
85
+ type_=type_,
86
+ )
87
+ dependant.dependencies.append(sub_dependant)
88
+ continue
89
+ if add_non_field_param_to_dependency(
90
+ param_name=param_name,
91
+ type_annotation=param_details.type_annotation,
92
+ dependant=dependant,
93
+ ):
94
+ assert param_details.field is None, (
95
+ f"Cannot specify multiple FastAPI annotations for {param_name!r}"
96
+ )
97
+ continue
98
+ assert param_details.field is not None
99
+ if isinstance(param_details.field.field_info, params.Body):
100
+ dependant.body_params.append(param_details.field)
101
+ else:
102
+ add_param_to_fields(field=param_details.field, dependant=dependant)
103
+ return dependant
104
+
105
+
106
+ def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None:
107
+ field_info = field.field_info
108
+ field_info_in = getattr(field_info, "in_", None)
109
+ if field_info_in == params.ParamTypes.path:
110
+ dependant.path_params.append(field)
111
+ elif field_info_in == params.ParamTypes.query:
112
+ dependant.query_params.append(field)
113
+ elif field_info_in == params.ParamTypes.header:
114
+ dependant.header_params.append(field)
115
+ else:
116
+ assert field_info_in == params.ParamTypes.cookie, (
117
+ f"non-body parameters must be in path, query, header or cookie: {field.name}"
118
+ )
119
+ dependant.cookie_params.append(field)
120
+
121
+
122
+ def add_non_field_param_to_dependency(
123
+ *, param_name: str, type_annotation: Any, dependant: Dependant
124
+ ) -> bool | None:
125
+ if lenient_issubclass(type_annotation, RequestAPIGateway):
126
+ dependant.request_param_name = param_name
127
+ return True
128
+ elif lenient_issubclass(type_annotation, ResponseAPIGateway):
129
+ dependant.response_param_name = param_name
130
+ return True
131
+ elif lenient_issubclass(type_annotation, RequestSQS):
132
+ dependant.request_sqs_param_name = param_name
133
+ return True
134
+
135
+ return None
136
+
137
+
138
+ def get_param_sub_dependant(
139
+ *,
140
+ param_name: str,
141
+ depends: params.Depends,
142
+ path: str,
143
+ type_: Literal["endpoint", "queue", "job"],
144
+ security_scopes: list[str] | None = None,
145
+ ) -> Dependant:
146
+ assert depends.dependency
147
+ return get_sub_dependant(
148
+ depends=depends,
149
+ dependency=depends.dependency,
150
+ path=path,
151
+ type_=type_,
152
+ name=param_name,
153
+ security_scopes=security_scopes,
154
+ )
155
+
156
+
157
+ def get_parameterless_sub_dependant(
158
+ *,
159
+ depends: params.Depends,
160
+ path: str,
161
+ type_: Literal["endpoint", "queue", "job"],
162
+ ) -> Dependant:
163
+ assert callable(depends.dependency), (
164
+ "A parameter-less dependency must have a callable dependency"
165
+ )
166
+ return get_sub_dependant(
167
+ depends=depends, dependency=depends.dependency, path=path, type_=type_
168
+ )
169
+
170
+
171
+ def get_sub_dependant(
172
+ *,
173
+ depends: params.Depends,
174
+ dependency: Callable[..., Any],
175
+ path: str,
176
+ type_: Literal["endpoint", "queue", "job"],
177
+ name: str | None = None,
178
+ security_scopes: list[str] | None = None,
179
+ ) -> Dependant:
180
+ security_requirement = None
181
+ security_scopes = security_scopes or []
182
+ if isinstance(depends, params.Security):
183
+ dependency_scopes = depends.scopes
184
+ security_scopes.extend(dependency_scopes)
185
+ if isinstance(dependency, SecurityBase):
186
+ use_scopes: list[str] = []
187
+ if isinstance(dependency, OAuth2):
188
+ use_scopes = security_scopes
189
+ security_requirement = SecurityRequirement(
190
+ security_scheme=dependency, scopes=use_scopes
191
+ )
192
+ sub_dependant = get_dependant(
193
+ path=path,
194
+ call=dependency,
195
+ name=name,
196
+ type_=type_,
197
+ security_scopes=security_scopes,
198
+ use_cache=depends.use_cache,
199
+ )
200
+ if security_requirement:
201
+ sub_dependant.security_requirements.append(security_requirement)
202
+ return sub_dependant
203
+
204
+
205
+ def get_path_param_names(path: str) -> set[str]:
206
+ return set(re.findall("{(.*?)}", path))
207
+
208
+
209
+ def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
210
+ signature = inspect.signature(call)
211
+ globalns = getattr(call, "__globals__", {})
212
+ typed_params = [
213
+ inspect.Parameter(
214
+ name=param.name,
215
+ kind=param.kind,
216
+ default=param.default,
217
+ annotation=get_typed_annotation(param.annotation, globalns),
218
+ )
219
+ for param in signature.parameters.values()
220
+ ]
221
+ typed_signature = inspect.Signature(typed_params)
222
+ return typed_signature
223
+
224
+
225
+ def get_typed_annotation(annotation: Any, globalns: dict[str, Any]) -> Any:
226
+ if isinstance(annotation, str):
227
+ annotation = ForwardRef(annotation)
228
+ annotation, _ = try_eval_type(annotation, globalns, globalns)
229
+ return annotation
230
+
231
+
232
+ # ruff: noqa: UP006
233
+ sequence_annotation_to_type = {
234
+ Sequence: list,
235
+ List: list,
236
+ list: list,
237
+ Tuple: tuple,
238
+ tuple: tuple,
239
+ Set: set,
240
+ set: set,
241
+ FrozenSet: frozenset,
242
+ frozenset: frozenset,
243
+ Deque: deque,
244
+ deque: deque,
245
+ }
246
+
247
+ sequence_types = tuple(sequence_annotation_to_type.keys())
248
+
249
+
250
+ def _annotation_is_sequence(annotation: type[Any] | None) -> bool:
251
+ if lenient_issubclass(annotation, (str, bytes)):
252
+ return False
253
+ return lenient_issubclass(annotation, sequence_types)
254
+
255
+
256
+ def _annotation_is_complex(annotation: type[Any] | None) -> bool:
257
+ return (
258
+ lenient_issubclass(annotation, (BaseModel, Mapping))
259
+ or _annotation_is_sequence(annotation)
260
+ or is_dataclass(annotation)
261
+ )
262
+
263
+
264
+ def field_annotation_is_complex(annotation: type[Any] | None) -> bool:
265
+ origin = get_origin(annotation)
266
+ if origin is Union or origin is UnionType:
267
+ return any(field_annotation_is_complex(arg) for arg in get_args(annotation))
268
+
269
+ return (
270
+ _annotation_is_complex(annotation)
271
+ or _annotation_is_complex(origin)
272
+ or hasattr(origin, "__pydantic_core_schema__")
273
+ or hasattr(origin, "__get_pydantic_core_schema__")
274
+ )
275
+
276
+
277
+ def field_annotation_is_scalar(annotation: Any) -> bool:
278
+ # handle Ellipsis here to make tuple[int, ...] work nicely
279
+ return annotation is Ellipsis or not field_annotation_is_complex(annotation)
280
+
281
+
282
+ def create_model_field(
283
+ name: str,
284
+ type_: Any,
285
+ default: Any | None = PydanticUndefined,
286
+ field_info: FieldInfo | None = None,
287
+ alias: str | None = None,
288
+ mode: Literal["validation", "serialization"] = "validation",
289
+ ) -> ModelField:
290
+ field_info = field_info or FieldInfo(annotation=type_, default=default, alias=alias)
291
+
292
+ try:
293
+ return ModelField(
294
+ name=name,
295
+ field_info=field_info,
296
+ mode=mode,
297
+ )
298
+ except (RuntimeError, PydanticSchemaGenerationError):
299
+ raise WellAPIError(
300
+ "Invalid args for response field! Hint: "
301
+ f"check that {type_} is a valid Pydantic field type. "
302
+ "If you are using a return type annotation that is not a valid Pydantic "
303
+ "field (e.g. Union[Response, dict, None]) you can disable generating the "
304
+ "response model from the type annotation with the path operation decorator "
305
+ "parameter response_model=None. Read more: "
306
+ "https://fastapi.tiangolo.com/tutorial/response-model/"
307
+ ) from None
308
+
309
+
310
+ def is_scalar_field(field: ModelField) -> bool:
311
+ return field_annotation_is_scalar(field.field_info.annotation) and not isinstance(
312
+ field.field_info, params.Body
313
+ )
314
+
315
+
316
+ def is_scalar_sequence_field(field: ModelField) -> bool:
317
+ return field_annotation_is_scalar_sequence(field.field_info.annotation)
318
+
319
+
320
+ def field_annotation_is_scalar_sequence(annotation: type[Any] | None) -> bool:
321
+ origin = get_origin(annotation)
322
+ if origin is Union or origin is UnionType:
323
+ at_least_one_scalar_sequence = False
324
+ for arg in get_args(annotation):
325
+ if field_annotation_is_scalar_sequence(arg):
326
+ at_least_one_scalar_sequence = True
327
+ continue
328
+ elif not field_annotation_is_scalar(arg):
329
+ return False
330
+ return at_least_one_scalar_sequence
331
+ return field_annotation_is_sequence(annotation) and all(
332
+ field_annotation_is_scalar(sub_annotation)
333
+ for sub_annotation in get_args(annotation)
334
+ )
335
+
336
+
337
+ def field_annotation_is_sequence(annotation: type[Any] | None) -> bool:
338
+ origin = get_origin(annotation)
339
+ if origin is Union or origin is UnionType:
340
+ for arg in get_args(annotation):
341
+ if field_annotation_is_sequence(arg):
342
+ return True
343
+ return False
344
+ return _annotation_is_sequence(annotation) or _annotation_is_sequence(
345
+ get_origin(annotation)
346
+ )
347
+
348
+
349
+ @dataclass
350
+ class ParamDetails:
351
+ type_annotation: Any
352
+ depends: params.Depends | None
353
+ field: ModelField | None
354
+
355
+
356
+ def analyze_param(
357
+ *,
358
+ param_name: str,
359
+ annotation: Any,
360
+ value: Any,
361
+ is_path_param: bool,
362
+ type_: Literal["endpoint", "queue", "job"],
363
+ ) -> ParamDetails:
364
+ field_info = None
365
+ depends = None
366
+ type_annotation: Any = Any
367
+ use_annotation: Any = Any
368
+ if annotation is not inspect.Signature.empty:
369
+ use_annotation = annotation
370
+ type_annotation = annotation
371
+ # Extract Annotated info
372
+ if get_origin(use_annotation) is Annotated:
373
+ annotated_args = get_args(annotation)
374
+ type_annotation = annotated_args[0]
375
+ fastapi_annotations = [
376
+ arg
377
+ for arg in annotated_args[1:]
378
+ if isinstance(arg, FieldInfo | params.Depends)
379
+ ]
380
+ fastapi_specific_annotations = [
381
+ arg
382
+ for arg in fastapi_annotations
383
+ if isinstance(arg, params.Param | params.Body | params.Depends)
384
+ ]
385
+ if fastapi_specific_annotations:
386
+ fastapi_annotation: FieldInfo | params.Depends | None = (
387
+ fastapi_specific_annotations[-1]
388
+ )
389
+ else:
390
+ fastapi_annotation = None
391
+ # Set default for Annotated FieldInfo
392
+ if isinstance(fastapi_annotation, FieldInfo):
393
+ # Copy `field_info` because we mutate `field_info.default` below.
394
+ field_info = copy_field_info(
395
+ field_info=fastapi_annotation, annotation=use_annotation
396
+ )
397
+ assert (
398
+ field_info.default is PydanticUndefined
399
+ or field_info.default is Ellipsis
400
+ ), (
401
+ f"`{field_info.__class__.__name__}` default value cannot be set in"
402
+ f" `Annotated` for {param_name!r}. Set the default value with `=` instead."
403
+ )
404
+ if value is not inspect.Signature.empty:
405
+ assert not is_path_param, "Path parameters cannot have default values"
406
+ field_info.default = value
407
+ else:
408
+ field_info.default = PydanticUndefined
409
+ # Get Annotated Depends
410
+ elif isinstance(fastapi_annotation, params.Depends):
411
+ depends = fastapi_annotation
412
+ # Get Depends from default value
413
+ if isinstance(value, params.Depends):
414
+ assert depends is None, (
415
+ "Cannot specify `Depends` in `Annotated` and default value"
416
+ f" together for {param_name!r}"
417
+ )
418
+ assert field_info is None, (
419
+ "Cannot specify a FastAPI annotation in `Annotated` and `Depends` as a"
420
+ f" default value together for {param_name!r}"
421
+ )
422
+ depends = value
423
+ # Get FieldInfo from default value
424
+ elif isinstance(value, FieldInfo):
425
+ assert field_info is None, (
426
+ "Cannot specify FastAPI annotations in `Annotated` and default value"
427
+ f" together for {param_name!r}"
428
+ )
429
+ field_info = value
430
+ field_info.annotation = type_annotation
431
+
432
+ # Get Depends from type annotation
433
+ if depends is not None and depends.dependency is None:
434
+ # Copy `depends` before mutating it
435
+ depends = copy(depends)
436
+ depends.dependency = type_annotation
437
+
438
+ # Handle non-param type annotations like Request
439
+ if lenient_issubclass(
440
+ type_annotation,
441
+ (RequestAPIGateway, ResponseAPIGateway, RequestSQS),
442
+ ):
443
+ assert depends is None, f"Cannot specify `Depends` for type {type_annotation!r}"
444
+ assert field_info is None, (
445
+ f"Cannot specify FastAPI annotation for type {type_annotation!r}"
446
+ )
447
+ # Handle default assignations, neither field_info nor depends was not found in Annotated nor default value
448
+ elif field_info is None and depends is None:
449
+ default_value = value if value is not inspect.Signature.empty else Ellipsis
450
+ if is_path_param:
451
+ # We might check here that `default_value is RequiredParam`, but the fact is that the same
452
+ # parameter might sometimes be a path parameter and sometimes not. See
453
+ # `tests/test_infer_param_optionality.py` for an example.
454
+ field_info = params.Path(annotation=use_annotation)
455
+ elif not field_annotation_is_scalar(annotation=type_annotation):
456
+ field_info = params.Body(annotation=use_annotation, default=default_value)
457
+ else:
458
+ field_info = params.Query(annotation=use_annotation, default=default_value)
459
+
460
+ field = None
461
+ # It's a field_info, not a dependency
462
+ if field_info is not None:
463
+ if type_ in ("queue", "job"):
464
+ assert not isinstance(field_info, params.Param), (
465
+ f"Cannot use `{field_info.__class__.__name__}` for {type_} param"
466
+ )
467
+ if type_ == "job":
468
+ assert isinstance(field_info, params.Body), (
469
+ f"Cannot use `{field_info.__class__.__name__}` for job param"
470
+ )
471
+ # Handle field_info.in_
472
+ if is_path_param:
473
+ assert isinstance(field_info, params.Path), (
474
+ f"Cannot use `{field_info.__class__.__name__}` for path param"
475
+ f" {param_name!r}"
476
+ )
477
+ elif (
478
+ isinstance(field_info, params.Param)
479
+ and getattr(field_info, "in_", None) is None
480
+ ):
481
+ field_info.in_ = params.ParamTypes.query
482
+
483
+ if not field_info.alias and getattr(field_info, "convert_underscores", None):
484
+ alias = param_name.replace("_", "-")
485
+ else:
486
+ alias = field_info.alias or param_name
487
+ field_info.alias = alias
488
+ field = create_model_field(
489
+ name=param_name,
490
+ type_=use_annotation,
491
+ default=field_info.default,
492
+ alias=alias,
493
+ field_info=field_info,
494
+ )
495
+ if is_path_param:
496
+ assert is_scalar_field(field=field), (
497
+ "Path params must be of one of the supported types"
498
+ )
499
+ elif isinstance(field_info, params.Query):
500
+ assert (
501
+ is_scalar_field(field)
502
+ or is_scalar_sequence_field(field)
503
+ or lenient_issubclass(field.type_, BaseModel)
504
+ )
505
+
506
+ return ParamDetails(type_annotation=type_annotation, depends=depends, field=field)
507
+
508
+
509
+ def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
510
+ cls = type(field_info)
511
+ merged_field_info = cls.from_annotation(annotation)
512
+ new_field_info = copy(field_info)
513
+ new_field_info.metadata = merged_field_info.metadata
514
+ new_field_info.annotation = merged_field_info.annotation
515
+ return new_field_info
516
+
517
+
518
+ CacheKey = tuple[Callable[..., Any] | None, tuple[str, ...]]
519
+
520
+
521
+ def get_flat_dependant(
522
+ dependant: Dependant,
523
+ *,
524
+ skip_repeats: bool = False,
525
+ visited: list[CacheKey] | None = None,
526
+ ) -> Dependant:
527
+ if visited is None:
528
+ visited = []
529
+ visited.append(dependant.cache_key)
530
+
531
+ flat_dependant = Dependant(
532
+ path_params=dependant.path_params.copy(),
533
+ query_params=dependant.query_params.copy(),
534
+ header_params=dependant.header_params.copy(),
535
+ cookie_params=dependant.cookie_params.copy(),
536
+ body_params=dependant.body_params.copy(),
537
+ security_requirements=dependant.security_requirements.copy(),
538
+ use_cache=dependant.use_cache,
539
+ path=dependant.path,
540
+ )
541
+ for sub_dependant in dependant.dependencies:
542
+ if skip_repeats and sub_dependant.cache_key in visited:
543
+ continue
544
+ flat_sub = get_flat_dependant(
545
+ sub_dependant, skip_repeats=skip_repeats, visited=visited
546
+ )
547
+ flat_dependant.path_params.extend(flat_sub.path_params)
548
+ flat_dependant.query_params.extend(flat_sub.query_params)
549
+ flat_dependant.header_params.extend(flat_sub.header_params)
550
+ flat_dependant.cookie_params.extend(flat_sub.cookie_params)
551
+ flat_dependant.body_params.extend(flat_sub.body_params)
552
+ flat_dependant.security_requirements.extend(flat_sub.security_requirements)
553
+ return flat_dependant
554
+
555
+
556
+ def _should_embed_body_fields(fields: list[ModelField]) -> bool:
557
+ if not fields:
558
+ return False
559
+ # More than one dependency could have the same field, it would show up as multiple
560
+ # fields but it's the same one, so count them by name
561
+ body_param_names_set = {field.name for field in fields}
562
+ # A top level field has to be a single field, not multiple
563
+ if len(body_param_names_set) > 1:
564
+ return True
565
+ first_field = fields[0]
566
+ # If it explicitly specifies it is embedded, it has to be embedded
567
+ if getattr(first_field.field_info, "embed", None):
568
+ return True
569
+
570
+ return False
571
+
572
+
573
+ def get_body_field(
574
+ *, flat_dependant: Dependant, name: str, embed_body_fields: bool
575
+ ) -> ModelField | None:
576
+ """
577
+ Get a ModelField representing the request body for a path operation, combining
578
+ all body parameters into a single field if necessary.
579
+
580
+ Used to check if it's form data (with `isinstance(body_field, params.Form)`)
581
+ or JSON and to generate the JSON Schema for a request body.
582
+
583
+ This is **not** used to validate/parse the request body, that's done with each
584
+ individual body parameter.
585
+ """
586
+ if not flat_dependant.body_params:
587
+ return None
588
+ first_param = flat_dependant.body_params[0]
589
+ if not embed_body_fields:
590
+ return first_param
591
+ model_name = "Body_" + name
592
+ BodyModel = create_body_model(
593
+ fields=flat_dependant.body_params, model_name=model_name
594
+ )
595
+ required = any(True for f in flat_dependant.body_params if f.required)
596
+ BodyFieldInfo_kwargs: dict[str, Any] = {
597
+ "annotation": BodyModel,
598
+ "alias": "body",
599
+ }
600
+ if not required:
601
+ BodyFieldInfo_kwargs["default"] = None
602
+
603
+ body_param_media_types = [
604
+ f.field_info.media_type
605
+ for f in flat_dependant.body_params
606
+ if isinstance(f.field_info, params.Body)
607
+ ]
608
+ if len(set(body_param_media_types)) == 1:
609
+ BodyFieldInfo_kwargs["media_type"] = body_param_media_types[0]
610
+
611
+ final_field = create_model_field(
612
+ name="body",
613
+ type_=BodyModel,
614
+ alias="body",
615
+ field_info=params.Body(**BodyFieldInfo_kwargs),
616
+ )
617
+ return final_field
618
+
619
+
620
+ def create_body_model(
621
+ *, fields: Sequence[ModelField], model_name: str
622
+ ) -> type[BaseModel]:
623
+ field_params = {f.name: (f.field_info.annotation, f.field_info) for f in fields}
624
+ BodyModel: type[BaseModel] = create_model(model_name, **field_params)
625
+ return BodyModel
626
+
627
+
628
+ @dataclass
629
+ class SolvedDependency:
630
+ values: dict[str, Any]
631
+ errors: list[Any]
632
+ response: ResponseAPIGateway
633
+ dependency_cache: dict[tuple[Callable[..., Any], tuple[str]], Any]
634
+
635
+
636
+ def solve_dependencies(
637
+ *,
638
+ request: RequestAPIGateway,
639
+ dependant: Dependant,
640
+ body: dict[str, Any] | None = None,
641
+ response: ResponseAPIGateway | None = None,
642
+ dependency_cache: dict[tuple[Callable[..., Any], tuple[str]], Any] | None = None,
643
+ embed_body_fields: bool,
644
+ ) -> SolvedDependency:
645
+ values: dict[str, Any] = {}
646
+ errors: list[Any] = []
647
+ if response is None:
648
+ response = ResponseAPIGateway()
649
+ del response.headers["content-length"]
650
+ response.statusCode = None # type: ignore
651
+ dependency_cache = dependency_cache or {}
652
+ sub_dependant: Dependant
653
+ for sub_dependant in dependant.dependencies:
654
+ sub_dependant.call = cast(Callable[..., Any], sub_dependant.call)
655
+ sub_dependant.cache_key = cast(
656
+ tuple[Callable[..., Any], tuple[str]], sub_dependant.cache_key
657
+ )
658
+ call = sub_dependant.call
659
+ use_sub_dependant = sub_dependant
660
+
661
+ solved_result = solve_dependencies(
662
+ request=request,
663
+ dependant=use_sub_dependant,
664
+ body=body,
665
+ response=response,
666
+ dependency_cache=dependency_cache,
667
+ embed_body_fields=embed_body_fields,
668
+ )
669
+ dependency_cache.update(solved_result.dependency_cache)
670
+ if solved_result.errors:
671
+ errors.extend(solved_result.errors)
672
+ continue
673
+ if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache:
674
+ solved = dependency_cache[sub_dependant.cache_key]
675
+ else:
676
+ solved = call(**solved_result.values)
677
+ if sub_dependant.name is not None:
678
+ values[sub_dependant.name] = solved
679
+ if sub_dependant.cache_key not in dependency_cache:
680
+ dependency_cache[sub_dependant.cache_key] = solved
681
+ path_values, path_errors = request_params_to_args(
682
+ dependant.path_params, request.path_params
683
+ )
684
+ query_values, query_errors = request_params_to_args(
685
+ dependant.query_params, request.query_params
686
+ )
687
+ header_values, header_errors = request_params_to_args(
688
+ dependant.header_params, request.headers
689
+ )
690
+ cookie_values, cookie_errors = request_params_to_args(
691
+ dependant.cookie_params, request.cookies
692
+ )
693
+ values.update(path_values)
694
+ values.update(query_values)
695
+ values.update(header_values)
696
+ values.update(cookie_values)
697
+ errors += path_errors + query_errors + header_errors + cookie_errors
698
+ if dependant.body_params:
699
+ (
700
+ body_values,
701
+ body_errors,
702
+ ) = request_body_to_args( # body_params checked above
703
+ body_fields=dependant.body_params,
704
+ received_body=body,
705
+ embed_body_fields=embed_body_fields,
706
+ )
707
+ values.update(body_values)
708
+ errors.extend(body_errors)
709
+ if dependant.http_connection_param_name:
710
+ values[dependant.http_connection_param_name] = request
711
+ if dependant.request_param_name and isinstance(request, RequestAPIGateway):
712
+ values[dependant.request_param_name] = request
713
+ if dependant.request_sqs_param_name and isinstance(request, RequestSQS):
714
+ values[dependant.request_sqs_param_name] = request
715
+ if dependant.response_param_name:
716
+ values[dependant.response_param_name] = response
717
+ return SolvedDependency(
718
+ values=values,
719
+ errors=errors,
720
+ response=response,
721
+ dependency_cache=dependency_cache,
722
+ )
723
+
724
+
725
+ def request_params_to_args(
726
+ fields: Sequence[ModelField],
727
+ received_params: Mapping[str, Any] | QueryParams | Headers,
728
+ ) -> tuple[dict[str, Any], list[Any]]:
729
+ values: dict[str, Any] = {}
730
+ errors: list[dict[str, Any]] = []
731
+
732
+ if not fields:
733
+ return values, errors
734
+
735
+ first_field = fields[0]
736
+ fields_to_extract = fields
737
+ single_not_embedded_field = False
738
+ default_convert_underscores = True
739
+ if len(fields) == 1 and lenient_issubclass(first_field.type_, BaseModel):
740
+ fields_to_extract = get_cached_model_fields(first_field.type_)
741
+ single_not_embedded_field = True
742
+ # If headers are in a Pydantic model, the way to disable convert_underscores
743
+ # would be with Header(convert_underscores=False) at the Pydantic model level
744
+ default_convert_underscores = getattr(
745
+ first_field.field_info, "convert_underscores", True
746
+ )
747
+
748
+ params_to_process: dict[str, Any] = {}
749
+
750
+ processed_keys = set()
751
+
752
+ for field in fields_to_extract:
753
+ alias = None
754
+ if isinstance(received_params, Headers):
755
+ # Handle fields extracted from a Pydantic Model for a header, each field
756
+ # doesn't have a FieldInfo of type Header with the default convert_underscores=True
757
+ convert_underscores = getattr(
758
+ field.field_info, "convert_underscores", default_convert_underscores
759
+ )
760
+ if convert_underscores:
761
+ alias = (
762
+ field.alias
763
+ if field.alias != field.name
764
+ else field.name.replace("_", "-")
765
+ )
766
+ value = _get_multidict_value(field, received_params, alias=alias)
767
+ if value is not None:
768
+ params_to_process[field.name] = value
769
+ processed_keys.add(alias or field.alias)
770
+ processed_keys.add(field.name)
771
+
772
+ for key, value in received_params.items():
773
+ if key not in processed_keys:
774
+ params_to_process[key] = value
775
+
776
+ if single_not_embedded_field:
777
+ field_info = first_field.field_info
778
+ assert isinstance(field_info, params.Param), (
779
+ "Params must be subclasses of Param"
780
+ )
781
+ loc: tuple[str, ...] = (field_info.in_.value,)
782
+ v_, errors_ = _validate_value_with_model_field(
783
+ field=first_field, value=params_to_process, values=values, loc=loc
784
+ )
785
+ return {first_field.name: v_}, errors_
786
+
787
+ for field in fields:
788
+ value = _get_multidict_value(field, received_params)
789
+ field_info = field.field_info
790
+ assert isinstance(field_info, params.Param), (
791
+ "Params must be subclasses of Param"
792
+ )
793
+ loc = (field_info.in_.value, field.alias)
794
+ v_, errors_ = _validate_value_with_model_field(
795
+ field=field, value=value, values=values, loc=loc
796
+ )
797
+ if errors_:
798
+ errors.extend(errors_)
799
+ else:
800
+ values[field.name] = v_
801
+ return values, errors
802
+
803
+
804
+ def _get_multidict_value(
805
+ field: ModelField, values: Mapping[str, Any], alias: str | None = None
806
+ ) -> Any:
807
+ alias = alias or field.alias
808
+ if is_sequence_field(field) and isinstance(values, ImmutableMultiDict | Headers):
809
+ value = values.getlist(alias)
810
+ else:
811
+ value = values.get(alias, None)
812
+ if value is None or (is_sequence_field(field) and len(value) == 0):
813
+ if field.required:
814
+ return
815
+ else:
816
+ return deepcopy(field.default)
817
+ return value
818
+
819
+
820
+ def is_sequence_field(field: ModelField) -> bool:
821
+ return field_annotation_is_sequence(field.field_info.annotation)
822
+
823
+
824
+ def _validate_value_with_model_field(
825
+ *, field: ModelField, value: Any, values: dict[str, Any], loc: tuple[str, ...]
826
+ ) -> tuple[Any, list[Any]]:
827
+ if value is None:
828
+ if field.required:
829
+ return None, [get_missing_field_error(loc=loc)]
830
+ else:
831
+ return deepcopy(field.default), []
832
+ v_, errors_ = field.validate(value, values, loc=loc)
833
+ if isinstance(errors_, Exception):
834
+ return None, [errors_]
835
+ elif isinstance(errors_, list):
836
+ new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=())
837
+ return None, new_errors
838
+ else:
839
+ return v_, []
840
+
841
+
842
+ def get_missing_field_error(loc: tuple[str, ...]) -> dict[str, Any]:
843
+ error = ValidationError.from_exception_data(
844
+ "Field required", [{"type": "missing", "loc": loc, "input": {}}]
845
+ ).errors(include_url=False)[0]
846
+ error["input"] = None
847
+ return error # type: ignore[return-value]
848
+
849
+
850
+ def request_body_to_args(
851
+ body_fields: list[ModelField],
852
+ received_body: dict[str, Any] | None,
853
+ embed_body_fields: bool,
854
+ ) -> tuple[dict[str, Any], list[dict[str, Any]]]:
855
+ values: dict[str, Any] = {}
856
+ errors: list[dict[str, Any]] = []
857
+ assert body_fields, "request_body_to_args() should be called with fields"
858
+ single_not_embedded_field = len(body_fields) == 1 and not embed_body_fields
859
+ first_field = body_fields[0]
860
+ body_to_process = received_body
861
+
862
+ if single_not_embedded_field:
863
+ loc: tuple[str, ...] = ("body",)
864
+ v_, errors_ = _validate_value_with_model_field(
865
+ field=first_field, value=body_to_process, values=values, loc=loc
866
+ )
867
+ return {first_field.name: v_}, errors_
868
+ for field in body_fields:
869
+ loc = ("body", field.alias)
870
+ value: Any = None
871
+ if body_to_process is not None:
872
+ try:
873
+ value = body_to_process.get(field.alias)
874
+ # If the received body is a list, not a dict
875
+ except AttributeError:
876
+ errors.append(get_missing_field_error(loc))
877
+ continue
878
+ v_, errors_ = _validate_value_with_model_field(
879
+ field=field, value=value, values=values, loc=loc
880
+ )
881
+ if errors_:
882
+ errors.extend(errors_)
883
+ else:
884
+ values[field.name] = v_
885
+ return values, errors
886
+
887
+
888
+ @lru_cache
889
+ def get_cached_model_fields(model: type[BaseModel]) -> list[ModelField]:
890
+ return [
891
+ ModelField(field_info=field_info, name=name)
892
+ for name, field_info in model.model_fields.items()
893
+ ]
894
+
895
+
896
+ def get_typed_return_annotation(call: Callable[..., Any]) -> Any:
897
+ signature = inspect.signature(call)
898
+ annotation = signature.return_annotation
899
+
900
+ if annotation is inspect.Signature.empty:
901
+ return None
902
+
903
+ globalns = getattr(call, "__globals__", {})
904
+ return get_typed_annotation(annotation, globalns)
905
+
906
+
907
+ def _get_flat_fields_from_params(fields: List[ModelField]) -> List[ModelField]:
908
+ if not fields:
909
+ return fields
910
+ first_field = fields[0]
911
+ if len(fields) == 1 and lenient_issubclass(first_field.type_, BaseModel):
912
+ fields_to_extract = get_cached_model_fields(first_field.type_)
913
+ return fields_to_extract
914
+ return fields
915
+
916
+
917
+ def get_flat_params(dependant: Dependant) -> List[ModelField]:
918
+ flat_dependant = get_flat_dependant(dependant, skip_repeats=True)
919
+ path_params = _get_flat_fields_from_params(flat_dependant.path_params)
920
+ query_params = _get_flat_fields_from_params(flat_dependant.query_params)
921
+ header_params = _get_flat_fields_from_params(flat_dependant.header_params)
922
+ cookie_params = _get_flat_fields_from_params(flat_dependant.cookie_params)
923
+ return path_params + query_params + header_params + cookie_params