lorax-arg 0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lorax/buffer.py +43 -0
- lorax/cache/__init__.py +43 -0
- lorax/cache/csv_tree_graph.py +59 -0
- lorax/cache/disk.py +467 -0
- lorax/cache/file_cache.py +142 -0
- lorax/cache/file_context.py +72 -0
- lorax/cache/lru.py +90 -0
- lorax/cache/tree_graph.py +293 -0
- lorax/cli.py +312 -0
- lorax/cloud/__init__.py +0 -0
- lorax/cloud/gcs_utils.py +205 -0
- lorax/constants.py +66 -0
- lorax/context.py +80 -0
- lorax/csv/__init__.py +7 -0
- lorax/csv/config.py +250 -0
- lorax/csv/layout.py +182 -0
- lorax/csv/newick_tree.py +234 -0
- lorax/handlers.py +998 -0
- lorax/lineage.py +456 -0
- lorax/loaders/__init__.py +0 -0
- lorax/loaders/csv_loader.py +10 -0
- lorax/loaders/loader.py +31 -0
- lorax/loaders/tskit_loader.py +119 -0
- lorax/lorax_app.py +75 -0
- lorax/manager.py +58 -0
- lorax/metadata/__init__.py +0 -0
- lorax/metadata/loader.py +426 -0
- lorax/metadata/mutations.py +146 -0
- lorax/modes.py +190 -0
- lorax/pg.py +183 -0
- lorax/redis_utils.py +30 -0
- lorax/routes.py +137 -0
- lorax/session_manager.py +206 -0
- lorax/sockets/__init__.py +55 -0
- lorax/sockets/connection.py +99 -0
- lorax/sockets/debug.py +47 -0
- lorax/sockets/decorators.py +112 -0
- lorax/sockets/file_ops.py +200 -0
- lorax/sockets/lineage.py +307 -0
- lorax/sockets/metadata.py +232 -0
- lorax/sockets/mutations.py +154 -0
- lorax/sockets/node_search.py +535 -0
- lorax/sockets/tree_layout.py +117 -0
- lorax/sockets/utils.py +10 -0
- lorax/tree_graph/__init__.py +12 -0
- lorax/tree_graph/tree_graph.py +689 -0
- lorax/utils.py +124 -0
- lorax_app/__init__.py +4 -0
- lorax_app/app.py +159 -0
- lorax_app/cli.py +114 -0
- lorax_app/static/X.png +0 -0
- lorax_app/static/assets/index-BCEGlUFi.js +2361 -0
- lorax_app/static/assets/index-iKjzUpA9.css +1 -0
- lorax_app/static/assets/localBackendWorker-BaWwjSV_.js +2 -0
- lorax_app/static/assets/renderDataWorker-BKLdiU7J.js +2 -0
- lorax_app/static/gestures/gesture-flick.ogv +0 -0
- lorax_app/static/gestures/gesture-two-finger-scroll.ogv +0 -0
- lorax_app/static/index.html +14 -0
- lorax_app/static/logo.png +0 -0
- lorax_app/static/lorax-logo.png +0 -0
- lorax_app/static/vite.svg +1 -0
- lorax_arg-0.1.dist-info/METADATA +131 -0
- lorax_arg-0.1.dist-info/RECORD +66 -0
- lorax_arg-0.1.dist-info/WHEEL +5 -0
- lorax_arg-0.1.dist-info/entry_points.txt +4 -0
- lorax_arg-0.1.dist-info/top_level.txt +2 -0
lorax/session_manager.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
|
|
2
|
+
import os
|
|
3
|
+
import json
|
|
4
|
+
from uuid import uuid4
|
|
5
|
+
from datetime import datetime, timezone
|
|
6
|
+
from typing import Optional, Dict, Tuple
|
|
7
|
+
|
|
8
|
+
from lorax.redis_utils import create_redis_client
|
|
9
|
+
from fastapi import Request, Response
|
|
10
|
+
|
|
11
|
+
from lorax.constants import (
|
|
12
|
+
SESSION_COOKIE,
|
|
13
|
+
COOKIE_MAX_AGE,
|
|
14
|
+
MAX_SOCKETS_PER_SESSION,
|
|
15
|
+
ENFORCE_CONNECTION_LIMITS,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
def _is_https_request(request: Request) -> bool:
|
|
19
|
+
"""
|
|
20
|
+
Determine whether the *original* request was HTTPS.
|
|
21
|
+
|
|
22
|
+
In production, Lorax typically runs behind an HTTPS load balancer / proxy.
|
|
23
|
+
In that setup, the app may see an internal hop as http://, but the proxy
|
|
24
|
+
sets X-Forwarded-Proto=https (or Forwarded: proto=https).
|
|
25
|
+
"""
|
|
26
|
+
xf_proto = (request.headers.get("x-forwarded-proto") or "").split(",")[0].strip().lower()
|
|
27
|
+
if xf_proto:
|
|
28
|
+
return xf_proto == "https"
|
|
29
|
+
|
|
30
|
+
forwarded = request.headers.get("forwarded") or ""
|
|
31
|
+
# Minimal parse: look for "proto=https" token anywhere.
|
|
32
|
+
if "proto=https" in forwarded.lower():
|
|
33
|
+
return True
|
|
34
|
+
|
|
35
|
+
return request.url.scheme == "https"
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class Session:
|
|
39
|
+
"""Per-user session with socket connection tracking."""
|
|
40
|
+
|
|
41
|
+
def __init__(self, sid, file_path=None, socket_connections=None, last_activity=None):
|
|
42
|
+
self.sid = sid
|
|
43
|
+
self.file_path = file_path
|
|
44
|
+
self.created_at = datetime.now(timezone.utc).isoformat()
|
|
45
|
+
self.last_activity = last_activity or self.created_at
|
|
46
|
+
# socket_connections: {socket_sid: connected_at_iso_string}
|
|
47
|
+
self.socket_connections: Dict[str, str] = socket_connections or {}
|
|
48
|
+
|
|
49
|
+
def to_dict(self):
|
|
50
|
+
return {
|
|
51
|
+
"sid": self.sid,
|
|
52
|
+
"file_path": self.file_path,
|
|
53
|
+
"created_at": self.created_at,
|
|
54
|
+
"last_activity": self.last_activity,
|
|
55
|
+
"socket_connections": self.socket_connections,
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
@staticmethod
|
|
59
|
+
def from_dict(data):
|
|
60
|
+
session = Session(
|
|
61
|
+
sid=data["sid"],
|
|
62
|
+
file_path=data.get("file_path"),
|
|
63
|
+
socket_connections=data.get("socket_connections", {}),
|
|
64
|
+
last_activity=data.get("last_activity"),
|
|
65
|
+
)
|
|
66
|
+
session.created_at = data.get("created_at", session.created_at)
|
|
67
|
+
return session
|
|
68
|
+
|
|
69
|
+
def update_activity(self):
|
|
70
|
+
"""Update last activity timestamp."""
|
|
71
|
+
self.last_activity = datetime.now(timezone.utc).isoformat()
|
|
72
|
+
|
|
73
|
+
def add_socket(self, socket_sid: str) -> Optional[str]:
|
|
74
|
+
"""
|
|
75
|
+
Add a socket connection.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
socket_sid of oldest connection to replace, or None if under limit
|
|
79
|
+
"""
|
|
80
|
+
self.update_activity()
|
|
81
|
+
|
|
82
|
+
# Check if we need to replace an existing connection
|
|
83
|
+
if ENFORCE_CONNECTION_LIMITS and len(self.socket_connections) >= MAX_SOCKETS_PER_SESSION:
|
|
84
|
+
# Find oldest connection by connected_at timestamp
|
|
85
|
+
oldest_socket = min(
|
|
86
|
+
self.socket_connections.items(),
|
|
87
|
+
key=lambda x: x[1] # Sort by timestamp
|
|
88
|
+
)
|
|
89
|
+
oldest_socket_sid = oldest_socket[0]
|
|
90
|
+
# Remove oldest
|
|
91
|
+
del self.socket_connections[oldest_socket_sid]
|
|
92
|
+
# Add new
|
|
93
|
+
self.socket_connections[socket_sid] = datetime.now(timezone.utc).isoformat()
|
|
94
|
+
return oldest_socket_sid
|
|
95
|
+
|
|
96
|
+
# Under limit, just add
|
|
97
|
+
self.socket_connections[socket_sid] = datetime.now(timezone.utc).isoformat()
|
|
98
|
+
return None
|
|
99
|
+
|
|
100
|
+
def remove_socket(self, socket_sid: str):
|
|
101
|
+
"""Remove a socket connection."""
|
|
102
|
+
self.socket_connections.pop(socket_sid, None)
|
|
103
|
+
self.update_activity()
|
|
104
|
+
|
|
105
|
+
def get_socket_count(self) -> int:
|
|
106
|
+
"""Get current socket connection count."""
|
|
107
|
+
return len(self.socket_connections)
|
|
108
|
+
|
|
109
|
+
def is_at_connection_limit(self) -> bool:
|
|
110
|
+
"""Check if session is at connection limit."""
|
|
111
|
+
if not ENFORCE_CONNECTION_LIMITS:
|
|
112
|
+
return False
|
|
113
|
+
return len(self.socket_connections) >= MAX_SOCKETS_PER_SESSION
|
|
114
|
+
|
|
115
|
+
class SessionManager:
|
|
116
|
+
def __init__(
|
|
117
|
+
self,
|
|
118
|
+
redis_url: Optional[str] = None,
|
|
119
|
+
*,
|
|
120
|
+
redis_client=None,
|
|
121
|
+
redis_cluster: bool = False,
|
|
122
|
+
):
|
|
123
|
+
self.redis_url = redis_url
|
|
124
|
+
self.redis_client = redis_client
|
|
125
|
+
self.redis_cluster = redis_cluster
|
|
126
|
+
self.memory_sessions: Dict[str, Session] = {}
|
|
127
|
+
|
|
128
|
+
if self.redis_client is None and self.redis_url:
|
|
129
|
+
self.redis_client = create_redis_client(
|
|
130
|
+
self.redis_url,
|
|
131
|
+
decode_responses=True,
|
|
132
|
+
cluster=self.redis_cluster,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
if self.redis_client:
|
|
136
|
+
print(f"✅ SessionManager using Redis at {self.redis_url}")
|
|
137
|
+
else:
|
|
138
|
+
print("⚠️ SessionManager running in in-memory mode")
|
|
139
|
+
|
|
140
|
+
async def get_session(self, sid: str) -> Optional[Session]:
|
|
141
|
+
"""Retrieve a session by SID."""
|
|
142
|
+
if not sid:
|
|
143
|
+
return None
|
|
144
|
+
|
|
145
|
+
if self.redis_client:
|
|
146
|
+
data = await self.redis_client.get(f"sessions:{sid}")
|
|
147
|
+
if data:
|
|
148
|
+
return Session.from_dict(json.loads(data))
|
|
149
|
+
return None
|
|
150
|
+
else:
|
|
151
|
+
return self.memory_sessions.get(sid)
|
|
152
|
+
|
|
153
|
+
async def create_session(self, sid: str = None) -> Session:
|
|
154
|
+
"""Create a new session. If SID provided, verify uniqueness/overwrite if needed."""
|
|
155
|
+
if not sid:
|
|
156
|
+
sid = str(uuid4())
|
|
157
|
+
|
|
158
|
+
session = Session(sid)
|
|
159
|
+
await self.save_session(session)
|
|
160
|
+
return session
|
|
161
|
+
|
|
162
|
+
async def save_session(self, session: Session):
|
|
163
|
+
"""Persist session state."""
|
|
164
|
+
if self.redis_client:
|
|
165
|
+
await self.redis_client.setex(
|
|
166
|
+
f"sessions:{session.sid}",
|
|
167
|
+
COOKIE_MAX_AGE,
|
|
168
|
+
json.dumps(session.to_dict())
|
|
169
|
+
)
|
|
170
|
+
else:
|
|
171
|
+
self.memory_sessions[session.sid] = session
|
|
172
|
+
|
|
173
|
+
async def get_or_create_session(self, request: Request, response: Response):
|
|
174
|
+
"""Helper to handle cookie extraction and setting."""
|
|
175
|
+
sid = request.cookies.get(SESSION_COOKIE)
|
|
176
|
+
# Cross-site usage (e.g. lorax.ucsc.edu -> api.lorax.in) requires SameSite=None,
|
|
177
|
+
# and browsers require SameSite=None cookies to be Secure.
|
|
178
|
+
# Detect original HTTPS behind proxies via X-Forwarded-Proto / Forwarded.
|
|
179
|
+
secure = _is_https_request(request)
|
|
180
|
+
|
|
181
|
+
session = None
|
|
182
|
+
if sid:
|
|
183
|
+
session = await self.get_session(sid)
|
|
184
|
+
|
|
185
|
+
if not session:
|
|
186
|
+
# Create new if missing or expired
|
|
187
|
+
session = await self.create_session()
|
|
188
|
+
# SameSite=None is required for cross-site usage (e.g. lorax.ucsc.edu -> api.lorax.in),
|
|
189
|
+
# but browsers require SameSite=None cookies to be Secure (HTTPS).
|
|
190
|
+
# For local HTTP dev, fall back to Lax so the cookie is accepted.
|
|
191
|
+
samesite = "none" if secure else "lax"
|
|
192
|
+
response.set_cookie(
|
|
193
|
+
key=SESSION_COOKIE,
|
|
194
|
+
value=session.sid,
|
|
195
|
+
httponly=True,
|
|
196
|
+
samesite=samesite,
|
|
197
|
+
max_age=COOKIE_MAX_AGE,
|
|
198
|
+
secure=secure
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
return session.sid, session
|
|
202
|
+
|
|
203
|
+
async def health_check(self):
|
|
204
|
+
if self.redis_client:
|
|
205
|
+
return await self.redis_client.ping()
|
|
206
|
+
return True
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Lorax Socket Event Handlers.
|
|
3
|
+
|
|
4
|
+
This package provides modularized socket event handling:
|
|
5
|
+
- connection: connect, disconnect, ping events
|
|
6
|
+
- file_ops: load_file, details, query events
|
|
7
|
+
- tree_layout: process_postorder_layout, cache_trees events
|
|
8
|
+
- metadata: fetch_metadata_*, search_metadata events
|
|
9
|
+
- mutations: query_mutations_window, search_mutations events
|
|
10
|
+
- node_search: search_nodes, get_highlight_positions events
|
|
11
|
+
- lineage: ancestors, descendants, mrca, subtree events
|
|
12
|
+
- debug: cache_stats events
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from lorax.sockets.utils import is_csv_session_file
|
|
16
|
+
from lorax.sockets.decorators import (
|
|
17
|
+
require_session,
|
|
18
|
+
with_session,
|
|
19
|
+
with_file_loaded,
|
|
20
|
+
csv_not_supported,
|
|
21
|
+
socket_error_handler,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
# Will be populated by individual modules after split
|
|
25
|
+
from lorax.sockets.connection import register_connection_events
|
|
26
|
+
from lorax.sockets.file_ops import register_file_events
|
|
27
|
+
from lorax.sockets.tree_layout import register_tree_layout_events
|
|
28
|
+
from lorax.sockets.metadata import register_metadata_events
|
|
29
|
+
from lorax.sockets.mutations import register_mutations_events
|
|
30
|
+
from lorax.sockets.node_search import register_node_search_events
|
|
31
|
+
from lorax.sockets.lineage import register_lineage_events
|
|
32
|
+
from lorax.sockets.debug import register_debug_events
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def register_socket_events(sio):
|
|
36
|
+
"""Register all socket event handlers."""
|
|
37
|
+
register_connection_events(sio)
|
|
38
|
+
register_file_events(sio)
|
|
39
|
+
register_tree_layout_events(sio)
|
|
40
|
+
register_metadata_events(sio)
|
|
41
|
+
register_mutations_events(sio)
|
|
42
|
+
register_node_search_events(sio)
|
|
43
|
+
register_lineage_events(sio)
|
|
44
|
+
register_debug_events(sio)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
__all__ = [
|
|
48
|
+
"register_socket_events",
|
|
49
|
+
"is_csv_session_file",
|
|
50
|
+
"require_session",
|
|
51
|
+
"with_session",
|
|
52
|
+
"with_file_loaded",
|
|
53
|
+
"csv_not_supported",
|
|
54
|
+
"socket_error_handler",
|
|
55
|
+
]
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Connection event handlers for Lorax Socket.IO.
|
|
3
|
+
|
|
4
|
+
Handles connect, disconnect, and ping events.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from datetime import datetime
|
|
8
|
+
from http.cookies import SimpleCookie
|
|
9
|
+
|
|
10
|
+
from lorax.context import session_manager
|
|
11
|
+
from lorax.constants import ERROR_SESSION_NOT_FOUND, ERROR_CONNECTION_REPLACED
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
# Mapping from socket_sid to lorax_sid for disconnect handling
|
|
15
|
+
_socket_to_session: dict = {}
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def register_connection_events(sio):
|
|
19
|
+
"""Register connection-related socket events."""
|
|
20
|
+
|
|
21
|
+
@sio.event
|
|
22
|
+
async def connect(sid, environ, auth=None):
|
|
23
|
+
print(f"🔌 Socket.IO connected: {sid}")
|
|
24
|
+
|
|
25
|
+
cookie = environ.get("HTTP_COOKIE", "")
|
|
26
|
+
cookies = SimpleCookie()
|
|
27
|
+
cookies.load(cookie)
|
|
28
|
+
|
|
29
|
+
lorax_sid_cookie = cookies.get("lorax_sid")
|
|
30
|
+
session_id = lorax_sid_cookie.value if lorax_sid_cookie else None
|
|
31
|
+
|
|
32
|
+
if not session_id:
|
|
33
|
+
print(f"⚠️ No lorax_sid cookie found for socket {sid}")
|
|
34
|
+
await sio.emit("error", {
|
|
35
|
+
"code": ERROR_SESSION_NOT_FOUND,
|
|
36
|
+
"message": "Session not found. Please refresh the page."
|
|
37
|
+
}, to=sid)
|
|
38
|
+
return
|
|
39
|
+
|
|
40
|
+
# Validate session exists
|
|
41
|
+
session = await session_manager.get_session(session_id)
|
|
42
|
+
if not session:
|
|
43
|
+
print(f"⚠️ Session not found: {session_id}")
|
|
44
|
+
await sio.emit("error", {
|
|
45
|
+
"code": ERROR_SESSION_NOT_FOUND,
|
|
46
|
+
"message": "Session expired. Please refresh the page."
|
|
47
|
+
}, to=sid)
|
|
48
|
+
return
|
|
49
|
+
|
|
50
|
+
# Track this socket connection
|
|
51
|
+
replaced_socket_sid = session.add_socket(sid)
|
|
52
|
+
await session_manager.save_session(session)
|
|
53
|
+
|
|
54
|
+
# Store mapping for disconnect handling
|
|
55
|
+
_socket_to_session[sid] = session_id
|
|
56
|
+
|
|
57
|
+
# If we replaced an old connection, notify it
|
|
58
|
+
if replaced_socket_sid:
|
|
59
|
+
print(f"🔄 Replacing old socket {replaced_socket_sid} with new socket {sid}")
|
|
60
|
+
await sio.emit("connection-replaced", {
|
|
61
|
+
"code": ERROR_CONNECTION_REPLACED,
|
|
62
|
+
"message": "This connection was replaced by a newer tab. Please use the new tab.",
|
|
63
|
+
}, to=replaced_socket_sid)
|
|
64
|
+
# Disconnect the old socket
|
|
65
|
+
try:
|
|
66
|
+
await sio.disconnect(replaced_socket_sid)
|
|
67
|
+
except Exception as e:
|
|
68
|
+
print(f"Warning: Failed to disconnect old socket: {e}")
|
|
69
|
+
|
|
70
|
+
# Send session state
|
|
71
|
+
if session.file_path:
|
|
72
|
+
await sio.emit("session-restored", {
|
|
73
|
+
"lorax_sid": session_id,
|
|
74
|
+
"file_path": session.file_path
|
|
75
|
+
}, to=sid)
|
|
76
|
+
else:
|
|
77
|
+
await sio.emit("status", {
|
|
78
|
+
"message": "Connected",
|
|
79
|
+
"lorax_sid": session_id
|
|
80
|
+
}, to=sid)
|
|
81
|
+
|
|
82
|
+
@sio.event
|
|
83
|
+
async def disconnect(sid):
|
|
84
|
+
print(f"🔌 Socket.IO disconnected: {sid}")
|
|
85
|
+
|
|
86
|
+
# Remove socket from session tracking
|
|
87
|
+
session_id = _socket_to_session.pop(sid, None)
|
|
88
|
+
if session_id:
|
|
89
|
+
session = await session_manager.get_session(session_id)
|
|
90
|
+
if session:
|
|
91
|
+
session.remove_socket(sid)
|
|
92
|
+
await session_manager.save_session(session)
|
|
93
|
+
|
|
94
|
+
@sio.event
|
|
95
|
+
async def ping(sid, data):
|
|
96
|
+
await sio.emit("pong", {
|
|
97
|
+
"type": "pong",
|
|
98
|
+
"time": datetime.utcnow().isoformat()
|
|
99
|
+
}, to=sid)
|
lorax/sockets/debug.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Debug event handlers for Lorax Socket.IO.
|
|
3
|
+
|
|
4
|
+
Handles cache statistics and debugging events.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from lorax.context import tree_graph_cache
|
|
8
|
+
from lorax.sockets.decorators import require_session
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def register_debug_events(sio):
|
|
12
|
+
"""Register debug-related socket events."""
|
|
13
|
+
|
|
14
|
+
@sio.event
|
|
15
|
+
async def get_cache_stats(sid, data):
|
|
16
|
+
"""Socket event to get TreeGraph cache statistics for debugging.
|
|
17
|
+
|
|
18
|
+
data: {
|
|
19
|
+
lorax_sid: str
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
Returns: {
|
|
23
|
+
mode: str,
|
|
24
|
+
session_trees: int, # Trees cached for this session
|
|
25
|
+
stats: dict # Additional stats
|
|
26
|
+
}
|
|
27
|
+
"""
|
|
28
|
+
try:
|
|
29
|
+
lorax_sid = data.get("lorax_sid")
|
|
30
|
+
session = await require_session(lorax_sid, sid, sio)
|
|
31
|
+
if not session:
|
|
32
|
+
return {"error": "Session not found"}
|
|
33
|
+
|
|
34
|
+
# Get session-specific stats
|
|
35
|
+
cached_trees = await tree_graph_cache.get_all_for_session(lorax_sid)
|
|
36
|
+
|
|
37
|
+
# Get global stats
|
|
38
|
+
global_stats = tree_graph_cache.get_stats()
|
|
39
|
+
|
|
40
|
+
return {
|
|
41
|
+
"session_trees": len(cached_trees),
|
|
42
|
+
"cached_tree_indices": list(cached_trees.keys()),
|
|
43
|
+
"stats": global_stats
|
|
44
|
+
}
|
|
45
|
+
except Exception as e:
|
|
46
|
+
print(f"❌ Get cache stats error: {e}")
|
|
47
|
+
return {"error": str(e)}
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Socket event decorators for Lorax.
|
|
3
|
+
|
|
4
|
+
Provides common patterns as decorators to reduce boilerplate:
|
|
5
|
+
- Session validation
|
|
6
|
+
- File loaded checks
|
|
7
|
+
- CSV not supported handling
|
|
8
|
+
- Error wrapping
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import functools
|
|
12
|
+
from typing import Callable, Any
|
|
13
|
+
|
|
14
|
+
from lorax.context import session_manager
|
|
15
|
+
from lorax.constants import ERROR_SESSION_NOT_FOUND, ERROR_NO_FILE_LOADED
|
|
16
|
+
from lorax.sockets.utils import is_csv_session_file
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
async def require_session(lorax_sid: str, socket_sid: str, sio) -> Any:
|
|
20
|
+
"""
|
|
21
|
+
Get session or emit error to client.
|
|
22
|
+
|
|
23
|
+
Returns the session if found, None otherwise.
|
|
24
|
+
"""
|
|
25
|
+
session = await session_manager.get_session(lorax_sid)
|
|
26
|
+
if not session:
|
|
27
|
+
await sio.emit("error", {
|
|
28
|
+
"code": ERROR_SESSION_NOT_FOUND,
|
|
29
|
+
"message": "Session expired. Please refresh the page."
|
|
30
|
+
}, to=socket_sid)
|
|
31
|
+
return None
|
|
32
|
+
return session
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def with_session(sio):
|
|
36
|
+
"""
|
|
37
|
+
Decorator that validates session before executing handler.
|
|
38
|
+
|
|
39
|
+
The decorated function must accept (sid, data) and data must contain 'lorax_sid'.
|
|
40
|
+
Passes session as third argument to the handler.
|
|
41
|
+
"""
|
|
42
|
+
def decorator(func: Callable):
|
|
43
|
+
@functools.wraps(func)
|
|
44
|
+
async def wrapper(sid, data):
|
|
45
|
+
lorax_sid = data.get("lorax_sid")
|
|
46
|
+
session = await require_session(lorax_sid, sid, sio)
|
|
47
|
+
if not session:
|
|
48
|
+
return
|
|
49
|
+
return await func(sid, data, session)
|
|
50
|
+
return wrapper
|
|
51
|
+
return decorator
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def with_file_loaded(sio, error_event: str = "error"):
|
|
55
|
+
"""
|
|
56
|
+
Decorator that validates file is loaded before executing handler.
|
|
57
|
+
|
|
58
|
+
Must be used after @with_session. The handler receives (sid, data, session).
|
|
59
|
+
"""
|
|
60
|
+
def decorator(func: Callable):
|
|
61
|
+
@functools.wraps(func)
|
|
62
|
+
async def wrapper(sid, data, session):
|
|
63
|
+
if not session.file_path:
|
|
64
|
+
print(f"⚠️ No file loaded for session {data.get('lorax_sid')}")
|
|
65
|
+
await sio.emit(error_event, {
|
|
66
|
+
"code": ERROR_NO_FILE_LOADED,
|
|
67
|
+
"message": "No file loaded. Please load a file first."
|
|
68
|
+
}, to=sid)
|
|
69
|
+
return
|
|
70
|
+
return await func(sid, data, session)
|
|
71
|
+
return wrapper
|
|
72
|
+
return decorator
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def csv_not_supported(sio, result_event: str, empty_result: dict = None):
|
|
76
|
+
"""
|
|
77
|
+
Decorator that returns early with message if file is CSV.
|
|
78
|
+
|
|
79
|
+
Must be used after @with_session. The handler receives (sid, data, session).
|
|
80
|
+
"""
|
|
81
|
+
def decorator(func: Callable):
|
|
82
|
+
@functools.wraps(func)
|
|
83
|
+
async def wrapper(sid, data, session):
|
|
84
|
+
if is_csv_session_file(session.file_path):
|
|
85
|
+
result = empty_result if empty_result else {"error": f"{func.__name__} is not supported for CSV yet."}
|
|
86
|
+
await sio.emit(result_event, result, to=sid)
|
|
87
|
+
return
|
|
88
|
+
return await func(sid, data, session)
|
|
89
|
+
return wrapper
|
|
90
|
+
return decorator
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def socket_error_handler(sio, result_event: str = "error"):
|
|
94
|
+
"""
|
|
95
|
+
Decorator that wraps handler in try/except and emits errors.
|
|
96
|
+
"""
|
|
97
|
+
def decorator(func: Callable):
|
|
98
|
+
@functools.wraps(func)
|
|
99
|
+
async def wrapper(*args, **kwargs):
|
|
100
|
+
try:
|
|
101
|
+
return await func(*args, **kwargs)
|
|
102
|
+
except Exception as e:
|
|
103
|
+
print(f"❌ {func.__name__} error: {e}")
|
|
104
|
+
# For callback-style handlers (return dict), return error
|
|
105
|
+
# For emit-style handlers, emit error
|
|
106
|
+
if result_event:
|
|
107
|
+
sid = args[0] if args else None
|
|
108
|
+
if sid:
|
|
109
|
+
await sio.emit(result_event, {"error": str(e)}, to=sid)
|
|
110
|
+
return {"error": str(e)}
|
|
111
|
+
return wrapper
|
|
112
|
+
return decorator
|