lorax-arg 0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. lorax/buffer.py +43 -0
  2. lorax/cache/__init__.py +43 -0
  3. lorax/cache/csv_tree_graph.py +59 -0
  4. lorax/cache/disk.py +467 -0
  5. lorax/cache/file_cache.py +142 -0
  6. lorax/cache/file_context.py +72 -0
  7. lorax/cache/lru.py +90 -0
  8. lorax/cache/tree_graph.py +293 -0
  9. lorax/cli.py +312 -0
  10. lorax/cloud/__init__.py +0 -0
  11. lorax/cloud/gcs_utils.py +205 -0
  12. lorax/constants.py +66 -0
  13. lorax/context.py +80 -0
  14. lorax/csv/__init__.py +7 -0
  15. lorax/csv/config.py +250 -0
  16. lorax/csv/layout.py +182 -0
  17. lorax/csv/newick_tree.py +234 -0
  18. lorax/handlers.py +998 -0
  19. lorax/lineage.py +456 -0
  20. lorax/loaders/__init__.py +0 -0
  21. lorax/loaders/csv_loader.py +10 -0
  22. lorax/loaders/loader.py +31 -0
  23. lorax/loaders/tskit_loader.py +119 -0
  24. lorax/lorax_app.py +75 -0
  25. lorax/manager.py +58 -0
  26. lorax/metadata/__init__.py +0 -0
  27. lorax/metadata/loader.py +426 -0
  28. lorax/metadata/mutations.py +146 -0
  29. lorax/modes.py +190 -0
  30. lorax/pg.py +183 -0
  31. lorax/redis_utils.py +30 -0
  32. lorax/routes.py +137 -0
  33. lorax/session_manager.py +206 -0
  34. lorax/sockets/__init__.py +55 -0
  35. lorax/sockets/connection.py +99 -0
  36. lorax/sockets/debug.py +47 -0
  37. lorax/sockets/decorators.py +112 -0
  38. lorax/sockets/file_ops.py +200 -0
  39. lorax/sockets/lineage.py +307 -0
  40. lorax/sockets/metadata.py +232 -0
  41. lorax/sockets/mutations.py +154 -0
  42. lorax/sockets/node_search.py +535 -0
  43. lorax/sockets/tree_layout.py +117 -0
  44. lorax/sockets/utils.py +10 -0
  45. lorax/tree_graph/__init__.py +12 -0
  46. lorax/tree_graph/tree_graph.py +689 -0
  47. lorax/utils.py +124 -0
  48. lorax_app/__init__.py +4 -0
  49. lorax_app/app.py +159 -0
  50. lorax_app/cli.py +114 -0
  51. lorax_app/static/X.png +0 -0
  52. lorax_app/static/assets/index-BCEGlUFi.js +2361 -0
  53. lorax_app/static/assets/index-iKjzUpA9.css +1 -0
  54. lorax_app/static/assets/localBackendWorker-BaWwjSV_.js +2 -0
  55. lorax_app/static/assets/renderDataWorker-BKLdiU7J.js +2 -0
  56. lorax_app/static/gestures/gesture-flick.ogv +0 -0
  57. lorax_app/static/gestures/gesture-two-finger-scroll.ogv +0 -0
  58. lorax_app/static/index.html +14 -0
  59. lorax_app/static/logo.png +0 -0
  60. lorax_app/static/lorax-logo.png +0 -0
  61. lorax_app/static/vite.svg +1 -0
  62. lorax_arg-0.1.dist-info/METADATA +131 -0
  63. lorax_arg-0.1.dist-info/RECORD +66 -0
  64. lorax_arg-0.1.dist-info/WHEEL +5 -0
  65. lorax_arg-0.1.dist-info/entry_points.txt +4 -0
  66. lorax_arg-0.1.dist-info/top_level.txt +2 -0
lorax/csv/config.py ADDED
@@ -0,0 +1,250 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, Dict, List, Tuple
5
+
6
+ import pandas as pd
7
+
8
+ from lorax.utils import extract_sample_names, max_branch_length_from_newick
9
+
10
+
11
+ @dataclass(frozen=True)
12
+ class CsvConfigOptions:
13
+ window_size: int = 50_000
14
+
15
+
16
+ REQUIRED_COLUMNS = ("genomic_positions", "newick")
17
+ MAX_BRANCH_COL = "max_branch_length"
18
+
19
+
20
+ def _dedupe_preserve_order(values: List[Any]) -> List[str]:
21
+ seen = set()
22
+ out: List[str] = []
23
+ for v in values:
24
+ s = str(v)
25
+ if s in seen:
26
+ continue
27
+ seen.add(s)
28
+ out.append(s)
29
+ return out
30
+
31
+
32
+ def _validate_csv_df(df: pd.DataFrame) -> None:
33
+ missing = [c for c in REQUIRED_COLUMNS if c not in df.columns]
34
+ if missing:
35
+ raise ValueError(
36
+ f"CSV missing required column(s): {missing}. "
37
+ f"Expected columns: {list(REQUIRED_COLUMNS)}"
38
+ )
39
+
40
+ if len(df) == 0:
41
+ raise ValueError("CSV contains no rows.")
42
+
43
+
44
+ def _sorted_reset(df: pd.DataFrame) -> pd.DataFrame:
45
+ # Deterministic ordering: global_index == row index after this.
46
+ out = df.copy()
47
+ out["genomic_positions"] = pd.to_numeric(out["genomic_positions"], errors="raise")
48
+ out = out.sort_values("genomic_positions", kind="mergesort")
49
+ out = out.reset_index(drop=True)
50
+ return out
51
+
52
+
53
+ def _compute_intervals(genomic_positions: List[int], window_size: int) -> List[int]:
54
+ # Frontend expects N+1 breakpoints for N trees (see localBackendWorker logic).
55
+ # Each tree i spans [intervals[i], intervals[i+1]).
56
+ if not genomic_positions:
57
+ return []
58
+
59
+ positions = [int(p) for p in genomic_positions]
60
+ intervals = positions[:] # N items
61
+ last_end = int(positions[-1]) + int(window_size)
62
+ intervals.append(last_end) # N+1
63
+ return intervals
64
+
65
+
66
+ def _is_empty_value(value: Any) -> bool:
67
+ if value is None:
68
+ return True
69
+ if isinstance(value, float) and pd.isna(value):
70
+ return True
71
+ if isinstance(value, str) and value.strip() == "":
72
+ return True
73
+ return False
74
+
75
+
76
+ def extract_csv_metadata(df: pd.DataFrame) -> Dict[str, List[str]]:
77
+ """Extract file-level metadata from a CSV DataFrame.
78
+
79
+ Metadata columns are all columns after max_branch_length. File-level metadata
80
+ values come only from metadata-only rows (rows with empty tree columns).
81
+ """
82
+ if MAX_BRANCH_COL not in df.columns:
83
+ return {}
84
+
85
+ max_idx = list(df.columns).index(MAX_BRANCH_COL)
86
+ tree_cols = list(df.columns[: max_idx + 1])
87
+ metadata_cols = list(df.columns[max_idx + 1 :])
88
+
89
+ if not metadata_cols:
90
+ return {}
91
+
92
+ file_level: Dict[str, List[str]] = {c: [] for c in metadata_cols}
93
+
94
+ for _, row in df.iterrows():
95
+ if any(not _is_empty_value(row[col]) for col in tree_cols):
96
+ continue
97
+ for col in metadata_cols:
98
+ value = row[col]
99
+ if _is_empty_value(value):
100
+ continue
101
+ file_level[col].append(str(value))
102
+
103
+ return file_level
104
+
105
+
106
+ def build_csv_config(
107
+ df: pd.DataFrame,
108
+ file_path: str,
109
+ *,
110
+ options: CsvConfigOptions | None = None,
111
+ ) -> Dict[str, Any]:
112
+ """Build a Lorax-compatible config for Newick-per-row CSV.
113
+
114
+ CSV schema:
115
+ - genomic_positions: int (tree start position)
116
+ - newick: str (Newick tree)
117
+
118
+ Contract notes:
119
+ - `intervals` must be N+1 breakpoints (required by frontend binning logic).
120
+ - `times` uses branch length for CSV: {type: "branch length", values: [0, max]}.
121
+ """
122
+ options = options or CsvConfigOptions()
123
+
124
+ _validate_csv_df(df)
125
+ file_metadata = extract_csv_metadata(df)
126
+ df2 = _sorted_reset(df)
127
+
128
+ # Optional per-tree metadata: "tree_info" / "tree info" column.
129
+ # This is intended for frontend-side per-tree UI (e.g., default per-tree colors)
130
+ # and is sent eagerly with config on load_file.
131
+ tree_info_map: Dict[str, str] = {}
132
+ tree_info_col = None
133
+ for cand in ("tree_info", "tree info"):
134
+ if cand in df2.columns:
135
+ tree_info_col = cand
136
+ break
137
+
138
+ if tree_info_col is not None:
139
+ # Use row index as tree_idx (this matches backend CSV layout access: df.iloc[tree_idx])
140
+ for i, row in df2.iterrows():
141
+ nwk = row.get("newick")
142
+ if isinstance(nwk, float) and pd.isna(nwk):
143
+ continue
144
+ if _is_empty_value(nwk):
145
+ continue
146
+ v = row.get(tree_info_col)
147
+ if _is_empty_value(v):
148
+ continue
149
+ tree_info_map[str(int(i))] = str(v)
150
+
151
+ # Compute max tree height (branch length) and sample names (best-effort, lightweight).
152
+ #
153
+ # Prefer the CSV-provided per-tree `max_branch_length` column when present, since
154
+ # regex parsing only captures max *edge* length, not the full root→tip height.
155
+ max_branch_length_all = 0.0
156
+ samples_set = set()
157
+ has_max_col = MAX_BRANCH_COL in df2.columns
158
+ saw_valid_max_col_value = False
159
+
160
+ for _, row in df2.iterrows():
161
+ nwk = row["newick"]
162
+ if isinstance(nwk, float) and pd.isna(nwk):
163
+ continue
164
+ if _is_empty_value(nwk):
165
+ continue
166
+ nwk = str(nwk)
167
+
168
+ if has_max_col:
169
+ v = row.get(MAX_BRANCH_COL)
170
+ if not _is_empty_value(v):
171
+ try:
172
+ max_br = float(v)
173
+ saw_valid_max_col_value = True
174
+ if max_br > max_branch_length_all:
175
+ max_branch_length_all = max_br
176
+ except Exception:
177
+ # Ignore invalid per-row values; may fallback to regex below.
178
+ pass
179
+
180
+ try:
181
+ sample_names = extract_sample_names(nwk)
182
+ samples_set.update(sample_names)
183
+ except Exception:
184
+ pass
185
+
186
+ # Fallback: if the CSV doesn't provide usable per-tree max heights, derive a
187
+ # best-effort global max from Newick text (max edge length).
188
+ if not saw_valid_max_col_value:
189
+ for _, row in df2.iterrows():
190
+ nwk = row["newick"]
191
+ if isinstance(nwk, float) and pd.isna(nwk):
192
+ continue
193
+ if _is_empty_value(nwk):
194
+ continue
195
+ nwk = str(nwk)
196
+
197
+ try:
198
+ max_br = float(max_branch_length_from_newick(nwk))
199
+ if max_br > max_branch_length_all:
200
+ max_branch_length_all = max_br
201
+ except Exception:
202
+ # Keep config resilient; downstream can still load.
203
+ pass
204
+
205
+ genomic_positions = df2["genomic_positions"].astype(int).tolist()
206
+ intervals = _compute_intervals(genomic_positions, options.window_size)
207
+ genome_length = int(intervals[-1]) if intervals else int(genomic_positions[-1])
208
+
209
+ # Compute centered initial position (10% of genome, minimum 1kb) like tskit loader.
210
+ window_size = max(genome_length * 0.1, 1000)
211
+ midpoint = genome_length / 2.0
212
+ start = max(0, midpoint - window_size / 2.0)
213
+ end = min(genome_length, midpoint + window_size / 2.0)
214
+
215
+ # Build a deterministic, file-level sample order for stable tip node IDs across trees.
216
+ # Prefer explicit file metadata (if provided), otherwise derive from all Newicks.
217
+ samples_list_meta = file_metadata.get("samples", [])
218
+ if samples_list_meta:
219
+ samples_order = _dedupe_preserve_order(samples_list_meta)
220
+ else:
221
+ samples_order = sorted(str(s) for s in samples_set)
222
+
223
+ # Temporary workaround: CSV inputs may include the outgroup sample "etal" (any case).
224
+ # The layout/parsing pipeline prunes it from the Newick tree, so also remove it
225
+ # from the file-level samples list to keep UI/search options consistent.
226
+ samples_order = [s for s in samples_order if str(s).lower() != "etal"]
227
+
228
+ sample_names_map = {str(s): {"sample_name": s} for s in samples_order}
229
+
230
+ config: Dict[str, Any] = {
231
+ "genome_length": genome_length,
232
+ "initial_position": [int(start), int(end)],
233
+ "times": {"type": "branch length", "values": [0.0, float(max_branch_length_all)]},
234
+ "intervals": intervals,
235
+ "filename": str(file_path).split("/")[-1],
236
+ "sample_names": sample_names_map,
237
+ "samples": samples_order,
238
+ # Present but empty for compatibility; CSV doesn’t have these yet.
239
+ "metadata_schema": {
240
+ # Always expose "sample" as the only supported metadata key for CSV.
241
+ # (CSV metadata Socket.IO handlers currently support only key == "sample".)
242
+ "metadata_keys": ["sample"]
243
+ },
244
+ }
245
+
246
+ if tree_info_map:
247
+ # Kept as "tree_info" regardless of whether CSV column is "tree info" or "tree_info".
248
+ config["tree_info"] = tree_info_map
249
+
250
+ return config
lorax/csv/layout.py ADDED
@@ -0,0 +1,182 @@
1
+ from __future__ import annotations
2
+
3
+ import struct
4
+ from typing import Any, Dict, List
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ import pyarrow as pa
9
+
10
+
11
+ def build_empty_tree_layout_arrow_ipc() -> bytes:
12
+ """Return a valid (possibly empty) PyArrow IPC stream for tree layout.
13
+
14
+ Frontend expects a table with these columns:
15
+ - node_id:int32
16
+ - parent_id:int32
17
+ - is_tip:bool
18
+ - tree_idx:int32
19
+ - x:float32
20
+ - y:float32
21
+ """
22
+ empty_table = pa.table(
23
+ {
24
+ "node_id": pa.array([], type=pa.int32()),
25
+ "parent_id": pa.array([], type=pa.int32()),
26
+ "is_tip": pa.array([], type=pa.bool_()),
27
+ "tree_idx": pa.array([], type=pa.int32()),
28
+ "x": pa.array([], type=pa.float32()),
29
+ "y": pa.array([], type=pa.float32()),
30
+ "name": pa.array([], type=pa.string()),
31
+ }
32
+ )
33
+ sink = pa.BufferOutputStream()
34
+ writer = pa.ipc.new_stream(sink, empty_table.schema)
35
+ writer.write_table(empty_table)
36
+ writer.close()
37
+ return sink.getvalue().to_pybytes()
38
+
39
+
40
+ def build_empty_layout_response(tree_indices: List[int] | None = None) -> Dict[str, Any]:
41
+ """Build the response shape returned by `handle_tree_graph_query` for CSV."""
42
+ return {
43
+ "buffer": build_empty_tree_layout_arrow_ipc(),
44
+ "global_min_time": 0.0,
45
+ "global_max_time": 0.0,
46
+ "tree_indices": tree_indices or [],
47
+ }
48
+
49
+
50
+ def build_csv_layout_response(
51
+ df: pd.DataFrame,
52
+ tree_indices: List[int],
53
+ max_branch_length: float,
54
+ samples_order: List[str] | None = None,
55
+ pre_parsed_graphs: Dict[int, Any] | None = None,
56
+ shift_tips_to_one: bool = False,
57
+ ) -> Dict[str, Any]:
58
+ """Build PyArrow IPC buffer for CSV trees.
59
+
60
+ Parses Newick strings from the DataFrame and generates the same buffer format
61
+ as construct_trees_batch() for tskit files.
62
+
63
+ Args:
64
+ df: DataFrame with 'newick' column containing Newick tree strings
65
+ tree_indices: List of row indices (global_index) to process
66
+ max_branch_length: From config times.values[1], for y normalization
67
+
68
+ Returns:
69
+ Same format as construct_trees_batch():
70
+ {buffer, global_min_time, global_max_time, tree_indices}
71
+ """
72
+ from lorax.csv.newick_tree import parse_newick_to_tree
73
+
74
+ # Collect all nodes from all trees
75
+ all_node_ids: List[np.ndarray] = []
76
+ all_parent_ids: List[np.ndarray] = []
77
+ all_is_tip: List[np.ndarray] = []
78
+ all_tree_idx: List[np.ndarray] = []
79
+ all_x: List[np.ndarray] = []
80
+ all_y: List[np.ndarray] = []
81
+ all_names: List[np.ndarray] = []
82
+
83
+ processed_indices: List[int] = []
84
+
85
+ for tree_idx in tree_indices:
86
+ tree_idx = int(tree_idx)
87
+ if tree_idx < 0 or tree_idx >= len(df):
88
+ continue
89
+
90
+ newick_str = df.iloc[tree_idx].get("newick")
91
+ if pd.isna(newick_str):
92
+ continue
93
+
94
+ tree_max_branch_length = None
95
+ if "max_branch_length" in df.columns:
96
+ try:
97
+ v = df.iloc[tree_idx].get("max_branch_length")
98
+ if v is not None and not (isinstance(v, float) and pd.isna(v)) and str(v).strip() != "":
99
+ tree_max_branch_length = float(v)
100
+ except Exception:
101
+ tree_max_branch_length = None
102
+
103
+ graph = pre_parsed_graphs.get(tree_idx) if pre_parsed_graphs else None
104
+ if graph is None:
105
+ try:
106
+ graph = parse_newick_to_tree(
107
+ str(newick_str),
108
+ max_branch_length,
109
+ samples_order=samples_order,
110
+ tree_max_branch_length=tree_max_branch_length,
111
+ shift_tips_to_one=shift_tips_to_one,
112
+ )
113
+ except Exception as e:
114
+ # Log error but continue with other trees
115
+ print(f"Failed to parse Newick for tree {tree_idx}: {e}")
116
+ continue
117
+
118
+ n = len(graph.node_id)
119
+ if n == 0:
120
+ continue
121
+
122
+ # Collect nodes for this tree
123
+ # SWAP coordinates to match tskit convention: time -> x, layout -> y
124
+ all_node_ids.append(graph.node_id)
125
+ all_parent_ids.append(graph.parent_id)
126
+ all_is_tip.append(graph.is_tip)
127
+ all_tree_idx.append(np.full(n, tree_idx, dtype=np.int32))
128
+ all_x.append(graph.y.astype(np.float32)) # SWAP: time -> x
129
+ all_y.append(graph.x.astype(np.float32)) # SWAP: layout -> y
130
+ all_names.append(np.array(graph.name, dtype=object))
131
+
132
+ processed_indices.append(tree_idx)
133
+
134
+ # Build PyArrow table
135
+ if not all_node_ids:
136
+ # Return empty buffer
137
+ return build_empty_layout_response(list(tree_indices))
138
+
139
+ # Concatenate all arrays
140
+ node_table = pa.table(
141
+ {
142
+ "node_id": pa.array(np.concatenate(all_node_ids), type=pa.int32()),
143
+ "parent_id": pa.array(np.concatenate(all_parent_ids), type=pa.int32()),
144
+ "is_tip": pa.array(np.concatenate(all_is_tip), type=pa.bool_()),
145
+ "tree_idx": pa.array(np.concatenate(all_tree_idx), type=pa.int32()),
146
+ "x": pa.array(np.concatenate(all_x), type=pa.float32()),
147
+ "y": pa.array(np.concatenate(all_y), type=pa.float32()),
148
+ "name": pa.array(np.concatenate(all_names), type=pa.string()),
149
+ }
150
+ )
151
+
152
+ # Empty mutation table (CSV has no mutations)
153
+ mut_table = pa.table(
154
+ {
155
+ "mut_x": pa.array([], type=pa.float32()),
156
+ "mut_y": pa.array([], type=pa.float32()),
157
+ "mut_tree_idx": pa.array([], type=pa.int32()),
158
+ }
159
+ )
160
+
161
+ # Serialize to IPC format (same as tree_graph.py)
162
+ node_sink = pa.BufferOutputStream()
163
+ node_writer = pa.ipc.new_stream(node_sink, node_table.schema)
164
+ node_writer.write_table(node_table)
165
+ node_writer.close()
166
+ node_bytes = node_sink.getvalue().to_pybytes()
167
+
168
+ mut_sink = pa.BufferOutputStream()
169
+ mut_writer = pa.ipc.new_stream(mut_sink, mut_table.schema)
170
+ mut_writer.write_table(mut_table)
171
+ mut_writer.close()
172
+ mut_bytes = mut_sink.getvalue().to_pybytes()
173
+
174
+ # Combine with length prefix (same format as tree_graph.py:559)
175
+ combined = struct.pack("<I", len(node_bytes)) + node_bytes + mut_bytes
176
+
177
+ return {
178
+ "buffer": combined,
179
+ "global_min_time": 0.0,
180
+ "global_max_time": float(max_branch_length),
181
+ "tree_indices": processed_indices,
182
+ }
@@ -0,0 +1,234 @@
1
+ """Parse Newick strings and compute x,y layout coordinates.
2
+
3
+ This module provides tree layout computation for CSV files containing Newick strings.
4
+ It mirrors the approach in tree_graph/tree_graph.py but sources data from Newick parsing
5
+ instead of tskit tables.
6
+
7
+ Coordinate system:
8
+ - y (time): anchored time in [0,1] where tips are always 1.0 and the root is
9
+ 1 - (tree_height / global_max_height)
10
+ - x (layout): tips get sequential x, internal nodes get (min+max)/2 of children
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ from dataclasses import dataclass
16
+ from typing import List, Optional
17
+
18
+ import numpy as np
19
+
20
+ # ete3 is used for Newick parsing
21
+ from ete3 import Tree
22
+
23
+
24
+ def prune_outgroup_sample(tree: Tree, outgroup: str = "etal") -> None:
25
+ """Prune an outgroup leaf from an ete3 Tree (best-effort).
26
+
27
+ This is a temporary compatibility workaround for CSV Newick inputs that include
28
+ a known outgroup sample (e.g. "etal") that should not be displayed.
29
+
30
+ Behavior:
31
+ - If outgroup is missing, do nothing.
32
+ - Otherwise, reroot on outgroup, delete it, and clear the root branch lengths
33
+ """
34
+ try:
35
+ leaf_names = tree.get_leaf_names()
36
+ if not leaf_names:
37
+ return
38
+
39
+ # Match case-insensitively (e.g. "etal" vs "Etal")
40
+ target = None
41
+ out_l = str(outgroup).lower()
42
+ for name in leaf_names:
43
+ if str(name).lower() == out_l:
44
+ target = str(name)
45
+ break
46
+ if target is None:
47
+ return
48
+
49
+ tree.set_outgroup(target)
50
+ (tree & target).delete()
51
+
52
+ # Remove root branch distortion after rerooting/pruning
53
+ tree.dist = 0.0
54
+ for child in tree.get_children():
55
+ child.dist = 0.0
56
+ except Exception:
57
+ # Best-effort: CSV loading should remain resilient.
58
+ return
59
+
60
+
61
+ @dataclass
62
+ class NewickTreeGraph:
63
+ """Tree structure from parsed Newick with layout coordinates.
64
+
65
+ Attributes:
66
+ node_id: Sequential node IDs assigned via post-order traversal
67
+ parent_id: Parent node ID (-1 for root)
68
+ is_tip: Boolean array indicating leaf nodes
69
+ name: Leaf names from the Newick string
70
+ branch_length: Distance to parent
71
+ x: Layout position [0,1] - tips spread, internal = (min+max)/2 of children
72
+ y: Normalized time [0,1] - cumulative distance from root
73
+ """
74
+
75
+ node_id: np.ndarray # int32
76
+ parent_id: np.ndarray # int32
77
+ is_tip: np.ndarray # bool
78
+ name: List[str] # leaf names
79
+ branch_length: np.ndarray # float32
80
+ x: np.ndarray # float32, layout position
81
+ y: np.ndarray # float32, normalized time
82
+
83
+
84
+ def shift_tree_tips_to_one(y: np.ndarray, is_tip: np.ndarray) -> np.ndarray:
85
+ """Shift y so the maximum tip height is exactly 1.0 (clipped to [0,1])."""
86
+ if y.size == 0:
87
+ return y
88
+ if is_tip is None or is_tip.size == 0:
89
+ return y
90
+ tip_max = float(y[is_tip].max()) if np.any(is_tip) else 1.0
91
+ if not np.isfinite(tip_max) or tip_max == 1.0:
92
+ return y
93
+ return np.clip(y + (1.0 - tip_max), 0.0, 1.0)
94
+
95
+
96
+ def parse_newick_to_tree(
97
+ newick_str: str,
98
+ max_branch_length: float,
99
+ samples_order: Optional[List[str]] = None,
100
+ *,
101
+ tree_max_branch_length: float | None = None,
102
+ shift_tips_to_one: bool = False,
103
+ ) -> NewickTreeGraph:
104
+ """Parse Newick string and compute x,y coordinates.
105
+
106
+ Uses ete3 to parse the Newick string, then computes layout coordinates
107
+ using the same algorithm as tree_graph.py:
108
+ - y: cumulative distance from root, normalized by max_branch_length
109
+ - x: post-order layout where tips get sequential x, internals get (min+max)/2
110
+
111
+ Args:
112
+ newick_str: Newick format tree string
113
+ max_branch_length: Global max height for y anchoring (from config times.values[1])
114
+ tree_max_branch_length: Optional per-tree max height. If provided, this is
115
+ used to anchor the root at 1 - (tree_height / global_max). If omitted,
116
+ tree height is derived from the parsed tree (max cumulative root distance).
117
+
118
+ Returns:
119
+ NewickTreeGraph with layout coordinates normalized to [0,1]
120
+ """
121
+ # Parse with ete3 - format=1 includes branch lengths and internal node names
122
+ tree = Tree(newick_str, format=1)
123
+ # Prune the outgroup before any layout or time computation.
124
+ # TODO: remove in feature.
125
+ prune_outgroup_sample(tree, outgroup="etal")
126
+
127
+ # Assign traversal order via post-order traversal (after pruning)
128
+ nodes = list(tree.traverse("postorder"))
129
+ if not nodes:
130
+ empty_i32 = np.array([], dtype=np.int32)
131
+ empty_bool = np.array([], dtype=np.bool_)
132
+ empty_f32 = np.array([], dtype=np.float32)
133
+ return NewickTreeGraph(
134
+ node_id=empty_i32,
135
+ parent_id=empty_i32,
136
+ is_tip=empty_bool,
137
+ name=[],
138
+ branch_length=empty_f32,
139
+ x=empty_f32,
140
+ y=empty_f32,
141
+ )
142
+ node_index = {node: idx for idx, node in enumerate(nodes)}
143
+
144
+ num_nodes = len(nodes)
145
+
146
+ # Build arrays
147
+ node_id = np.arange(num_nodes, dtype=np.int32)
148
+ parent_id = np.full(num_nodes, -1, dtype=np.int32)
149
+ is_tip = np.zeros(num_nodes, dtype=np.bool_)
150
+ branch_length = np.zeros(num_nodes, dtype=np.float32)
151
+ name = [""] * num_nodes
152
+
153
+ assigned_ids = {}
154
+ next_id = 0
155
+ if samples_order:
156
+ sample_id_map = {str(name): idx for idx, name in enumerate(samples_order)}
157
+ for node in nodes:
158
+ if node.is_leaf():
159
+ node_name = node.name if node.name else ""
160
+ if node_name not in sample_id_map:
161
+ raise ValueError(
162
+ f"Leaf sample name '{node_name}' not found in samples_order. "
163
+ "Ensure CSV config provides a complete, file-level samples list."
164
+ )
165
+ assigned_ids[node] = sample_id_map[node_name]
166
+ next_id = len(samples_order)
167
+ for node in nodes:
168
+ if not node.is_leaf():
169
+ assigned_ids[node] = next_id
170
+ next_id += 1
171
+ else:
172
+ for node in nodes:
173
+ assigned_ids[node] = next_id
174
+ next_id += 1
175
+
176
+ for node in nodes:
177
+ idx = node_index[node]
178
+ if node.up is not None:
179
+ parent_id[idx] = assigned_ids[node.up]
180
+ is_tip[idx] = node.is_leaf()
181
+ branch_length[idx] = node.dist if node.dist else 0.0
182
+ name[idx] = node.name if node.name else ""
183
+ node_id[idx] = assigned_ids[node]
184
+
185
+ # Compute y_raw (time): cumulative distance from root
186
+ # Root at 0, tips at tree_height
187
+ y_raw = np.zeros(num_nodes, dtype=np.float32)
188
+ for node in tree.traverse("preorder"): # Root first
189
+ idx = node_index[node]
190
+ if node.up is not None:
191
+ parent_idx = node_index[node.up]
192
+ y_raw[idx] = y_raw[parent_idx] + branch_length[idx]
193
+
194
+ # Anchor time so tips are always 1.0 and the root is 1 - tree_height/global_max.
195
+ #
196
+ # If the per-tree max height is known (e.g. CSV 'max_branch_length' column), use
197
+ # it; otherwise derive it from the parsed tree.
198
+ tree_height = float(tree_max_branch_length) if (tree_max_branch_length or 0.0) > 0 else float(y_raw.max())
199
+ if max_branch_length > 0:
200
+ y = 1.0 - (tree_height - y_raw) / float(max_branch_length)
201
+ y = np.clip(y, 0.0, 1.0)
202
+ else:
203
+ # Degenerate case: no global height information. Keep tips at 1.
204
+ y = np.ones(num_nodes, dtype=np.float32)
205
+
206
+ if shift_tips_to_one:
207
+ y = shift_tree_tips_to_one(y, is_tip)
208
+
209
+ # Compute x (layout): tips get sequential x, internals = (min+max)/2
210
+ x = np.zeros(num_nodes, dtype=np.float32)
211
+ tip_counter = 0
212
+
213
+ for node in tree.traverse("postorder"):
214
+ idx = node_index[node]
215
+ if node.is_leaf():
216
+ x[idx] = tip_counter
217
+ tip_counter += 1
218
+ else:
219
+ child_xs = [x[node_index[child]] for child in node.children]
220
+ x[idx] = (min(child_xs) + max(child_xs)) / 2.0
221
+
222
+ # Normalize x to [0,1]
223
+ if tip_counter > 1:
224
+ x = x / (tip_counter - 1)
225
+
226
+ return NewickTreeGraph(
227
+ node_id=node_id,
228
+ parent_id=parent_id,
229
+ is_tip=is_tip,
230
+ name=name,
231
+ branch_length=branch_length,
232
+ x=x,
233
+ y=y,
234
+ )