hypern 0.1.0__cp310-cp310-manylinux_2_34_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. hypern/__init__.py +4 -0
  2. hypern/application.py +234 -0
  3. hypern/auth/__init__.py +0 -0
  4. hypern/auth/authorization.py +2 -0
  5. hypern/background.py +4 -0
  6. hypern/caching/__init__.py +0 -0
  7. hypern/caching/base/__init__.py +8 -0
  8. hypern/caching/base/backend.py +3 -0
  9. hypern/caching/base/key_maker.py +8 -0
  10. hypern/caching/cache_manager.py +56 -0
  11. hypern/caching/cache_tag.py +10 -0
  12. hypern/caching/custom_key_maker.py +11 -0
  13. hypern/caching/redis_backend.py +3 -0
  14. hypern/cli/__init__.py +0 -0
  15. hypern/cli/commands.py +0 -0
  16. hypern/config.py +149 -0
  17. hypern/datastructures.py +27 -0
  18. hypern/db/__init__.py +0 -0
  19. hypern/db/nosql/__init__.py +25 -0
  20. hypern/db/nosql/addons/__init__.py +4 -0
  21. hypern/db/nosql/addons/color.py +16 -0
  22. hypern/db/nosql/addons/daterange.py +30 -0
  23. hypern/db/nosql/addons/encrypted.py +53 -0
  24. hypern/db/nosql/addons/password.py +134 -0
  25. hypern/db/nosql/addons/unicode.py +10 -0
  26. hypern/db/sql/__init__.py +176 -0
  27. hypern/db/sql/addons/__init__.py +14 -0
  28. hypern/db/sql/addons/color.py +15 -0
  29. hypern/db/sql/addons/daterange.py +22 -0
  30. hypern/db/sql/addons/datetime.py +22 -0
  31. hypern/db/sql/addons/encrypted.py +58 -0
  32. hypern/db/sql/addons/password.py +170 -0
  33. hypern/db/sql/addons/ts_vector.py +46 -0
  34. hypern/db/sql/addons/unicode.py +15 -0
  35. hypern/db/sql/repository.py +289 -0
  36. hypern/enum.py +13 -0
  37. hypern/exceptions.py +93 -0
  38. hypern/hypern.cpython-310-x86_64-linux-gnu.so +0 -0
  39. hypern/hypern.pyi +172 -0
  40. hypern/i18n/__init__.py +0 -0
  41. hypern/logging/__init__.py +3 -0
  42. hypern/logging/logger.py +91 -0
  43. hypern/middleware/__init__.py +5 -0
  44. hypern/middleware/base.py +16 -0
  45. hypern/middleware/cors.py +38 -0
  46. hypern/middleware/i18n.py +1 -0
  47. hypern/middleware/limit.py +174 -0
  48. hypern/openapi/__init__.py +5 -0
  49. hypern/openapi/schemas.py +64 -0
  50. hypern/openapi/swagger.py +3 -0
  51. hypern/py.typed +0 -0
  52. hypern/response/__init__.py +3 -0
  53. hypern/response/response.py +134 -0
  54. hypern/routing/__init__.py +4 -0
  55. hypern/routing/dispatcher.py +65 -0
  56. hypern/routing/endpoint.py +27 -0
  57. hypern/routing/parser.py +101 -0
  58. hypern/routing/router.py +279 -0
  59. hypern/scheduler.py +5 -0
  60. hypern/security.py +44 -0
  61. hypern/worker.py +30 -0
  62. hypern-0.1.0.dist-info/METADATA +121 -0
  63. hypern-0.1.0.dist-info/RECORD +65 -0
  64. hypern-0.1.0.dist-info/WHEEL +4 -0
  65. hypern-0.1.0.dist-info/licenses/LICENSE +24 -0
@@ -0,0 +1,38 @@
1
+ from typing import List
2
+ from .base import Middleware
3
+
4
+
5
+ class CORSMiddleware(Middleware):
6
+ """
7
+ The `CORSMiddleware` class is used to add CORS headers to the response based on specified origins,
8
+ methods, and headers.
9
+ """
10
+
11
+ def __init__(self, allow_origins: List[str] = None, allow_methods: List[str] = None, allow_headers: List[str] = None) -> None:
12
+ super().__init__()
13
+ self.allow_origins = allow_origins or []
14
+ self.allow_methods = allow_methods or []
15
+ self.allow_headers = allow_headers or []
16
+
17
+ def before_request(self, request):
18
+ return request
19
+
20
+ def after_request(self, response):
21
+ """
22
+ The `after_request` function adds Access-Control headers to the response based on specified origins,
23
+ methods, and headers.
24
+
25
+ :param response: The `after_request` method is used to add CORS (Cross-Origin Resource Sharing)
26
+ headers to the response object before sending it back to the client. The parameters used in this
27
+ method are:
28
+ :return: The `response` object is being returned from the `after_request` method.
29
+ """
30
+ for origin in self.allow_origins:
31
+ self.app.add_response_header("Access-Control-Allow-Origin", origin)
32
+ self.app.add_response_header(
33
+ "Access-Control-Allow-Methods",
34
+ ", ".join([method.upper() for method in self.allow_methods]),
35
+ )
36
+ self.app.add_response_header("Access-Control-Allow-Headers", ", ".join(self.allow_headers))
37
+ self.app.add_response_header("Access-Control-Allow-Credentials", "true")
38
+ return response
@@ -0,0 +1 @@
1
+ # comming soon
@@ -0,0 +1,174 @@
1
+ from abc import ABC, abstractmethod
2
+ from threading import Lock
3
+ import time
4
+ from robyn import Response, Request
5
+ from .base import Middleware
6
+
7
+
8
+ class StorageBackend(ABC):
9
+ @abstractmethod
10
+ def increment(self, key, amount=1, expire=None):
11
+ pass
12
+
13
+ @abstractmethod
14
+ def get(self, key):
15
+ pass
16
+
17
+
18
+ class RedisBackend(StorageBackend):
19
+ def __init__(self, redis_client):
20
+ self.redis = redis_client
21
+
22
+ def increment(self, key, amount=1, expire=None):
23
+ """
24
+ The `increment` function increments a value in Redis by a specified amount and optionally sets an
25
+ expiration time for the key.
26
+
27
+ :param key: The `key` parameter in the `increment` method is used to specify the key in the Redis
28
+ database that you want to increment
29
+ :param amount: The `amount` parameter in the `increment` method specifies the value by which the
30
+ key's current value should be incremented. By default, it is set to 1, meaning that if no specific
31
+ amount is provided, the key's value will be incremented by 1, defaults to 1 (optional)
32
+ :param expire: The `expire` parameter in the `increment` method is used to specify the expiration
33
+ time for the key in Redis. If a value is provided for `expire`, the key will expire after the
34
+ specified number of seconds. If `expire` is not provided (i.e., it is `None`
35
+ :return: The `increment` method returns the result of incrementing the value of the key by the
36
+ specified amount. If an expiration time is provided, it also sets the expiration time for the key in
37
+ Redis. The method returns the updated value of the key after the increment operation.
38
+ """
39
+ with self.redis.pipeline() as pipe:
40
+ pipe.incr(key, amount)
41
+ if expire:
42
+ pipe.expire(key, int(expire))
43
+ return pipe.execute()[0]
44
+
45
+ def get(self, key):
46
+ return int(self.redis.get(key) or 0)
47
+
48
+
49
+ class InMemoryBackend(StorageBackend):
50
+ def __init__(self):
51
+ self.storage = {}
52
+
53
+ def increment(self, key, amount=1, expire=None):
54
+ """
55
+ The `increment` function updates the value associated with a key in a storage dictionary by a
56
+ specified amount and optionally sets an expiration time.
57
+
58
+ :param key: The `key` parameter in the `increment` method is used to identify the value that needs
59
+ to be incremented in the storage. It serves as a unique identifier for the value being manipulated
60
+ :param amount: The `amount` parameter in the `increment` method specifies the value by which the
61
+ existing value associated with the given `key` should be incremented. By default, if no `amount` is
62
+ provided, it will increment the value by 1, defaults to 1 (optional)
63
+ :param expire: The `expire` parameter in the `increment` method is used to specify the expiration
64
+ time for the key-value pair being incremented. If a value is provided for the `expire` parameter, it
65
+ sets the expiration time for the key in the storage dictionary to the current time plus the
66
+ specified expiration duration
67
+ :return: The function `increment` returns the updated value of the key in the storage after
68
+ incrementing it by the specified amount.
69
+ """
70
+ if key not in self.storage:
71
+ self.storage[key] = {"value": 0, "expire": None}
72
+ self.storage[key]["value"] += amount
73
+ if expire:
74
+ self.storage[key]["expire"] = time.time() + expire
75
+ return self.storage[key]["value"]
76
+
77
+ def get(self, key):
78
+ """
79
+ This Python function retrieves the value associated with a given key from a storage dictionary,
80
+ checking for expiration before returning the value or 0 if the key is not found.
81
+
82
+ :param key: The `key` parameter is used to specify the key of the item you want to retrieve from the
83
+ storage. The function checks if the key exists in the storage dictionary and returns the
84
+ corresponding value if it does. If the key has an expiration time set and it has expired, the
85
+ function deletes the key
86
+ :return: The `get` method returns the value associated with the given key if the key is present in
87
+ the storage and has not expired. If the key is not found or has expired, it returns 0.
88
+ """
89
+ if key in self.storage:
90
+ if self.storage[key]["expire"] and time.time() > self.storage[key]["expire"]:
91
+ del self.storage[key]
92
+ return 0
93
+ return self.storage[key]["value"]
94
+ return 0
95
+
96
+
97
+ class RateLimitMiddleware(Middleware):
98
+ """
99
+ The RateLimitMiddleware class implements rate limiting functionality to restrict the number of
100
+ Requests per minute for a given IP address.
101
+ """
102
+
103
+ def __init__(self, storage_backend, requests_per_minute=60, window_size=60):
104
+ super().__init__()
105
+ self.storage = storage_backend
106
+ self.requests_per_minute = requests_per_minute
107
+ self.window_size = window_size
108
+
109
+ def get_request_identifier(self, request: Request):
110
+ return request.ip_addr
111
+
112
+ def before_request(self, request: Request):
113
+ """
114
+ The `before_request` function checks the request rate limit and returns a 429 status code if the
115
+ limit is exceeded.
116
+
117
+ :param request: The `request` parameter in the `before_request` method is of type `Request`. It
118
+ is used to represent an incoming HTTP request that the server will process
119
+ :type request: Request
120
+ :return: The code snippet is a method called `before_request` that takes in a `Request` object
121
+ as a parameter.
122
+ """
123
+ identifier = self.get_request_identifier(request)
124
+ current_time = int(time.time())
125
+ window_key = f"{identifier}:{current_time // self.window_size}"
126
+
127
+ request_count = self.storage.increment(window_key, expire=self.window_size)
128
+
129
+ if request_count > self.requests_per_minute:
130
+ return Response(status_code=429, description=b"Too Many Requests", headers={"Retry-After": str(self.window_size)})
131
+
132
+ return request
133
+
134
+ def after_request(self, response):
135
+ return response
136
+
137
+
138
+ class ConcurrentRequestMiddleware(Middleware):
139
+ # The `ConcurrentRequestMiddleware` class limits the number of concurrent requests and returns a 429
140
+ # status code with a Retry-After header if the limit is reached.
141
+ def __init__(self, max_concurrent_requests=100):
142
+ super().__init__()
143
+ self.max_concurrent_requests = max_concurrent_requests
144
+ self.current_requests = 0
145
+ self.lock = Lock()
146
+
147
+ def get_request_identifier(self, request):
148
+ return request.ip_addr
149
+
150
+ def before_request(self, request):
151
+ """
152
+ The `before_request` function limits the number of concurrent requests and returns a 429 status code
153
+ with a Retry-After header if the limit is reached.
154
+
155
+ :param request: The `before_request` method in the code snippet is a method that is called before
156
+ processing each incoming request. It checks if the number of current requests is within the allowed
157
+ limit (`max_concurrent_requests`). If the limit is exceeded, it returns a 429 status code with a
158
+ "Too Many Requests
159
+ :return: the `request` object after checking if the number of current requests is within the allowed
160
+ limit. If the limit is exceeded, it returns a 429 status code response with a "Too Many Requests"
161
+ description and a "Retry-After" header set to 5.
162
+ """
163
+
164
+ with self.lock:
165
+ if self.current_requests >= self.max_concurrent_requests:
166
+ return Response(status_code=429, description="Too Many Requests", headers={"Retry-After": "5"})
167
+ self.current_requests += 1
168
+
169
+ return request
170
+
171
+ def after_request(self, response):
172
+ with self.lock:
173
+ self.current_requests -= 1
174
+ return response
@@ -0,0 +1,5 @@
1
+ # -*- coding: utf-8 -*-
2
+ from .schemas import SchemaGenerator
3
+ from .swagger import SwaggerUI
4
+
5
+ __all__ = ["SchemaGenerator", "SwaggerUI"]
@@ -0,0 +1,64 @@
1
+ # -*- coding: utf-8 -*-
2
+ from robyn.router import Route
3
+ from robyn import Robyn, HttpMethod
4
+ from hypern.hypern import BaseSchemaGenerator
5
+ import typing
6
+ import orjson
7
+
8
+
9
+ class EndpointInfo(typing.NamedTuple):
10
+ path: str
11
+ http_method: str
12
+ func: typing.Callable[..., typing.Any]
13
+
14
+
15
+ class SchemaGenerator(BaseSchemaGenerator):
16
+ def __init__(self, base_schema: dict[str, typing.Any]) -> None:
17
+ self.base_schema = base_schema
18
+
19
+ def get_endpoints(self, routes: list[Route]) -> list[EndpointInfo]:
20
+ """
21
+ Given the routes, yields the following information:
22
+
23
+ - path
24
+ eg: /users/
25
+ - http_method
26
+ one of 'get', 'post', 'put', 'patch', 'delete', 'options'
27
+ - func
28
+ method ready to extract the docstring
29
+ """
30
+ endpoints_info: list[EndpointInfo] = []
31
+
32
+ for route in routes:
33
+ method = route.route_type
34
+ http_method = "get"
35
+ if method == HttpMethod.POST:
36
+ http_method = "post"
37
+ elif method == HttpMethod.PUT:
38
+ http_method = "put"
39
+ elif method == HttpMethod.PATCH:
40
+ http_method = "patch"
41
+ elif method == HttpMethod.DELETE:
42
+ http_method = "delete"
43
+ elif method == HttpMethod.OPTIONS:
44
+ http_method = "options"
45
+ endpoints_info.append(EndpointInfo(path=route.route, http_method=http_method, func=route.function.handler))
46
+ return endpoints_info
47
+
48
+ def get_schema(self, app: Robyn) -> dict[str, typing.Any]:
49
+ schema = dict(self.base_schema)
50
+ schema.setdefault("paths", {})
51
+ endpoints_info = self.get_endpoints(app.router.get_routes())
52
+
53
+ for endpoint in endpoints_info:
54
+ parsed = self.parse_docstring(endpoint.func)
55
+
56
+ if not parsed:
57
+ continue
58
+
59
+ if endpoint.path not in schema["paths"]:
60
+ schema["paths"][endpoint.path] = {}
61
+
62
+ schema["paths"][endpoint.path][endpoint.http_method] = orjson.loads(parsed)
63
+
64
+ return schema
@@ -0,0 +1,3 @@
1
+ from hypern.hypern import SwaggerUI
2
+
3
+ __all__ = ["SwaggerUI"]
hypern/py.typed ADDED
File without changes
@@ -0,0 +1,3 @@
1
+ from .response import Response, JSONResponse, HTMLResponse, PlainTextResponse, RedirectResponse, FileResponse
2
+
3
+ __all__ = ["Response", "JSONResponse", "HTMLResponse", "PlainTextResponse", "RedirectResponse", "FileResponse"]
@@ -0,0 +1,134 @@
1
+ from __future__ import annotations
2
+
3
+ import typing
4
+ from urllib.parse import quote
5
+ from robyn import Response as RobynResponse, Headers
6
+ import orjson
7
+
8
+ from hypern.background import BackgroundTask, BackgroundTasks
9
+
10
+
11
+ class BaseResponse:
12
+ media_type = None
13
+ charset = "utf-8"
14
+
15
+ def __init__(
16
+ self,
17
+ content: typing.Any = None,
18
+ status_code: int = 200,
19
+ headers: typing.Mapping[str, str] | None = None,
20
+ media_type: str | None = None,
21
+ backgrounds: typing.List[BackgroundTask] | None = None,
22
+ ) -> None:
23
+ self.status_code = status_code
24
+ if media_type is not None:
25
+ self.media_type = media_type
26
+ self.body = self.render(content)
27
+ self.init_headers(headers)
28
+ self.backgrounds = backgrounds
29
+
30
+ def render(self, content: typing.Any) -> bytes | memoryview:
31
+ if content is None:
32
+ return b""
33
+ if isinstance(content, (bytes, memoryview)):
34
+ return content
35
+ if isinstance(content, str):
36
+ return content.encode(self.charset)
37
+ return orjson.dumps(content) # type: ignore
38
+
39
+ def init_headers(self, headers: typing.Mapping[str, str] | None = None) -> None:
40
+ if headers is None:
41
+ raw_headers: dict = {}
42
+ populate_content_length = True
43
+ populate_content_type = True
44
+ else:
45
+ raw_headers = {k.lower(): v for k, v in headers.items()}
46
+ keys = raw_headers.keys()
47
+ populate_content_length = "content-length" not in keys
48
+ populate_content_type = "content-type" not in keys
49
+
50
+ body = getattr(self, "body", None)
51
+ if body is not None and populate_content_length and not (self.status_code < 200 or self.status_code in (204, 304)):
52
+ content_length = str(len(body))
53
+ raw_headers.setdefault("content-length", content_length)
54
+
55
+ content_type = self.media_type
56
+ if content_type is not None and populate_content_type:
57
+ if content_type.startswith("text/") and "charset=" not in content_type.lower():
58
+ content_type += "; charset=" + self.charset
59
+ raw_headers.setdefault("content-type", content_type)
60
+
61
+ self.raw_headers = raw_headers
62
+
63
+
64
+ def to_response(cls):
65
+ class ResponseWrapper(cls):
66
+ def __new__(cls, *args, **kwargs):
67
+ instance = super().__new__(cls)
68
+ instance.__init__(*args, **kwargs)
69
+ # Execute background tasks
70
+ task_manager = BackgroundTasks()
71
+ if instance.backgrounds:
72
+ for task in instance.backgrounds:
73
+ task_manager.add_task(task)
74
+ task_manager.execute_all()
75
+ del task_manager
76
+
77
+ headers = Headers(instance.raw_headers)
78
+ return RobynResponse(
79
+ status_code=instance.status_code,
80
+ headers=headers,
81
+ description=instance.body,
82
+ )
83
+
84
+ return ResponseWrapper
85
+
86
+
87
+ @to_response
88
+ class Response(BaseResponse):
89
+ media_type = None
90
+ charset = "utf-8"
91
+
92
+
93
+ @to_response
94
+ class JSONResponse(BaseResponse):
95
+ media_type = "application/json"
96
+
97
+
98
+ @to_response
99
+ class HTMLResponse(BaseResponse):
100
+ media_type = "text/html"
101
+
102
+
103
+ @to_response
104
+ class PlainTextResponse(BaseResponse):
105
+ media_type = "text/plain"
106
+
107
+
108
+ @to_response
109
+ class RedirectResponse(BaseResponse):
110
+ def __init__(
111
+ self,
112
+ url: str,
113
+ status_code: int = 307,
114
+ headers: typing.Mapping[str, str] | None = None,
115
+ backgrounds: typing.List[BackgroundTask] | None = None,
116
+ ) -> None:
117
+ super().__init__(content=b"", status_code=status_code, headers=headers, backgrounds=backgrounds)
118
+ self.raw_headers["location"] = quote(str(url), safe=":/%#?=@[]!$&'()*+,;")
119
+
120
+
121
+ @to_response
122
+ class FileResponse(BaseResponse):
123
+ def __init__(
124
+ self,
125
+ content: bytes | memoryview,
126
+ filename: str,
127
+ status_code: int = 200,
128
+ headers: typing.Mapping[str, str] | None = None,
129
+ backgrounds: typing.List[BackgroundTask] | None = None,
130
+ ) -> None:
131
+ super().__init__(content=content, status_code=status_code, headers=headers, backgrounds=backgrounds)
132
+ self.raw_headers["content-disposition"] = f'attachment; filename="{filename}"'
133
+ self.raw_headers.setdefault("content-type", "application/octet-stream")
134
+ self.raw_headers.setdefault("content-length", str(len(content)))
@@ -0,0 +1,4 @@
1
+ from .router import Route
2
+ from .endpoint import HTTPEndpoint
3
+
4
+ __all__ = ["Route", "HTTPEndpoint"]
@@ -0,0 +1,65 @@
1
+ # -*- coding: utf-8 -*-
2
+ from __future__ import annotations
3
+
4
+ from robyn import Request, Response
5
+ from pydantic import BaseModel
6
+ from hypern.exceptions import BaseException
7
+ from hypern.response import JSONResponse
8
+ import typing
9
+ import asyncio
10
+ import functools
11
+ import inspect
12
+ import orjson
13
+ import traceback
14
+
15
+ from .parser import InputHandler
16
+
17
+
18
+ def is_async_callable(obj: typing.Any) -> bool:
19
+ while isinstance(obj, functools.partial):
20
+ obj = obj.funcz
21
+ return asyncio.iscoroutinefunction(obj) or (callable(obj) and asyncio.iscoroutinefunction(obj.__call__))
22
+
23
+
24
+ async def run_in_threadpool(func: typing.Callable, *args, **kwargs):
25
+ if kwargs: # pragma: no cover
26
+ # run_sync doesn't accept 'kwargs', so bind them in here
27
+ func = functools.partial(func, **kwargs)
28
+ return await asyncio.to_thread(func, *args)
29
+
30
+
31
+ async def dispatch(handler, request: Request, global_dependencies, router_dependencies) -> Response:
32
+ try:
33
+ is_async = is_async_callable(handler)
34
+ signature = inspect.signature(handler)
35
+ input_handler = InputHandler(request, global_dependencies, router_dependencies)
36
+ _response_type = signature.return_annotation
37
+ _kwargs = await input_handler.get_input_handler(signature)
38
+
39
+ if is_async:
40
+ response = await handler(**_kwargs) # type: ignore
41
+ else:
42
+ response = await run_in_threadpool(handler, **_kwargs)
43
+ if not isinstance(response, Response):
44
+ if isinstance(_response_type, type) and issubclass(_response_type, BaseModel):
45
+ response = _response_type.model_validate(response).model_dump(mode="json") # type: ignore
46
+ response = JSONResponse(
47
+ content=orjson.dumps({"message": response, "error_code": None}),
48
+ status_code=200,
49
+ )
50
+
51
+ except Exception as e:
52
+ _res: typing.Dict = {"message": "", "error_code": "UNKNOWN_ERROR"}
53
+ if isinstance(e, BaseException):
54
+ _res["error_code"] = e.error_code
55
+ _res["message"] = e.msg
56
+ _status = e.status
57
+ else:
58
+ traceback.print_exc()
59
+ _res["message"] = str(e)
60
+ _status = 400
61
+ response = JSONResponse(
62
+ content=orjson.dumps(_res),
63
+ status_code=_status,
64
+ )
65
+ return response
@@ -0,0 +1,27 @@
1
+ # -*- coding: utf-8 -*-
2
+ from __future__ import annotations
3
+
4
+ from robyn import Request, Response
5
+ from hypern.response import JSONResponse
6
+ import typing
7
+ import orjson
8
+
9
+ from .dispatcher import dispatch
10
+
11
+
12
+ class HTTPEndpoint:
13
+ def __init__(self, *args, **kwargs) -> None:
14
+ super().__init__(*args, **kwargs)
15
+
16
+ def method_not_allowed(self, request: Request) -> Response:
17
+ return JSONResponse(
18
+ description=orjson.dumps({"message": "Method Not Allowed", "error_code": "METHOD_NOT_ALLOW"}),
19
+ status_code=405,
20
+ )
21
+
22
+ async def dispatch(self, request: Request, global_dependencies, router_dependencies) -> Response:
23
+ handler_name = "get" if request.method == "HEAD" and not hasattr(self, "head") else request.method.lower()
24
+ handler: typing.Callable[[Request], typing.Any] = getattr( # type: ignore
25
+ self, handler_name, self.method_not_allowed
26
+ )
27
+ return await dispatch(handler, request, global_dependencies, router_dependencies)
@@ -0,0 +1,101 @@
1
+ # -*- coding: utf-8 -*-
2
+ from __future__ import annotations
3
+
4
+ from pydantic import BaseModel, ValidationError
5
+ from robyn import Request
6
+ from hypern.exceptions import BadRequest, ValidationError as HypernValidationError
7
+ from hypern.auth.authorization import Authorization
8
+ from pydash import get
9
+ import typing
10
+ import inspect
11
+ import orjson
12
+
13
+
14
+ class ParamParser:
15
+ def __init__(self, request: Request):
16
+ self.request = request
17
+
18
+ def parse_data_by_name(self, param_name: str) -> dict:
19
+ param_name = param_name.lower()
20
+ data_parsers = {
21
+ "query_params": self._parse_query_params,
22
+ "path_params": self._parse_path_params,
23
+ "form_data": self._parse_form_data,
24
+ }
25
+
26
+ parser = data_parsers.get(param_name)
27
+ if not parser:
28
+ raise BadRequest(msg="Backend Error: Invalid parameter type, must be query_params, path_params or form_data.")
29
+ return parser()
30
+
31
+ def _parse_query_params(self) -> dict:
32
+ query_params = self.request.query_params.to_dict()
33
+ return {k: v[0] for k, v in query_params.items()}
34
+
35
+ def _parse_path_params(self) -> dict:
36
+ return lambda: dict(self.request.path_params.items())
37
+
38
+ def _parse_form_data(self) -> dict:
39
+ form_data = {k: v for k, v in self.request.form_data.items()}
40
+ return form_data if form_data else self.request.json()
41
+
42
+
43
+ class InputHandler:
44
+ def __init__(self, request, global_dependencies, router_dependencies):
45
+ self.request = request
46
+ self.global_dependencies = global_dependencies
47
+ self.router_dependencies = router_dependencies
48
+ self.param_parser = ParamParser(request)
49
+
50
+ async def parse_pydantic_model(self, param_name: str, model_class: typing.Type[BaseModel]) -> BaseModel:
51
+ try:
52
+ data = self.param_parser.parse_data_by_name(param_name)
53
+ return model_class(**data)
54
+ except ValidationError as e:
55
+ invalid_fields = orjson.loads(e.json())
56
+ raise HypernValidationError(
57
+ msg=orjson.dumps(
58
+ [
59
+ {
60
+ "field": get(item, "loc")[0],
61
+ "msg": get(item, "msg"),
62
+ }
63
+ for item in invalid_fields
64
+ ]
65
+ ).decode("utf-8"),
66
+ )
67
+
68
+ async def handle_special_params(self, param_name: str) -> typing.Any:
69
+ special_params = {
70
+ "request": lambda: self.request,
71
+ "global_dependencies": lambda: self.global_dependencies,
72
+ "router_dependencies": lambda: self.router_dependencies,
73
+ }
74
+ return special_params.get(param_name, lambda: None)()
75
+
76
+ async def get_input_handler(self, signature: inspect.Signature) -> typing.Dict[str, typing.Any]:
77
+ """
78
+ Parse the request data and return the kwargs for the handler
79
+ """
80
+ kwargs = {}
81
+
82
+ for param in signature.parameters.values():
83
+ name = param.name
84
+ ptype = param.annotation
85
+
86
+ # Handle Pydantic models
87
+ if isinstance(ptype, type) and issubclass(ptype, BaseModel):
88
+ kwargs[name] = await self.parse_pydantic_model(name, ptype)
89
+ continue
90
+
91
+ # Handle Authorization
92
+ if isinstance(ptype, type) and issubclass(ptype, Authorization):
93
+ kwargs[name] = await ptype().validate(self.request)
94
+ continue
95
+
96
+ # Handle special parameters
97
+ special_value = await self.handle_special_params(name)
98
+ if special_value is not None:
99
+ kwargs[name] = special_value
100
+
101
+ return kwargs