hypern 0.3.11__cp312-cp312-musllinux_1_2_armv7l.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hypern/__init__.py +24 -0
- hypern/application.py +495 -0
- hypern/args_parser.py +73 -0
- hypern/auth/__init__.py +0 -0
- hypern/auth/authorization.py +2 -0
- hypern/background.py +4 -0
- hypern/caching/__init__.py +6 -0
- hypern/caching/backend.py +31 -0
- hypern/caching/redis_backend.py +201 -0
- hypern/caching/strategies.py +208 -0
- hypern/cli/__init__.py +0 -0
- hypern/cli/commands.py +0 -0
- hypern/config.py +246 -0
- hypern/database/__init__.py +0 -0
- hypern/database/sqlalchemy/__init__.py +4 -0
- hypern/database/sqlalchemy/config.py +66 -0
- hypern/database/sqlalchemy/repository.py +290 -0
- hypern/database/sqlx/__init__.py +36 -0
- hypern/database/sqlx/field.py +246 -0
- hypern/database/sqlx/migrate.py +263 -0
- hypern/database/sqlx/model.py +117 -0
- hypern/database/sqlx/query.py +904 -0
- hypern/datastructures.py +40 -0
- hypern/enum.py +13 -0
- hypern/exceptions/__init__.py +34 -0
- hypern/exceptions/base.py +62 -0
- hypern/exceptions/common.py +12 -0
- hypern/exceptions/errors.py +15 -0
- hypern/exceptions/formatters.py +56 -0
- hypern/exceptions/http.py +76 -0
- hypern/gateway/__init__.py +6 -0
- hypern/gateway/aggregator.py +32 -0
- hypern/gateway/gateway.py +41 -0
- hypern/gateway/proxy.py +60 -0
- hypern/gateway/service.py +52 -0
- hypern/hypern.cpython-312-arm-linux-musleabihf.so +0 -0
- hypern/hypern.pyi +333 -0
- hypern/i18n/__init__.py +0 -0
- hypern/logging/__init__.py +3 -0
- hypern/logging/logger.py +82 -0
- hypern/middleware/__init__.py +17 -0
- hypern/middleware/base.py +13 -0
- hypern/middleware/cache.py +177 -0
- hypern/middleware/compress.py +78 -0
- hypern/middleware/cors.py +41 -0
- hypern/middleware/i18n.py +1 -0
- hypern/middleware/limit.py +177 -0
- hypern/middleware/security.py +184 -0
- hypern/openapi/__init__.py +5 -0
- hypern/openapi/schemas.py +51 -0
- hypern/openapi/swagger.py +3 -0
- hypern/processpool.py +139 -0
- hypern/py.typed +0 -0
- hypern/reload.py +46 -0
- hypern/response/__init__.py +3 -0
- hypern/response/response.py +142 -0
- hypern/routing/__init__.py +5 -0
- hypern/routing/dispatcher.py +70 -0
- hypern/routing/endpoint.py +30 -0
- hypern/routing/parser.py +98 -0
- hypern/routing/queue.py +175 -0
- hypern/routing/route.py +280 -0
- hypern/scheduler.py +5 -0
- hypern/worker.py +274 -0
- hypern/ws/__init__.py +4 -0
- hypern/ws/channel.py +80 -0
- hypern/ws/heartbeat.py +74 -0
- hypern/ws/room.py +76 -0
- hypern/ws/route.py +26 -0
- hypern-0.3.11.dist-info/METADATA +134 -0
- hypern-0.3.11.dist-info/RECORD +74 -0
- hypern-0.3.11.dist-info/WHEEL +4 -0
- hypern-0.3.11.dist-info/licenses/LICENSE +24 -0
- hypern.libs/libgcc_s-5b5488a6.so.1 +0 -0
hypern/routing/route.py
ADDED
@@ -0,0 +1,280 @@
|
|
1
|
+
# -*- coding: utf-8 -*-
|
2
|
+
import asyncio
|
3
|
+
import inspect
|
4
|
+
from enum import Enum
|
5
|
+
from typing import Any, Callable, Dict, List, Type, Union, get_args, get_origin
|
6
|
+
|
7
|
+
import yaml # type: ignore
|
8
|
+
from pydantic import BaseModel
|
9
|
+
from pydantic.fields import FieldInfo
|
10
|
+
|
11
|
+
from hypern.auth.authorization import Authorization
|
12
|
+
from hypern.datastructures import HTTPMethod
|
13
|
+
from hypern.hypern import FunctionInfo, Request, Router
|
14
|
+
from hypern.hypern import Route as InternalRoute
|
15
|
+
|
16
|
+
from .dispatcher import dispatch
|
17
|
+
|
18
|
+
|
19
|
+
def get_field_type(field):
|
20
|
+
return field.outer_type_
|
21
|
+
|
22
|
+
|
23
|
+
def pydantic_to_swagger(model: type[BaseModel] | dict):
|
24
|
+
if isinstance(model, dict):
|
25
|
+
# Handle the case when a dict is passed instead of a Pydantic model
|
26
|
+
schema = {}
|
27
|
+
for name, field_type in model.items():
|
28
|
+
schema[name] = _process_field(name, field_type)
|
29
|
+
return schema
|
30
|
+
|
31
|
+
schema = {
|
32
|
+
model.__name__: {
|
33
|
+
"type": "object",
|
34
|
+
"properties": {},
|
35
|
+
}
|
36
|
+
}
|
37
|
+
|
38
|
+
for name, field in model.model_fields.items():
|
39
|
+
schema[model.__name__]["properties"][name] = _process_field(name, field)
|
40
|
+
|
41
|
+
return schema
|
42
|
+
|
43
|
+
|
44
|
+
class SchemaProcessor:
|
45
|
+
@staticmethod
|
46
|
+
def process_union(args: tuple) -> Dict[str, Any]:
|
47
|
+
"""Process Union types"""
|
48
|
+
if type(None) in args:
|
49
|
+
inner_type = next(arg for arg in args if arg is not type(None))
|
50
|
+
schema = SchemaProcessor._process_field("", inner_type)
|
51
|
+
schema["nullable"] = True
|
52
|
+
return schema
|
53
|
+
return {"oneOf": [SchemaProcessor._process_field("", arg) for arg in args]}
|
54
|
+
|
55
|
+
@staticmethod
|
56
|
+
def process_enum(annotation: Type[Enum]) -> Dict[str, Any]:
|
57
|
+
"""Process Enum types"""
|
58
|
+
return {"type": "string", "enum": [e.value for e in annotation.__members__.values()]}
|
59
|
+
|
60
|
+
@staticmethod
|
61
|
+
def process_primitive(annotation: type) -> Dict[str, str]:
|
62
|
+
"""Process primitive types"""
|
63
|
+
type_mapping = {int: "integer", float: "number", str: "string", bool: "boolean"}
|
64
|
+
return {"type": type_mapping.get(annotation, "object")}
|
65
|
+
|
66
|
+
@staticmethod
|
67
|
+
def process_list(annotation: type) -> Dict[str, Any]:
|
68
|
+
"""Process list types"""
|
69
|
+
schema = {"type": "array"}
|
70
|
+
|
71
|
+
args = get_args(annotation)
|
72
|
+
if args:
|
73
|
+
item_type = args[0]
|
74
|
+
schema["items"] = SchemaProcessor._process_field("item", item_type)
|
75
|
+
else:
|
76
|
+
schema["items"] = {}
|
77
|
+
return schema
|
78
|
+
|
79
|
+
@staticmethod
|
80
|
+
def process_dict(annotation: type) -> Dict[str, Any]:
|
81
|
+
"""Process dict types"""
|
82
|
+
schema = {"type": "object"}
|
83
|
+
|
84
|
+
args = get_args(annotation)
|
85
|
+
if args:
|
86
|
+
key_type, value_type = args
|
87
|
+
if key_type == str: # noqa: E721
|
88
|
+
schema["additionalProperties"] = SchemaProcessor._process_field("value", value_type)
|
89
|
+
return schema
|
90
|
+
|
91
|
+
@classmethod
|
92
|
+
def _process_field(cls, name: str, field: Any) -> Dict[str, Any]:
|
93
|
+
"""Process a single field"""
|
94
|
+
if isinstance(field, FieldInfo):
|
95
|
+
annotation = field.annotation
|
96
|
+
else:
|
97
|
+
annotation = field
|
98
|
+
|
99
|
+
# Process Union types
|
100
|
+
origin = get_origin(annotation)
|
101
|
+
if origin is Union:
|
102
|
+
return cls.process_union(get_args(annotation))
|
103
|
+
|
104
|
+
# Process Enum types
|
105
|
+
if isinstance(annotation, type) and issubclass(annotation, Enum):
|
106
|
+
return cls.process_enum(annotation)
|
107
|
+
|
108
|
+
# Process primitive types
|
109
|
+
if annotation in {int, float, str, bool}:
|
110
|
+
return cls.process_primitive(annotation)
|
111
|
+
|
112
|
+
# Process list types
|
113
|
+
if annotation == list or origin is list: # noqa: E721
|
114
|
+
return cls.process_list(annotation)
|
115
|
+
|
116
|
+
# Process dict types
|
117
|
+
if annotation == dict or origin is dict: # noqa: E721
|
118
|
+
return cls.process_dict(annotation)
|
119
|
+
|
120
|
+
# Process Pydantic models
|
121
|
+
if isinstance(annotation, type) and issubclass(annotation, BaseModel):
|
122
|
+
return pydantic_to_swagger(annotation)
|
123
|
+
|
124
|
+
# Fallback for complex types
|
125
|
+
return {"type": "object"}
|
126
|
+
|
127
|
+
|
128
|
+
def _process_field(name: str, field: Any) -> Dict[str, Any]:
|
129
|
+
"""
|
130
|
+
Process a field and return its schema representation
|
131
|
+
|
132
|
+
Args:
|
133
|
+
name: Field name
|
134
|
+
field: Field type or FieldInfo object
|
135
|
+
|
136
|
+
Returns:
|
137
|
+
Dictionary representing the JSON schema for the field
|
138
|
+
"""
|
139
|
+
return SchemaProcessor._process_field(name, field)
|
140
|
+
|
141
|
+
|
142
|
+
class Route:
|
143
|
+
def __init__(
|
144
|
+
self,
|
145
|
+
path: str,
|
146
|
+
endpoint: Callable[..., Any] | None = None,
|
147
|
+
*,
|
148
|
+
name: str | None = None,
|
149
|
+
tags: List[str] | None = None,
|
150
|
+
) -> None:
|
151
|
+
self.path = path
|
152
|
+
self.endpoint = endpoint
|
153
|
+
self.tags = tags or ["Default"]
|
154
|
+
self.name = name
|
155
|
+
|
156
|
+
self.http_methods = {
|
157
|
+
"GET": HTTPMethod.GET,
|
158
|
+
"POST": HTTPMethod.POST,
|
159
|
+
"PUT": HTTPMethod.PUT,
|
160
|
+
"DELETE": HTTPMethod.DELETE,
|
161
|
+
"PATCH": HTTPMethod.PATCH,
|
162
|
+
"HEAD": HTTPMethod.HEAD,
|
163
|
+
"OPTIONS": HTTPMethod.OPTIONS,
|
164
|
+
}
|
165
|
+
self.functional_handlers = []
|
166
|
+
|
167
|
+
def _process_authorization(self, item: type, docs: Dict) -> None:
|
168
|
+
if isinstance(item, type) and issubclass(item, Authorization):
|
169
|
+
auth_obj = item()
|
170
|
+
docs["security"] = [{auth_obj.name: []}]
|
171
|
+
|
172
|
+
def _process_model_params(self, key: str, item: type, docs: Dict) -> None:
|
173
|
+
if not (isinstance(item, type) and issubclass(item, BaseModel)):
|
174
|
+
return
|
175
|
+
|
176
|
+
if key == "form_data":
|
177
|
+
docs["requestBody"] = {"content": {"application/json": {"schema": pydantic_to_swagger(item).get(item.__name__)}}}
|
178
|
+
elif key == "query_params":
|
179
|
+
docs["parameters"] = [{"name": param, "in": "query", "schema": _process_field(param, field)} for param, field in item.model_fields.items()]
|
180
|
+
elif key == "path_params":
|
181
|
+
path_params = [
|
182
|
+
{"name": param, "in": "path", "required": True, "schema": _process_field(param, field)} for param, field in item.model_fields.items()
|
183
|
+
]
|
184
|
+
docs.setdefault("parameters", []).extend(path_params)
|
185
|
+
|
186
|
+
def _process_response(self, response_type: type, docs: Dict) -> None:
|
187
|
+
if isinstance(response_type, type) and issubclass(response_type, BaseModel):
|
188
|
+
docs["responses"] = {
|
189
|
+
"200": {
|
190
|
+
"description": "Successful response",
|
191
|
+
"content": {"application/json": {"schema": pydantic_to_swagger(response_type).get(response_type.__name__)}},
|
192
|
+
}
|
193
|
+
}
|
194
|
+
|
195
|
+
def swagger_generate(self, signature: inspect.Signature, summary: str = "Document API") -> str:
|
196
|
+
_inputs = signature.parameters.values()
|
197
|
+
_inputs_dict = {_input.name: _input.annotation for _input in _inputs}
|
198
|
+
_docs: Dict = {"summary": summary, "tags": self.tags, "responses": [], "name": self.name}
|
199
|
+
|
200
|
+
for key, item in _inputs_dict.items():
|
201
|
+
self._process_authorization(item, _docs)
|
202
|
+
self._process_model_params(key, item, _docs)
|
203
|
+
|
204
|
+
self._process_response(signature.return_annotation, _docs)
|
205
|
+
return yaml.dump(_docs)
|
206
|
+
|
207
|
+
def _combine_path(self, path1: str, path2: str) -> str:
|
208
|
+
if path1.endswith("/") and path2.startswith("/"):
|
209
|
+
return path1 + path2[1:]
|
210
|
+
if not path1.endswith("/") and not path2.startswith("/"):
|
211
|
+
return path1 + "/" + path2
|
212
|
+
return path1 + path2
|
213
|
+
|
214
|
+
def make_internal_route(self, path, handler, method) -> InternalRoute:
|
215
|
+
is_async = asyncio.iscoroutinefunction(handler)
|
216
|
+
func_info = FunctionInfo(handler=handler, is_async=is_async)
|
217
|
+
return InternalRoute(path=path, function=func_info, method=method)
|
218
|
+
|
219
|
+
def __call__(self, app, *args: Any, **kwds: Any) -> Any:
|
220
|
+
router = Router(self.path)
|
221
|
+
|
222
|
+
# Validate handlers
|
223
|
+
if not self.endpoint and not self.functional_handlers:
|
224
|
+
raise ValueError(f"No handler found for route: {self.path}")
|
225
|
+
|
226
|
+
# Handle functional routes
|
227
|
+
for route in self.functional_handlers:
|
228
|
+
router.add_route(route=route)
|
229
|
+
if not self.endpoint:
|
230
|
+
return router
|
231
|
+
|
232
|
+
# Handle class-based routes
|
233
|
+
for name, func in self.endpoint.__dict__.items():
|
234
|
+
if name.upper() in self.http_methods:
|
235
|
+
sig = inspect.signature(func)
|
236
|
+
doc = self.swagger_generate(sig, func.__doc__)
|
237
|
+
endpoint_obj = self.endpoint()
|
238
|
+
route = self.make_internal_route(path="/", handler=endpoint_obj.dispatch, method=name.upper())
|
239
|
+
route.doc = doc
|
240
|
+
router.add_route(route=route)
|
241
|
+
del endpoint_obj # free up memory
|
242
|
+
return router
|
243
|
+
|
244
|
+
def add_route(
|
245
|
+
self,
|
246
|
+
path: str,
|
247
|
+
method: str,
|
248
|
+
) -> Callable:
|
249
|
+
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
250
|
+
async def functional_wrapper(request: Request, inject: Dict[str, Any]) -> Any:
|
251
|
+
return await dispatch(func, request, inject)
|
252
|
+
|
253
|
+
sig = inspect.signature(func)
|
254
|
+
route = self.make_internal_route(path=path, handler=functional_wrapper, method=method.upper())
|
255
|
+
route.doc = self.swagger_generate(sig, func.__doc__)
|
256
|
+
|
257
|
+
self.functional_handlers.append(route)
|
258
|
+
|
259
|
+
return decorator
|
260
|
+
|
261
|
+
def get(self, path: str) -> Callable:
|
262
|
+
return self.add_route(path, "GET")
|
263
|
+
|
264
|
+
def post(self, path: str) -> Callable:
|
265
|
+
return self.add_route(path, "POST")
|
266
|
+
|
267
|
+
def put(self, path: str) -> Callable:
|
268
|
+
return self.add_route(path, "PUT")
|
269
|
+
|
270
|
+
def delete(self, path: str) -> Callable:
|
271
|
+
return self.add_route(path, "DELETE")
|
272
|
+
|
273
|
+
def patch(self, path: str) -> Callable:
|
274
|
+
return self.add_route(path, "PATCH")
|
275
|
+
|
276
|
+
def head(self, path: str) -> Callable:
|
277
|
+
return self.add_route(path, "HEAD")
|
278
|
+
|
279
|
+
def options(self, path: str) -> Callable:
|
280
|
+
return self.add_route(path, "OPTIONS")
|
hypern/scheduler.py
ADDED
hypern/worker.py
ADDED
@@ -0,0 +1,274 @@
|
|
1
|
+
import asyncio
|
2
|
+
import logging
|
3
|
+
import os
|
4
|
+
import time
|
5
|
+
import traceback
|
6
|
+
from concurrent.futures import ThreadPoolExecutor
|
7
|
+
from functools import partial, wraps
|
8
|
+
from typing import Callable, Dict, List
|
9
|
+
|
10
|
+
from celery import Celery
|
11
|
+
from celery.result import AsyncResult
|
12
|
+
from celery.signals import (
|
13
|
+
after_setup_logger,
|
14
|
+
after_setup_task_logger,
|
15
|
+
task_failure,
|
16
|
+
task_postrun,
|
17
|
+
task_prerun,
|
18
|
+
worker_ready,
|
19
|
+
worker_shutdown,
|
20
|
+
)
|
21
|
+
from kombu import Exchange, Queue
|
22
|
+
|
23
|
+
|
24
|
+
class Worker(Celery):
|
25
|
+
def __init__(
|
26
|
+
self,
|
27
|
+
main: str = None,
|
28
|
+
broker_url: str = None,
|
29
|
+
result_backend: str = "rpc://",
|
30
|
+
queues: Dict[str, Dict] = None,
|
31
|
+
task_routes: Dict[str, str] = None,
|
32
|
+
imports: List[str] = None,
|
33
|
+
**kwargs,
|
34
|
+
):
|
35
|
+
super().__init__(main, **kwargs)
|
36
|
+
|
37
|
+
self._executor = ThreadPoolExecutor()
|
38
|
+
self._task_timings = {}
|
39
|
+
|
40
|
+
self.default_exchange = Exchange("default", type="direct")
|
41
|
+
self.priority_exchange = Exchange("priority", type="direct")
|
42
|
+
|
43
|
+
default_queues = {
|
44
|
+
"default": {"exchange": self.default_exchange, "routing_key": "default"},
|
45
|
+
"high_priority": {"exchange": self.priority_exchange, "routing_key": "high"},
|
46
|
+
"low_priority": {"exchange": self.priority_exchange, "routing_key": "low"},
|
47
|
+
}
|
48
|
+
if queues:
|
49
|
+
default_queues.update(queues)
|
50
|
+
|
51
|
+
self._queues = {
|
52
|
+
name: Queue(
|
53
|
+
name,
|
54
|
+
exchange=config.get("exchange", self.default_exchange),
|
55
|
+
routing_key=config.get("routing_key", name),
|
56
|
+
queue_arguments=config.get("arguments", {}),
|
57
|
+
)
|
58
|
+
for name, config in default_queues.items()
|
59
|
+
}
|
60
|
+
|
61
|
+
self.conf.update(
|
62
|
+
broker_url=broker_url,
|
63
|
+
result_backend=result_backend,
|
64
|
+
# Worker Pool Configuration
|
65
|
+
worker_pool="solo",
|
66
|
+
worker_pool_restarts=True,
|
67
|
+
broker_connection_retry_on_startup=True,
|
68
|
+
# Worker Configuration
|
69
|
+
worker_prefetch_multiplier=1,
|
70
|
+
worker_max_tasks_per_child=1000,
|
71
|
+
worker_concurrency=os.cpu_count(),
|
72
|
+
# Task Settings
|
73
|
+
task_acks_late=True,
|
74
|
+
task_reject_on_worker_lost=True,
|
75
|
+
task_time_limit=3600,
|
76
|
+
task_soft_time_limit=3000,
|
77
|
+
task_default_retry_delay=300,
|
78
|
+
task_max_retries=3,
|
79
|
+
# Memory Management
|
80
|
+
worker_max_memory_per_child=200000, # 200MB
|
81
|
+
# Task Routing
|
82
|
+
task_routes=task_routes,
|
83
|
+
task_queues=list(self._queues.values()),
|
84
|
+
# Performance Settings
|
85
|
+
task_compression="gzip",
|
86
|
+
result_compression="gzip",
|
87
|
+
task_serializer="json",
|
88
|
+
result_serializer="json",
|
89
|
+
accept_content=["json"],
|
90
|
+
imports=imports,
|
91
|
+
task_default_exchange=self.default_exchange.name,
|
92
|
+
task_default_routing_key="default",
|
93
|
+
)
|
94
|
+
|
95
|
+
self._setup_signals()
|
96
|
+
|
97
|
+
def _setup_signals(self):
|
98
|
+
@worker_ready.connect
|
99
|
+
def on_worker_ready(sender, **kwargs):
|
100
|
+
self.logger.info(f"Worker {sender.hostname} is ready")
|
101
|
+
|
102
|
+
@worker_shutdown.connect
|
103
|
+
def on_worker_shutdown(sender, **kwargs):
|
104
|
+
self.logger.info(f"Worker {sender.hostname} is shutting down")
|
105
|
+
self._executor.shutdown(wait=True)
|
106
|
+
|
107
|
+
@task_prerun.connect
|
108
|
+
def task_prerun_handler(task_id, task, *args, **kwargs):
|
109
|
+
self._task_timings[task_id] = {"start": time.time()}
|
110
|
+
self.logger.info(f"Task {task.name}[{task_id}] started")
|
111
|
+
|
112
|
+
@task_postrun.connect
|
113
|
+
def task_postrun_handler(task_id, task, *args, retval=None, **kwargs):
|
114
|
+
if task_id in self._task_timings:
|
115
|
+
start_time = self._task_timings[task_id]["start"]
|
116
|
+
duration = time.time() - start_time
|
117
|
+
self.logger.info(f"Task {task.name}[{task_id}] completed in {duration:.2f}s")
|
118
|
+
del self._task_timings[task_id]
|
119
|
+
|
120
|
+
@task_failure.connect
|
121
|
+
def task_failure_handler(task_id, exc, task, *args, **kwargs):
|
122
|
+
self.logger.error(f"Task {task.name}[{task_id}] failed: {exc}\n{traceback.format_exc()}")
|
123
|
+
|
124
|
+
@after_setup_logger.connect
|
125
|
+
def setup_celery_logger(logger, *args, **kwargs):
|
126
|
+
existing_logger = logging.getLogger("hypern")
|
127
|
+
logger.handlers = existing_logger.handlers
|
128
|
+
logger.filters = existing_logger.filters
|
129
|
+
logger.level = existing_logger.level
|
130
|
+
|
131
|
+
@after_setup_task_logger.connect
|
132
|
+
def setup_task_logger(logger, *args, **kwargs):
|
133
|
+
existing_logger = logging.getLogger("hypern")
|
134
|
+
logger.handlers = existing_logger.handlers
|
135
|
+
logger.filters = existing_logger.filters
|
136
|
+
logger.level = existing_logger.level
|
137
|
+
|
138
|
+
def add_task_routes(self, routes: Dict[str, str]) -> None:
|
139
|
+
"""
|
140
|
+
Example:
|
141
|
+
app.add_task_routes({
|
142
|
+
'tasks.email.*': 'email_queue',
|
143
|
+
'tasks.payment.process': 'payment_queue',
|
144
|
+
'tasks.high_priority.*': 'high_priority'
|
145
|
+
})
|
146
|
+
"""
|
147
|
+
for task_pattern, queue in routes.items():
|
148
|
+
self.add_task_route(task_pattern, queue)
|
149
|
+
|
150
|
+
def add_task_route(self, task_pattern: str, queue: str) -> None:
|
151
|
+
"""
|
152
|
+
Add a task route to the Celery app
|
153
|
+
|
154
|
+
Example:
|
155
|
+
app.add_task_route('tasks.email.send', 'email_queue')
|
156
|
+
app.add_task_route('tasks.payment.*', 'payment_queue')
|
157
|
+
"""
|
158
|
+
if queue not in self._queues:
|
159
|
+
raise ValueError(f"Queue '{queue}' does not exist. Create it first using create_queue()")
|
160
|
+
|
161
|
+
self._task_route_mapping[task_pattern] = queue
|
162
|
+
|
163
|
+
# Update Celery task routes
|
164
|
+
routes = self.conf.task_routes or {}
|
165
|
+
routes[task_pattern] = {"queue": queue}
|
166
|
+
self.conf.task_routes = routes
|
167
|
+
|
168
|
+
self.logger.info(f"Added route: {task_pattern} -> {queue}")
|
169
|
+
|
170
|
+
def task(self, *args, **opts):
|
171
|
+
"""
|
172
|
+
Decorator modified to support sync and async functions
|
173
|
+
"""
|
174
|
+
base_task = Celery.task.__get__(self)
|
175
|
+
|
176
|
+
def decorator(func):
|
177
|
+
is_async = asyncio.iscoroutinefunction(func)
|
178
|
+
|
179
|
+
if is_async:
|
180
|
+
|
181
|
+
@wraps(func)
|
182
|
+
async def async_wrapper(*fargs, **fkwargs):
|
183
|
+
return await func(*fargs, **fkwargs)
|
184
|
+
|
185
|
+
@base_task(*args, **opts)
|
186
|
+
def wrapped(*fargs, **fkwargs):
|
187
|
+
loop = asyncio.new_event_loop()
|
188
|
+
asyncio.set_event_loop(loop)
|
189
|
+
try:
|
190
|
+
return loop.run_until_complete(async_wrapper(*fargs, **fkwargs))
|
191
|
+
finally:
|
192
|
+
loop.close()
|
193
|
+
|
194
|
+
return wrapped
|
195
|
+
else:
|
196
|
+
return base_task(*args, **opts)(func)
|
197
|
+
|
198
|
+
return decorator
|
199
|
+
|
200
|
+
async def async_send_task(self, task_name: str, *args, **kwargs) -> AsyncResult:
|
201
|
+
"""
|
202
|
+
Version of send_task() that is async
|
203
|
+
"""
|
204
|
+
loop = asyncio.get_event_loop()
|
205
|
+
return await loop.run_in_executor(self._executor, partial(self.send_task, task_name, args=args, kwargs=kwargs))
|
206
|
+
|
207
|
+
async def async_result(self, task_id: str) -> Dict:
|
208
|
+
"""
|
209
|
+
Get the result of a task asynchronously
|
210
|
+
"""
|
211
|
+
async_result = self.AsyncResult(task_id)
|
212
|
+
loop = asyncio.get_event_loop()
|
213
|
+
|
214
|
+
result = await loop.run_in_executor(
|
215
|
+
self._executor,
|
216
|
+
lambda: {
|
217
|
+
"task_id": task_id,
|
218
|
+
"status": async_result.status,
|
219
|
+
"result": async_result.result,
|
220
|
+
"traceback": async_result.traceback,
|
221
|
+
"date_done": async_result.date_done,
|
222
|
+
},
|
223
|
+
)
|
224
|
+
return result
|
225
|
+
|
226
|
+
def get_queue_length(self, queue_name: str) -> int:
|
227
|
+
"""
|
228
|
+
Get the number of messages in a queue
|
229
|
+
"""
|
230
|
+
with self.connection_or_acquire() as conn:
|
231
|
+
channel = conn.channel()
|
232
|
+
queue = Queue(queue_name, channel=channel)
|
233
|
+
return queue.queue_declare(passive=True).message_count
|
234
|
+
|
235
|
+
async def chain_tasks(self, tasks: list) -> AsyncResult:
|
236
|
+
"""
|
237
|
+
Function to chain multiple tasks together
|
238
|
+
"""
|
239
|
+
chain = tasks[0]
|
240
|
+
for task in tasks[1:]:
|
241
|
+
chain = chain | task
|
242
|
+
return await self.adelay_task(chain)
|
243
|
+
|
244
|
+
def register_task_middleware(self, middleware: Callable):
|
245
|
+
"""
|
246
|
+
Register a middleware function to be called before each task
|
247
|
+
"""
|
248
|
+
|
249
|
+
def task_middleware(task):
|
250
|
+
@wraps(task)
|
251
|
+
def _wrapped(*args, **kwargs):
|
252
|
+
return middleware(task, *args, **kwargs)
|
253
|
+
|
254
|
+
return _wrapped
|
255
|
+
|
256
|
+
self.task = task_middleware(self.task)
|
257
|
+
|
258
|
+
def monitor_task(self, task_id: str) -> dict:
|
259
|
+
"""
|
260
|
+
Get monitoring data for a task
|
261
|
+
"""
|
262
|
+
result = self.AsyncResult(task_id)
|
263
|
+
timing_info = self._task_timings.get(task_id, {})
|
264
|
+
|
265
|
+
monitoring_data = {
|
266
|
+
"task_id": task_id,
|
267
|
+
"status": result.status,
|
268
|
+
"start_time": timing_info.get("start"),
|
269
|
+
"duration": time.time() - timing_info["start"] if timing_info.get("start") else None,
|
270
|
+
"result": result.result if result.ready() else None,
|
271
|
+
"traceback": result.traceback,
|
272
|
+
}
|
273
|
+
|
274
|
+
return monitoring_data
|
hypern/ws/__init__.py
ADDED
hypern/ws/channel.py
ADDED
@@ -0,0 +1,80 @@
|
|
1
|
+
from dataclasses import dataclass, field
|
2
|
+
from typing import Any, Awaitable, Callable, Dict, Set
|
3
|
+
|
4
|
+
from hypern.hypern import WebSocketSession
|
5
|
+
|
6
|
+
|
7
|
+
@dataclass
|
8
|
+
class Channel:
|
9
|
+
name: str
|
10
|
+
subscribers: Set[WebSocketSession] = field(default_factory=set)
|
11
|
+
handlers: Dict[str, Callable[[WebSocketSession, Any], Awaitable[None]]] = field(default_factory=dict)
|
12
|
+
|
13
|
+
def publish(self, event: str, data: Any, publisher: WebSocketSession = None):
|
14
|
+
"""Publish an event to all subscribers except the publisher"""
|
15
|
+
for subscriber in self.subscribers:
|
16
|
+
if subscriber != publisher:
|
17
|
+
subscriber.send({"channel": self.name, "event": event, "data": data})
|
18
|
+
|
19
|
+
def handle_event(self, event: str, session: WebSocketSession, data: Any):
|
20
|
+
"""Handle an event on this channel"""
|
21
|
+
if event in self.handlers:
|
22
|
+
self.handlers[event](session, data)
|
23
|
+
|
24
|
+
def add_subscriber(self, subscriber: WebSocketSession):
|
25
|
+
"""Add a subscriber to the channel"""
|
26
|
+
self.subscribers.add(subscriber)
|
27
|
+
|
28
|
+
def remove_subscriber(self, subscriber: WebSocketSession):
|
29
|
+
"""Remove a subscriber from the channel"""
|
30
|
+
self.subscribers.discard(subscriber)
|
31
|
+
|
32
|
+
def on(self, event: str):
|
33
|
+
"""Decorator for registering event handlers"""
|
34
|
+
|
35
|
+
def decorator(handler: Callable[[WebSocketSession, Any], Awaitable[None]]):
|
36
|
+
self.handlers[event] = handler
|
37
|
+
return handler
|
38
|
+
|
39
|
+
return decorator
|
40
|
+
|
41
|
+
|
42
|
+
class ChannelManager:
|
43
|
+
def __init__(self):
|
44
|
+
self.channels: Dict[str, Channel] = {}
|
45
|
+
self.client_channels: Dict[WebSocketSession, Set[str]] = {}
|
46
|
+
|
47
|
+
def create_channel(self, channel_name: str) -> Channel:
|
48
|
+
"""Create a new channel if it doesn't exist"""
|
49
|
+
if channel_name not in self.channels:
|
50
|
+
self.channels[channel_name] = Channel(channel_name)
|
51
|
+
return self.channels[channel_name]
|
52
|
+
|
53
|
+
def get_channel(self, channel_name: str) -> Channel:
|
54
|
+
"""Get a channel by name"""
|
55
|
+
return self.channels.get(channel_name)
|
56
|
+
|
57
|
+
def subscribe(self, client: WebSocketSession, channel_name: str):
|
58
|
+
"""Subscribe a client to a channel"""
|
59
|
+
channel = self.create_channel(channel_name)
|
60
|
+
channel.add_subscriber(client)
|
61
|
+
|
62
|
+
if client not in self.client_channels:
|
63
|
+
self.client_channels[client] = set()
|
64
|
+
self.client_channels[client].add(channel_name)
|
65
|
+
|
66
|
+
def unsubscribe(self, client: WebSocketSession, channel_name: str):
|
67
|
+
"""Unsubscribe a client from a channel"""
|
68
|
+
channel = self.get_channel(channel_name)
|
69
|
+
if channel:
|
70
|
+
channel.remove_subscriber(client)
|
71
|
+
if client in self.client_channels:
|
72
|
+
self.client_channels[client].discard(channel_name)
|
73
|
+
|
74
|
+
def unsubscribe_all(self, client: WebSocketSession):
|
75
|
+
"""Unsubscribe a client from all channels"""
|
76
|
+
if client in self.client_channels:
|
77
|
+
channels = self.client_channels[client].copy()
|
78
|
+
for channel_name in channels:
|
79
|
+
self.unsubscribe(client, channel_name)
|
80
|
+
del self.client_channels[client]
|