lorax-arg 0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. lorax/buffer.py +43 -0
  2. lorax/cache/__init__.py +43 -0
  3. lorax/cache/csv_tree_graph.py +59 -0
  4. lorax/cache/disk.py +467 -0
  5. lorax/cache/file_cache.py +142 -0
  6. lorax/cache/file_context.py +72 -0
  7. lorax/cache/lru.py +90 -0
  8. lorax/cache/tree_graph.py +293 -0
  9. lorax/cli.py +312 -0
  10. lorax/cloud/__init__.py +0 -0
  11. lorax/cloud/gcs_utils.py +205 -0
  12. lorax/constants.py +66 -0
  13. lorax/context.py +80 -0
  14. lorax/csv/__init__.py +7 -0
  15. lorax/csv/config.py +250 -0
  16. lorax/csv/layout.py +182 -0
  17. lorax/csv/newick_tree.py +234 -0
  18. lorax/handlers.py +998 -0
  19. lorax/lineage.py +456 -0
  20. lorax/loaders/__init__.py +0 -0
  21. lorax/loaders/csv_loader.py +10 -0
  22. lorax/loaders/loader.py +31 -0
  23. lorax/loaders/tskit_loader.py +119 -0
  24. lorax/lorax_app.py +75 -0
  25. lorax/manager.py +58 -0
  26. lorax/metadata/__init__.py +0 -0
  27. lorax/metadata/loader.py +426 -0
  28. lorax/metadata/mutations.py +146 -0
  29. lorax/modes.py +190 -0
  30. lorax/pg.py +183 -0
  31. lorax/redis_utils.py +30 -0
  32. lorax/routes.py +137 -0
  33. lorax/session_manager.py +206 -0
  34. lorax/sockets/__init__.py +55 -0
  35. lorax/sockets/connection.py +99 -0
  36. lorax/sockets/debug.py +47 -0
  37. lorax/sockets/decorators.py +112 -0
  38. lorax/sockets/file_ops.py +200 -0
  39. lorax/sockets/lineage.py +307 -0
  40. lorax/sockets/metadata.py +232 -0
  41. lorax/sockets/mutations.py +154 -0
  42. lorax/sockets/node_search.py +535 -0
  43. lorax/sockets/tree_layout.py +117 -0
  44. lorax/sockets/utils.py +10 -0
  45. lorax/tree_graph/__init__.py +12 -0
  46. lorax/tree_graph/tree_graph.py +689 -0
  47. lorax/utils.py +124 -0
  48. lorax_app/__init__.py +4 -0
  49. lorax_app/app.py +159 -0
  50. lorax_app/cli.py +114 -0
  51. lorax_app/static/X.png +0 -0
  52. lorax_app/static/assets/index-BCEGlUFi.js +2361 -0
  53. lorax_app/static/assets/index-iKjzUpA9.css +1 -0
  54. lorax_app/static/assets/localBackendWorker-BaWwjSV_.js +2 -0
  55. lorax_app/static/assets/renderDataWorker-BKLdiU7J.js +2 -0
  56. lorax_app/static/gestures/gesture-flick.ogv +0 -0
  57. lorax_app/static/gestures/gesture-two-finger-scroll.ogv +0 -0
  58. lorax_app/static/index.html +14 -0
  59. lorax_app/static/logo.png +0 -0
  60. lorax_app/static/lorax-logo.png +0 -0
  61. lorax_app/static/vite.svg +1 -0
  62. lorax_arg-0.1.dist-info/METADATA +131 -0
  63. lorax_arg-0.1.dist-info/RECORD +66 -0
  64. lorax_arg-0.1.dist-info/WHEEL +5 -0
  65. lorax_arg-0.1.dist-info/entry_points.txt +4 -0
  66. lorax_arg-0.1.dist-info/top_level.txt +2 -0
lorax/lineage.py ADDED
@@ -0,0 +1,456 @@
1
+ """
2
+ Lineage Operations for TreeGraph-based ancestry and descendant tracing.
3
+
4
+ Provides efficient operations using cached TreeGraph objects:
5
+ - Ancestor tracing: Path from node to root
6
+ - Descendant finding: All nodes below a given node
7
+ - Node search: Filter nodes by metadata/attributes
8
+ """
9
+
10
+ import numpy as np
11
+ from typing import List, Dict, Optional, Any, TYPE_CHECKING
12
+ from collections import deque
13
+
14
+ if TYPE_CHECKING:
15
+ from lorax.tree_graph import TreeGraph
16
+ from lorax.tree_graph_cache import TreeGraphCache
17
+
18
+
19
+ async def get_ancestors(
20
+ tree_graph_cache: "TreeGraphCache",
21
+ session_id: str,
22
+ tree_index: int,
23
+ node_id: int
24
+ ) -> Dict[str, Any]:
25
+ """
26
+ Get all ancestors of a node (path from node to root).
27
+
28
+ Uses the parent array from the cached TreeGraph to trace the ancestry path.
29
+
30
+ Args:
31
+ tree_graph_cache: The TreeGraph cache instance
32
+ session_id: Session identifier
33
+ tree_index: Tree index to query
34
+ node_id: Node ID to trace ancestors for
35
+
36
+ Returns:
37
+ Dict with:
38
+ - ancestors: List of node IDs from node to root (excluding the query node)
39
+ - path: List of (node_id, time, x, y) tuples for visualization
40
+ - error: Error message if tree not cached
41
+ """
42
+ tg = await tree_graph_cache.get(session_id, tree_index)
43
+ if tg is None:
44
+ return {
45
+ "error": f"Tree {tree_index} not cached. Request tree layout first.",
46
+ "ancestors": [],
47
+ "path": []
48
+ }
49
+
50
+ # Validate node_id
51
+ if node_id < 0 or node_id >= len(tg.parent):
52
+ return {
53
+ "error": f"Invalid node_id {node_id}",
54
+ "ancestors": [],
55
+ "path": []
56
+ }
57
+
58
+ # Check if node is in this tree
59
+ if not tg.in_tree[node_id]:
60
+ return {
61
+ "error": f"Node {node_id} is not in tree {tree_index}",
62
+ "ancestors": [],
63
+ "path": []
64
+ }
65
+
66
+ ancestors = []
67
+ path = []
68
+
69
+ # Include starting node in path
70
+ path.append({
71
+ "node_id": int(node_id),
72
+ "time": float(tg.time[node_id]),
73
+ "x": float(tg.y[node_id]), # Backend convention: y = time -> frontend x
74
+ "y": float(tg.x[node_id]) # Backend convention: x = layout -> frontend y
75
+ })
76
+
77
+ current = node_id
78
+ while True:
79
+ parent = tg.parent[current]
80
+ if parent == -1:
81
+ # Reached root
82
+ break
83
+ ancestors.append(int(parent))
84
+ path.append({
85
+ "node_id": int(parent),
86
+ "time": float(tg.time[parent]),
87
+ "x": float(tg.y[parent]),
88
+ "y": float(tg.x[parent])
89
+ })
90
+ current = parent
91
+
92
+ return {
93
+ "ancestors": ancestors,
94
+ "path": path,
95
+ "tree_index": tree_index,
96
+ "query_node": node_id
97
+ }
98
+
99
+
100
+ async def get_descendants(
101
+ tree_graph_cache: "TreeGraphCache",
102
+ session_id: str,
103
+ tree_index: int,
104
+ node_id: int,
105
+ include_tips_only: bool = False
106
+ ) -> Dict[str, Any]:
107
+ """
108
+ Get all descendants of a node (BFS traversal down the tree).
109
+
110
+ Uses the CSR children structure from the cached TreeGraph.
111
+
112
+ Args:
113
+ tree_graph_cache: The TreeGraph cache instance
114
+ session_id: Session identifier
115
+ tree_index: Tree index to query
116
+ node_id: Node ID to find descendants for
117
+ include_tips_only: If True, only return tip (leaf) nodes
118
+
119
+ Returns:
120
+ Dict with:
121
+ - descendants: List of node IDs that are descendants
122
+ - tips: List of tip node IDs (always included for convenience)
123
+ - error: Error message if tree not cached
124
+ """
125
+ tg = await tree_graph_cache.get(session_id, tree_index)
126
+ if tg is None:
127
+ return {
128
+ "error": f"Tree {tree_index} not cached. Request tree layout first.",
129
+ "descendants": [],
130
+ "tips": []
131
+ }
132
+
133
+ # Validate node_id
134
+ if node_id < 0 or node_id >= len(tg.parent):
135
+ return {
136
+ "error": f"Invalid node_id {node_id}",
137
+ "descendants": [],
138
+ "tips": []
139
+ }
140
+
141
+ # Check if node is in this tree
142
+ if not tg.in_tree[node_id]:
143
+ return {
144
+ "error": f"Node {node_id} is not in tree {tree_index}",
145
+ "descendants": [],
146
+ "tips": []
147
+ }
148
+
149
+ descendants = []
150
+ tips = []
151
+
152
+ # BFS traversal using deque for efficiency
153
+ queue = deque(tg.children(node_id).tolist())
154
+
155
+ while queue:
156
+ child = queue.popleft()
157
+ descendants.append(int(child))
158
+
159
+ # Check if tip (no children)
160
+ if tg.is_tip(child):
161
+ tips.append(int(child))
162
+ else:
163
+ # Add children to queue
164
+ queue.extend(tg.children(child).tolist())
165
+
166
+ result = {
167
+ "tips": tips,
168
+ "tree_index": tree_index,
169
+ "query_node": node_id,
170
+ "total_descendants": len(descendants)
171
+ }
172
+
173
+ if include_tips_only:
174
+ result["descendants"] = tips
175
+ else:
176
+ result["descendants"] = descendants
177
+
178
+ return result
179
+
180
+
181
+ async def search_nodes_by_criteria(
182
+ tree_graph_cache: "TreeGraphCache",
183
+ session_id: str,
184
+ tree_index: int,
185
+ criteria: Dict[str, Any]
186
+ ) -> Dict[str, Any]:
187
+ """
188
+ Search for nodes matching specified criteria.
189
+
190
+ Supported criteria:
191
+ - min_time: Minimum node time
192
+ - max_time: Maximum node time
193
+ - is_tip: True for tips only, False for internal nodes only
194
+ - has_children: True for nodes with children, False for tips
195
+ - node_ids: List of specific node IDs to filter to
196
+
197
+ Args:
198
+ tree_graph_cache: The TreeGraph cache instance
199
+ session_id: Session identifier
200
+ tree_index: Tree index to search
201
+ criteria: Dict of filter criteria
202
+
203
+ Returns:
204
+ Dict with:
205
+ - matches: List of matching node IDs
206
+ - positions: List of {node_id, x, y, time} for each match
207
+ - error: Error message if tree not cached
208
+ """
209
+ tg = await tree_graph_cache.get(session_id, tree_index)
210
+ if tg is None:
211
+ return {
212
+ "error": f"Tree {tree_index} not cached. Request tree layout first.",
213
+ "matches": [],
214
+ "positions": []
215
+ }
216
+
217
+ matches = []
218
+ positions = []
219
+
220
+ # Get all nodes in this tree
221
+ in_tree_indices = np.where(tg.in_tree)[0]
222
+
223
+ # Optional: filter to specific node IDs first
224
+ node_id_filter = criteria.get("node_ids")
225
+ if node_id_filter is not None:
226
+ node_id_set = set(node_id_filter)
227
+ in_tree_indices = [n for n in in_tree_indices if n in node_id_set]
228
+
229
+ for node_id in in_tree_indices:
230
+ if _matches_criteria(tg, node_id, criteria):
231
+ matches.append(int(node_id))
232
+ positions.append({
233
+ "node_id": int(node_id),
234
+ "x": float(tg.y[node_id]), # time -> frontend x
235
+ "y": float(tg.x[node_id]), # layout -> frontend y
236
+ "time": float(tg.time[node_id])
237
+ })
238
+
239
+ return {
240
+ "matches": matches,
241
+ "positions": positions,
242
+ "tree_index": tree_index,
243
+ "criteria": criteria,
244
+ "total_matches": len(matches)
245
+ }
246
+
247
+
248
+ def _matches_criteria(tg: "TreeGraph", node_id: int, criteria: Dict[str, Any]) -> bool:
249
+ """
250
+ Check if a node matches all specified criteria.
251
+
252
+ Args:
253
+ tg: TreeGraph object
254
+ node_id: Node ID to check
255
+ criteria: Dict of filter criteria
256
+
257
+ Returns:
258
+ True if node matches all criteria
259
+ """
260
+ # Time range filters
261
+ if "min_time" in criteria:
262
+ if tg.time[node_id] < criteria["min_time"]:
263
+ return False
264
+
265
+ if "max_time" in criteria:
266
+ if tg.time[node_id] > criteria["max_time"]:
267
+ return False
268
+
269
+ # Tip/internal filter
270
+ if "is_tip" in criteria:
271
+ is_tip = tg.is_tip(node_id)
272
+ if criteria["is_tip"] != is_tip:
273
+ return False
274
+
275
+ # Has children filter (inverse of is_tip)
276
+ if "has_children" in criteria:
277
+ has_children = not tg.is_tip(node_id)
278
+ if criteria["has_children"] != has_children:
279
+ return False
280
+
281
+ # Y (layout) position range
282
+ if "min_y" in criteria:
283
+ if tg.x[node_id] < criteria["min_y"]: # x in backend = y in frontend
284
+ return False
285
+
286
+ if "max_y" in criteria:
287
+ if tg.x[node_id] > criteria["max_y"]:
288
+ return False
289
+
290
+ return True
291
+
292
+
293
+ async def get_subtree(
294
+ tree_graph_cache: "TreeGraphCache",
295
+ session_id: str,
296
+ tree_index: int,
297
+ root_node_id: int
298
+ ) -> Dict[str, Any]:
299
+ """
300
+ Get the complete subtree rooted at a given node.
301
+
302
+ Returns all nodes in the subtree with their structure preserved.
303
+
304
+ Args:
305
+ tree_graph_cache: The TreeGraph cache instance
306
+ session_id: Session identifier
307
+ tree_index: Tree index to query
308
+ root_node_id: Root of the subtree
309
+
310
+ Returns:
311
+ Dict with:
312
+ - nodes: List of {node_id, parent_id, x, y, time, is_tip}
313
+ - edges: List of {parent, child} pairs
314
+ - error: Error message if tree not cached
315
+ """
316
+ tg = await tree_graph_cache.get(session_id, tree_index)
317
+ if tg is None:
318
+ return {
319
+ "error": f"Tree {tree_index} not cached. Request tree layout first.",
320
+ "nodes": [],
321
+ "edges": []
322
+ }
323
+
324
+ # Validate node_id
325
+ if root_node_id < 0 or root_node_id >= len(tg.parent):
326
+ return {
327
+ "error": f"Invalid node_id {root_node_id}",
328
+ "nodes": [],
329
+ "edges": []
330
+ }
331
+
332
+ if not tg.in_tree[root_node_id]:
333
+ return {
334
+ "error": f"Node {root_node_id} is not in tree {tree_index}",
335
+ "nodes": [],
336
+ "edges": []
337
+ }
338
+
339
+ nodes = []
340
+ edges = []
341
+
342
+ # BFS to collect all nodes and edges
343
+ queue = deque([root_node_id])
344
+ visited = set()
345
+
346
+ while queue:
347
+ node_id = queue.popleft()
348
+ if node_id in visited:
349
+ continue
350
+ visited.add(node_id)
351
+
352
+ nodes.append({
353
+ "node_id": int(node_id),
354
+ "parent_id": int(tg.parent[node_id]),
355
+ "x": float(tg.y[node_id]),
356
+ "y": float(tg.x[node_id]),
357
+ "time": float(tg.time[node_id]),
358
+ "is_tip": tg.is_tip(node_id)
359
+ })
360
+
361
+ children = tg.children(node_id)
362
+ for child in children:
363
+ edges.append({
364
+ "parent": int(node_id),
365
+ "child": int(child)
366
+ })
367
+ queue.append(child)
368
+
369
+ return {
370
+ "nodes": nodes,
371
+ "edges": edges,
372
+ "tree_index": tree_index,
373
+ "root_node": root_node_id,
374
+ "total_nodes": len(nodes)
375
+ }
376
+
377
+
378
+ async def get_mrca(
379
+ tree_graph_cache: "TreeGraphCache",
380
+ session_id: str,
381
+ tree_index: int,
382
+ node_ids: List[int]
383
+ ) -> Dict[str, Any]:
384
+ """
385
+ Find the Most Recent Common Ancestor (MRCA) of a set of nodes.
386
+
387
+ Uses ancestor tracing and intersection to find the MRCA.
388
+
389
+ Args:
390
+ tree_graph_cache: The TreeGraph cache instance
391
+ session_id: Session identifier
392
+ tree_index: Tree index to query
393
+ node_ids: List of node IDs to find MRCA for
394
+
395
+ Returns:
396
+ Dict with:
397
+ - mrca: Node ID of the MRCA, or None if not found
398
+ - mrca_time: Time of the MRCA node
399
+ - mrca_position: {x, y} of the MRCA
400
+ - error: Error message if tree not cached
401
+ """
402
+ if not node_ids or len(node_ids) < 2:
403
+ return {
404
+ "error": "Need at least 2 nodes to find MRCA",
405
+ "mrca": None
406
+ }
407
+
408
+ tg = await tree_graph_cache.get(session_id, tree_index)
409
+ if tg is None:
410
+ return {
411
+ "error": f"Tree {tree_index} not cached. Request tree layout first.",
412
+ "mrca": None
413
+ }
414
+
415
+ # Validate all nodes
416
+ for node_id in node_ids:
417
+ if node_id < 0 or node_id >= len(tg.parent):
418
+ return {"error": f"Invalid node_id {node_id}", "mrca": None}
419
+ if not tg.in_tree[node_id]:
420
+ return {"error": f"Node {node_id} not in tree {tree_index}", "mrca": None}
421
+
422
+ # Get ancestor sets for each node
423
+ ancestor_sets = []
424
+ for node_id in node_ids:
425
+ ancestors = set()
426
+ current = node_id
427
+ while current != -1:
428
+ ancestors.add(current)
429
+ current = tg.parent[current]
430
+ ancestor_sets.append(ancestors)
431
+
432
+ # Find intersection (common ancestors)
433
+ common_ancestors = ancestor_sets[0]
434
+ for ancestor_set in ancestor_sets[1:]:
435
+ common_ancestors = common_ancestors.intersection(ancestor_set)
436
+
437
+ if not common_ancestors:
438
+ return {
439
+ "error": "No common ancestor found",
440
+ "mrca": None,
441
+ "tree_index": tree_index
442
+ }
443
+
444
+ # MRCA is the common ancestor with the highest time (most recent)
445
+ mrca = max(common_ancestors, key=lambda n: tg.time[n])
446
+
447
+ return {
448
+ "mrca": int(mrca),
449
+ "mrca_time": float(tg.time[mrca]),
450
+ "mrca_position": {
451
+ "x": float(tg.y[mrca]),
452
+ "y": float(tg.x[mrca])
453
+ },
454
+ "tree_index": tree_index,
455
+ "query_nodes": node_ids
456
+ }
File without changes
@@ -0,0 +1,10 @@
1
+ from lorax.csv.config import CsvConfigOptions, build_csv_config
2
+
3
+
4
+ def get_config_csv(df, file_path, root_dir, window_size=50000):
5
+ """Extract configuration from a Newick-per-row CSV file.
6
+
7
+ Kept as a thin wrapper for backwards compatibility; real logic lives in
8
+ `lorax.csv.config` for encapsulation and re-use.
9
+ """
10
+ return build_csv_config(df, str(file_path), options=CsvConfigOptions(window_size=window_size))
@@ -0,0 +1,31 @@
1
+ """
2
+ Config computation for loaded files.
3
+
4
+ Provides compute_config() which creates configuration from a tree sequence.
5
+ The config is cached within FileContext (see cache/file_cache.py), not here.
6
+ """
7
+
8
+ from lorax.loaders.csv_loader import get_config_csv
9
+ from lorax.loaders.tskit_loader import get_config_tskit
10
+
11
+
12
+ def compute_config(ts, file_path, root_dir):
13
+ """
14
+ Compute config for a tree sequence.
15
+
16
+ This function is called by file_cache.py when loading a new FileContext.
17
+ The config is cached within the FileContext, so this function does
18
+ not maintain its own cache.
19
+
20
+ Args:
21
+ ts: tskit.TreeSequence or pandas.DataFrame (for CSV)
22
+ file_path: Path to the source file
23
+ root_dir: Root directory for relative paths
24
+
25
+ Returns:
26
+ dict: Configuration including intervals, sample counts, etc.
27
+ """
28
+ if file_path.endswith('.tsz') or file_path.endswith('.trees'):
29
+ return get_config_tskit(ts, file_path, root_dir)
30
+ else:
31
+ return get_config_csv(ts, file_path, root_dir)
@@ -0,0 +1,119 @@
1
+ import os
2
+ import tskit
3
+ from lorax.metadata.loader import get_metadata_schema
4
+
5
+
6
+ def _get_project_name(file_path, root_dir):
7
+ """Return the immediate parent directory name for the file."""
8
+ if not file_path:
9
+ return None
10
+ try:
11
+ parent_dir = os.path.basename(os.path.dirname(str(file_path)))
12
+ except Exception:
13
+ parent_dir = None
14
+ if parent_dir:
15
+ return parent_dir
16
+ if root_dir:
17
+ return os.path.basename(os.path.normpath(str(root_dir)))
18
+ return None
19
+
20
+ def get_config_tskit(ts, file_path, root_dir):
21
+ """Extract configuration and metadata from a tree sequence file.
22
+
23
+ Note: Uses get_metadata_schema() for lightweight initial load.
24
+ Full metadata mappings are fetched on-demand via fetch_metadata_for_key.
25
+ """
26
+ try:
27
+ intervals = list(ts.breakpoints())
28
+ times = [ts.min_time, ts.max_time]
29
+ genome_length = ts.sequence_length
30
+
31
+ # Timeline unit label for UI: normalize unknown -> "Time"
32
+ time_units = getattr(ts, "time_units", None)
33
+ time_units_str = str(time_units) if time_units is not None else "unknown"
34
+ timeline_type = "Coalescent Time" if time_units_str.strip().lower() == "unknown" else time_units_str
35
+
36
+ # Compute centered initial position (10% of genome, minimum 1kb)
37
+ window_size = max(genome_length * 0.1, 1000)
38
+ midpoint = genome_length / 2.0
39
+ start = max(0, midpoint - window_size / 2.0)
40
+ end = min(genome_length, midpoint + window_size / 2.0)
41
+
42
+ sample_names = {}
43
+ # Use schema-only extraction for lightweight initial load
44
+ metadata_schema = get_metadata_schema(ts, sources=("individual", "node", "population"))
45
+
46
+ filename = os.path.basename(file_path)
47
+ project_name = _get_project_name(file_path, root_dir)
48
+
49
+ config = {
50
+ 'genome_length': genome_length,
51
+ 'initial_position': [int(start), int(end)],
52
+ 'times': {'type': timeline_type, 'values': times},
53
+ 'intervals': intervals,
54
+ 'filename': str(filename),
55
+ 'project': project_name,
56
+ # node_times removed - now sent per-query from handle_layout_query for efficiency
57
+ # 'mutations': extract_node_mutations_tables(ts),
58
+ # 'mutations_by_node': extract_mutations_by_node(ts),
59
+ 'sample_names': sample_names,
60
+ # Send schema only - full mappings fetched on-demand
61
+ 'metadata_schema': metadata_schema
62
+ }
63
+ return config
64
+ except Exception as e:
65
+ print("Error in get_config", e)
66
+ return None
67
+
68
+ def extract_node_mutations_tables(ts):
69
+ """Extract mutations keyed by position for UI display."""
70
+ t = ts.tables
71
+ s, m = t.sites, t.mutations
72
+
73
+ pos = s.position[m.site]
74
+ anc = ts.sites_ancestral_state
75
+ der = ts.mutations_derived_state
76
+ nodes = m.node # Node IDs for each mutation
77
+
78
+ out = {}
79
+
80
+ for p, a, d, node_id in zip(pos, anc, der, nodes):
81
+ if a == d:
82
+ continue
83
+
84
+ out[str(int(p))] = {
85
+ "mutation": f"{a}->{d}",
86
+ "node": int(node_id)
87
+ }
88
+
89
+ return out
90
+
91
+
92
+ def extract_mutations_by_node(ts):
93
+ """Extract mutations grouped by node ID for tree building.
94
+
95
+ Returns:
96
+ dict: {node_id (int): [{position, mutation_str}, ...]}
97
+ """
98
+ t = ts.tables
99
+ s, m = t.sites, t.mutations
100
+
101
+ pos = s.position[m.site]
102
+ anc = ts.sites_ancestral_state
103
+ der = ts.mutations_derived_state
104
+ nodes = m.node
105
+
106
+ out = {}
107
+
108
+ for p, a, d, node_id in zip(pos, anc, der, nodes):
109
+ if a == d:
110
+ continue
111
+ node_id = int(node_id)
112
+ if node_id not in out:
113
+ out[node_id] = []
114
+ out[node_id].append({
115
+ "position": int(p),
116
+ "mutation": f"{a}{int(p)}{d}"
117
+ })
118
+
119
+ return out
lorax/lorax_app.py ADDED
@@ -0,0 +1,75 @@
1
+ """
2
+ Socket.IO version of the Lorax backend (single-process, no Gunicorn).
3
+
4
+ Run with:
5
+ uvicorn lorax_socketio_app:sio_app --host 0.0.0.0 --port 8080 --reload
6
+ """
7
+ import os
8
+ import socketio
9
+ from fastapi import FastAPI
10
+ from fastapi.middleware.cors import CORSMiddleware
11
+ from starlette.middleware.gzip import GZipMiddleware
12
+ from dotenv import load_dotenv
13
+
14
+ from lorax.context import REDIS_CLUSTER_URL, REDIS_CLUSTER
15
+ from lorax.constants import (
16
+ SOCKET_PING_TIMEOUT, SOCKET_PING_INTERVAL, MAX_HTTP_BUFFER_SIZE
17
+ )
18
+ from lorax.routes import router
19
+ from lorax.sockets import register_socket_events
20
+
21
+ load_dotenv()
22
+
23
+ # Setup
24
+
25
+ app = FastAPI(title="Lorax Backend", version="1.0.0")
26
+ app.add_middleware(GZipMiddleware, minimum_size=1000)
27
+
28
+
29
+ ALLOWED_ORIGINS = [
30
+ o.strip() for o in os.getenv("ALLOWED_ORIGINS", "http://localhost:5173,http://localhost:3001,http://localhost:3000").split(",")
31
+ ]
32
+
33
+ app.add_middleware(
34
+ CORSMiddleware,
35
+ allow_origins=ALLOWED_ORIGINS,
36
+ allow_credentials=True,
37
+ allow_methods=["*"],
38
+ allow_headers=["*"],
39
+ )
40
+
41
+ _client_manager = None
42
+ if REDIS_CLUSTER_URL and not REDIS_CLUSTER:
43
+ _client_manager = socketio.AsyncRedisManager(REDIS_CLUSTER_URL)
44
+ elif REDIS_CLUSTER_URL and REDIS_CLUSTER:
45
+ print("Warning: Socket.IO Redis manager does not support Redis Cluster; running without shared manager.")
46
+
47
+ if _client_manager:
48
+ sio = socketio.AsyncServer(
49
+ async_mode="asgi",
50
+ cors_allowed_origins="*",
51
+ client_manager=_client_manager,
52
+ logger=False,
53
+ engineio_logger=False,
54
+ ping_timeout=SOCKET_PING_TIMEOUT,
55
+ ping_interval=SOCKET_PING_INTERVAL,
56
+ max_http_buffer_size=MAX_HTTP_BUFFER_SIZE
57
+ )
58
+ else:
59
+ sio = socketio.AsyncServer(
60
+ async_mode="asgi",
61
+ cors_allowed_origins="*",
62
+ logger=False,
63
+ engineio_logger=False,
64
+ ping_timeout=SOCKET_PING_TIMEOUT,
65
+ ping_interval=SOCKET_PING_INTERVAL,
66
+ max_http_buffer_size=MAX_HTTP_BUFFER_SIZE
67
+ )
68
+
69
+ sio_app = socketio.ASGIApp(sio, other_asgi_app=app)
70
+
71
+ # Include Routes
72
+ app.include_router(router)
73
+
74
+ # Register Socket Events
75
+ register_socket_events(sio)