modmex-lambda 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- modmex_lambda/__init__.py +62 -0
- modmex_lambda/data_classes/__init__.py +49 -0
- modmex_lambda/data_classes/api_gateway_authorizer_event.py +38 -0
- modmex_lambda/data_classes/api_gateway_proxy_event.py +328 -0
- modmex_lambda/data_classes/api_gateway_websocket_event.py +40 -0
- modmex_lambda/data_classes/cognito_user_pool_event.py +599 -0
- modmex_lambda/data_classes/common.py +441 -0
- modmex_lambda/event_handler/__init__.py +45 -0
- modmex_lambda/event_handler/api_gateway.py +331 -0
- modmex_lambda/event_handler/constants.py +3 -0
- modmex_lambda/event_handler/content_types.py +13 -0
- modmex_lambda/event_handler/cors.py +97 -0
- modmex_lambda/event_handler/dependencies/__init__.py +0 -0
- modmex_lambda/event_handler/dependencies/compat.py +231 -0
- modmex_lambda/event_handler/dependencies/dependant.py +279 -0
- modmex_lambda/event_handler/dependencies/dependency_middleware.py +423 -0
- modmex_lambda/event_handler/dependencies/depends.py +184 -0
- modmex_lambda/event_handler/dependencies/params.py +317 -0
- modmex_lambda/event_handler/dependencies/types.py +14 -0
- modmex_lambda/event_handler/exception_handler.py +70 -0
- modmex_lambda/event_handler/exceptions.py +72 -0
- modmex_lambda/event_handler/gateway_response.py +96 -0
- modmex_lambda/event_handler/middlewares.py +33 -0
- modmex_lambda/event_handler/params.py +44 -0
- modmex_lambda/event_handler/request.py +70 -0
- modmex_lambda/event_handler/response.py +60 -0
- modmex_lambda/event_handler/routing.py +507 -0
- modmex_lambda/event_handler/routing_fallbacks.py +92 -0
- modmex_lambda/event_handler/types.py +31 -0
- modmex_lambda/event_sources.py +53 -0
- modmex_lambda/exceptions.py +3 -0
- modmex_lambda/logging.py +99 -0
- modmex_lambda/params.py +3 -0
- modmex_lambda/parser.py +47 -0
- modmex_lambda/request.py +3 -0
- modmex_lambda/resolver.py +3 -0
- modmex_lambda/response.py +3 -0
- modmex_lambda/routing.py +3 -0
- modmex_lambda/shared/__init__.py +0 -0
- modmex_lambda/shared/cookies.py +84 -0
- modmex_lambda/shared/headers_serializer.py +65 -0
- modmex_lambda/shared/json_encoder.py +53 -0
- modmex_lambda/shared/types.py +4 -0
- modmex_lambda/validation.py +178 -0
- modmex_lambda-0.1.0.dist-info/METADATA +375 -0
- modmex_lambda-0.1.0.dist-info/RECORD +48 -0
- modmex_lambda-0.1.0.dist-info/WHEEL +4 -0
- modmex_lambda-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,331 @@
|
|
|
1
|
+
"""API Gateway event handler public API."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import re
|
|
6
|
+
import json
|
|
7
|
+
from typing import Any, Callable, Pattern
|
|
8
|
+
from functools import partial
|
|
9
|
+
from enum import Enum
|
|
10
|
+
from http import HTTPStatus
|
|
11
|
+
|
|
12
|
+
from modmex_lambda.event_handler import content_types
|
|
13
|
+
from modmex_lambda.event_handler.exceptions import (
|
|
14
|
+
ForbiddenError,
|
|
15
|
+
MethodNotAllowedError,
|
|
16
|
+
NotFoundError,
|
|
17
|
+
RequestValidationError,
|
|
18
|
+
UnauthorizedError,
|
|
19
|
+
)
|
|
20
|
+
from modmex_lambda.event_handler.gateway_response import GatewayResponseBuilder
|
|
21
|
+
from modmex_lambda.event_handler.request import Request
|
|
22
|
+
from modmex_lambda.event_handler.response import Response
|
|
23
|
+
from modmex_lambda.event_handler.routing import HasRoutes, Route, Router
|
|
24
|
+
from modmex_lambda.event_handler.routing_fallbacks import RoutingFallbackHandler
|
|
25
|
+
from modmex_lambda.data_classes.common import BaseProxyEvent
|
|
26
|
+
from modmex_lambda.data_classes.api_gateway_proxy_event import APIGatewayProxyEvent, APIGatewayProxyEventV2
|
|
27
|
+
from modmex_lambda.shared.types import AnyCallableT
|
|
28
|
+
from modmex_lambda.event_handler.types import IApiGatewayResolver
|
|
29
|
+
from modmex_lambda.event_handler.dependencies.dependency_middleware import DependencyMiddleware
|
|
30
|
+
from modmex_lambda.event_handler.middlewares import NextMiddleware
|
|
31
|
+
from modmex_lambda.event_handler.cors import CORSConfig
|
|
32
|
+
from modmex_lambda.shared.json_encoder import JSONEncoder
|
|
33
|
+
from modmex_lambda.event_handler.constants import DEFAULT_STATUS_CODE
|
|
34
|
+
|
|
35
|
+
# NextMiddleware = Callable[["ApiGatewayResolver"], Response]
|
|
36
|
+
Middleware = Callable[["ApiGatewayResolver", NextMiddleware], Response]
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
JSON_DUMP_CALL = partial(json.dumps, separators=(",", ":"), cls=JSONEncoder)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class ProxyEventType(Enum):
|
|
43
|
+
"""An enumerations of the supported proxy event types."""
|
|
44
|
+
|
|
45
|
+
APIGatewayProxyEvent = "APIGatewayProxyEvent"
|
|
46
|
+
APIGatewayProxyEventV2 = "APIGatewayProxyEventV2"
|
|
47
|
+
LambdaFunctionUrlEvent = "LambdaFunctionUrlEvent"
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
_PROXY_EVENT_MAP: dict[Enum, type[BaseProxyEvent]] = {
|
|
51
|
+
ProxyEventType.APIGatewayProxyEvent: APIGatewayProxyEvent,
|
|
52
|
+
ProxyEventType.APIGatewayProxyEventV2: APIGatewayProxyEventV2,
|
|
53
|
+
ProxyEventType.LambdaFunctionUrlEvent: APIGatewayProxyEventV2,
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class BaseRouter(HasRoutes):
|
|
58
|
+
current_event: BaseProxyEvent
|
|
59
|
+
lambda_context: object
|
|
60
|
+
context: dict[str, Any]
|
|
61
|
+
_router_middlewares: list[Callable] = []
|
|
62
|
+
processed_stack_frames: list[str] = []
|
|
63
|
+
|
|
64
|
+
def use(self, middlewares: list[Middleware]) -> None:
|
|
65
|
+
self._router_middlewares = self._router_middlewares + middlewares
|
|
66
|
+
|
|
67
|
+
def append_context(self, **kwargs: Any) -> None:
|
|
68
|
+
self.context.update(kwargs)
|
|
69
|
+
|
|
70
|
+
def clear_context(self) -> None:
|
|
71
|
+
"""Resets routing context"""
|
|
72
|
+
self.context.clear()
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
def request(self) -> Request:
|
|
76
|
+
cached: Request | None = self.context.get("_request")
|
|
77
|
+
if cached is not None:
|
|
78
|
+
return cached
|
|
79
|
+
|
|
80
|
+
route: Route | None = self.context.get("_route")
|
|
81
|
+
if route is None:
|
|
82
|
+
raise RuntimeError(
|
|
83
|
+
"app.request is only available after route resolution. Use it inside middleware or a route handler.",
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
request = Request(
|
|
87
|
+
route_path=route.path,
|
|
88
|
+
path_parameters=self.context.get("_path_params", {}),
|
|
89
|
+
current_event=self.current_event,
|
|
90
|
+
context=self.context,
|
|
91
|
+
)
|
|
92
|
+
self.context["_request"] = request
|
|
93
|
+
return request
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class ApiGatewayResolver(BaseRouter, IApiGatewayResolver):
|
|
97
|
+
_event_type: ProxyEventType
|
|
98
|
+
|
|
99
|
+
def __init__(
|
|
100
|
+
self,
|
|
101
|
+
*,
|
|
102
|
+
cors: CORSConfig | None = None,
|
|
103
|
+
serializer: Callable[[dict], str] | None = None,
|
|
104
|
+
strip_prefixes: list[str | Pattern] | None = None,
|
|
105
|
+
json_body_deserializer: Callable[[str], dict] | None = None,
|
|
106
|
+
logger: Any | None = None,
|
|
107
|
+
) -> None:
|
|
108
|
+
self._cors = cors
|
|
109
|
+
self._cors_enabled = cors is not None
|
|
110
|
+
|
|
111
|
+
self._router = Router()
|
|
112
|
+
self._routing_fallbacks = RoutingFallbackHandler(self)
|
|
113
|
+
self._response_builder_class = GatewayResponseBuilder
|
|
114
|
+
|
|
115
|
+
self._serializer = serializer or JSON_DUMP_CALL
|
|
116
|
+
self._json_body_deserializer = json_body_deserializer
|
|
117
|
+
|
|
118
|
+
self._router_middlewares: list[Middleware] = []
|
|
119
|
+
|
|
120
|
+
self._logger = logger
|
|
121
|
+
self.dependency_overrides: dict[Callable[..., Any], Callable[..., Any]] = {}
|
|
122
|
+
self.current_event = None
|
|
123
|
+
self.current_context = None
|
|
124
|
+
self._strip_prefixes = strip_prefixes
|
|
125
|
+
self.context: dict[str, Any] = {}
|
|
126
|
+
self._dependency_middleware = DependencyMiddleware()
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
@property
|
|
130
|
+
def handler(self) -> Callable[[dict[str, Any], object], dict[str, Any]]:
|
|
131
|
+
return self.resolve
|
|
132
|
+
|
|
133
|
+
def middleware(self, func: Middleware | None = None) -> Any:
|
|
134
|
+
if func is None:
|
|
135
|
+
def decorator(mw: Middleware) -> Middleware:
|
|
136
|
+
self._router_middlewares.append(mw)
|
|
137
|
+
return mw
|
|
138
|
+
|
|
139
|
+
return decorator
|
|
140
|
+
|
|
141
|
+
self._router_middlewares.append(func)
|
|
142
|
+
return func
|
|
143
|
+
|
|
144
|
+
def route(
|
|
145
|
+
self,
|
|
146
|
+
rule: str,
|
|
147
|
+
method: str | list[str] | tuple[str],
|
|
148
|
+
description: str | None = None,
|
|
149
|
+
status_code: int | None = DEFAULT_STATUS_CODE,
|
|
150
|
+
middlewares: list[Middleware] | None = None,
|
|
151
|
+
cors: bool | None = None,
|
|
152
|
+
compress: bool = False,
|
|
153
|
+
cache_control: str | None = None,
|
|
154
|
+
**_: Any,
|
|
155
|
+
) -> Callable[[AnyCallableT], AnyCallableT]:
|
|
156
|
+
cors_enabled = self._cors_enabled if cors is None else cors
|
|
157
|
+
return self._router.route(
|
|
158
|
+
rule=rule,
|
|
159
|
+
method=method,
|
|
160
|
+
description=description,
|
|
161
|
+
status_code=status_code,
|
|
162
|
+
middlewares=middlewares,
|
|
163
|
+
cors=cors_enabled,
|
|
164
|
+
compress=compress,
|
|
165
|
+
cache_control=cache_control,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
def include_router(self, router: Router) -> None:
|
|
169
|
+
self._router.include_router(router)
|
|
170
|
+
|
|
171
|
+
def resolve(self, event: dict[str, Any], context: object) -> dict[str, Any]:
|
|
172
|
+
self.current_event = self._to_proxy_event(event)
|
|
173
|
+
self.current_context = context
|
|
174
|
+
self.context = {}
|
|
175
|
+
response = self._resolve().serialize(self.current_event, self._cors)
|
|
176
|
+
self.clear_context()
|
|
177
|
+
return response
|
|
178
|
+
|
|
179
|
+
def _resolve(self) -> GatewayResponseBuilder:
|
|
180
|
+
path = self._remove_prefix(self.current_event.path)
|
|
181
|
+
method = self.current_event.http_method.upper()
|
|
182
|
+
route, path_params, allowed_methods = self._router.match(method, path)
|
|
183
|
+
|
|
184
|
+
if route:
|
|
185
|
+
self.append_context(
|
|
186
|
+
_route=route,
|
|
187
|
+
_route_args=path_params,
|
|
188
|
+
_path_params=path_params,
|
|
189
|
+
_path=path,
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
response = self._call_route(route, route_arguments=path_params)
|
|
193
|
+
elif allowed_methods:
|
|
194
|
+
response = self._routing_fallbacks.method_not_allowed(
|
|
195
|
+
method=method,
|
|
196
|
+
path=path,
|
|
197
|
+
allowed_methods=allowed_methods,
|
|
198
|
+
)
|
|
199
|
+
else:
|
|
200
|
+
response = self._routing_fallbacks.not_found(method=method, path=path)
|
|
201
|
+
|
|
202
|
+
return self._response_builder_class(
|
|
203
|
+
response=response,
|
|
204
|
+
route=route,
|
|
205
|
+
json_serializer=self._serializer,
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
def _call_route(self, route: Route, route_arguments: dict | None = None) -> Response:
|
|
209
|
+
try:
|
|
210
|
+
response = route.invoke(
|
|
211
|
+
router_middlewares=self._router_middlewares,
|
|
212
|
+
app=self,
|
|
213
|
+
route_arguments=route_arguments or {},
|
|
214
|
+
)
|
|
215
|
+
return response
|
|
216
|
+
except Exception as e:
|
|
217
|
+
response = self._call_exception_handler(e)
|
|
218
|
+
if response:
|
|
219
|
+
return response
|
|
220
|
+
|
|
221
|
+
raise
|
|
222
|
+
|
|
223
|
+
def _remove_prefix(self, path: str) -> str:
|
|
224
|
+
"""Remove the configured prefix from the path"""
|
|
225
|
+
if not isinstance(self._strip_prefixes, list):
|
|
226
|
+
return path
|
|
227
|
+
|
|
228
|
+
for prefix in self._strip_prefixes:
|
|
229
|
+
if isinstance(prefix, str):
|
|
230
|
+
if path == prefix:
|
|
231
|
+
return "/"
|
|
232
|
+
|
|
233
|
+
if self._path_starts_with(path, prefix):
|
|
234
|
+
return path[len(prefix) :]
|
|
235
|
+
|
|
236
|
+
if isinstance(prefix, Pattern):
|
|
237
|
+
path = re.sub(prefix, "", path)
|
|
238
|
+
|
|
239
|
+
if not path:
|
|
240
|
+
return "/"
|
|
241
|
+
|
|
242
|
+
return path
|
|
243
|
+
|
|
244
|
+
def _to_proxy_event(self, event: dict) -> BaseProxyEvent:
|
|
245
|
+
event_type = getattr(self, "_event_type", None)
|
|
246
|
+
if event_type is None:
|
|
247
|
+
raise TypeError(
|
|
248
|
+
"ApiGatewayResolver is a base resolver. Use ApiGatewayRestResolver for API Gateway REST API "
|
|
249
|
+
"payload v1 or ApiGatewayHttpResolver for API Gateway HTTP API payload v2.",
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
event_class = _PROXY_EVENT_MAP.get(event_type)
|
|
253
|
+
if event_class is None:
|
|
254
|
+
raise TypeError(f"Unsupported API Gateway event type: {event_type!r}")
|
|
255
|
+
|
|
256
|
+
return event_class(event, self._json_body_deserializer)
|
|
257
|
+
|
|
258
|
+
def _to_response(self, result: dict | tuple | Response) -> Response:
|
|
259
|
+
if isinstance(result, Response):
|
|
260
|
+
return result
|
|
261
|
+
if isinstance(result, tuple) and len(result) == 2:
|
|
262
|
+
result, status_code = result
|
|
263
|
+
else:
|
|
264
|
+
route: Route | None = self.context.get("_route")
|
|
265
|
+
status_code = route.status_code if route else HTTPStatus.OK.value
|
|
266
|
+
return Response(body=result, status_code=status_code, content_type=content_types.APPLICATION_JSON)
|
|
267
|
+
|
|
268
|
+
def exception_handler(
|
|
269
|
+
self,
|
|
270
|
+
exc_class: type[Exception] | list[type[Exception]],
|
|
271
|
+
):
|
|
272
|
+
return self._router.exception_handler(exc_class)
|
|
273
|
+
|
|
274
|
+
def _call_exception_handler(self, exp: Exception) -> Response | None:
|
|
275
|
+
handler = self._router.exception_handler_manager.lookup_exception_handler(type(exp))
|
|
276
|
+
if handler:
|
|
277
|
+
try:
|
|
278
|
+
return handler(exp)
|
|
279
|
+
except Exception as exc:
|
|
280
|
+
exp = exc
|
|
281
|
+
|
|
282
|
+
return self._default_error_response(exp)
|
|
283
|
+
|
|
284
|
+
def _default_error_response(self, exc: Exception) -> Response | None:
|
|
285
|
+
if isinstance(exc, RequestValidationError):
|
|
286
|
+
errors = [{"loc": e["loc"], "type": e["type"]} for e in exc.errors()]
|
|
287
|
+
return Response(
|
|
288
|
+
body={
|
|
289
|
+
"message": "Validation Error",
|
|
290
|
+
"detail": errors,
|
|
291
|
+
},
|
|
292
|
+
status_code=HTTPStatus.BAD_REQUEST.value,
|
|
293
|
+
content_type=content_types.APPLICATION_JSON,
|
|
294
|
+
)
|
|
295
|
+
if isinstance(exc, MethodNotAllowedError):
|
|
296
|
+
return Response(body={"message": "Method Not Allowed"}, status_code=405)
|
|
297
|
+
if isinstance(exc, NotFoundError):
|
|
298
|
+
return Response(body={"message": "Not Found"}, status_code=404)
|
|
299
|
+
if isinstance(exc, UnauthorizedError):
|
|
300
|
+
return Response(body={"message": "Unauthorized"}, status_code=401)
|
|
301
|
+
if isinstance(exc, ForbiddenError):
|
|
302
|
+
return Response(body={"message": "Forbidden"}, status_code=403)
|
|
303
|
+
return None
|
|
304
|
+
|
|
305
|
+
@staticmethod
|
|
306
|
+
def _path_starts_with(path: str, prefix: str) -> bool:
|
|
307
|
+
"""Returns true if the `path` starts with a prefix plus a `/`"""
|
|
308
|
+
if not isinstance(prefix, str) or prefix == "":
|
|
309
|
+
return False
|
|
310
|
+
|
|
311
|
+
return path.startswith(f"{prefix}/")
|
|
312
|
+
|
|
313
|
+
class ApiGatewayRestResolver(ApiGatewayResolver):
|
|
314
|
+
"""Resolver for API Gateway REST API payload v1."""
|
|
315
|
+
|
|
316
|
+
_event_type = ProxyEventType.APIGatewayProxyEvent
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
class ApiGatewayHttpResolver(ApiGatewayResolver):
|
|
320
|
+
"""Resolver for API Gateway HTTP API payload v2."""
|
|
321
|
+
|
|
322
|
+
_event_type = ProxyEventType.APIGatewayProxyEventV2
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
__all__ = [
|
|
326
|
+
"ApiGatewayHttpResolver",
|
|
327
|
+
"ApiGatewayRestResolver",
|
|
328
|
+
"Request",
|
|
329
|
+
"Response",
|
|
330
|
+
"Route",
|
|
331
|
+
]
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
"""Common HTTP content types."""
|
|
2
|
+
|
|
3
|
+
APPLICATION_JSON = "application/json"
|
|
4
|
+
TEXT_PLAIN = "text/plain"
|
|
5
|
+
TEXT_HTML = "text/html"
|
|
6
|
+
APPLICATION_OCTET_STREAM = "application/octet-stream"
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"APPLICATION_JSON",
|
|
10
|
+
"TEXT_PLAIN",
|
|
11
|
+
"TEXT_HTML",
|
|
12
|
+
"APPLICATION_OCTET_STREAM",
|
|
13
|
+
]
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class CORSConfig:
|
|
6
|
+
|
|
7
|
+
_REQUIRED_HEADERS = ["Authorization", "Content-Type", "X-Amz-Date", "X-Api-Key", "X-Amz-Security-Token"]
|
|
8
|
+
|
|
9
|
+
def __init__(
|
|
10
|
+
self,
|
|
11
|
+
allow_origin: str = "*",
|
|
12
|
+
extra_origins: list[str] | None = None,
|
|
13
|
+
allow_headers: list[str] | None = None,
|
|
14
|
+
expose_headers: list[str] | None = None,
|
|
15
|
+
max_age: int | None = None,
|
|
16
|
+
allow_credentials: bool = False,
|
|
17
|
+
):
|
|
18
|
+
"""
|
|
19
|
+
Parameters
|
|
20
|
+
----------
|
|
21
|
+
allow_origin: str
|
|
22
|
+
The value of the `Access-Control-Allow-Origin` to send in the response. Defaults to "*", but should
|
|
23
|
+
only be used during development.
|
|
24
|
+
extra_origins: list[str] | None
|
|
25
|
+
The list of additional allowed origins.
|
|
26
|
+
allow_headers: list[str] | None
|
|
27
|
+
The list of additional allowed headers. This list is added to list of
|
|
28
|
+
built-in allowed headers: `Authorization`, `Content-Type`, `X-Amz-Date`,
|
|
29
|
+
`X-Api-Key`, `X-Amz-Security-Token`.
|
|
30
|
+
expose_headers: list[str] | None
|
|
31
|
+
A list of values to return for the Access-Control-Expose-Headers
|
|
32
|
+
max_age: int | None
|
|
33
|
+
The value for the `Access-Control-Max-Age`
|
|
34
|
+
allow_credentials: bool
|
|
35
|
+
A boolean value that sets the value of `Access-Control-Allow-Credentials`
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
self._allowed_origins = [allow_origin]
|
|
39
|
+
|
|
40
|
+
if extra_origins:
|
|
41
|
+
self._allowed_origins.extend(extra_origins)
|
|
42
|
+
|
|
43
|
+
self.allow_headers = set(self._REQUIRED_HEADERS + (allow_headers or []))
|
|
44
|
+
self.expose_headers = expose_headers or []
|
|
45
|
+
self.max_age = max_age
|
|
46
|
+
self.allow_credentials = allow_credentials
|
|
47
|
+
|
|
48
|
+
def to_dict(self, origin: str | None) -> dict[str, str]:
|
|
49
|
+
"""Builds the configured Access-Control http headers"""
|
|
50
|
+
|
|
51
|
+
# If there's no Origin, don't add any CORS headers
|
|
52
|
+
if not origin:
|
|
53
|
+
return {}
|
|
54
|
+
|
|
55
|
+
# If the origin doesn't match any of the allowed origins, and we don't allow all origins ("*"),
|
|
56
|
+
# don't add any CORS headers
|
|
57
|
+
if origin not in self._allowed_origins and "*" not in self._allowed_origins:
|
|
58
|
+
return {}
|
|
59
|
+
|
|
60
|
+
# The origin matched an allowed origin, so return the CORS headers
|
|
61
|
+
headers = {
|
|
62
|
+
"Access-Control-Allow-Origin": origin,
|
|
63
|
+
"Access-Control-Allow-Headers": CORSConfig.build_allow_methods(self.allow_headers),
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
if self.expose_headers:
|
|
67
|
+
headers["Access-Control-Expose-Headers"] = ",".join(self.expose_headers)
|
|
68
|
+
if self.max_age is not None:
|
|
69
|
+
headers["Access-Control-Max-Age"] = str(self.max_age)
|
|
70
|
+
if origin != "*" and self.allow_credentials is True:
|
|
71
|
+
headers["Access-Control-Allow-Credentials"] = "true"
|
|
72
|
+
return headers
|
|
73
|
+
|
|
74
|
+
def allowed_origin(self, extracted_origin: str) -> str | None:
|
|
75
|
+
if extracted_origin in self._allowed_origins:
|
|
76
|
+
return extracted_origin
|
|
77
|
+
if extracted_origin is not None and "*" in self._allowed_origins:
|
|
78
|
+
return "*"
|
|
79
|
+
|
|
80
|
+
return None
|
|
81
|
+
|
|
82
|
+
@staticmethod
|
|
83
|
+
def build_allow_methods(methods: set[str]) -> str:
|
|
84
|
+
"""Build sorted comma delimited methods for Access-Control-Allow-Methods header
|
|
85
|
+
|
|
86
|
+
Parameters
|
|
87
|
+
----------
|
|
88
|
+
methods : set[str]
|
|
89
|
+
Set of HTTP Methods
|
|
90
|
+
|
|
91
|
+
Returns
|
|
92
|
+
-------
|
|
93
|
+
set[str]
|
|
94
|
+
Formatted string with all HTTP Methods allowed for CORS e.g., `GET, OPTIONS`
|
|
95
|
+
|
|
96
|
+
"""
|
|
97
|
+
return ",".join(sorted(methods))
|
|
File without changes
|
|
@@ -0,0 +1,231 @@
|
|
|
1
|
+
# mypy: ignore-errors
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from collections import deque
|
|
5
|
+
from collections.abc import Mapping, Sequence
|
|
6
|
+
from copy import copy
|
|
7
|
+
from dataclasses import dataclass, is_dataclass
|
|
8
|
+
from typing import Any, Deque, FrozenSet, List, Set, Tuple, Union, Literal, Annotated, get_args, get_origin
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
from modmex import BaseModel, ValidationError, create_model
|
|
12
|
+
from modmex_lambda.validation import ModmexValidator
|
|
13
|
+
|
|
14
|
+
from modmex_lambda.event_handler.dependencies.types import IncEx, UnionType
|
|
15
|
+
|
|
16
|
+
from modmex import Field as FieldInfo, Undefined
|
|
17
|
+
from modmex.fields import UndefinedType
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
Required = Undefined
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
sequence_annotation_to_type = {
|
|
25
|
+
Sequence: list,
|
|
26
|
+
List: list,
|
|
27
|
+
list: list,
|
|
28
|
+
Tuple: tuple,
|
|
29
|
+
tuple: tuple,
|
|
30
|
+
Set: set,
|
|
31
|
+
set: set,
|
|
32
|
+
FrozenSet: frozenset,
|
|
33
|
+
frozenset: frozenset,
|
|
34
|
+
Deque: deque,
|
|
35
|
+
deque: deque,
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
sequence_types = tuple(sequence_annotation_to_type.keys())
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class ErrorWrapper(Exception):
|
|
43
|
+
pass
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@dataclass
|
|
47
|
+
class ModelField:
|
|
48
|
+
field_info: FieldInfo
|
|
49
|
+
name: str
|
|
50
|
+
mode: Literal["validation", "serialization"] = "validation"
|
|
51
|
+
|
|
52
|
+
@property
|
|
53
|
+
def alias(self) -> str:
|
|
54
|
+
value = self.field_info.alias
|
|
55
|
+
return value if value is not None else self.name
|
|
56
|
+
|
|
57
|
+
@property
|
|
58
|
+
def required(self) -> bool:
|
|
59
|
+
return self.field_info.is_required()
|
|
60
|
+
|
|
61
|
+
@property
|
|
62
|
+
def default(self) -> Any:
|
|
63
|
+
return self.get_default()
|
|
64
|
+
|
|
65
|
+
@property
|
|
66
|
+
def type_(self) -> Any:
|
|
67
|
+
return self.field_info.annotation
|
|
68
|
+
|
|
69
|
+
# def __post_init__(self) -> None:
|
|
70
|
+
# # If the field_info.annotation is already an Annotated type with discriminator metadata,
|
|
71
|
+
# # use it directly instead of wrapping it again
|
|
72
|
+
# annotation = self.field_info.annotation
|
|
73
|
+
# if (
|
|
74
|
+
# get_origin(annotation) is Annotated
|
|
75
|
+
# and hasattr(self.field_info, "discriminator")
|
|
76
|
+
# and self.field_info.discriminator is not None
|
|
77
|
+
# ):
|
|
78
|
+
# self._type_adapter: TypeAdapter[Any] = TypeAdapter(annotation)
|
|
79
|
+
# else:
|
|
80
|
+
# self._type_adapter: TypeAdapter[Any] = TypeAdapter(
|
|
81
|
+
# Annotated[annotation, self.field_info],
|
|
82
|
+
# )
|
|
83
|
+
|
|
84
|
+
def get_default(self) -> Any:
|
|
85
|
+
if self.field_info.is_required():
|
|
86
|
+
return None
|
|
87
|
+
return self.field_info.get_default(call_default_factory=True)
|
|
88
|
+
|
|
89
|
+
def serialize(
|
|
90
|
+
self,
|
|
91
|
+
value: Any,
|
|
92
|
+
*,
|
|
93
|
+
mode: Literal["json", "python"] = "json",
|
|
94
|
+
include: IncEx | None = None,
|
|
95
|
+
exclude: IncEx | None = None,
|
|
96
|
+
by_alias: bool = True,
|
|
97
|
+
exclude_unset: bool = False,
|
|
98
|
+
exclude_defaults: bool = False,
|
|
99
|
+
exclude_none: bool = False,
|
|
100
|
+
) -> Any:
|
|
101
|
+
return ModmexValidator().serialize(value)
|
|
102
|
+
|
|
103
|
+
def validate(
|
|
104
|
+
self,
|
|
105
|
+
value: Any,
|
|
106
|
+
*,
|
|
107
|
+
loc: tuple[int | str, ...] = (),
|
|
108
|
+
) -> tuple[Any, list[dict[str, Any]] | None]:
|
|
109
|
+
try:
|
|
110
|
+
return (ModmexValidator().validate(value, self.field_info.annotation, list(loc)), None)
|
|
111
|
+
except ValidationError as exc:
|
|
112
|
+
return None, _regenerate_error_with_loc(errors=exc.errors, loc_prefix=())
|
|
113
|
+
|
|
114
|
+
def __hash__(self) -> int:
|
|
115
|
+
# Each ModelField is unique for our purposes
|
|
116
|
+
return id(self)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def _normalize_errors(errors: Sequence[Any]) -> list[dict[str, Any]]:
|
|
120
|
+
return errors # type: ignore[r
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def create_body_model(*, fields: Sequence[ModelField], model_name: str) -> type[BaseModel]:
|
|
124
|
+
field_params = {f.name: (f.field_info.annotation, f.field_info) for f in fields}
|
|
125
|
+
model: type[BaseModel] = create_model(model_name, **field_params)
|
|
126
|
+
return model
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def is_scalar_field(field: ModelField) -> bool:
|
|
130
|
+
from modmex_lambda.event_handler.dependencies.params import Body
|
|
131
|
+
|
|
132
|
+
return field_annotation_is_scalar(field.field_info.annotation) and not isinstance(field.field_info, Body)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def field_annotation_is_complex(annotation: type[Any] | None) -> bool:
|
|
136
|
+
origin = get_origin(annotation)
|
|
137
|
+
if origin is Union or origin is UnionType:
|
|
138
|
+
return any(field_annotation_is_complex(arg) for arg in get_args(annotation))
|
|
139
|
+
|
|
140
|
+
return (
|
|
141
|
+
_annotation_is_complex(annotation)
|
|
142
|
+
or _annotation_is_complex(origin)
|
|
143
|
+
# or hasattr(origin, "__pydantic_core_schema__")
|
|
144
|
+
# or hasattr(origin, "__get_pydantic_core_schema__")
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def field_annotation_is_scalar(annotation: Any) -> bool:
|
|
149
|
+
return annotation is Ellipsis or not field_annotation_is_complex(annotation)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def field_annotation_is_sequence(annotation: type[Any] | None) -> bool:
|
|
153
|
+
return _annotation_is_sequence(annotation) or _annotation_is_sequence(get_origin(annotation))
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def _annotation_is_complex(annotation: type[Any] | None) -> bool:
|
|
157
|
+
return (
|
|
158
|
+
lenient_issubclass(annotation, (BaseModel, Mapping)) # Keep it to UploadFile
|
|
159
|
+
or _annotation_is_sequence(annotation)
|
|
160
|
+
or is_dataclass(annotation)
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
def get_annotation_from_field_info(annotation: Any, field_info: FieldInfo, field_name: str) -> Any:
|
|
164
|
+
return annotation
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
|
|
168
|
+
# Create a shallow copy of the field_info to preserve its type and all attributes
|
|
169
|
+
new_field = copy(field_info)
|
|
170
|
+
|
|
171
|
+
# Recursively extract all metadata from nested Annotated types
|
|
172
|
+
def extract_metadata(ann: Any) -> tuple[Any, list[Any]]:
|
|
173
|
+
"""Extract base type and all non-FieldInfo metadata from potentially nested Annotated types."""
|
|
174
|
+
if get_origin(ann) is not Annotated:
|
|
175
|
+
return ann, []
|
|
176
|
+
|
|
177
|
+
args = get_args(ann)
|
|
178
|
+
base_type = args[0]
|
|
179
|
+
metadata = list(args[1:])
|
|
180
|
+
|
|
181
|
+
# If base type is also Annotated, recursively extract its metadata
|
|
182
|
+
if get_origin(base_type) is Annotated:
|
|
183
|
+
inner_base, inner_metadata = extract_metadata(base_type)
|
|
184
|
+
all_metadata = [m for m in inner_metadata + metadata if not isinstance(m, FieldInfo)]
|
|
185
|
+
return inner_base, all_metadata
|
|
186
|
+
else:
|
|
187
|
+
constraint_metadata = [m for m in metadata if not isinstance(m, FieldInfo)]
|
|
188
|
+
return base_type, constraint_metadata
|
|
189
|
+
|
|
190
|
+
# Extract base type and constraints
|
|
191
|
+
base_type, constraints = extract_metadata(annotation)
|
|
192
|
+
|
|
193
|
+
# Set the annotation with base type and all constraint metadata
|
|
194
|
+
# Use tuple unpacking for Python 3.10+ compatibility
|
|
195
|
+
if constraints:
|
|
196
|
+
new_field.annotation = Annotated[(base_type, *constraints)]
|
|
197
|
+
else:
|
|
198
|
+
new_field.annotation = base_type
|
|
199
|
+
|
|
200
|
+
return new_field
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def get_missing_field_error(loc: tuple[str, ...]) -> dict[str, Any]:
|
|
204
|
+
if hasattr(ValidationError, "from_exception_data"):
|
|
205
|
+
error = ValidationError.from_exception_data(
|
|
206
|
+
"Field required",
|
|
207
|
+
[{"type": "missing", "loc": loc, "input": {}}],
|
|
208
|
+
).errors()[0]
|
|
209
|
+
error["input"] = None
|
|
210
|
+
return error
|
|
211
|
+
return {"type": "missing", "loc": loc, "msg": "Field required", "input": None}
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def _annotation_is_sequence(annotation: type[Any] | None) -> bool:
|
|
215
|
+
if lenient_issubclass(annotation, (str, bytes)):
|
|
216
|
+
return False
|
|
217
|
+
return lenient_issubclass(annotation, sequence_types)
|
|
218
|
+
|
|
219
|
+
def _regenerate_error_with_loc(*, errors: Sequence[Any], loc_prefix: tuple[str | int, ...]) -> list[dict[str, Any]]:
|
|
220
|
+
updated_loc_errors: list[Any] = [
|
|
221
|
+
{**err, "loc": loc_prefix + tuple(err.get("loc", ()))} for err in _normalize_errors(errors)
|
|
222
|
+
]
|
|
223
|
+
|
|
224
|
+
return updated_loc_errors
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def lenient_issubclass(cls: Any, class_or_tuple: Any) -> bool: # pragma: no cover
|
|
228
|
+
try:
|
|
229
|
+
return isinstance(cls, type) and issubclass(cls, class_or_tuple)
|
|
230
|
+
except TypeError:
|
|
231
|
+
return False
|