modmex-lambda 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. modmex_lambda/__init__.py +62 -0
  2. modmex_lambda/data_classes/__init__.py +49 -0
  3. modmex_lambda/data_classes/api_gateway_authorizer_event.py +38 -0
  4. modmex_lambda/data_classes/api_gateway_proxy_event.py +328 -0
  5. modmex_lambda/data_classes/api_gateway_websocket_event.py +40 -0
  6. modmex_lambda/data_classes/cognito_user_pool_event.py +599 -0
  7. modmex_lambda/data_classes/common.py +441 -0
  8. modmex_lambda/event_handler/__init__.py +45 -0
  9. modmex_lambda/event_handler/api_gateway.py +331 -0
  10. modmex_lambda/event_handler/constants.py +3 -0
  11. modmex_lambda/event_handler/content_types.py +13 -0
  12. modmex_lambda/event_handler/cors.py +97 -0
  13. modmex_lambda/event_handler/dependencies/__init__.py +0 -0
  14. modmex_lambda/event_handler/dependencies/compat.py +231 -0
  15. modmex_lambda/event_handler/dependencies/dependant.py +279 -0
  16. modmex_lambda/event_handler/dependencies/dependency_middleware.py +423 -0
  17. modmex_lambda/event_handler/dependencies/depends.py +184 -0
  18. modmex_lambda/event_handler/dependencies/params.py +317 -0
  19. modmex_lambda/event_handler/dependencies/types.py +14 -0
  20. modmex_lambda/event_handler/exception_handler.py +70 -0
  21. modmex_lambda/event_handler/exceptions.py +72 -0
  22. modmex_lambda/event_handler/gateway_response.py +96 -0
  23. modmex_lambda/event_handler/middlewares.py +33 -0
  24. modmex_lambda/event_handler/params.py +44 -0
  25. modmex_lambda/event_handler/request.py +70 -0
  26. modmex_lambda/event_handler/response.py +60 -0
  27. modmex_lambda/event_handler/routing.py +507 -0
  28. modmex_lambda/event_handler/routing_fallbacks.py +92 -0
  29. modmex_lambda/event_handler/types.py +31 -0
  30. modmex_lambda/event_sources.py +53 -0
  31. modmex_lambda/exceptions.py +3 -0
  32. modmex_lambda/logging.py +99 -0
  33. modmex_lambda/params.py +3 -0
  34. modmex_lambda/parser.py +47 -0
  35. modmex_lambda/request.py +3 -0
  36. modmex_lambda/resolver.py +3 -0
  37. modmex_lambda/response.py +3 -0
  38. modmex_lambda/routing.py +3 -0
  39. modmex_lambda/shared/__init__.py +0 -0
  40. modmex_lambda/shared/cookies.py +84 -0
  41. modmex_lambda/shared/headers_serializer.py +65 -0
  42. modmex_lambda/shared/json_encoder.py +53 -0
  43. modmex_lambda/shared/types.py +4 -0
  44. modmex_lambda/validation.py +178 -0
  45. modmex_lambda-0.1.0.dist-info/METADATA +375 -0
  46. modmex_lambda-0.1.0.dist-info/RECORD +48 -0
  47. modmex_lambda-0.1.0.dist-info/WHEEL +4 -0
  48. modmex_lambda-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,279 @@
1
+ from __future__ import annotations
2
+
3
+ import inspect
4
+ import re
5
+ from typing import Any, Callable, ForwardRef, Union, cast, get_args, get_origin
6
+
7
+ from modmex import BaseModel
8
+
9
+ from modmex_lambda.event_handler.dependencies.compat import ModelField, Required, create_body_model, is_scalar_field
10
+ from modmex_lambda.event_handler.dependencies.depends import DependencyParam, _get_depends_from_annotation
11
+ from modmex_lambda.event_handler.dependencies.params import (
12
+ Body,
13
+ Dependant,
14
+ File,
15
+ Form,
16
+ Param,
17
+ ParamTypes,
18
+ analyze_param,
19
+ create_response_field,
20
+ get_flat_dependant,
21
+ )
22
+ from modmex_lambda.event_handler.dependencies.types import UnionType
23
+ from modmex_lambda.event_handler.request import Request
24
+
25
+
26
+ def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None:
27
+ field_info = cast(Param, field.field_info)
28
+ param_fields = {
29
+ ParamTypes.path: dependant.path_params,
30
+ ParamTypes.query: dependant.query_params,
31
+ ParamTypes.header: dependant.header_params,
32
+ ParamTypes.cookie: dependant.cookie_params,
33
+ }
34
+
35
+ target = param_fields.get(field_info.in_)
36
+ if target is None:
37
+ raise AssertionError(f"Unsupported param type: {field_info.in_}")
38
+ target.append(field)
39
+
40
+
41
+ def resolve_forward_ref_lenient(
42
+ type_hint: Any,
43
+ globalns: dict[str, Any] | None = None,
44
+ localns: dict[str, Any] | None = None,
45
+ ) -> Any:
46
+ globalns = globalns or {}
47
+ localns = localns or globalns
48
+
49
+ if isinstance(type_hint, str):
50
+ try:
51
+ return eval(type_hint, globalns, localns)
52
+ except Exception:
53
+ return ForwardRef(type_hint)
54
+
55
+ if isinstance(type_hint, ForwardRef):
56
+ try:
57
+ return eval(type_hint.__forward_arg__, globalns, localns)
58
+ except Exception:
59
+ return type_hint
60
+
61
+ return type_hint
62
+
63
+
64
+ def get_typed_annotation(annotation: Any, globalns: dict[str, Any], localns: dict[str, Any] | None = None) -> Any:
65
+ if isinstance(annotation, str):
66
+ return resolve_forward_ref_lenient(ForwardRef(annotation), globalns, localns or globalns)
67
+ return annotation
68
+
69
+
70
+ def get_closure_namespace(call: Callable[..., Any]) -> dict[str, Any]:
71
+ namespace = dict(getattr(call, "__modmex_lambda_localns__", {}) or {})
72
+ closure = getattr(call, "__closure__", None)
73
+ if not closure:
74
+ return namespace
75
+
76
+ namespace.update({name: cell.cell_contents for name, cell in zip(call.__code__.co_freevars, closure)})
77
+ return namespace
78
+
79
+
80
+ def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
81
+ signature = inspect.signature(call)
82
+ globalns = getattr(call, "__globals__", {})
83
+ localns = {**globalns, **get_closure_namespace(call)}
84
+
85
+ typed_params = [
86
+ inspect.Parameter(
87
+ name=param.name,
88
+ kind=param.kind,
89
+ default=param.default,
90
+ annotation=get_typed_annotation(param.annotation, globalns, localns),
91
+ )
92
+ for param in signature.parameters.values()
93
+ ]
94
+
95
+ if signature.return_annotation is inspect.Signature.empty:
96
+ return inspect.Signature(typed_params)
97
+
98
+ return_annotation = get_typed_annotation(signature.return_annotation, globalns, localns)
99
+ return inspect.Signature(typed_params, return_annotation=return_annotation)
100
+
101
+
102
+ def get_path_param_names(path: str) -> set[str]:
103
+ return set(re.findall(r"<(\w+)>", path))
104
+
105
+
106
+ def is_request_annotation(annotation: Any) -> bool:
107
+ if annotation is Request:
108
+ return True
109
+
110
+ origin = get_origin(annotation)
111
+ if origin is Union or origin is UnionType:
112
+ return Request in get_args(annotation)
113
+
114
+ return False
115
+
116
+
117
+ def get_dependant(
118
+ *,
119
+ path: str,
120
+ call: Callable[..., Any],
121
+ name: str | None = None,
122
+ responses: dict[int, Any] | None = None,
123
+ ) -> Dependant:
124
+ dependant = Dependant(call=call, name=name, path=path)
125
+ path_param_names = get_path_param_names(path)
126
+ endpoint_signature = get_typed_signature(call)
127
+
128
+ for param_name, param in endpoint_signature.parameters.items():
129
+ if is_request_annotation(param.annotation):
130
+ continue
131
+
132
+ depends = _get_depends_from_annotation(param.annotation)
133
+ if depends is not None:
134
+ _inherit_local_namespace(parent=call, dependency=depends.dependency)
135
+ dependant.dependencies.append(
136
+ DependencyParam(
137
+ param_name=param_name,
138
+ depends=depends,
139
+ dependant=get_dependant(path=path, call=depends.dependency),
140
+ ),
141
+ )
142
+ continue
143
+
144
+ _add_route_param(
145
+ dependant=dependant,
146
+ param_name=param_name,
147
+ param=param,
148
+ is_path_param=param_name in path_param_names,
149
+ )
150
+
151
+ _add_return_annotation(dependant, endpoint_signature)
152
+ return dependant
153
+
154
+
155
+ def _inherit_local_namespace(*, parent: Callable[..., Any], dependency: Callable[..., Any]) -> None:
156
+ if hasattr(dependency, "__modmex_lambda_localns__"):
157
+ return
158
+ setattr(dependency, "__modmex_lambda_localns__", getattr(parent, "__modmex_lambda_localns__", {}))
159
+
160
+
161
+ def _add_route_param(
162
+ *,
163
+ dependant: Dependant,
164
+ param_name: str,
165
+ param: inspect.Parameter,
166
+ is_path_param: bool,
167
+ ) -> None:
168
+ param_field = analyze_param(
169
+ param_name=param_name,
170
+ annotation=param.annotation,
171
+ value=param.default,
172
+ is_path_param=is_path_param,
173
+ is_response_param=False,
174
+ )
175
+ if param_field is None:
176
+ raise AssertionError(f"Parameter field is None for param: {param_name}")
177
+
178
+ if is_body_param(param_field=param_field, is_path_param=is_path_param):
179
+ dependant.body_params.append(param_field)
180
+ else:
181
+ add_param_to_fields(field=param_field, dependant=dependant)
182
+
183
+
184
+ def _add_return_annotation(dependant: Dependant, endpoint_signature: inspect.Signature) -> None:
185
+ if endpoint_signature.return_annotation is inspect.Signature.empty:
186
+ return
187
+
188
+ param_field = analyze_param(
189
+ param_name="return",
190
+ annotation=endpoint_signature.return_annotation,
191
+ value=None,
192
+ is_path_param=False,
193
+ is_response_param=True,
194
+ )
195
+ if param_field is None:
196
+ raise AssertionError("Param field is None for return annotation")
197
+
198
+ dependant.return_param = param_field
199
+
200
+
201
+ def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool:
202
+ if is_path_param:
203
+ if not is_scalar_field(field=param_field):
204
+ raise AssertionError("Path params must be of one of the supported types")
205
+ return False
206
+
207
+ if is_scalar_field(field=param_field):
208
+ return False
209
+ if isinstance(param_field.field_info, Param):
210
+ return False
211
+ if not isinstance(param_field.field_info, Body):
212
+ raise AssertionError(f"Param: {param_field.name} can only be a request body, use Body()")
213
+ return True
214
+
215
+
216
+ def get_flat_params(dependant: Dependant) -> list[ModelField]:
217
+ flat_dependant = get_flat_dependant(dependant)
218
+ return (
219
+ flat_dependant.path_params
220
+ + flat_dependant.query_params
221
+ + flat_dependant.header_params
222
+ + flat_dependant.cookie_params
223
+ )
224
+
225
+
226
+ def get_body_field(*, dependant: Dependant, name: str) -> ModelField | None:
227
+ flat_dependant = get_flat_dependant(dependant)
228
+ if not flat_dependant.body_params:
229
+ return None
230
+
231
+ first_param = flat_dependant.body_params[0]
232
+ if len({param.name for param in flat_dependant.body_params}) == 1 and not getattr(first_param.field_info, "embed", None):
233
+ return first_param
234
+
235
+ for param in flat_dependant.body_params:
236
+ setattr(param.field_info, "embed", True) # noqa: B010
237
+
238
+ body_model = create_body_model(fields=flat_dependant.body_params, model_name=f"Body_{name}")
239
+ required = any(field.required for field in flat_dependant.body_params)
240
+ body_field_info, body_field_info_kwargs = get_body_field_info(
241
+ body_model=body_model,
242
+ flat_dependant=flat_dependant,
243
+ required=required,
244
+ )
245
+
246
+ return create_response_field(
247
+ name="body",
248
+ type_=body_model,
249
+ default=Required if required else None,
250
+ alias="body",
251
+ field_info=body_field_info(**body_field_info_kwargs),
252
+ )
253
+
254
+
255
+ def get_body_field_info(
256
+ *,
257
+ body_model: type[BaseModel],
258
+ flat_dependant: Dependant,
259
+ required: bool,
260
+ ) -> tuple[type[Body], dict[str, Any]]:
261
+ body_field_info_kwargs: dict[str, Any] = {"annotation": body_model, "alias": "body"}
262
+
263
+ if not required:
264
+ body_field_info_kwargs["default"] = None
265
+
266
+ if any(isinstance(field.field_info, File) for field in flat_dependant.body_params):
267
+ body_field_info_kwargs["media_type"] = "multipart/form-data"
268
+ elif any(isinstance(field.field_info, Form) for field in flat_dependant.body_params):
269
+ body_field_info_kwargs["media_type"] = "application/x-www-form-urlencoded"
270
+ else:
271
+ body_param_media_types = [
272
+ field.field_info.media_type
273
+ for field in flat_dependant.body_params
274
+ if isinstance(field.field_info, Body) and hasattr(field.field_info, "media_type")
275
+ ]
276
+ if body_param_media_types and len(set(body_param_media_types)) == 1:
277
+ body_field_info_kwargs["media_type"] = body_param_media_types[0]
278
+
279
+ return Body, body_field_info_kwargs
@@ -0,0 +1,423 @@
1
+ from typing import Any, Mapping, MutableMapping, Sequence, Union, get_args, get_origin
2
+
3
+ from modmex import BaseModel, FieldInfo
4
+ from modmex_lambda.exceptions import RequestValidationError
5
+ from modmex_lambda.event_handler.middlewares import IMiddleware, NextMiddleware
6
+ from modmex_lambda.event_handler.dependencies.compat import (
7
+ _normalize_errors,
8
+ _regenerate_error_with_loc,
9
+ field_annotation_is_sequence,
10
+ get_missing_field_error,
11
+ lenient_issubclass,
12
+ is_scalar_field
13
+ )
14
+ from modmex_lambda.event_handler.types import IApiGatewayResolver
15
+ from modmex_lambda.event_handler.routing import IRoute
16
+
17
+ from modmex_lambda.event_handler.dependencies.params import ModelField, Param, UploadFile
18
+ from modmex_lambda.event_handler.dependencies.types import UnionType
19
+ from modmex_lambda.event_handler.response import Response
20
+
21
+
22
+ class DependencyMiddleware(IMiddleware):
23
+ def _get_body(self, app: IApiGatewayResolver) -> Any:
24
+ return app.current_event.json_body
25
+
26
+ def handler(self, app: IApiGatewayResolver, next_middleware: NextMiddleware) -> Response:
27
+ route: IRoute = app.context.get("_route")
28
+ values: dict[str, Any] = {}
29
+ errors: list[Any] = []
30
+
31
+ param_sources = (
32
+ (route.dependant.path_params, app.context.get("_route_args")),
33
+ (
34
+ route.dependant.query_params,
35
+ _normalize_multi_params(app.current_event.resolved_query_string_parameters, route.dependant.query_params),
36
+ ),
37
+ (
38
+ route.dependant.header_params,
39
+ _normalize_multi_params(app.current_event.resolved_headers_field, route.dependant.header_params),
40
+ ),
41
+ (route.dependant.cookie_params, getattr(app.current_event, "resolved_cookies_field", {})),
42
+ )
43
+
44
+ for params, received_params in param_sources:
45
+ param_values, param_errors = _request_params_to_args(params, received_params)
46
+ values.update(param_values)
47
+ errors.extend(param_errors)
48
+
49
+ if route.dependant.body_params:
50
+ body_values, body_errors = _request_body_to_args(
51
+ required_params=route.dependant.body_params,
52
+ received_body=self._get_body(app),
53
+ )
54
+ values.update(body_values)
55
+ errors.extend(body_errors)
56
+
57
+ if errors:
58
+ raise RequestValidationError(_normalize_errors(errors))
59
+
60
+ app.context["_route_args"] = values
61
+ return next_middleware(app)
62
+
63
+
64
+
65
+ def _request_params_to_args(
66
+ required_params: Sequence[ModelField],
67
+ received_params: Mapping[str, Any],
68
+ ) -> tuple[dict[str, Any], list[dict[str, Any]]]:
69
+ """
70
+ Convert the request params to a dictionary of values using validation, and returns a list of errors.
71
+ """
72
+ values: dict[str, Any] = {}
73
+ errors: list[dict[str, Any]] = []
74
+
75
+ for field in required_params:
76
+ field_info = field.field_info
77
+
78
+ # To ensure early failure, we check if it's not an instance of Param.
79
+ if not isinstance(field_info, Param):
80
+ raise AssertionError(f"Expected Param field_info, got {field_info}")
81
+
82
+ loc = (field_info.in_.value, field.alias)
83
+ value = received_params.get(field.alias)
84
+
85
+ # If we don't have a value, see if it's required or has a default
86
+ if value is None:
87
+ _handle_missing_field_value(field, values, errors, loc)
88
+ continue
89
+
90
+ # Finally, validate the value
91
+ values[field.name] = _validate_field(field=field, value=value, loc=loc, existing_errors=errors)
92
+
93
+ return values, errors
94
+
95
+
96
+ def _request_body_to_args(
97
+ required_params: list[ModelField],
98
+ received_body: dict[str, Any] | None,
99
+ ) -> tuple[dict[str, Any], list[dict[str, Any]]]:
100
+ """
101
+ Convert the request body to a dictionary of values using validation, and returns a list of errors.
102
+ """
103
+ values: dict[str, Any] = {}
104
+ errors: list[dict[str, Any]] = []
105
+
106
+ received_body, field_alias_omitted = _get_embed_body(
107
+ field=required_params[0],
108
+ required_params=required_params,
109
+ received_body=received_body,
110
+ )
111
+
112
+ for field in required_params:
113
+ loc = _get_body_field_location(field, field_alias_omitted)
114
+ value = _extract_field_value_from_body(field, received_body, loc, errors)
115
+
116
+ # If we don't have a value, see if it's required or has a default
117
+ if value is None:
118
+ _handle_missing_field_value(field, values, errors, loc)
119
+ continue
120
+
121
+ value = _normalize_field_value(value=value, field_info=field.field_info)
122
+
123
+ # UploadFile objects bypass Pydantic validation — they're already constructed
124
+ if isinstance(value, UploadFile):
125
+ values[field.name] = value
126
+ else:
127
+ values[field.name] = _validate_field(field=field, value=value, loc=loc, existing_errors=errors)
128
+
129
+ return values, errors
130
+
131
+
132
+ def _get_body_field_location(field: ModelField, field_alias_omitted: bool) -> tuple[str, ...]:
133
+ """Get the location tuple for a body field based on whether the field alias is omitted."""
134
+ if field_alias_omitted:
135
+ return ("body",)
136
+ return ("body", field.alias)
137
+
138
+
139
+ def _extract_field_value_from_body(
140
+ field: ModelField,
141
+ received_body: dict[str, Any] | None,
142
+ loc: tuple[str, ...],
143
+ errors: list[dict[str, Any]],
144
+ ) -> Any | None:
145
+ """Extract field value from the received body, handling potential AttributeError."""
146
+ if received_body is None:
147
+ return None
148
+
149
+ try:
150
+ return received_body.get(field.alias)
151
+ except AttributeError:
152
+ errors.append(get_missing_field_error(loc))
153
+ return None
154
+
155
+
156
+ def _handle_missing_field_value(
157
+ field: ModelField,
158
+ values: dict[str, Any],
159
+ errors: list[dict[str, Any]],
160
+ loc: tuple[str, ...],
161
+ ) -> None:
162
+ """Handle the case when a field value is missing."""
163
+ if field.required:
164
+ errors.append(get_missing_field_error(loc))
165
+ else:
166
+ values[field.name] = field.get_default()
167
+
168
+
169
+ def _is_or_contains_sequence(annotation: Any) -> bool:
170
+ """
171
+ Check if annotation is a sequence or Union/RootModel containing a sequence.
172
+
173
+ This function handles complex type annotations like:
174
+ - List[Model] - direct sequence
175
+ - Union[Model, List[Model]] - checks if any Union member is a sequence
176
+ - Optional[List[Model]] - Union[List[Model], None]
177
+ - RootModel[List[Model]] - checks if the RootModel wraps a sequence
178
+ - Optional[RootModel[List[Model]]] - Union member that is a RootModel
179
+ - RootModel[Union[Model, List[Model]]] - RootModel wrapping a Union with a sequence
180
+ """
181
+ # Direct sequence check
182
+ if field_annotation_is_sequence(annotation):
183
+ return True
184
+
185
+ # Check Union members — recurse so we catch RootModel inside Union
186
+ origin = get_origin(annotation)
187
+ if origin is Union or origin is UnionType:
188
+ for arg in get_args(annotation):
189
+ if _is_or_contains_sequence(arg):
190
+ return True
191
+
192
+ # Check if it's a RootModel wrapping a sequence (or Union containing a sequence)
193
+ if lenient_issubclass(annotation, BaseModel) and getattr(annotation, "__pydantic_root_model__", False):
194
+ if hasattr(annotation, "model_fields") and "root" in annotation.model_fields:
195
+ root_annotation = annotation.model_fields["root"].annotation
196
+ return _is_or_contains_sequence(root_annotation)
197
+
198
+ return False
199
+
200
+
201
+ def _normalize_field_value(value: Any, field_info: FieldInfo) -> Any:
202
+ """Normalize field value, converting lists to single values for non-sequence fields."""
203
+ return _normalize_value(value=value, annotation=field_info.annotation)
204
+
205
+
206
+ def _normalize_value(value: Any, annotation: Any) -> Any:
207
+ if isinstance(value, UploadFile) and annotation is bytes:
208
+ return value.content
209
+
210
+ if _is_or_contains_sequence(annotation):
211
+ return value
212
+
213
+ if isinstance(value, list) and value:
214
+ return value[0]
215
+
216
+ return value
217
+
218
+
219
+ def _validate_field(
220
+ *,
221
+ field: ModelField,
222
+ value: Any,
223
+ loc: tuple[str, ...],
224
+ existing_errors: list[dict[str, Any]],
225
+ ):
226
+ """
227
+ Validate a field, and append any errors to the existing_errors list.
228
+ """
229
+ validated_value, errors = field.validate(value=value, loc=loc)
230
+
231
+ if isinstance(errors, list):
232
+ processed_errors = _regenerate_error_with_loc(errors=errors, loc_prefix=())
233
+ existing_errors.extend(processed_errors)
234
+ elif errors:
235
+ existing_errors.append(errors)
236
+
237
+ return validated_value
238
+
239
+
240
+ def _get_embed_body(
241
+ *,
242
+ field: ModelField,
243
+ required_params: list[ModelField],
244
+ received_body: dict[str, Any] | None,
245
+ ) -> tuple[dict[str, Any] | None, bool]:
246
+ field_info = field.field_info
247
+ embed = getattr(field_info, "embed", None)
248
+
249
+ # If the field is an embed, and the field alias is omitted, we need to wrap the received body in the field alias.
250
+ field_alias_omitted = len(required_params) == 1 and not embed
251
+ if field_alias_omitted:
252
+ received_body = {field.alias: received_body}
253
+
254
+ return received_body, field_alias_omitted
255
+
256
+
257
+ def _normalize_multi_params(
258
+ input_dict: MutableMapping[str, Any],
259
+ params: Sequence[ModelField],
260
+ ) -> MutableMapping[str, Any]:
261
+ for param in params:
262
+ if is_scalar_field(param):
263
+ _process_scalar_param(input_dict, param)
264
+ elif lenient_issubclass(param.field_info.annotation, BaseModel):
265
+ _process_model_param(input_dict, param)
266
+ return input_dict
267
+
268
+
269
+ def _process_scalar_param(input_dict: MutableMapping[str, Any], param: ModelField) -> None:
270
+ try:
271
+ value = input_dict[param.alias]
272
+ if isinstance(value, list) and len(value) == 1:
273
+ input_dict[param.alias] = value[0]
274
+ except KeyError:
275
+ pass
276
+
277
+
278
+ def _process_model_param(input_dict: MutableMapping[str, Any], param: ModelField) -> None:
279
+ model_class = param.field_info.annotation
280
+
281
+ model_data = {}
282
+ for field_name, field_alias, annotation in _iter_model_fields(model_class):
283
+ value = _get_param_value(input_dict, field_alias, field_name, model_class)
284
+ if value is not None:
285
+ model_data[field_alias] = _normalize_value(value=value, annotation=annotation)
286
+
287
+ input_dict[param.alias] = model_data
288
+
289
+
290
+ def _iter_model_fields(model_class: type[BaseModel]):
291
+ model_fields = getattr(model_class, "model_fields", None)
292
+ if model_fields:
293
+ for field_name, field_info in model_fields.items():
294
+ yield field_name, field_info.alias or field_name, field_info.annotation
295
+ return
296
+
297
+ alias_generator = getattr(model_class, "model_config", {}).get("alias_generator")
298
+ for field in getattr(model_class, "__modmex_fields__", ()):
299
+ field_alias = alias_generator(field.name) if alias_generator else field.name
300
+ yield field.name, field_alias, field.type
301
+
302
+
303
+ def _get_param_value(
304
+ input_dict: MutableMapping[str, Any],
305
+ field_alias: str,
306
+ field_name: str,
307
+ model_class: type[BaseModel],
308
+ ) -> Any:
309
+ value = input_dict.get(field_alias)
310
+ if value is not None:
311
+ return value
312
+
313
+ model_config = getattr(model_class, "model_config", {})
314
+ if model_config.get("validate_by_name") or model_config.get("populate_by_name"):
315
+ value = input_dict.get(field_name)
316
+
317
+ return value
318
+
319
+
320
+ def _extract_multipart_boundary(content_type: str) -> str | None:
321
+ """Extract the boundary string from a multipart/form-data content-type header."""
322
+ for segment in content_type.split(";"):
323
+ stripped = segment.strip()
324
+ if stripped.startswith("boundary="):
325
+ boundary = stripped[len("boundary=") :]
326
+ # Remove optional quotes around boundary
327
+ if boundary.startswith('"') and boundary.endswith('"'):
328
+ boundary = boundary[1:-1]
329
+ return boundary
330
+ return None
331
+
332
+
333
+ def _parse_multipart_body(body: bytes, boundary: str) -> dict[str, Any]:
334
+ """
335
+ Parse a multipart/form-data body into a dict of field names to values.
336
+
337
+ File fields get bytes values; regular form fields get string values.
338
+ Multiple values for the same field name are collected into lists.
339
+ """
340
+ delimiter = f"--{boundary}".encode()
341
+ end_delimiter = f"--{boundary}--".encode()
342
+
343
+ result: dict[str, Any] = {}
344
+
345
+ # Split body by the boundary delimiter
346
+ raw_parts = body.split(delimiter)
347
+
348
+ for raw_part in raw_parts:
349
+ # Skip the preamble (before first boundary) and epilogue (after closing boundary)
350
+ if not raw_part or raw_part.strip() == b"" or raw_part.strip() == b"--":
351
+ continue
352
+
353
+ # Remove the end delimiter marker if present
354
+ chunk = raw_part
355
+ if chunk.endswith(end_delimiter):
356
+ chunk = chunk[: -len(end_delimiter)]
357
+
358
+ # Strip leading \r\n
359
+ if chunk.startswith(b"\r\n"):
360
+ chunk = chunk[2:]
361
+
362
+ # Strip trailing \r\n
363
+ if chunk.endswith(b"\r\n"):
364
+ chunk = chunk[:-2]
365
+
366
+ # Split headers from body at the double CRLF
367
+ header_end = chunk.find(b"\r\n\r\n")
368
+ if header_end == -1:
369
+ continue
370
+
371
+ header_section = chunk[:header_end].decode("utf-8")
372
+ body_section = chunk[header_end + 4 :]
373
+
374
+ # Parse Content-Disposition to get the field name and optional filename
375
+ field_name = None
376
+ filename = None
377
+ content_type_header = None
378
+
379
+ for header_line in header_section.split("\r\n"):
380
+ header_lower = header_line.lower()
381
+ if header_lower.startswith("content-disposition:"):
382
+ field_name = _extract_header_param(header_line, "name")
383
+ filename = _extract_header_param(header_line, "filename")
384
+ elif header_lower.startswith("content-type:"):
385
+ content_type_header = header_line.split(":", 1)[1].strip()
386
+
387
+ if field_name is None:
388
+ continue
389
+
390
+ # If it has a filename, it's a file upload — wrap as UploadFile
391
+ # Otherwise it's a regular form field — decode to string
392
+ if filename is not None:
393
+ value: Any = UploadFile(content=body_section, filename=filename, content_type=content_type_header)
394
+ else:
395
+ value = body_section.decode("utf-8")
396
+
397
+ # Collect multiple values for same field name into a list
398
+ if field_name in result:
399
+ existing = result[field_name]
400
+ if isinstance(existing, list):
401
+ existing.append(value)
402
+ else:
403
+ result[field_name] = [existing, value]
404
+ else:
405
+ result[field_name] = value
406
+
407
+ return result
408
+
409
+
410
+ def _extract_header_param(header_line: str, param_name: str) -> str | None:
411
+ """Extract a parameter value from a header line (e.g., name="file" from Content-Disposition)."""
412
+ search = f'{param_name}="'
413
+ idx = header_line.find(search)
414
+ if idx == -1:
415
+ return None
416
+ start = idx + len(search)
417
+ end = header_line.find('"', start)
418
+ if end == -1:
419
+ return None
420
+ return header_line[start:end]
421
+
422
+
423
+ __all__ = ["DependencyMiddleware"]