modmex-lambda 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- modmex_lambda/__init__.py +62 -0
- modmex_lambda/data_classes/__init__.py +49 -0
- modmex_lambda/data_classes/api_gateway_authorizer_event.py +38 -0
- modmex_lambda/data_classes/api_gateway_proxy_event.py +328 -0
- modmex_lambda/data_classes/api_gateway_websocket_event.py +40 -0
- modmex_lambda/data_classes/cognito_user_pool_event.py +599 -0
- modmex_lambda/data_classes/common.py +441 -0
- modmex_lambda/event_handler/__init__.py +45 -0
- modmex_lambda/event_handler/api_gateway.py +331 -0
- modmex_lambda/event_handler/constants.py +3 -0
- modmex_lambda/event_handler/content_types.py +13 -0
- modmex_lambda/event_handler/cors.py +97 -0
- modmex_lambda/event_handler/dependencies/__init__.py +0 -0
- modmex_lambda/event_handler/dependencies/compat.py +231 -0
- modmex_lambda/event_handler/dependencies/dependant.py +279 -0
- modmex_lambda/event_handler/dependencies/dependency_middleware.py +423 -0
- modmex_lambda/event_handler/dependencies/depends.py +184 -0
- modmex_lambda/event_handler/dependencies/params.py +317 -0
- modmex_lambda/event_handler/dependencies/types.py +14 -0
- modmex_lambda/event_handler/exception_handler.py +70 -0
- modmex_lambda/event_handler/exceptions.py +72 -0
- modmex_lambda/event_handler/gateway_response.py +96 -0
- modmex_lambda/event_handler/middlewares.py +33 -0
- modmex_lambda/event_handler/params.py +44 -0
- modmex_lambda/event_handler/request.py +70 -0
- modmex_lambda/event_handler/response.py +60 -0
- modmex_lambda/event_handler/routing.py +507 -0
- modmex_lambda/event_handler/routing_fallbacks.py +92 -0
- modmex_lambda/event_handler/types.py +31 -0
- modmex_lambda/event_sources.py +53 -0
- modmex_lambda/exceptions.py +3 -0
- modmex_lambda/logging.py +99 -0
- modmex_lambda/params.py +3 -0
- modmex_lambda/parser.py +47 -0
- modmex_lambda/request.py +3 -0
- modmex_lambda/resolver.py +3 -0
- modmex_lambda/response.py +3 -0
- modmex_lambda/routing.py +3 -0
- modmex_lambda/shared/__init__.py +0 -0
- modmex_lambda/shared/cookies.py +84 -0
- modmex_lambda/shared/headers_serializer.py +65 -0
- modmex_lambda/shared/json_encoder.py +53 -0
- modmex_lambda/shared/types.py +4 -0
- modmex_lambda/validation.py +178 -0
- modmex_lambda-0.1.0.dist-info/METADATA +375 -0
- modmex_lambda-0.1.0.dist-info/RECORD +48 -0
- modmex_lambda-0.1.0.dist-info/WHEEL +4 -0
- modmex_lambda-0.1.0.dist-info/licenses/LICENSE +21 -0
modmex_lambda/logging.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
"""Lightweight structured logger for Lambda workloads."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import sys
|
|
7
|
+
import traceback
|
|
8
|
+
from datetime import datetime, timezone
|
|
9
|
+
from typing import Any, Callable, TextIO
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Logger:
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
*,
|
|
16
|
+
service: str,
|
|
17
|
+
stream: TextIO | None = None,
|
|
18
|
+
json_serializer: Callable[[dict[str, Any]], str] | None = None,
|
|
19
|
+
correlation_id_header: str = "x-correlation-id",
|
|
20
|
+
) -> None:
|
|
21
|
+
self._service = service
|
|
22
|
+
self._stream = stream or sys.stdout
|
|
23
|
+
self._serialize = json_serializer or (lambda payload: json.dumps(payload, separators=(",", ":"), default=str))
|
|
24
|
+
self._correlation_id_header = correlation_id_header.lower()
|
|
25
|
+
self._persistent_keys: dict[str, Any] = {}
|
|
26
|
+
self._context: object | None = None
|
|
27
|
+
self._event: dict[str, Any] | None = None
|
|
28
|
+
|
|
29
|
+
def set_context(self, *, context: object | None = None, event: dict[str, Any] | None = None) -> None:
|
|
30
|
+
self._context = context
|
|
31
|
+
self._event = event
|
|
32
|
+
|
|
33
|
+
def append_keys(self, **kwargs: Any) -> None:
|
|
34
|
+
self._persistent_keys.update(kwargs)
|
|
35
|
+
|
|
36
|
+
def clear_state(self) -> None:
|
|
37
|
+
self._persistent_keys.clear()
|
|
38
|
+
|
|
39
|
+
def debug(self, message: str, **kwargs: Any) -> None:
|
|
40
|
+
self._log("DEBUG", message, **kwargs)
|
|
41
|
+
|
|
42
|
+
def info(self, message: str, **kwargs: Any) -> None:
|
|
43
|
+
self._log("INFO", message, **kwargs)
|
|
44
|
+
|
|
45
|
+
def warning(self, message: str, **kwargs: Any) -> None:
|
|
46
|
+
self._log("WARNING", message, **kwargs)
|
|
47
|
+
|
|
48
|
+
def error(self, message: str, *, exc_info: bool = False, **kwargs: Any) -> None:
|
|
49
|
+
self._log("ERROR", message, exc_info=exc_info, **kwargs)
|
|
50
|
+
|
|
51
|
+
def critical(self, message: str, *, exc_info: bool = False, **kwargs: Any) -> None:
|
|
52
|
+
self._log("CRITICAL", message, exc_info=exc_info, **kwargs)
|
|
53
|
+
|
|
54
|
+
def _log(self, level: str, message: str, *, exc_info: bool = False, **kwargs: Any) -> None:
|
|
55
|
+
record: dict[str, Any] = {
|
|
56
|
+
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
57
|
+
"level": level,
|
|
58
|
+
"service": self._service,
|
|
59
|
+
"message": message,
|
|
60
|
+
}
|
|
61
|
+
record.update(self._persistent_keys)
|
|
62
|
+
|
|
63
|
+
request_id = self._extract_request_id()
|
|
64
|
+
if request_id:
|
|
65
|
+
record["request_id"] = request_id
|
|
66
|
+
|
|
67
|
+
correlation_id = self._extract_correlation_id()
|
|
68
|
+
if correlation_id and "correlation_id" not in record:
|
|
69
|
+
record["correlation_id"] = correlation_id
|
|
70
|
+
|
|
71
|
+
record.update(kwargs)
|
|
72
|
+
|
|
73
|
+
if exc_info:
|
|
74
|
+
record["exception"] = traceback.format_exc()
|
|
75
|
+
|
|
76
|
+
self._stream.write(self._serialize(record) + "\n")
|
|
77
|
+
|
|
78
|
+
def _extract_request_id(self) -> str | None:
|
|
79
|
+
if self._context is None:
|
|
80
|
+
return None
|
|
81
|
+
value = getattr(self._context, "aws_request_id", None)
|
|
82
|
+
return str(value) if value else None
|
|
83
|
+
|
|
84
|
+
def _extract_correlation_id(self) -> str | None:
|
|
85
|
+
event = self._event
|
|
86
|
+
if not isinstance(event, dict):
|
|
87
|
+
return None
|
|
88
|
+
|
|
89
|
+
headers = event.get("headers")
|
|
90
|
+
if not isinstance(headers, dict):
|
|
91
|
+
return None
|
|
92
|
+
|
|
93
|
+
for key, value in headers.items():
|
|
94
|
+
if str(key).lower() == self._correlation_id_header and value is not None:
|
|
95
|
+
return str(value)
|
|
96
|
+
return None
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
__all__ = ["Logger"]
|
modmex_lambda/params.py
ADDED
modmex_lambda/parser.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
"""Lightweight event parser APIs."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
from functools import wraps
|
|
7
|
+
from typing import Any, Callable, TYPE_CHECKING
|
|
8
|
+
|
|
9
|
+
from modmex.errors import ValidationError
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from .validation import Validator
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def parse(*, event: Any, model: Any, validator: "Validator | None" = None) -> Any:
|
|
16
|
+
"""Validate/transform a raw event into ``model`` using the selected validator."""
|
|
17
|
+
selected_validator = validator
|
|
18
|
+
if selected_validator is None:
|
|
19
|
+
from .validation import ModmexValidator
|
|
20
|
+
|
|
21
|
+
selected_validator = ModmexValidator()
|
|
22
|
+
|
|
23
|
+
raw_event = event
|
|
24
|
+
if isinstance(event, str):
|
|
25
|
+
try:
|
|
26
|
+
raw_event = json.loads(event)
|
|
27
|
+
except json.JSONDecodeError as exc:
|
|
28
|
+
raise ValidationError(errors=[{"loc": [], "msg": str(exc), "type": "value_error.jsondecode"}]) from exc
|
|
29
|
+
|
|
30
|
+
return selected_validator.validate(raw_event, model)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def event_parser(*, model: Any, validator: "Validator | None" = None) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
|
34
|
+
"""Decorator that parses Lambda ``event`` into ``model`` before handler execution."""
|
|
35
|
+
|
|
36
|
+
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
|
37
|
+
@wraps(func)
|
|
38
|
+
def wrapper(event: Any, context: object, *args: Any, **kwargs: Any) -> Any:
|
|
39
|
+
parsed_event = parse(event=event, model=model, validator=validator)
|
|
40
|
+
return func(parsed_event, context, *args, **kwargs)
|
|
41
|
+
|
|
42
|
+
return wrapper
|
|
43
|
+
|
|
44
|
+
return decorator
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
__all__ = ["parse", "event_parser"]
|
modmex_lambda/request.py
ADDED
modmex_lambda/routing.py
ADDED
|
File without changes
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from enum import Enum
|
|
4
|
+
from io import StringIO
|
|
5
|
+
from typing import TYPE_CHECKING
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from datetime import datetime
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class SameSite(Enum):
|
|
12
|
+
|
|
13
|
+
DEFAULT_MODE = ""
|
|
14
|
+
LAX_MODE = "Lax"
|
|
15
|
+
STRICT_MODE = "Strict"
|
|
16
|
+
NONE_MODE = "None"
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _format_date(timestamp: datetime) -> str:
|
|
20
|
+
return timestamp.strftime("%a, %d %b %Y %H:%M:%S GMT")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class Cookie:
|
|
24
|
+
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
name: str,
|
|
28
|
+
value: str,
|
|
29
|
+
path: str = "",
|
|
30
|
+
domain: str = "",
|
|
31
|
+
secure: bool = True,
|
|
32
|
+
http_only: bool = False,
|
|
33
|
+
max_age: int | None = None,
|
|
34
|
+
expires: datetime | None = None,
|
|
35
|
+
same_site: SameSite | None = None,
|
|
36
|
+
custom_attributes: list[str] | None = None,
|
|
37
|
+
):
|
|
38
|
+
|
|
39
|
+
self.name = name
|
|
40
|
+
self.value = value
|
|
41
|
+
self.path = path
|
|
42
|
+
self.domain = domain
|
|
43
|
+
self.secure = secure
|
|
44
|
+
self.expires = expires
|
|
45
|
+
self.max_age = max_age
|
|
46
|
+
self.http_only = http_only
|
|
47
|
+
self.same_site = same_site
|
|
48
|
+
self.custom_attributes = custom_attributes
|
|
49
|
+
|
|
50
|
+
def __str__(self) -> str:
|
|
51
|
+
payload = StringIO()
|
|
52
|
+
payload.write(f"{self.name}={self.value}")
|
|
53
|
+
|
|
54
|
+
if self.path:
|
|
55
|
+
payload.write(f"; Path={self.path}")
|
|
56
|
+
|
|
57
|
+
if self.domain:
|
|
58
|
+
payload.write(f"; Domain={self.domain}")
|
|
59
|
+
|
|
60
|
+
if self.expires:
|
|
61
|
+
payload.write(f"; Expires={_format_date(self.expires)}")
|
|
62
|
+
|
|
63
|
+
if self.max_age:
|
|
64
|
+
if self.max_age > 0:
|
|
65
|
+
payload.write(f"; Max-Age={self.max_age}")
|
|
66
|
+
else:
|
|
67
|
+
# negative or zero max-age should be set to 0
|
|
68
|
+
payload.write("; Max-Age=0")
|
|
69
|
+
|
|
70
|
+
if self.http_only:
|
|
71
|
+
payload.write("; HttpOnly")
|
|
72
|
+
|
|
73
|
+
if self.secure:
|
|
74
|
+
payload.write("; Secure")
|
|
75
|
+
|
|
76
|
+
if self.same_site:
|
|
77
|
+
payload.write(f"; SameSite={self.same_site.value}")
|
|
78
|
+
|
|
79
|
+
if self.custom_attributes:
|
|
80
|
+
for attr in self.custom_attributes:
|
|
81
|
+
payload.write(f"; {attr}")
|
|
82
|
+
|
|
83
|
+
return payload.getvalue()
|
|
84
|
+
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from collections import defaultdict
|
|
6
|
+
from typing import Any
|
|
7
|
+
from modmex_lambda.shared.cookies import Cookie
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class BaseHeadersSerializer(ABC):
|
|
11
|
+
|
|
12
|
+
@abstractmethod
|
|
13
|
+
def serialize(self, headers: dict[str, str | list[str]], cookies: list[Cookie]) -> dict[str, Any]:
|
|
14
|
+
raise NotImplementedError()
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class HttpApiHeadersSerializer(BaseHeadersSerializer):
|
|
18
|
+
def serialize(self, headers: dict[str, str | list[str]], cookies: list[Cookie]) -> dict[str, Any]:
|
|
19
|
+
"""
|
|
20
|
+
When using HTTP APIs or LambdaFunctionURLs, everything is taken care automatically for us.
|
|
21
|
+
We can directly assign a list of cookies and a dict of headers to the response payload, and the
|
|
22
|
+
runtime will automatically serialize them correctly on the output.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
# Format 2.0 doesn't have multiValueHeaders or multiValueQueryStringParameters fields.
|
|
26
|
+
# Duplicate headers are combined with commas and included in the headers field.
|
|
27
|
+
combined_headers: dict[str, str] = {}
|
|
28
|
+
for key, values in headers.items():
|
|
29
|
+
# omit headers with explicit null values
|
|
30
|
+
if values is None:
|
|
31
|
+
continue
|
|
32
|
+
|
|
33
|
+
if isinstance(values, str):
|
|
34
|
+
combined_headers[key] = values
|
|
35
|
+
else:
|
|
36
|
+
combined_headers[key] = ", ".join(values)
|
|
37
|
+
|
|
38
|
+
return {"headers": combined_headers, "cookies": list(map(str, cookies))}
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class MultiValueHeadersSerializer(BaseHeadersSerializer):
|
|
42
|
+
def serialize(self, headers: dict[str, str | list[str]], cookies: list[Cookie]) -> dict[str, Any]:
|
|
43
|
+
"""
|
|
44
|
+
When using REST APIs, headers can be encoded using the `multiValueHeaders` key on the response.
|
|
45
|
+
This is also the case when using an ALB integration with the `multiValueHeaders` option enabled.
|
|
46
|
+
The solution covers headers with just one key or multiple keys.
|
|
47
|
+
"""
|
|
48
|
+
payload: dict[str, list[str]] = defaultdict(list)
|
|
49
|
+
for key, values in headers.items():
|
|
50
|
+
# omit headers with explicit null values
|
|
51
|
+
if values is None:
|
|
52
|
+
continue
|
|
53
|
+
|
|
54
|
+
if isinstance(values, str):
|
|
55
|
+
payload[key].append(values)
|
|
56
|
+
else:
|
|
57
|
+
payload[key].extend(values)
|
|
58
|
+
|
|
59
|
+
if cookies:
|
|
60
|
+
payload.setdefault("Set-Cookie", [])
|
|
61
|
+
for cookie in cookies:
|
|
62
|
+
payload["Set-Cookie"].append(str(cookie))
|
|
63
|
+
|
|
64
|
+
return {"multiValueHeaders": payload}
|
|
65
|
+
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
from json import JSONEncoder as BaseJSONEncoder
|
|
2
|
+
from datetime import date, datetime, time, timedelta
|
|
3
|
+
from decimal import Decimal
|
|
4
|
+
from enum import Enum
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class JSONEncoder(BaseJSONEncoder):
|
|
8
|
+
|
|
9
|
+
def default(self, obj: object) -> object:
|
|
10
|
+
|
|
11
|
+
if is_modmex(obj):
|
|
12
|
+
return obj.model_dump_json()
|
|
13
|
+
|
|
14
|
+
if is_dataclass(obj):
|
|
15
|
+
return dataclass_to_dict(obj)
|
|
16
|
+
|
|
17
|
+
if is_pydantic(obj):
|
|
18
|
+
return obj.model_dump_json()
|
|
19
|
+
|
|
20
|
+
if isinstance(obj, Enum):
|
|
21
|
+
return obj.value
|
|
22
|
+
|
|
23
|
+
if isinstance(obj, datetime):
|
|
24
|
+
return obj.isoformat()
|
|
25
|
+
|
|
26
|
+
if isinstance(obj, date):
|
|
27
|
+
return obj.isoformat()
|
|
28
|
+
|
|
29
|
+
if isinstance(obj, time):
|
|
30
|
+
return obj.isoformat()
|
|
31
|
+
|
|
32
|
+
if isinstance(obj, timedelta):
|
|
33
|
+
return obj.total_seconds()
|
|
34
|
+
|
|
35
|
+
if isinstance(obj, Decimal):
|
|
36
|
+
return float(obj)
|
|
37
|
+
|
|
38
|
+
return super().default(obj)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def is_dataclass(obj) -> bool:
|
|
42
|
+
return hasattr(obj, "__dataclass_fields__")
|
|
43
|
+
|
|
44
|
+
def is_modmex(obj) -> bool:
|
|
45
|
+
return hasattr(obj, "model_dump_json")
|
|
46
|
+
|
|
47
|
+
def is_pydantic(obj) -> bool:
|
|
48
|
+
return hasattr(obj, "model_dump_json")
|
|
49
|
+
|
|
50
|
+
def dataclass_to_dict(obj) -> dict:
|
|
51
|
+
import dataclasses
|
|
52
|
+
|
|
53
|
+
return dataclasses.asdict(obj)
|
|
@@ -0,0 +1,178 @@
|
|
|
1
|
+
"""Validation adapter layer backed by Modmex."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import inspect
|
|
6
|
+
import json
|
|
7
|
+
import types
|
|
8
|
+
from collections.abc import Mapping, Sequence
|
|
9
|
+
from datetime import date, datetime
|
|
10
|
+
from decimal import Decimal
|
|
11
|
+
from enum import Enum
|
|
12
|
+
from typing import Annotated, Any, Literal, Protocol, Union, get_args, get_origin
|
|
13
|
+
|
|
14
|
+
from modmex import BaseModel
|
|
15
|
+
from modmex.errors import ValidationError
|
|
16
|
+
from modmex.validation import (
|
|
17
|
+
bool_validator,
|
|
18
|
+
decimal_validator,
|
|
19
|
+
float_validator,
|
|
20
|
+
int_validator,
|
|
21
|
+
parse_date,
|
|
22
|
+
parse_datetime,
|
|
23
|
+
str_validator,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class Validator(Protocol):
|
|
28
|
+
def validate(self, value: Any, target_type: Any, loc: list[Any] | None = None) -> Any:
|
|
29
|
+
"""Validate/coerce ``value`` into ``target_type``."""
|
|
30
|
+
|
|
31
|
+
def serialize(self, value: Any) -> Any:
|
|
32
|
+
"""Serialize a model-like value into Python primitives."""
|
|
33
|
+
|
|
34
|
+
def dumps(self, value: Any) -> str:
|
|
35
|
+
"""Serialize a value to JSON."""
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class ModmexValidator:
|
|
39
|
+
def __init__(self) -> None:
|
|
40
|
+
self._primitive_validators = {
|
|
41
|
+
str: str_validator,
|
|
42
|
+
int: int_validator,
|
|
43
|
+
bool: bool_validator,
|
|
44
|
+
float: float_validator,
|
|
45
|
+
datetime: parse_datetime,
|
|
46
|
+
date: parse_date,
|
|
47
|
+
Decimal: decimal_validator,
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
def validate(self, value: Any, target_type: Any, loc: list[Any] | None = None) -> Any:
|
|
51
|
+
return self._validate_value(target_type, value, list(loc or []))
|
|
52
|
+
|
|
53
|
+
def serialize(self, value: Any) -> Any:
|
|
54
|
+
if hasattr(value, "model_dump"):
|
|
55
|
+
return value.model_dump()
|
|
56
|
+
if hasattr(value, "to_dict"):
|
|
57
|
+
return value.to_dict()
|
|
58
|
+
return value
|
|
59
|
+
|
|
60
|
+
def dumps(self, value: Any) -> str:
|
|
61
|
+
if hasattr(value, "model_dump_json"):
|
|
62
|
+
return value.model_dump_json()
|
|
63
|
+
return json.dumps(self.serialize(value), separators=(",", ":"), default=self.serialize)
|
|
64
|
+
|
|
65
|
+
def _validate_value(self, annotation: Any, value: Any, loc: list[Any]) -> Any:
|
|
66
|
+
if annotation is inspect.Signature.empty or annotation is Any:
|
|
67
|
+
return value
|
|
68
|
+
|
|
69
|
+
if value is None:
|
|
70
|
+
if self._accepts_none(annotation):
|
|
71
|
+
return None
|
|
72
|
+
raise ValidationError(errors=[{"loc": loc, "msg": "Field required", "type": "missing"}])
|
|
73
|
+
|
|
74
|
+
origin = get_origin(annotation)
|
|
75
|
+
args = get_args(annotation)
|
|
76
|
+
|
|
77
|
+
if origin is not None:
|
|
78
|
+
if origin is Annotated:
|
|
79
|
+
return self._validate_value(args[0], value, loc)
|
|
80
|
+
|
|
81
|
+
if origin in (list, Sequence):
|
|
82
|
+
item_type = args[0] if args else Any
|
|
83
|
+
return [self._validate_value(item_type, item, loc + [idx]) for idx, item in enumerate(value)]
|
|
84
|
+
|
|
85
|
+
if origin is tuple:
|
|
86
|
+
return self._validate_tuple(args, value, loc)
|
|
87
|
+
|
|
88
|
+
if origin is dict:
|
|
89
|
+
return self._validate_dict(args, value, loc)
|
|
90
|
+
|
|
91
|
+
if origin in (Union, types.UnionType):
|
|
92
|
+
return self._validate_union(args, value, loc)
|
|
93
|
+
|
|
94
|
+
if self._is_literal(annotation):
|
|
95
|
+
return self._validate_literal(annotation, value, loc)
|
|
96
|
+
|
|
97
|
+
return self._validate_leaf(annotation, value, loc)
|
|
98
|
+
|
|
99
|
+
def _validate_tuple(self, args: tuple[Any, ...], value: Any, loc: list[Any]) -> tuple[Any, ...]:
|
|
100
|
+
if not isinstance(value, tuple):
|
|
101
|
+
raise ValidationError(errors=[{"loc": loc, "msg": "must be a tuple", "type": "type_error.tuple"}])
|
|
102
|
+
if len(args) == 2 and args[1] is Ellipsis:
|
|
103
|
+
return tuple(self._validate_value(args[0], item, loc + [idx]) for idx, item in enumerate(value))
|
|
104
|
+
if len(value) != len(args):
|
|
105
|
+
raise ValidationError(errors=[{"loc": loc, "msg": "Tuple length mismatch", "type": "type_error"}])
|
|
106
|
+
return tuple(self._validate_value(arg, item, loc + [idx]) for idx, (arg, item) in enumerate(zip(args, value)))
|
|
107
|
+
|
|
108
|
+
def _validate_dict(self, args: tuple[Any, ...], value: Any, loc: list[Any]) -> dict[Any, Any]:
|
|
109
|
+
if not isinstance(value, Mapping):
|
|
110
|
+
raise ValidationError(errors=[{"loc": loc, "msg": "must be a dict", "type": "type_error.dict"}])
|
|
111
|
+
key_type, value_type = args if args else (Any, Any)
|
|
112
|
+
validated: dict[Any, Any] = {}
|
|
113
|
+
for key, item in value.items():
|
|
114
|
+
valid_key = self._validate_value(key_type, key, loc + ["<key>"])
|
|
115
|
+
validated[valid_key] = self._validate_value(value_type, item, loc + [key])
|
|
116
|
+
return validated
|
|
117
|
+
|
|
118
|
+
def _validate_union(self, args: tuple[Any, ...], value: Any, loc: list[Any]) -> Any:
|
|
119
|
+
candidate_errors: list[dict[str, Any]] = []
|
|
120
|
+
for candidate in args:
|
|
121
|
+
try:
|
|
122
|
+
return self._validate_value(candidate, value, loc)
|
|
123
|
+
except ValidationError as exc:
|
|
124
|
+
candidate_errors.extend(exc.errors)
|
|
125
|
+
raise ValidationError(
|
|
126
|
+
errors=candidate_errors or [{"loc": loc, "msg": "Invalid union type", "type": "type_error"}]
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
def _validate_literal(self, annotation: Any, value: Any, loc: list[Any]) -> Any:
|
|
130
|
+
if value in get_args(annotation):
|
|
131
|
+
return value
|
|
132
|
+
raise ValidationError(
|
|
133
|
+
errors=[{"loc": loc, "msg": f"Unexpected literal value: {value}", "type": "literal_error"}]
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
def _validate_leaf(self, annotation: Any, value: Any, loc: list[Any]) -> Any:
|
|
137
|
+
try:
|
|
138
|
+
primitive_validator = self._primitive_validators.get(annotation)
|
|
139
|
+
if primitive_validator is not None:
|
|
140
|
+
return primitive_validator(value)
|
|
141
|
+
|
|
142
|
+
if inspect.isclass(annotation) and issubclass(annotation, Enum):
|
|
143
|
+
return annotation(value)
|
|
144
|
+
|
|
145
|
+
if inspect.isclass(annotation) and issubclass(annotation, BaseModel):
|
|
146
|
+
if isinstance(value, annotation):
|
|
147
|
+
return value
|
|
148
|
+
if not isinstance(value, Mapping):
|
|
149
|
+
raise ValidationError(errors=[{"loc": loc, "msg": "Expected object", "type": "type_error"}])
|
|
150
|
+
return annotation(**value)
|
|
151
|
+
|
|
152
|
+
if isinstance(value, annotation):
|
|
153
|
+
return value
|
|
154
|
+
|
|
155
|
+
if isinstance(value, Mapping):
|
|
156
|
+
return annotation(**value)
|
|
157
|
+
|
|
158
|
+
return annotation(value)
|
|
159
|
+
except ValidationError as exc:
|
|
160
|
+
raise ValidationError(errors=[self._prefix_error(error, loc) for error in exc.errors]) from exc
|
|
161
|
+
except (TypeError, ValueError) as exc:
|
|
162
|
+
raise ValidationError(errors=[{"loc": loc, "msg": str(exc), "type": "type_error"}]) from exc
|
|
163
|
+
|
|
164
|
+
def _accepts_none(self, annotation: Any) -> bool:
|
|
165
|
+
origin = get_origin(annotation)
|
|
166
|
+
if origin is None:
|
|
167
|
+
return annotation is type(None)
|
|
168
|
+
return any(arg is type(None) for arg in get_args(annotation))
|
|
169
|
+
|
|
170
|
+
def _is_literal(self, annotation: Any) -> bool:
|
|
171
|
+
return get_origin(annotation) is Literal
|
|
172
|
+
|
|
173
|
+
def _prefix_error(self, error: dict[str, Any], prefix: list[Any]) -> dict[str, Any]:
|
|
174
|
+
return {
|
|
175
|
+
"loc": [*prefix, *list(error.get("loc", []))],
|
|
176
|
+
"msg": error.get("msg", "Validation error"),
|
|
177
|
+
"type": error.get("type", "type_error"),
|
|
178
|
+
}
|