hypern 0.2.0__cp310-none-win32.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hypern/__init__.py +4 -0
- hypern/application.py +412 -0
- hypern/auth/__init__.py +0 -0
- hypern/auth/authorization.py +2 -0
- hypern/background.py +4 -0
- hypern/caching/__init__.py +0 -0
- hypern/caching/base/__init__.py +8 -0
- hypern/caching/base/backend.py +3 -0
- hypern/caching/base/key_maker.py +8 -0
- hypern/caching/cache_manager.py +56 -0
- hypern/caching/cache_tag.py +10 -0
- hypern/caching/custom_key_maker.py +11 -0
- hypern/caching/redis_backend.py +3 -0
- hypern/cli/__init__.py +0 -0
- hypern/cli/commands.py +0 -0
- hypern/config.py +149 -0
- hypern/datastructures.py +40 -0
- hypern/db/__init__.py +0 -0
- hypern/db/nosql/__init__.py +25 -0
- hypern/db/nosql/addons/__init__.py +4 -0
- hypern/db/nosql/addons/color.py +16 -0
- hypern/db/nosql/addons/daterange.py +30 -0
- hypern/db/nosql/addons/encrypted.py +53 -0
- hypern/db/nosql/addons/password.py +134 -0
- hypern/db/nosql/addons/unicode.py +10 -0
- hypern/db/sql/__init__.py +179 -0
- hypern/db/sql/addons/__init__.py +14 -0
- hypern/db/sql/addons/color.py +16 -0
- hypern/db/sql/addons/daterange.py +23 -0
- hypern/db/sql/addons/datetime.py +22 -0
- hypern/db/sql/addons/encrypted.py +58 -0
- hypern/db/sql/addons/password.py +171 -0
- hypern/db/sql/addons/ts_vector.py +46 -0
- hypern/db/sql/addons/unicode.py +15 -0
- hypern/db/sql/repository.py +290 -0
- hypern/enum.py +13 -0
- hypern/exceptions.py +97 -0
- hypern/hypern.cp310-win32.pyd +0 -0
- hypern/hypern.pyi +266 -0
- hypern/i18n/__init__.py +0 -0
- hypern/logging/__init__.py +3 -0
- hypern/logging/logger.py +82 -0
- hypern/middleware/__init__.py +5 -0
- hypern/middleware/base.py +18 -0
- hypern/middleware/cors.py +38 -0
- hypern/middleware/i18n.py +1 -0
- hypern/middleware/limit.py +176 -0
- hypern/openapi/__init__.py +5 -0
- hypern/openapi/schemas.py +53 -0
- hypern/openapi/swagger.py +3 -0
- hypern/processpool.py +106 -0
- hypern/py.typed +0 -0
- hypern/response/__init__.py +3 -0
- hypern/response/response.py +134 -0
- hypern/routing/__init__.py +4 -0
- hypern/routing/dispatcher.py +67 -0
- hypern/routing/endpoint.py +30 -0
- hypern/routing/parser.py +100 -0
- hypern/routing/route.py +284 -0
- hypern/scheduler.py +5 -0
- hypern/security.py +44 -0
- hypern/worker.py +30 -0
- hypern-0.2.0.dist-info/METADATA +127 -0
- hypern-0.2.0.dist-info/RECORD +66 -0
- hypern-0.2.0.dist-info/WHEEL +4 -0
- hypern-0.2.0.dist-info/licenses/LICENSE +24 -0
hypern/hypern.pyi
ADDED
|
@@ -0,0 +1,266 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any, Callable, Dict, List, Tuple
|
|
5
|
+
|
|
6
|
+
@dataclass
|
|
7
|
+
class BaseBackend:
|
|
8
|
+
get: Callable[[str], Any]
|
|
9
|
+
set: Callable[[Any, str, int], None]
|
|
10
|
+
delete_startswith: Callable[[str], None]
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class RedisBackend(BaseBackend):
|
|
14
|
+
url: str
|
|
15
|
+
|
|
16
|
+
get: Callable[[str], Any]
|
|
17
|
+
set: Callable[[Any, str, int], None]
|
|
18
|
+
delete_startswith: Callable[[str], None]
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class BaseSchemaGenerator:
|
|
22
|
+
remove_converter: Callable[[str], str]
|
|
23
|
+
parse_docstring: Callable[[Callable[..., Any]], str]
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class SwaggerUI:
|
|
27
|
+
title: str
|
|
28
|
+
openapi_url: str
|
|
29
|
+
|
|
30
|
+
def get_html_content(self) -> str: ...
|
|
31
|
+
|
|
32
|
+
@dataclass
|
|
33
|
+
class BackgroundTask:
|
|
34
|
+
"""
|
|
35
|
+
A task to be executed in the background
|
|
36
|
+
id: str: The task ID
|
|
37
|
+
function: Callable[..., Any]: The function to be executed
|
|
38
|
+
args: List | Tuple: The arguments to be passed to the function
|
|
39
|
+
kwargs: Dict[str, Any]: The keyword arguments to be passed to the function
|
|
40
|
+
timeout_secs: int: The maximum time in seconds the task is allowed to run
|
|
41
|
+
cancelled: bool: Whether the task is cancelled
|
|
42
|
+
|
|
43
|
+
**Note**: function is currently running with sync mode, so it should be a sync function
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
id: str
|
|
47
|
+
function: Callable[..., Any]
|
|
48
|
+
args: List | Tuple
|
|
49
|
+
kwargs: Dict[str, Any]
|
|
50
|
+
timeout_secs: int
|
|
51
|
+
cancelled: bool
|
|
52
|
+
|
|
53
|
+
def get_id(self) -> str:
|
|
54
|
+
"""
|
|
55
|
+
Get the task ID
|
|
56
|
+
"""
|
|
57
|
+
pass
|
|
58
|
+
|
|
59
|
+
def cancel(self) -> None:
|
|
60
|
+
"""
|
|
61
|
+
Cancel the task
|
|
62
|
+
"""
|
|
63
|
+
pass
|
|
64
|
+
|
|
65
|
+
def is_cancelled(self) -> bool:
|
|
66
|
+
"""
|
|
67
|
+
Check if the task is cancelled
|
|
68
|
+
"""
|
|
69
|
+
pass
|
|
70
|
+
|
|
71
|
+
def execute(self) -> Any:
|
|
72
|
+
"""
|
|
73
|
+
Execute the task
|
|
74
|
+
"""
|
|
75
|
+
pass
|
|
76
|
+
|
|
77
|
+
@dataclass
|
|
78
|
+
class BackgroundTasks:
|
|
79
|
+
"""
|
|
80
|
+
A collection of tasks to be executed in the background
|
|
81
|
+
|
|
82
|
+
**Note**: Only set tasks. pool, sender, receiver are set by the framework
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
def add_task(self, task: BackgroundTask) -> str:
|
|
86
|
+
"""
|
|
87
|
+
Add a task to the collection
|
|
88
|
+
"""
|
|
89
|
+
pass
|
|
90
|
+
|
|
91
|
+
def cancel_task(self, task_id: str) -> bool:
|
|
92
|
+
"""
|
|
93
|
+
Cancel a task in the collection
|
|
94
|
+
"""
|
|
95
|
+
pass
|
|
96
|
+
|
|
97
|
+
def execute_all(self) -> None:
|
|
98
|
+
"""
|
|
99
|
+
Execute all tasks in the collection
|
|
100
|
+
"""
|
|
101
|
+
pass
|
|
102
|
+
|
|
103
|
+
def execute_task(self, task_id: str) -> None:
|
|
104
|
+
"""
|
|
105
|
+
Execute a task in the collection
|
|
106
|
+
"""
|
|
107
|
+
pass
|
|
108
|
+
|
|
109
|
+
class Scheduler:
|
|
110
|
+
def add_job(
|
|
111
|
+
self,
|
|
112
|
+
job_type: str,
|
|
113
|
+
schedule_param: str,
|
|
114
|
+
task: Callable[..., Any],
|
|
115
|
+
timezone: str,
|
|
116
|
+
dependencies: List[str],
|
|
117
|
+
retry_policy: Tuple[int, int, bool] | None = None,
|
|
118
|
+
) -> str:
|
|
119
|
+
"""
|
|
120
|
+
Add a job to the scheduler
|
|
121
|
+
params:
|
|
122
|
+
job_type: str: The type of the job (e.g. "cron", "interval")
|
|
123
|
+
|
|
124
|
+
schedule_param: str: The schedule parameter of the job. interval in seconds for interval jobs, cron expression for cron jobs
|
|
125
|
+
|
|
126
|
+
Exmaple:
|
|
127
|
+
// sec min hour day of month month day of week year
|
|
128
|
+
expression = "0 30 9,12,15 1,15 May-Aug Mon,Wed,Fri 2018/2";
|
|
129
|
+
|
|
130
|
+
task: Callable[..., Any]: The task to be executed
|
|
131
|
+
|
|
132
|
+
timezone: str: The timezone of the job
|
|
133
|
+
|
|
134
|
+
dependencies: List[str]: The IDs of the jobs this job depends on
|
|
135
|
+
|
|
136
|
+
retry_policy: Tuple[int, int, bool] | None: The retry policy of the job. (max_retries, retry_delay_secs, exponential_backoff)
|
|
137
|
+
|
|
138
|
+
return:
|
|
139
|
+
str: The ID of the job
|
|
140
|
+
"""
|
|
141
|
+
pass
|
|
142
|
+
|
|
143
|
+
def remove_job(self, job_id: str) -> None:
|
|
144
|
+
"""
|
|
145
|
+
Remove a job from the scheduler
|
|
146
|
+
"""
|
|
147
|
+
pass
|
|
148
|
+
|
|
149
|
+
def start(self) -> None:
|
|
150
|
+
"""
|
|
151
|
+
Start the scheduler
|
|
152
|
+
"""
|
|
153
|
+
pass
|
|
154
|
+
|
|
155
|
+
def stop(self) -> None:
|
|
156
|
+
"""
|
|
157
|
+
Stop the scheduler
|
|
158
|
+
"""
|
|
159
|
+
pass
|
|
160
|
+
|
|
161
|
+
def get_job_status(self, job_id: str) -> Tuple[float, float, List[str], int]:
|
|
162
|
+
"""
|
|
163
|
+
Get the status of a job
|
|
164
|
+
"""
|
|
165
|
+
pass
|
|
166
|
+
|
|
167
|
+
def get_next_run(self, job_id: str) -> float:
|
|
168
|
+
"""
|
|
169
|
+
Get the next run time of a job
|
|
170
|
+
"""
|
|
171
|
+
pass
|
|
172
|
+
|
|
173
|
+
@dataclass
|
|
174
|
+
class FunctionInfo:
|
|
175
|
+
"""
|
|
176
|
+
The function info object passed to the route handler.
|
|
177
|
+
|
|
178
|
+
Attributes:
|
|
179
|
+
handler (Callable): The function to be called
|
|
180
|
+
is_async (bool): Whether the function is async or not
|
|
181
|
+
"""
|
|
182
|
+
|
|
183
|
+
handler: Callable
|
|
184
|
+
is_async: bool
|
|
185
|
+
|
|
186
|
+
class SocketHeld:
|
|
187
|
+
socket: Any
|
|
188
|
+
|
|
189
|
+
@dataclass
|
|
190
|
+
class Server:
|
|
191
|
+
router: Router
|
|
192
|
+
websocket_router: Any
|
|
193
|
+
startup_handler: Any
|
|
194
|
+
shutdown_handler: Any
|
|
195
|
+
|
|
196
|
+
def add_route(self, route: Route) -> None: ...
|
|
197
|
+
def set_router(self, router: Router) -> None: ...
|
|
198
|
+
def start(self, socket: SocketHeld, worker: int, max_blocking_threads: int) -> None: ...
|
|
199
|
+
def inject(self, key: str, value: Any) -> None: ...
|
|
200
|
+
def set_injected(self, injected: Dict[str, Any]) -> None: ...
|
|
201
|
+
def set_before_hooks(self, hooks: List[FunctionInfo]) -> None: ...
|
|
202
|
+
def set_after_hooks(self, hooks: List[FunctionInfo]) -> None: ...
|
|
203
|
+
def set_response_headers(self, headers: Dict[str, str]) -> None: ...
|
|
204
|
+
|
|
205
|
+
class Route:
|
|
206
|
+
path: str
|
|
207
|
+
function: FunctionInfo
|
|
208
|
+
method: str
|
|
209
|
+
|
|
210
|
+
def matches(self, path: str, method: str) -> str: ...
|
|
211
|
+
def clone_route(self) -> Route: ...
|
|
212
|
+
def update_path(self, new_path: str) -> None: ...
|
|
213
|
+
def update_method(self, new_method: str) -> None: ...
|
|
214
|
+
def is_valid(self) -> bool: ...
|
|
215
|
+
def get_path_parans(self) -> List[str]: ...
|
|
216
|
+
def has_parameters(self) -> bool: ...
|
|
217
|
+
def normalized_path(self) -> str: ...
|
|
218
|
+
def same_handler(self, other: Route) -> bool: ...
|
|
219
|
+
|
|
220
|
+
class Router:
|
|
221
|
+
routes: List[Route]
|
|
222
|
+
|
|
223
|
+
def add_route(self, route: Route) -> None: ...
|
|
224
|
+
def remove_route(self, path: str, method: str) -> bool: ...
|
|
225
|
+
def get_route(self, path: str, method) -> Route | None: ...
|
|
226
|
+
def get_routes_by_path(self, path: str) -> List[Route]: ...
|
|
227
|
+
def get_routes_by_method(self, method: str) -> List[Route]: ...
|
|
228
|
+
def extend_route(self, routes: List[Route]) -> None: ...
|
|
229
|
+
|
|
230
|
+
@dataclass
|
|
231
|
+
class Header:
|
|
232
|
+
headers: Dict[str, str]
|
|
233
|
+
|
|
234
|
+
@dataclass
|
|
235
|
+
class Response:
|
|
236
|
+
status_code: int
|
|
237
|
+
response_type: str
|
|
238
|
+
headers: Any
|
|
239
|
+
description: str
|
|
240
|
+
file_path: str
|
|
241
|
+
|
|
242
|
+
@dataclass
|
|
243
|
+
class QueryParams:
|
|
244
|
+
queries: Dict[str, List[str]]
|
|
245
|
+
|
|
246
|
+
@dataclass
|
|
247
|
+
class UploadedFile:
|
|
248
|
+
name: str
|
|
249
|
+
content_type: str
|
|
250
|
+
path: str
|
|
251
|
+
size: int
|
|
252
|
+
content: bytes
|
|
253
|
+
filename: str
|
|
254
|
+
|
|
255
|
+
@dataclass
|
|
256
|
+
class BodyData:
|
|
257
|
+
json: bytes
|
|
258
|
+
files: List[UploadedFile]
|
|
259
|
+
|
|
260
|
+
@dataclass
|
|
261
|
+
class Request:
|
|
262
|
+
query_params: QueryParams
|
|
263
|
+
headers: Dict[str, str]
|
|
264
|
+
path_params: Dict[str, str]
|
|
265
|
+
body: BodyData
|
|
266
|
+
method: str
|
hypern/i18n/__init__.py
ADDED
|
File without changes
|
hypern/logging/logger.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
import logging
|
|
3
|
+
import sys
|
|
4
|
+
from copy import copy
|
|
5
|
+
from datetime import datetime, timezone
|
|
6
|
+
from typing import Literal, Optional
|
|
7
|
+
|
|
8
|
+
import click
|
|
9
|
+
|
|
10
|
+
TRACE_LOG_LEVEL = 5
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ColourizedFormatter(logging.Formatter):
|
|
14
|
+
level_name_colors = {
|
|
15
|
+
TRACE_LOG_LEVEL: lambda level_name: click.style(str(level_name), fg="blue"),
|
|
16
|
+
logging.DEBUG: lambda level_name: click.style(str(level_name), fg="cyan"),
|
|
17
|
+
logging.INFO: lambda level_name: click.style(str(level_name), fg="green"),
|
|
18
|
+
logging.WARNING: lambda level_name: click.style(str(level_name), fg="yellow"),
|
|
19
|
+
logging.ERROR: lambda level_name: click.style(str(level_name), fg="red"),
|
|
20
|
+
logging.CRITICAL: lambda level_name: click.style(str(level_name), fg="bright_red"),
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
fmt: Optional[str] = None,
|
|
26
|
+
datefmt: Optional[str] = None,
|
|
27
|
+
style: Literal["%", "{", "$"] = "%",
|
|
28
|
+
use_colors: Optional[bool] = None,
|
|
29
|
+
):
|
|
30
|
+
if use_colors in (True, False):
|
|
31
|
+
self.use_colors = use_colors
|
|
32
|
+
else:
|
|
33
|
+
self.use_colors = sys.stdout.isatty()
|
|
34
|
+
super().__init__(fmt=fmt, datefmt=datefmt, style=style)
|
|
35
|
+
|
|
36
|
+
def color_level_name(self, level_name: str, level_no: int) -> str:
|
|
37
|
+
def default(level_name: str) -> str:
|
|
38
|
+
return str(level_name)
|
|
39
|
+
|
|
40
|
+
func = self.level_name_colors.get(level_no, default)
|
|
41
|
+
return func(level_name)
|
|
42
|
+
|
|
43
|
+
def should_use_colors(self) -> bool:
|
|
44
|
+
return True
|
|
45
|
+
|
|
46
|
+
def formatMessage(self, record: logging.LogRecord) -> str:
|
|
47
|
+
recordcopy = copy(record)
|
|
48
|
+
levelname = recordcopy.levelname
|
|
49
|
+
process = recordcopy.process
|
|
50
|
+
created = recordcopy.created
|
|
51
|
+
filename = recordcopy.filename
|
|
52
|
+
module = recordcopy.module
|
|
53
|
+
lineno = recordcopy.lineno
|
|
54
|
+
separator = " " * (5 - len(recordcopy.levelname))
|
|
55
|
+
if self.use_colors:
|
|
56
|
+
levelname = self.color_level_name(levelname, recordcopy.levelno)
|
|
57
|
+
if "color_message" in recordcopy.__dict__:
|
|
58
|
+
recordcopy.msg = recordcopy.__dict__["color_message"]
|
|
59
|
+
recordcopy.__dict__["message"] = recordcopy.getMessage()
|
|
60
|
+
recordcopy.__dict__["levelprefix"] = levelname + separator
|
|
61
|
+
recordcopy.__dict__["process"] = click.style(str(process), fg="blue")
|
|
62
|
+
recordcopy.__dict__["asctime"] = click.style(datetime.fromtimestamp(created, tz=timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%fZ"), fg=(101, 111, 104))
|
|
63
|
+
recordcopy.__dict__["filename"] = click.style(f"{module}/{filename}:{lineno}:", fg=(101, 111, 104))
|
|
64
|
+
return super().formatMessage(recordcopy)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class DefaultFormatter(ColourizedFormatter):
|
|
68
|
+
def should_use_colors(self) -> bool:
|
|
69
|
+
return sys.stderr.isatty()
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def create_logger(name) -> logging.Logger:
|
|
73
|
+
logger = logging.getLogger(name)
|
|
74
|
+
logger.setLevel(logging.DEBUG)
|
|
75
|
+
formatter = DefaultFormatter(fmt="%(asctime)s %(levelprefix)s %(filename)s %(message)s", use_colors=True, datefmt="%Y-%m-%d %H:%M:%S")
|
|
76
|
+
handler = logging.StreamHandler()
|
|
77
|
+
handler.setFormatter(formatter)
|
|
78
|
+
logger.addHandler(handler)
|
|
79
|
+
return logger
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
logger = create_logger("hypern")
|
|
@@ -0,0 +1,5 @@
|
|
|
1
|
+
from .base import Middleware
|
|
2
|
+
from .cors import CORSMiddleware
|
|
3
|
+
from .limit import RateLimitMiddleware, StorageBackend, RedisBackend, InMemoryBackend
|
|
4
|
+
|
|
5
|
+
__all__ = ["Middleware", "CORSMiddleware", "RateLimitMiddleware", "StorageBackend", "RedisBackend", "InMemoryBackend"]
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from hypern.hypern import Response, Request
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
# The `Middleware` class is an abstract base class with abstract methods `before_request` and
|
|
6
|
+
# `after_request` for handling requests and responses in a web application.
|
|
7
|
+
class Middleware(ABC):
|
|
8
|
+
def __init__(self) -> None:
|
|
9
|
+
super().__init__()
|
|
10
|
+
self.app = None
|
|
11
|
+
|
|
12
|
+
@abstractmethod
|
|
13
|
+
def before_request(self, request: Request):
|
|
14
|
+
pass
|
|
15
|
+
|
|
16
|
+
@abstractmethod
|
|
17
|
+
def after_request(self, response: Response):
|
|
18
|
+
pass
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
from .base import Middleware
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class CORSMiddleware(Middleware):
|
|
6
|
+
"""
|
|
7
|
+
The `CORSMiddleware` class is used to add CORS headers to the response based on specified origins,
|
|
8
|
+
methods, and headers.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
def __init__(self, allow_origins: List[str] = None, allow_methods: List[str] = None, allow_headers: List[str] = None) -> None:
|
|
12
|
+
super().__init__()
|
|
13
|
+
self.allow_origins = allow_origins or []
|
|
14
|
+
self.allow_methods = allow_methods or []
|
|
15
|
+
self.allow_headers = allow_headers or []
|
|
16
|
+
|
|
17
|
+
def before_request(self, request):
|
|
18
|
+
return request
|
|
19
|
+
|
|
20
|
+
def after_request(self, response):
|
|
21
|
+
"""
|
|
22
|
+
The `after_request` function adds Access-Control headers to the response based on specified origins,
|
|
23
|
+
methods, and headers.
|
|
24
|
+
|
|
25
|
+
:param response: The `after_request` method is used to add CORS (Cross-Origin Resource Sharing)
|
|
26
|
+
headers to the response object before sending it back to the client. The parameters used in this
|
|
27
|
+
method are:
|
|
28
|
+
:return: The `response` object is being returned from the `after_request` method.
|
|
29
|
+
"""
|
|
30
|
+
for origin in self.allow_origins:
|
|
31
|
+
self.app.add_response_header("Access-Control-Allow-Origin", origin)
|
|
32
|
+
self.app.add_response_header(
|
|
33
|
+
"Access-Control-Allow-Methods",
|
|
34
|
+
", ".join([method.upper() for method in self.allow_methods]),
|
|
35
|
+
)
|
|
36
|
+
self.app.add_response_header("Access-Control-Allow-Headers", ", ".join(self.allow_headers))
|
|
37
|
+
self.app.add_response_header("Access-Control-Allow-Credentials", "true")
|
|
38
|
+
return response
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# comming soon
|
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
import time
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from threading import Lock
|
|
4
|
+
|
|
5
|
+
from hypern.hypern import Request, Response
|
|
6
|
+
|
|
7
|
+
from .base import Middleware
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class StorageBackend(ABC):
|
|
11
|
+
@abstractmethod
|
|
12
|
+
def increment(self, key, amount=1, expire=None):
|
|
13
|
+
pass
|
|
14
|
+
|
|
15
|
+
@abstractmethod
|
|
16
|
+
def get(self, key):
|
|
17
|
+
pass
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class RedisBackend(StorageBackend):
|
|
21
|
+
def __init__(self, redis_client):
|
|
22
|
+
self.redis = redis_client
|
|
23
|
+
|
|
24
|
+
def increment(self, key, amount=1, expire=None):
|
|
25
|
+
"""
|
|
26
|
+
The `increment` function increments a value in Redis by a specified amount and optionally sets an
|
|
27
|
+
expiration time for the key.
|
|
28
|
+
|
|
29
|
+
:param key: The `key` parameter in the `increment` method is used to specify the key in the Redis
|
|
30
|
+
database that you want to increment
|
|
31
|
+
:param amount: The `amount` parameter in the `increment` method specifies the value by which the
|
|
32
|
+
key's current value should be incremented. By default, it is set to 1, meaning that if no specific
|
|
33
|
+
amount is provided, the key's value will be incremented by 1, defaults to 1 (optional)
|
|
34
|
+
:param expire: The `expire` parameter in the `increment` method is used to specify the expiration
|
|
35
|
+
time for the key in Redis. If a value is provided for `expire`, the key will expire after the
|
|
36
|
+
specified number of seconds. If `expire` is not provided (i.e., it is `None`
|
|
37
|
+
:return: The `increment` method returns the result of incrementing the value of the key by the
|
|
38
|
+
specified amount. If an expiration time is provided, it also sets the expiration time for the key in
|
|
39
|
+
Redis. The method returns the updated value of the key after the increment operation.
|
|
40
|
+
"""
|
|
41
|
+
with self.redis.pipeline() as pipe:
|
|
42
|
+
pipe.incr(key, amount)
|
|
43
|
+
if expire:
|
|
44
|
+
pipe.expire(key, int(expire))
|
|
45
|
+
return pipe.execute()[0]
|
|
46
|
+
|
|
47
|
+
def get(self, key):
|
|
48
|
+
return int(self.redis.get(key) or 0)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class InMemoryBackend(StorageBackend):
|
|
52
|
+
def __init__(self):
|
|
53
|
+
self.storage = {}
|
|
54
|
+
|
|
55
|
+
def increment(self, key, amount=1, expire=None):
|
|
56
|
+
"""
|
|
57
|
+
The `increment` function updates the value associated with a key in a storage dictionary by a
|
|
58
|
+
specified amount and optionally sets an expiration time.
|
|
59
|
+
|
|
60
|
+
:param key: The `key` parameter in the `increment` method is used to identify the value that needs
|
|
61
|
+
to be incremented in the storage. It serves as a unique identifier for the value being manipulated
|
|
62
|
+
:param amount: The `amount` parameter in the `increment` method specifies the value by which the
|
|
63
|
+
existing value associated with the given `key` should be incremented. By default, if no `amount` is
|
|
64
|
+
provided, it will increment the value by 1, defaults to 1 (optional)
|
|
65
|
+
:param expire: The `expire` parameter in the `increment` method is used to specify the expiration
|
|
66
|
+
time for the key-value pair being incremented. If a value is provided for the `expire` parameter, it
|
|
67
|
+
sets the expiration time for the key in the storage dictionary to the current time plus the
|
|
68
|
+
specified expiration duration
|
|
69
|
+
:return: The function `increment` returns the updated value of the key in the storage after
|
|
70
|
+
incrementing it by the specified amount.
|
|
71
|
+
"""
|
|
72
|
+
if key not in self.storage:
|
|
73
|
+
self.storage[key] = {"value": 0, "expire": None}
|
|
74
|
+
self.storage[key]["value"] += amount
|
|
75
|
+
if expire:
|
|
76
|
+
self.storage[key]["expire"] = time.time() + expire
|
|
77
|
+
return self.storage[key]["value"]
|
|
78
|
+
|
|
79
|
+
def get(self, key):
|
|
80
|
+
"""
|
|
81
|
+
This Python function retrieves the value associated with a given key from a storage dictionary,
|
|
82
|
+
checking for expiration before returning the value or 0 if the key is not found.
|
|
83
|
+
|
|
84
|
+
:param key: The `key` parameter is used to specify the key of the item you want to retrieve from the
|
|
85
|
+
storage. The function checks if the key exists in the storage dictionary and returns the
|
|
86
|
+
corresponding value if it does. If the key has an expiration time set and it has expired, the
|
|
87
|
+
function deletes the key
|
|
88
|
+
:return: The `get` method returns the value associated with the given key if the key is present in
|
|
89
|
+
the storage and has not expired. If the key is not found or has expired, it returns 0.
|
|
90
|
+
"""
|
|
91
|
+
if key in self.storage:
|
|
92
|
+
if self.storage[key]["expire"] and time.time() > self.storage[key]["expire"]:
|
|
93
|
+
del self.storage[key]
|
|
94
|
+
return 0
|
|
95
|
+
return self.storage[key]["value"]
|
|
96
|
+
return 0
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class RateLimitMiddleware(Middleware):
|
|
100
|
+
"""
|
|
101
|
+
The RateLimitMiddleware class implements rate limiting functionality to restrict the number of
|
|
102
|
+
Requests per minute for a given IP address.
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
def __init__(self, storage_backend, requests_per_minute=60, window_size=60):
|
|
106
|
+
super().__init__()
|
|
107
|
+
self.storage = storage_backend
|
|
108
|
+
self.requests_per_minute = requests_per_minute
|
|
109
|
+
self.window_size = window_size
|
|
110
|
+
|
|
111
|
+
def get_request_identifier(self, request: Request):
|
|
112
|
+
return request.ip_addr
|
|
113
|
+
|
|
114
|
+
def before_request(self, request: Request):
|
|
115
|
+
"""
|
|
116
|
+
The `before_request` function checks the request rate limit and returns a 429 status code if the
|
|
117
|
+
limit is exceeded.
|
|
118
|
+
|
|
119
|
+
:param request: The `request` parameter in the `before_request` method is of type `Request`. It
|
|
120
|
+
is used to represent an incoming HTTP request that the server will process
|
|
121
|
+
:type request: Request
|
|
122
|
+
:return: The code snippet is a method called `before_request` that takes in a `Request` object
|
|
123
|
+
as a parameter.
|
|
124
|
+
"""
|
|
125
|
+
identifier = self.get_request_identifier(request)
|
|
126
|
+
current_time = int(time.time())
|
|
127
|
+
window_key = f"{identifier}:{current_time // self.window_size}"
|
|
128
|
+
|
|
129
|
+
request_count = self.storage.increment(window_key, expire=self.window_size)
|
|
130
|
+
|
|
131
|
+
if request_count > self.requests_per_minute:
|
|
132
|
+
return Response(status_code=429, description=b"Too Many Requests", headers={"Retry-After": str(self.window_size)})
|
|
133
|
+
|
|
134
|
+
return request
|
|
135
|
+
|
|
136
|
+
def after_request(self, response):
|
|
137
|
+
return response
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class ConcurrentRequestMiddleware(Middleware):
|
|
141
|
+
# The `ConcurrentRequestMiddleware` class limits the number of concurrent requests and returns a 429
|
|
142
|
+
# status code with a Retry-After header if the limit is reached.
|
|
143
|
+
def __init__(self, max_concurrent_requests=100):
|
|
144
|
+
super().__init__()
|
|
145
|
+
self.max_concurrent_requests = max_concurrent_requests
|
|
146
|
+
self.current_requests = 0
|
|
147
|
+
self.lock = Lock()
|
|
148
|
+
|
|
149
|
+
def get_request_identifier(self, request):
|
|
150
|
+
return request.ip_addr
|
|
151
|
+
|
|
152
|
+
def before_request(self, request):
|
|
153
|
+
"""
|
|
154
|
+
The `before_request` function limits the number of concurrent requests and returns a 429 status code
|
|
155
|
+
with a Retry-After header if the limit is reached.
|
|
156
|
+
|
|
157
|
+
:param request: The `before_request` method in the code snippet is a method that is called before
|
|
158
|
+
processing each incoming request. It checks if the number of current requests is within the allowed
|
|
159
|
+
limit (`max_concurrent_requests`). If the limit is exceeded, it returns a 429 status code with a
|
|
160
|
+
"Too Many Requests
|
|
161
|
+
:return: the `request` object after checking if the number of current requests is within the allowed
|
|
162
|
+
limit. If the limit is exceeded, it returns a 429 status code response with a "Too Many Requests"
|
|
163
|
+
description and a "Retry-After" header set to 5.
|
|
164
|
+
"""
|
|
165
|
+
|
|
166
|
+
with self.lock:
|
|
167
|
+
if self.current_requests >= self.max_concurrent_requests:
|
|
168
|
+
return Response(status_code=429, description="Too Many Requests", headers={"Retry-After": "5"})
|
|
169
|
+
self.current_requests += 1
|
|
170
|
+
|
|
171
|
+
return request
|
|
172
|
+
|
|
173
|
+
def after_request(self, response):
|
|
174
|
+
with self.lock:
|
|
175
|
+
self.current_requests -= 1
|
|
176
|
+
return response
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from hypern.hypern import BaseSchemaGenerator, Route as InternalRoute
|
|
5
|
+
import typing
|
|
6
|
+
import orjson
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class EndpointInfo(typing.NamedTuple):
|
|
10
|
+
path: str
|
|
11
|
+
http_method: str
|
|
12
|
+
func: typing.Callable[..., typing.Any]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class SchemaGenerator(BaseSchemaGenerator):
|
|
16
|
+
def __init__(self, base_schema: dict[str, typing.Any]) -> None:
|
|
17
|
+
self.base_schema = base_schema
|
|
18
|
+
|
|
19
|
+
def get_endpoints(self, routes: list[InternalRoute]) -> list[EndpointInfo]:
|
|
20
|
+
"""
|
|
21
|
+
Given the routes, yields the following information:
|
|
22
|
+
|
|
23
|
+
- path
|
|
24
|
+
eg: /users/
|
|
25
|
+
- http_method
|
|
26
|
+
one of 'get', 'post', 'put', 'patch', 'delete', 'options'
|
|
27
|
+
- func
|
|
28
|
+
method ready to extract the docstring
|
|
29
|
+
"""
|
|
30
|
+
endpoints_info: list[EndpointInfo] = []
|
|
31
|
+
|
|
32
|
+
for route in routes:
|
|
33
|
+
method = route.method.lower()
|
|
34
|
+
endpoints_info.append(EndpointInfo(path=route.path, http_method=method, func=route.function.handler))
|
|
35
|
+
return endpoints_info
|
|
36
|
+
|
|
37
|
+
def get_schema(self, app) -> dict[str, typing.Any]:
|
|
38
|
+
schema = dict(self.base_schema)
|
|
39
|
+
schema.setdefault("paths", {})
|
|
40
|
+
endpoints_info = self.get_endpoints(app.router.routes)
|
|
41
|
+
|
|
42
|
+
for endpoint in endpoints_info:
|
|
43
|
+
parsed = self.parse_docstring(endpoint.func)
|
|
44
|
+
|
|
45
|
+
if not parsed:
|
|
46
|
+
continue
|
|
47
|
+
|
|
48
|
+
if endpoint.path not in schema["paths"]:
|
|
49
|
+
schema["paths"][endpoint.path] = {}
|
|
50
|
+
|
|
51
|
+
schema["paths"][endpoint.path][endpoint.http_method] = orjson.loads(parsed)
|
|
52
|
+
|
|
53
|
+
return schema
|