modmex-lambda 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- modmex_lambda/__init__.py +62 -0
- modmex_lambda/data_classes/__init__.py +49 -0
- modmex_lambda/data_classes/api_gateway_authorizer_event.py +38 -0
- modmex_lambda/data_classes/api_gateway_proxy_event.py +328 -0
- modmex_lambda/data_classes/api_gateway_websocket_event.py +40 -0
- modmex_lambda/data_classes/cognito_user_pool_event.py +599 -0
- modmex_lambda/data_classes/common.py +441 -0
- modmex_lambda/event_handler/__init__.py +45 -0
- modmex_lambda/event_handler/api_gateway.py +331 -0
- modmex_lambda/event_handler/constants.py +3 -0
- modmex_lambda/event_handler/content_types.py +13 -0
- modmex_lambda/event_handler/cors.py +97 -0
- modmex_lambda/event_handler/dependencies/__init__.py +0 -0
- modmex_lambda/event_handler/dependencies/compat.py +231 -0
- modmex_lambda/event_handler/dependencies/dependant.py +279 -0
- modmex_lambda/event_handler/dependencies/dependency_middleware.py +423 -0
- modmex_lambda/event_handler/dependencies/depends.py +184 -0
- modmex_lambda/event_handler/dependencies/params.py +317 -0
- modmex_lambda/event_handler/dependencies/types.py +14 -0
- modmex_lambda/event_handler/exception_handler.py +70 -0
- modmex_lambda/event_handler/exceptions.py +72 -0
- modmex_lambda/event_handler/gateway_response.py +96 -0
- modmex_lambda/event_handler/middlewares.py +33 -0
- modmex_lambda/event_handler/params.py +44 -0
- modmex_lambda/event_handler/request.py +70 -0
- modmex_lambda/event_handler/response.py +60 -0
- modmex_lambda/event_handler/routing.py +507 -0
- modmex_lambda/event_handler/routing_fallbacks.py +92 -0
- modmex_lambda/event_handler/types.py +31 -0
- modmex_lambda/event_sources.py +53 -0
- modmex_lambda/exceptions.py +3 -0
- modmex_lambda/logging.py +99 -0
- modmex_lambda/params.py +3 -0
- modmex_lambda/parser.py +47 -0
- modmex_lambda/request.py +3 -0
- modmex_lambda/resolver.py +3 -0
- modmex_lambda/response.py +3 -0
- modmex_lambda/routing.py +3 -0
- modmex_lambda/shared/__init__.py +0 -0
- modmex_lambda/shared/cookies.py +84 -0
- modmex_lambda/shared/headers_serializer.py +65 -0
- modmex_lambda/shared/json_encoder.py +53 -0
- modmex_lambda/shared/types.py +4 -0
- modmex_lambda/validation.py +178 -0
- modmex_lambda-0.1.0.dist-info/METADATA +375 -0
- modmex_lambda-0.1.0.dist-info/RECORD +48 -0
- modmex_lambda-0.1.0.dist-info/WHEEL +4 -0
- modmex_lambda-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,279 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import inspect
|
|
4
|
+
import re
|
|
5
|
+
from typing import Any, Callable, ForwardRef, Union, cast, get_args, get_origin
|
|
6
|
+
|
|
7
|
+
from modmex import BaseModel
|
|
8
|
+
|
|
9
|
+
from modmex_lambda.event_handler.dependencies.compat import ModelField, Required, create_body_model, is_scalar_field
|
|
10
|
+
from modmex_lambda.event_handler.dependencies.depends import DependencyParam, _get_depends_from_annotation
|
|
11
|
+
from modmex_lambda.event_handler.dependencies.params import (
|
|
12
|
+
Body,
|
|
13
|
+
Dependant,
|
|
14
|
+
File,
|
|
15
|
+
Form,
|
|
16
|
+
Param,
|
|
17
|
+
ParamTypes,
|
|
18
|
+
analyze_param,
|
|
19
|
+
create_response_field,
|
|
20
|
+
get_flat_dependant,
|
|
21
|
+
)
|
|
22
|
+
from modmex_lambda.event_handler.dependencies.types import UnionType
|
|
23
|
+
from modmex_lambda.event_handler.request import Request
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None:
|
|
27
|
+
field_info = cast(Param, field.field_info)
|
|
28
|
+
param_fields = {
|
|
29
|
+
ParamTypes.path: dependant.path_params,
|
|
30
|
+
ParamTypes.query: dependant.query_params,
|
|
31
|
+
ParamTypes.header: dependant.header_params,
|
|
32
|
+
ParamTypes.cookie: dependant.cookie_params,
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
target = param_fields.get(field_info.in_)
|
|
36
|
+
if target is None:
|
|
37
|
+
raise AssertionError(f"Unsupported param type: {field_info.in_}")
|
|
38
|
+
target.append(field)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def resolve_forward_ref_lenient(
|
|
42
|
+
type_hint: Any,
|
|
43
|
+
globalns: dict[str, Any] | None = None,
|
|
44
|
+
localns: dict[str, Any] | None = None,
|
|
45
|
+
) -> Any:
|
|
46
|
+
globalns = globalns or {}
|
|
47
|
+
localns = localns or globalns
|
|
48
|
+
|
|
49
|
+
if isinstance(type_hint, str):
|
|
50
|
+
try:
|
|
51
|
+
return eval(type_hint, globalns, localns)
|
|
52
|
+
except Exception:
|
|
53
|
+
return ForwardRef(type_hint)
|
|
54
|
+
|
|
55
|
+
if isinstance(type_hint, ForwardRef):
|
|
56
|
+
try:
|
|
57
|
+
return eval(type_hint.__forward_arg__, globalns, localns)
|
|
58
|
+
except Exception:
|
|
59
|
+
return type_hint
|
|
60
|
+
|
|
61
|
+
return type_hint
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def get_typed_annotation(annotation: Any, globalns: dict[str, Any], localns: dict[str, Any] | None = None) -> Any:
|
|
65
|
+
if isinstance(annotation, str):
|
|
66
|
+
return resolve_forward_ref_lenient(ForwardRef(annotation), globalns, localns or globalns)
|
|
67
|
+
return annotation
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def get_closure_namespace(call: Callable[..., Any]) -> dict[str, Any]:
|
|
71
|
+
namespace = dict(getattr(call, "__modmex_lambda_localns__", {}) or {})
|
|
72
|
+
closure = getattr(call, "__closure__", None)
|
|
73
|
+
if not closure:
|
|
74
|
+
return namespace
|
|
75
|
+
|
|
76
|
+
namespace.update({name: cell.cell_contents for name, cell in zip(call.__code__.co_freevars, closure)})
|
|
77
|
+
return namespace
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
|
|
81
|
+
signature = inspect.signature(call)
|
|
82
|
+
globalns = getattr(call, "__globals__", {})
|
|
83
|
+
localns = {**globalns, **get_closure_namespace(call)}
|
|
84
|
+
|
|
85
|
+
typed_params = [
|
|
86
|
+
inspect.Parameter(
|
|
87
|
+
name=param.name,
|
|
88
|
+
kind=param.kind,
|
|
89
|
+
default=param.default,
|
|
90
|
+
annotation=get_typed_annotation(param.annotation, globalns, localns),
|
|
91
|
+
)
|
|
92
|
+
for param in signature.parameters.values()
|
|
93
|
+
]
|
|
94
|
+
|
|
95
|
+
if signature.return_annotation is inspect.Signature.empty:
|
|
96
|
+
return inspect.Signature(typed_params)
|
|
97
|
+
|
|
98
|
+
return_annotation = get_typed_annotation(signature.return_annotation, globalns, localns)
|
|
99
|
+
return inspect.Signature(typed_params, return_annotation=return_annotation)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def get_path_param_names(path: str) -> set[str]:
|
|
103
|
+
return set(re.findall(r"<(\w+)>", path))
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def is_request_annotation(annotation: Any) -> bool:
|
|
107
|
+
if annotation is Request:
|
|
108
|
+
return True
|
|
109
|
+
|
|
110
|
+
origin = get_origin(annotation)
|
|
111
|
+
if origin is Union or origin is UnionType:
|
|
112
|
+
return Request in get_args(annotation)
|
|
113
|
+
|
|
114
|
+
return False
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def get_dependant(
|
|
118
|
+
*,
|
|
119
|
+
path: str,
|
|
120
|
+
call: Callable[..., Any],
|
|
121
|
+
name: str | None = None,
|
|
122
|
+
responses: dict[int, Any] | None = None,
|
|
123
|
+
) -> Dependant:
|
|
124
|
+
dependant = Dependant(call=call, name=name, path=path)
|
|
125
|
+
path_param_names = get_path_param_names(path)
|
|
126
|
+
endpoint_signature = get_typed_signature(call)
|
|
127
|
+
|
|
128
|
+
for param_name, param in endpoint_signature.parameters.items():
|
|
129
|
+
if is_request_annotation(param.annotation):
|
|
130
|
+
continue
|
|
131
|
+
|
|
132
|
+
depends = _get_depends_from_annotation(param.annotation)
|
|
133
|
+
if depends is not None:
|
|
134
|
+
_inherit_local_namespace(parent=call, dependency=depends.dependency)
|
|
135
|
+
dependant.dependencies.append(
|
|
136
|
+
DependencyParam(
|
|
137
|
+
param_name=param_name,
|
|
138
|
+
depends=depends,
|
|
139
|
+
dependant=get_dependant(path=path, call=depends.dependency),
|
|
140
|
+
),
|
|
141
|
+
)
|
|
142
|
+
continue
|
|
143
|
+
|
|
144
|
+
_add_route_param(
|
|
145
|
+
dependant=dependant,
|
|
146
|
+
param_name=param_name,
|
|
147
|
+
param=param,
|
|
148
|
+
is_path_param=param_name in path_param_names,
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
_add_return_annotation(dependant, endpoint_signature)
|
|
152
|
+
return dependant
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def _inherit_local_namespace(*, parent: Callable[..., Any], dependency: Callable[..., Any]) -> None:
|
|
156
|
+
if hasattr(dependency, "__modmex_lambda_localns__"):
|
|
157
|
+
return
|
|
158
|
+
setattr(dependency, "__modmex_lambda_localns__", getattr(parent, "__modmex_lambda_localns__", {}))
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def _add_route_param(
|
|
162
|
+
*,
|
|
163
|
+
dependant: Dependant,
|
|
164
|
+
param_name: str,
|
|
165
|
+
param: inspect.Parameter,
|
|
166
|
+
is_path_param: bool,
|
|
167
|
+
) -> None:
|
|
168
|
+
param_field = analyze_param(
|
|
169
|
+
param_name=param_name,
|
|
170
|
+
annotation=param.annotation,
|
|
171
|
+
value=param.default,
|
|
172
|
+
is_path_param=is_path_param,
|
|
173
|
+
is_response_param=False,
|
|
174
|
+
)
|
|
175
|
+
if param_field is None:
|
|
176
|
+
raise AssertionError(f"Parameter field is None for param: {param_name}")
|
|
177
|
+
|
|
178
|
+
if is_body_param(param_field=param_field, is_path_param=is_path_param):
|
|
179
|
+
dependant.body_params.append(param_field)
|
|
180
|
+
else:
|
|
181
|
+
add_param_to_fields(field=param_field, dependant=dependant)
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def _add_return_annotation(dependant: Dependant, endpoint_signature: inspect.Signature) -> None:
|
|
185
|
+
if endpoint_signature.return_annotation is inspect.Signature.empty:
|
|
186
|
+
return
|
|
187
|
+
|
|
188
|
+
param_field = analyze_param(
|
|
189
|
+
param_name="return",
|
|
190
|
+
annotation=endpoint_signature.return_annotation,
|
|
191
|
+
value=None,
|
|
192
|
+
is_path_param=False,
|
|
193
|
+
is_response_param=True,
|
|
194
|
+
)
|
|
195
|
+
if param_field is None:
|
|
196
|
+
raise AssertionError("Param field is None for return annotation")
|
|
197
|
+
|
|
198
|
+
dependant.return_param = param_field
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool:
|
|
202
|
+
if is_path_param:
|
|
203
|
+
if not is_scalar_field(field=param_field):
|
|
204
|
+
raise AssertionError("Path params must be of one of the supported types")
|
|
205
|
+
return False
|
|
206
|
+
|
|
207
|
+
if is_scalar_field(field=param_field):
|
|
208
|
+
return False
|
|
209
|
+
if isinstance(param_field.field_info, Param):
|
|
210
|
+
return False
|
|
211
|
+
if not isinstance(param_field.field_info, Body):
|
|
212
|
+
raise AssertionError(f"Param: {param_field.name} can only be a request body, use Body()")
|
|
213
|
+
return True
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def get_flat_params(dependant: Dependant) -> list[ModelField]:
|
|
217
|
+
flat_dependant = get_flat_dependant(dependant)
|
|
218
|
+
return (
|
|
219
|
+
flat_dependant.path_params
|
|
220
|
+
+ flat_dependant.query_params
|
|
221
|
+
+ flat_dependant.header_params
|
|
222
|
+
+ flat_dependant.cookie_params
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def get_body_field(*, dependant: Dependant, name: str) -> ModelField | None:
|
|
227
|
+
flat_dependant = get_flat_dependant(dependant)
|
|
228
|
+
if not flat_dependant.body_params:
|
|
229
|
+
return None
|
|
230
|
+
|
|
231
|
+
first_param = flat_dependant.body_params[0]
|
|
232
|
+
if len({param.name for param in flat_dependant.body_params}) == 1 and not getattr(first_param.field_info, "embed", None):
|
|
233
|
+
return first_param
|
|
234
|
+
|
|
235
|
+
for param in flat_dependant.body_params:
|
|
236
|
+
setattr(param.field_info, "embed", True) # noqa: B010
|
|
237
|
+
|
|
238
|
+
body_model = create_body_model(fields=flat_dependant.body_params, model_name=f"Body_{name}")
|
|
239
|
+
required = any(field.required for field in flat_dependant.body_params)
|
|
240
|
+
body_field_info, body_field_info_kwargs = get_body_field_info(
|
|
241
|
+
body_model=body_model,
|
|
242
|
+
flat_dependant=flat_dependant,
|
|
243
|
+
required=required,
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
return create_response_field(
|
|
247
|
+
name="body",
|
|
248
|
+
type_=body_model,
|
|
249
|
+
default=Required if required else None,
|
|
250
|
+
alias="body",
|
|
251
|
+
field_info=body_field_info(**body_field_info_kwargs),
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def get_body_field_info(
|
|
256
|
+
*,
|
|
257
|
+
body_model: type[BaseModel],
|
|
258
|
+
flat_dependant: Dependant,
|
|
259
|
+
required: bool,
|
|
260
|
+
) -> tuple[type[Body], dict[str, Any]]:
|
|
261
|
+
body_field_info_kwargs: dict[str, Any] = {"annotation": body_model, "alias": "body"}
|
|
262
|
+
|
|
263
|
+
if not required:
|
|
264
|
+
body_field_info_kwargs["default"] = None
|
|
265
|
+
|
|
266
|
+
if any(isinstance(field.field_info, File) for field in flat_dependant.body_params):
|
|
267
|
+
body_field_info_kwargs["media_type"] = "multipart/form-data"
|
|
268
|
+
elif any(isinstance(field.field_info, Form) for field in flat_dependant.body_params):
|
|
269
|
+
body_field_info_kwargs["media_type"] = "application/x-www-form-urlencoded"
|
|
270
|
+
else:
|
|
271
|
+
body_param_media_types = [
|
|
272
|
+
field.field_info.media_type
|
|
273
|
+
for field in flat_dependant.body_params
|
|
274
|
+
if isinstance(field.field_info, Body) and hasattr(field.field_info, "media_type")
|
|
275
|
+
]
|
|
276
|
+
if body_param_media_types and len(set(body_param_media_types)) == 1:
|
|
277
|
+
body_field_info_kwargs["media_type"] = body_param_media_types[0]
|
|
278
|
+
|
|
279
|
+
return Body, body_field_info_kwargs
|
|
@@ -0,0 +1,423 @@
|
|
|
1
|
+
from typing import Any, Mapping, MutableMapping, Sequence, Union, get_args, get_origin
|
|
2
|
+
|
|
3
|
+
from modmex import BaseModel, FieldInfo
|
|
4
|
+
from modmex_lambda.exceptions import RequestValidationError
|
|
5
|
+
from modmex_lambda.event_handler.middlewares import IMiddleware, NextMiddleware
|
|
6
|
+
from modmex_lambda.event_handler.dependencies.compat import (
|
|
7
|
+
_normalize_errors,
|
|
8
|
+
_regenerate_error_with_loc,
|
|
9
|
+
field_annotation_is_sequence,
|
|
10
|
+
get_missing_field_error,
|
|
11
|
+
lenient_issubclass,
|
|
12
|
+
is_scalar_field
|
|
13
|
+
)
|
|
14
|
+
from modmex_lambda.event_handler.types import IApiGatewayResolver
|
|
15
|
+
from modmex_lambda.event_handler.routing import IRoute
|
|
16
|
+
|
|
17
|
+
from modmex_lambda.event_handler.dependencies.params import ModelField, Param, UploadFile
|
|
18
|
+
from modmex_lambda.event_handler.dependencies.types import UnionType
|
|
19
|
+
from modmex_lambda.event_handler.response import Response
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class DependencyMiddleware(IMiddleware):
|
|
23
|
+
def _get_body(self, app: IApiGatewayResolver) -> Any:
|
|
24
|
+
return app.current_event.json_body
|
|
25
|
+
|
|
26
|
+
def handler(self, app: IApiGatewayResolver, next_middleware: NextMiddleware) -> Response:
|
|
27
|
+
route: IRoute = app.context.get("_route")
|
|
28
|
+
values: dict[str, Any] = {}
|
|
29
|
+
errors: list[Any] = []
|
|
30
|
+
|
|
31
|
+
param_sources = (
|
|
32
|
+
(route.dependant.path_params, app.context.get("_route_args")),
|
|
33
|
+
(
|
|
34
|
+
route.dependant.query_params,
|
|
35
|
+
_normalize_multi_params(app.current_event.resolved_query_string_parameters, route.dependant.query_params),
|
|
36
|
+
),
|
|
37
|
+
(
|
|
38
|
+
route.dependant.header_params,
|
|
39
|
+
_normalize_multi_params(app.current_event.resolved_headers_field, route.dependant.header_params),
|
|
40
|
+
),
|
|
41
|
+
(route.dependant.cookie_params, getattr(app.current_event, "resolved_cookies_field", {})),
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
for params, received_params in param_sources:
|
|
45
|
+
param_values, param_errors = _request_params_to_args(params, received_params)
|
|
46
|
+
values.update(param_values)
|
|
47
|
+
errors.extend(param_errors)
|
|
48
|
+
|
|
49
|
+
if route.dependant.body_params:
|
|
50
|
+
body_values, body_errors = _request_body_to_args(
|
|
51
|
+
required_params=route.dependant.body_params,
|
|
52
|
+
received_body=self._get_body(app),
|
|
53
|
+
)
|
|
54
|
+
values.update(body_values)
|
|
55
|
+
errors.extend(body_errors)
|
|
56
|
+
|
|
57
|
+
if errors:
|
|
58
|
+
raise RequestValidationError(_normalize_errors(errors))
|
|
59
|
+
|
|
60
|
+
app.context["_route_args"] = values
|
|
61
|
+
return next_middleware(app)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def _request_params_to_args(
|
|
66
|
+
required_params: Sequence[ModelField],
|
|
67
|
+
received_params: Mapping[str, Any],
|
|
68
|
+
) -> tuple[dict[str, Any], list[dict[str, Any]]]:
|
|
69
|
+
"""
|
|
70
|
+
Convert the request params to a dictionary of values using validation, and returns a list of errors.
|
|
71
|
+
"""
|
|
72
|
+
values: dict[str, Any] = {}
|
|
73
|
+
errors: list[dict[str, Any]] = []
|
|
74
|
+
|
|
75
|
+
for field in required_params:
|
|
76
|
+
field_info = field.field_info
|
|
77
|
+
|
|
78
|
+
# To ensure early failure, we check if it's not an instance of Param.
|
|
79
|
+
if not isinstance(field_info, Param):
|
|
80
|
+
raise AssertionError(f"Expected Param field_info, got {field_info}")
|
|
81
|
+
|
|
82
|
+
loc = (field_info.in_.value, field.alias)
|
|
83
|
+
value = received_params.get(field.alias)
|
|
84
|
+
|
|
85
|
+
# If we don't have a value, see if it's required or has a default
|
|
86
|
+
if value is None:
|
|
87
|
+
_handle_missing_field_value(field, values, errors, loc)
|
|
88
|
+
continue
|
|
89
|
+
|
|
90
|
+
# Finally, validate the value
|
|
91
|
+
values[field.name] = _validate_field(field=field, value=value, loc=loc, existing_errors=errors)
|
|
92
|
+
|
|
93
|
+
return values, errors
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def _request_body_to_args(
|
|
97
|
+
required_params: list[ModelField],
|
|
98
|
+
received_body: dict[str, Any] | None,
|
|
99
|
+
) -> tuple[dict[str, Any], list[dict[str, Any]]]:
|
|
100
|
+
"""
|
|
101
|
+
Convert the request body to a dictionary of values using validation, and returns a list of errors.
|
|
102
|
+
"""
|
|
103
|
+
values: dict[str, Any] = {}
|
|
104
|
+
errors: list[dict[str, Any]] = []
|
|
105
|
+
|
|
106
|
+
received_body, field_alias_omitted = _get_embed_body(
|
|
107
|
+
field=required_params[0],
|
|
108
|
+
required_params=required_params,
|
|
109
|
+
received_body=received_body,
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
for field in required_params:
|
|
113
|
+
loc = _get_body_field_location(field, field_alias_omitted)
|
|
114
|
+
value = _extract_field_value_from_body(field, received_body, loc, errors)
|
|
115
|
+
|
|
116
|
+
# If we don't have a value, see if it's required or has a default
|
|
117
|
+
if value is None:
|
|
118
|
+
_handle_missing_field_value(field, values, errors, loc)
|
|
119
|
+
continue
|
|
120
|
+
|
|
121
|
+
value = _normalize_field_value(value=value, field_info=field.field_info)
|
|
122
|
+
|
|
123
|
+
# UploadFile objects bypass Pydantic validation — they're already constructed
|
|
124
|
+
if isinstance(value, UploadFile):
|
|
125
|
+
values[field.name] = value
|
|
126
|
+
else:
|
|
127
|
+
values[field.name] = _validate_field(field=field, value=value, loc=loc, existing_errors=errors)
|
|
128
|
+
|
|
129
|
+
return values, errors
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def _get_body_field_location(field: ModelField, field_alias_omitted: bool) -> tuple[str, ...]:
|
|
133
|
+
"""Get the location tuple for a body field based on whether the field alias is omitted."""
|
|
134
|
+
if field_alias_omitted:
|
|
135
|
+
return ("body",)
|
|
136
|
+
return ("body", field.alias)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def _extract_field_value_from_body(
|
|
140
|
+
field: ModelField,
|
|
141
|
+
received_body: dict[str, Any] | None,
|
|
142
|
+
loc: tuple[str, ...],
|
|
143
|
+
errors: list[dict[str, Any]],
|
|
144
|
+
) -> Any | None:
|
|
145
|
+
"""Extract field value from the received body, handling potential AttributeError."""
|
|
146
|
+
if received_body is None:
|
|
147
|
+
return None
|
|
148
|
+
|
|
149
|
+
try:
|
|
150
|
+
return received_body.get(field.alias)
|
|
151
|
+
except AttributeError:
|
|
152
|
+
errors.append(get_missing_field_error(loc))
|
|
153
|
+
return None
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def _handle_missing_field_value(
|
|
157
|
+
field: ModelField,
|
|
158
|
+
values: dict[str, Any],
|
|
159
|
+
errors: list[dict[str, Any]],
|
|
160
|
+
loc: tuple[str, ...],
|
|
161
|
+
) -> None:
|
|
162
|
+
"""Handle the case when a field value is missing."""
|
|
163
|
+
if field.required:
|
|
164
|
+
errors.append(get_missing_field_error(loc))
|
|
165
|
+
else:
|
|
166
|
+
values[field.name] = field.get_default()
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def _is_or_contains_sequence(annotation: Any) -> bool:
|
|
170
|
+
"""
|
|
171
|
+
Check if annotation is a sequence or Union/RootModel containing a sequence.
|
|
172
|
+
|
|
173
|
+
This function handles complex type annotations like:
|
|
174
|
+
- List[Model] - direct sequence
|
|
175
|
+
- Union[Model, List[Model]] - checks if any Union member is a sequence
|
|
176
|
+
- Optional[List[Model]] - Union[List[Model], None]
|
|
177
|
+
- RootModel[List[Model]] - checks if the RootModel wraps a sequence
|
|
178
|
+
- Optional[RootModel[List[Model]]] - Union member that is a RootModel
|
|
179
|
+
- RootModel[Union[Model, List[Model]]] - RootModel wrapping a Union with a sequence
|
|
180
|
+
"""
|
|
181
|
+
# Direct sequence check
|
|
182
|
+
if field_annotation_is_sequence(annotation):
|
|
183
|
+
return True
|
|
184
|
+
|
|
185
|
+
# Check Union members — recurse so we catch RootModel inside Union
|
|
186
|
+
origin = get_origin(annotation)
|
|
187
|
+
if origin is Union or origin is UnionType:
|
|
188
|
+
for arg in get_args(annotation):
|
|
189
|
+
if _is_or_contains_sequence(arg):
|
|
190
|
+
return True
|
|
191
|
+
|
|
192
|
+
# Check if it's a RootModel wrapping a sequence (or Union containing a sequence)
|
|
193
|
+
if lenient_issubclass(annotation, BaseModel) and getattr(annotation, "__pydantic_root_model__", False):
|
|
194
|
+
if hasattr(annotation, "model_fields") and "root" in annotation.model_fields:
|
|
195
|
+
root_annotation = annotation.model_fields["root"].annotation
|
|
196
|
+
return _is_or_contains_sequence(root_annotation)
|
|
197
|
+
|
|
198
|
+
return False
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def _normalize_field_value(value: Any, field_info: FieldInfo) -> Any:
|
|
202
|
+
"""Normalize field value, converting lists to single values for non-sequence fields."""
|
|
203
|
+
return _normalize_value(value=value, annotation=field_info.annotation)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def _normalize_value(value: Any, annotation: Any) -> Any:
|
|
207
|
+
if isinstance(value, UploadFile) and annotation is bytes:
|
|
208
|
+
return value.content
|
|
209
|
+
|
|
210
|
+
if _is_or_contains_sequence(annotation):
|
|
211
|
+
return value
|
|
212
|
+
|
|
213
|
+
if isinstance(value, list) and value:
|
|
214
|
+
return value[0]
|
|
215
|
+
|
|
216
|
+
return value
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def _validate_field(
|
|
220
|
+
*,
|
|
221
|
+
field: ModelField,
|
|
222
|
+
value: Any,
|
|
223
|
+
loc: tuple[str, ...],
|
|
224
|
+
existing_errors: list[dict[str, Any]],
|
|
225
|
+
):
|
|
226
|
+
"""
|
|
227
|
+
Validate a field, and append any errors to the existing_errors list.
|
|
228
|
+
"""
|
|
229
|
+
validated_value, errors = field.validate(value=value, loc=loc)
|
|
230
|
+
|
|
231
|
+
if isinstance(errors, list):
|
|
232
|
+
processed_errors = _regenerate_error_with_loc(errors=errors, loc_prefix=())
|
|
233
|
+
existing_errors.extend(processed_errors)
|
|
234
|
+
elif errors:
|
|
235
|
+
existing_errors.append(errors)
|
|
236
|
+
|
|
237
|
+
return validated_value
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def _get_embed_body(
|
|
241
|
+
*,
|
|
242
|
+
field: ModelField,
|
|
243
|
+
required_params: list[ModelField],
|
|
244
|
+
received_body: dict[str, Any] | None,
|
|
245
|
+
) -> tuple[dict[str, Any] | None, bool]:
|
|
246
|
+
field_info = field.field_info
|
|
247
|
+
embed = getattr(field_info, "embed", None)
|
|
248
|
+
|
|
249
|
+
# If the field is an embed, and the field alias is omitted, we need to wrap the received body in the field alias.
|
|
250
|
+
field_alias_omitted = len(required_params) == 1 and not embed
|
|
251
|
+
if field_alias_omitted:
|
|
252
|
+
received_body = {field.alias: received_body}
|
|
253
|
+
|
|
254
|
+
return received_body, field_alias_omitted
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def _normalize_multi_params(
|
|
258
|
+
input_dict: MutableMapping[str, Any],
|
|
259
|
+
params: Sequence[ModelField],
|
|
260
|
+
) -> MutableMapping[str, Any]:
|
|
261
|
+
for param in params:
|
|
262
|
+
if is_scalar_field(param):
|
|
263
|
+
_process_scalar_param(input_dict, param)
|
|
264
|
+
elif lenient_issubclass(param.field_info.annotation, BaseModel):
|
|
265
|
+
_process_model_param(input_dict, param)
|
|
266
|
+
return input_dict
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
def _process_scalar_param(input_dict: MutableMapping[str, Any], param: ModelField) -> None:
|
|
270
|
+
try:
|
|
271
|
+
value = input_dict[param.alias]
|
|
272
|
+
if isinstance(value, list) and len(value) == 1:
|
|
273
|
+
input_dict[param.alias] = value[0]
|
|
274
|
+
except KeyError:
|
|
275
|
+
pass
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
def _process_model_param(input_dict: MutableMapping[str, Any], param: ModelField) -> None:
|
|
279
|
+
model_class = param.field_info.annotation
|
|
280
|
+
|
|
281
|
+
model_data = {}
|
|
282
|
+
for field_name, field_alias, annotation in _iter_model_fields(model_class):
|
|
283
|
+
value = _get_param_value(input_dict, field_alias, field_name, model_class)
|
|
284
|
+
if value is not None:
|
|
285
|
+
model_data[field_alias] = _normalize_value(value=value, annotation=annotation)
|
|
286
|
+
|
|
287
|
+
input_dict[param.alias] = model_data
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def _iter_model_fields(model_class: type[BaseModel]):
|
|
291
|
+
model_fields = getattr(model_class, "model_fields", None)
|
|
292
|
+
if model_fields:
|
|
293
|
+
for field_name, field_info in model_fields.items():
|
|
294
|
+
yield field_name, field_info.alias or field_name, field_info.annotation
|
|
295
|
+
return
|
|
296
|
+
|
|
297
|
+
alias_generator = getattr(model_class, "model_config", {}).get("alias_generator")
|
|
298
|
+
for field in getattr(model_class, "__modmex_fields__", ()):
|
|
299
|
+
field_alias = alias_generator(field.name) if alias_generator else field.name
|
|
300
|
+
yield field.name, field_alias, field.type
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
def _get_param_value(
|
|
304
|
+
input_dict: MutableMapping[str, Any],
|
|
305
|
+
field_alias: str,
|
|
306
|
+
field_name: str,
|
|
307
|
+
model_class: type[BaseModel],
|
|
308
|
+
) -> Any:
|
|
309
|
+
value = input_dict.get(field_alias)
|
|
310
|
+
if value is not None:
|
|
311
|
+
return value
|
|
312
|
+
|
|
313
|
+
model_config = getattr(model_class, "model_config", {})
|
|
314
|
+
if model_config.get("validate_by_name") or model_config.get("populate_by_name"):
|
|
315
|
+
value = input_dict.get(field_name)
|
|
316
|
+
|
|
317
|
+
return value
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
def _extract_multipart_boundary(content_type: str) -> str | None:
|
|
321
|
+
"""Extract the boundary string from a multipart/form-data content-type header."""
|
|
322
|
+
for segment in content_type.split(";"):
|
|
323
|
+
stripped = segment.strip()
|
|
324
|
+
if stripped.startswith("boundary="):
|
|
325
|
+
boundary = stripped[len("boundary=") :]
|
|
326
|
+
# Remove optional quotes around boundary
|
|
327
|
+
if boundary.startswith('"') and boundary.endswith('"'):
|
|
328
|
+
boundary = boundary[1:-1]
|
|
329
|
+
return boundary
|
|
330
|
+
return None
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
def _parse_multipart_body(body: bytes, boundary: str) -> dict[str, Any]:
|
|
334
|
+
"""
|
|
335
|
+
Parse a multipart/form-data body into a dict of field names to values.
|
|
336
|
+
|
|
337
|
+
File fields get bytes values; regular form fields get string values.
|
|
338
|
+
Multiple values for the same field name are collected into lists.
|
|
339
|
+
"""
|
|
340
|
+
delimiter = f"--{boundary}".encode()
|
|
341
|
+
end_delimiter = f"--{boundary}--".encode()
|
|
342
|
+
|
|
343
|
+
result: dict[str, Any] = {}
|
|
344
|
+
|
|
345
|
+
# Split body by the boundary delimiter
|
|
346
|
+
raw_parts = body.split(delimiter)
|
|
347
|
+
|
|
348
|
+
for raw_part in raw_parts:
|
|
349
|
+
# Skip the preamble (before first boundary) and epilogue (after closing boundary)
|
|
350
|
+
if not raw_part or raw_part.strip() == b"" or raw_part.strip() == b"--":
|
|
351
|
+
continue
|
|
352
|
+
|
|
353
|
+
# Remove the end delimiter marker if present
|
|
354
|
+
chunk = raw_part
|
|
355
|
+
if chunk.endswith(end_delimiter):
|
|
356
|
+
chunk = chunk[: -len(end_delimiter)]
|
|
357
|
+
|
|
358
|
+
# Strip leading \r\n
|
|
359
|
+
if chunk.startswith(b"\r\n"):
|
|
360
|
+
chunk = chunk[2:]
|
|
361
|
+
|
|
362
|
+
# Strip trailing \r\n
|
|
363
|
+
if chunk.endswith(b"\r\n"):
|
|
364
|
+
chunk = chunk[:-2]
|
|
365
|
+
|
|
366
|
+
# Split headers from body at the double CRLF
|
|
367
|
+
header_end = chunk.find(b"\r\n\r\n")
|
|
368
|
+
if header_end == -1:
|
|
369
|
+
continue
|
|
370
|
+
|
|
371
|
+
header_section = chunk[:header_end].decode("utf-8")
|
|
372
|
+
body_section = chunk[header_end + 4 :]
|
|
373
|
+
|
|
374
|
+
# Parse Content-Disposition to get the field name and optional filename
|
|
375
|
+
field_name = None
|
|
376
|
+
filename = None
|
|
377
|
+
content_type_header = None
|
|
378
|
+
|
|
379
|
+
for header_line in header_section.split("\r\n"):
|
|
380
|
+
header_lower = header_line.lower()
|
|
381
|
+
if header_lower.startswith("content-disposition:"):
|
|
382
|
+
field_name = _extract_header_param(header_line, "name")
|
|
383
|
+
filename = _extract_header_param(header_line, "filename")
|
|
384
|
+
elif header_lower.startswith("content-type:"):
|
|
385
|
+
content_type_header = header_line.split(":", 1)[1].strip()
|
|
386
|
+
|
|
387
|
+
if field_name is None:
|
|
388
|
+
continue
|
|
389
|
+
|
|
390
|
+
# If it has a filename, it's a file upload — wrap as UploadFile
|
|
391
|
+
# Otherwise it's a regular form field — decode to string
|
|
392
|
+
if filename is not None:
|
|
393
|
+
value: Any = UploadFile(content=body_section, filename=filename, content_type=content_type_header)
|
|
394
|
+
else:
|
|
395
|
+
value = body_section.decode("utf-8")
|
|
396
|
+
|
|
397
|
+
# Collect multiple values for same field name into a list
|
|
398
|
+
if field_name in result:
|
|
399
|
+
existing = result[field_name]
|
|
400
|
+
if isinstance(existing, list):
|
|
401
|
+
existing.append(value)
|
|
402
|
+
else:
|
|
403
|
+
result[field_name] = [existing, value]
|
|
404
|
+
else:
|
|
405
|
+
result[field_name] = value
|
|
406
|
+
|
|
407
|
+
return result
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
def _extract_header_param(header_line: str, param_name: str) -> str | None:
|
|
411
|
+
"""Extract a parameter value from a header line (e.g., name="file" from Content-Disposition)."""
|
|
412
|
+
search = f'{param_name}="'
|
|
413
|
+
idx = header_line.find(search)
|
|
414
|
+
if idx == -1:
|
|
415
|
+
return None
|
|
416
|
+
start = idx + len(search)
|
|
417
|
+
end = header_line.find('"', start)
|
|
418
|
+
if end == -1:
|
|
419
|
+
return None
|
|
420
|
+
return header_line[start:end]
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
__all__ = ["DependencyMiddleware"]
|