lorax-arg 0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lorax/buffer.py +43 -0
- lorax/cache/__init__.py +43 -0
- lorax/cache/csv_tree_graph.py +59 -0
- lorax/cache/disk.py +467 -0
- lorax/cache/file_cache.py +142 -0
- lorax/cache/file_context.py +72 -0
- lorax/cache/lru.py +90 -0
- lorax/cache/tree_graph.py +293 -0
- lorax/cli.py +312 -0
- lorax/cloud/__init__.py +0 -0
- lorax/cloud/gcs_utils.py +205 -0
- lorax/constants.py +66 -0
- lorax/context.py +80 -0
- lorax/csv/__init__.py +7 -0
- lorax/csv/config.py +250 -0
- lorax/csv/layout.py +182 -0
- lorax/csv/newick_tree.py +234 -0
- lorax/handlers.py +998 -0
- lorax/lineage.py +456 -0
- lorax/loaders/__init__.py +0 -0
- lorax/loaders/csv_loader.py +10 -0
- lorax/loaders/loader.py +31 -0
- lorax/loaders/tskit_loader.py +119 -0
- lorax/lorax_app.py +75 -0
- lorax/manager.py +58 -0
- lorax/metadata/__init__.py +0 -0
- lorax/metadata/loader.py +426 -0
- lorax/metadata/mutations.py +146 -0
- lorax/modes.py +190 -0
- lorax/pg.py +183 -0
- lorax/redis_utils.py +30 -0
- lorax/routes.py +137 -0
- lorax/session_manager.py +206 -0
- lorax/sockets/__init__.py +55 -0
- lorax/sockets/connection.py +99 -0
- lorax/sockets/debug.py +47 -0
- lorax/sockets/decorators.py +112 -0
- lorax/sockets/file_ops.py +200 -0
- lorax/sockets/lineage.py +307 -0
- lorax/sockets/metadata.py +232 -0
- lorax/sockets/mutations.py +154 -0
- lorax/sockets/node_search.py +535 -0
- lorax/sockets/tree_layout.py +117 -0
- lorax/sockets/utils.py +10 -0
- lorax/tree_graph/__init__.py +12 -0
- lorax/tree_graph/tree_graph.py +689 -0
- lorax/utils.py +124 -0
- lorax_app/__init__.py +4 -0
- lorax_app/app.py +159 -0
- lorax_app/cli.py +114 -0
- lorax_app/static/X.png +0 -0
- lorax_app/static/assets/index-BCEGlUFi.js +2361 -0
- lorax_app/static/assets/index-iKjzUpA9.css +1 -0
- lorax_app/static/assets/localBackendWorker-BaWwjSV_.js +2 -0
- lorax_app/static/assets/renderDataWorker-BKLdiU7J.js +2 -0
- lorax_app/static/gestures/gesture-flick.ogv +0 -0
- lorax_app/static/gestures/gesture-two-finger-scroll.ogv +0 -0
- lorax_app/static/index.html +14 -0
- lorax_app/static/logo.png +0 -0
- lorax_app/static/lorax-logo.png +0 -0
- lorax_app/static/vite.svg +1 -0
- lorax_arg-0.1.dist-info/METADATA +131 -0
- lorax_arg-0.1.dist-info/RECORD +66 -0
- lorax_arg-0.1.dist-info/WHEEL +5 -0
- lorax_arg-0.1.dist-info/entry_points.txt +4 -0
- lorax_arg-0.1.dist-info/top_level.txt +2 -0
lorax/handlers.py
ADDED
|
@@ -0,0 +1,998 @@
|
|
|
1
|
+
# handlers.py
|
|
2
|
+
import os
|
|
3
|
+
import json
|
|
4
|
+
import asyncio
|
|
5
|
+
import re
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import pandas as pd
|
|
9
|
+
import psutil
|
|
10
|
+
import tskit
|
|
11
|
+
|
|
12
|
+
from lorax.modes import CURRENT_MODE
|
|
13
|
+
from lorax.cloud.gcs_utils import get_public_gcs_dict
|
|
14
|
+
from lorax.tree_graph import construct_trees_batch, construct_tree, TreeGraph
|
|
15
|
+
from lorax.csv.layout import build_empty_layout_response, build_csv_layout_response
|
|
16
|
+
from lorax.utils import (
|
|
17
|
+
ensure_json_dict,
|
|
18
|
+
list_project_files,
|
|
19
|
+
make_json_serializable,
|
|
20
|
+
)
|
|
21
|
+
from lorax.metadata.loader import (
|
|
22
|
+
get_metadata_for_key,
|
|
23
|
+
search_samples_by_metadata,
|
|
24
|
+
get_metadata_array_for_key,
|
|
25
|
+
_get_sample_metadata_value
|
|
26
|
+
)
|
|
27
|
+
from lorax.metadata.mutations import (
|
|
28
|
+
get_mutations_in_window,
|
|
29
|
+
search_mutations_by_position
|
|
30
|
+
)
|
|
31
|
+
from lorax.buffer import mutations_to_arrow_buffer
|
|
32
|
+
from lorax.cache import get_file_context, get_file_cache_size
|
|
33
|
+
def _get_tip_shift_project_prefixes() -> list[str]:
|
|
34
|
+
"""Return project name prefixes that should shift CSV tips to y=1."""
|
|
35
|
+
raw = os.getenv("LORAX_CSV_TIP_SHIFT_PROJECTS", "heliconius")
|
|
36
|
+
parts = [p.strip().lower() for p in raw.split(",") if p.strip()]
|
|
37
|
+
return parts
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def should_shift_csv_tips(file_path: str) -> bool:
|
|
41
|
+
"""Return True when the file path matches a configured project prefix."""
|
|
42
|
+
if not file_path:
|
|
43
|
+
return False
|
|
44
|
+
prefixes = _get_tip_shift_project_prefixes()
|
|
45
|
+
if not prefixes:
|
|
46
|
+
return False
|
|
47
|
+
parts = re.split(r"[\\/]", str(file_path))
|
|
48
|
+
for part in parts:
|
|
49
|
+
if not part:
|
|
50
|
+
continue
|
|
51
|
+
part_l = str(part).lower()
|
|
52
|
+
if any(part_l.startswith(prefix) for prefix in prefixes):
|
|
53
|
+
return True
|
|
54
|
+
return False
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _is_heliconius_project(file_path: str) -> bool:
|
|
58
|
+
"""Return True when the file path indicates a Heliconius project (case-insensitive)."""
|
|
59
|
+
if not file_path:
|
|
60
|
+
return False
|
|
61
|
+
parts = re.split(r"[\\/]", str(file_path))
|
|
62
|
+
for part in parts:
|
|
63
|
+
if part and str(part).lower().startswith("heliconius"):
|
|
64
|
+
return True
|
|
65
|
+
return False
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
async def cache_status():
|
|
69
|
+
"""Return current memory usage and cache statistics."""
|
|
70
|
+
process = psutil.Process(os.getpid())
|
|
71
|
+
mem_info = process.memory_info()
|
|
72
|
+
rss_mb = mem_info.rss / (1024 * 1024)
|
|
73
|
+
vms_mb = mem_info.vms / (1024 * 1024)
|
|
74
|
+
|
|
75
|
+
return {
|
|
76
|
+
"rss_MB": round(rss_mb, 2),
|
|
77
|
+
"vms_MB": round(vms_mb, 2),
|
|
78
|
+
"file_cache_size": get_file_cache_size(),
|
|
79
|
+
"pid": os.getpid(),
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
async def handle_upload(file_path, root_dir):
|
|
84
|
+
"""Load a file and return its FileContext."""
|
|
85
|
+
ctx = await get_file_context(file_path, root_dir)
|
|
86
|
+
print("File loading complete")
|
|
87
|
+
return ctx
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
async def get_projects(upload_dir, BUCKET_NAME, sid=None):
|
|
91
|
+
"""List all projects and their files from local uploads and GCS bucket."""
|
|
92
|
+
projects = {}
|
|
93
|
+
upload_dir = str(upload_dir)
|
|
94
|
+
# Avoid listing Uploads/<sid> as separate projects; add session-scoped uploads below.
|
|
95
|
+
projects = list_project_files(
|
|
96
|
+
upload_dir,
|
|
97
|
+
projects,
|
|
98
|
+
root=upload_dir,
|
|
99
|
+
exclude_dirs=["Uploads"],
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
# Prefer session-scoped Uploads/<sid> when available (non-local); local uses flat Uploads
|
|
103
|
+
upload_files = []
|
|
104
|
+
uploads_root = None
|
|
105
|
+
if CURRENT_MODE == "local":
|
|
106
|
+
uploads_root = os.path.join(upload_dir, "Uploads")
|
|
107
|
+
else:
|
|
108
|
+
uploads_root = os.path.join(upload_dir, "Uploads", sid) if sid else None
|
|
109
|
+
|
|
110
|
+
if uploads_root and os.path.isdir(uploads_root):
|
|
111
|
+
for item in os.listdir(uploads_root):
|
|
112
|
+
if item.endswith((".trees", ".trees.tsz", ".csv")):
|
|
113
|
+
upload_files.append(item)
|
|
114
|
+
|
|
115
|
+
projects["Uploads"] = {
|
|
116
|
+
"folder": "Uploads",
|
|
117
|
+
"files": sorted(set(upload_files)),
|
|
118
|
+
"description": "",
|
|
119
|
+
}
|
|
120
|
+
# Remove accidental project entry created from Uploads/<sid>
|
|
121
|
+
if sid and sid in projects:
|
|
122
|
+
projects.pop(sid, None)
|
|
123
|
+
|
|
124
|
+
# Merge GCS projects: always include non-Uploads; Uploads only per mode rules
|
|
125
|
+
if CURRENT_MODE == "local":
|
|
126
|
+
projects = get_public_gcs_dict(
|
|
127
|
+
BUCKET_NAME,
|
|
128
|
+
sid=sid,
|
|
129
|
+
projects=projects,
|
|
130
|
+
include_uploads=False,
|
|
131
|
+
uploads_sid=None,
|
|
132
|
+
)
|
|
133
|
+
else:
|
|
134
|
+
projects = get_public_gcs_dict(
|
|
135
|
+
BUCKET_NAME,
|
|
136
|
+
sid=sid,
|
|
137
|
+
projects=projects,
|
|
138
|
+
include_uploads=True,
|
|
139
|
+
uploads_sid=sid,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
return projects
|
|
143
|
+
|
|
144
|
+
def _build_sample_name_mapping(ts, sample_name_key="name"):
|
|
145
|
+
"""
|
|
146
|
+
Build mapping from sample name (lowercase) to node_id.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
ts: tskit.TreeSequence
|
|
150
|
+
sample_name_key: Key in node metadata used as sample name
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
dict mapping lowercase sample name to node_id
|
|
154
|
+
"""
|
|
155
|
+
name_to_node_id = {}
|
|
156
|
+
for node_id in ts.samples():
|
|
157
|
+
node = ts.node(node_id)
|
|
158
|
+
node_meta = node.metadata or {}
|
|
159
|
+
try:
|
|
160
|
+
node_meta = ensure_json_dict(node_meta)
|
|
161
|
+
except (TypeError, json.JSONDecodeError):
|
|
162
|
+
node_meta = {}
|
|
163
|
+
name = str(node_meta.get(sample_name_key, f"{node_id}"))
|
|
164
|
+
name_to_node_id[name.lower()] = node_id
|
|
165
|
+
return name_to_node_id
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def _compute_lineage_paths(tree, tree_seeds, name_map, sample_colors):
|
|
169
|
+
"""
|
|
170
|
+
Compute ancestry paths for seed nodes in a tree.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
tree: tskit.Tree object
|
|
174
|
+
tree_seeds: List of seed node IDs to trace ancestry
|
|
175
|
+
name_map: Dict mapping node_id to original name
|
|
176
|
+
sample_colors: Optional dict {sample_name: [r,g,b,a]} for coloring
|
|
177
|
+
|
|
178
|
+
Returns:
|
|
179
|
+
List of lineage dicts with path_node_ids and color
|
|
180
|
+
"""
|
|
181
|
+
tree_lineages = []
|
|
182
|
+
for seed_node in tree_seeds:
|
|
183
|
+
# Trace ancestry path from sample to root
|
|
184
|
+
path_nodes = []
|
|
185
|
+
current = seed_node
|
|
186
|
+
while current != -1 and current != tskit.NULL:
|
|
187
|
+
path_nodes.append(current)
|
|
188
|
+
current = tree.parent(current)
|
|
189
|
+
|
|
190
|
+
if len(path_nodes) > 1:
|
|
191
|
+
# Emit root -> tip to match frontend L-shape construction.
|
|
192
|
+
path_nodes = list(reversed(path_nodes))
|
|
193
|
+
|
|
194
|
+
# Get color for this lineage
|
|
195
|
+
name = name_map.get(seed_node, str(seed_node))
|
|
196
|
+
color = None
|
|
197
|
+
if sample_colors:
|
|
198
|
+
color = sample_colors.get(name.lower())
|
|
199
|
+
|
|
200
|
+
tree_lineages.append({
|
|
201
|
+
"path_node_ids": [int(n) for n in path_nodes],
|
|
202
|
+
"color": color
|
|
203
|
+
})
|
|
204
|
+
|
|
205
|
+
return tree_lineages
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def search_nodes_in_trees(
|
|
209
|
+
ts,
|
|
210
|
+
sample_names,
|
|
211
|
+
tree_indices,
|
|
212
|
+
show_lineages=False,
|
|
213
|
+
sample_colors=None,
|
|
214
|
+
sample_name_key="name"
|
|
215
|
+
):
|
|
216
|
+
"""
|
|
217
|
+
Search for nodes matching sample names in specified trees.
|
|
218
|
+
Returns highlights and optionally lineage paths.
|
|
219
|
+
|
|
220
|
+
Args:
|
|
221
|
+
ts: tskit.TreeSequence
|
|
222
|
+
sample_names: List of sample names to search for
|
|
223
|
+
tree_indices: List of tree indices to search in
|
|
224
|
+
show_lineages: Whether to compute lineage (ancestry) paths
|
|
225
|
+
sample_colors: Optional dict {sample_name: [r,g,b,a]} for coloring
|
|
226
|
+
sample_name_key: Key in node metadata used as sample name
|
|
227
|
+
|
|
228
|
+
Returns:
|
|
229
|
+
dict with:
|
|
230
|
+
- highlights: {tree_idx: [{node_id, name}]}
|
|
231
|
+
- lineage: {tree_idx: [{path: [[x,y]...], color}]} if show_lineages
|
|
232
|
+
"""
|
|
233
|
+
if not sample_names or not tree_indices:
|
|
234
|
+
return {"highlights": {}, "lineage": {}}
|
|
235
|
+
|
|
236
|
+
# Build sample_name -> node_id mapping
|
|
237
|
+
name_to_node_id = _build_sample_name_mapping(ts, sample_name_key)
|
|
238
|
+
|
|
239
|
+
# Convert sample_names to node_ids
|
|
240
|
+
target_node_ids = set()
|
|
241
|
+
name_map = {} # node_id -> original name
|
|
242
|
+
for name in sample_names:
|
|
243
|
+
lower_name = name.lower()
|
|
244
|
+
if lower_name in name_to_node_id:
|
|
245
|
+
nid = name_to_node_id[lower_name]
|
|
246
|
+
target_node_ids.add(nid)
|
|
247
|
+
name_map[nid] = name
|
|
248
|
+
|
|
249
|
+
if not target_node_ids:
|
|
250
|
+
return {"highlights": {}, "lineage": {}}
|
|
251
|
+
|
|
252
|
+
highlights = {}
|
|
253
|
+
lineage = {}
|
|
254
|
+
|
|
255
|
+
for tree_idx in tree_indices:
|
|
256
|
+
tree_idx = int(tree_idx)
|
|
257
|
+
if tree_idx < 0 or tree_idx >= ts.num_trees:
|
|
258
|
+
continue
|
|
259
|
+
|
|
260
|
+
tree = ts.at_index(tree_idx)
|
|
261
|
+
|
|
262
|
+
# Find matching samples in this tree
|
|
263
|
+
tree_highlights = []
|
|
264
|
+
tree_seeds = [] # For lineage computation
|
|
265
|
+
|
|
266
|
+
for node_id in target_node_ids:
|
|
267
|
+
# Check if this sample is in this tree
|
|
268
|
+
if tree.is_sample(node_id):
|
|
269
|
+
name = name_map.get(node_id, str(node_id))
|
|
270
|
+
tree_highlights.append({
|
|
271
|
+
"node_id": int(node_id),
|
|
272
|
+
"name": name
|
|
273
|
+
})
|
|
274
|
+
tree_seeds.append(node_id)
|
|
275
|
+
|
|
276
|
+
if tree_highlights:
|
|
277
|
+
highlights[tree_idx] = tree_highlights
|
|
278
|
+
|
|
279
|
+
# Compute lineage paths if requested
|
|
280
|
+
if show_lineages and tree_seeds:
|
|
281
|
+
tree_lineages = _compute_lineage_paths(
|
|
282
|
+
tree, tree_seeds, name_map, sample_colors
|
|
283
|
+
)
|
|
284
|
+
if tree_lineages:
|
|
285
|
+
lineage[tree_idx] = tree_lineages
|
|
286
|
+
|
|
287
|
+
return {"highlights": highlights, "lineage": lineage}
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def get_node_details(ts, node_name):
|
|
291
|
+
"""Get details for a specific node in the tree sequence."""
|
|
292
|
+
node = ts.node(node_name)
|
|
293
|
+
return {
|
|
294
|
+
"id": node.id,
|
|
295
|
+
"time": node.time,
|
|
296
|
+
"population": node.population,
|
|
297
|
+
"individual": node.individual,
|
|
298
|
+
"metadata": make_json_serializable(node.metadata)
|
|
299
|
+
}
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
def get_tree_details(ts, tree_index):
|
|
303
|
+
"""Get details for a specific tree at the given index."""
|
|
304
|
+
tree = ts.at_index(tree_index)
|
|
305
|
+
|
|
306
|
+
mutations = []
|
|
307
|
+
for mut in tree.mutations():
|
|
308
|
+
site = ts.site(mut.site)
|
|
309
|
+
mutations.append({
|
|
310
|
+
"id": mut.id,
|
|
311
|
+
"node": mut.node, # Node ID for highlighting
|
|
312
|
+
"site_id": mut.site,
|
|
313
|
+
"position": site.position,
|
|
314
|
+
"derived_state": mut.derived_state,
|
|
315
|
+
"inherited_state": ts.mutation(mut.parent).derived_state if mut.parent != -1 else site.ancestral_state
|
|
316
|
+
})
|
|
317
|
+
|
|
318
|
+
return {
|
|
319
|
+
"interval": tree.interval,
|
|
320
|
+
"num_roots": tree.num_roots,
|
|
321
|
+
"num_nodes": tree.num_nodes,
|
|
322
|
+
"mutations": mutations
|
|
323
|
+
}
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
def get_individual_details(ts, individual_id):
|
|
327
|
+
"""Get details for a specific individual in the tree sequence."""
|
|
328
|
+
individual = ts.individual(individual_id)
|
|
329
|
+
return {
|
|
330
|
+
"id": individual.id,
|
|
331
|
+
"nodes": make_json_serializable(individual.nodes),
|
|
332
|
+
"metadata": make_json_serializable(individual.metadata)
|
|
333
|
+
}
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
def get_comprehensive_individual_details(ts, individual_id):
|
|
337
|
+
"""Get comprehensive individual table data including location, parents, flags."""
|
|
338
|
+
if individual_id is None or individual_id == -1:
|
|
339
|
+
return None
|
|
340
|
+
|
|
341
|
+
individual = ts.individual(individual_id)
|
|
342
|
+
return {
|
|
343
|
+
"id": int(individual.id),
|
|
344
|
+
"flags": int(individual.flags),
|
|
345
|
+
"location": list(individual.location) if len(individual.location) > 0 else None,
|
|
346
|
+
"parents": [int(p) for p in individual.parents] if len(individual.parents) > 0 else [],
|
|
347
|
+
"nodes": [int(n) for n in individual.nodes],
|
|
348
|
+
"metadata": make_json_serializable(individual.metadata)
|
|
349
|
+
}
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
def get_population_details(ts, population_id):
|
|
353
|
+
"""Get population table data."""
|
|
354
|
+
if population_id is None or population_id == -1:
|
|
355
|
+
return None
|
|
356
|
+
pop = ts.population(population_id)
|
|
357
|
+
return {
|
|
358
|
+
"id": int(pop.id),
|
|
359
|
+
"metadata": make_json_serializable(pop.metadata)
|
|
360
|
+
}
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
def get_mutations_for_node(ts, node_id, tree_index=None):
|
|
364
|
+
"""Get all mutations on a specific node, optionally filtered by tree interval."""
|
|
365
|
+
mutations = []
|
|
366
|
+
|
|
367
|
+
# Get tree interval if tree_index is specified
|
|
368
|
+
tree_interval = None
|
|
369
|
+
if tree_index is not None:
|
|
370
|
+
tree = ts.at_index(tree_index)
|
|
371
|
+
tree_interval = tree.interval
|
|
372
|
+
|
|
373
|
+
for mut in ts.mutations():
|
|
374
|
+
if mut.node == node_id:
|
|
375
|
+
site = ts.site(mut.site)
|
|
376
|
+
|
|
377
|
+
# Filter by tree interval if specified
|
|
378
|
+
if tree_interval is not None:
|
|
379
|
+
if not (site.position >= tree_interval.left and site.position < tree_interval.right):
|
|
380
|
+
continue
|
|
381
|
+
|
|
382
|
+
mutations.append({
|
|
383
|
+
"id": int(mut.id),
|
|
384
|
+
"site_id": int(mut.site),
|
|
385
|
+
"position": float(site.position),
|
|
386
|
+
"ancestral_state": site.ancestral_state,
|
|
387
|
+
"derived_state": mut.derived_state,
|
|
388
|
+
"time": float(mut.time) if mut.time != tskit.UNKNOWN_TIME else None,
|
|
389
|
+
"parent_mutation": int(mut.parent) if mut.parent != -1 else None,
|
|
390
|
+
"metadata": make_json_serializable(mut.metadata) if mut.metadata else None
|
|
391
|
+
})
|
|
392
|
+
|
|
393
|
+
return mutations
|
|
394
|
+
|
|
395
|
+
|
|
396
|
+
def get_edges_for_node(ts, node_id, tree_index=None):
|
|
397
|
+
"""Get all edges where this node is parent or child."""
|
|
398
|
+
edges = {
|
|
399
|
+
"as_parent": [], # Edges where node is parent
|
|
400
|
+
"as_child": [] # Edges where node is child
|
|
401
|
+
}
|
|
402
|
+
|
|
403
|
+
# Get tree interval if tree_index is specified
|
|
404
|
+
tree_interval = None
|
|
405
|
+
if tree_index is not None:
|
|
406
|
+
tree = ts.at_index(tree_index)
|
|
407
|
+
tree_interval = tree.interval
|
|
408
|
+
|
|
409
|
+
for edge in ts.edges():
|
|
410
|
+
# Filter by tree interval if specified (edge must overlap with tree)
|
|
411
|
+
if tree_interval is not None:
|
|
412
|
+
if edge.right <= tree_interval.left or edge.left >= tree_interval.right:
|
|
413
|
+
continue
|
|
414
|
+
|
|
415
|
+
edge_data = {
|
|
416
|
+
"id": int(edge.id),
|
|
417
|
+
"left": float(edge.left),
|
|
418
|
+
"right": float(edge.right),
|
|
419
|
+
"parent": int(edge.parent),
|
|
420
|
+
"child": int(edge.child)
|
|
421
|
+
}
|
|
422
|
+
|
|
423
|
+
if edge.parent == node_id:
|
|
424
|
+
edges["as_parent"].append(edge_data)
|
|
425
|
+
if edge.child == node_id:
|
|
426
|
+
edges["as_child"].append(edge_data)
|
|
427
|
+
|
|
428
|
+
return edges
|
|
429
|
+
|
|
430
|
+
|
|
431
|
+
async def handle_details(file_path, data):
|
|
432
|
+
"""Handle requests for tree, node, and individual details."""
|
|
433
|
+
try:
|
|
434
|
+
ctx = await get_file_context(file_path)
|
|
435
|
+
if ctx is None:
|
|
436
|
+
return json.dumps({"error": "Tree sequence (ts) is not set. Please upload a file first."})
|
|
437
|
+
|
|
438
|
+
ts = ctx.tree_sequence
|
|
439
|
+
return_data = {}
|
|
440
|
+
tree_index = data.get("treeIndex")
|
|
441
|
+
comprehensive = data.get("comprehensive", False)
|
|
442
|
+
|
|
443
|
+
if tree_index is not None:
|
|
444
|
+
return_data["tree"] = get_tree_details(ts, int(tree_index))
|
|
445
|
+
|
|
446
|
+
node_name = data.get("node")
|
|
447
|
+
if node_name is not None:
|
|
448
|
+
node_id = int(node_name)
|
|
449
|
+
node_details = get_node_details(ts, node_id)
|
|
450
|
+
return_data["node"] = node_details
|
|
451
|
+
|
|
452
|
+
# Auto-fetch individual details
|
|
453
|
+
if node_details.get("individual") != -1:
|
|
454
|
+
if comprehensive:
|
|
455
|
+
return_data["individual"] = get_comprehensive_individual_details(
|
|
456
|
+
ts, node_details.get("individual")
|
|
457
|
+
)
|
|
458
|
+
else:
|
|
459
|
+
return_data["individual"] = get_individual_details(
|
|
460
|
+
ts, node_details.get("individual")
|
|
461
|
+
)
|
|
462
|
+
|
|
463
|
+
# Comprehensive mode: add population, mutations, edges
|
|
464
|
+
if comprehensive:
|
|
465
|
+
# Population
|
|
466
|
+
if node_details.get("population") != -1:
|
|
467
|
+
return_data["population"] = get_population_details(
|
|
468
|
+
ts, node_details.get("population")
|
|
469
|
+
)
|
|
470
|
+
|
|
471
|
+
# Mutations on this node
|
|
472
|
+
return_data["mutations"] = get_mutations_for_node(
|
|
473
|
+
ts, node_id, tree_index
|
|
474
|
+
)
|
|
475
|
+
|
|
476
|
+
# Edges for this node
|
|
477
|
+
return_data["edges"] = get_edges_for_node(
|
|
478
|
+
ts, node_id, tree_index
|
|
479
|
+
)
|
|
480
|
+
|
|
481
|
+
return json.dumps(return_data)
|
|
482
|
+
except Exception as e:
|
|
483
|
+
return json.dumps({"error": f"Error getting details: {str(e)}"})
|
|
484
|
+
|
|
485
|
+
|
|
486
|
+
async def handle_tree_graph_query(
|
|
487
|
+
file_path,
|
|
488
|
+
tree_indices,
|
|
489
|
+
sparsification=False,
|
|
490
|
+
session_id: str = None,
|
|
491
|
+
tree_graph_cache=None,
|
|
492
|
+
csv_tree_graph_cache=None,
|
|
493
|
+
actual_display_array=None
|
|
494
|
+
):
|
|
495
|
+
"""
|
|
496
|
+
Construct trees using Numba-optimized tree_graph module.
|
|
497
|
+
|
|
498
|
+
Args:
|
|
499
|
+
file_path: Path to tree sequence file
|
|
500
|
+
tree_indices: List of tree indices to process
|
|
501
|
+
sparsification: Enable tip-only sparsification (default False)
|
|
502
|
+
session_id: Session ID for cache lookup/storage
|
|
503
|
+
tree_graph_cache: TreeGraphCache instance for caching TreeGraph objects
|
|
504
|
+
|
|
505
|
+
Returns:
|
|
506
|
+
dict with:
|
|
507
|
+
- buffer: PyArrow IPC binary data containing:
|
|
508
|
+
- node_id: int32 (tskit node ID)
|
|
509
|
+
- parent_id: int32 (-1 for roots)
|
|
510
|
+
- is_tip: bool
|
|
511
|
+
- tree_idx: int32 (which tree this node belongs to)
|
|
512
|
+
- x: float32 (time-based coordinate [0,1])
|
|
513
|
+
- y: float32 (layout-based coordinate [0,1])
|
|
514
|
+
- global_min_time: float
|
|
515
|
+
- global_max_time: float
|
|
516
|
+
- tree_indices: list[int]
|
|
517
|
+
"""
|
|
518
|
+
ctx = await get_file_context(file_path)
|
|
519
|
+
if ctx is None:
|
|
520
|
+
return {"error": "Tree sequence not loaded. Please load a file first."}
|
|
521
|
+
|
|
522
|
+
ts = ctx.tree_sequence
|
|
523
|
+
|
|
524
|
+
# CSV support: parse Newick strings and build tree layout
|
|
525
|
+
if isinstance(ts, pd.DataFrame):
|
|
526
|
+
shift_tips_to_one = should_shift_csv_tips(ctx.file_path or file_path)
|
|
527
|
+
# Get max_branch_length from config (times.values[1])
|
|
528
|
+
times_values = ctx.config.get("times", {}).get("values", [0.0, 1.0])
|
|
529
|
+
max_branch_length = float(times_values[1]) if len(times_values) > 1 else 1.0
|
|
530
|
+
indices = [int(t) for t in (tree_indices or [])]
|
|
531
|
+
samples_order = ctx.config.get("samples") or []
|
|
532
|
+
pre_parsed_graphs = {}
|
|
533
|
+
if session_id and csv_tree_graph_cache:
|
|
534
|
+
from lorax.csv.newick_tree import parse_newick_to_tree
|
|
535
|
+
|
|
536
|
+
for tree_idx in indices:
|
|
537
|
+
cached = await csv_tree_graph_cache.get(session_id, int(tree_idx))
|
|
538
|
+
if cached is not None:
|
|
539
|
+
pre_parsed_graphs[int(tree_idx)] = cached
|
|
540
|
+
continue
|
|
541
|
+
|
|
542
|
+
# Cache miss: parse and store (best-effort)
|
|
543
|
+
try:
|
|
544
|
+
newick_str = ts.iloc[int(tree_idx)].get("newick")
|
|
545
|
+
except Exception:
|
|
546
|
+
newick_str = None
|
|
547
|
+
if newick_str is None or pd.isna(newick_str):
|
|
548
|
+
continue
|
|
549
|
+
|
|
550
|
+
tree_max_branch_length = None
|
|
551
|
+
if "max_branch_length" in ts.columns:
|
|
552
|
+
try:
|
|
553
|
+
v = ts.iloc[int(tree_idx)].get("max_branch_length")
|
|
554
|
+
if v is not None and not (isinstance(v, float) and pd.isna(v)) and str(v).strip() != "":
|
|
555
|
+
tree_max_branch_length = float(v)
|
|
556
|
+
except Exception:
|
|
557
|
+
tree_max_branch_length = None
|
|
558
|
+
|
|
559
|
+
try:
|
|
560
|
+
graph = await asyncio.to_thread(
|
|
561
|
+
parse_newick_to_tree,
|
|
562
|
+
str(newick_str),
|
|
563
|
+
max_branch_length,
|
|
564
|
+
samples_order=samples_order,
|
|
565
|
+
tree_max_branch_length=tree_max_branch_length,
|
|
566
|
+
shift_tips_to_one=shift_tips_to_one,
|
|
567
|
+
)
|
|
568
|
+
except Exception:
|
|
569
|
+
continue
|
|
570
|
+
|
|
571
|
+
pre_parsed_graphs[int(tree_idx)] = graph
|
|
572
|
+
await csv_tree_graph_cache.set(session_id, int(tree_idx), graph)
|
|
573
|
+
|
|
574
|
+
if actual_display_array is not None:
|
|
575
|
+
await csv_tree_graph_cache.evict_not_visible(session_id, set(actual_display_array))
|
|
576
|
+
|
|
577
|
+
return build_csv_layout_response(
|
|
578
|
+
ts,
|
|
579
|
+
indices,
|
|
580
|
+
max_branch_length,
|
|
581
|
+
samples_order=samples_order,
|
|
582
|
+
pre_parsed_graphs=pre_parsed_graphs,
|
|
583
|
+
shift_tips_to_one=shift_tips_to_one,
|
|
584
|
+
)
|
|
585
|
+
|
|
586
|
+
# Collect pre-cached TreeGraphs
|
|
587
|
+
pre_cached_graphs = {}
|
|
588
|
+
if session_id and tree_graph_cache:
|
|
589
|
+
for tree_idx in tree_indices:
|
|
590
|
+
cached = await tree_graph_cache.get(session_id, int(tree_idx))
|
|
591
|
+
if cached is not None:
|
|
592
|
+
pre_cached_graphs[int(tree_idx)] = cached
|
|
593
|
+
if pre_cached_graphs:
|
|
594
|
+
print(f"TreeGraph cache hits: {len(pre_cached_graphs)}/{len(tree_indices)} trees")
|
|
595
|
+
|
|
596
|
+
# Run in thread pool to avoid blocking
|
|
597
|
+
def process_trees():
|
|
598
|
+
return construct_trees_batch(
|
|
599
|
+
ts,
|
|
600
|
+
tree_indices,
|
|
601
|
+
sparsification=sparsification,
|
|
602
|
+
pre_cached_graphs=pre_cached_graphs
|
|
603
|
+
)
|
|
604
|
+
|
|
605
|
+
buffer, min_time, max_time, processed_indices, newly_built = await asyncio.to_thread(process_trees)
|
|
606
|
+
|
|
607
|
+
# Cache newly built TreeGraphs
|
|
608
|
+
if session_id and tree_graph_cache and newly_built:
|
|
609
|
+
for tree_idx, graph in newly_built.items():
|
|
610
|
+
await tree_graph_cache.set(session_id, tree_idx, graph)
|
|
611
|
+
print(f"TreeGraph cached: {len(newly_built)} new trees for session {session_id[:8]}...")
|
|
612
|
+
|
|
613
|
+
# Evict trees no longer in visible set (visibility-based eviction)
|
|
614
|
+
if session_id and tree_graph_cache and actual_display_array is not None:
|
|
615
|
+
await tree_graph_cache.evict_not_visible(session_id, set(actual_display_array))
|
|
616
|
+
|
|
617
|
+
return {
|
|
618
|
+
"buffer": buffer,
|
|
619
|
+
"global_min_time": min_time,
|
|
620
|
+
"global_max_time": max_time,
|
|
621
|
+
"tree_indices": processed_indices
|
|
622
|
+
}
|
|
623
|
+
|
|
624
|
+
|
|
625
|
+
async def get_or_construct_tree_graph(
|
|
626
|
+
file_path: str,
|
|
627
|
+
tree_index: int,
|
|
628
|
+
session_id: str,
|
|
629
|
+
tree_graph_cache
|
|
630
|
+
) -> TreeGraph:
|
|
631
|
+
"""
|
|
632
|
+
Get a TreeGraph from cache or construct and cache it.
|
|
633
|
+
|
|
634
|
+
This function is used by lineage operations that need the full TreeGraph
|
|
635
|
+
structure for ancestor/descendant traversal.
|
|
636
|
+
|
|
637
|
+
Args:
|
|
638
|
+
file_path: Path to tree sequence file
|
|
639
|
+
tree_index: Index of the tree to get
|
|
640
|
+
session_id: Session ID for cache key
|
|
641
|
+
tree_graph_cache: TreeGraphCache instance
|
|
642
|
+
|
|
643
|
+
Returns:
|
|
644
|
+
TreeGraph object, or None if file not loaded
|
|
645
|
+
"""
|
|
646
|
+
# Check cache first
|
|
647
|
+
cached = await tree_graph_cache.get(session_id, tree_index)
|
|
648
|
+
if cached is not None:
|
|
649
|
+
print(f"TreeGraph cache hit: session={session_id[:8]}... tree={tree_index}")
|
|
650
|
+
return cached
|
|
651
|
+
|
|
652
|
+
# Load file context
|
|
653
|
+
ctx = await get_file_context(file_path)
|
|
654
|
+
if ctx is None:
|
|
655
|
+
return None
|
|
656
|
+
|
|
657
|
+
ts = ctx.tree_sequence
|
|
658
|
+
|
|
659
|
+
# Can't construct TreeGraph for CSV
|
|
660
|
+
if isinstance(ts, pd.DataFrame):
|
|
661
|
+
return None
|
|
662
|
+
|
|
663
|
+
# Construct tree graph
|
|
664
|
+
def _construct():
|
|
665
|
+
edges = ts.tables.edges
|
|
666
|
+
nodes = ts.tables.nodes
|
|
667
|
+
breakpoints = list(ts.breakpoints())
|
|
668
|
+
min_time = float(ts.min_time)
|
|
669
|
+
max_time = float(ts.max_time)
|
|
670
|
+
return construct_tree(ts, edges, nodes, breakpoints, tree_index, min_time, max_time)
|
|
671
|
+
|
|
672
|
+
tree_graph = await asyncio.to_thread(_construct)
|
|
673
|
+
|
|
674
|
+
# Cache it
|
|
675
|
+
await tree_graph_cache.set(session_id, tree_index, tree_graph)
|
|
676
|
+
print(f"TreeGraph cached: session={session_id[:8]}... tree={tree_index}")
|
|
677
|
+
|
|
678
|
+
return tree_graph
|
|
679
|
+
|
|
680
|
+
|
|
681
|
+
async def ensure_trees_cached(
|
|
682
|
+
file_path: str,
|
|
683
|
+
tree_indices: list,
|
|
684
|
+
session_id: str,
|
|
685
|
+
tree_graph_cache
|
|
686
|
+
) -> int:
|
|
687
|
+
"""
|
|
688
|
+
Ensure multiple trees are cached for a session.
|
|
689
|
+
|
|
690
|
+
This is called after process_postorder_layout to cache trees for
|
|
691
|
+
subsequent lineage operations.
|
|
692
|
+
|
|
693
|
+
Args:
|
|
694
|
+
file_path: Path to tree sequence file
|
|
695
|
+
tree_indices: List of tree indices to cache
|
|
696
|
+
session_id: Session ID for cache key
|
|
697
|
+
tree_graph_cache: TreeGraphCache instance
|
|
698
|
+
|
|
699
|
+
Returns:
|
|
700
|
+
Number of trees newly cached (not already in cache)
|
|
701
|
+
"""
|
|
702
|
+
ctx = await get_file_context(file_path)
|
|
703
|
+
if ctx is None:
|
|
704
|
+
return 0
|
|
705
|
+
|
|
706
|
+
ts = ctx.tree_sequence
|
|
707
|
+
|
|
708
|
+
if isinstance(ts, pd.DataFrame):
|
|
709
|
+
return 0
|
|
710
|
+
|
|
711
|
+
newly_cached = 0
|
|
712
|
+
|
|
713
|
+
# Pre-extract tables for efficiency
|
|
714
|
+
edges = ts.tables.edges
|
|
715
|
+
nodes = ts.tables.nodes
|
|
716
|
+
breakpoints = list(ts.breakpoints())
|
|
717
|
+
min_time = float(ts.min_time)
|
|
718
|
+
max_time = float(ts.max_time)
|
|
719
|
+
|
|
720
|
+
for tree_index in tree_indices:
|
|
721
|
+
tree_index = int(tree_index)
|
|
722
|
+
|
|
723
|
+
# Skip if already cached
|
|
724
|
+
cached = await tree_graph_cache.get(session_id, tree_index)
|
|
725
|
+
if cached is not None:
|
|
726
|
+
continue
|
|
727
|
+
|
|
728
|
+
# Construct and cache
|
|
729
|
+
def _construct(idx):
|
|
730
|
+
return construct_tree(ts, edges, nodes, breakpoints, idx, min_time, max_time)
|
|
731
|
+
|
|
732
|
+
tree_graph = await asyncio.to_thread(_construct, tree_index)
|
|
733
|
+
await tree_graph_cache.set(session_id, tree_index, tree_graph)
|
|
734
|
+
newly_cached += 1
|
|
735
|
+
|
|
736
|
+
if newly_cached > 0:
|
|
737
|
+
print(f"Cached {newly_cached} trees for session {session_id[:8]}...")
|
|
738
|
+
|
|
739
|
+
return newly_cached
|
|
740
|
+
|
|
741
|
+
|
|
742
|
+
def _get_matching_sample_nodes(ts, metadata_key, metadata_value, sources, sample_name_key):
|
|
743
|
+
"""
|
|
744
|
+
Find all sample node IDs that match a metadata value.
|
|
745
|
+
|
|
746
|
+
Args:
|
|
747
|
+
ts: tskit.TreeSequence
|
|
748
|
+
metadata_key: Metadata key to filter by
|
|
749
|
+
metadata_value: Metadata value to match
|
|
750
|
+
sources: Metadata sources to search
|
|
751
|
+
sample_name_key: Key in node metadata used as sample name
|
|
752
|
+
|
|
753
|
+
Returns:
|
|
754
|
+
Set of matching node IDs
|
|
755
|
+
"""
|
|
756
|
+
matching_node_ids = set()
|
|
757
|
+
for node_id in ts.samples():
|
|
758
|
+
sample_name, value = _get_sample_metadata_value(
|
|
759
|
+
ts, node_id, metadata_key, sources, sample_name_key
|
|
760
|
+
)
|
|
761
|
+
if value is not None and str(value) == str(metadata_value):
|
|
762
|
+
matching_node_ids.add(node_id)
|
|
763
|
+
return matching_node_ids
|
|
764
|
+
|
|
765
|
+
|
|
766
|
+
async def _ensure_tree_graph_loaded(
|
|
767
|
+
ts,
|
|
768
|
+
tree_idx,
|
|
769
|
+
session_id,
|
|
770
|
+
tree_graph_cache,
|
|
771
|
+
edges,
|
|
772
|
+
nodes,
|
|
773
|
+
breakpoints,
|
|
774
|
+
min_time,
|
|
775
|
+
max_time
|
|
776
|
+
):
|
|
777
|
+
"""
|
|
778
|
+
Get tree graph from cache or construct and cache it.
|
|
779
|
+
|
|
780
|
+
Args:
|
|
781
|
+
ts: tskit.TreeSequence
|
|
782
|
+
tree_idx: Tree index to load
|
|
783
|
+
session_id: Session ID for cache key
|
|
784
|
+
tree_graph_cache: TreeGraphCache instance
|
|
785
|
+
edges, nodes, breakpoints, min_time, max_time: Pre-extracted table data
|
|
786
|
+
|
|
787
|
+
Returns:
|
|
788
|
+
TreeGraph object
|
|
789
|
+
"""
|
|
790
|
+
from lorax.tree_graph import construct_tree
|
|
791
|
+
|
|
792
|
+
# Try to get from cache first
|
|
793
|
+
graph = await tree_graph_cache.get(session_id, tree_idx)
|
|
794
|
+
if graph is not None:
|
|
795
|
+
return graph
|
|
796
|
+
|
|
797
|
+
# Construct tree graph
|
|
798
|
+
def _construct():
|
|
799
|
+
return construct_tree(ts, edges, nodes, breakpoints, tree_idx, min_time, max_time)
|
|
800
|
+
|
|
801
|
+
graph = await asyncio.to_thread(_construct)
|
|
802
|
+
|
|
803
|
+
# Cache it for future use
|
|
804
|
+
await tree_graph_cache.set(session_id, tree_idx, graph)
|
|
805
|
+
|
|
806
|
+
return graph
|
|
807
|
+
|
|
808
|
+
|
|
809
|
+
async def get_highlight_positions(
|
|
810
|
+
ts,
|
|
811
|
+
file_path,
|
|
812
|
+
metadata_key,
|
|
813
|
+
metadata_value,
|
|
814
|
+
tree_indices,
|
|
815
|
+
session_id: str,
|
|
816
|
+
tree_graph_cache,
|
|
817
|
+
sources=("individual", "node", "population"),
|
|
818
|
+
sample_name_key="name"
|
|
819
|
+
):
|
|
820
|
+
"""
|
|
821
|
+
Get positions for all tip nodes with a specific metadata value.
|
|
822
|
+
Uses cached TreeGraph objects when available.
|
|
823
|
+
|
|
824
|
+
Args:
|
|
825
|
+
ts: tskit.TreeSequence
|
|
826
|
+
file_path: Path to tree sequence file (for cache key)
|
|
827
|
+
metadata_key: Metadata key to filter by
|
|
828
|
+
metadata_value: Metadata value to match
|
|
829
|
+
tree_indices: List of tree indices to compute positions for
|
|
830
|
+
session_id: Session ID for cache lookup
|
|
831
|
+
tree_graph_cache: TreeGraphCache instance
|
|
832
|
+
sources: Metadata sources to search
|
|
833
|
+
sample_name_key: Key in node metadata used as sample name
|
|
834
|
+
|
|
835
|
+
Returns:
|
|
836
|
+
dict with:
|
|
837
|
+
- positions: List of {node_id, tree_idx, x, y} dicts
|
|
838
|
+
"""
|
|
839
|
+
if not tree_indices:
|
|
840
|
+
return {"positions": []}
|
|
841
|
+
|
|
842
|
+
# Get sample node IDs that have this metadata value
|
|
843
|
+
matching_node_ids = _get_matching_sample_nodes(
|
|
844
|
+
ts, metadata_key, metadata_value, sources, sample_name_key
|
|
845
|
+
)
|
|
846
|
+
|
|
847
|
+
if not matching_node_ids:
|
|
848
|
+
return {"positions": []}
|
|
849
|
+
|
|
850
|
+
# Pre-extract tables for reuse (only needed if cache miss)
|
|
851
|
+
edges = ts.tables.edges
|
|
852
|
+
nodes = ts.tables.nodes
|
|
853
|
+
breakpoints = list(ts.breakpoints())
|
|
854
|
+
min_time = float(ts.min_time)
|
|
855
|
+
max_time = float(ts.max_time)
|
|
856
|
+
|
|
857
|
+
positions = []
|
|
858
|
+
|
|
859
|
+
# For each requested tree, get graph and extract positions
|
|
860
|
+
for tree_idx in tree_indices:
|
|
861
|
+
tree_idx = int(tree_idx)
|
|
862
|
+
if tree_idx < 0 or tree_idx >= ts.num_trees:
|
|
863
|
+
continue
|
|
864
|
+
|
|
865
|
+
graph = await _ensure_tree_graph_loaded(
|
|
866
|
+
ts, tree_idx, session_id, tree_graph_cache,
|
|
867
|
+
edges, nodes, breakpoints, min_time, max_time
|
|
868
|
+
)
|
|
869
|
+
|
|
870
|
+
# Extract positions for matching nodes that are in this tree
|
|
871
|
+
for node_id in matching_node_ids:
|
|
872
|
+
if graph.in_tree[node_id]:
|
|
873
|
+
positions.append({
|
|
874
|
+
"node_id": int(node_id),
|
|
875
|
+
"tree_idx": tree_idx,
|
|
876
|
+
"x": float(graph.x[node_id]),
|
|
877
|
+
"y": float(graph.y[node_id])
|
|
878
|
+
})
|
|
879
|
+
|
|
880
|
+
return {"positions": positions}
|
|
881
|
+
|
|
882
|
+
|
|
883
|
+
async def get_multi_value_highlight_positions(
|
|
884
|
+
ts,
|
|
885
|
+
file_path,
|
|
886
|
+
metadata_key,
|
|
887
|
+
metadata_values, # List[str] - Array of values (OR logic)
|
|
888
|
+
tree_indices,
|
|
889
|
+
session_id: str,
|
|
890
|
+
tree_graph_cache,
|
|
891
|
+
show_lineages: bool = False,
|
|
892
|
+
sources=("individual", "node", "population"),
|
|
893
|
+
sample_name_key="name"
|
|
894
|
+
):
|
|
895
|
+
"""
|
|
896
|
+
Get positions for tip nodes matching ANY of the metadata values.
|
|
897
|
+
Returns positions grouped by value for per-value coloring.
|
|
898
|
+
|
|
899
|
+
Args:
|
|
900
|
+
ts: tskit.TreeSequence
|
|
901
|
+
file_path: Path to tree sequence file (for cache key)
|
|
902
|
+
metadata_key: Metadata key to filter by
|
|
903
|
+
metadata_values: List of metadata values to match (OR logic)
|
|
904
|
+
tree_indices: List of tree indices to compute positions for
|
|
905
|
+
session_id: Session ID for cache lookup
|
|
906
|
+
tree_graph_cache: TreeGraphCache instance
|
|
907
|
+
show_lineages: Whether to compute lineage (ancestry) paths
|
|
908
|
+
sources: Metadata sources to search
|
|
909
|
+
sample_name_key: Key in node metadata used as sample name
|
|
910
|
+
|
|
911
|
+
Returns:
|
|
912
|
+
dict with:
|
|
913
|
+
- positions_by_value: {"Africa": [{node_id, tree_idx, x, y}, ...], ...}
|
|
914
|
+
- lineages: {"Africa": {tree_idx: [{path_node_ids, color}]}} if show_lineages
|
|
915
|
+
- total_count: int
|
|
916
|
+
"""
|
|
917
|
+
if not tree_indices or not metadata_values:
|
|
918
|
+
return {"positions_by_value": {}, "lineages": {}, "total_count": 0}
|
|
919
|
+
|
|
920
|
+
# Deduplicate values
|
|
921
|
+
unique_values = list(set(str(v) for v in metadata_values))
|
|
922
|
+
|
|
923
|
+
# Pre-extract tables for reuse (only needed if cache miss)
|
|
924
|
+
edges = ts.tables.edges
|
|
925
|
+
nodes = ts.tables.nodes
|
|
926
|
+
breakpoints = list(ts.breakpoints())
|
|
927
|
+
min_time = float(ts.min_time)
|
|
928
|
+
max_time = float(ts.max_time)
|
|
929
|
+
|
|
930
|
+
positions_by_value = {}
|
|
931
|
+
lineages = {} if show_lineages else None
|
|
932
|
+
total_count = 0
|
|
933
|
+
|
|
934
|
+
# For each value, find matching samples
|
|
935
|
+
for value in unique_values:
|
|
936
|
+
matching_node_ids = _get_matching_sample_nodes(
|
|
937
|
+
ts, metadata_key, value, sources, sample_name_key
|
|
938
|
+
)
|
|
939
|
+
|
|
940
|
+
if not matching_node_ids:
|
|
941
|
+
positions_by_value[value] = []
|
|
942
|
+
continue
|
|
943
|
+
|
|
944
|
+
value_positions = []
|
|
945
|
+
value_lineages = {} if show_lineages else None
|
|
946
|
+
|
|
947
|
+
# For each requested tree, get graph and extract positions
|
|
948
|
+
for tree_idx in tree_indices:
|
|
949
|
+
tree_idx = int(tree_idx)
|
|
950
|
+
if tree_idx < 0 or tree_idx >= ts.num_trees:
|
|
951
|
+
continue
|
|
952
|
+
|
|
953
|
+
graph = await _ensure_tree_graph_loaded(
|
|
954
|
+
ts, tree_idx, session_id, tree_graph_cache,
|
|
955
|
+
edges, nodes, breakpoints, min_time, max_time
|
|
956
|
+
)
|
|
957
|
+
|
|
958
|
+
tree_positions = []
|
|
959
|
+
tree_seeds = [] # For lineage computation
|
|
960
|
+
|
|
961
|
+
# Extract positions for matching nodes that are in this tree
|
|
962
|
+
for node_id in matching_node_ids:
|
|
963
|
+
if graph.in_tree[node_id]:
|
|
964
|
+
tree_positions.append({
|
|
965
|
+
"node_id": int(node_id),
|
|
966
|
+
"tree_idx": tree_idx,
|
|
967
|
+
"x": float(graph.x[node_id]),
|
|
968
|
+
"y": float(graph.y[node_id])
|
|
969
|
+
})
|
|
970
|
+
tree_seeds.append(node_id)
|
|
971
|
+
|
|
972
|
+
value_positions.extend(tree_positions)
|
|
973
|
+
|
|
974
|
+
# Compute lineage paths if requested
|
|
975
|
+
if show_lineages and tree_seeds:
|
|
976
|
+
tree = ts.at_index(tree_idx)
|
|
977
|
+
name_map = {nid: str(nid) for nid in tree_seeds}
|
|
978
|
+
tree_lineages = _compute_lineage_paths(
|
|
979
|
+
tree, tree_seeds, name_map, None # No per-sample colors, use value color
|
|
980
|
+
)
|
|
981
|
+
if tree_lineages:
|
|
982
|
+
value_lineages[tree_idx] = tree_lineages
|
|
983
|
+
|
|
984
|
+
positions_by_value[value] = value_positions
|
|
985
|
+
total_count += len(value_positions)
|
|
986
|
+
|
|
987
|
+
if show_lineages and value_lineages:
|
|
988
|
+
lineages[value] = value_lineages
|
|
989
|
+
|
|
990
|
+
result = {
|
|
991
|
+
"positions_by_value": positions_by_value,
|
|
992
|
+
"total_count": total_count
|
|
993
|
+
}
|
|
994
|
+
|
|
995
|
+
if show_lineages:
|
|
996
|
+
result["lineages"] = lineages
|
|
997
|
+
|
|
998
|
+
return result
|