lorax-arg 0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. lorax/buffer.py +43 -0
  2. lorax/cache/__init__.py +43 -0
  3. lorax/cache/csv_tree_graph.py +59 -0
  4. lorax/cache/disk.py +467 -0
  5. lorax/cache/file_cache.py +142 -0
  6. lorax/cache/file_context.py +72 -0
  7. lorax/cache/lru.py +90 -0
  8. lorax/cache/tree_graph.py +293 -0
  9. lorax/cli.py +312 -0
  10. lorax/cloud/__init__.py +0 -0
  11. lorax/cloud/gcs_utils.py +205 -0
  12. lorax/constants.py +66 -0
  13. lorax/context.py +80 -0
  14. lorax/csv/__init__.py +7 -0
  15. lorax/csv/config.py +250 -0
  16. lorax/csv/layout.py +182 -0
  17. lorax/csv/newick_tree.py +234 -0
  18. lorax/handlers.py +998 -0
  19. lorax/lineage.py +456 -0
  20. lorax/loaders/__init__.py +0 -0
  21. lorax/loaders/csv_loader.py +10 -0
  22. lorax/loaders/loader.py +31 -0
  23. lorax/loaders/tskit_loader.py +119 -0
  24. lorax/lorax_app.py +75 -0
  25. lorax/manager.py +58 -0
  26. lorax/metadata/__init__.py +0 -0
  27. lorax/metadata/loader.py +426 -0
  28. lorax/metadata/mutations.py +146 -0
  29. lorax/modes.py +190 -0
  30. lorax/pg.py +183 -0
  31. lorax/redis_utils.py +30 -0
  32. lorax/routes.py +137 -0
  33. lorax/session_manager.py +206 -0
  34. lorax/sockets/__init__.py +55 -0
  35. lorax/sockets/connection.py +99 -0
  36. lorax/sockets/debug.py +47 -0
  37. lorax/sockets/decorators.py +112 -0
  38. lorax/sockets/file_ops.py +200 -0
  39. lorax/sockets/lineage.py +307 -0
  40. lorax/sockets/metadata.py +232 -0
  41. lorax/sockets/mutations.py +154 -0
  42. lorax/sockets/node_search.py +535 -0
  43. lorax/sockets/tree_layout.py +117 -0
  44. lorax/sockets/utils.py +10 -0
  45. lorax/tree_graph/__init__.py +12 -0
  46. lorax/tree_graph/tree_graph.py +689 -0
  47. lorax/utils.py +124 -0
  48. lorax_app/__init__.py +4 -0
  49. lorax_app/app.py +159 -0
  50. lorax_app/cli.py +114 -0
  51. lorax_app/static/X.png +0 -0
  52. lorax_app/static/assets/index-BCEGlUFi.js +2361 -0
  53. lorax_app/static/assets/index-iKjzUpA9.css +1 -0
  54. lorax_app/static/assets/localBackendWorker-BaWwjSV_.js +2 -0
  55. lorax_app/static/assets/renderDataWorker-BKLdiU7J.js +2 -0
  56. lorax_app/static/gestures/gesture-flick.ogv +0 -0
  57. lorax_app/static/gestures/gesture-two-finger-scroll.ogv +0 -0
  58. lorax_app/static/index.html +14 -0
  59. lorax_app/static/logo.png +0 -0
  60. lorax_app/static/lorax-logo.png +0 -0
  61. lorax_app/static/vite.svg +1 -0
  62. lorax_arg-0.1.dist-info/METADATA +131 -0
  63. lorax_arg-0.1.dist-info/RECORD +66 -0
  64. lorax_arg-0.1.dist-info/WHEEL +5 -0
  65. lorax_arg-0.1.dist-info/entry_points.txt +4 -0
  66. lorax_arg-0.1.dist-info/top_level.txt +2 -0
lorax/handlers.py ADDED
@@ -0,0 +1,998 @@
1
+ # handlers.py
2
+ import os
3
+ import json
4
+ import asyncio
5
+ import re
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+ import psutil
10
+ import tskit
11
+
12
+ from lorax.modes import CURRENT_MODE
13
+ from lorax.cloud.gcs_utils import get_public_gcs_dict
14
+ from lorax.tree_graph import construct_trees_batch, construct_tree, TreeGraph
15
+ from lorax.csv.layout import build_empty_layout_response, build_csv_layout_response
16
+ from lorax.utils import (
17
+ ensure_json_dict,
18
+ list_project_files,
19
+ make_json_serializable,
20
+ )
21
+ from lorax.metadata.loader import (
22
+ get_metadata_for_key,
23
+ search_samples_by_metadata,
24
+ get_metadata_array_for_key,
25
+ _get_sample_metadata_value
26
+ )
27
+ from lorax.metadata.mutations import (
28
+ get_mutations_in_window,
29
+ search_mutations_by_position
30
+ )
31
+ from lorax.buffer import mutations_to_arrow_buffer
32
+ from lorax.cache import get_file_context, get_file_cache_size
33
+ def _get_tip_shift_project_prefixes() -> list[str]:
34
+ """Return project name prefixes that should shift CSV tips to y=1."""
35
+ raw = os.getenv("LORAX_CSV_TIP_SHIFT_PROJECTS", "heliconius")
36
+ parts = [p.strip().lower() for p in raw.split(",") if p.strip()]
37
+ return parts
38
+
39
+
40
+ def should_shift_csv_tips(file_path: str) -> bool:
41
+ """Return True when the file path matches a configured project prefix."""
42
+ if not file_path:
43
+ return False
44
+ prefixes = _get_tip_shift_project_prefixes()
45
+ if not prefixes:
46
+ return False
47
+ parts = re.split(r"[\\/]", str(file_path))
48
+ for part in parts:
49
+ if not part:
50
+ continue
51
+ part_l = str(part).lower()
52
+ if any(part_l.startswith(prefix) for prefix in prefixes):
53
+ return True
54
+ return False
55
+
56
+
57
+ def _is_heliconius_project(file_path: str) -> bool:
58
+ """Return True when the file path indicates a Heliconius project (case-insensitive)."""
59
+ if not file_path:
60
+ return False
61
+ parts = re.split(r"[\\/]", str(file_path))
62
+ for part in parts:
63
+ if part and str(part).lower().startswith("heliconius"):
64
+ return True
65
+ return False
66
+
67
+
68
+ async def cache_status():
69
+ """Return current memory usage and cache statistics."""
70
+ process = psutil.Process(os.getpid())
71
+ mem_info = process.memory_info()
72
+ rss_mb = mem_info.rss / (1024 * 1024)
73
+ vms_mb = mem_info.vms / (1024 * 1024)
74
+
75
+ return {
76
+ "rss_MB": round(rss_mb, 2),
77
+ "vms_MB": round(vms_mb, 2),
78
+ "file_cache_size": get_file_cache_size(),
79
+ "pid": os.getpid(),
80
+ }
81
+
82
+
83
+ async def handle_upload(file_path, root_dir):
84
+ """Load a file and return its FileContext."""
85
+ ctx = await get_file_context(file_path, root_dir)
86
+ print("File loading complete")
87
+ return ctx
88
+
89
+
90
+ async def get_projects(upload_dir, BUCKET_NAME, sid=None):
91
+ """List all projects and their files from local uploads and GCS bucket."""
92
+ projects = {}
93
+ upload_dir = str(upload_dir)
94
+ # Avoid listing Uploads/<sid> as separate projects; add session-scoped uploads below.
95
+ projects = list_project_files(
96
+ upload_dir,
97
+ projects,
98
+ root=upload_dir,
99
+ exclude_dirs=["Uploads"],
100
+ )
101
+
102
+ # Prefer session-scoped Uploads/<sid> when available (non-local); local uses flat Uploads
103
+ upload_files = []
104
+ uploads_root = None
105
+ if CURRENT_MODE == "local":
106
+ uploads_root = os.path.join(upload_dir, "Uploads")
107
+ else:
108
+ uploads_root = os.path.join(upload_dir, "Uploads", sid) if sid else None
109
+
110
+ if uploads_root and os.path.isdir(uploads_root):
111
+ for item in os.listdir(uploads_root):
112
+ if item.endswith((".trees", ".trees.tsz", ".csv")):
113
+ upload_files.append(item)
114
+
115
+ projects["Uploads"] = {
116
+ "folder": "Uploads",
117
+ "files": sorted(set(upload_files)),
118
+ "description": "",
119
+ }
120
+ # Remove accidental project entry created from Uploads/<sid>
121
+ if sid and sid in projects:
122
+ projects.pop(sid, None)
123
+
124
+ # Merge GCS projects: always include non-Uploads; Uploads only per mode rules
125
+ if CURRENT_MODE == "local":
126
+ projects = get_public_gcs_dict(
127
+ BUCKET_NAME,
128
+ sid=sid,
129
+ projects=projects,
130
+ include_uploads=False,
131
+ uploads_sid=None,
132
+ )
133
+ else:
134
+ projects = get_public_gcs_dict(
135
+ BUCKET_NAME,
136
+ sid=sid,
137
+ projects=projects,
138
+ include_uploads=True,
139
+ uploads_sid=sid,
140
+ )
141
+
142
+ return projects
143
+
144
+ def _build_sample_name_mapping(ts, sample_name_key="name"):
145
+ """
146
+ Build mapping from sample name (lowercase) to node_id.
147
+
148
+ Args:
149
+ ts: tskit.TreeSequence
150
+ sample_name_key: Key in node metadata used as sample name
151
+
152
+ Returns:
153
+ dict mapping lowercase sample name to node_id
154
+ """
155
+ name_to_node_id = {}
156
+ for node_id in ts.samples():
157
+ node = ts.node(node_id)
158
+ node_meta = node.metadata or {}
159
+ try:
160
+ node_meta = ensure_json_dict(node_meta)
161
+ except (TypeError, json.JSONDecodeError):
162
+ node_meta = {}
163
+ name = str(node_meta.get(sample_name_key, f"{node_id}"))
164
+ name_to_node_id[name.lower()] = node_id
165
+ return name_to_node_id
166
+
167
+
168
+ def _compute_lineage_paths(tree, tree_seeds, name_map, sample_colors):
169
+ """
170
+ Compute ancestry paths for seed nodes in a tree.
171
+
172
+ Args:
173
+ tree: tskit.Tree object
174
+ tree_seeds: List of seed node IDs to trace ancestry
175
+ name_map: Dict mapping node_id to original name
176
+ sample_colors: Optional dict {sample_name: [r,g,b,a]} for coloring
177
+
178
+ Returns:
179
+ List of lineage dicts with path_node_ids and color
180
+ """
181
+ tree_lineages = []
182
+ for seed_node in tree_seeds:
183
+ # Trace ancestry path from sample to root
184
+ path_nodes = []
185
+ current = seed_node
186
+ while current != -1 and current != tskit.NULL:
187
+ path_nodes.append(current)
188
+ current = tree.parent(current)
189
+
190
+ if len(path_nodes) > 1:
191
+ # Emit root -> tip to match frontend L-shape construction.
192
+ path_nodes = list(reversed(path_nodes))
193
+
194
+ # Get color for this lineage
195
+ name = name_map.get(seed_node, str(seed_node))
196
+ color = None
197
+ if sample_colors:
198
+ color = sample_colors.get(name.lower())
199
+
200
+ tree_lineages.append({
201
+ "path_node_ids": [int(n) for n in path_nodes],
202
+ "color": color
203
+ })
204
+
205
+ return tree_lineages
206
+
207
+
208
+ def search_nodes_in_trees(
209
+ ts,
210
+ sample_names,
211
+ tree_indices,
212
+ show_lineages=False,
213
+ sample_colors=None,
214
+ sample_name_key="name"
215
+ ):
216
+ """
217
+ Search for nodes matching sample names in specified trees.
218
+ Returns highlights and optionally lineage paths.
219
+
220
+ Args:
221
+ ts: tskit.TreeSequence
222
+ sample_names: List of sample names to search for
223
+ tree_indices: List of tree indices to search in
224
+ show_lineages: Whether to compute lineage (ancestry) paths
225
+ sample_colors: Optional dict {sample_name: [r,g,b,a]} for coloring
226
+ sample_name_key: Key in node metadata used as sample name
227
+
228
+ Returns:
229
+ dict with:
230
+ - highlights: {tree_idx: [{node_id, name}]}
231
+ - lineage: {tree_idx: [{path: [[x,y]...], color}]} if show_lineages
232
+ """
233
+ if not sample_names or not tree_indices:
234
+ return {"highlights": {}, "lineage": {}}
235
+
236
+ # Build sample_name -> node_id mapping
237
+ name_to_node_id = _build_sample_name_mapping(ts, sample_name_key)
238
+
239
+ # Convert sample_names to node_ids
240
+ target_node_ids = set()
241
+ name_map = {} # node_id -> original name
242
+ for name in sample_names:
243
+ lower_name = name.lower()
244
+ if lower_name in name_to_node_id:
245
+ nid = name_to_node_id[lower_name]
246
+ target_node_ids.add(nid)
247
+ name_map[nid] = name
248
+
249
+ if not target_node_ids:
250
+ return {"highlights": {}, "lineage": {}}
251
+
252
+ highlights = {}
253
+ lineage = {}
254
+
255
+ for tree_idx in tree_indices:
256
+ tree_idx = int(tree_idx)
257
+ if tree_idx < 0 or tree_idx >= ts.num_trees:
258
+ continue
259
+
260
+ tree = ts.at_index(tree_idx)
261
+
262
+ # Find matching samples in this tree
263
+ tree_highlights = []
264
+ tree_seeds = [] # For lineage computation
265
+
266
+ for node_id in target_node_ids:
267
+ # Check if this sample is in this tree
268
+ if tree.is_sample(node_id):
269
+ name = name_map.get(node_id, str(node_id))
270
+ tree_highlights.append({
271
+ "node_id": int(node_id),
272
+ "name": name
273
+ })
274
+ tree_seeds.append(node_id)
275
+
276
+ if tree_highlights:
277
+ highlights[tree_idx] = tree_highlights
278
+
279
+ # Compute lineage paths if requested
280
+ if show_lineages and tree_seeds:
281
+ tree_lineages = _compute_lineage_paths(
282
+ tree, tree_seeds, name_map, sample_colors
283
+ )
284
+ if tree_lineages:
285
+ lineage[tree_idx] = tree_lineages
286
+
287
+ return {"highlights": highlights, "lineage": lineage}
288
+
289
+
290
+ def get_node_details(ts, node_name):
291
+ """Get details for a specific node in the tree sequence."""
292
+ node = ts.node(node_name)
293
+ return {
294
+ "id": node.id,
295
+ "time": node.time,
296
+ "population": node.population,
297
+ "individual": node.individual,
298
+ "metadata": make_json_serializable(node.metadata)
299
+ }
300
+
301
+
302
+ def get_tree_details(ts, tree_index):
303
+ """Get details for a specific tree at the given index."""
304
+ tree = ts.at_index(tree_index)
305
+
306
+ mutations = []
307
+ for mut in tree.mutations():
308
+ site = ts.site(mut.site)
309
+ mutations.append({
310
+ "id": mut.id,
311
+ "node": mut.node, # Node ID for highlighting
312
+ "site_id": mut.site,
313
+ "position": site.position,
314
+ "derived_state": mut.derived_state,
315
+ "inherited_state": ts.mutation(mut.parent).derived_state if mut.parent != -1 else site.ancestral_state
316
+ })
317
+
318
+ return {
319
+ "interval": tree.interval,
320
+ "num_roots": tree.num_roots,
321
+ "num_nodes": tree.num_nodes,
322
+ "mutations": mutations
323
+ }
324
+
325
+
326
+ def get_individual_details(ts, individual_id):
327
+ """Get details for a specific individual in the tree sequence."""
328
+ individual = ts.individual(individual_id)
329
+ return {
330
+ "id": individual.id,
331
+ "nodes": make_json_serializable(individual.nodes),
332
+ "metadata": make_json_serializable(individual.metadata)
333
+ }
334
+
335
+
336
+ def get_comprehensive_individual_details(ts, individual_id):
337
+ """Get comprehensive individual table data including location, parents, flags."""
338
+ if individual_id is None or individual_id == -1:
339
+ return None
340
+
341
+ individual = ts.individual(individual_id)
342
+ return {
343
+ "id": int(individual.id),
344
+ "flags": int(individual.flags),
345
+ "location": list(individual.location) if len(individual.location) > 0 else None,
346
+ "parents": [int(p) for p in individual.parents] if len(individual.parents) > 0 else [],
347
+ "nodes": [int(n) for n in individual.nodes],
348
+ "metadata": make_json_serializable(individual.metadata)
349
+ }
350
+
351
+
352
+ def get_population_details(ts, population_id):
353
+ """Get population table data."""
354
+ if population_id is None or population_id == -1:
355
+ return None
356
+ pop = ts.population(population_id)
357
+ return {
358
+ "id": int(pop.id),
359
+ "metadata": make_json_serializable(pop.metadata)
360
+ }
361
+
362
+
363
+ def get_mutations_for_node(ts, node_id, tree_index=None):
364
+ """Get all mutations on a specific node, optionally filtered by tree interval."""
365
+ mutations = []
366
+
367
+ # Get tree interval if tree_index is specified
368
+ tree_interval = None
369
+ if tree_index is not None:
370
+ tree = ts.at_index(tree_index)
371
+ tree_interval = tree.interval
372
+
373
+ for mut in ts.mutations():
374
+ if mut.node == node_id:
375
+ site = ts.site(mut.site)
376
+
377
+ # Filter by tree interval if specified
378
+ if tree_interval is not None:
379
+ if not (site.position >= tree_interval.left and site.position < tree_interval.right):
380
+ continue
381
+
382
+ mutations.append({
383
+ "id": int(mut.id),
384
+ "site_id": int(mut.site),
385
+ "position": float(site.position),
386
+ "ancestral_state": site.ancestral_state,
387
+ "derived_state": mut.derived_state,
388
+ "time": float(mut.time) if mut.time != tskit.UNKNOWN_TIME else None,
389
+ "parent_mutation": int(mut.parent) if mut.parent != -1 else None,
390
+ "metadata": make_json_serializable(mut.metadata) if mut.metadata else None
391
+ })
392
+
393
+ return mutations
394
+
395
+
396
+ def get_edges_for_node(ts, node_id, tree_index=None):
397
+ """Get all edges where this node is parent or child."""
398
+ edges = {
399
+ "as_parent": [], # Edges where node is parent
400
+ "as_child": [] # Edges where node is child
401
+ }
402
+
403
+ # Get tree interval if tree_index is specified
404
+ tree_interval = None
405
+ if tree_index is not None:
406
+ tree = ts.at_index(tree_index)
407
+ tree_interval = tree.interval
408
+
409
+ for edge in ts.edges():
410
+ # Filter by tree interval if specified (edge must overlap with tree)
411
+ if tree_interval is not None:
412
+ if edge.right <= tree_interval.left or edge.left >= tree_interval.right:
413
+ continue
414
+
415
+ edge_data = {
416
+ "id": int(edge.id),
417
+ "left": float(edge.left),
418
+ "right": float(edge.right),
419
+ "parent": int(edge.parent),
420
+ "child": int(edge.child)
421
+ }
422
+
423
+ if edge.parent == node_id:
424
+ edges["as_parent"].append(edge_data)
425
+ if edge.child == node_id:
426
+ edges["as_child"].append(edge_data)
427
+
428
+ return edges
429
+
430
+
431
+ async def handle_details(file_path, data):
432
+ """Handle requests for tree, node, and individual details."""
433
+ try:
434
+ ctx = await get_file_context(file_path)
435
+ if ctx is None:
436
+ return json.dumps({"error": "Tree sequence (ts) is not set. Please upload a file first."})
437
+
438
+ ts = ctx.tree_sequence
439
+ return_data = {}
440
+ tree_index = data.get("treeIndex")
441
+ comprehensive = data.get("comprehensive", False)
442
+
443
+ if tree_index is not None:
444
+ return_data["tree"] = get_tree_details(ts, int(tree_index))
445
+
446
+ node_name = data.get("node")
447
+ if node_name is not None:
448
+ node_id = int(node_name)
449
+ node_details = get_node_details(ts, node_id)
450
+ return_data["node"] = node_details
451
+
452
+ # Auto-fetch individual details
453
+ if node_details.get("individual") != -1:
454
+ if comprehensive:
455
+ return_data["individual"] = get_comprehensive_individual_details(
456
+ ts, node_details.get("individual")
457
+ )
458
+ else:
459
+ return_data["individual"] = get_individual_details(
460
+ ts, node_details.get("individual")
461
+ )
462
+
463
+ # Comprehensive mode: add population, mutations, edges
464
+ if comprehensive:
465
+ # Population
466
+ if node_details.get("population") != -1:
467
+ return_data["population"] = get_population_details(
468
+ ts, node_details.get("population")
469
+ )
470
+
471
+ # Mutations on this node
472
+ return_data["mutations"] = get_mutations_for_node(
473
+ ts, node_id, tree_index
474
+ )
475
+
476
+ # Edges for this node
477
+ return_data["edges"] = get_edges_for_node(
478
+ ts, node_id, tree_index
479
+ )
480
+
481
+ return json.dumps(return_data)
482
+ except Exception as e:
483
+ return json.dumps({"error": f"Error getting details: {str(e)}"})
484
+
485
+
486
+ async def handle_tree_graph_query(
487
+ file_path,
488
+ tree_indices,
489
+ sparsification=False,
490
+ session_id: str = None,
491
+ tree_graph_cache=None,
492
+ csv_tree_graph_cache=None,
493
+ actual_display_array=None
494
+ ):
495
+ """
496
+ Construct trees using Numba-optimized tree_graph module.
497
+
498
+ Args:
499
+ file_path: Path to tree sequence file
500
+ tree_indices: List of tree indices to process
501
+ sparsification: Enable tip-only sparsification (default False)
502
+ session_id: Session ID for cache lookup/storage
503
+ tree_graph_cache: TreeGraphCache instance for caching TreeGraph objects
504
+
505
+ Returns:
506
+ dict with:
507
+ - buffer: PyArrow IPC binary data containing:
508
+ - node_id: int32 (tskit node ID)
509
+ - parent_id: int32 (-1 for roots)
510
+ - is_tip: bool
511
+ - tree_idx: int32 (which tree this node belongs to)
512
+ - x: float32 (time-based coordinate [0,1])
513
+ - y: float32 (layout-based coordinate [0,1])
514
+ - global_min_time: float
515
+ - global_max_time: float
516
+ - tree_indices: list[int]
517
+ """
518
+ ctx = await get_file_context(file_path)
519
+ if ctx is None:
520
+ return {"error": "Tree sequence not loaded. Please load a file first."}
521
+
522
+ ts = ctx.tree_sequence
523
+
524
+ # CSV support: parse Newick strings and build tree layout
525
+ if isinstance(ts, pd.DataFrame):
526
+ shift_tips_to_one = should_shift_csv_tips(ctx.file_path or file_path)
527
+ # Get max_branch_length from config (times.values[1])
528
+ times_values = ctx.config.get("times", {}).get("values", [0.0, 1.0])
529
+ max_branch_length = float(times_values[1]) if len(times_values) > 1 else 1.0
530
+ indices = [int(t) for t in (tree_indices or [])]
531
+ samples_order = ctx.config.get("samples") or []
532
+ pre_parsed_graphs = {}
533
+ if session_id and csv_tree_graph_cache:
534
+ from lorax.csv.newick_tree import parse_newick_to_tree
535
+
536
+ for tree_idx in indices:
537
+ cached = await csv_tree_graph_cache.get(session_id, int(tree_idx))
538
+ if cached is not None:
539
+ pre_parsed_graphs[int(tree_idx)] = cached
540
+ continue
541
+
542
+ # Cache miss: parse and store (best-effort)
543
+ try:
544
+ newick_str = ts.iloc[int(tree_idx)].get("newick")
545
+ except Exception:
546
+ newick_str = None
547
+ if newick_str is None or pd.isna(newick_str):
548
+ continue
549
+
550
+ tree_max_branch_length = None
551
+ if "max_branch_length" in ts.columns:
552
+ try:
553
+ v = ts.iloc[int(tree_idx)].get("max_branch_length")
554
+ if v is not None and not (isinstance(v, float) and pd.isna(v)) and str(v).strip() != "":
555
+ tree_max_branch_length = float(v)
556
+ except Exception:
557
+ tree_max_branch_length = None
558
+
559
+ try:
560
+ graph = await asyncio.to_thread(
561
+ parse_newick_to_tree,
562
+ str(newick_str),
563
+ max_branch_length,
564
+ samples_order=samples_order,
565
+ tree_max_branch_length=tree_max_branch_length,
566
+ shift_tips_to_one=shift_tips_to_one,
567
+ )
568
+ except Exception:
569
+ continue
570
+
571
+ pre_parsed_graphs[int(tree_idx)] = graph
572
+ await csv_tree_graph_cache.set(session_id, int(tree_idx), graph)
573
+
574
+ if actual_display_array is not None:
575
+ await csv_tree_graph_cache.evict_not_visible(session_id, set(actual_display_array))
576
+
577
+ return build_csv_layout_response(
578
+ ts,
579
+ indices,
580
+ max_branch_length,
581
+ samples_order=samples_order,
582
+ pre_parsed_graphs=pre_parsed_graphs,
583
+ shift_tips_to_one=shift_tips_to_one,
584
+ )
585
+
586
+ # Collect pre-cached TreeGraphs
587
+ pre_cached_graphs = {}
588
+ if session_id and tree_graph_cache:
589
+ for tree_idx in tree_indices:
590
+ cached = await tree_graph_cache.get(session_id, int(tree_idx))
591
+ if cached is not None:
592
+ pre_cached_graphs[int(tree_idx)] = cached
593
+ if pre_cached_graphs:
594
+ print(f"TreeGraph cache hits: {len(pre_cached_graphs)}/{len(tree_indices)} trees")
595
+
596
+ # Run in thread pool to avoid blocking
597
+ def process_trees():
598
+ return construct_trees_batch(
599
+ ts,
600
+ tree_indices,
601
+ sparsification=sparsification,
602
+ pre_cached_graphs=pre_cached_graphs
603
+ )
604
+
605
+ buffer, min_time, max_time, processed_indices, newly_built = await asyncio.to_thread(process_trees)
606
+
607
+ # Cache newly built TreeGraphs
608
+ if session_id and tree_graph_cache and newly_built:
609
+ for tree_idx, graph in newly_built.items():
610
+ await tree_graph_cache.set(session_id, tree_idx, graph)
611
+ print(f"TreeGraph cached: {len(newly_built)} new trees for session {session_id[:8]}...")
612
+
613
+ # Evict trees no longer in visible set (visibility-based eviction)
614
+ if session_id and tree_graph_cache and actual_display_array is not None:
615
+ await tree_graph_cache.evict_not_visible(session_id, set(actual_display_array))
616
+
617
+ return {
618
+ "buffer": buffer,
619
+ "global_min_time": min_time,
620
+ "global_max_time": max_time,
621
+ "tree_indices": processed_indices
622
+ }
623
+
624
+
625
+ async def get_or_construct_tree_graph(
626
+ file_path: str,
627
+ tree_index: int,
628
+ session_id: str,
629
+ tree_graph_cache
630
+ ) -> TreeGraph:
631
+ """
632
+ Get a TreeGraph from cache or construct and cache it.
633
+
634
+ This function is used by lineage operations that need the full TreeGraph
635
+ structure for ancestor/descendant traversal.
636
+
637
+ Args:
638
+ file_path: Path to tree sequence file
639
+ tree_index: Index of the tree to get
640
+ session_id: Session ID for cache key
641
+ tree_graph_cache: TreeGraphCache instance
642
+
643
+ Returns:
644
+ TreeGraph object, or None if file not loaded
645
+ """
646
+ # Check cache first
647
+ cached = await tree_graph_cache.get(session_id, tree_index)
648
+ if cached is not None:
649
+ print(f"TreeGraph cache hit: session={session_id[:8]}... tree={tree_index}")
650
+ return cached
651
+
652
+ # Load file context
653
+ ctx = await get_file_context(file_path)
654
+ if ctx is None:
655
+ return None
656
+
657
+ ts = ctx.tree_sequence
658
+
659
+ # Can't construct TreeGraph for CSV
660
+ if isinstance(ts, pd.DataFrame):
661
+ return None
662
+
663
+ # Construct tree graph
664
+ def _construct():
665
+ edges = ts.tables.edges
666
+ nodes = ts.tables.nodes
667
+ breakpoints = list(ts.breakpoints())
668
+ min_time = float(ts.min_time)
669
+ max_time = float(ts.max_time)
670
+ return construct_tree(ts, edges, nodes, breakpoints, tree_index, min_time, max_time)
671
+
672
+ tree_graph = await asyncio.to_thread(_construct)
673
+
674
+ # Cache it
675
+ await tree_graph_cache.set(session_id, tree_index, tree_graph)
676
+ print(f"TreeGraph cached: session={session_id[:8]}... tree={tree_index}")
677
+
678
+ return tree_graph
679
+
680
+
681
+ async def ensure_trees_cached(
682
+ file_path: str,
683
+ tree_indices: list,
684
+ session_id: str,
685
+ tree_graph_cache
686
+ ) -> int:
687
+ """
688
+ Ensure multiple trees are cached for a session.
689
+
690
+ This is called after process_postorder_layout to cache trees for
691
+ subsequent lineage operations.
692
+
693
+ Args:
694
+ file_path: Path to tree sequence file
695
+ tree_indices: List of tree indices to cache
696
+ session_id: Session ID for cache key
697
+ tree_graph_cache: TreeGraphCache instance
698
+
699
+ Returns:
700
+ Number of trees newly cached (not already in cache)
701
+ """
702
+ ctx = await get_file_context(file_path)
703
+ if ctx is None:
704
+ return 0
705
+
706
+ ts = ctx.tree_sequence
707
+
708
+ if isinstance(ts, pd.DataFrame):
709
+ return 0
710
+
711
+ newly_cached = 0
712
+
713
+ # Pre-extract tables for efficiency
714
+ edges = ts.tables.edges
715
+ nodes = ts.tables.nodes
716
+ breakpoints = list(ts.breakpoints())
717
+ min_time = float(ts.min_time)
718
+ max_time = float(ts.max_time)
719
+
720
+ for tree_index in tree_indices:
721
+ tree_index = int(tree_index)
722
+
723
+ # Skip if already cached
724
+ cached = await tree_graph_cache.get(session_id, tree_index)
725
+ if cached is not None:
726
+ continue
727
+
728
+ # Construct and cache
729
+ def _construct(idx):
730
+ return construct_tree(ts, edges, nodes, breakpoints, idx, min_time, max_time)
731
+
732
+ tree_graph = await asyncio.to_thread(_construct, tree_index)
733
+ await tree_graph_cache.set(session_id, tree_index, tree_graph)
734
+ newly_cached += 1
735
+
736
+ if newly_cached > 0:
737
+ print(f"Cached {newly_cached} trees for session {session_id[:8]}...")
738
+
739
+ return newly_cached
740
+
741
+
742
+ def _get_matching_sample_nodes(ts, metadata_key, metadata_value, sources, sample_name_key):
743
+ """
744
+ Find all sample node IDs that match a metadata value.
745
+
746
+ Args:
747
+ ts: tskit.TreeSequence
748
+ metadata_key: Metadata key to filter by
749
+ metadata_value: Metadata value to match
750
+ sources: Metadata sources to search
751
+ sample_name_key: Key in node metadata used as sample name
752
+
753
+ Returns:
754
+ Set of matching node IDs
755
+ """
756
+ matching_node_ids = set()
757
+ for node_id in ts.samples():
758
+ sample_name, value = _get_sample_metadata_value(
759
+ ts, node_id, metadata_key, sources, sample_name_key
760
+ )
761
+ if value is not None and str(value) == str(metadata_value):
762
+ matching_node_ids.add(node_id)
763
+ return matching_node_ids
764
+
765
+
766
+ async def _ensure_tree_graph_loaded(
767
+ ts,
768
+ tree_idx,
769
+ session_id,
770
+ tree_graph_cache,
771
+ edges,
772
+ nodes,
773
+ breakpoints,
774
+ min_time,
775
+ max_time
776
+ ):
777
+ """
778
+ Get tree graph from cache or construct and cache it.
779
+
780
+ Args:
781
+ ts: tskit.TreeSequence
782
+ tree_idx: Tree index to load
783
+ session_id: Session ID for cache key
784
+ tree_graph_cache: TreeGraphCache instance
785
+ edges, nodes, breakpoints, min_time, max_time: Pre-extracted table data
786
+
787
+ Returns:
788
+ TreeGraph object
789
+ """
790
+ from lorax.tree_graph import construct_tree
791
+
792
+ # Try to get from cache first
793
+ graph = await tree_graph_cache.get(session_id, tree_idx)
794
+ if graph is not None:
795
+ return graph
796
+
797
+ # Construct tree graph
798
+ def _construct():
799
+ return construct_tree(ts, edges, nodes, breakpoints, tree_idx, min_time, max_time)
800
+
801
+ graph = await asyncio.to_thread(_construct)
802
+
803
+ # Cache it for future use
804
+ await tree_graph_cache.set(session_id, tree_idx, graph)
805
+
806
+ return graph
807
+
808
+
809
+ async def get_highlight_positions(
810
+ ts,
811
+ file_path,
812
+ metadata_key,
813
+ metadata_value,
814
+ tree_indices,
815
+ session_id: str,
816
+ tree_graph_cache,
817
+ sources=("individual", "node", "population"),
818
+ sample_name_key="name"
819
+ ):
820
+ """
821
+ Get positions for all tip nodes with a specific metadata value.
822
+ Uses cached TreeGraph objects when available.
823
+
824
+ Args:
825
+ ts: tskit.TreeSequence
826
+ file_path: Path to tree sequence file (for cache key)
827
+ metadata_key: Metadata key to filter by
828
+ metadata_value: Metadata value to match
829
+ tree_indices: List of tree indices to compute positions for
830
+ session_id: Session ID for cache lookup
831
+ tree_graph_cache: TreeGraphCache instance
832
+ sources: Metadata sources to search
833
+ sample_name_key: Key in node metadata used as sample name
834
+
835
+ Returns:
836
+ dict with:
837
+ - positions: List of {node_id, tree_idx, x, y} dicts
838
+ """
839
+ if not tree_indices:
840
+ return {"positions": []}
841
+
842
+ # Get sample node IDs that have this metadata value
843
+ matching_node_ids = _get_matching_sample_nodes(
844
+ ts, metadata_key, metadata_value, sources, sample_name_key
845
+ )
846
+
847
+ if not matching_node_ids:
848
+ return {"positions": []}
849
+
850
+ # Pre-extract tables for reuse (only needed if cache miss)
851
+ edges = ts.tables.edges
852
+ nodes = ts.tables.nodes
853
+ breakpoints = list(ts.breakpoints())
854
+ min_time = float(ts.min_time)
855
+ max_time = float(ts.max_time)
856
+
857
+ positions = []
858
+
859
+ # For each requested tree, get graph and extract positions
860
+ for tree_idx in tree_indices:
861
+ tree_idx = int(tree_idx)
862
+ if tree_idx < 0 or tree_idx >= ts.num_trees:
863
+ continue
864
+
865
+ graph = await _ensure_tree_graph_loaded(
866
+ ts, tree_idx, session_id, tree_graph_cache,
867
+ edges, nodes, breakpoints, min_time, max_time
868
+ )
869
+
870
+ # Extract positions for matching nodes that are in this tree
871
+ for node_id in matching_node_ids:
872
+ if graph.in_tree[node_id]:
873
+ positions.append({
874
+ "node_id": int(node_id),
875
+ "tree_idx": tree_idx,
876
+ "x": float(graph.x[node_id]),
877
+ "y": float(graph.y[node_id])
878
+ })
879
+
880
+ return {"positions": positions}
881
+
882
+
883
+ async def get_multi_value_highlight_positions(
884
+ ts,
885
+ file_path,
886
+ metadata_key,
887
+ metadata_values, # List[str] - Array of values (OR logic)
888
+ tree_indices,
889
+ session_id: str,
890
+ tree_graph_cache,
891
+ show_lineages: bool = False,
892
+ sources=("individual", "node", "population"),
893
+ sample_name_key="name"
894
+ ):
895
+ """
896
+ Get positions for tip nodes matching ANY of the metadata values.
897
+ Returns positions grouped by value for per-value coloring.
898
+
899
+ Args:
900
+ ts: tskit.TreeSequence
901
+ file_path: Path to tree sequence file (for cache key)
902
+ metadata_key: Metadata key to filter by
903
+ metadata_values: List of metadata values to match (OR logic)
904
+ tree_indices: List of tree indices to compute positions for
905
+ session_id: Session ID for cache lookup
906
+ tree_graph_cache: TreeGraphCache instance
907
+ show_lineages: Whether to compute lineage (ancestry) paths
908
+ sources: Metadata sources to search
909
+ sample_name_key: Key in node metadata used as sample name
910
+
911
+ Returns:
912
+ dict with:
913
+ - positions_by_value: {"Africa": [{node_id, tree_idx, x, y}, ...], ...}
914
+ - lineages: {"Africa": {tree_idx: [{path_node_ids, color}]}} if show_lineages
915
+ - total_count: int
916
+ """
917
+ if not tree_indices or not metadata_values:
918
+ return {"positions_by_value": {}, "lineages": {}, "total_count": 0}
919
+
920
+ # Deduplicate values
921
+ unique_values = list(set(str(v) for v in metadata_values))
922
+
923
+ # Pre-extract tables for reuse (only needed if cache miss)
924
+ edges = ts.tables.edges
925
+ nodes = ts.tables.nodes
926
+ breakpoints = list(ts.breakpoints())
927
+ min_time = float(ts.min_time)
928
+ max_time = float(ts.max_time)
929
+
930
+ positions_by_value = {}
931
+ lineages = {} if show_lineages else None
932
+ total_count = 0
933
+
934
+ # For each value, find matching samples
935
+ for value in unique_values:
936
+ matching_node_ids = _get_matching_sample_nodes(
937
+ ts, metadata_key, value, sources, sample_name_key
938
+ )
939
+
940
+ if not matching_node_ids:
941
+ positions_by_value[value] = []
942
+ continue
943
+
944
+ value_positions = []
945
+ value_lineages = {} if show_lineages else None
946
+
947
+ # For each requested tree, get graph and extract positions
948
+ for tree_idx in tree_indices:
949
+ tree_idx = int(tree_idx)
950
+ if tree_idx < 0 or tree_idx >= ts.num_trees:
951
+ continue
952
+
953
+ graph = await _ensure_tree_graph_loaded(
954
+ ts, tree_idx, session_id, tree_graph_cache,
955
+ edges, nodes, breakpoints, min_time, max_time
956
+ )
957
+
958
+ tree_positions = []
959
+ tree_seeds = [] # For lineage computation
960
+
961
+ # Extract positions for matching nodes that are in this tree
962
+ for node_id in matching_node_ids:
963
+ if graph.in_tree[node_id]:
964
+ tree_positions.append({
965
+ "node_id": int(node_id),
966
+ "tree_idx": tree_idx,
967
+ "x": float(graph.x[node_id]),
968
+ "y": float(graph.y[node_id])
969
+ })
970
+ tree_seeds.append(node_id)
971
+
972
+ value_positions.extend(tree_positions)
973
+
974
+ # Compute lineage paths if requested
975
+ if show_lineages and tree_seeds:
976
+ tree = ts.at_index(tree_idx)
977
+ name_map = {nid: str(nid) for nid in tree_seeds}
978
+ tree_lineages = _compute_lineage_paths(
979
+ tree, tree_seeds, name_map, None # No per-sample colors, use value color
980
+ )
981
+ if tree_lineages:
982
+ value_lineages[tree_idx] = tree_lineages
983
+
984
+ positions_by_value[value] = value_positions
985
+ total_count += len(value_positions)
986
+
987
+ if show_lineages and value_lineages:
988
+ lineages[value] = value_lineages
989
+
990
+ result = {
991
+ "positions_by_value": positions_by_value,
992
+ "total_count": total_count
993
+ }
994
+
995
+ if show_lineages:
996
+ result["lineages"] = lineages
997
+
998
+ return result