lorax-arg 0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. lorax/buffer.py +43 -0
  2. lorax/cache/__init__.py +43 -0
  3. lorax/cache/csv_tree_graph.py +59 -0
  4. lorax/cache/disk.py +467 -0
  5. lorax/cache/file_cache.py +142 -0
  6. lorax/cache/file_context.py +72 -0
  7. lorax/cache/lru.py +90 -0
  8. lorax/cache/tree_graph.py +293 -0
  9. lorax/cli.py +312 -0
  10. lorax/cloud/__init__.py +0 -0
  11. lorax/cloud/gcs_utils.py +205 -0
  12. lorax/constants.py +66 -0
  13. lorax/context.py +80 -0
  14. lorax/csv/__init__.py +7 -0
  15. lorax/csv/config.py +250 -0
  16. lorax/csv/layout.py +182 -0
  17. lorax/csv/newick_tree.py +234 -0
  18. lorax/handlers.py +998 -0
  19. lorax/lineage.py +456 -0
  20. lorax/loaders/__init__.py +0 -0
  21. lorax/loaders/csv_loader.py +10 -0
  22. lorax/loaders/loader.py +31 -0
  23. lorax/loaders/tskit_loader.py +119 -0
  24. lorax/lorax_app.py +75 -0
  25. lorax/manager.py +58 -0
  26. lorax/metadata/__init__.py +0 -0
  27. lorax/metadata/loader.py +426 -0
  28. lorax/metadata/mutations.py +146 -0
  29. lorax/modes.py +190 -0
  30. lorax/pg.py +183 -0
  31. lorax/redis_utils.py +30 -0
  32. lorax/routes.py +137 -0
  33. lorax/session_manager.py +206 -0
  34. lorax/sockets/__init__.py +55 -0
  35. lorax/sockets/connection.py +99 -0
  36. lorax/sockets/debug.py +47 -0
  37. lorax/sockets/decorators.py +112 -0
  38. lorax/sockets/file_ops.py +200 -0
  39. lorax/sockets/lineage.py +307 -0
  40. lorax/sockets/metadata.py +232 -0
  41. lorax/sockets/mutations.py +154 -0
  42. lorax/sockets/node_search.py +535 -0
  43. lorax/sockets/tree_layout.py +117 -0
  44. lorax/sockets/utils.py +10 -0
  45. lorax/tree_graph/__init__.py +12 -0
  46. lorax/tree_graph/tree_graph.py +689 -0
  47. lorax/utils.py +124 -0
  48. lorax_app/__init__.py +4 -0
  49. lorax_app/app.py +159 -0
  50. lorax_app/cli.py +114 -0
  51. lorax_app/static/X.png +0 -0
  52. lorax_app/static/assets/index-BCEGlUFi.js +2361 -0
  53. lorax_app/static/assets/index-iKjzUpA9.css +1 -0
  54. lorax_app/static/assets/localBackendWorker-BaWwjSV_.js +2 -0
  55. lorax_app/static/assets/renderDataWorker-BKLdiU7J.js +2 -0
  56. lorax_app/static/gestures/gesture-flick.ogv +0 -0
  57. lorax_app/static/gestures/gesture-two-finger-scroll.ogv +0 -0
  58. lorax_app/static/index.html +14 -0
  59. lorax_app/static/logo.png +0 -0
  60. lorax_app/static/lorax-logo.png +0 -0
  61. lorax_app/static/vite.svg +1 -0
  62. lorax_arg-0.1.dist-info/METADATA +131 -0
  63. lorax_arg-0.1.dist-info/RECORD +66 -0
  64. lorax_arg-0.1.dist-info/WHEEL +5 -0
  65. lorax_arg-0.1.dist-info/entry_points.txt +4 -0
  66. lorax_arg-0.1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,689 @@
1
+ """
2
+ tree_graph.py - Numba-optimized tree construction from tskit tables.
3
+
4
+ This module provides:
5
+ - TreeGraph: Numpy-based tree representation with CSR children and x,y coordinates
6
+ - construct_tree: Build a single tree from tables (Numba-optimized)
7
+ - construct_trees_batch: Build multiple trees efficiently (includes mutation extraction)
8
+ """
9
+
10
+ import struct
11
+ import numpy as np
12
+ import pyarrow as pa
13
+ from numba import njit
14
+ from numba.typed import Dict
15
+ from numba import types
16
+ from dataclasses import dataclass
17
+ from typing import List, Optional, Tuple
18
+
19
+ # Default cell size for sparsification (0.2% of normalized [0,1] space)
20
+ DEFAULT_SPARSIFY_CELL_SIZE = 0.002
21
+
22
+
23
+ @njit(cache=True)
24
+ def _compute_x_postorder(children_indptr, children_data, roots, num_nodes):
25
+ """
26
+ Numba-compiled post-order traversal for computing x (layout) coordinates.
27
+
28
+ Tips get sequential x values (0, 1, 2, ...), internal nodes get (min + max) / 2 of children.
29
+
30
+ Args:
31
+ children_indptr: CSR indptr array
32
+ children_data: CSR data array (flattened children)
33
+ roots: Array of root node IDs
34
+ num_nodes: Total number of nodes
35
+
36
+ Returns:
37
+ (x, tip_counter): x coordinates array and number of tips
38
+ """
39
+ x = np.full(num_nodes, -1.0, dtype=np.float32)
40
+ tip_counter = 0
41
+
42
+ # Pre-allocated stack arrays (avoid Python list)
43
+ stack_nodes = np.empty(num_nodes, dtype=np.int32)
44
+ stack_visited = np.empty(num_nodes, dtype=np.uint8) # 0=False, 1=True
45
+
46
+ for i in range(len(roots)):
47
+ root = roots[i]
48
+ stack_ptr = 0
49
+
50
+ stack_nodes[stack_ptr] = root
51
+ stack_visited[stack_ptr] = 0
52
+ stack_ptr += 1
53
+
54
+ while stack_ptr > 0:
55
+ stack_ptr -= 1
56
+ node = stack_nodes[stack_ptr]
57
+ visited = stack_visited[stack_ptr]
58
+
59
+ start = children_indptr[node]
60
+ end = children_indptr[node + 1]
61
+ num_children = end - start
62
+
63
+ if visited == 0 and num_children > 0:
64
+ # Push node back as visited
65
+ stack_nodes[stack_ptr] = node
66
+ stack_visited[stack_ptr] = 1
67
+ stack_ptr += 1
68
+
69
+ # Push children
70
+ for j in range(start, end):
71
+ stack_nodes[stack_ptr] = children_data[j]
72
+ stack_visited[stack_ptr] = 0
73
+ stack_ptr += 1
74
+ else:
75
+ # Post-order processing
76
+ if num_children == 0:
77
+ x[node] = tip_counter
78
+ tip_counter += 1
79
+ else:
80
+ # Compute (min + max) / 2 of children (matches jstree.js)
81
+ min_x = x[children_data[start]]
82
+ max_x = x[children_data[start]]
83
+ for j in range(start + 1, end):
84
+ child_x = x[children_data[j]]
85
+ if child_x < min_x:
86
+ min_x = child_x
87
+ if child_x > max_x:
88
+ max_x = child_x
89
+ x[node] = (min_x + max_x) / 2.0
90
+
91
+ return x, tip_counter
92
+
93
+
94
+ @dataclass
95
+ class TreeGraph:
96
+ """
97
+ Graph representation using numpy arrays with CSR format for children.
98
+
99
+ Attributes:
100
+ parent: int32 array where parent[node_id] = parent_id (-1 for root)
101
+ time: float32 array of raw node times
102
+ children_indptr: int32 CSR row pointers (length = num_nodes + 1)
103
+ children_data: int32 flattened children array
104
+ x: float32 layout position [0,1] (tips spread, internal=(min+max)/2 of children)
105
+ y: float32 normalized time [0,1] (min_time=0, max_time=1)
106
+ in_tree: bool array indicating which nodes are in this tree
107
+ """
108
+ parent: np.ndarray
109
+ time: np.ndarray
110
+ children_indptr: np.ndarray
111
+ children_data: np.ndarray
112
+ x: np.ndarray
113
+ y: np.ndarray
114
+ in_tree: np.ndarray
115
+
116
+ def children(self, node_id: int) -> np.ndarray:
117
+ """Get children of a node as numpy array slice (zero-copy)."""
118
+ return self.children_data[self.children_indptr[node_id]:self.children_indptr[node_id + 1]]
119
+
120
+ def is_tip(self, node_id: int) -> bool:
121
+ """Check if a node is a tip (no children)."""
122
+ return self.children_indptr[node_id + 1] == self.children_indptr[node_id]
123
+
124
+ def get_node_x(self, node_id: int) -> float:
125
+ """Get the x (layout) coordinate for a node."""
126
+ return self.x[node_id] if node_id >= 0 and node_id < len(self.x) else 0.5
127
+
128
+ def to_pyarrow(self, tree_idx: int = 0) -> bytes:
129
+ """
130
+ Serialize TreeGraph to PyArrow IPC format for frontend rendering.
131
+
132
+ Note: Coordinates are swapped to match backend convention:
133
+ - Backend x = self.y (time)
134
+ - Backend y = self.x (layout)
135
+
136
+ Args:
137
+ tree_idx: Tree index to include in output (for multi-tree rendering)
138
+
139
+ Returns:
140
+ bytes: PyArrow IPC binary data ready to send to frontend
141
+ """
142
+ # Get nodes that are in this tree
143
+ indices = np.where(self.in_tree)[0].astype(np.int32)
144
+ n = len(indices)
145
+
146
+ if n == 0:
147
+ # Return empty table
148
+ table = pa.table({
149
+ 'node_id': pa.array([], type=pa.int32()),
150
+ 'parent_id': pa.array([], type=pa.int32()),
151
+ 'is_tip': pa.array([], type=pa.bool_()),
152
+ 'tree_idx': pa.array([], type=pa.int32()),
153
+ 'x': pa.array([], type=pa.float32()),
154
+ 'y': pa.array([], type=pa.float32()),
155
+ })
156
+ else:
157
+ # Derive is_tip from CSR: nodes with no children
158
+ child_counts = np.diff(self.children_indptr)
159
+ is_tip = child_counts[indices] == 0
160
+
161
+ # Build PyArrow table (swap x<->y for backend convention)
162
+ table = pa.table({
163
+ 'node_id': pa.array(indices, type=pa.int32()),
164
+ 'parent_id': pa.array(self.parent[indices], type=pa.int32()),
165
+ 'is_tip': pa.array(is_tip, type=pa.bool_()),
166
+ 'tree_idx': pa.array(np.full(n, tree_idx, dtype=np.int32), type=pa.int32()),
167
+ 'x': pa.array(self.y[indices], type=pa.float32()), # SWAP: time -> x
168
+ 'y': pa.array(self.x[indices], type=pa.float32()), # SWAP: layout -> y
169
+ })
170
+
171
+ # Serialize to IPC format
172
+ sink = pa.BufferOutputStream()
173
+ writer = pa.ipc.new_stream(sink, table.schema)
174
+ writer.write_table(table)
175
+ writer.close()
176
+
177
+ return sink.getvalue().to_pybytes()
178
+
179
+
180
+ def construct_tree(
181
+ ts,
182
+ edges,
183
+ nodes,
184
+ breakpoints,
185
+ index: int,
186
+ min_time: Optional[float] = None,
187
+ max_time: Optional[float] = None
188
+ ) -> TreeGraph:
189
+ """
190
+ Construct tree with x,y coordinates using Numba-optimized post-order traversal.
191
+
192
+ Args:
193
+ ts: tskit TreeSequence object
194
+ edges: ts.tables.edges (pre-extracted for reuse)
195
+ nodes: ts.tables.nodes (pre-extracted for reuse)
196
+ breakpoints: list/array of breakpoints (pre-extracted for reuse)
197
+ index: Tree index
198
+ min_time: Optional global min time (default: ts.min_time)
199
+ max_time: Optional global max time (default: ts.max_time)
200
+
201
+ Returns:
202
+ TreeGraph with CSR children and x,y coordinates in [0,1].
203
+ """
204
+ if index < 0 or index >= ts.num_trees:
205
+ raise ValueError(f"Tree index {index} out of range [0, {ts.num_trees - 1}]")
206
+
207
+ interval_left = breakpoints[index]
208
+ num_nodes = len(nodes.time)
209
+ node_times = nodes.time
210
+
211
+ # Use provided min/max or compute from ts
212
+ if min_time is None:
213
+ min_time = ts.min_time
214
+ if max_time is None:
215
+ max_time = ts.max_time
216
+
217
+ # === Edge filtering & parent array ===
218
+ active_mask = (edges.left <= interval_left) & (edges.right > interval_left)
219
+ active_parents = edges.parent[active_mask]
220
+ active_children = edges.child[active_mask]
221
+
222
+ parent = np.full(num_nodes, -1, dtype=np.int32)
223
+ parent[active_children] = active_parents
224
+
225
+ # === CSR children structure ===
226
+ child_counts = np.bincount(active_parents, minlength=num_nodes).astype(np.int32)
227
+ children_indptr = np.zeros(num_nodes + 1, dtype=np.int32)
228
+ children_indptr[1:] = np.cumsum(child_counts)
229
+ sort_idx = np.argsort(active_parents, kind='stable')
230
+ children_data = active_children[sort_idx].astype(np.int32)
231
+
232
+ # === Track which nodes are in this tree ===
233
+ in_tree = np.zeros(num_nodes, dtype=np.bool_)
234
+ in_tree[active_children] = True
235
+ in_tree[active_parents] = True
236
+
237
+ # === Y coordinate: normalized time (vectorized) ===
238
+ # Inverted: max_time → 0, min_time → 1 (root at left, tips at right)
239
+ time_range = max_time - min_time if max_time > min_time else 1.0
240
+ y = ((max_time - node_times) / time_range).astype(np.float32)
241
+
242
+ # === X coordinate: Numba-optimized post-order traversal ===
243
+ roots = np.where(in_tree & (parent == -1))[0].astype(np.int32)
244
+ x, tip_counter = _compute_x_postorder(children_indptr, children_data, roots, num_nodes)
245
+
246
+ # Normalize x to [0, 1]
247
+ if tip_counter > 1:
248
+ x[in_tree] /= (tip_counter - 1)
249
+
250
+ return TreeGraph(
251
+ parent=parent,
252
+ time=node_times.astype(np.float32),
253
+ children_indptr=children_indptr,
254
+ children_data=children_data,
255
+ x=x,
256
+ y=y,
257
+ in_tree=in_tree
258
+ )
259
+
260
+
261
+ def construct_trees_batch(
262
+ ts,
263
+ tree_indices: List[int],
264
+ sparsification: bool = False,
265
+ include_mutations: bool = True,
266
+ pre_cached_graphs: Optional[dict] = None
267
+ ) -> Tuple[bytes, float, float, List[int], dict]:
268
+ """
269
+ Construct multiple trees and return combined PyArrow buffer.
270
+
271
+ This is the main entry point for the backend handler.
272
+
273
+ Args:
274
+ ts: tskit TreeSequence object
275
+ tree_indices: List of tree indices to process
276
+ sparsification: Enable tip-only sparsification (default False)
277
+ include_mutations: Whether to include mutation data in buffer
278
+ pre_cached_graphs: Optional dict mapping tree_idx -> TreeGraph for cache hits
279
+
280
+ Returns:
281
+ Tuple of (buffer, global_min_time, global_max_time, tree_indices, newly_built_graphs)
282
+ where newly_built_graphs is a dict mapping tree_idx -> TreeGraph for trees constructed
283
+ """
284
+ # Pre-extract tables for reuse
285
+ edges = ts.tables.edges
286
+ nodes = ts.tables.nodes
287
+ breakpoints = list(ts.breakpoints())
288
+
289
+ min_time = float(ts.min_time)
290
+ max_time = float(ts.max_time)
291
+
292
+ # Check if tree sequence has mutations
293
+ has_mutations = include_mutations and ts.num_mutations > 0
294
+
295
+ # Pre-extract mutation tables and positions (avoid repeated lookups in loop)
296
+ if has_mutations:
297
+ sites = ts.tables.sites
298
+ mutations = ts.tables.mutations
299
+ mutation_positions = sites.position[mutations.site]
300
+ else:
301
+ sites = None
302
+ mutations = None
303
+ mutation_positions = None
304
+
305
+ if len(tree_indices) == 0:
306
+ # Return empty buffer with separate node and mutation tables
307
+ node_table = pa.table({
308
+ 'node_id': pa.array([], type=pa.int32()),
309
+ 'parent_id': pa.array([], type=pa.int32()),
310
+ 'is_tip': pa.array([], type=pa.bool_()),
311
+ 'tree_idx': pa.array([], type=pa.int32()),
312
+ 'x': pa.array([], type=pa.float32()),
313
+ 'y': pa.array([], type=pa.float32()),
314
+ })
315
+ mut_table = pa.table({
316
+ 'mut_x': pa.array([], type=pa.float32()),
317
+ 'mut_y': pa.array([], type=pa.float32()),
318
+ 'mut_tree_idx': pa.array([], type=pa.int32()),
319
+ })
320
+
321
+ node_sink = pa.BufferOutputStream()
322
+ node_writer = pa.ipc.new_stream(node_sink, node_table.schema)
323
+ node_writer.write_table(node_table)
324
+ node_writer.close()
325
+ node_bytes = node_sink.getvalue().to_pybytes()
326
+
327
+ mut_sink = pa.BufferOutputStream()
328
+ mut_writer = pa.ipc.new_stream(mut_sink, mut_table.schema)
329
+ mut_writer.write_table(mut_table)
330
+ mut_writer.close()
331
+ mut_bytes = mut_sink.getvalue().to_pybytes()
332
+
333
+ combined = struct.pack('<I', len(node_bytes)) + node_bytes + mut_bytes
334
+ return combined, min_time, max_time, [], {}
335
+
336
+ # Estimate total nodes for pre-allocation
337
+ sample_tree = ts.at_index(int(tree_indices[0]) if tree_indices else 0)
338
+ estimated_nodes_per_tree = sample_tree.num_nodes
339
+ total_estimated = estimated_nodes_per_tree * len(tree_indices) * 2
340
+
341
+ # Pre-allocate node arrays
342
+ all_node_ids = np.empty(total_estimated, dtype=np.int32)
343
+ all_parent_ids = np.empty(total_estimated, dtype=np.int32)
344
+ all_is_tip = np.empty(total_estimated, dtype=np.bool_)
345
+ all_tree_idx = np.empty(total_estimated, dtype=np.int32)
346
+ all_x = np.empty(total_estimated, dtype=np.float32)
347
+ all_y = np.empty(total_estimated, dtype=np.float32)
348
+
349
+ # Pre-allocate mutation arrays (estimate based on mutation density)
350
+ # Simplified: only x, y, tree_idx needed
351
+ estimated_mutations = max(1000, ts.num_mutations // max(1, ts.num_trees) * len(tree_indices) * 2)
352
+ all_mut_tree_idx = np.empty(estimated_mutations, dtype=np.int32) if has_mutations else None
353
+ all_mut_x = np.empty(estimated_mutations, dtype=np.float32) if has_mutations else None
354
+ all_mut_y = np.empty(estimated_mutations, dtype=np.float32) if has_mutations else None
355
+ all_mut_node_id = np.empty(estimated_mutations, dtype=np.int32) if has_mutations else None
356
+
357
+ offset = 0
358
+ mut_offset = 0
359
+ processed_indices = []
360
+
361
+ # Initialize cache tracking
362
+ pre_cached_graphs = pre_cached_graphs or {}
363
+ newly_built_graphs = {}
364
+
365
+ for tree_idx in tree_indices:
366
+ tree_idx = int(tree_idx)
367
+
368
+ if tree_idx < 0 or tree_idx >= ts.num_trees:
369
+ continue
370
+
371
+ # Check pre-cached first, then construct if needed
372
+ if tree_idx in pre_cached_graphs:
373
+ graph = pre_cached_graphs[tree_idx]
374
+ else:
375
+ graph = construct_tree(ts, edges, nodes, breakpoints, tree_idx, min_time, max_time)
376
+ newly_built_graphs[tree_idx] = graph # Track for caching
377
+
378
+ # Get nodes in tree
379
+ indices = np.where(graph.in_tree)[0].astype(np.int32)
380
+ n = len(indices)
381
+
382
+ if n == 0:
383
+ continue
384
+
385
+ # Derive is_tip
386
+ child_counts = np.diff(graph.children_indptr)
387
+ is_tip = child_counts[indices] == 0
388
+
389
+ # Get coordinates (swap for backend convention)
390
+ node_ids = indices
391
+ parent_ids = graph.parent[indices]
392
+ x = graph.y[indices] # SWAP: time -> x
393
+ y = graph.x[indices] # SWAP: layout -> y
394
+
395
+ # Apply sparsification if requested
396
+ if sparsification:
397
+ resolution = int(1.0 / DEFAULT_SPARSIFY_CELL_SIZE)
398
+ keep_mask = _sparsify_tips_only(
399
+ node_ids.astype(np.int32),
400
+ x.astype(np.float32),
401
+ y.astype(np.float32),
402
+ is_tip,
403
+ parent_ids.astype(np.int32),
404
+ resolution
405
+ )
406
+ node_ids = node_ids[keep_mask]
407
+ parent_ids = parent_ids[keep_mask]
408
+ x = x[keep_mask]
409
+ y = y[keep_mask]
410
+ is_tip = is_tip[keep_mask]
411
+ n = len(node_ids)
412
+
413
+ if n > 0:
414
+ # Collapse unary internal nodes (keep roots even if unary).
415
+ order = np.argsort(node_ids)
416
+ sorted_ids = node_ids[order]
417
+ pos = np.searchsorted(sorted_ids, parent_ids)
418
+ parent_local = np.full(n, -1, dtype=np.int32)
419
+ valid = (parent_ids != -1) & (pos < n) & (sorted_ids[pos] == parent_ids)
420
+ parent_local[valid] = order[pos[valid]]
421
+ child_counts = np.bincount(parent_local[parent_local >= 0], minlength=n)
422
+ collapse_mask = child_counts != 1
423
+
424
+ if not np.all(collapse_mask):
425
+ new_parent_ids = parent_ids.copy()
426
+ for i in range(n):
427
+ if not collapse_mask[i]:
428
+ continue
429
+ parent = new_parent_ids[i]
430
+ while parent != -1:
431
+ pos = np.searchsorted(sorted_ids, parent)
432
+ if pos >= n or sorted_ids[pos] != parent:
433
+ parent = -1
434
+ break
435
+ parent_idx = order[pos]
436
+ if collapse_mask[parent_idx]:
437
+ break
438
+ parent = new_parent_ids[parent_idx]
439
+ new_parent_ids[i] = parent
440
+
441
+ node_ids = node_ids[collapse_mask]
442
+ parent_ids = new_parent_ids[collapse_mask]
443
+ x = x[collapse_mask]
444
+ y = y[collapse_mask]
445
+ n = len(node_ids)
446
+
447
+ if n > 0:
448
+ order = np.argsort(node_ids)
449
+ sorted_ids = node_ids[order]
450
+ pos = np.searchsorted(sorted_ids, parent_ids)
451
+ parent_local = np.full(n, -1, dtype=np.int32)
452
+ valid = (parent_ids != -1) & (pos < n) & (sorted_ids[pos] == parent_ids)
453
+ parent_local[valid] = order[pos[valid]]
454
+ child_counts = np.bincount(parent_local[parent_local >= 0], minlength=n)
455
+ is_tip = child_counts == 0
456
+ else:
457
+ is_tip = is_tip[:0]
458
+
459
+ if n == 0:
460
+ continue
461
+
462
+ # Ensure capacity
463
+ while offset + n > len(all_node_ids):
464
+ new_size = len(all_node_ids) * 2
465
+ all_node_ids.resize(new_size, refcheck=False)
466
+ all_parent_ids.resize(new_size, refcheck=False)
467
+ all_is_tip.resize(new_size, refcheck=False)
468
+ all_tree_idx.resize(new_size, refcheck=False)
469
+ all_x.resize(new_size, refcheck=False)
470
+ all_y.resize(new_size, refcheck=False)
471
+
472
+ # Copy node data
473
+ all_node_ids[offset:offset+n] = node_ids
474
+ all_parent_ids[offset:offset+n] = parent_ids
475
+ all_is_tip[offset:offset+n] = is_tip
476
+ all_tree_idx[offset:offset+n] = tree_idx
477
+ all_x[offset:offset+n] = x
478
+ all_y[offset:offset+n] = y
479
+
480
+ offset += n
481
+
482
+ # Collect mutations for this tree interval (inline computation)
483
+ if has_mutations:
484
+ interval_left = breakpoints[tree_idx]
485
+ interval_right = breakpoints[tree_idx + 1]
486
+
487
+ # Filter mutations by genomic position (using pre-extracted mutation_positions)
488
+ mask = (mutation_positions >= interval_left) & (mutation_positions < interval_right)
489
+ mut_indices = np.where(mask)[0]
490
+
491
+ n_muts = len(mut_indices)
492
+ if n_muts > 0:
493
+ # Extract mutation data (simplified: only x, y, tree_idx)
494
+ mut_node_ids = mutations.node[mut_indices].astype(np.int32)
495
+ mut_times = mutations.time[mut_indices]
496
+ mut_parent_ids = graph.parent[mut_node_ids].astype(np.int32)
497
+
498
+ # Reuse graph.x for layout position (aligned with node horizontally)
499
+ mut_layout = graph.x[mut_node_ids].astype(np.float32)
500
+
501
+ # Time normalization (same formula as nodes)
502
+ time_range = max_time - min_time if max_time > min_time else 1.0
503
+
504
+ # Compute normalized time for valid times
505
+ mut_time_norm = (max_time - mut_times) / time_range
506
+
507
+ # Handle NaN times: use midpoint between node and parent in normalized y space
508
+ nan_mask = np.isnan(mut_times)
509
+ if np.any(nan_mask):
510
+ node_y = graph.y[mut_node_ids[nan_mask]]
511
+ parent_ids_for_nan = mut_parent_ids[nan_mask]
512
+ # For roots (parent=-1), use 0.0 (corresponds to max_time)
513
+ parent_y = np.where(
514
+ parent_ids_for_nan >= 0,
515
+ graph.y[np.maximum(parent_ids_for_nan, 0)],
516
+ 0.0
517
+ )
518
+ mut_time_norm[nan_mask] = (node_y + parent_y) / 2.0
519
+
520
+ mut_time_norm = mut_time_norm.astype(np.float32)
521
+
522
+ # SWAP for backend convention (same as nodes)
523
+ mut_x = mut_time_norm # time -> x
524
+ mut_y = mut_layout # layout -> y
525
+
526
+ # Ensure mutation buffer capacity
527
+ while mut_offset + n_muts > len(all_mut_tree_idx):
528
+ new_size = len(all_mut_tree_idx) * 2
529
+ all_mut_tree_idx.resize(new_size, refcheck=False)
530
+ all_mut_x.resize(new_size, refcheck=False)
531
+ all_mut_y.resize(new_size, refcheck=False)
532
+ all_mut_node_id.resize(new_size, refcheck=False)
533
+
534
+ # Copy mutation data (only essential fields)
535
+ all_mut_tree_idx[mut_offset:mut_offset+n_muts] = tree_idx
536
+ all_mut_x[mut_offset:mut_offset+n_muts] = mut_x
537
+ all_mut_y[mut_offset:mut_offset+n_muts] = mut_y
538
+ all_mut_node_id[mut_offset:mut_offset+n_muts] = mut_node_ids
539
+
540
+ mut_offset += n_muts
541
+
542
+ processed_indices.append(tree_idx)
543
+
544
+ # Trim node arrays to actual size
545
+ all_node_ids = all_node_ids[:offset]
546
+ all_parent_ids = all_parent_ids[:offset]
547
+ all_is_tip = all_is_tip[:offset]
548
+ all_tree_idx = all_tree_idx[:offset]
549
+ all_x = all_x[:offset]
550
+ all_y = all_y[:offset]
551
+
552
+ # Trim mutation arrays to actual size
553
+ if has_mutations and mut_offset > 0:
554
+ all_mut_tree_idx = all_mut_tree_idx[:mut_offset]
555
+ all_mut_x = all_mut_x[:mut_offset]
556
+ all_mut_y = all_mut_y[:mut_offset]
557
+ all_mut_node_id = all_mut_node_id[:mut_offset]
558
+
559
+ # Build separate node table
560
+ if offset == 0:
561
+ node_table = pa.table({
562
+ 'node_id': pa.array([], type=pa.int32()),
563
+ 'parent_id': pa.array([], type=pa.int32()),
564
+ 'is_tip': pa.array([], type=pa.bool_()),
565
+ 'tree_idx': pa.array([], type=pa.int32()),
566
+ 'x': pa.array([], type=pa.float32()),
567
+ 'y': pa.array([], type=pa.float32()),
568
+ })
569
+ else:
570
+ node_table = pa.table({
571
+ 'node_id': pa.array(all_node_ids, type=pa.int32()),
572
+ 'parent_id': pa.array(all_parent_ids, type=pa.int32()),
573
+ 'is_tip': pa.array(all_is_tip, type=pa.bool_()),
574
+ 'tree_idx': pa.array(all_tree_idx, type=pa.int32()),
575
+ 'x': pa.array(all_x, type=pa.float32()),
576
+ 'y': pa.array(all_y, type=pa.float32()),
577
+ })
578
+
579
+ # Build separate mutation table (simplified: only x, y, tree_idx, node_id)
580
+ if has_mutations and mut_offset > 0:
581
+ mut_table = pa.table({
582
+ 'mut_x': pa.array(all_mut_x, type=pa.float32()),
583
+ 'mut_y': pa.array(all_mut_y, type=pa.float32()),
584
+ 'mut_tree_idx': pa.array(all_mut_tree_idx, type=pa.int32()),
585
+ 'mut_node_id': pa.array(all_mut_node_id, type=pa.int32()),
586
+ })
587
+ else:
588
+ mut_table = pa.table({
589
+ 'mut_x': pa.array([], type=pa.float32()),
590
+ 'mut_y': pa.array([], type=pa.float32()),
591
+ 'mut_tree_idx': pa.array([], type=pa.int32()),
592
+ 'mut_node_id': pa.array([], type=pa.int32()),
593
+ })
594
+
595
+ # Serialize node table to IPC
596
+ node_sink = pa.BufferOutputStream()
597
+ node_writer = pa.ipc.new_stream(node_sink, node_table.schema)
598
+ node_writer.write_table(node_table)
599
+ node_writer.close()
600
+ node_bytes = node_sink.getvalue().to_pybytes()
601
+
602
+ # Serialize mutation table to IPC
603
+ mut_sink = pa.BufferOutputStream()
604
+ mut_writer = pa.ipc.new_stream(mut_sink, mut_table.schema)
605
+ mut_writer.write_table(mut_table)
606
+ mut_writer.close()
607
+ mut_bytes = mut_sink.getvalue().to_pybytes()
608
+
609
+ # Combine with 4-byte length prefix for node buffer
610
+ # Format: [4-byte node_len (little-endian)][node_bytes][mut_bytes]
611
+ combined = struct.pack('<I', len(node_bytes)) + node_bytes + mut_bytes
612
+
613
+ return combined, min_time, max_time, processed_indices, newly_built_graphs
614
+
615
+
616
+ @njit(cache=True)
617
+ def _sparsify_tips_only(node_ids, x, y, is_tip, parent_ids, resolution):
618
+ """
619
+ Optimized sparsification: only grid-dedupe tips, then trace ancestors.
620
+
621
+ Algorithm:
622
+ 1. Grid-dedupe ONLY tip nodes (skip internal nodes in grid phase)
623
+ 2. Trace path to root for each kept tip
624
+ 3. Keep all ancestors along the path
625
+
626
+ This is faster because tips are ~90% of nodes and we skip
627
+ the grid computation for internal nodes. Some internal nodes
628
+ may appear "dangling" (no children visible) - this is correct
629
+ behavior showing the true ancestry path.
630
+
631
+ Args:
632
+ node_ids: int32 array of node IDs
633
+ x: float32 array of x coordinates (normalized [0,1])
634
+ y: float32 array of y coordinates (normalized [0,1])
635
+ is_tip: bool array indicating tip nodes
636
+ parent_ids: int32 array of parent IDs (-1 for roots)
637
+ resolution: Grid resolution (e.g., 500 for cell_size=0.002)
638
+
639
+ Returns:
640
+ keep: bool array indicating which nodes to keep
641
+ """
642
+ n = len(node_ids)
643
+ keep = np.zeros(n, dtype=np.bool_)
644
+
645
+ if n == 0:
646
+ return keep
647
+
648
+ # Build node_to_idx mapping first (needed for ancestor tracing)
649
+ max_node_id = 0
650
+ for i in range(n):
651
+ if node_ids[i] > max_node_id:
652
+ max_node_id = node_ids[i]
653
+
654
+ node_to_idx = np.full(max_node_id + 1, -1, dtype=np.int32)
655
+ for i in range(n):
656
+ node_to_idx[node_ids[i]] = i
657
+
658
+ # --- Phase 1: Grid-dedupe ONLY tips ---
659
+ seen_cells = Dict.empty(key_type=types.int64, value_type=types.int32)
660
+
661
+ for i in range(n):
662
+ if not is_tip[i]:
663
+ continue # Skip internal nodes in grid phase
664
+
665
+ # Compute grid cell for this tip
666
+ cx = min(int(x[i] * resolution), resolution - 1)
667
+ cy = min(int(y[i] * resolution), resolution - 1)
668
+ key = cx * (resolution + 1) + cy
669
+
670
+ if key not in seen_cells:
671
+ seen_cells[key] = i
672
+ keep[i] = True
673
+
674
+ # --- Phase 2: Trace ancestors for each kept tip ---
675
+ for i in range(n):
676
+ if keep[i] and is_tip[i]: # Only trace from kept tips
677
+ parent = parent_ids[i]
678
+ while parent != -1:
679
+ if parent > max_node_id:
680
+ break
681
+ parent_idx = node_to_idx[parent]
682
+ if parent_idx < 0:
683
+ break
684
+ if keep[parent_idx]:
685
+ break # Already kept, path is connected
686
+ keep[parent_idx] = True
687
+ parent = parent_ids[parent_idx]
688
+
689
+ return keep