hypern 0.3.0__cp311-cp311-win32.whl → 0.3.2__cp311-cp311-win32.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hypern/application.py +47 -8
- hypern/args_parser.py +7 -0
- hypern/caching/__init__.py +6 -0
- hypern/caching/backend.py +31 -0
- hypern/caching/redis_backend.py +200 -2
- hypern/caching/strategies.py +208 -0
- hypern/gateway/__init__.py +6 -0
- hypern/gateway/aggregator.py +32 -0
- hypern/gateway/gateway.py +41 -0
- hypern/gateway/proxy.py +60 -0
- hypern/gateway/service.py +52 -0
- hypern/hypern.cp311-win32.pyd +0 -0
- hypern/hypern.pyi +27 -17
- hypern/middleware/__init__.py +14 -2
- hypern/middleware/base.py +9 -14
- hypern/middleware/cache.py +177 -0
- hypern/middleware/compress.py +78 -0
- hypern/middleware/cors.py +6 -3
- hypern/middleware/limit.py +5 -4
- hypern/middleware/security.py +184 -0
- hypern/processpool.py +41 -2
- hypern/reload.py +26 -40
- hypern/routing/__init__.py +2 -1
- hypern/routing/queue.py +175 -0
- hypern/ws/__init__.py +4 -0
- hypern/ws/channel.py +80 -0
- hypern/ws/heartbeat.py +74 -0
- hypern/ws/room.py +76 -0
- hypern/ws/route.py +26 -0
- {hypern-0.3.0.dist-info → hypern-0.3.2.dist-info}/METADATA +1 -1
- {hypern-0.3.0.dist-info → hypern-0.3.2.dist-info}/RECORD +33 -23
- {hypern-0.3.0.dist-info → hypern-0.3.2.dist-info}/WHEEL +1 -1
- hypern/caching/base/__init__.py +0 -8
- hypern/caching/base/backend.py +0 -3
- hypern/caching/base/key_maker.py +0 -8
- hypern/caching/cache_manager.py +0 -56
- hypern/caching/cache_tag.py +0 -10
- hypern/caching/custom_key_maker.py +0 -11
- {hypern-0.3.0.dist-info → hypern-0.3.2.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,184 @@
|
|
1
|
+
import hashlib
|
2
|
+
import hmac
|
3
|
+
import secrets
|
4
|
+
import time
|
5
|
+
from base64 import b64decode, b64encode
|
6
|
+
from dataclasses import dataclass
|
7
|
+
from datetime import datetime, timedelta, timezone
|
8
|
+
from typing import Any, Dict, List, Optional
|
9
|
+
|
10
|
+
import jwt
|
11
|
+
|
12
|
+
from hypern.exceptions import Forbidden, Unauthorized
|
13
|
+
from hypern.hypern import Request, Response
|
14
|
+
from .base import Middleware, MiddlewareConfig
|
15
|
+
|
16
|
+
|
17
|
+
@dataclass
|
18
|
+
class CORSConfig:
|
19
|
+
allowed_origins: List[str]
|
20
|
+
allowed_methods: List[str]
|
21
|
+
max_age: int
|
22
|
+
|
23
|
+
|
24
|
+
@dataclass
|
25
|
+
class SecurityConfig:
|
26
|
+
rate_limiting: bool = False
|
27
|
+
jwt_auth: bool = False
|
28
|
+
cors_configuration: Optional[CORSConfig] = None
|
29
|
+
csrf_protection: bool = False
|
30
|
+
security_headers: Optional[Dict[str, str]] = None
|
31
|
+
jwt_secret: str = ""
|
32
|
+
jwt_algorithm: str = "HS256"
|
33
|
+
jwt_expires_in: int = 3600 # 1 hour in seconds
|
34
|
+
|
35
|
+
def __post_init__(self):
|
36
|
+
if self.cors_configuration:
|
37
|
+
self.cors_configuration = CORSConfig(**self.cors_configuration)
|
38
|
+
|
39
|
+
if self.security_headers is None:
|
40
|
+
self.security_headers = {
|
41
|
+
"X-Frame-Options": "DENY",
|
42
|
+
"X-Content-Type-Options": "nosniff",
|
43
|
+
"Strict-Transport-Security": "max-age=31536000; includeSubDomains",
|
44
|
+
}
|
45
|
+
|
46
|
+
|
47
|
+
class SecurityMiddleware(Middleware):
|
48
|
+
def __init__(self, secur_config: SecurityConfig, config: Optional[MiddlewareConfig] = None):
|
49
|
+
super().__init__(config)
|
50
|
+
self.secur_config = secur_config
|
51
|
+
self._secret_key = secrets.token_bytes(32)
|
52
|
+
self._token_lifetime = 3600
|
53
|
+
self._rate_limit_storage = {}
|
54
|
+
|
55
|
+
def _rate_limit_check(self, request: Request) -> Optional[Response]:
|
56
|
+
"""Check if the request exceeds rate limits"""
|
57
|
+
if not self.secur_config.rate_limiting:
|
58
|
+
return None
|
59
|
+
|
60
|
+
client_ip = request.client.host
|
61
|
+
current_time = time.time()
|
62
|
+
window_start = int(current_time - 60) # 1-minute window
|
63
|
+
|
64
|
+
# Clean up old entries
|
65
|
+
self._rate_limit_storage = {ip: hits for ip, hits in self._rate_limit_storage.items() if hits["timestamp"] > window_start}
|
66
|
+
|
67
|
+
if client_ip not in self._rate_limit_storage:
|
68
|
+
self._rate_limit_storage[client_ip] = {"count": 1, "timestamp": current_time}
|
69
|
+
else:
|
70
|
+
self._rate_limit_storage[client_ip]["count"] += 1
|
71
|
+
|
72
|
+
if self._rate_limit_storage[client_ip]["count"] > 60: # 60 requests per minute
|
73
|
+
return Response(status_code=429, description=b"Too Many Requests", headers={"Retry-After": "60"})
|
74
|
+
return None
|
75
|
+
|
76
|
+
def _generate_jwt_token(self, user_data: Dict[str, Any]) -> str:
|
77
|
+
"""Generate a JWT token"""
|
78
|
+
if not self.secur_config.jwt_secret:
|
79
|
+
raise ValueError("JWT secret key is not configured")
|
80
|
+
|
81
|
+
payload = {
|
82
|
+
"user": user_data,
|
83
|
+
"exp": datetime.now(tz=timezone.utc) + timedelta(seconds=self.secur_config.jwt_expires_in),
|
84
|
+
"iat": datetime.now(tz=timezone.utc),
|
85
|
+
}
|
86
|
+
return jwt.encode(payload, self.secur_config.jwt_secret, algorithm=self.secur_config.jwt_algorithm)
|
87
|
+
|
88
|
+
def _verify_jwt_token(self, token: str) -> Dict[str, Any]:
|
89
|
+
"""Verify JWT token and return payload"""
|
90
|
+
try:
|
91
|
+
payload = jwt.decode(token, self.secur_config.jwt_secret, algorithms=[self.secur_config.jwt_algorithm])
|
92
|
+
return payload
|
93
|
+
except jwt.ExpiredSignatureError:
|
94
|
+
raise Unauthorized("Token has expired")
|
95
|
+
except jwt.InvalidTokenError:
|
96
|
+
raise Unauthorized("Invalid token")
|
97
|
+
|
98
|
+
def _generate_csrf_token(self, session_id: str) -> str:
|
99
|
+
"""Generate a new CSRF token"""
|
100
|
+
timestamp = str(int(time.time()))
|
101
|
+
token_data = f"{session_id}:{timestamp}"
|
102
|
+
signature = hmac.new(self._secret_key, token_data.encode(), hashlib.sha256).digest()
|
103
|
+
return b64encode(f"{token_data}:{b64encode(signature).decode()}".encode()).decode()
|
104
|
+
|
105
|
+
def _validate_csrf_token(self, token: str) -> bool:
|
106
|
+
"""Validate CSRF token"""
|
107
|
+
try:
|
108
|
+
decoded_token = b64decode(token.encode()).decode()
|
109
|
+
session_id, timestamp, signature = decoded_token.rsplit(":", 2)
|
110
|
+
|
111
|
+
# Verify timestamp
|
112
|
+
token_time = int(timestamp)
|
113
|
+
current_time = int(time.time())
|
114
|
+
if current_time - token_time > self._token_lifetime:
|
115
|
+
return False
|
116
|
+
|
117
|
+
# Verify signature
|
118
|
+
expected_data = f"{session_id}:{timestamp}"
|
119
|
+
expected_signature = hmac.new(self._secret_key, expected_data.encode(), hashlib.sha256).digest()
|
120
|
+
|
121
|
+
actual_signature = b64decode(signature)
|
122
|
+
return hmac.compare_digest(expected_signature, actual_signature)
|
123
|
+
|
124
|
+
except (ValueError, AttributeError, TypeError):
|
125
|
+
return False
|
126
|
+
|
127
|
+
def _apply_cors_headers(self, response: Response) -> None:
|
128
|
+
"""Apply CORS headers to response"""
|
129
|
+
if not self.secur_config.cors_configuration:
|
130
|
+
return
|
131
|
+
|
132
|
+
cors = self.secur_config.cors_configuration
|
133
|
+
response.headers.update(
|
134
|
+
{
|
135
|
+
"Access-Control-Allow-Origin": ", ".join(cors.allowed_origins),
|
136
|
+
"Access-Control-Allow-Methods": ", ".join(cors.allowed_methods),
|
137
|
+
"Access-Control-Max-Age": str(cors.max_age),
|
138
|
+
"Access-Control-Allow-Headers": "Content-Type, Authorization, X-CSRF-Token",
|
139
|
+
"Access-Control-Allow-Credentials": "true",
|
140
|
+
}
|
141
|
+
)
|
142
|
+
|
143
|
+
def _apply_security_headers(self, response: Response) -> None:
|
144
|
+
"""Apply security headers to response"""
|
145
|
+
if self.secur_config.security_headers:
|
146
|
+
response.headers.update(self.secur_config.security_headers)
|
147
|
+
|
148
|
+
async def before_request(self, request: Request) -> Request | Response:
|
149
|
+
"""Process request before handling"""
|
150
|
+
# Rate limiting check
|
151
|
+
if rate_limit_response := self._rate_limit_check(request):
|
152
|
+
return rate_limit_response
|
153
|
+
|
154
|
+
# JWT authentication check
|
155
|
+
if self.secur_config.jwt_auth:
|
156
|
+
auth_header = request.headers.get("Authorization")
|
157
|
+
if not auth_header or not auth_header.startswith("Bearer "):
|
158
|
+
raise Unauthorized("Missing or invalid authorization header")
|
159
|
+
token = auth_header.split(" ")[1]
|
160
|
+
try:
|
161
|
+
request.user = self._verify_jwt_token(token)
|
162
|
+
except Unauthorized as e:
|
163
|
+
return Response(status_code=401, description=str(e))
|
164
|
+
|
165
|
+
# CSRF protection check
|
166
|
+
if self.secur_config.csrf_protection and request.method in ["POST", "PUT", "DELETE", "PATCH"]:
|
167
|
+
csrf_token = request.headers.get("X-CSRF-Token")
|
168
|
+
if not csrf_token or not self._validate_csrf_token(csrf_token):
|
169
|
+
raise Forbidden("CSRF token missing or invalid")
|
170
|
+
|
171
|
+
return request
|
172
|
+
|
173
|
+
async def after_request(self, response: Response) -> Response:
|
174
|
+
"""Process response after handling"""
|
175
|
+
self._apply_security_headers(response)
|
176
|
+
self._apply_cors_headers(response)
|
177
|
+
return response
|
178
|
+
|
179
|
+
def generate_csrf_token(self, request: Request) -> str:
|
180
|
+
"""Generate and set CSRF token for the request"""
|
181
|
+
if not hasattr(request, "session_id"):
|
182
|
+
request.session_id = secrets.token_urlsafe(32)
|
183
|
+
token = self._generate_csrf_token(request.session_id)
|
184
|
+
return token
|
hypern/processpool.py
CHANGED
@@ -25,11 +25,26 @@ def run_processes(
|
|
25
25
|
after_request: List[FunctionInfo],
|
26
26
|
response_headers: Dict[str, str],
|
27
27
|
reload: bool = True,
|
28
|
+
on_startup: FunctionInfo | None = None,
|
29
|
+
on_shutdown: FunctionInfo | None = None,
|
30
|
+
auto_compression: bool = False,
|
28
31
|
) -> List[Process]:
|
29
32
|
socket = SocketHeld(host, port)
|
30
33
|
|
31
34
|
process_pool = init_processpool(
|
32
|
-
router,
|
35
|
+
router,
|
36
|
+
websocket_router,
|
37
|
+
socket,
|
38
|
+
workers,
|
39
|
+
processes,
|
40
|
+
max_blocking_threads,
|
41
|
+
injectables,
|
42
|
+
before_request,
|
43
|
+
after_request,
|
44
|
+
response_headers,
|
45
|
+
on_startup,
|
46
|
+
on_shutdown,
|
47
|
+
auto_compression,
|
33
48
|
)
|
34
49
|
|
35
50
|
def terminating_signal_handler(_sig, _frame):
|
@@ -79,6 +94,9 @@ def init_processpool(
|
|
79
94
|
before_request: List[FunctionInfo],
|
80
95
|
after_request: List[FunctionInfo],
|
81
96
|
response_headers: Dict[str, str],
|
97
|
+
on_startup: FunctionInfo | None = None,
|
98
|
+
on_shutdown: FunctionInfo | None = None,
|
99
|
+
auto_compression: bool = False,
|
82
100
|
) -> List[Process]:
|
83
101
|
process_pool = []
|
84
102
|
|
@@ -86,7 +104,20 @@ def init_processpool(
|
|
86
104
|
copied_socket = socket.try_clone()
|
87
105
|
process = Process(
|
88
106
|
target=spawn_process,
|
89
|
-
args=(
|
107
|
+
args=(
|
108
|
+
router,
|
109
|
+
websocket_router,
|
110
|
+
copied_socket,
|
111
|
+
workers,
|
112
|
+
max_blocking_threads,
|
113
|
+
injectables,
|
114
|
+
before_request,
|
115
|
+
after_request,
|
116
|
+
response_headers,
|
117
|
+
on_startup,
|
118
|
+
on_shutdown,
|
119
|
+
auto_compression,
|
120
|
+
),
|
90
121
|
)
|
91
122
|
process.start()
|
92
123
|
process_pool.append(process)
|
@@ -118,6 +149,9 @@ def spawn_process(
|
|
118
149
|
before_request: List[FunctionInfo],
|
119
150
|
after_request: List[FunctionInfo],
|
120
151
|
response_headers: Dict[str, str],
|
152
|
+
on_startup: FunctionInfo | None = None,
|
153
|
+
on_shutdown: FunctionInfo | None = None,
|
154
|
+
auto_compression: bool = False,
|
121
155
|
):
|
122
156
|
loop = initialize_event_loop()
|
123
157
|
|
@@ -128,7 +162,12 @@ def spawn_process(
|
|
128
162
|
server.set_before_hooks(hooks=before_request)
|
129
163
|
server.set_after_hooks(hooks=after_request)
|
130
164
|
server.set_response_headers(headers=response_headers)
|
165
|
+
server.set_auto_compression(enabled=auto_compression)
|
131
166
|
|
167
|
+
if on_startup:
|
168
|
+
server.set_startup_handler(on_startup)
|
169
|
+
if on_shutdown:
|
170
|
+
server.set_shutdown_handler(on_shutdown)
|
132
171
|
try:
|
133
172
|
server.start(socket, workers, max_blocking_threads)
|
134
173
|
loop = asyncio.get_event_loop()
|
hypern/reload.py
CHANGED
@@ -2,6 +2,8 @@ import sys
|
|
2
2
|
import time
|
3
3
|
import subprocess
|
4
4
|
from watchdog.events import FileSystemEventHandler
|
5
|
+
import signal
|
6
|
+
import os
|
5
7
|
|
6
8
|
from .logging import logger
|
7
9
|
|
@@ -10,51 +12,35 @@ class EventHandler(FileSystemEventHandler):
|
|
10
12
|
def __init__(self, file_path: str, directory_path: str) -> None:
|
11
13
|
self.file_path = file_path
|
12
14
|
self.directory_path = directory_path
|
13
|
-
self.process = None
|
14
|
-
self.last_reload = time.time()
|
15
|
-
|
16
|
-
def stop_server(self):
|
17
|
-
if self.process:
|
18
|
-
try:
|
19
|
-
# Check if the process is still alive
|
20
|
-
if self.process.poll() is None: # None means the process is still running
|
21
|
-
self.process.terminate() # Gracefully terminate the process
|
22
|
-
self.process.wait(timeout=5) # Wait for the process to exit
|
23
|
-
else:
|
24
|
-
logger.error("Process is not running.")
|
25
|
-
except subprocess.TimeoutExpired:
|
26
|
-
logger.error("Process did not terminate in time. Forcing termination.")
|
27
|
-
self.process.kill() # Forcefully kill the process if it doesn't stop
|
28
|
-
except ProcessLookupError:
|
29
|
-
logger.error("Process does not exist.")
|
30
|
-
except Exception as e:
|
31
|
-
logger.error(f"An error occurred while stopping the server: {e}")
|
32
|
-
else:
|
33
|
-
logger.debug("No process to stop.")
|
15
|
+
self.process = None
|
16
|
+
self.last_reload = time.time()
|
34
17
|
|
35
18
|
def reload(self):
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
19
|
+
# Kill all existing processes with the same command
|
20
|
+
current_cmd = [sys.executable, *sys.argv]
|
21
|
+
|
22
|
+
try:
|
23
|
+
# Find and kill existing processes
|
24
|
+
for proc in subprocess.Popen(["ps", "aux"], stdout=subprocess.PIPE).communicate()[0].decode().splitlines():
|
25
|
+
if all(str(arg) in proc for arg in current_cmd):
|
26
|
+
pid = int(proc.split()[1])
|
27
|
+
try:
|
28
|
+
os.kill(pid, signal.SIGKILL) # NOSONAR
|
29
|
+
logger.debug(f"Killed process with PID {pid}")
|
30
|
+
except ProcessLookupError:
|
31
|
+
pass
|
32
|
+
|
33
|
+
# Start new process
|
34
|
+
self.process = subprocess.Popen(current_cmd)
|
35
|
+
self.last_reload = time.time()
|
36
|
+
logger.debug("Server reloaded successfully")
|
37
|
+
|
38
|
+
except Exception as e:
|
39
|
+
logger.error(f"Reload failed: {e}")
|
47
40
|
|
48
41
|
def on_modified(self, event) -> None:
|
49
|
-
"""
|
50
|
-
This function is a callback that will start a new server on every even change
|
51
|
-
|
52
|
-
:param event FSEvent: a data structure with info about the events
|
53
|
-
"""
|
54
|
-
|
55
|
-
# Avoid reloading multiple times when watchdog detects multiple events
|
56
42
|
if time.time() - self.last_reload < 0.5:
|
57
43
|
return
|
58
44
|
|
59
|
-
time.sleep(0.2) #
|
45
|
+
time.sleep(0.2) # Ensure file is written
|
60
46
|
self.reload()
|
hypern/routing/__init__.py
CHANGED
hypern/routing/queue.py
ADDED
@@ -0,0 +1,175 @@
|
|
1
|
+
import asyncio
|
2
|
+
import time
|
3
|
+
from contextlib import asynccontextmanager
|
4
|
+
from dataclasses import dataclass, field
|
5
|
+
from queue import PriorityQueue
|
6
|
+
from typing import Any, Dict
|
7
|
+
|
8
|
+
from hypern.hypern import Request, Response
|
9
|
+
from hypern.response import JSONResponse
|
10
|
+
from hypern.routing import HTTPEndpoint
|
11
|
+
from hypern.logging import logger
|
12
|
+
|
13
|
+
|
14
|
+
@dataclass(order=True)
|
15
|
+
class PrioritizedRequest:
|
16
|
+
priority: int
|
17
|
+
timestamp: float = field(default_factory=time.time)
|
18
|
+
request: Request | None = field(default=None, compare=False)
|
19
|
+
future: asyncio.Future = field(compare=False, default_factory=asyncio.Future)
|
20
|
+
|
21
|
+
|
22
|
+
class QueuedHTTPEndpoint(HTTPEndpoint):
|
23
|
+
"""
|
24
|
+
HTTPEndpoint with request queuing capabilities for high-load scenarios.
|
25
|
+
"""
|
26
|
+
|
27
|
+
def __init__(self, *args, **kwargs):
|
28
|
+
super().__init__(*args, **kwargs)
|
29
|
+
# Queue configuration
|
30
|
+
self._max_concurrent = kwargs.get("max_concurrent", 100)
|
31
|
+
self._queue_size = kwargs.get("queue_size", 1000)
|
32
|
+
self._request_timeout = kwargs.get("request_timeout", 30)
|
33
|
+
|
34
|
+
# Initialize queuing system
|
35
|
+
self._request_queue: PriorityQueue = PriorityQueue(maxsize=self._queue_size)
|
36
|
+
self._active_requests = 0
|
37
|
+
self._lock = None # Will be initialized when needed
|
38
|
+
self._request_semaphore = None # Will be initialized when needed
|
39
|
+
self._shutdown = False
|
40
|
+
self._queue_task = None
|
41
|
+
self._initialized = False
|
42
|
+
|
43
|
+
# Metrics
|
44
|
+
self._metrics = {"processed_requests": 0, "queued_requests": 0, "rejected_requests": 0, "avg_wait_time": 0.0}
|
45
|
+
|
46
|
+
self._fully_message = "Request queue is full"
|
47
|
+
|
48
|
+
async def _initialize(self):
|
49
|
+
"""Initialize async components when first request arrives"""
|
50
|
+
if not self._initialized:
|
51
|
+
self._lock = asyncio.Lock()
|
52
|
+
self._request_semaphore = asyncio.Semaphore(self._max_concurrent)
|
53
|
+
self._queue_task = asyncio.create_task(self._process_queue())
|
54
|
+
self._initialized = True
|
55
|
+
|
56
|
+
@asynccontextmanager
|
57
|
+
async def _queue_context(self, request: Request, priority: int = 10):
|
58
|
+
"""Context manager for handling request queuing."""
|
59
|
+
if self._shutdown:
|
60
|
+
raise RuntimeError("Endpoint is shutting down")
|
61
|
+
|
62
|
+
await self._initialize() # Ensure async components are initialized
|
63
|
+
|
64
|
+
request_future = asyncio.Future()
|
65
|
+
prioritized_request = PrioritizedRequest(priority=priority, timestamp=time.time(), request=request, future=request_future)
|
66
|
+
|
67
|
+
try:
|
68
|
+
if self._request_queue.qsize() >= self._queue_size:
|
69
|
+
self._metrics["rejected_requests"] += 1
|
70
|
+
raise asyncio.QueueFull(self._fully_message)
|
71
|
+
|
72
|
+
await self._enqueue_request(prioritized_request)
|
73
|
+
yield await asyncio.wait_for(request_future, timeout=self._request_timeout)
|
74
|
+
|
75
|
+
except asyncio.TimeoutError:
|
76
|
+
self._metrics["rejected_requests"] += 1
|
77
|
+
raise asyncio.TimeoutError("Request timed out while waiting in queue")
|
78
|
+
finally:
|
79
|
+
if not request_future.done():
|
80
|
+
request_future.cancel()
|
81
|
+
|
82
|
+
async def _enqueue_request(self, request: PrioritizedRequest):
|
83
|
+
"""Add request to the queue."""
|
84
|
+
try:
|
85
|
+
self._request_queue.put_nowait(request)
|
86
|
+
self._metrics["queued_requests"] += 1
|
87
|
+
except asyncio.QueueFull:
|
88
|
+
self._metrics["rejected_requests"] += 1
|
89
|
+
raise asyncio.QueueFull(self._fully_message)
|
90
|
+
|
91
|
+
async def _process_queue(self):
|
92
|
+
"""Background task to process queued requests."""
|
93
|
+
while not self._shutdown:
|
94
|
+
try:
|
95
|
+
if not self._request_queue.empty():
|
96
|
+
async with self._lock:
|
97
|
+
if self._active_requests >= self._max_concurrent:
|
98
|
+
await asyncio.sleep(0.1)
|
99
|
+
continue
|
100
|
+
|
101
|
+
request = self._request_queue.get_nowait()
|
102
|
+
wait_time = time.time() - request.timestamp
|
103
|
+
self._metrics["avg_wait_time"] = (self._metrics["avg_wait_time"] * self._metrics["processed_requests"] + wait_time) / (
|
104
|
+
self._metrics["processed_requests"] + 1
|
105
|
+
)
|
106
|
+
|
107
|
+
if not request.future.cancelled():
|
108
|
+
self._active_requests += 1
|
109
|
+
asyncio.create_task(self._handle_request(request))
|
110
|
+
|
111
|
+
await asyncio.sleep(0.01)
|
112
|
+
except Exception as e:
|
113
|
+
logger.error(f"Error processing queue: {e}")
|
114
|
+
await asyncio.sleep(1)
|
115
|
+
|
116
|
+
async def _handle_request(self, request: PrioritizedRequest):
|
117
|
+
"""Handle individual request."""
|
118
|
+
try:
|
119
|
+
async with self._request_semaphore:
|
120
|
+
response = await super().dispatch(request.request, {})
|
121
|
+
if not request.future.done():
|
122
|
+
request.future.set_result(response)
|
123
|
+
except Exception as e:
|
124
|
+
if not request.future.done():
|
125
|
+
request.future.set_exception(e)
|
126
|
+
finally:
|
127
|
+
self._active_requests -= 1
|
128
|
+
self._metrics["processed_requests"] += 1
|
129
|
+
self._metrics["queued_requests"] -= 1
|
130
|
+
|
131
|
+
async def dispatch(self, request: Request, inject: Dict[str, Any]) -> Response:
|
132
|
+
"""
|
133
|
+
Enhanced dispatch method with request queuing.
|
134
|
+
"""
|
135
|
+
try:
|
136
|
+
priority = self._get_request_priority(request)
|
137
|
+
|
138
|
+
async with self._queue_context(request, priority) as response:
|
139
|
+
return response
|
140
|
+
|
141
|
+
except asyncio.QueueFull:
|
142
|
+
return JSONResponse(description={"error": "Server too busy", "message": self._fully_message, "retry_after": 5}, status_code=503)
|
143
|
+
except asyncio.TimeoutError:
|
144
|
+
return JSONResponse(
|
145
|
+
description={
|
146
|
+
"error": "Request timeout",
|
147
|
+
"message": "Request timed out while waiting in queue",
|
148
|
+
},
|
149
|
+
status_code=504,
|
150
|
+
)
|
151
|
+
except Exception as e:
|
152
|
+
return JSONResponse(description={"error": "Internal server error", "message": str(e)}, status_code=500)
|
153
|
+
|
154
|
+
def _get_request_priority(self, request: Request) -> int:
|
155
|
+
"""
|
156
|
+
Determine request priority. Override this method to implement
|
157
|
+
custom priority logic.
|
158
|
+
"""
|
159
|
+
if request.method == "GET":
|
160
|
+
return 5
|
161
|
+
return 10
|
162
|
+
|
163
|
+
async def shutdown(self):
|
164
|
+
"""Gracefully shutdown the endpoint."""
|
165
|
+
self._shutdown = True
|
166
|
+
if self._queue_task and not self._queue_task.done():
|
167
|
+
self._queue_task.cancel()
|
168
|
+
try:
|
169
|
+
await self._queue_task
|
170
|
+
except asyncio.CancelledError:
|
171
|
+
pass
|
172
|
+
|
173
|
+
def get_metrics(self) -> Dict[str, Any]:
|
174
|
+
"""Get current queue metrics."""
|
175
|
+
return {**self._metrics, "current_queue_size": self._request_queue.qsize(), "active_requests": self._active_requests}
|
hypern/ws/__init__.py
ADDED
hypern/ws/channel.py
ADDED
@@ -0,0 +1,80 @@
|
|
1
|
+
from dataclasses import dataclass, field
|
2
|
+
from typing import Any, Awaitable, Callable, Dict, Set
|
3
|
+
|
4
|
+
from hypern.hypern import WebSocketSession
|
5
|
+
|
6
|
+
|
7
|
+
@dataclass
|
8
|
+
class Channel:
|
9
|
+
name: str
|
10
|
+
subscribers: Set[WebSocketSession] = field(default_factory=set)
|
11
|
+
handlers: Dict[str, Callable[[WebSocketSession, Any], Awaitable[None]]] = field(default_factory=dict)
|
12
|
+
|
13
|
+
def publish(self, event: str, data: Any, publisher: WebSocketSession = None):
|
14
|
+
"""Publish an event to all subscribers except the publisher"""
|
15
|
+
for subscriber in self.subscribers:
|
16
|
+
if subscriber != publisher:
|
17
|
+
subscriber.send({"channel": self.name, "event": event, "data": data})
|
18
|
+
|
19
|
+
def handle_event(self, event: str, session: WebSocketSession, data: Any):
|
20
|
+
"""Handle an event on this channel"""
|
21
|
+
if event in self.handlers:
|
22
|
+
self.handlers[event](session, data)
|
23
|
+
|
24
|
+
def add_subscriber(self, subscriber: WebSocketSession):
|
25
|
+
"""Add a subscriber to the channel"""
|
26
|
+
self.subscribers.add(subscriber)
|
27
|
+
|
28
|
+
def remove_subscriber(self, subscriber: WebSocketSession):
|
29
|
+
"""Remove a subscriber from the channel"""
|
30
|
+
self.subscribers.discard(subscriber)
|
31
|
+
|
32
|
+
def on(self, event: str):
|
33
|
+
"""Decorator for registering event handlers"""
|
34
|
+
|
35
|
+
def decorator(handler: Callable[[WebSocketSession, Any], Awaitable[None]]):
|
36
|
+
self.handlers[event] = handler
|
37
|
+
return handler
|
38
|
+
|
39
|
+
return decorator
|
40
|
+
|
41
|
+
|
42
|
+
class ChannelManager:
|
43
|
+
def __init__(self):
|
44
|
+
self.channels: Dict[str, Channel] = {}
|
45
|
+
self.client_channels: Dict[WebSocketSession, Set[str]] = {}
|
46
|
+
|
47
|
+
def create_channel(self, channel_name: str) -> Channel:
|
48
|
+
"""Create a new channel if it doesn't exist"""
|
49
|
+
if channel_name not in self.channels:
|
50
|
+
self.channels[channel_name] = Channel(channel_name)
|
51
|
+
return self.channels[channel_name]
|
52
|
+
|
53
|
+
def get_channel(self, channel_name: str) -> Channel:
|
54
|
+
"""Get a channel by name"""
|
55
|
+
return self.channels.get(channel_name)
|
56
|
+
|
57
|
+
def subscribe(self, client: WebSocketSession, channel_name: str):
|
58
|
+
"""Subscribe a client to a channel"""
|
59
|
+
channel = self.create_channel(channel_name)
|
60
|
+
channel.add_subscriber(client)
|
61
|
+
|
62
|
+
if client not in self.client_channels:
|
63
|
+
self.client_channels[client] = set()
|
64
|
+
self.client_channels[client].add(channel_name)
|
65
|
+
|
66
|
+
def unsubscribe(self, client: WebSocketSession, channel_name: str):
|
67
|
+
"""Unsubscribe a client from a channel"""
|
68
|
+
channel = self.get_channel(channel_name)
|
69
|
+
if channel:
|
70
|
+
channel.remove_subscriber(client)
|
71
|
+
if client in self.client_channels:
|
72
|
+
self.client_channels[client].discard(channel_name)
|
73
|
+
|
74
|
+
def unsubscribe_all(self, client: WebSocketSession):
|
75
|
+
"""Unsubscribe a client from all channels"""
|
76
|
+
if client in self.client_channels:
|
77
|
+
channels = self.client_channels[client].copy()
|
78
|
+
for channel_name in channels:
|
79
|
+
self.unsubscribe(client, channel_name)
|
80
|
+
del self.client_channels[client]
|