lorax-arg 0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lorax/buffer.py +43 -0
- lorax/cache/__init__.py +43 -0
- lorax/cache/csv_tree_graph.py +59 -0
- lorax/cache/disk.py +467 -0
- lorax/cache/file_cache.py +142 -0
- lorax/cache/file_context.py +72 -0
- lorax/cache/lru.py +90 -0
- lorax/cache/tree_graph.py +293 -0
- lorax/cli.py +312 -0
- lorax/cloud/__init__.py +0 -0
- lorax/cloud/gcs_utils.py +205 -0
- lorax/constants.py +66 -0
- lorax/context.py +80 -0
- lorax/csv/__init__.py +7 -0
- lorax/csv/config.py +250 -0
- lorax/csv/layout.py +182 -0
- lorax/csv/newick_tree.py +234 -0
- lorax/handlers.py +998 -0
- lorax/lineage.py +456 -0
- lorax/loaders/__init__.py +0 -0
- lorax/loaders/csv_loader.py +10 -0
- lorax/loaders/loader.py +31 -0
- lorax/loaders/tskit_loader.py +119 -0
- lorax/lorax_app.py +75 -0
- lorax/manager.py +58 -0
- lorax/metadata/__init__.py +0 -0
- lorax/metadata/loader.py +426 -0
- lorax/metadata/mutations.py +146 -0
- lorax/modes.py +190 -0
- lorax/pg.py +183 -0
- lorax/redis_utils.py +30 -0
- lorax/routes.py +137 -0
- lorax/session_manager.py +206 -0
- lorax/sockets/__init__.py +55 -0
- lorax/sockets/connection.py +99 -0
- lorax/sockets/debug.py +47 -0
- lorax/sockets/decorators.py +112 -0
- lorax/sockets/file_ops.py +200 -0
- lorax/sockets/lineage.py +307 -0
- lorax/sockets/metadata.py +232 -0
- lorax/sockets/mutations.py +154 -0
- lorax/sockets/node_search.py +535 -0
- lorax/sockets/tree_layout.py +117 -0
- lorax/sockets/utils.py +10 -0
- lorax/tree_graph/__init__.py +12 -0
- lorax/tree_graph/tree_graph.py +689 -0
- lorax/utils.py +124 -0
- lorax_app/__init__.py +4 -0
- lorax_app/app.py +159 -0
- lorax_app/cli.py +114 -0
- lorax_app/static/X.png +0 -0
- lorax_app/static/assets/index-BCEGlUFi.js +2361 -0
- lorax_app/static/assets/index-iKjzUpA9.css +1 -0
- lorax_app/static/assets/localBackendWorker-BaWwjSV_.js +2 -0
- lorax_app/static/assets/renderDataWorker-BKLdiU7J.js +2 -0
- lorax_app/static/gestures/gesture-flick.ogv +0 -0
- lorax_app/static/gestures/gesture-two-finger-scroll.ogv +0 -0
- lorax_app/static/index.html +14 -0
- lorax_app/static/logo.png +0 -0
- lorax_app/static/lorax-logo.png +0 -0
- lorax_app/static/vite.svg +1 -0
- lorax_arg-0.1.dist-info/METADATA +131 -0
- lorax_arg-0.1.dist-info/RECORD +66 -0
- lorax_arg-0.1.dist-info/WHEEL +5 -0
- lorax_arg-0.1.dist-info/entry_points.txt +4 -0
- lorax_arg-0.1.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,689 @@
|
|
|
1
|
+
"""
|
|
2
|
+
tree_graph.py - Numba-optimized tree construction from tskit tables.
|
|
3
|
+
|
|
4
|
+
This module provides:
|
|
5
|
+
- TreeGraph: Numpy-based tree representation with CSR children and x,y coordinates
|
|
6
|
+
- construct_tree: Build a single tree from tables (Numba-optimized)
|
|
7
|
+
- construct_trees_batch: Build multiple trees efficiently (includes mutation extraction)
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import struct
|
|
11
|
+
import numpy as np
|
|
12
|
+
import pyarrow as pa
|
|
13
|
+
from numba import njit
|
|
14
|
+
from numba.typed import Dict
|
|
15
|
+
from numba import types
|
|
16
|
+
from dataclasses import dataclass
|
|
17
|
+
from typing import List, Optional, Tuple
|
|
18
|
+
|
|
19
|
+
# Default cell size for sparsification (0.2% of normalized [0,1] space)
|
|
20
|
+
DEFAULT_SPARSIFY_CELL_SIZE = 0.002
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@njit(cache=True)
|
|
24
|
+
def _compute_x_postorder(children_indptr, children_data, roots, num_nodes):
|
|
25
|
+
"""
|
|
26
|
+
Numba-compiled post-order traversal for computing x (layout) coordinates.
|
|
27
|
+
|
|
28
|
+
Tips get sequential x values (0, 1, 2, ...), internal nodes get (min + max) / 2 of children.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
children_indptr: CSR indptr array
|
|
32
|
+
children_data: CSR data array (flattened children)
|
|
33
|
+
roots: Array of root node IDs
|
|
34
|
+
num_nodes: Total number of nodes
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
(x, tip_counter): x coordinates array and number of tips
|
|
38
|
+
"""
|
|
39
|
+
x = np.full(num_nodes, -1.0, dtype=np.float32)
|
|
40
|
+
tip_counter = 0
|
|
41
|
+
|
|
42
|
+
# Pre-allocated stack arrays (avoid Python list)
|
|
43
|
+
stack_nodes = np.empty(num_nodes, dtype=np.int32)
|
|
44
|
+
stack_visited = np.empty(num_nodes, dtype=np.uint8) # 0=False, 1=True
|
|
45
|
+
|
|
46
|
+
for i in range(len(roots)):
|
|
47
|
+
root = roots[i]
|
|
48
|
+
stack_ptr = 0
|
|
49
|
+
|
|
50
|
+
stack_nodes[stack_ptr] = root
|
|
51
|
+
stack_visited[stack_ptr] = 0
|
|
52
|
+
stack_ptr += 1
|
|
53
|
+
|
|
54
|
+
while stack_ptr > 0:
|
|
55
|
+
stack_ptr -= 1
|
|
56
|
+
node = stack_nodes[stack_ptr]
|
|
57
|
+
visited = stack_visited[stack_ptr]
|
|
58
|
+
|
|
59
|
+
start = children_indptr[node]
|
|
60
|
+
end = children_indptr[node + 1]
|
|
61
|
+
num_children = end - start
|
|
62
|
+
|
|
63
|
+
if visited == 0 and num_children > 0:
|
|
64
|
+
# Push node back as visited
|
|
65
|
+
stack_nodes[stack_ptr] = node
|
|
66
|
+
stack_visited[stack_ptr] = 1
|
|
67
|
+
stack_ptr += 1
|
|
68
|
+
|
|
69
|
+
# Push children
|
|
70
|
+
for j in range(start, end):
|
|
71
|
+
stack_nodes[stack_ptr] = children_data[j]
|
|
72
|
+
stack_visited[stack_ptr] = 0
|
|
73
|
+
stack_ptr += 1
|
|
74
|
+
else:
|
|
75
|
+
# Post-order processing
|
|
76
|
+
if num_children == 0:
|
|
77
|
+
x[node] = tip_counter
|
|
78
|
+
tip_counter += 1
|
|
79
|
+
else:
|
|
80
|
+
# Compute (min + max) / 2 of children (matches jstree.js)
|
|
81
|
+
min_x = x[children_data[start]]
|
|
82
|
+
max_x = x[children_data[start]]
|
|
83
|
+
for j in range(start + 1, end):
|
|
84
|
+
child_x = x[children_data[j]]
|
|
85
|
+
if child_x < min_x:
|
|
86
|
+
min_x = child_x
|
|
87
|
+
if child_x > max_x:
|
|
88
|
+
max_x = child_x
|
|
89
|
+
x[node] = (min_x + max_x) / 2.0
|
|
90
|
+
|
|
91
|
+
return x, tip_counter
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@dataclass
|
|
95
|
+
class TreeGraph:
|
|
96
|
+
"""
|
|
97
|
+
Graph representation using numpy arrays with CSR format for children.
|
|
98
|
+
|
|
99
|
+
Attributes:
|
|
100
|
+
parent: int32 array where parent[node_id] = parent_id (-1 for root)
|
|
101
|
+
time: float32 array of raw node times
|
|
102
|
+
children_indptr: int32 CSR row pointers (length = num_nodes + 1)
|
|
103
|
+
children_data: int32 flattened children array
|
|
104
|
+
x: float32 layout position [0,1] (tips spread, internal=(min+max)/2 of children)
|
|
105
|
+
y: float32 normalized time [0,1] (min_time=0, max_time=1)
|
|
106
|
+
in_tree: bool array indicating which nodes are in this tree
|
|
107
|
+
"""
|
|
108
|
+
parent: np.ndarray
|
|
109
|
+
time: np.ndarray
|
|
110
|
+
children_indptr: np.ndarray
|
|
111
|
+
children_data: np.ndarray
|
|
112
|
+
x: np.ndarray
|
|
113
|
+
y: np.ndarray
|
|
114
|
+
in_tree: np.ndarray
|
|
115
|
+
|
|
116
|
+
def children(self, node_id: int) -> np.ndarray:
|
|
117
|
+
"""Get children of a node as numpy array slice (zero-copy)."""
|
|
118
|
+
return self.children_data[self.children_indptr[node_id]:self.children_indptr[node_id + 1]]
|
|
119
|
+
|
|
120
|
+
def is_tip(self, node_id: int) -> bool:
|
|
121
|
+
"""Check if a node is a tip (no children)."""
|
|
122
|
+
return self.children_indptr[node_id + 1] == self.children_indptr[node_id]
|
|
123
|
+
|
|
124
|
+
def get_node_x(self, node_id: int) -> float:
|
|
125
|
+
"""Get the x (layout) coordinate for a node."""
|
|
126
|
+
return self.x[node_id] if node_id >= 0 and node_id < len(self.x) else 0.5
|
|
127
|
+
|
|
128
|
+
def to_pyarrow(self, tree_idx: int = 0) -> bytes:
|
|
129
|
+
"""
|
|
130
|
+
Serialize TreeGraph to PyArrow IPC format for frontend rendering.
|
|
131
|
+
|
|
132
|
+
Note: Coordinates are swapped to match backend convention:
|
|
133
|
+
- Backend x = self.y (time)
|
|
134
|
+
- Backend y = self.x (layout)
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
tree_idx: Tree index to include in output (for multi-tree rendering)
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
bytes: PyArrow IPC binary data ready to send to frontend
|
|
141
|
+
"""
|
|
142
|
+
# Get nodes that are in this tree
|
|
143
|
+
indices = np.where(self.in_tree)[0].astype(np.int32)
|
|
144
|
+
n = len(indices)
|
|
145
|
+
|
|
146
|
+
if n == 0:
|
|
147
|
+
# Return empty table
|
|
148
|
+
table = pa.table({
|
|
149
|
+
'node_id': pa.array([], type=pa.int32()),
|
|
150
|
+
'parent_id': pa.array([], type=pa.int32()),
|
|
151
|
+
'is_tip': pa.array([], type=pa.bool_()),
|
|
152
|
+
'tree_idx': pa.array([], type=pa.int32()),
|
|
153
|
+
'x': pa.array([], type=pa.float32()),
|
|
154
|
+
'y': pa.array([], type=pa.float32()),
|
|
155
|
+
})
|
|
156
|
+
else:
|
|
157
|
+
# Derive is_tip from CSR: nodes with no children
|
|
158
|
+
child_counts = np.diff(self.children_indptr)
|
|
159
|
+
is_tip = child_counts[indices] == 0
|
|
160
|
+
|
|
161
|
+
# Build PyArrow table (swap x<->y for backend convention)
|
|
162
|
+
table = pa.table({
|
|
163
|
+
'node_id': pa.array(indices, type=pa.int32()),
|
|
164
|
+
'parent_id': pa.array(self.parent[indices], type=pa.int32()),
|
|
165
|
+
'is_tip': pa.array(is_tip, type=pa.bool_()),
|
|
166
|
+
'tree_idx': pa.array(np.full(n, tree_idx, dtype=np.int32), type=pa.int32()),
|
|
167
|
+
'x': pa.array(self.y[indices], type=pa.float32()), # SWAP: time -> x
|
|
168
|
+
'y': pa.array(self.x[indices], type=pa.float32()), # SWAP: layout -> y
|
|
169
|
+
})
|
|
170
|
+
|
|
171
|
+
# Serialize to IPC format
|
|
172
|
+
sink = pa.BufferOutputStream()
|
|
173
|
+
writer = pa.ipc.new_stream(sink, table.schema)
|
|
174
|
+
writer.write_table(table)
|
|
175
|
+
writer.close()
|
|
176
|
+
|
|
177
|
+
return sink.getvalue().to_pybytes()
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def construct_tree(
|
|
181
|
+
ts,
|
|
182
|
+
edges,
|
|
183
|
+
nodes,
|
|
184
|
+
breakpoints,
|
|
185
|
+
index: int,
|
|
186
|
+
min_time: Optional[float] = None,
|
|
187
|
+
max_time: Optional[float] = None
|
|
188
|
+
) -> TreeGraph:
|
|
189
|
+
"""
|
|
190
|
+
Construct tree with x,y coordinates using Numba-optimized post-order traversal.
|
|
191
|
+
|
|
192
|
+
Args:
|
|
193
|
+
ts: tskit TreeSequence object
|
|
194
|
+
edges: ts.tables.edges (pre-extracted for reuse)
|
|
195
|
+
nodes: ts.tables.nodes (pre-extracted for reuse)
|
|
196
|
+
breakpoints: list/array of breakpoints (pre-extracted for reuse)
|
|
197
|
+
index: Tree index
|
|
198
|
+
min_time: Optional global min time (default: ts.min_time)
|
|
199
|
+
max_time: Optional global max time (default: ts.max_time)
|
|
200
|
+
|
|
201
|
+
Returns:
|
|
202
|
+
TreeGraph with CSR children and x,y coordinates in [0,1].
|
|
203
|
+
"""
|
|
204
|
+
if index < 0 or index >= ts.num_trees:
|
|
205
|
+
raise ValueError(f"Tree index {index} out of range [0, {ts.num_trees - 1}]")
|
|
206
|
+
|
|
207
|
+
interval_left = breakpoints[index]
|
|
208
|
+
num_nodes = len(nodes.time)
|
|
209
|
+
node_times = nodes.time
|
|
210
|
+
|
|
211
|
+
# Use provided min/max or compute from ts
|
|
212
|
+
if min_time is None:
|
|
213
|
+
min_time = ts.min_time
|
|
214
|
+
if max_time is None:
|
|
215
|
+
max_time = ts.max_time
|
|
216
|
+
|
|
217
|
+
# === Edge filtering & parent array ===
|
|
218
|
+
active_mask = (edges.left <= interval_left) & (edges.right > interval_left)
|
|
219
|
+
active_parents = edges.parent[active_mask]
|
|
220
|
+
active_children = edges.child[active_mask]
|
|
221
|
+
|
|
222
|
+
parent = np.full(num_nodes, -1, dtype=np.int32)
|
|
223
|
+
parent[active_children] = active_parents
|
|
224
|
+
|
|
225
|
+
# === CSR children structure ===
|
|
226
|
+
child_counts = np.bincount(active_parents, minlength=num_nodes).astype(np.int32)
|
|
227
|
+
children_indptr = np.zeros(num_nodes + 1, dtype=np.int32)
|
|
228
|
+
children_indptr[1:] = np.cumsum(child_counts)
|
|
229
|
+
sort_idx = np.argsort(active_parents, kind='stable')
|
|
230
|
+
children_data = active_children[sort_idx].astype(np.int32)
|
|
231
|
+
|
|
232
|
+
# === Track which nodes are in this tree ===
|
|
233
|
+
in_tree = np.zeros(num_nodes, dtype=np.bool_)
|
|
234
|
+
in_tree[active_children] = True
|
|
235
|
+
in_tree[active_parents] = True
|
|
236
|
+
|
|
237
|
+
# === Y coordinate: normalized time (vectorized) ===
|
|
238
|
+
# Inverted: max_time → 0, min_time → 1 (root at left, tips at right)
|
|
239
|
+
time_range = max_time - min_time if max_time > min_time else 1.0
|
|
240
|
+
y = ((max_time - node_times) / time_range).astype(np.float32)
|
|
241
|
+
|
|
242
|
+
# === X coordinate: Numba-optimized post-order traversal ===
|
|
243
|
+
roots = np.where(in_tree & (parent == -1))[0].astype(np.int32)
|
|
244
|
+
x, tip_counter = _compute_x_postorder(children_indptr, children_data, roots, num_nodes)
|
|
245
|
+
|
|
246
|
+
# Normalize x to [0, 1]
|
|
247
|
+
if tip_counter > 1:
|
|
248
|
+
x[in_tree] /= (tip_counter - 1)
|
|
249
|
+
|
|
250
|
+
return TreeGraph(
|
|
251
|
+
parent=parent,
|
|
252
|
+
time=node_times.astype(np.float32),
|
|
253
|
+
children_indptr=children_indptr,
|
|
254
|
+
children_data=children_data,
|
|
255
|
+
x=x,
|
|
256
|
+
y=y,
|
|
257
|
+
in_tree=in_tree
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
def construct_trees_batch(
|
|
262
|
+
ts,
|
|
263
|
+
tree_indices: List[int],
|
|
264
|
+
sparsification: bool = False,
|
|
265
|
+
include_mutations: bool = True,
|
|
266
|
+
pre_cached_graphs: Optional[dict] = None
|
|
267
|
+
) -> Tuple[bytes, float, float, List[int], dict]:
|
|
268
|
+
"""
|
|
269
|
+
Construct multiple trees and return combined PyArrow buffer.
|
|
270
|
+
|
|
271
|
+
This is the main entry point for the backend handler.
|
|
272
|
+
|
|
273
|
+
Args:
|
|
274
|
+
ts: tskit TreeSequence object
|
|
275
|
+
tree_indices: List of tree indices to process
|
|
276
|
+
sparsification: Enable tip-only sparsification (default False)
|
|
277
|
+
include_mutations: Whether to include mutation data in buffer
|
|
278
|
+
pre_cached_graphs: Optional dict mapping tree_idx -> TreeGraph for cache hits
|
|
279
|
+
|
|
280
|
+
Returns:
|
|
281
|
+
Tuple of (buffer, global_min_time, global_max_time, tree_indices, newly_built_graphs)
|
|
282
|
+
where newly_built_graphs is a dict mapping tree_idx -> TreeGraph for trees constructed
|
|
283
|
+
"""
|
|
284
|
+
# Pre-extract tables for reuse
|
|
285
|
+
edges = ts.tables.edges
|
|
286
|
+
nodes = ts.tables.nodes
|
|
287
|
+
breakpoints = list(ts.breakpoints())
|
|
288
|
+
|
|
289
|
+
min_time = float(ts.min_time)
|
|
290
|
+
max_time = float(ts.max_time)
|
|
291
|
+
|
|
292
|
+
# Check if tree sequence has mutations
|
|
293
|
+
has_mutations = include_mutations and ts.num_mutations > 0
|
|
294
|
+
|
|
295
|
+
# Pre-extract mutation tables and positions (avoid repeated lookups in loop)
|
|
296
|
+
if has_mutations:
|
|
297
|
+
sites = ts.tables.sites
|
|
298
|
+
mutations = ts.tables.mutations
|
|
299
|
+
mutation_positions = sites.position[mutations.site]
|
|
300
|
+
else:
|
|
301
|
+
sites = None
|
|
302
|
+
mutations = None
|
|
303
|
+
mutation_positions = None
|
|
304
|
+
|
|
305
|
+
if len(tree_indices) == 0:
|
|
306
|
+
# Return empty buffer with separate node and mutation tables
|
|
307
|
+
node_table = pa.table({
|
|
308
|
+
'node_id': pa.array([], type=pa.int32()),
|
|
309
|
+
'parent_id': pa.array([], type=pa.int32()),
|
|
310
|
+
'is_tip': pa.array([], type=pa.bool_()),
|
|
311
|
+
'tree_idx': pa.array([], type=pa.int32()),
|
|
312
|
+
'x': pa.array([], type=pa.float32()),
|
|
313
|
+
'y': pa.array([], type=pa.float32()),
|
|
314
|
+
})
|
|
315
|
+
mut_table = pa.table({
|
|
316
|
+
'mut_x': pa.array([], type=pa.float32()),
|
|
317
|
+
'mut_y': pa.array([], type=pa.float32()),
|
|
318
|
+
'mut_tree_idx': pa.array([], type=pa.int32()),
|
|
319
|
+
})
|
|
320
|
+
|
|
321
|
+
node_sink = pa.BufferOutputStream()
|
|
322
|
+
node_writer = pa.ipc.new_stream(node_sink, node_table.schema)
|
|
323
|
+
node_writer.write_table(node_table)
|
|
324
|
+
node_writer.close()
|
|
325
|
+
node_bytes = node_sink.getvalue().to_pybytes()
|
|
326
|
+
|
|
327
|
+
mut_sink = pa.BufferOutputStream()
|
|
328
|
+
mut_writer = pa.ipc.new_stream(mut_sink, mut_table.schema)
|
|
329
|
+
mut_writer.write_table(mut_table)
|
|
330
|
+
mut_writer.close()
|
|
331
|
+
mut_bytes = mut_sink.getvalue().to_pybytes()
|
|
332
|
+
|
|
333
|
+
combined = struct.pack('<I', len(node_bytes)) + node_bytes + mut_bytes
|
|
334
|
+
return combined, min_time, max_time, [], {}
|
|
335
|
+
|
|
336
|
+
# Estimate total nodes for pre-allocation
|
|
337
|
+
sample_tree = ts.at_index(int(tree_indices[0]) if tree_indices else 0)
|
|
338
|
+
estimated_nodes_per_tree = sample_tree.num_nodes
|
|
339
|
+
total_estimated = estimated_nodes_per_tree * len(tree_indices) * 2
|
|
340
|
+
|
|
341
|
+
# Pre-allocate node arrays
|
|
342
|
+
all_node_ids = np.empty(total_estimated, dtype=np.int32)
|
|
343
|
+
all_parent_ids = np.empty(total_estimated, dtype=np.int32)
|
|
344
|
+
all_is_tip = np.empty(total_estimated, dtype=np.bool_)
|
|
345
|
+
all_tree_idx = np.empty(total_estimated, dtype=np.int32)
|
|
346
|
+
all_x = np.empty(total_estimated, dtype=np.float32)
|
|
347
|
+
all_y = np.empty(total_estimated, dtype=np.float32)
|
|
348
|
+
|
|
349
|
+
# Pre-allocate mutation arrays (estimate based on mutation density)
|
|
350
|
+
# Simplified: only x, y, tree_idx needed
|
|
351
|
+
estimated_mutations = max(1000, ts.num_mutations // max(1, ts.num_trees) * len(tree_indices) * 2)
|
|
352
|
+
all_mut_tree_idx = np.empty(estimated_mutations, dtype=np.int32) if has_mutations else None
|
|
353
|
+
all_mut_x = np.empty(estimated_mutations, dtype=np.float32) if has_mutations else None
|
|
354
|
+
all_mut_y = np.empty(estimated_mutations, dtype=np.float32) if has_mutations else None
|
|
355
|
+
all_mut_node_id = np.empty(estimated_mutations, dtype=np.int32) if has_mutations else None
|
|
356
|
+
|
|
357
|
+
offset = 0
|
|
358
|
+
mut_offset = 0
|
|
359
|
+
processed_indices = []
|
|
360
|
+
|
|
361
|
+
# Initialize cache tracking
|
|
362
|
+
pre_cached_graphs = pre_cached_graphs or {}
|
|
363
|
+
newly_built_graphs = {}
|
|
364
|
+
|
|
365
|
+
for tree_idx in tree_indices:
|
|
366
|
+
tree_idx = int(tree_idx)
|
|
367
|
+
|
|
368
|
+
if tree_idx < 0 or tree_idx >= ts.num_trees:
|
|
369
|
+
continue
|
|
370
|
+
|
|
371
|
+
# Check pre-cached first, then construct if needed
|
|
372
|
+
if tree_idx in pre_cached_graphs:
|
|
373
|
+
graph = pre_cached_graphs[tree_idx]
|
|
374
|
+
else:
|
|
375
|
+
graph = construct_tree(ts, edges, nodes, breakpoints, tree_idx, min_time, max_time)
|
|
376
|
+
newly_built_graphs[tree_idx] = graph # Track for caching
|
|
377
|
+
|
|
378
|
+
# Get nodes in tree
|
|
379
|
+
indices = np.where(graph.in_tree)[0].astype(np.int32)
|
|
380
|
+
n = len(indices)
|
|
381
|
+
|
|
382
|
+
if n == 0:
|
|
383
|
+
continue
|
|
384
|
+
|
|
385
|
+
# Derive is_tip
|
|
386
|
+
child_counts = np.diff(graph.children_indptr)
|
|
387
|
+
is_tip = child_counts[indices] == 0
|
|
388
|
+
|
|
389
|
+
# Get coordinates (swap for backend convention)
|
|
390
|
+
node_ids = indices
|
|
391
|
+
parent_ids = graph.parent[indices]
|
|
392
|
+
x = graph.y[indices] # SWAP: time -> x
|
|
393
|
+
y = graph.x[indices] # SWAP: layout -> y
|
|
394
|
+
|
|
395
|
+
# Apply sparsification if requested
|
|
396
|
+
if sparsification:
|
|
397
|
+
resolution = int(1.0 / DEFAULT_SPARSIFY_CELL_SIZE)
|
|
398
|
+
keep_mask = _sparsify_tips_only(
|
|
399
|
+
node_ids.astype(np.int32),
|
|
400
|
+
x.astype(np.float32),
|
|
401
|
+
y.astype(np.float32),
|
|
402
|
+
is_tip,
|
|
403
|
+
parent_ids.astype(np.int32),
|
|
404
|
+
resolution
|
|
405
|
+
)
|
|
406
|
+
node_ids = node_ids[keep_mask]
|
|
407
|
+
parent_ids = parent_ids[keep_mask]
|
|
408
|
+
x = x[keep_mask]
|
|
409
|
+
y = y[keep_mask]
|
|
410
|
+
is_tip = is_tip[keep_mask]
|
|
411
|
+
n = len(node_ids)
|
|
412
|
+
|
|
413
|
+
if n > 0:
|
|
414
|
+
# Collapse unary internal nodes (keep roots even if unary).
|
|
415
|
+
order = np.argsort(node_ids)
|
|
416
|
+
sorted_ids = node_ids[order]
|
|
417
|
+
pos = np.searchsorted(sorted_ids, parent_ids)
|
|
418
|
+
parent_local = np.full(n, -1, dtype=np.int32)
|
|
419
|
+
valid = (parent_ids != -1) & (pos < n) & (sorted_ids[pos] == parent_ids)
|
|
420
|
+
parent_local[valid] = order[pos[valid]]
|
|
421
|
+
child_counts = np.bincount(parent_local[parent_local >= 0], minlength=n)
|
|
422
|
+
collapse_mask = child_counts != 1
|
|
423
|
+
|
|
424
|
+
if not np.all(collapse_mask):
|
|
425
|
+
new_parent_ids = parent_ids.copy()
|
|
426
|
+
for i in range(n):
|
|
427
|
+
if not collapse_mask[i]:
|
|
428
|
+
continue
|
|
429
|
+
parent = new_parent_ids[i]
|
|
430
|
+
while parent != -1:
|
|
431
|
+
pos = np.searchsorted(sorted_ids, parent)
|
|
432
|
+
if pos >= n or sorted_ids[pos] != parent:
|
|
433
|
+
parent = -1
|
|
434
|
+
break
|
|
435
|
+
parent_idx = order[pos]
|
|
436
|
+
if collapse_mask[parent_idx]:
|
|
437
|
+
break
|
|
438
|
+
parent = new_parent_ids[parent_idx]
|
|
439
|
+
new_parent_ids[i] = parent
|
|
440
|
+
|
|
441
|
+
node_ids = node_ids[collapse_mask]
|
|
442
|
+
parent_ids = new_parent_ids[collapse_mask]
|
|
443
|
+
x = x[collapse_mask]
|
|
444
|
+
y = y[collapse_mask]
|
|
445
|
+
n = len(node_ids)
|
|
446
|
+
|
|
447
|
+
if n > 0:
|
|
448
|
+
order = np.argsort(node_ids)
|
|
449
|
+
sorted_ids = node_ids[order]
|
|
450
|
+
pos = np.searchsorted(sorted_ids, parent_ids)
|
|
451
|
+
parent_local = np.full(n, -1, dtype=np.int32)
|
|
452
|
+
valid = (parent_ids != -1) & (pos < n) & (sorted_ids[pos] == parent_ids)
|
|
453
|
+
parent_local[valid] = order[pos[valid]]
|
|
454
|
+
child_counts = np.bincount(parent_local[parent_local >= 0], minlength=n)
|
|
455
|
+
is_tip = child_counts == 0
|
|
456
|
+
else:
|
|
457
|
+
is_tip = is_tip[:0]
|
|
458
|
+
|
|
459
|
+
if n == 0:
|
|
460
|
+
continue
|
|
461
|
+
|
|
462
|
+
# Ensure capacity
|
|
463
|
+
while offset + n > len(all_node_ids):
|
|
464
|
+
new_size = len(all_node_ids) * 2
|
|
465
|
+
all_node_ids.resize(new_size, refcheck=False)
|
|
466
|
+
all_parent_ids.resize(new_size, refcheck=False)
|
|
467
|
+
all_is_tip.resize(new_size, refcheck=False)
|
|
468
|
+
all_tree_idx.resize(new_size, refcheck=False)
|
|
469
|
+
all_x.resize(new_size, refcheck=False)
|
|
470
|
+
all_y.resize(new_size, refcheck=False)
|
|
471
|
+
|
|
472
|
+
# Copy node data
|
|
473
|
+
all_node_ids[offset:offset+n] = node_ids
|
|
474
|
+
all_parent_ids[offset:offset+n] = parent_ids
|
|
475
|
+
all_is_tip[offset:offset+n] = is_tip
|
|
476
|
+
all_tree_idx[offset:offset+n] = tree_idx
|
|
477
|
+
all_x[offset:offset+n] = x
|
|
478
|
+
all_y[offset:offset+n] = y
|
|
479
|
+
|
|
480
|
+
offset += n
|
|
481
|
+
|
|
482
|
+
# Collect mutations for this tree interval (inline computation)
|
|
483
|
+
if has_mutations:
|
|
484
|
+
interval_left = breakpoints[tree_idx]
|
|
485
|
+
interval_right = breakpoints[tree_idx + 1]
|
|
486
|
+
|
|
487
|
+
# Filter mutations by genomic position (using pre-extracted mutation_positions)
|
|
488
|
+
mask = (mutation_positions >= interval_left) & (mutation_positions < interval_right)
|
|
489
|
+
mut_indices = np.where(mask)[0]
|
|
490
|
+
|
|
491
|
+
n_muts = len(mut_indices)
|
|
492
|
+
if n_muts > 0:
|
|
493
|
+
# Extract mutation data (simplified: only x, y, tree_idx)
|
|
494
|
+
mut_node_ids = mutations.node[mut_indices].astype(np.int32)
|
|
495
|
+
mut_times = mutations.time[mut_indices]
|
|
496
|
+
mut_parent_ids = graph.parent[mut_node_ids].astype(np.int32)
|
|
497
|
+
|
|
498
|
+
# Reuse graph.x for layout position (aligned with node horizontally)
|
|
499
|
+
mut_layout = graph.x[mut_node_ids].astype(np.float32)
|
|
500
|
+
|
|
501
|
+
# Time normalization (same formula as nodes)
|
|
502
|
+
time_range = max_time - min_time if max_time > min_time else 1.0
|
|
503
|
+
|
|
504
|
+
# Compute normalized time for valid times
|
|
505
|
+
mut_time_norm = (max_time - mut_times) / time_range
|
|
506
|
+
|
|
507
|
+
# Handle NaN times: use midpoint between node and parent in normalized y space
|
|
508
|
+
nan_mask = np.isnan(mut_times)
|
|
509
|
+
if np.any(nan_mask):
|
|
510
|
+
node_y = graph.y[mut_node_ids[nan_mask]]
|
|
511
|
+
parent_ids_for_nan = mut_parent_ids[nan_mask]
|
|
512
|
+
# For roots (parent=-1), use 0.0 (corresponds to max_time)
|
|
513
|
+
parent_y = np.where(
|
|
514
|
+
parent_ids_for_nan >= 0,
|
|
515
|
+
graph.y[np.maximum(parent_ids_for_nan, 0)],
|
|
516
|
+
0.0
|
|
517
|
+
)
|
|
518
|
+
mut_time_norm[nan_mask] = (node_y + parent_y) / 2.0
|
|
519
|
+
|
|
520
|
+
mut_time_norm = mut_time_norm.astype(np.float32)
|
|
521
|
+
|
|
522
|
+
# SWAP for backend convention (same as nodes)
|
|
523
|
+
mut_x = mut_time_norm # time -> x
|
|
524
|
+
mut_y = mut_layout # layout -> y
|
|
525
|
+
|
|
526
|
+
# Ensure mutation buffer capacity
|
|
527
|
+
while mut_offset + n_muts > len(all_mut_tree_idx):
|
|
528
|
+
new_size = len(all_mut_tree_idx) * 2
|
|
529
|
+
all_mut_tree_idx.resize(new_size, refcheck=False)
|
|
530
|
+
all_mut_x.resize(new_size, refcheck=False)
|
|
531
|
+
all_mut_y.resize(new_size, refcheck=False)
|
|
532
|
+
all_mut_node_id.resize(new_size, refcheck=False)
|
|
533
|
+
|
|
534
|
+
# Copy mutation data (only essential fields)
|
|
535
|
+
all_mut_tree_idx[mut_offset:mut_offset+n_muts] = tree_idx
|
|
536
|
+
all_mut_x[mut_offset:mut_offset+n_muts] = mut_x
|
|
537
|
+
all_mut_y[mut_offset:mut_offset+n_muts] = mut_y
|
|
538
|
+
all_mut_node_id[mut_offset:mut_offset+n_muts] = mut_node_ids
|
|
539
|
+
|
|
540
|
+
mut_offset += n_muts
|
|
541
|
+
|
|
542
|
+
processed_indices.append(tree_idx)
|
|
543
|
+
|
|
544
|
+
# Trim node arrays to actual size
|
|
545
|
+
all_node_ids = all_node_ids[:offset]
|
|
546
|
+
all_parent_ids = all_parent_ids[:offset]
|
|
547
|
+
all_is_tip = all_is_tip[:offset]
|
|
548
|
+
all_tree_idx = all_tree_idx[:offset]
|
|
549
|
+
all_x = all_x[:offset]
|
|
550
|
+
all_y = all_y[:offset]
|
|
551
|
+
|
|
552
|
+
# Trim mutation arrays to actual size
|
|
553
|
+
if has_mutations and mut_offset > 0:
|
|
554
|
+
all_mut_tree_idx = all_mut_tree_idx[:mut_offset]
|
|
555
|
+
all_mut_x = all_mut_x[:mut_offset]
|
|
556
|
+
all_mut_y = all_mut_y[:mut_offset]
|
|
557
|
+
all_mut_node_id = all_mut_node_id[:mut_offset]
|
|
558
|
+
|
|
559
|
+
# Build separate node table
|
|
560
|
+
if offset == 0:
|
|
561
|
+
node_table = pa.table({
|
|
562
|
+
'node_id': pa.array([], type=pa.int32()),
|
|
563
|
+
'parent_id': pa.array([], type=pa.int32()),
|
|
564
|
+
'is_tip': pa.array([], type=pa.bool_()),
|
|
565
|
+
'tree_idx': pa.array([], type=pa.int32()),
|
|
566
|
+
'x': pa.array([], type=pa.float32()),
|
|
567
|
+
'y': pa.array([], type=pa.float32()),
|
|
568
|
+
})
|
|
569
|
+
else:
|
|
570
|
+
node_table = pa.table({
|
|
571
|
+
'node_id': pa.array(all_node_ids, type=pa.int32()),
|
|
572
|
+
'parent_id': pa.array(all_parent_ids, type=pa.int32()),
|
|
573
|
+
'is_tip': pa.array(all_is_tip, type=pa.bool_()),
|
|
574
|
+
'tree_idx': pa.array(all_tree_idx, type=pa.int32()),
|
|
575
|
+
'x': pa.array(all_x, type=pa.float32()),
|
|
576
|
+
'y': pa.array(all_y, type=pa.float32()),
|
|
577
|
+
})
|
|
578
|
+
|
|
579
|
+
# Build separate mutation table (simplified: only x, y, tree_idx, node_id)
|
|
580
|
+
if has_mutations and mut_offset > 0:
|
|
581
|
+
mut_table = pa.table({
|
|
582
|
+
'mut_x': pa.array(all_mut_x, type=pa.float32()),
|
|
583
|
+
'mut_y': pa.array(all_mut_y, type=pa.float32()),
|
|
584
|
+
'mut_tree_idx': pa.array(all_mut_tree_idx, type=pa.int32()),
|
|
585
|
+
'mut_node_id': pa.array(all_mut_node_id, type=pa.int32()),
|
|
586
|
+
})
|
|
587
|
+
else:
|
|
588
|
+
mut_table = pa.table({
|
|
589
|
+
'mut_x': pa.array([], type=pa.float32()),
|
|
590
|
+
'mut_y': pa.array([], type=pa.float32()),
|
|
591
|
+
'mut_tree_idx': pa.array([], type=pa.int32()),
|
|
592
|
+
'mut_node_id': pa.array([], type=pa.int32()),
|
|
593
|
+
})
|
|
594
|
+
|
|
595
|
+
# Serialize node table to IPC
|
|
596
|
+
node_sink = pa.BufferOutputStream()
|
|
597
|
+
node_writer = pa.ipc.new_stream(node_sink, node_table.schema)
|
|
598
|
+
node_writer.write_table(node_table)
|
|
599
|
+
node_writer.close()
|
|
600
|
+
node_bytes = node_sink.getvalue().to_pybytes()
|
|
601
|
+
|
|
602
|
+
# Serialize mutation table to IPC
|
|
603
|
+
mut_sink = pa.BufferOutputStream()
|
|
604
|
+
mut_writer = pa.ipc.new_stream(mut_sink, mut_table.schema)
|
|
605
|
+
mut_writer.write_table(mut_table)
|
|
606
|
+
mut_writer.close()
|
|
607
|
+
mut_bytes = mut_sink.getvalue().to_pybytes()
|
|
608
|
+
|
|
609
|
+
# Combine with 4-byte length prefix for node buffer
|
|
610
|
+
# Format: [4-byte node_len (little-endian)][node_bytes][mut_bytes]
|
|
611
|
+
combined = struct.pack('<I', len(node_bytes)) + node_bytes + mut_bytes
|
|
612
|
+
|
|
613
|
+
return combined, min_time, max_time, processed_indices, newly_built_graphs
|
|
614
|
+
|
|
615
|
+
|
|
616
|
+
@njit(cache=True)
|
|
617
|
+
def _sparsify_tips_only(node_ids, x, y, is_tip, parent_ids, resolution):
|
|
618
|
+
"""
|
|
619
|
+
Optimized sparsification: only grid-dedupe tips, then trace ancestors.
|
|
620
|
+
|
|
621
|
+
Algorithm:
|
|
622
|
+
1. Grid-dedupe ONLY tip nodes (skip internal nodes in grid phase)
|
|
623
|
+
2. Trace path to root for each kept tip
|
|
624
|
+
3. Keep all ancestors along the path
|
|
625
|
+
|
|
626
|
+
This is faster because tips are ~90% of nodes and we skip
|
|
627
|
+
the grid computation for internal nodes. Some internal nodes
|
|
628
|
+
may appear "dangling" (no children visible) - this is correct
|
|
629
|
+
behavior showing the true ancestry path.
|
|
630
|
+
|
|
631
|
+
Args:
|
|
632
|
+
node_ids: int32 array of node IDs
|
|
633
|
+
x: float32 array of x coordinates (normalized [0,1])
|
|
634
|
+
y: float32 array of y coordinates (normalized [0,1])
|
|
635
|
+
is_tip: bool array indicating tip nodes
|
|
636
|
+
parent_ids: int32 array of parent IDs (-1 for roots)
|
|
637
|
+
resolution: Grid resolution (e.g., 500 for cell_size=0.002)
|
|
638
|
+
|
|
639
|
+
Returns:
|
|
640
|
+
keep: bool array indicating which nodes to keep
|
|
641
|
+
"""
|
|
642
|
+
n = len(node_ids)
|
|
643
|
+
keep = np.zeros(n, dtype=np.bool_)
|
|
644
|
+
|
|
645
|
+
if n == 0:
|
|
646
|
+
return keep
|
|
647
|
+
|
|
648
|
+
# Build node_to_idx mapping first (needed for ancestor tracing)
|
|
649
|
+
max_node_id = 0
|
|
650
|
+
for i in range(n):
|
|
651
|
+
if node_ids[i] > max_node_id:
|
|
652
|
+
max_node_id = node_ids[i]
|
|
653
|
+
|
|
654
|
+
node_to_idx = np.full(max_node_id + 1, -1, dtype=np.int32)
|
|
655
|
+
for i in range(n):
|
|
656
|
+
node_to_idx[node_ids[i]] = i
|
|
657
|
+
|
|
658
|
+
# --- Phase 1: Grid-dedupe ONLY tips ---
|
|
659
|
+
seen_cells = Dict.empty(key_type=types.int64, value_type=types.int32)
|
|
660
|
+
|
|
661
|
+
for i in range(n):
|
|
662
|
+
if not is_tip[i]:
|
|
663
|
+
continue # Skip internal nodes in grid phase
|
|
664
|
+
|
|
665
|
+
# Compute grid cell for this tip
|
|
666
|
+
cx = min(int(x[i] * resolution), resolution - 1)
|
|
667
|
+
cy = min(int(y[i] * resolution), resolution - 1)
|
|
668
|
+
key = cx * (resolution + 1) + cy
|
|
669
|
+
|
|
670
|
+
if key not in seen_cells:
|
|
671
|
+
seen_cells[key] = i
|
|
672
|
+
keep[i] = True
|
|
673
|
+
|
|
674
|
+
# --- Phase 2: Trace ancestors for each kept tip ---
|
|
675
|
+
for i in range(n):
|
|
676
|
+
if keep[i] and is_tip[i]: # Only trace from kept tips
|
|
677
|
+
parent = parent_ids[i]
|
|
678
|
+
while parent != -1:
|
|
679
|
+
if parent > max_node_id:
|
|
680
|
+
break
|
|
681
|
+
parent_idx = node_to_idx[parent]
|
|
682
|
+
if parent_idx < 0:
|
|
683
|
+
break
|
|
684
|
+
if keep[parent_idx]:
|
|
685
|
+
break # Already kept, path is connected
|
|
686
|
+
keep[parent_idx] = True
|
|
687
|
+
parent = parent_ids[parent_idx]
|
|
688
|
+
|
|
689
|
+
return keep
|